diff --git a/src/servers/src/grpc/prom_query_gateway.rs b/src/servers/src/grpc/prom_query_gateway.rs index fb9d605e70..3ee7902f2c 100644 --- a/src/servers/src/grpc/prom_query_gateway.rs +++ b/src/servers/src/grpc/prom_query_gateway.rs @@ -122,7 +122,7 @@ impl PrometheusGatewayService { let result = self.handler.do_query(&query, ctx).await; let (metric_name, mut result_type) = match retrieve_metric_name_and_result_type(&query.query) { - Ok((metric_name, result_type)) => (metric_name.unwrap_or_default(), result_type), + Ok((metric_name, result_type)) => (metric_name, result_type), Err(err) => { return PrometheusJsonResponse::error(err.status_code(), err.output_msg()) } diff --git a/src/servers/src/http/prometheus.rs b/src/servers/src/http/prometheus.rs index acb39ab4cd..b55f48a000 100644 --- a/src/servers/src/http/prometheus.rs +++ b/src/servers/src/http/prometheus.rs @@ -318,7 +318,7 @@ async fn do_instant_query( ) -> PrometheusJsonResponse { let result = handler.do_query(prom_query, query_ctx).await; let (metric_name, result_type) = match retrieve_metric_name_and_result_type(&prom_query.query) { - Ok((metric_name, result_type)) => (metric_name.unwrap_or_default(), result_type), + Ok((metric_name, result_type)) => (metric_name, result_type), Err(err) => return PrometheusJsonResponse::error(err.status_code(), err.output_msg()), }; PrometheusJsonResponse::from_query_result(result, metric_name, result_type).await @@ -428,7 +428,7 @@ async fn do_range_query( let result = handler.do_query(prom_query, query_ctx).await; let metric_name = match retrieve_metric_name_and_result_type(&prom_query.query) { Err(err) => return PrometheusJsonResponse::error(err.status_code(), err.output_msg()), - Ok((metric_name, _)) => metric_name.unwrap_or_default(), + Ok((metric_name, _)) => metric_name, }; PrometheusJsonResponse::from_query_result(result, metric_name, ValueType::Matrix).await } @@ -824,13 +824,52 @@ pub(crate) fn try_update_catalog_schema(ctx: &mut QueryContext, catalog: &str, s } fn promql_expr_to_metric_name(expr: &PromqlExpr) -> Option { - find_metric_name_and_matchers(expr, |name, matchers| { - name.clone().or(matchers - .find_matchers(METRIC_NAME) - .into_iter() - .next() - .map(|m| m.value)) - }) + let mut metric_names = HashSet::new(); + collect_metric_names(expr, &mut metric_names); + + // Return the metric name only if there's exactly one unique metric name + if metric_names.len() == 1 { + metric_names.into_iter().next() + } else { + None + } +} + +/// Recursively collect all metric names from a PromQL expression +fn collect_metric_names(expr: &PromqlExpr, metric_names: &mut HashSet) { + match expr { + PromqlExpr::Aggregate(AggregateExpr { expr, .. }) => { + collect_metric_names(expr, metric_names) + } + PromqlExpr::Unary(UnaryExpr { expr }) => collect_metric_names(expr, metric_names), + PromqlExpr::Binary(BinaryExpr { lhs, rhs, .. }) => { + collect_metric_names(lhs, metric_names); + collect_metric_names(rhs, metric_names); + } + PromqlExpr::Paren(ParenExpr { expr }) => collect_metric_names(expr, metric_names), + PromqlExpr::Subquery(SubqueryExpr { expr, .. }) => collect_metric_names(expr, metric_names), + PromqlExpr::VectorSelector(VectorSelector { name, matchers, .. }) => { + if let Some(name) = name { + metric_names.insert(name.clone()); + } else if let Some(matcher) = matchers.find_matchers(METRIC_NAME).into_iter().next() { + metric_names.insert(matcher.value); + } + } + PromqlExpr::MatrixSelector(MatrixSelector { vs, .. }) => { + let VectorSelector { name, matchers, .. } = vs; + if let Some(name) = name { + metric_names.insert(name.clone()); + } else if let Some(matcher) = matchers.find_matchers(METRIC_NAME).into_iter().next() { + metric_names.insert(matcher.value); + } + } + PromqlExpr::Call(Call { args, .. }) => { + args.args + .iter() + .for_each(|e| collect_metric_names(e, metric_names)); + } + PromqlExpr::NumberLiteral(_) | PromqlExpr::StringLiteral(_) | PromqlExpr::Extension(_) => {} + } } fn find_metric_name_and_matchers(expr: &PromqlExpr, f: F) -> Option @@ -1114,51 +1153,11 @@ async fn retrieve_field_names( /// Try to parse and extract the name of referenced metric from the promql query. /// -/// Returns the metric name if a single metric is referenced, otherwise None. +/// Returns the metric name if exactly one unique metric is referenced, otherwise None. +/// Multiple references to the same metric are allowed. fn retrieve_metric_name_from_promql(query: &str) -> Option { let promql_expr = promql_parser::parser::parse(query).ok()?; - - struct MetricNameVisitor { - metric_name: Option, - } - - impl promql_parser::util::ExprVisitor for MetricNameVisitor { - type Error = (); - - fn pre_visit(&mut self, plan: &PromqlExpr) -> std::result::Result { - let query_metric_name = match plan { - PromqlExpr::VectorSelector(vs) => vs - .matchers - .find_matchers(METRIC_NAME) - .into_iter() - .next() - .map(|m| m.value) - .or_else(|| vs.name.clone()), - PromqlExpr::MatrixSelector(ms) => ms - .vs - .matchers - .find_matchers(METRIC_NAME) - .into_iter() - .next() - .map(|m| m.value) - .or_else(|| ms.vs.name.clone()), - _ => return Ok(true), - }; - - // set it to empty string if multiple metrics are referenced. - if self.metric_name.is_some() && query_metric_name.is_some() { - self.metric_name = Some(String::new()); - } else { - self.metric_name = query_metric_name.or_else(|| self.metric_name.clone()); - } - - Ok(true) - } - } - - let mut visitor = MetricNameVisitor { metric_name: None }; - promql_parser::util::walk_expr(&mut visitor, &promql_expr).ok()?; - visitor.metric_name + promql_expr_to_metric_name(&promql_expr) } #[derive(Debug, Default, Serialize, Deserialize)] @@ -1275,3 +1274,205 @@ pub async fn parse_query( PrometheusJsonResponse::error(StatusCode::InvalidArguments, "query is required") } } + +#[cfg(test)] +mod tests { + use promql_parser::parser::value::ValueType; + + use super::*; + + struct TestCase { + name: &'static str, + promql: &'static str, + expected_metric: Option<&'static str>, + expected_type: ValueType, + should_error: bool, + } + + #[test] + fn test_retrieve_metric_name_and_result_type() { + let test_cases = &[ + // Single metric cases + TestCase { + name: "simple metric", + promql: "cpu_usage", + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Vector, + should_error: false, + }, + TestCase { + name: "metric with selector", + promql: r#"cpu_usage{instance="localhost"}"#, + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Vector, + should_error: false, + }, + TestCase { + name: "metric with range selector", + promql: "cpu_usage[5m]", + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Matrix, + should_error: false, + }, + TestCase { + name: "metric with __name__ matcher", + promql: r#"{__name__="cpu_usage"}"#, + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Vector, + should_error: false, + }, + TestCase { + name: "metric with unary operator", + promql: "-cpu_usage", + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Vector, + should_error: false, + }, + // Aggregation and function cases + TestCase { + name: "metric with aggregation", + promql: "sum(cpu_usage)", + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Vector, + should_error: false, + }, + TestCase { + name: "complex aggregation", + promql: r#"sum by (instance) (cpu_usage{job="node"})"#, + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Vector, + should_error: false, + }, + // Same metric binary operations + TestCase { + name: "same metric addition", + promql: "cpu_usage + cpu_usage", + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Vector, + should_error: false, + }, + TestCase { + name: "metric with scalar addition", + promql: r#"sum(rate(cpu_usage{job="node"}[5m])) by (instance) + 100"#, + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Vector, + should_error: false, + }, + // Multiple metrics cases + TestCase { + name: "different metrics addition", + promql: "cpu_usage + memory_usage", + expected_metric: None, + expected_type: ValueType::Vector, + should_error: false, + }, + TestCase { + name: "different metrics subtraction", + promql: "network_in - network_out", + expected_metric: None, + expected_type: ValueType::Vector, + should_error: false, + }, + // Unless operator cases + TestCase { + name: "unless with different metrics", + promql: "cpu_usage unless memory_usage", + expected_metric: None, + expected_type: ValueType::Vector, + should_error: false, + }, + TestCase { + name: "unless with same metric", + promql: "cpu_usage unless cpu_usage", + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Vector, + should_error: false, + }, + // Subquery cases + TestCase { + name: "basic subquery", + promql: "cpu_usage[5m:1m]", + expected_metric: Some("cpu_usage"), + expected_type: ValueType::Matrix, + should_error: false, + }, + TestCase { + name: "subquery with multiple metrics", + promql: "(cpu_usage + memory_usage)[5m:1m]", + expected_metric: None, + expected_type: ValueType::Matrix, + should_error: false, + }, + // Literal values + TestCase { + name: "scalar value", + promql: "42", + expected_metric: None, + expected_type: ValueType::Scalar, + should_error: false, + }, + TestCase { + name: "string literal", + promql: r#""hello world""#, + expected_metric: None, + expected_type: ValueType::String, + should_error: false, + }, + // Error cases + TestCase { + name: "invalid syntax", + promql: "cpu_usage{invalid=", + expected_metric: None, + expected_type: ValueType::Vector, + should_error: true, + }, + TestCase { + name: "empty query", + promql: "", + expected_metric: None, + expected_type: ValueType::Vector, + should_error: true, + }, + TestCase { + name: "malformed brackets", + promql: "cpu_usage[5m", + expected_metric: None, + expected_type: ValueType::Vector, + should_error: true, + }, + ]; + + for test_case in test_cases { + let result = retrieve_metric_name_and_result_type(test_case.promql); + + if test_case.should_error { + assert!( + result.is_err(), + "Test '{}' should have failed but succeeded with: {:?}", + test_case.name, + result + ); + } else { + let (metric_name, value_type) = result.unwrap_or_else(|e| { + panic!( + "Test '{}' should have succeeded but failed with error: {}", + test_case.name, e + ) + }); + + let expected_metric_name = test_case.expected_metric.map(|s| s.to_string()); + assert_eq!( + metric_name, expected_metric_name, + "Test '{}': metric name mismatch. Expected: {:?}, Got: {:?}", + test_case.name, expected_metric_name, metric_name + ); + + assert_eq!( + value_type, test_case.expected_type, + "Test '{}': value type mismatch. Expected: {:?}, Got: {:?}", + test_case.name, test_case.expected_type, value_type + ); + } + } + } +} diff --git a/src/servers/src/http/result/prometheus_resp.rs b/src/servers/src/http/result/prometheus_resp.rs index 0209ed293b..bba507b676 100644 --- a/src/servers/src/http/result/prometheus_resp.rs +++ b/src/servers/src/http/result/prometheus_resp.rs @@ -118,7 +118,7 @@ impl PrometheusJsonResponse { /// Convert from `Result` pub async fn from_query_result( result: Result, - metric_name: String, + metric_name: Option, result_type: ValueType, ) -> Self { let response: Result = try { @@ -182,7 +182,7 @@ impl PrometheusJsonResponse { /// Convert [RecordBatches] to [PromData] fn record_batches_to_data( batches: RecordBatches, - metric_name: String, + metric_name: Option, result_type: ValueType, ) -> Result { // infer semantic type of each column from schema. @@ -230,7 +230,6 @@ impl PrometheusJsonResponse { reason: "no value column found".to_string(), })?; - let metric_name = (METRIC_NAME, metric_name.as_str()); // Preserves the order of output tags. // Tag order matters, e.g., after sorc and sort_desc, the output order must be kept. let mut buffer = IndexMap::, Vec<(f64, String)>>::new(); @@ -276,9 +275,10 @@ impl PrometheusJsonResponse { } // retrieve tags - // TODO(ruihang): push table name `__metric__` let mut tags = Vec::with_capacity(num_label_columns + 1); - tags.push(metric_name); + if let Some(metric_name) = &metric_name { + tags.push((METRIC_NAME, metric_name.as_str())); + } for (tag_column, tag_name) in tag_columns.iter().zip(tag_names.iter()) { // TODO(ruihang): add test for NULL tag if let Some(tag_value) = tag_column.get_data(row_index) {