From 58c37f588dcb75239364f007d60ffbeaebb3f93c Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 13 Jan 2023 14:27:31 +0800 Subject: [PATCH] feat: plan some aggregate expr in PromQL planner (#870) Signed-off-by: Ruihang Xia Signed-off-by: Ruihang Xia --- src/promql/src/error.rs | 16 +- src/promql/src/planner.rs | 327 +++++++++++++++++++++++++++++++++++--- 2 files changed, 317 insertions(+), 26 deletions(-) diff --git a/src/promql/src/error.rs b/src/promql/src/error.rs index 4a055729f0..d95d500f7e 100644 --- a/src/promql/src/error.rs +++ b/src/promql/src/error.rs @@ -15,7 +15,7 @@ use std::any::Any; use common_error::prelude::*; -use promql_parser::parser::Expr as PromExpr; +use promql_parser::parser::{Expr as PromExpr, TokenType}; #[derive(Debug, Snafu)] #[snafu(visibility(pub))] @@ -23,6 +23,12 @@ pub enum Error { #[snafu(display("Unsupported expr type: {}", name))] UnsupportedExpr { name: String, backtrace: Backtrace }, + #[snafu(display("Unexpected token: {}", token))] + UnexpectedToken { + token: TokenType, + backtrace: Backtrace, + }, + #[snafu(display("Internal error during build DataFusion plan, error: {}", source))] DataFusionPlanning { source: datafusion::error::DataFusionError, @@ -41,6 +47,12 @@ pub enum Error { #[snafu(display("Cannot find value columns in table {}", table))] ValueNotFound { table: String, backtrace: Backtrace }, + #[snafu(display("Cannot find label in table {}, source: {}", table, source))] + LabelNotFound { + table: String, + source: datafusion::error::DataFusionError, + }, + #[snafu(display("Cannot find the table {}", table))] TableNotFound { table: String, @@ -90,7 +102,9 @@ impl ErrorExt for Error { TimeIndexNotFound { .. } | ValueNotFound { .. } | UnsupportedExpr { .. } + | UnexpectedToken { .. } | MultipleVector { .. } + | LabelNotFound { .. } | ExpectExpr { .. } => StatusCode::InvalidArguments, UnknownTable { .. } | TableNotFound { .. } diff --git a/src/promql/src/planner.rs b/src/promql/src/planner.rs index 2799c6d2ba..b580d6cd77 100644 --- a/src/promql/src/planner.rs +++ b/src/promql/src/planner.rs @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, UNIX_EPOCH}; use datafusion::datasource::DefaultTableSource; +use datafusion::logical_expr::expr::AggregateFunction; use datafusion::logical_expr::{ - BinaryExpr, BuiltinScalarFunction, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Operator, + AggregateFunction as AggregateFunctionEnum, BinaryExpr, BuiltinScalarFunction, Extension, + Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; use datafusion::optimizer::utils; use datafusion::prelude::{Column, Expr as DfExpr}; @@ -26,16 +29,16 @@ use datafusion::scalar::ScalarValue; use datafusion::sql::planner::ContextProvider; use datafusion::sql::TableReference; use promql_parser::label::{MatchOp, Matchers, METRIC_NAME}; -use promql_parser::parser::{EvalStmt, Expr as PromExpr, Function}; +use promql_parser::parser::{token, EvalStmt, Expr as PromExpr, Function, TokenType}; use snafu::{OptionExt, ResultExt}; use table::table::adapter::DfTableProviderAdapter; use crate::error::{ - DataFusionPlanningSnafu, ExpectExprSnafu, MultipleVectorSnafu, Result, TableNameNotFoundSnafu, - TableNotFoundSnafu, TimeIndexNotFoundSnafu, UnknownTableSnafu, UnsupportedExprSnafu, - ValueNotFoundSnafu, + DataFusionPlanningSnafu, ExpectExprSnafu, LabelNotFoundSnafu, MultipleVectorSnafu, Result, + TableNameNotFoundSnafu, TableNotFoundSnafu, TimeIndexNotFoundSnafu, UnexpectedPlanExprSnafu, + UnexpectedTokenSnafu, UnknownTableSnafu, UnsupportedExprSnafu, ValueNotFoundSnafu, }; -use crate::extension_plan::{InstantManipulate, Millisecond, SeriesNormalize}; +use crate::extension_plan::{InstantManipulate, Millisecond, RangeManipulate, SeriesNormalize}; #[derive(Default, Debug, Clone)] struct PromPlannerContext { @@ -79,10 +82,55 @@ impl PromPlanner { pub fn prom_expr_to_plan(&mut self, prom_expr: PromExpr) -> Result { let res = match &prom_expr { - PromExpr::AggregateExpr { .. } => UnsupportedExprSnafu { - name: "Prom Aggregate", + PromExpr::AggregateExpr { + op, + expr, + // TODO(ruihang): support param + param: _param, + grouping, + without, + } => { + let input = self.prom_expr_to_plan(*expr.clone())?; + + // calculate columns to group by + let schema = input.schema(); + let group_columns_indices = grouping + .iter() + .map(|label| { + schema + .index_of_column_by_name(None, label) + .with_context(|_| LabelNotFoundSnafu { + table: self.ctx.table_name.clone().unwrap(), + }) + }) + .collect::>>()?; + let value_names = self.ctx.value_columns.iter().collect::>(); + let group_exprs = schema + .fields() + .iter() + .enumerate() + .filter_map(|(i, field)| { + if *without != group_columns_indices.contains(&i) + && Some(field.name()) != self.ctx.time_index_column.as_ref() + && !value_names.contains(&field.name()) + { + Some(DfExpr::Column(Column::from(field.name()))) + } else { + None + } + }) + .collect::>(); + + // convert op and value columns to aggregate exprs + let aggr_exprs = self.create_aggregate_exprs(*op)?; + + // create plan + LogicalPlanBuilder::from(input) + .aggregate(group_exprs, aggr_exprs) + .context(DataFusionPlanningSnafu)? + .build() + .context(DataFusionPlanningSnafu)? } - .fail()?, PromExpr::UnaryExpr { .. } => UnsupportedExprSnafu { name: "Prom Unary Expr", } @@ -136,10 +184,47 @@ impl PromPlanner { node: Arc::new(manipulate), }) } - PromExpr::MatrixSelector { .. } => UnsupportedExprSnafu { - name: "Prom Matrix Selector", + PromExpr::MatrixSelector { + vector_selector, + range, + } => { + let normalize = match &**vector_selector { + PromExpr::VectorSelector { + name: _, + offset, + start_or_end: _, + label_matchers, + } => { + let matchers = self.preprocess_label_matchers(label_matchers)?; + self.setup_context()?; + self.selector_to_series_normalize_plan(*offset, matchers)? + } + _ => UnexpectedPlanExprSnafu { + desc: format!( + "MatrixSelector must contains a VectorSelector, but found {vector_selector:?}", + ), + } + .fail()?, + }; + let manipulate = RangeManipulate::new( + self.ctx.start, + self.ctx.end, + self.ctx.interval, + // TODO(ruihang): convert via Timestamp datatypes to support different time units + range.as_millis() as _, + self.ctx + .time_index_column + .clone() + .expect("time index should be set in `setup_context`"), + self.ctx.value_columns.clone(), + normalize, + ) + .context(DataFusionPlanningSnafu)?; + + LogicalPlan::Extension(Extension { + node: Arc::new(manipulate), + }) } - .fail()?, PromExpr::Call { func, args } => { let args = self.create_function_args(args)?; let input = @@ -387,6 +472,41 @@ impl PromPlanner { table: self.ctx.table_name.clone().unwrap(), }) } + + fn create_aggregate_exprs(&self, op: TokenType) -> Result> { + let aggr = match op { + token::T_SUM => AggregateFunctionEnum::Sum, + token::T_AVG => AggregateFunctionEnum::Avg, + token::T_COUNT => AggregateFunctionEnum::Count, + token::T_MIN => AggregateFunctionEnum::Min, + token::T_MAX => AggregateFunctionEnum::Max, + token::T_GROUP => AggregateFunctionEnum::Grouping, + token::T_STDDEV => AggregateFunctionEnum::Stddev, + token::T_STDVAR => AggregateFunctionEnum::Variance, + token::T_TOPK | token::T_BOTTOMK | token::T_COUNT_VALUES | token::T_QUANTILE => { + UnsupportedExprSnafu { + name: op.to_string(), + } + .fail()? + } + _ => UnexpectedTokenSnafu { token: op }.fail()?, + }; + + let exprs = self + .ctx + .value_columns + .iter() + .map(|col| { + DfExpr::AggregateFunction(AggregateFunction { + fun: aggr.clone(), + args: vec![DfExpr::Column(Column::from_name(col))], + distinct: false, + filter: None, + }) + }) + .collect(); + Ok(exprs) + } } #[derive(Default, Debug)] @@ -474,19 +594,19 @@ mod test { } // { - // input: `abs(some_metric{foo!="bar"})`, - // expected: &Call{ - // Func: MustGetFunction("abs"), - // Args: Expressions{ - // &VectorSelector{ - // Name: "some_metric", - // LabelMatchers: []*labels.Matcher{ - // MustLabelMatcher(labels.MatchNotEqual, "foo", "bar"), - // MustLabelMatcher(labels.MatchEqual, model.MetricNameLabel, "some_metric"), - // }, - // }, - // }, - // }, + // input: `abs(some_metric{foo!="bar"})`, + // expected: &Call{ + // Func: MustGetFunction("abs"), + // Args: Expressions{ + // &VectorSelector{ + // Name: "some_metric", + // LabelMatchers: []*labels.Matcher{ + // MustLabelMatcher(labels.MatchNotEqual, "foo", "bar"), + // MustLabelMatcher(labels.MatchEqual, model.MetricNameLabel, "some_metric"), + // }, + // }, + // }, + // }, // }, async fn do_single_instant_function_call(fn_name: &'static str, plan_name: &str) { let prom_expr = PromExpr::Call { @@ -689,4 +809,161 @@ mod test { async fn single_rad() { do_single_instant_function_call("rad", "").await; } + + // { + // input: "avg by (foo)(some_metric)", + // expected: &AggregateExpr{ + // Op: AVG, + // Expr: &VectorSelector{ + // Name: "some_metric", + // LabelMatchers: []*labels.Matcher{ + // MustLabelMatcher(labels.MatchEqual, model.MetricNameLabel, "some_metric"), + // }, + // PosRange: PositionRange{ + // Start: 13, + // End: 24, + // }, + // }, + // Grouping: []string{"foo"}, + // PosRange: PositionRange{ + // Start: 0, + // End: 25, + // }, + // }, + // }, + async fn do_aggregate_expr_plan(op: TokenType, name: &str) { + let prom_expr = PromExpr::AggregateExpr { + op, + expr: Box::new(PromExpr::VectorSelector { + name: Some("some_metric".to_owned()), + offset: None, + start_or_end: None, + label_matchers: Matchers { + matchers: vec![ + Matcher { + op: MatchOp::NotEqual, + name: "tag_0".to_string(), + value: "bar".to_string(), + }, + Matcher { + op: MatchOp::Equal, + name: METRIC_NAME.to_string(), + value: "some_metric".to_string(), + }, + ], + }, + }), + param: Box::new(PromExpr::empty_vector_selector()), + grouping: vec![String::from("tag_1")], + without: false, + }; + let mut eval_stmt = EvalStmt { + expr: prom_expr, + start: UNIX_EPOCH, + end: UNIX_EPOCH + .checked_add(Duration::from_secs(100_000)) + .unwrap(), + interval: Duration::from_secs(5), + lookback_delta: Duration::from_secs(1), + }; + + // test group by + let context_provider = build_test_context_provider("some_metric".to_string(), 2, 2).await; + let plan = PromPlanner::stmt_to_plan(eval_stmt.clone(), context_provider).unwrap(); + let expected_no_without = String::from( + "Aggregate: groupBy=[[some_metric.tag_1]], aggr=[[TEMPLATE(some_metric.field_0), TEMPLATE(some_metric.field_1)]] [tag_1:Utf8, TEMPLATE(some_metric.field_0):Float64;N, TEMPLATE(some_metric.field_1):Float64;N]\ + \n PromInstantManipulate: range=[0..100000000], lookback=[1000], interval=[5000], time index=[timestamp] [tag_0:Utf8, tag_1:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N]\ + \n PromSeriesNormalize: offset=[0], time index=[timestamp] [tag_0:Utf8, tag_1:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N]\ + \n Filter: tag_0 != Utf8(\"bar\") AND timestamp >= TimestampMillisecond(0, None) AND timestamp <= TimestampMillisecond(100000000, None) [tag_0:Utf8, tag_1:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N]\ + \n TableScan: some_metric, unsupported_filters=[tag_0 != Utf8(\"bar\"), timestamp >= TimestampMillisecond(0, None), timestamp <= TimestampMillisecond(100000000, None)] [tag_0:Utf8, tag_1:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N]") + .replace("TEMPLATE", name); + assert_eq!( + plan.display_indent_schema().to_string(), + expected_no_without + ); + + // test group without + if let PromExpr::AggregateExpr { without, .. } = &mut eval_stmt.expr { + *without = true; + } + let context_provider = build_test_context_provider("some_metric".to_string(), 2, 2).await; + let plan = PromPlanner::stmt_to_plan(eval_stmt, context_provider).unwrap(); + let expected_without = String::from( + "Aggregate: groupBy=[[some_metric.tag_0]], aggr=[[TEMPLATE(some_metric.field_0), TEMPLATE(some_metric.field_1)]] [tag_0:Utf8, TEMPLATE(some_metric.field_0):Float64;N, TEMPLATE(some_metric.field_1):Float64;N]\ + \n PromInstantManipulate: range=[0..100000000], lookback=[1000], interval=[5000], time index=[timestamp] [tag_0:Utf8, tag_1:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N]\ + \n PromSeriesNormalize: offset=[0], time index=[timestamp] [tag_0:Utf8, tag_1:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N]\ + \n Filter: tag_0 != Utf8(\"bar\") AND timestamp >= TimestampMillisecond(0, None) AND timestamp <= TimestampMillisecond(100000000, None) [tag_0:Utf8, tag_1:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N]\ + \n TableScan: some_metric, unsupported_filters=[tag_0 != Utf8(\"bar\"), timestamp >= TimestampMillisecond(0, None), timestamp <= TimestampMillisecond(100000000, None)] [tag_0:Utf8, tag_1:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N]") + .replace("TEMPLATE", name); + assert_eq!(plan.display_indent_schema().to_string(), expected_without); + } + + #[tokio::test] + async fn aggregate_sum() { + do_aggregate_expr_plan(token::T_SUM, "SUM").await; + } + + #[tokio::test] + async fn aggregate_avg() { + do_aggregate_expr_plan(token::T_AVG, "AVG").await; + } + + #[tokio::test] + #[should_panic] // output type doesn't match + async fn aggregate_count() { + do_aggregate_expr_plan(token::T_COUNT, "COUNT").await; + } + + #[tokio::test] + async fn aggregate_min() { + do_aggregate_expr_plan(token::T_MIN, "MIN").await; + } + + #[tokio::test] + async fn aggregate_max() { + do_aggregate_expr_plan(token::T_MAX, "MAX").await; + } + + #[tokio::test] + #[should_panic] // output type doesn't match + async fn aggregate_group() { + do_aggregate_expr_plan(token::T_GROUP, "GROUPING").await; + } + + #[tokio::test] + async fn aggregate_stddev() { + do_aggregate_expr_plan(token::T_STDDEV, "STDDEV").await; + } + + #[tokio::test] + #[should_panic] // schema doesn't match + async fn aggregate_stdvar() { + do_aggregate_expr_plan(token::T_STDVAR, "STDVAR").await; + } + + #[tokio::test] + #[should_panic] + async fn aggregate_top_k() { + do_aggregate_expr_plan(token::T_TOPK, "").await; + } + + #[tokio::test] + #[should_panic] + async fn aggregate_bottom_k() { + do_aggregate_expr_plan(token::T_BOTTOMK, "").await; + } + + #[tokio::test] + #[should_panic] + async fn aggregate_count_values() { + do_aggregate_expr_plan(token::T_COUNT_VALUES, "").await; + } + + #[tokio::test] + #[should_panic] + async fn aggregate_quantile() { + do_aggregate_expr_plan(token::T_QUANTILE, "").await; + } + + // TODO(ruihang): add range fn tests once exprs are ready. }