1use std::sync::Arc;
16
17use datafusion::config::ConfigOptions;
18use datafusion::physical_optimizer::PhysicalOptimizerRule;
19use datafusion::physical_plan::ExecutionPlan;
20use datafusion::physical_plan::projection::ProjectionExec;
21use datafusion_common::Result as DfResult;
22use datafusion_physical_expr::Distribution;
23use datafusion_physical_expr::utils::map_columns_before_projection;
24
25use crate::dist_plan::MergeScanExec;
26
27#[derive(Debug)]
35pub struct PassDistribution;
36
37impl PhysicalOptimizerRule for PassDistribution {
38 fn optimize(
39 &self,
40 plan: Arc<dyn ExecutionPlan>,
41 config: &ConfigOptions,
42 ) -> DfResult<Arc<dyn ExecutionPlan>> {
43 Self::do_optimize(plan, config)
44 }
45
46 fn name(&self) -> &str {
47 "PassDistributionRule"
48 }
49
50 fn schema_check(&self) -> bool {
51 false
52 }
53}
54
55impl PassDistribution {
56 fn do_optimize(
57 plan: Arc<dyn ExecutionPlan>,
58 _config: &ConfigOptions,
59 ) -> DfResult<Arc<dyn ExecutionPlan>> {
60 Self::rewrite_with_distribution(plan, None)
62 }
63
64 fn rewrite_with_distribution(
66 plan: Arc<dyn ExecutionPlan>,
67 current_req: Option<Distribution>,
68 ) -> DfResult<Arc<dyn ExecutionPlan>> {
69 if let Some(merge_scan) = plan.as_any().downcast_ref::<MergeScanExec>()
71 && let Some(distribution) = current_req.as_ref()
72 && let Some(new_plan) = merge_scan.try_with_new_distribution(distribution.clone())
73 {
74 return Ok(Arc::new(new_plan) as _);
76 }
77
78 let children = plan.children();
80 if children.is_empty() {
81 return Ok(plan);
82 }
83
84 let required = plan.required_input_distribution();
85 let mut new_children = Vec::with_capacity(children.len());
86 for (idx, child) in children.into_iter().enumerate() {
87 let child_req = match required.get(idx) {
88 Some(Distribution::UnspecifiedDistribution) if idx == 0 => {
89 Self::map_hash_requirement_through_projection(plan.as_ref(), ¤t_req)
90 }
91 Some(Distribution::UnspecifiedDistribution) => None,
92 None => current_req.clone(),
93 Some(req) => Some(req.clone()),
94 };
95 let new_child = Self::rewrite_with_distribution(child.clone(), child_req)?;
96 new_children.push(new_child);
97 }
98
99 let unchanged = plan
101 .children()
102 .into_iter()
103 .zip(new_children.iter())
104 .all(|(old, new)| Arc::ptr_eq(old, new));
105 if unchanged {
106 Ok(plan)
107 } else {
108 plan.with_new_children(new_children)
109 }
110 }
111
112 fn map_hash_requirement_through_projection(
113 plan: &dyn ExecutionPlan,
114 current_req: &Option<Distribution>,
115 ) -> Option<Distribution> {
116 let Some(Distribution::HashPartitioned(required_exprs)) = current_req else {
117 return None;
118 };
119
120 let projection = plan.as_any().downcast_ref::<ProjectionExec>()?;
121 let proj_exprs = projection
122 .expr()
123 .iter()
124 .map(|expr| (Arc::clone(&expr.expr), expr.alias.clone()))
125 .collect::<Vec<_>>();
126 let mapped = map_columns_before_projection(required_exprs, &proj_exprs);
127
128 (mapped.len() == required_exprs.len()).then_some(Distribution::HashPartitioned(mapped))
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use std::collections::{BTreeMap, BTreeSet};
135
136 use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
137 use async_trait::async_trait;
138 use common_query::request::QueryRequest;
139 use common_recordbatch::SendableRecordBatchStream;
140 use datafusion::common::NullEquality;
141 use datafusion::execution::SessionStateBuilder;
142 use datafusion::physical_optimizer::PhysicalOptimizerRule;
143 use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
144 use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr};
145 use datafusion::physical_plan::{ExecutionPlanProperties, Partitioning};
146 use datafusion_expr::{JoinType, LogicalPlanBuilder};
147 use datafusion_physical_expr::PhysicalExpr;
148 use datafusion_physical_expr::expressions::Column as PhysicalColumn;
149 use session::ReadPreference;
150 use session::context::QueryContext;
151 use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
152 use store_api::storage::RegionId;
153 use table::table_name::TableName;
154
155 use super::*;
156 use crate::error::Result as QueryResult;
157 use crate::region_query::RegionQueryHandler;
158
159 struct NoopRegionQueryHandler;
160
161 #[async_trait]
162 impl RegionQueryHandler for NoopRegionQueryHandler {
163 async fn do_get(
164 &self,
165 _read_preference: ReadPreference,
166 _request: QueryRequest,
167 ) -> QueryResult<SendableRecordBatchStream> {
168 unreachable!("pass distribution tests should not execute remote queries")
169 }
170 }
171
172 #[test]
173 fn passes_hash_requirement_through_projection_to_merge_scan() {
174 let schema = test_schema();
175 let left_merge_scan = Arc::new(test_merge_scan_exec(schema.clone()));
176 let right_merge_scan = Arc::new(test_merge_scan_exec(schema.clone()));
177 let left_projection = Arc::new(
178 ProjectionExec::try_new(
179 vec![
180 ProjectionExpr::new(partition_column("greptime_value", 3), "greptime_value"),
181 ProjectionExpr::new(
182 partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
183 DATA_SCHEMA_TSID_COLUMN_NAME,
184 ),
185 ProjectionExpr::new(
186 partition_column("greptime_timestamp", 2),
187 "greptime_timestamp",
188 ),
189 ],
190 left_merge_scan,
191 )
192 .unwrap(),
193 ) as Arc<dyn datafusion::physical_plan::ExecutionPlan>;
194 let join = Arc::new(
195 HashJoinExec::try_new(
196 left_projection,
197 right_merge_scan,
198 vec![
199 (
200 partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
201 partition_column(DATA_SCHEMA_TSID_COLUMN_NAME, 1),
202 ),
203 (
204 partition_column("greptime_timestamp", 2),
205 partition_column("greptime_timestamp", 2),
206 ),
207 ],
208 None,
209 &JoinType::Inner,
210 None,
211 PartitionMode::Partitioned,
212 NullEquality::NullEqualsNull,
213 false,
214 )
215 .unwrap(),
216 ) as Arc<dyn datafusion::physical_plan::ExecutionPlan>;
217
218 let optimized = PassDistribution
219 .optimize(join, &ConfigOptions::default())
220 .unwrap();
221 let hash_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
222 let left_projection = hash_join
223 .left()
224 .as_any()
225 .downcast_ref::<ProjectionExec>()
226 .unwrap();
227 let left_partitioning = left_projection.input().output_partitioning();
228 let right_partitioning = hash_join.right().output_partitioning();
229
230 let Partitioning::Hash(left_exprs, left_count) = left_partitioning else {
231 panic!("expected left merge scan hash partitioning");
232 };
233 let Partitioning::Hash(right_exprs, right_count) = right_partitioning else {
234 panic!("expected right merge scan hash partitioning");
235 };
236
237 assert_eq!(*left_count, 32);
238 assert_eq!(*right_count, 32);
239 assert_eq!(
240 column_names(left_exprs),
241 vec![DATA_SCHEMA_TSID_COLUMN_NAME, "greptime_timestamp"]
242 );
243 assert_eq!(
244 column_names(right_exprs),
245 vec![DATA_SCHEMA_TSID_COLUMN_NAME, "greptime_timestamp"]
246 );
247 }
248
249 fn test_merge_scan_exec(schema: SchemaRef) -> MergeScanExec {
250 let session_state = SessionStateBuilder::new().with_default_features().build();
251 let partition_cols = BTreeMap::from([
252 (
253 DATA_SCHEMA_TSID_COLUMN_NAME.to_string(),
254 BTreeSet::from([datafusion_common::Column::from_name(
255 DATA_SCHEMA_TSID_COLUMN_NAME,
256 )]),
257 ),
258 (
259 "greptime_timestamp".to_string(),
260 BTreeSet::from([datafusion_common::Column::from_name("greptime_timestamp")]),
261 ),
262 ]);
263 let plan = LogicalPlanBuilder::empty(false).build().unwrap();
264
265 MergeScanExec::new(
266 &session_state,
267 TableName::new("greptime", "public", "test"),
268 vec![RegionId::new(1, 0), RegionId::new(1, 1)],
269 plan,
270 schema.as_ref(),
271 Arc::new(NoopRegionQueryHandler),
272 QueryContext::arc(),
273 32,
274 partition_cols,
275 )
276 .unwrap()
277 }
278
279 fn test_schema() -> SchemaRef {
280 Arc::new(Schema::new(vec![
281 Field::new("host", DataType::Utf8, true),
282 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
283 Field::new(
284 "greptime_timestamp",
285 DataType::Timestamp(TimeUnit::Millisecond, None),
286 false,
287 ),
288 Field::new("greptime_value", DataType::Float64, true),
289 ]))
290 }
291
292 fn partition_column(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
293 Arc::new(PhysicalColumn::new(name, index))
294 }
295
296 fn column_names(exprs: &[Arc<dyn PhysicalExpr>]) -> Vec<&str> {
297 exprs
298 .iter()
299 .map(|expr| {
300 expr.as_any()
301 .downcast_ref::<PhysicalColumn>()
302 .unwrap()
303 .name()
304 })
305 .collect()
306 }
307}