diff --git a/src/query/src/optimizer.rs b/src/query/src/optimizer.rs index b85320a495..ffbfff5ee2 100644 --- a/src/query/src/optimizer.rs +++ b/src/query/src/optimizer.rs @@ -19,6 +19,7 @@ pub mod count_wildcard; pub(crate) mod json_type_concretize; pub mod parallelize_scan; pub mod pass_distribution; +pub mod promql_tsid_narrow_join; pub mod remove_duplicate; pub mod scan_hint; pub mod string_normalization; diff --git a/src/query/src/optimizer/promql_tsid_narrow_join.rs b/src/query/src/optimizer/promql_tsid_narrow_join.rs new file mode 100644 index 0000000000..419415662e --- /dev/null +++ b/src/query/src/optimizer/promql_tsid_narrow_join.rs @@ -0,0 +1,271 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow_schema::{DataType, SchemaRef}; +use datafusion::config::ConfigOptions; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; +use datafusion_common::Result as DfResult; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_expr::JoinType; +use datafusion_physical_expr::expressions::Column; +use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME; + +/// Chooses a broadcast-style hash join for the PromQL vector-vector shape where +/// the build side only carries value, `__tsid`, and timestamp columns. +/// +/// PromQL arithmetic joins often keep one side narrow (without raw labels) and the other side wide +/// with all output labels. Partitioning both sides shuffles the wide stream. +/// `CollectLeft` only gathers the narrow build side and lets the wide probe side +/// keep its existing partitioning. +#[derive(Debug)] +pub struct PromqlTsidNarrowJoin; + +impl PhysicalOptimizerRule for PromqlTsidNarrowJoin { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> DfResult> { + plan.transform_up(Self::rewrite_join).data() + } + + fn name(&self) -> &str { + "PromqlTsidNarrowJoin" + } + + fn schema_check(&self) -> bool { + true + } +} + +impl PromqlTsidNarrowJoin { + fn rewrite_join(plan: Arc) -> DfResult>> { + let Some(hash_join) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + if !Self::should_collect_left(hash_join) { + return Ok(Transformed::no(plan)); + } + + Ok(Transformed::yes( + hash_join + .builder() + .with_partition_mode(PartitionMode::CollectLeft) + .reset_state() + .build_exec()?, + )) + } + + fn should_collect_left(hash_join: &HashJoinExec) -> bool { + hash_join.partition_mode() == &PartitionMode::Partitioned + && hash_join.join_type() == &JoinType::Inner + && hash_join.filter().is_none() + && hash_join.right().schema().fields().len() > hash_join.left().schema().fields().len() + && Self::is_promql_value_tsid_time_schema(&hash_join.left().schema()) + && Self::joins_on_tsid_and_time(hash_join) + } + + fn is_promql_value_tsid_time_schema(schema: &SchemaRef) -> bool { + let mut has_value = false; + let mut has_tsid = false; + let mut has_time = false; + + for field in schema.fields() { + match field.name().as_str() { + "greptime_value" => has_value = true, + DATA_SCHEMA_TSID_COLUMN_NAME => has_tsid = true, + _ if matches!(field.data_type(), DataType::Timestamp(_, _)) => has_time = true, + _ => return false, + } + } + + has_value && has_tsid && has_time + } + + fn joins_on_tsid_and_time(hash_join: &HashJoinExec) -> bool { + let mut has_tsid = false; + let mut has_time = false; + + for (left, right) in hash_join.on() { + let (Some(left_col), Some(right_col)) = ( + left.as_any().downcast_ref::(), + right.as_any().downcast_ref::(), + ) else { + return false; + }; + + if left_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME + && right_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME + { + has_tsid = true; + } else if matches!( + hash_join + .left() + .schema() + .field(left_col.index()) + .data_type(), + DataType::Timestamp(_, _) + ) && matches!( + hash_join + .right() + .schema() + .field(right_col.index()) + .data_type(), + DataType::Timestamp(_, _) + ) { + has_time = true; + } + } + + has_tsid && has_time + } +} + +#[cfg(test)] +mod tests { + use arrow_schema::{DataType, Field, Schema, TimeUnit}; + use datafusion::common::NullEquality; + use datafusion::physical_optimizer::PhysicalOptimizerRule; + use datafusion::physical_plan::displayable; + use datafusion::physical_plan::empty::EmptyExec; + use datafusion::physical_plan::joins::HashJoinExec; + use datafusion_common::config::ConfigOptions; + use datafusion_physical_expr::PhysicalExpr; + + use super::*; + + #[test] + fn chooses_collect_left_for_narrow_promql_build_side() { + let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("greptime_value", DataType::Float64, true), + Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false), + Field::new( + "greptime_timestamp", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + ])))) as Arc; + let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("greptime_value", DataType::Float64, true), + Field::new("host", DataType::Utf8, true), + Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false), + Field::new( + "greptime_timestamp", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + ])))) as Arc; + let on = vec![ + ( + Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1)) as Arc, + Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2)) as Arc, + ), + ( + Arc::new(Column::new("greptime_timestamp", 2)) as Arc, + Arc::new(Column::new("greptime_timestamp", 3)) as Arc, + ), + ]; + let join = Arc::new( + HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + Some(vec![0, 3, 4, 5, 6]), + PartitionMode::Partitioned, + NullEquality::NullEqualsNull, + false, + ) + .unwrap(), + ) as Arc; + let original_schema = join.schema(); + + let optimized = PromqlTsidNarrowJoin + .optimize(join, &ConfigOptions::default()) + .unwrap(); + let optimized_join = optimized.as_any().downcast_ref::().unwrap(); + + assert_eq!(optimized_join.partition_mode(), &PartitionMode::CollectLeft); + assert_eq!(optimized.schema(), original_schema); + assert!( + displayable(optimized.as_ref()) + .one_line() + .to_string() + .contains( + "projection=[greptime_value@0, greptime_value@3, host@4, __tsid@5, greptime_timestamp@6]" + ) + ); + } + + #[test] + fn keeps_partitioned_join_when_left_side_carries_labels() { + let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("greptime_value", DataType::Float64, true), + Field::new("host", DataType::Utf8, true), + Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false), + Field::new( + "greptime_timestamp", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + ])))) as Arc; + let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("greptime_value", DataType::Float64, true), + Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false), + Field::new( + "greptime_timestamp", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + ])))) as Arc; + let join = Arc::new( + HashJoinExec::try_new( + left, + right, + vec![ + ( + Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2)) + as Arc, + Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1)) + as Arc, + ), + ( + Arc::new(Column::new("greptime_timestamp", 3)) as Arc, + Arc::new(Column::new("greptime_timestamp", 2)) as Arc, + ), + ], + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNull, + false, + ) + .unwrap(), + ) as Arc; + + let optimized = PromqlTsidNarrowJoin + .optimize(join, &ConfigOptions::default()) + .unwrap(); + let optimized_join = optimized.as_any().downcast_ref::().unwrap(); + + assert_eq!(optimized_join.partition_mode(), &PartitionMode::Partitioned); + } +} diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 45a5700781..4262428091 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -66,6 +66,7 @@ use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule; use crate::optimizer::json_type_concretize::JsonTypeConcretizeRule; use crate::optimizer::parallelize_scan::ParallelizeScan; use crate::optimizer::pass_distribution::PassDistribution; +use crate::optimizer::promql_tsid_narrow_join::PromqlTsidNarrowJoin; use crate::optimizer::remove_duplicate::RemoveDuplicate; use crate::optimizer::scan_hint::ScanHintRule; use crate::optimizer::string_normalization::StringNormalizationRule; @@ -189,9 +190,13 @@ impl QueryEngineState { physical_optimizer .rules .insert(6, Arc::new(PassDistribution)); + // Prefer collecting narrow PromQL build sides over repartitioning wide label streams. + physical_optimizer + .rules + .insert(7, Arc::new(PromqlTsidNarrowJoin)); // Enforce sorting AFTER custom rules that modify the plan structure physical_optimizer.rules.insert( - 7, + 8, Arc::new(datafusion::physical_optimizer::enforce_sorting::EnforceSorting {}), ); // Add rule for windowed sort diff --git a/tests/cases/standalone/common/promql/tsid_binary_join_regression.result b/tests/cases/standalone/common/promql/tsid_binary_join_regression.result index 3640291dc3..d414eb6bba 100644 --- a/tests/cases/standalone/common/promql/tsid_binary_join_regression.result +++ b/tests/cases/standalone/common/promql/tsid_binary_join_regression.result @@ -71,11 +71,11 @@ TQL ANALYZE (0, 5, '5s') tsid_binary_join_left / tsid_binary_join_right; | stage | node | plan_| +-+-+-+ | 0_| 0_|_ProjectionExec: expr=[host@2 as host, job@3 as job, ts@5 as ts, __tsid@4 as __tsid, greptime_value@0 / greptime_value@1 as tsid_binary_join_left.greptime_value / tsid_binary_join_right.greptime_value] REDACTED -|_|_|_HashJoinExec: mode=Partitioned, join_type=Inner, on=[(__tsid@1, __tsid@3), (ts@2, ts@4)], projection=[greptime_value@0, greptime_value@3, host@4, job@5, __tsid@6, ts@7], NullsEqual: true REDACTED -|_|_|_RepartitionExec: partitioning=Hash([__tsid@1, ts@2],REDACTED +|_|_|_HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(__tsid@1, __tsid@3), (ts@2, ts@4)], projection=[greptime_value@0, greptime_value@3, host@4, job@5, __tsid@6, ts@7], NullsEqual: true REDACTED +|_|_|_CoalescePartitionsExec REDACTED |_|_|_ProjectionExec: expr=[greptime_value@0 as greptime_value, __tsid@3 as __tsid, ts@4 as ts] REDACTED |_|_|_MergeScanExec: REDACTED -|_|_|_RepartitionExec: partitioning=Hash([__tsid@3, ts@4],REDACTED +|_|_|_CooperativeExec REDACTED |_|_|_MergeScanExec: REDACTED |_|_|_| | 1_| 0_|_PromInstantManipulateExec: range=[0..5000], lookback=[300000], interval=[5000], time index=[ts] REDACTED @@ -189,11 +189,11 @@ TQL ANALYZE (0, 5, '5s') tsid_binary_join_left > bool tsid_binary_join_right; | stage | node | plan_| +-+-+-+ | 0_| 0_|_ProjectionExec: expr=[host@2 as host, job@3 as job, ts@5 as ts, __tsid@4 as __tsid, CAST(greptime_value@1 < greptime_value@0 AS Float64) as tsid_binary_join_left.greptime_value > tsid_binary_join_right.greptime_value] REDACTED -|_|_|_HashJoinExec: mode=Partitioned, join_type=Inner, on=[(__tsid@1, __tsid@3), (ts@2, ts@4)], projection=[greptime_value@0, greptime_value@3, host@4, job@5, __tsid@6, ts@7], NullsEqual: true REDACTED -|_|_|_RepartitionExec: partitioning=Hash([__tsid@1, ts@2],REDACTED +|_|_|_HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(__tsid@1, __tsid@3), (ts@2, ts@4)], projection=[greptime_value@0, greptime_value@3, host@4, job@5, __tsid@6, ts@7], NullsEqual: true REDACTED +|_|_|_CoalescePartitionsExec REDACTED |_|_|_ProjectionExec: expr=[greptime_value@0 as greptime_value, __tsid@3 as __tsid, ts@4 as ts] REDACTED |_|_|_MergeScanExec: REDACTED -|_|_|_RepartitionExec: partitioning=Hash([__tsid@3, ts@4],REDACTED +|_|_|_CooperativeExec REDACTED |_|_|_MergeScanExec: REDACTED |_|_|_| | 1_| 0_|_PromInstantManipulateExec: range=[0..5000], lookback=[300000], interval=[5000], time index=[ts] REDACTED diff --git a/tests/cases/standalone/common/tql-explain-analyze/explain.result b/tests/cases/standalone/common/tql-explain-analyze/explain.result index 65532d738b..e60a6b74f6 100644 --- a/tests/cases/standalone/common/tql-explain-analyze/explain.result +++ b/tests/cases/standalone/common/tql-explain-analyze/explain.result @@ -182,6 +182,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test; | physical_plan after FilterPushdown_| SAME TEXT AS ABOVE_| | physical_plan after parallelize_scan_| SAME TEXT AS ABOVE_| | physical_plan after PassDistributionRule_| SAME TEXT AS ABOVE_| +| physical_plan after PromqlTsidNarrowJoin_| SAME TEXT AS ABOVE_| | physical_plan after EnforceSorting_| SAME TEXT AS ABOVE_| | physical_plan after EnforceDistribution_| SAME TEXT AS ABOVE_| | physical_plan after CombinePartialFinalAggregate_| SAME TEXT AS ABOVE_| @@ -332,6 +333,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test AS series; | physical_plan after FilterPushdown_| SAME TEXT AS ABOVE_| | physical_plan after parallelize_scan_| SAME TEXT AS ABOVE_| | physical_plan after PassDistributionRule_| SAME TEXT AS ABOVE_| +| physical_plan after PromqlTsidNarrowJoin_| SAME TEXT AS ABOVE_| | physical_plan after EnforceSorting_| SAME TEXT AS ABOVE_| | physical_plan after EnforceDistribution_| SAME TEXT AS ABOVE_| | physical_plan after CombinePartialFinalAggregate_| SAME TEXT AS ABOVE_| @@ -654,6 +656,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test_nano; | physical_plan after FilterPushdown_| SAME TEXT AS ABOVE_| | physical_plan after parallelize_scan_| SAME TEXT AS ABOVE_| | physical_plan after PassDistributionRule_| SAME TEXT AS ABOVE_| +| physical_plan after PromqlTsidNarrowJoin_| SAME TEXT AS ABOVE_| | physical_plan after EnforceSorting_| OutputRequirementExec: order_by=[], dist_by=Unspecified_| |_|_PromInstantManipulateExec: range=[0..10000], lookback=[300000], interval=[5000], time index=[j]_| |_|_PromSeriesDivideExec: tags=["k"]_|