diff --git a/src/servers/benches/prom_decode.rs b/src/servers/benches/prom_decode.rs index 28f86c844c..dc404918a1 100644 --- a/src/servers/benches/prom_decode.rs +++ b/src/servers/benches/prom_decode.rs @@ -16,9 +16,9 @@ use std::time::Duration; use api::prom_store::remote::WriteRequest; use bytes::Bytes; -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; use prost::Message; -use servers::http::PromValidationMode; +use servers::http::{PromValidationMode, validate_label_name}; use servers::prom_store::to_grpc_row_insert_requests; use servers::proto::{PromSeriesProcessor, PromWriteRequest}; @@ -56,7 +56,7 @@ fn bench_decode_prom_request(c: &mut Criterion) { |b, &mode| { b.iter(|| { let data = data.clone(); - prom_request.merge(data, mode, &mut p).unwrap(); + prom_request.decode(data, mode, &mut p).unwrap(); prom_request.as_row_insert_requests(); }); }, @@ -66,5 +66,98 @@ fn bench_decode_prom_request(c: &mut Criterion) { group.finish(); } -criterion_group!(benches, bench_decode_prom_request); +/// Benchmark comparing UTF-8 string validation (`decode_string`) vs +/// direct byte-level Prometheus label name validation (`decode_label_name`). +fn bench_label_name_validation(c: &mut Criterion) { + let mut group = c.benchmark_group("label_name_validation"); + group.measurement_time(Duration::from_secs(3)); + + // Test inputs: typical Prometheus label names of varying lengths. + let test_names: Vec<(&str, &[u8])> = vec![ + ("short", b"__name__"), + ("medium", b"http_request_duration_seconds"), + ( + "long", + b"very_long_label_name_that_might_appear_in_a_real_prometheus_metric_configuration", + ), + ]; + + let strict = PromValidationMode::Strict; + + for (label, name_bytes) in &test_names { + // Benchmark decode_string (UTF-8 validation only) + group.bench_with_input( + BenchmarkId::new("decode_string", label), + name_bytes, + |b, bytes| { + b.iter(|| { + black_box(strict.decode_string(black_box(bytes)).unwrap()); + }); + }, + ); + + // Benchmark decode_label_name (byte-level ASCII check + unchecked conversion) + group.bench_with_input( + BenchmarkId::new("decode_label_name", label), + name_bytes, + |b, bytes| { + b.iter(|| black_box(strict.decode_label_name(black_box(bytes)).unwrap())); + }, + ); + + // Benchmark is_valid_prom_label_name_bytes alone (byte check only, no String allocation) + group.bench_with_input( + BenchmarkId::new("is_valid_prom_label_name_bytes", label), + name_bytes, + |b, bytes| { + b.iter(|| { + black_box(validate_label_name(black_box(bytes))); + }); + }, + ); + } + + group.finish(); +} + +/// Benchmark comparing `std::str::from_utf8` vs `simdutf8::basic::from_utf8` +/// across varying input data lengths. +fn bench_utf8_validation(c: &mut Criterion) { + let mut group = c.benchmark_group("utf8_validation"); + group.measurement_time(Duration::from_secs(3)); + + // Generate valid ASCII/UTF-8 byte slices of varying lengths. + // Uses a repeating pattern of typical label characters. + let pattern = b"http_request_duration_seconds_total_bucket"; + let lengths: Vec = vec![8, 32, 64, 128, 256, 512, 1024, 4096, 16384, 65536]; + + for &len in &lengths { + let data: Vec = pattern.iter().copied().cycle().take(len).collect(); + + group.bench_with_input(BenchmarkId::new("std_from_utf8", len), &data, |b, data| { + b.iter(|| { + black_box(std::str::from_utf8(black_box(data)).unwrap()); + }); + }); + + group.bench_with_input( + BenchmarkId::new("simdutf8_basic_from_utf8", len), + &data, + |b, data| { + b.iter(|| { + black_box(simdutf8::basic::from_utf8(black_box(data)).unwrap()); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_decode_prom_request, + bench_label_name_validation, + bench_utf8_validation +); criterion_main!(benches); diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 53402fdd54..9181551cde 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -176,8 +176,62 @@ pub enum PromValidationMode { Unchecked, } +/// Returns `true` if the given byte slice is a valid Prometheus label name, +/// i.e., it matches `[a-zA-Z_][a-zA-Z0-9_]*`. +/// +/// Since the allowed characters are pure ASCII, valid label names are +/// also valid UTF-8 by definition. +const IS_VALID_LABEL_REST: [bool; 256] = { + let mut table = [false; 256]; + let mut i = 0; + while i < 256 { + let b = i as u8; + table[i] = b.is_ascii_alphanumeric() || b == b'_'; + i += 1; + } + table +}; + +#[inline] +pub fn validate_label_name(name: &[u8]) -> bool { + if name.is_empty() { + return false; + } + let first = name[0]; + if !(first.is_ascii_alphabetic() || first == b'_') { + return false; + } + + let mut rest = &name[1..]; + while rest.len() >= 8 { + let res = IS_VALID_LABEL_REST[rest[0] as usize] as u8 + & IS_VALID_LABEL_REST[rest[1] as usize] as u8 + & IS_VALID_LABEL_REST[rest[2] as usize] as u8 + & IS_VALID_LABEL_REST[rest[3] as usize] as u8 + & IS_VALID_LABEL_REST[rest[4] as usize] as u8 + & IS_VALID_LABEL_REST[rest[5] as usize] as u8 + & IS_VALID_LABEL_REST[rest[6] as usize] as u8 + & IS_VALID_LABEL_REST[rest[7] as usize] as u8; + + if res == 0 { + return false; + } + rest = &rest[8..]; + } + + for &b in rest { + if !IS_VALID_LABEL_REST[b as usize] { + return false; + } + } + + true +} + impl PromValidationMode { /// Decodes provided bytes to [String] with optional UTF-8 validation. + /// + /// Use this for label **values** and other general string fields. pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result { let result = match self { PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) { @@ -193,14 +247,28 @@ impl PromValidationMode { Ok(result) } - pub(crate) fn validate_bytes(&self, bytes: &[u8]) -> std::result::Result<(), DecodeError> { - match self { - PromValidationMode::Strict => { - simdutf8::basic::from_utf8(bytes).map_err(|_| DecodeError::new("invalid utf-8"))?; - Ok(()) - } - PromValidationMode::Lossy | PromValidationMode::Unchecked => Ok(()), + /// Decodes provided bytes to a label name [`&str`] with Prometheus label name validation. + /// + /// The check is always performed regardless of [`PromValidationMode`], as required by + /// the [Prometheus data model](https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels). + pub fn decode_label_name<'a>( + &self, + bytes: &'a [u8], + ) -> std::result::Result<&'a str, DecodeError> { + if !validate_label_name(bytes) { + debug!( + "Invalid Prometheus label name: {:?}, must match [a-zA-Z_][a-zA-Z0-9_]*", + bytes + ); + return Err(DecodeError::new(format!( + "invalid prometheus label name: '{}', must match [a-zA-Z_][a-zA-Z0-9_]*", + String::from_utf8_lossy(bytes) + ))); } + + // SAFETY: `validate_label_name` only allows ASCII bytes, + // and ASCII is always valid UTF-8. + Ok(unsafe { std::str::from_utf8_unchecked(bytes) }) } } @@ -1759,4 +1827,80 @@ mod test { assert_eq!(ResponseFormat::Null.as_str(), "null"); assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1"); } + + #[test] + fn test_decode_label_name_strict() { + let strict = PromValidationMode::Strict; + + // Valid Prometheus label names + assert!(strict.decode_label_name(b"__name__").is_ok()); + assert!(strict.decode_label_name(b"job").is_ok()); + assert!(strict.decode_label_name(b"instance").is_ok()); + assert!(strict.decode_label_name(b"_private").is_ok()); + assert!(strict.decode_label_name(b"label_with_underscores").is_ok()); + assert!(strict.decode_label_name(b"abc123").is_ok()); + assert!(strict.decode_label_name(b"A").is_ok()); + assert!(strict.decode_label_name(b"_").is_ok()); + + // Invalid: starts with digit + assert!(strict.decode_label_name(b"0abc").is_err()); + assert!(strict.decode_label_name(b"123").is_err()); + + // Invalid: contains special characters + assert!(strict.decode_label_name(b"label-name").is_err()); + assert!(strict.decode_label_name(b"label.name").is_err()); + assert!(strict.decode_label_name(b"label name").is_err()); + assert!(strict.decode_label_name(b"label/name").is_err()); + + // Invalid: empty + assert!(strict.decode_label_name(b"").is_err()); + + // Invalid: non-ASCII UTF-8 + assert!(strict.decode_label_name("ラベル".as_bytes()).is_err()); + + // Invalid UTF-8 bytes should fail + assert!(strict.decode_label_name(&[0xff, 0xfe]).is_err()); + } + + #[test] + fn test_decode_label_name_lossy() { + let lossy = PromValidationMode::Lossy; + + // Label name validation is always enforced. + assert!(lossy.decode_label_name(b"__name__").is_ok()); + assert!(lossy.decode_label_name(b"label-name").is_err()); + assert!(lossy.decode_label_name(b"0abc").is_err()); + + // Invalid UTF-8 bytes fail the label-name byte check. + assert!(lossy.decode_label_name(&[0xff, 0xfe]).is_err()); + } + + #[test] + fn test_decode_label_name_unchecked() { + let unchecked = PromValidationMode::Unchecked; + + // Label name validation is always enforced. + assert!(unchecked.decode_label_name(b"__name__").is_ok()); + assert!(unchecked.decode_label_name(b"label-name").is_err()); + assert!(unchecked.decode_label_name(b"0abc").is_err()); + } + + #[test] + fn test_is_valid_prom_label_name_bytes() { + use super::validate_label_name; + + assert!(validate_label_name(b"__name__")); + assert!(validate_label_name(b"job")); + assert!(validate_label_name(b"_")); + assert!(validate_label_name(b"A")); + assert!(validate_label_name(b"abc123")); + assert!(validate_label_name(b"_leading_underscore")); + + assert!(!validate_label_name(b"")); + assert!(!validate_label_name(b"0starts_with_digit")); + assert!(!validate_label_name(b"has-dash")); + assert!(!validate_label_name(b"has.dot")); + assert!(!validate_label_name(b"has space")); + assert!(!validate_label_name(&[0xff, 0xfe])); + } } diff --git a/src/servers/src/http/prom_store.rs b/src/servers/src/http/prom_store.rs index 62bb2383be..d8891e0d98 100644 --- a/src/servers/src/http/prom_store.rs +++ b/src/servers/src/http/prom_store.rs @@ -45,7 +45,7 @@ use crate::query_handler::{PipelineHandlerRef, PromStoreProtocolHandlerRef, Prom pub const PHYSICAL_TABLE_PARAM: &str = "physical_table"; lazy_static! { - static ref PROM_WRITE_REQUEST_POOL: Pool = + static ref PROM_WRITE_REQUEST_POOL: Pool> = Pool::new(256, PromWriteRequest::default); } @@ -222,7 +222,7 @@ pub fn decode_remote_write_request( body: Bytes, prom_validation_mode: PromValidationMode, processor: &mut PromSeriesProcessor, -) -> Result { +) -> Result> { let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_DECODE_ELAPSED.start_timer(); // due to vmagent's limitation, there is a chance that vmagent is @@ -241,7 +241,7 @@ pub fn decode_remote_write_request( let mut request = PROM_WRITE_REQUEST_POOL.pull(PromWriteRequest::default); request - .merge(buf, prom_validation_mode, processor) + .decode(buf, prom_validation_mode, processor) .context(error::DecodePromRemoteRequestSnafu)?; Ok(std::mem::take(&mut request.table_data)) } diff --git a/src/servers/src/prom_row_builder.rs b/src/servers/src/prom_row_builder.rs index df38f9b5fe..a951b5df8e 100644 --- a/src/servers/src/prom_row_builder.rs +++ b/src/servers/src/prom_row_builder.rs @@ -18,11 +18,12 @@ use api::prom_store::remote::Sample; use api::v1::helper::{field_column_schema, time_index_column_schema}; use api::v1::value::ValueData; use api::v1::{ColumnDataType, ColumnSchema, Row, RowInsertRequest, Rows, SemanticType, Value}; +use bytes::Bytes; use common_query::prelude::{greptime_timestamp, greptime_value}; use pipeline::{ContextOpt, ContextReq}; use prost::DecodeError; -use crate::http::PromValidationMode; +use crate::http::{PromValidationMode, validate_label_name}; use crate::proto::PromLabel; use crate::repeated_field::Clear; @@ -35,18 +36,21 @@ pub struct PromCtx { /// [TablesBuilder] serves as an intermediate container to build [RowInsertRequests]. #[derive(Default, Debug)] -pub struct TablesBuilder { +pub struct TablesBuilder<'a> { // schema -> table -> table_builder - pub tables: HashMap>, + pub tables: HashMap>>, + /// Raw request data. + raw_data: Bytes, } -impl Clear for TablesBuilder { +impl<'a> Clear for TablesBuilder<'a> { fn clear(&mut self) { self.tables.clear(); + self.raw_data = Bytes::new(); } } -impl TablesBuilder { +impl<'a> TablesBuilder<'a> { /// Gets table builder with given table name. Creates an empty [TableBuilder] if not exist. pub(crate) fn get_or_create_table_builder( &mut self, @@ -54,7 +58,7 @@ impl TablesBuilder { table_name: String, label_num: usize, row_num: usize, - ) -> &mut TableBuilder { + ) -> &mut TableBuilder<'a> { self.tables .entry(prom_ctx) .or_default() @@ -90,30 +94,34 @@ impl TablesBuilder { req }) } + + pub(crate) fn set_raw_data(&mut self, buf: Bytes) { + self.raw_data = buf; + } } /// Builder for one table. #[derive(Debug)] -pub struct TableBuilder { +pub struct TableBuilder<'a> { /// Column schemas. schema: Vec, /// Rows written. rows: Vec, /// Indices of columns inside `schema`. - col_indexes: HashMap, usize>, + col_indexes: HashMap<&'a [u8], usize>, } -impl Default for TableBuilder { +impl<'a> Default for TableBuilder<'a> { fn default() -> Self { Self::with_capacity(2, 0) } } -impl TableBuilder { +impl<'a> TableBuilder<'a> { pub(crate) fn with_capacity(cols: usize, rows: usize) -> Self { let mut col_indexes = HashMap::with_capacity_and_hasher(cols, Default::default()); - col_indexes.insert(greptime_timestamp().as_bytes().to_owned(), 0); - col_indexes.insert(greptime_value().as_bytes().to_owned(), 1); + col_indexes.insert(greptime_timestamp().as_bytes(), 0); + col_indexes.insert(greptime_value().as_bytes(), 1); let mut schema = Vec::with_capacity(cols); schema.push(time_index_column_schema( @@ -142,25 +150,32 @@ impl TableBuilder { let mut row = vec![Value { value_data: None }; self.col_indexes.len()]; for PromLabel { name, value } in labels { - prom_validation_mode.validate_bytes(name)?; - let raw_tag_name = name; + if !validate_label_name(name) { + return Err(DecodeError::new(format!( + "Invalid label name: `{}`", + String::from_utf8_lossy(name) + ))); + } + let raw_tag_name = *name; let tag_value = Some(ValueData::StringValue( prom_validation_mode.decode_string(value)?, )); let tag_num = self.col_indexes.len(); - if let Some(e) = self.col_indexes.get_mut(*raw_tag_name) { + if let Some(e) = self.col_indexes.get_mut(raw_tag_name) { row[*e].value_data = tag_value; continue; } - let tag_name = prom_validation_mode.decode_string(raw_tag_name)?; + + // Safety: we've validated the label name is valid in line 152. + let tag_name = unsafe { std::str::from_utf8_unchecked(raw_tag_name) }; self.schema.push(ColumnSchema { - column_name: tag_name.clone(), + column_name: tag_name.to_owned(), datatype: ColumnDataType::String as i32, semantic_type: SemanticType::Tag as i32, ..Default::default() }); - self.col_indexes.insert(tag_name.into_bytes(), tag_num); + self.col_indexes.insert(raw_tag_name, tag_num); row.push(Value { value_data: tag_value, @@ -220,6 +235,7 @@ mod tests { use crate::http::PromValidationMode; use crate::prom_row_builder::TableBuilder; use crate::proto::PromLabel; + #[test] fn test_table_builder() { let mut builder = TableBuilder::default(); diff --git a/src/servers/src/proto.rs b/src/servers/src/proto.rs index d46489124a..62ab275808 100644 --- a/src/servers/src/proto.rs +++ b/src/servers/src/proto.rs @@ -206,9 +206,9 @@ impl PromTimeSeries { } } - fn add_to_table_data( + fn add_to_table_data<'a>( &mut self, - table_builders: &mut TablesBuilder, + table_builders: &mut TablesBuilder<'a>, prom_validation_mode: PromValidationMode, ) -> Result<(), DecodeError> { let label_num = self.labels.len(); @@ -236,30 +236,32 @@ impl PromTimeSeries { } #[derive(Default, Debug)] -pub struct PromWriteRequest { - pub(crate) table_data: TablesBuilder, +pub struct PromWriteRequest<'a> { + pub(crate) table_data: TablesBuilder<'a>, series: PromTimeSeries, } -impl Clear for PromWriteRequest { +impl<'a> Clear for PromWriteRequest<'a> { fn clear(&mut self) { self.table_data.clear(); } } -impl PromWriteRequest { +impl<'a> PromWriteRequest<'a> { pub fn as_row_insert_requests(&mut self) -> ContextReq { self.table_data.as_insert_requests() } - // todo(hl): maybe use &[u8] can reduce the overhead introduced with Bytes. - pub fn merge( + /// Decode the buf. + pub fn decode( &mut self, mut buf: Bytes, prom_validation_mode: PromValidationMode, processor: &mut PromSeriesProcessor, ) -> Result<(), DecodeError> { const STRUCT_NAME: &str = "PromWriteRequest"; + // Keep a reference to the underlying buffer so the decoded raw bytes won't be dangling. + self.table_data.set_raw_data(buf.clone()); while buf.has_remaining() { let (tag, wire_type) = decode_key(&mut buf)?; assert_eq!(WireType::LengthDelimited, wire_type); @@ -360,7 +362,7 @@ impl PromSeriesProcessor { let mut vec_pipeline_map = Vec::new(); let mut pipeline_map = BTreeMap::new(); for l in series.labels.iter() { - let name = prom_validation_mode.decode_string(l.name)?; + let name = prom_validation_mode.decode_label_name(l.name)?; let value = prom_validation_mode.decode_string(l.value)?; pipeline_map.insert(KeyString::from(name), VrlValue::Bytes(value.into())); } @@ -470,7 +472,7 @@ mod tests { let mut p = PromSeriesProcessor::default_processor(); prom_write_request.clear(); prom_write_request - .merge(data.clone(), PromValidationMode::Strict, &mut p) + .decode(data.clone(), PromValidationMode::Strict, &mut p) .unwrap(); let req = prom_write_request.as_row_insert_requests();