Skip to main content

query/optimizer/
pass_distribution.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::config::ConfigOptions;
18use datafusion::physical_optimizer::PhysicalOptimizerRule;
19use datafusion::physical_plan::ExecutionPlan;
20use datafusion::physical_plan::projection::ProjectionExec;
21use datafusion_common::Result as DfResult;
22use datafusion_physical_expr::Distribution;
23use datafusion_physical_expr::utils::map_columns_before_projection;
24
25use crate::dist_plan::MergeScanExec;
26
27/// This is a [`PhysicalOptimizerRule`] to pass distribution requirement to
28/// [`MergeScanExec`] to avoid unnecessary shuffling.
29///
30/// This rule is expected to be run before [`EnforceDistribution`].
31///
32/// [`EnforceDistribution`]: datafusion::physical_optimizer::enforce_distribution::EnforceDistribution
33/// [`MergeScanExec`]: crate::dist_plan::MergeScanExec
34#[derive(Debug)]
35pub struct PassDistribution;
36
37impl PhysicalOptimizerRule for PassDistribution {
38    fn optimize(
39        &self,
40        plan: Arc<dyn ExecutionPlan>,
41        config: &ConfigOptions,
42    ) -> DfResult<Arc<dyn ExecutionPlan>> {
43        Self::do_optimize(plan, config)
44    }
45
46    fn name(&self) -> &str {
47        "PassDistributionRule"
48    }
49
50    fn schema_check(&self) -> bool {
51        false
52    }
53}
54
55impl PassDistribution {
56    fn do_optimize(
57        plan: Arc<dyn ExecutionPlan>,
58        _config: &ConfigOptions,
59    ) -> DfResult<Arc<dyn ExecutionPlan>> {
60        // Start from root with no requirement
61        Self::rewrite_with_distribution(plan, None)
62    }
63
64    /// Top-down rewrite that propagates distribution requirements to children.
65    fn rewrite_with_distribution(
66        plan: Arc<dyn ExecutionPlan>,
67        current_req: Option<Distribution>,
68    ) -> DfResult<Arc<dyn ExecutionPlan>> {
69        // If this is a MergeScanExec, try to apply the current requirement.
70        if let Some(merge_scan) = plan.as_any().downcast_ref::<MergeScanExec>()
71            && let Some(distribution) = current_req.as_ref()
72            && let Some(new_plan) = merge_scan.try_with_new_distribution(distribution.clone())
73        {
74            // Leaf node; no children to process
75            return Ok(Arc::new(new_plan) as _);
76        }
77
78        // Compute per-child requirements from the current node.
79        let children = plan.children();
80        if children.is_empty() {
81            return Ok(plan);
82        }
83
84        let required = plan.required_input_distribution();
85        let mut new_children = Vec::with_capacity(children.len());
86        for (idx, child) in children.into_iter().enumerate() {
87            let child_req = match required.get(idx) {
88                Some(Distribution::UnspecifiedDistribution) if idx == 0 => {
89                    Self::map_hash_requirement_through_projection(plan.as_ref(), &current_req)
90                }
91                Some(Distribution::UnspecifiedDistribution) => None,
92                None => current_req.clone(),
93                Some(req) => Some(req.clone()),
94            };
95            let new_child = Self::rewrite_with_distribution(child.clone(), child_req)?;
96            new_children.push(new_child);
97        }
98
99        // Rebuild the node only if any child changed (pointer inequality)
100        let unchanged = plan
101            .children()
102            .into_iter()
103            .zip(new_children.iter())
104            .all(|(old, new)| Arc::ptr_eq(old, new));
105        if unchanged {
106            Ok(plan)
107        } else {
108            plan.with_new_children(new_children)
109        }
110    }
111
112    fn map_hash_requirement_through_projection(
113        plan: &dyn ExecutionPlan,
114        current_req: &Option<Distribution>,
115    ) -> Option<Distribution> {
116        let Some(Distribution::HashPartitioned(required_exprs)) = current_req else {
117            return None;
118        };
119
120        let projection = plan.as_any().downcast_ref::<ProjectionExec>()?;
121        let proj_exprs = projection
122            .expr()
123            .iter()
124            .map(|expr| (Arc::clone(&expr.expr), expr.alias.clone()))
125            .collect::<Vec<_>>();
126        let mapped = map_columns_before_projection(required_exprs, &proj_exprs);
127
128        (mapped.len() == required_exprs.len()).then_some(Distribution::HashPartitioned(mapped))
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use std::collections::{BTreeMap, BTreeSet};
135
136    use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
137    use async_trait::async_trait;
138    use common_query::request::QueryRequest;
139    use common_recordbatch::SendableRecordBatchStream;
140    use datafusion::common::NullEquality;
141    use datafusion::execution::SessionStateBuilder;
142    use datafusion::physical_optimizer::PhysicalOptimizerRule;
143    use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
144    use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr};
145    use datafusion::physical_plan::{ExecutionPlanProperties, Partitioning};
146    use datafusion_expr::{JoinType, LogicalPlanBuilder};
147    use datafusion_physical_expr::PhysicalExpr;
148    use datafusion_physical_expr::expressions::Column as PhysicalColumn;
149    use session::ReadPreference;
150    use session::context::QueryContext;
151    use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
152    use store_api::storage::RegionId;
153    use table::table_name::TableName;
154
155    use super::*;
156    use crate::error::Result as QueryResult;
157    use crate::region_query::RegionQueryHandler;
158
159    struct NoopRegionQueryHandler;
160
161    #[async_trait]
162    impl RegionQueryHandler for NoopRegionQueryHandler {
163        async fn do_get(
164            &self,
165            _read_preference: ReadPreference,
166            _request: QueryRequest,
167        ) -> QueryResult<SendableRecordBatchStream> {
168            unreachable!("pass distribution tests should not execute remote queries")
169        }
170    }
171
172    #[test]
173    fn passes_hash_requirement_through_projection_to_merge_scan() {
174        let schema = test_schema();
175        let left_merge_scan = Arc::new(test_merge_scan_exec(schema.clone()));
176        let right_merge_scan = Arc::new(test_merge_scan_exec(schema.clone()));
177        let left_projection = Arc::new(
178            ProjectionExec::try_new(
179                vec![
180                    ProjectionExpr::new(partition_column("greptime_value", 3), "greptime_value"),
181                    ProjectionExpr::new(
182                        partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
183                        DATA_SCHEMA_TSID_COLUMN_NAME,
184                    ),
185                    ProjectionExpr::new(
186                        partition_column("greptime_timestamp", 2),
187                        "greptime_timestamp",
188                    ),
189                ],
190                left_merge_scan,
191            )
192            .unwrap(),
193        ) as Arc<dyn datafusion::physical_plan::ExecutionPlan>;
194        let join = Arc::new(
195            HashJoinExec::try_new(
196                left_projection,
197                right_merge_scan,
198                vec![
199                    (
200                        partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
201                        partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
202                    ),
203                    (
204                        partition_column("greptime_timestamp", 2),
205                        partition_column("greptime_timestamp", 2),
206                    ),
207                ],
208                None,
209                &JoinType::Inner,
210                None,
211                PartitionMode::Partitioned,
212                NullEquality::NullEqualsNull,
213                false,
214            )
215            .unwrap(),
216        ) as Arc<dyn datafusion::physical_plan::ExecutionPlan>;
217
218        let optimized = PassDistribution
219            .optimize(join, &ConfigOptions::default())
220            .unwrap();
221        let hash_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
222        let left_projection = hash_join
223            .left()
224            .as_any()
225            .downcast_ref::<ProjectionExec>()
226            .unwrap();
227        let left_partitioning = left_projection.input().output_partitioning();
228        let right_partitioning = hash_join.right().output_partitioning();
229
230        let Partitioning::Hash(left_exprs, left_count) = left_partitioning else {
231            panic!("expected left merge scan hash partitioning");
232        };
233        let Partitioning::Hash(right_exprs, right_count) = right_partitioning else {
234            panic!("expected right merge scan hash partitioning");
235        };
236
237        assert_eq!(*left_count, 32);
238        assert_eq!(*right_count, 32);
239        assert_eq!(
240            column_names(left_exprs),
241            vec![DATA_SCHEMA_TSID_COLUMN_NAME, "greptime_timestamp"]
242        );
243        assert_eq!(
244            column_names(right_exprs),
245            vec![DATA_SCHEMA_TSID_COLUMN_NAME, "greptime_timestamp"]
246        );
247    }
248
249    fn test_merge_scan_exec(schema: SchemaRef) -> MergeScanExec {
250        let session_state = SessionStateBuilder::new().with_default_features().build();
251        let partition_cols = BTreeMap::from([
252            (
253                DATA_SCHEMA_TSID_COLUMN_NAME.to_string(),
254                BTreeSet::from([datafusion_common::Column::from_name(
255                    DATA_SCHEMA_TSID_COLUMN_NAME,
256                )]),
257            ),
258            (
259                "greptime_timestamp".to_string(),
260                BTreeSet::from([datafusion_common::Column::from_name("greptime_timestamp")]),
261            ),
262        ]);
263        let plan = LogicalPlanBuilder::empty(false).build().unwrap();
264
265        MergeScanExec::new(
266            &session_state,
267            TableName::new("greptime", "public", "test"),
268            vec![RegionId::new(1, 0), RegionId::new(1, 1)],
269            plan,
270            schema.as_ref(),
271            Arc::new(NoopRegionQueryHandler),
272            QueryContext::arc(),
273            32,
274            partition_cols,
275        )
276        .unwrap()
277    }
278
279    fn test_schema() -> SchemaRef {
280        Arc::new(Schema::new(vec![
281            Field::new("host", DataType::Utf8, true),
282            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
283            Field::new(
284                "greptime_timestamp",
285                DataType::Timestamp(TimeUnit::Millisecond, None),
286                false,
287            ),
288            Field::new("greptime_value", DataType::Float64, true),
289        ]))
290    }
291
292    fn partition_column(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
293        Arc::new(PhysicalColumn::new(name, index))
294    }
295
296    fn column_names(exprs: &[Arc<dyn PhysicalExpr>]) -> Vec<&str> {
297        exprs
298            .iter()
299            .map(|expr| {
300                expr.as_any()
301                    .downcast_ref::<PhysicalColumn>()
302                    .unwrap()
303                    .name()
304            })
305            .collect()
306    }
307}