diff --git a/src/query/src/optimizer.rs b/src/query/src/optimizer.rs index 4259b587ba..aaac1e3124 100644 --- a/src/query/src/optimizer.rs +++ b/src/query/src/optimizer.rs @@ -13,6 +13,7 @@ // limitations under the License. pub mod constant_term; +pub mod count_nest_aggr; pub mod count_wildcard; pub mod parallelize_scan; pub mod pass_distribution; diff --git a/src/query/src/optimizer/count_nest_aggr.rs b/src/query/src/optimizer/count_nest_aggr.rs new file mode 100644 index 0000000000..89ba426074 --- /dev/null +++ b/src/query/src/optimizer/count_nest_aggr.rs @@ -0,0 +1,346 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashSet; +use std::sync::Arc; + +use datafusion::config::ConfigOptions; +use datafusion::functions_aggregate::count::count_udaf; +use datafusion::logical_expr::{Extension, LogicalPlan, LogicalPlanBuilder, Sort}; +use datafusion_common::Result; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_expr::{Expr, UserDefinedLogicalNodeCore, lit}; +use promql::extension_plan::{InstantManipulate, SeriesDivide, SeriesNormalize}; +use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME; + +use crate::QueryEngineContext; +use crate::optimizer::ExtensionAnalyzerRule; + +/// Rewrites `count(() by (...))` into a presence-based +/// group count. +/// +/// This stays intentionally narrow: +/// - the outer aggregate must be plain `count` +/// - the inner aggregate must be a plain aggregate whose result existence is equivalent to input +/// group existence +/// - the inner input must be the direct instant-vector-selector plan +/// - the outer count must only group by the evaluation timestamp +#[derive(Debug)] +pub struct CountNestAggrRule; + +impl ExtensionAnalyzerRule for CountNestAggrRule { + fn analyze( + &self, + plan: LogicalPlan, + _ctx: &QueryEngineContext, + _config: &ConfigOptions, + ) -> Result { + plan.transform_down(&Self::rewrite_plan).map(|x| x.data) + } +} + +impl CountNestAggrRule { + fn rewrite_plan(plan: LogicalPlan) -> Result> { + let LogicalPlan::Sort(sort) = plan else { + return Ok(Transformed::no(plan)); + }; + + if let Some(rewritten) = Self::try_rewrite_sort(&sort)? { + Ok(Transformed::yes(rewritten)) + } else { + Ok(Transformed::no(LogicalPlan::Sort(sort))) + } + } + + fn try_rewrite_sort(sort: &Sort) -> Result> { + if sort.fetch.is_some() { + return Ok(None); + } + + let LogicalPlan::Aggregate(outer_agg) = sort.input.as_ref() else { + return Ok(None); + }; + if outer_agg.group_expr.len() != 1 || outer_agg.aggr_expr.len() != 1 { + return Ok(None); + } + let outer_time_expr = outer_agg.group_expr[0].clone(); + let outer_count_arg = + match Self::aggregate_if(&outer_agg.aggr_expr[0], |name| name == "count") { + Some((_, arg)) => arg, + None => return Ok(None), + }; + + let LogicalPlan::Sort(inner_sort) = outer_agg.input.as_ref() else { + return Ok(None); + }; + if inner_sort.fetch.is_some() { + return Ok(None); + } + + let LogicalPlan::Aggregate(inner_agg) = inner_sort.input.as_ref() else { + return Ok(None); + }; + if inner_agg.aggr_expr.len() != 1 || inner_agg.group_expr.is_empty() { + return Ok(None); + } + let (inner_is_count, inner_value_expr) = + match Self::aggregate_if(&inner_agg.aggr_expr[0], |name| { + Self::is_supported_inner_aggregate(name) + }) { + Some((name, arg)) => (name == "count", arg), + None => return Ok(None), + }; + let Expr::Column(_) = inner_value_expr else { + return Ok(None); + }; + + let Expr::Column(outer_count_column) = outer_count_arg else { + return Ok(None); + }; + let inner_output_field = inner_agg.schema.field(inner_agg.group_expr.len()); + if outer_count_column.name != *inner_output_field.name() { + return Ok(None); + } + + if !Self::is_projection_chain_to_instant(inner_agg.input.as_ref()) { + return Ok(None); + } + + if !inner_agg + .group_expr + .iter() + .all(|expr| matches!(expr, Expr::Column(_))) + { + return Ok(None); + } + + let Some(time_expr_pos) = inner_agg + .group_expr + .iter() + .position(|expr| expr == &outer_time_expr) + else { + return Ok(None); + }; + + let mut presence_group_exprs = Vec::with_capacity(inner_agg.group_expr.len()); + presence_group_exprs.push(outer_time_expr.clone()); + presence_group_exprs.extend( + inner_agg + .group_expr + .iter() + .enumerate() + .filter(|(idx, _)| *idx != time_expr_pos) + .map(|(_, expr)| expr.clone()), + ); + + let mut required_input_columns = + Self::collect_required_input_columns(&presence_group_exprs, inner_value_expr); + required_input_columns.extend(Self::collect_required_instant_columns( + inner_agg.input.as_ref(), + )); + let presence_source = Self::rebuild_projection_chain_to_instant( + inner_agg.input.as_ref(), + &required_input_columns, + )?; + + let outer_value_name = outer_agg + .schema + .field(outer_agg.group_expr.len()) + .name() + .clone(); + let mut presence_input = LogicalPlanBuilder::from(presence_source); + if !inner_is_count { + presence_input = presence_input.filter(inner_value_expr.clone().is_not_null())?; + } + let presence_input = presence_input + .project(presence_group_exprs.clone())? + .distinct()? + .build()?; + + let rewritten = LogicalPlanBuilder::from(presence_input) + .aggregate( + outer_agg.group_expr.clone(), + vec![count_udaf().call(vec![lit(1_i64)]).alias(outer_value_name)], + )? + .sort(sort.expr.clone())? + .build()?; + + Ok(Some(rewritten)) + } + + fn collect_required_input_columns(group_exprs: &[Expr], value_expr: &Expr) -> HashSet { + let mut required = HashSet::new(); + + for expr in group_exprs { + if let Expr::Column(column) = expr { + required.insert(column.name.clone()); + } + } + if let Expr::Column(column) = value_expr { + // Keep the value column in the pruned instant input so `InstantManipulate` + // can still perform stale-NaN filtering before we project down to keys. + required.insert(column.name.clone()); + } + + required + } + + fn collect_required_instant_columns(plan: &LogicalPlan) -> HashSet { + let mut required = HashSet::new(); + Self::collect_required_instant_columns_into(plan, &mut required); + required + } + + fn collect_required_instant_columns_into(plan: &LogicalPlan, required: &mut HashSet) { + match plan { + LogicalPlan::Projection(projection) => { + Self::collect_required_instant_columns_into(projection.input.as_ref(), required); + } + LogicalPlan::Extension(extension) => { + for expr in extension.node.expressions() { + if let Expr::Column(column) = expr { + required.insert(column.name); + } + } + + if extension.node.as_any().is::() + && extension.node.inputs()[0] + .schema() + .fields() + .iter() + .any(|field| field.name() == DATA_SCHEMA_TSID_COLUMN_NAME) + { + required.insert(DATA_SCHEMA_TSID_COLUMN_NAME.to_string()); + } + + if let Some(input) = extension.node.inputs().into_iter().next() { + Self::collect_required_instant_columns_into(input, required); + } + } + _ => {} + } + } + + fn aggregate_if(expr: &Expr, accept_name: F) -> Option<(&str, &Expr)> + where + F: FnOnce(&str) -> bool, + { + let Expr::AggregateFunction(func) = expr else { + return None; + }; + let name = func.func.name(); + if !accept_name(name) + || func.params.filter.is_some() + || func.params.distinct + || !func.params.order_by.is_empty() + || func.params.args.len() != 1 + { + return None; + } + + Some((name, &func.params.args[0])) + } + + fn is_supported_inner_aggregate(name: &str) -> bool { + matches!( + name, + "count" | "sum" | "avg" | "min" | "max" | "stddev_pop" | "var_pop" + ) + } + + fn is_projection_chain_to_instant(plan: &LogicalPlan) -> bool { + let mut current = plan; + loop { + match current { + LogicalPlan::Projection(projection) => current = projection.input.as_ref(), + LogicalPlan::Extension(ext) => { + return ext.node.as_any().is::(); + } + _ => return false, + } + } + } + + fn rebuild_projection_chain_to_instant( + plan: &LogicalPlan, + required_columns: &HashSet, + ) -> Result { + match plan { + LogicalPlan::Projection(projection) => { + let input = Self::rebuild_projection_chain_to_instant( + projection.input.as_ref(), + required_columns, + )?; + LogicalPlanBuilder::from(input) + .project(projection.expr.clone())? + .build() + } + LogicalPlan::Extension(extension) => { + if let Some(instant) = extension.node.as_any().downcast_ref::() { + let input = + Self::prune_instant_input(extension.node.inputs()[0], required_columns)?; + return Ok(LogicalPlan::Extension(Extension { + node: Arc::new(instant.with_exprs_and_inputs(vec![], vec![input])?), + })); + } + + Ok(plan.clone()) + } + _ => Ok(plan.clone()), + } + } + + fn prune_instant_input( + plan: &LogicalPlan, + required_columns: &HashSet, + ) -> Result { + match plan { + LogicalPlan::Extension(extension) => { + if let Some(normalize) = extension.node.as_any().downcast_ref::() { + let input = + Self::prune_instant_input(extension.node.inputs()[0], required_columns)?; + return Ok(LogicalPlan::Extension(Extension { + node: Arc::new(normalize.with_exprs_and_inputs(vec![], vec![input])?), + })); + } + + if let Some(divide) = extension.node.as_any().downcast_ref::() { + let divide_input = extension.node.inputs()[0].clone(); + + let projection_exprs = divide_input + .schema() + .fields() + .iter() + .filter(|field| required_columns.contains(field.name())) + .map(|field| { + Expr::Column(datafusion_common::Column::from_name(field.name().clone())) + }) + .collect::>(); + let projected_input = LogicalPlanBuilder::from(divide_input) + .project(projection_exprs)? + .build()?; + + return Ok(LogicalPlan::Extension(Extension { + node: Arc::new( + divide.with_exprs_and_inputs(vec![], vec![projected_input])?, + ), + })); + } + + Ok(plan.clone()) + } + _ => Ok(plan.clone()), + } + } +} diff --git a/src/query/src/planner.rs b/src/query/src/planner.rs index f522dc567a..6b206b9d8d 100644 --- a/src/query/src/planner.rs +++ b/src/query/src/planner.rs @@ -278,17 +278,22 @@ impl DfLogicalPlanner { let table_provider = DfTableSourceProvider::new( self.engine_state.catalog_manager().clone(), self.engine_state.disallow_cross_catalog_query(), - query_ctx, + query_ctx.clone(), plan_decoder, self.session_state .config_options() .sql_parser .enable_ident_normalization, ); - PromPlanner::stmt_to_plan(table_provider, stmt, &self.engine_state) + let plan = PromPlanner::stmt_to_plan(table_provider, stmt, &self.engine_state) .await .map_err(BoxedError::new) - .context(QueryPlanSnafu) + .context(QueryPlanSnafu)?; + + let context = QueryEngineContext::new(self.session_state.clone(), query_ctx); + Ok(self + .engine_state + .optimize_by_extension_rules(plan, &context)?) } #[tracing::instrument(skip_all)] @@ -571,15 +576,22 @@ mod tests { use std::sync::Arc; use arrow_schema::DataType; + use catalog::RegisterTableRequest; + use catalog::memory::MemoryCatalogManager; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use datatypes::prelude::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema}; use session::context::QueryContext; + use store_api::metric_engine_consts::{ + DATA_SCHEMA_TABLE_ID_COLUMN_NAME, DATA_SCHEMA_TSID_COLUMN_NAME, LOGICAL_TABLE_METADATA_KEY, + METRIC_ENGINE_NAME, + }; use table::metadata::{TableInfoBuilder, TableMetaBuilder}; use table::test_util::EmptyTable; use super::*; - use crate::QueryEngineRef; - use crate::parser::QueryLanguageParser; + use crate::parser::{PromQuery, QueryLanguageParser}; + use crate::{QueryEngineFactory, QueryEngineRef}; async fn create_test_engine() -> QueryEngineRef { let columns = vec![ @@ -600,6 +612,109 @@ mod tests { crate::tests::new_query_engine_with_table(table) } + fn create_promql_test_engine() -> QueryEngineRef { + let catalog_manager = MemoryCatalogManager::with_default_setup(); + let physical_table_name = "phy"; + let physical_table_id = 999u32; + + let physical_schema = Arc::new(Schema::new(vec![ + ColumnSchema::new( + DATA_SCHEMA_TABLE_ID_COLUMN_NAME.to_string(), + ConcreteDataType::uint32_datatype(), + false, + ), + ColumnSchema::new( + DATA_SCHEMA_TSID_COLUMN_NAME.to_string(), + ConcreteDataType::uint64_datatype(), + false, + ), + ColumnSchema::new("tag_0", ConcreteDataType::string_datatype(), false), + ColumnSchema::new("tag_1", ConcreteDataType::string_datatype(), false), + ColumnSchema::new( + "timestamp", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ColumnSchema::new("field_0", ConcreteDataType::float64_datatype(), true), + ])); + let physical_meta = TableMetaBuilder::empty() + .schema(physical_schema) + .primary_key_indices(vec![0, 1, 2, 3]) + .value_indices(vec![4, 5]) + .engine(METRIC_ENGINE_NAME.to_string()) + .next_column_id(1024) + .build() + .unwrap(); + let physical_info = TableInfoBuilder::default() + .table_id(physical_table_id) + .name(physical_table_name) + .meta(physical_meta) + .build() + .unwrap(); + catalog_manager + .register_table_sync(RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name: physical_table_name.to_string(), + table_id: physical_table_id, + table: EmptyTable::from_table_info(&physical_info), + }) + .unwrap(); + + let mut options = table::requests::TableOptions::default(); + options.extra_options.insert( + LOGICAL_TABLE_METADATA_KEY.to_string(), + physical_table_name.to_string(), + ); + let logical_schema = Arc::new(Schema::new(vec![ + ColumnSchema::new("tag_0", ConcreteDataType::string_datatype(), false), + ColumnSchema::new("tag_1", ConcreteDataType::string_datatype(), false), + ColumnSchema::new( + "timestamp", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ColumnSchema::new("field_0", ConcreteDataType::float64_datatype(), true), + ])); + let logical_meta = TableMetaBuilder::empty() + .schema(logical_schema) + .primary_key_indices(vec![0, 1]) + .value_indices(vec![3]) + .engine(METRIC_ENGINE_NAME.to_string()) + .options(options) + .next_column_id(1024) + .build() + .unwrap(); + let logical_info = TableInfoBuilder::default() + .table_id(1024) + .name("some_metric") + .meta(logical_meta) + .build() + .unwrap(); + catalog_manager + .register_table_sync(RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name: "some_metric".to_string(), + table_id: 1024, + table: EmptyTable::from_table_info(&logical_info), + }) + .unwrap(); + + QueryEngineFactory::new( + catalog_manager, + None, + None, + None, + None, + false, + crate::options::QueryOptions::default(), + ) + .query_engine() + } + async fn parse_sql_to_plan(sql: &str) -> LogicalPlan { let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap(); let engine = create_test_engine().await; @@ -610,6 +725,25 @@ mod tests { .unwrap() } + async fn parse_promql_to_plan(query: &str) -> LogicalPlan { + let engine = create_promql_test_engine(); + let query_ctx = QueryContext::arc(); + let stmt = QueryLanguageParser::parse_promql( + &PromQuery { + query: query.to_string(), + start: "0".to_string(), + end: "10".to_string(), + step: "5s".to_string(), + lookback: "300s".to_string(), + alias: None, + }, + &query_ctx, + ) + .unwrap(); + + engine.planner().plan(&stmt, query_ctx).await.unwrap() + } + #[tokio::test] async fn test_extract_placeholder_cast_types_multiple() { let plan = parse_sql_to_plan( @@ -646,6 +780,72 @@ mod tests { assert_eq!(type_3, &Some(DataType::Int32)); } + #[tokio::test] + async fn test_plan_pql_applies_extension_rules() { + for inner_agg in ["count", "sum", "avg", "min", "max", "stddev", "stdvar"] { + let plan = parse_promql_to_plan(&format!( + "sum(irate(some_metric[1h])) / scalar(count({inner_agg}(some_metric) by (tag_0)))" + )) + .await; + let plan_str = plan.display_indent_schema().to_string(); + assert!(plan_str.contains("Distinct:"), "{inner_agg}: {plan_str}"); + } + } + + #[tokio::test] + async fn test_plan_pql_filters_null_only_groups_for_non_count_inner_aggs() { + let count_plan = parse_promql_to_plan("scalar(count(count(some_metric) by (tag_0)))").await; + let count_plan_str = count_plan.display_indent_schema().to_string(); + assert!( + !count_plan_str.contains("field_0 IS NOT NULL"), + "{count_plan_str}" + ); + + for inner_agg in ["sum", "avg", "min", "max", "stddev", "stdvar"] { + let plan = parse_promql_to_plan(&format!( + "scalar(count({inner_agg}(some_metric) by (tag_0)))" + )) + .await; + let plan_str = plan.display_indent_schema().to_string(); + assert!( + plan_str.contains("field_0 IS NOT NULL"), + "{inner_agg}: {plan_str}" + ); + } + } + + #[tokio::test] + async fn test_plan_pql_skips_extension_rules_for_non_direct_or_unsupported_inner_agg() { + for query in [ + "sum(irate(some_metric[1h])) / scalar(count(sum(irate(some_metric[1h])) by (tag_0)))", + "sum(irate(some_metric[1h])) / scalar(count(group(some_metric) by (tag_0)))", + ] { + let plan = parse_promql_to_plan(query).await; + let plan_str = plan.display_indent_schema().to_string(); + assert!(!plan_str.contains("Distinct:"), "{query}: {plan_str}"); + } + } + + #[tokio::test] + async fn test_plan_sql_does_not_apply_nested_count_rule() { + let plan = parse_sql_to_plan( + "SELECT id, count(inner_count) \ + FROM ( \ + SELECT id, count(name) AS inner_count \ + FROM test \ + GROUP BY id \ + ORDER BY id \ + LIMIT 1000000 \ + ) t \ + GROUP BY id \ + ORDER BY id", + ) + .await; + + let plan_str = plan.display_indent_schema().to_string(); + assert!(!plan_str.contains("Distinct:"), "{plan_str}"); + } + #[tokio::test] async fn test_get_inferred_parameter_types_subquery() { let plan = parse_sql_to_plan( diff --git a/src/query/src/promql/planner.rs b/src/query/src/promql/planner.rs index b6f4f2d28f..23d654d2b6 100644 --- a/src/query/src/promql/planner.rs +++ b/src/query/src/promql/planner.rs @@ -4056,6 +4056,7 @@ mod test { use table::test_util::EmptyTable; use super::*; + use crate::QueryEngineContext; use crate::options::QueryOptions; use crate::parser::QueryLanguageParser; @@ -4073,6 +4074,64 @@ mod test { ) } + async fn build_optimized_promql_plan( + table_provider: DfTableSourceProvider, + eval_stmt: &EvalStmt, + ) -> LogicalPlan { + let state = build_query_engine_state(); + let raw_plan = PromPlanner::stmt_to_plan(table_provider, eval_stmt, &state) + .await + .unwrap(); + let context = QueryEngineContext::new(state.session_state(), QueryContext::arc()); + state + .optimize_by_extension_rules(raw_plan, &context) + .unwrap() + } + + async fn build_optimized_tsid_plan( + query: &str, + num_tag: usize, + num_field: usize, + end_secs: u64, + lookback_secs: u64, + ) -> String { + let eval_stmt = EvalStmt { + expr: parser::parse(query).unwrap(), + start: UNIX_EPOCH, + end: UNIX_EPOCH + .checked_add(Duration::from_secs(end_secs)) + .unwrap(), + interval: Duration::from_secs(5), + lookback_delta: Duration::from_secs(lookback_secs), + }; + let table_provider = build_test_table_provider_with_tsid( + &[(DEFAULT_SCHEMA_NAME.to_string(), "some_metric".to_string())], + num_tag, + num_field, + ) + .await; + + build_optimized_promql_plan(table_provider, &eval_stmt) + .await + .display_indent_schema() + .to_string() + } + + async fn assert_nested_count_rewrite_applies(query: &str, expected_outer_agg: &str) { + let plan_str = build_optimized_tsid_plan(query, 2, 1, 100_000, 1).await; + + assert!(plan_str.contains("PromSeriesDivide: tags=[\"__tsid\"]")); + assert!(plan_str.contains("Projection: some_metric.timestamp, some_metric.tag_0")); + assert!(plan_str.contains("Distinct:")); + assert!(plan_str.contains(expected_outer_agg), "{plan_str}"); + assert!(!plan_str.contains("PromSeriesDivide: tags=[\"tag_0\"]")); + } + + async fn assert_nested_count_rewrite_missing(query: &str, num_tag: usize, lookback_secs: u64) { + let plan_str = build_optimized_tsid_plan(query, num_tag, 1, 100_000, lookback_secs).await; + assert!(!plan_str.contains("Distinct:"), "{plan_str}"); + } + async fn build_test_table_provider( table_name_tuples: &[(String, String)], num_tag: usize, @@ -4685,6 +4744,117 @@ mod test { ); } + #[tokio::test] + async fn scalar_count_count_range_keeps_full_window() { + let plan_str = build_optimized_tsid_plan( + "scalar(count(count(some_metric) by (tag_0)))", + 1, + 1, + 100_000, + 1, + ) + .await; + assert!(plan_str.contains("ScalarCalculate: tags=[]")); + assert!(plan_str.contains("PromInstantManipulate: range=[0..100000000]")); + assert!(!plan_str.contains("PromInstantManipulate: range=[99999000..99999000]")); + } + + #[tokio::test] + async fn scalar_count_count_rewrite_applies_inside_binary_expr_for_tsid_input() { + let plan_str = build_optimized_tsid_plan( + "sum(irate(some_metric[1h])) / scalar(count(count(some_metric) by (tag_0)))", + 2, + 1, + 10, + 300, + ) + .await; + assert!(plan_str.contains("Distinct:"), "{plan_str}"); + } + + #[tokio::test] + async fn nested_count_rewrite_keeps_full_series_key_with_tsid_input() { + assert_nested_count_rewrite_applies( + "count(count(some_metric) by (tag_0))", + "Aggregate: groupBy=[[some_metric.timestamp]], aggr=[[count(Int64(1)) AS count(count(some_metric.field_0))]]" + ) + .await; + } + + #[tokio::test] + async fn nested_sum_count_rewrite_keeps_full_series_key_with_tsid_input() { + assert_nested_count_rewrite_applies( + "count(sum(some_metric) by (tag_0))", + "Aggregate: groupBy=[[some_metric.timestamp]], aggr=[[count(Int64(1)) AS count(sum(some_metric.field_0))]]" + ) + .await; + } + + #[tokio::test] + async fn nested_supported_inner_aggs_rewrite_apply_for_tsid_input() { + for (query, expected_outer_agg) in [ + ( + "count(avg(some_metric) by (tag_0))", + "Aggregate: groupBy=[[some_metric.timestamp]], aggr=[[count(Int64(1)) AS count(avg(some_metric.field_0))]]", + ), + ( + "count(min(some_metric) by (tag_0))", + "Aggregate: groupBy=[[some_metric.timestamp]], aggr=[[count(Int64(1)) AS count(min(some_metric.field_0))]]", + ), + ( + "count(max(some_metric) by (tag_0))", + "Aggregate: groupBy=[[some_metric.timestamp]], aggr=[[count(Int64(1)) AS count(max(some_metric.field_0))]]", + ), + ( + "count(stddev(some_metric) by (tag_0))", + "Aggregate: groupBy=[[some_metric.timestamp]], aggr=[[count(Int64(1)) AS count(stddev_pop(some_metric.field_0))]]", + ), + ( + "count(stdvar(some_metric) by (tag_0))", + "Aggregate: groupBy=[[some_metric.timestamp]], aggr=[[count(Int64(1)) AS count(var_pop(some_metric.field_0))]]", + ), + ] { + assert_nested_count_rewrite_applies(query, expected_outer_agg).await; + } + } + + #[tokio::test] + async fn nested_non_count_inner_aggs_rewrite_filter_null_values_for_tsid_input() { + let count_plan = + build_optimized_tsid_plan("count(count(some_metric) by (tag_0))", 2, 1, 100_000, 1) + .await; + assert!( + !count_plan.contains("some_metric.field_0 IS NOT NULL"), + "{count_plan}" + ); + + for query in [ + "count(sum(some_metric) by (tag_0))", + "count(avg(some_metric) by (tag_0))", + "count(min(some_metric) by (tag_0))", + "count(max(some_metric) by (tag_0))", + "count(stddev(some_metric) by (tag_0))", + "count(stdvar(some_metric) by (tag_0))", + ] { + let plan_str = build_optimized_tsid_plan(query, 2, 1, 100_000, 1).await; + assert!( + plan_str.contains("Filter: some_metric.field_0 IS NOT NULL"), + "{query}: {plan_str}" + ); + } + } + + #[tokio::test] + async fn nested_unsupported_or_non_direct_inner_aggs_do_not_rewrite() { + assert_nested_count_rewrite_missing("count(group(some_metric) by (tag_0))", 2, 1).await; + assert_nested_count_rewrite_missing( + "count(sum(irate(some_metric[1h])) by (tag_0))", + 2, + 300, + ) + .await; + } + #[tokio::test] async fn physical_table_name_is_not_leaked_in_plan() { let prom_expr = parser::parse("some_metric").unwrap(); diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index a45fc4c896..f696c8b53e 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -60,6 +60,7 @@ use crate::dist_plan::{ use crate::metrics::{QUERY_MEMORY_POOL_REJECTED_TOTAL, QUERY_MEMORY_POOL_USAGE_BYTES}; use crate::optimizer::ExtensionAnalyzerRule; use crate::optimizer::constant_term::MatchesConstantTermOptimizer; +use crate::optimizer::count_nest_aggr::CountNestAggrRule; use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule; use crate::optimizer::parallelize_scan::ParallelizeScan; use crate::optimizer::pass_distribution::PassDistribution; @@ -146,6 +147,7 @@ impl QueryEngineState { // The [`TypeConversionRule`] must be at first extension_rules.insert(0, Arc::new(TypeConversionRule) as _); + extension_rules.push(Arc::new(CountNestAggrRule) as _); // Apply the datafusion rules let mut analyzer = Analyzer::new(); diff --git a/tests/cases/standalone/common/promql/scalar.result b/tests/cases/standalone/common/promql/scalar.result index c5c3e5ebd1..c3292b4f5c 100644 --- a/tests/cases/standalone/common/promql/scalar.result +++ b/tests/cases/standalone/common/promql/scalar.result @@ -136,6 +136,42 @@ TQL EVAL (0, 15, '5s') scalar(count(count(host) by (host))); | 1970-01-01T00:00:15 | 2.0 | +---------------------+--------------------------------+ +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 15, '5s') scalar(count(sum(host) by (host))); + ++---------------------+------------------------------+ +| ts | scalar(count(sum(host.val))) | ++---------------------+------------------------------+ +| 1970-01-01T00:00:00 | 2.0 | +| 1970-01-01T00:00:05 | 2.0 | +| 1970-01-01T00:00:10 | 2.0 | +| 1970-01-01T00:00:15 | 2.0 | ++---------------------+------------------------------+ + +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 15, '5s') scalar(count(avg(host) by (host))); + ++---------------------+------------------------------+ +| ts | scalar(count(avg(host.val))) | ++---------------------+------------------------------+ +| 1970-01-01T00:00:00 | 2.0 | +| 1970-01-01T00:00:05 | 2.0 | +| 1970-01-01T00:00:10 | 2.0 | +| 1970-01-01T00:00:15 | 2.0 | ++---------------------+------------------------------+ + +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 15, '5s') scalar(count(stddev(host) by (host))); + ++---------------------+-------------------------------------+ +| ts | scalar(count(stddev_pop(host.val))) | ++---------------------+-------------------------------------+ +| 1970-01-01T00:00:00 | 2.0 | +| 1970-01-01T00:00:05 | 2.0 | +| 1970-01-01T00:00:10 | 2.0 | +| 1970-01-01T00:00:15 | 2.0 | ++---------------------+-------------------------------------+ + -- SQLNESS SORT_RESULT 3 1 TQL EVAL (0, 15, '5s') scalar(host{host="host1"} + scalar(host{host="host2"})); @@ -516,7 +552,99 @@ TQL EVAL (0, 15, '5s') clamp_max(clamp(host{host="host1"}, 0, 15), 6); | 1970-01-01T00:00:15 | 6.0 | host1 | +---------------------+---------------------------------------------------------+-------+ -Drop table host; +DROP TABLE host; + +Affected Rows: 0 + +CREATE TABLE presence_metric ( + ts timestamp(3) time index, + instance STRING, + cpu STRING, + shard STRING, + val DOUBLE, + PRIMARY KEY (instance, cpu, shard), +); + +Affected Rows: 0 + +INSERT INTO TABLE presence_metric VALUES + (0, 'i1', 'cpu0', 'a', 1.0), + (0, 'i1', 'cpu0', 'b', 2.0), + (0, 'i1', 'cpu1', 'a', 10.0), + (0, 'i1', 'cpu2', 'a', 20.0), + (0, 'i2', 'cpu9', 'a', 100.0), + (200000, 'i1', 'cpu0', 'a', 'NAN'::DOUBLE), + (200000, 'i1', 'cpu0', 'b', 'NAN'::DOUBLE), + (200000, 'i1', 'cpu1', 'a', 11.0), + (200000, 'i1', 'cpu2', 'a', NULL), + (200000, 'i2', 'cpu9', 'a', 101.0), + (400000, 'i1', 'cpu1', 'a', 12.0), + (400000, 'i2', 'cpu9', 'a', 102.0), + (600000, 'i1', 'cpu0', 'a', 7.0), + (600000, 'i1', 'cpu0', 'b', 8.0), + (600000, 'i2', 'cpu9', 'a', 103.0); + +Affected Rows: 15 + +-- NaN drops `cpu0` from the grouped count, while the NULL sample on `cpu2` +-- still leaves a zero-valued row in `count(...) by (cpu)`. +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 600, '200s') count(presence_metric{instance="i1"}) by (cpu); + ++------+---------------------+----------------------------+ +| cpu | ts | count(presence_metric.val) | ++------+---------------------+----------------------------+ +| cpu0 | 1970-01-01T00:00:00 | 2 | +| cpu0 | 1970-01-01T00:10:00 | 2 | +| cpu1 | 1970-01-01T00:00:00 | 1 | +| cpu1 | 1970-01-01T00:03:20 | 1 | +| cpu1 | 1970-01-01T00:06:40 | 1 | +| cpu1 | 1970-01-01T00:10:00 | 1 | +| cpu2 | 1970-01-01T00:00:00 | 1 | +| cpu2 | 1970-01-01T00:03:20 | 0 | +| cpu2 | 1970-01-01T00:06:40 | 0 | ++------+---------------------+----------------------------+ + +-- Nested-count rewrite should preserve grouped presence after stale-NaN filtering and null-value pruning. +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 600, '200s') scalar(count(count(presence_metric{instance="i1"}) by (cpu))); + ++---------------------+-------------------------------------------+ +| ts | scalar(count(count(presence_metric.val))) | ++---------------------+-------------------------------------------+ +| 1970-01-01T00:00:00 | 3.0 | +| 1970-01-01T00:03:20 | 2.0 | +| 1970-01-01T00:06:40 | 2.0 | +| 1970-01-01T00:10:00 | 2.0 | ++---------------------+-------------------------------------------+ + +-- Non-count inner aggregates must drop NULL-only groups before the outer count. +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 600, '200s') scalar(count(sum(presence_metric{instance="i1"}) by (cpu))); + ++---------------------+-----------------------------------------+ +| ts | scalar(count(sum(presence_metric.val))) | ++---------------------+-----------------------------------------+ +| 1970-01-01T00:00:00 | 3.0 | +| 1970-01-01T00:03:20 | 1.0 | +| 1970-01-01T00:06:40 | 1.0 | +| 1970-01-01T00:10:00 | 2.0 | ++---------------------+-----------------------------------------+ + +-- False case: outer `by (instance)` keeps multiple series at the scalar input, so scalar should still yield NaN. +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 600, '200s') scalar(count(count(presence_metric) by (instance, cpu)) by (instance)); + ++---------------------+-------------------------------------------+ +| ts | scalar(count(count(presence_metric.val))) | ++---------------------+-------------------------------------------+ +| 1970-01-01T00:00:00 | NaN | +| 1970-01-01T00:03:20 | NaN | +| 1970-01-01T00:06:40 | NaN | +| 1970-01-01T00:10:00 | NaN | ++---------------------+-------------------------------------------+ + +DROP TABLE presence_metric; Affected Rows: 0 diff --git a/tests/cases/standalone/common/promql/scalar.sql b/tests/cases/standalone/common/promql/scalar.sql index b4007bbf15..662f9665fe 100644 --- a/tests/cases/standalone/common/promql/scalar.sql +++ b/tests/cases/standalone/common/promql/scalar.sql @@ -43,6 +43,15 @@ TQL EVAL (0, 15, '5s') scalar(host{host="host1"}) + host; -- SQLNESS SORT_RESULT 3 1 TQL EVAL (0, 15, '5s') scalar(count(count(host) by (host))); +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 15, '5s') scalar(count(sum(host) by (host))); + +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 15, '5s') scalar(count(avg(host) by (host))); + +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 15, '5s') scalar(count(stddev(host) by (host))); + -- SQLNESS SORT_RESULT 3 1 TQL EVAL (0, 15, '5s') scalar(host{host="host1"} + scalar(host{host="host2"})); @@ -149,4 +158,49 @@ TQL EVAL (0, 15, '5s') clamp(clamp_min(host{host="host1"}, 1), 0, 12); -- SQLNESS SORT_RESULT 3 1 TQL EVAL (0, 15, '5s') clamp_max(clamp(host{host="host1"}, 0, 15), 6); -Drop table host; +DROP TABLE host; + +CREATE TABLE presence_metric ( + ts timestamp(3) time index, + instance STRING, + cpu STRING, + shard STRING, + val DOUBLE, + PRIMARY KEY (instance, cpu, shard), +); + +INSERT INTO TABLE presence_metric VALUES + (0, 'i1', 'cpu0', 'a', 1.0), + (0, 'i1', 'cpu0', 'b', 2.0), + (0, 'i1', 'cpu1', 'a', 10.0), + (0, 'i1', 'cpu2', 'a', 20.0), + (0, 'i2', 'cpu9', 'a', 100.0), + (200000, 'i1', 'cpu0', 'a', 'NAN'::DOUBLE), + (200000, 'i1', 'cpu0', 'b', 'NAN'::DOUBLE), + (200000, 'i1', 'cpu1', 'a', 11.0), + (200000, 'i1', 'cpu2', 'a', NULL), + (200000, 'i2', 'cpu9', 'a', 101.0), + (400000, 'i1', 'cpu1', 'a', 12.0), + (400000, 'i2', 'cpu9', 'a', 102.0), + (600000, 'i1', 'cpu0', 'a', 7.0), + (600000, 'i1', 'cpu0', 'b', 8.0), + (600000, 'i2', 'cpu9', 'a', 103.0); + +-- NaN drops `cpu0` from the grouped count, while the NULL sample on `cpu2` +-- still leaves a zero-valued row in `count(...) by (cpu)`. +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 600, '200s') count(presence_metric{instance="i1"}) by (cpu); + +-- Nested-count rewrite should preserve grouped presence after stale-NaN filtering and null-value pruning. +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 600, '200s') scalar(count(count(presence_metric{instance="i1"}) by (cpu))); + +-- Non-count inner aggregates must drop NULL-only groups before the outer count. +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 600, '200s') scalar(count(sum(presence_metric{instance="i1"}) by (cpu))); + +-- False case: outer `by (instance)` keeps multiple series at the scalar input, so scalar should still yield NaN. +-- SQLNESS SORT_RESULT 3 1 +TQL EVAL (0, 600, '200s') scalar(count(count(presence_metric) by (instance, cpu)) by (instance)); + +DROP TABLE presence_metric; diff --git a/tests/cases/standalone/tql-explain-analyze/tsid_column.result b/tests/cases/standalone/tql-explain-analyze/tsid_column.result index 84544b1655..4a7a875060 100644 --- a/tests/cases/standalone/tql-explain-analyze/tsid_column.result +++ b/tests/cases/standalone/tql-explain-analyze/tsid_column.result @@ -112,10 +112,63 @@ TQL ANALYZE (0, 10, '5s') sum(irate(tsid_metric[1h])) / scalar(count(count(tsid |_|_|_AggregateExec: mode=FinalPartitioned, gby=[ts@0 as ts], aggr=[count(count(tsid_metric.val))] REDACTED |_|_|_RepartitionExec: partitioning=REDACTED |_|_|_AggregateExec: mode=Partial, gby=[ts@0 as ts], aggr=[count(count(tsid_metric.val))] REDACTED -|_|_|_ProjectionExec: expr=[ts@1 as ts, count(tsid_metric.val)@2 as count(tsid_metric.val)] REDACTED -|_|_|_AggregateExec: mode=FinalPartitioned, gby=[job@0 as job, ts@1 as ts], aggr=[count(tsid_metric.val)] REDACTED +|_|_|_ProjectionExec: expr=[ts@0 as ts] REDACTED +|_|_|_AggregateExec: mode=FinalPartitioned, gby=[ts@0 as ts, job@1 as job], aggr=[] REDACTED |_|_|_RepartitionExec: partitioning=REDACTED -|_|_|_AggregateExec: mode=Partial, gby=[job@1 as job, ts@2 as ts], aggr=[count(tsid_metric.val)] REDACTED +|_|_|_AggregateExec: mode=Partial, gby=[ts@0 as ts, job@1 as job], aggr=[] REDACTED +|_|_|_ProjectionExec: expr=[ts@3 as ts, job@1 as job] REDACTED +|_|_|_PromInstantManipulateExec: range=[0..10000], lookback=[300000], interval=[5000], time index=[ts] REDACTED +|_|_|_PromSeriesDivideExec: tags=["__tsid"] REDACTED +|_|_|_ProjectionExec: expr=[val@1 as val, job@3 as job, __tsid@2 as __tsid, ts@0 as ts] REDACTED +|_|_|_SeriesScan: region=REDACTED, "partition_count":{"count":1, "mem_ranges":1, "files":0, "file_ranges":0}, "distribution":"PerSeries" REDACTED +|_|_|_| +| 1_| 0_|_SortPreservingMergeExec: [ts@0 ASC NULLS LAST] REDACTED +|_|_|_SortExec: expr=[ts@0 ASC NULLS LAST], preserve_partitioning=[true] REDACTED +|_|_|_AggregateExec: mode=FinalPartitioned, gby=[ts@0 as ts], aggr=[sum(prom_irate(ts_range,val))] REDACTED +|_|_|_RepartitionExec: partitioning=REDACTED +|_|_|_AggregateExec: mode=Partial, gby=[ts@0 as ts], aggr=[sum(prom_irate(ts_range,val))] REDACTED +|_|_|_FilterExec: prom_irate(ts_range,val)@1 IS NOT NULL REDACTED +|_|_|_ProjectionExec: expr=[ts@2 as ts, prom_irate(ts_range@3, val@0) as prom_irate(ts_range,val)] REDACTED +|_|_|_PromRangeManipulateExec: req range=[0..10000], interval=[5000], eval range=[3600000], time index=[ts] REDACTED +|_|_|_PromSeriesNormalizeExec: offset=[0], time index=[ts], filter NaN: [true] REDACTED +|_|_|_PromSeriesDivideExec: tags=["__tsid"] REDACTED +|_|_|_ProjectionExec: expr=[val@1 as val, __tsid@2 as __tsid, ts@0 as ts] REDACTED +|_|_|_SeriesScan: region=REDACTED, "partition_count":{"count":1, "mem_ranges":1, "files":0, "file_ranges":0}, "distribution":"PerSeries" REDACTED +|_|_|_| +|_|_| Total rows: 2_| ++-+-+-+ + +-- SQLNESS REPLACE (metrics.*) REDACTED +-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED +-- SQLNESS REPLACE (-+) - +-- SQLNESS REPLACE (\s\s+) _ +-- SQLNESS REPLACE (peers.*) REDACTED +-- SQLNESS REPLACE region=\d+\(\d+,\s+\d+\) region=REDACTED +-- SQLNESS REPLACE (Hash.*) REDACTED +TQL ANALYZE (0, 10, '5s') sum(irate(tsid_metric[1h])) / scalar(count(sum(tsid_metric) by (job))); + ++-+-+-+ +| stage | node | plan_| ++-+-+-+ +| 0_| 0_|_ProjectionExec: expr=[ts@1 as ts, sum(prom_irate(ts_range,val))@2 / scalar(count(sum(tsid_metric.val)))@0 as lhs.sum(prom_irate(ts_range,val)) / rhs.scalar(count(sum(tsid_metric.val)))] REDACTED +|_|_|_REDACTED +|_|_|_ScalarCalculateExec: tags=[] REDACTED +|_|_|_CoalescePartitionsExec REDACTED +|_|_|_MergeScanExec: REDACTED +|_|_|_CooperativeExec REDACTED +|_|_|_MergeScanExec: REDACTED +|_|_|_| +| 1_| 0_|_SortPreservingMergeExec: [ts@0 ASC NULLS LAST] REDACTED +|_|_|_SortExec: expr=[ts@0 ASC NULLS LAST], preserve_partitioning=[true] REDACTED +|_|_|_AggregateExec: mode=FinalPartitioned, gby=[ts@0 as ts], aggr=[count(sum(tsid_metric.val))] REDACTED +|_|_|_RepartitionExec: partitioning=REDACTED +|_|_|_AggregateExec: mode=Partial, gby=[ts@0 as ts], aggr=[count(sum(tsid_metric.val))] REDACTED +|_|_|_ProjectionExec: expr=[ts@0 as ts] REDACTED +|_|_|_AggregateExec: mode=FinalPartitioned, gby=[ts@0 as ts, job@1 as job], aggr=[] REDACTED +|_|_|_RepartitionExec: partitioning=REDACTED +|_|_|_AggregateExec: mode=Partial, gby=[ts@0 as ts, job@1 as job], aggr=[] REDACTED +|_|_|_ProjectionExec: expr=[ts@1 as ts, job@0 as job] REDACTED +|_|_|_FilterExec: val@0 IS NOT NULL, projection=[job@1, ts@2] REDACTED |_|_|_ProjectionExec: expr=[val@0 as val, job@1 as job, ts@3 as ts] REDACTED |_|_|_PromInstantManipulateExec: range=[0..10000], lookback=[300000], interval=[5000], time index=[ts] REDACTED |_|_|_PromSeriesDivideExec: tags=["__tsid"] REDACTED diff --git a/tests/cases/standalone/tql-explain-analyze/tsid_column.sql b/tests/cases/standalone/tql-explain-analyze/tsid_column.sql index 7b3de23f33..dedce2dfb1 100644 --- a/tests/cases/standalone/tql-explain-analyze/tsid_column.sql +++ b/tests/cases/standalone/tql-explain-analyze/tsid_column.sql @@ -51,6 +51,14 @@ TQL ANALYZE (0, 10, '5s') sum by (job, instance) (tsid_metric); -- SQLNESS REPLACE (Hash.*) REDACTED TQL ANALYZE (0, 10, '5s') sum(irate(tsid_metric[1h])) / scalar(count(count(tsid_metric) by (job))); +-- SQLNESS REPLACE (metrics.*) REDACTED +-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED +-- SQLNESS REPLACE (-+) - +-- SQLNESS REPLACE (\s\s+) _ +-- SQLNESS REPLACE (peers.*) REDACTED +-- SQLNESS REPLACE region=\d+\(\d+,\s+\d+\) region=REDACTED +-- SQLNESS REPLACE (Hash.*) REDACTED +TQL ANALYZE (0, 10, '5s') sum(irate(tsid_metric[1h])) / scalar(count(sum(tsid_metric) by (job))); + DROP TABLE tsid_metric; DROP TABLE tsid_physical; -