1use std::sync::Arc;
16
17use arrow_schema::{DataType, SchemaRef};
18use datafusion::config::ConfigOptions;
19use datafusion::physical_optimizer::PhysicalOptimizerRule;
20use datafusion::physical_plan::ExecutionPlan;
21use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
22use datafusion_common::Result as DfResult;
23use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
24use datafusion_expr::JoinType;
25use datafusion_physical_expr::expressions::Column;
26use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
27
28#[derive(Debug)]
36pub struct PromqlTsidNarrowJoin;
37
38impl PhysicalOptimizerRule for PromqlTsidNarrowJoin {
39 fn optimize(
40 &self,
41 plan: Arc<dyn ExecutionPlan>,
42 _config: &ConfigOptions,
43 ) -> DfResult<Arc<dyn ExecutionPlan>> {
44 plan.transform_up(Self::rewrite_join).data()
45 }
46
47 fn name(&self) -> &str {
48 "PromqlTsidNarrowJoin"
49 }
50
51 fn schema_check(&self) -> bool {
52 true
53 }
54}
55
56impl PromqlTsidNarrowJoin {
57 fn rewrite_join(plan: Arc<dyn ExecutionPlan>) -> DfResult<Transformed<Arc<dyn ExecutionPlan>>> {
58 let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() else {
59 return Ok(Transformed::no(plan));
60 };
61
62 if !Self::should_collect_left(hash_join) {
63 return Ok(Transformed::no(plan));
64 }
65
66 Ok(Transformed::yes(
67 hash_join
68 .builder()
69 .with_partition_mode(PartitionMode::CollectLeft)
70 .reset_state()
71 .build_exec()?,
72 ))
73 }
74
75 fn should_collect_left(hash_join: &HashJoinExec) -> bool {
76 hash_join.partition_mode() == &PartitionMode::Partitioned
77 && hash_join.join_type() == &JoinType::Inner
78 && hash_join.filter().is_none()
79 && hash_join.right().schema().fields().len() > hash_join.left().schema().fields().len()
80 && Self::is_promql_value_tsid_time_schema(&hash_join.left().schema())
81 && Self::joins_on_tsid_and_time(hash_join)
82 }
83
84 fn is_promql_value_tsid_time_schema(schema: &SchemaRef) -> bool {
85 let mut has_value = false;
86 let mut has_tsid = false;
87 let mut has_time = false;
88
89 for field in schema.fields() {
90 match field.name().as_str() {
91 "greptime_value" => has_value = true,
92 DATA_SCHEMA_TSID_COLUMN_NAME => has_tsid = true,
93 _ if matches!(field.data_type(), DataType::Timestamp(_, _)) => has_time = true,
94 _ => return false,
95 }
96 }
97
98 has_value && has_tsid && has_time
99 }
100
101 fn joins_on_tsid_and_time(hash_join: &HashJoinExec) -> bool {
102 let mut has_tsid = false;
103 let mut has_time = false;
104
105 for (left, right) in hash_join.on() {
106 let (Some(left_col), Some(right_col)) = (
107 left.as_any().downcast_ref::<Column>(),
108 right.as_any().downcast_ref::<Column>(),
109 ) else {
110 return false;
111 };
112
113 if left_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME
114 && right_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME
115 {
116 has_tsid = true;
117 } else if matches!(
118 hash_join
119 .left()
120 .schema()
121 .field(left_col.index())
122 .data_type(),
123 DataType::Timestamp(_, _)
124 ) && matches!(
125 hash_join
126 .right()
127 .schema()
128 .field(right_col.index())
129 .data_type(),
130 DataType::Timestamp(_, _)
131 ) {
132 has_time = true;
133 }
134 }
135
136 has_tsid && has_time
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use arrow_schema::{DataType, Field, Schema, TimeUnit};
143 use datafusion::common::NullEquality;
144 use datafusion::physical_optimizer::PhysicalOptimizerRule;
145 use datafusion::physical_plan::displayable;
146 use datafusion::physical_plan::empty::EmptyExec;
147 use datafusion::physical_plan::joins::HashJoinExec;
148 use datafusion_common::config::ConfigOptions;
149 use datafusion_physical_expr::PhysicalExpr;
150
151 use super::*;
152
153 #[test]
154 fn chooses_collect_left_for_narrow_promql_build_side() {
155 let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
156 Field::new("greptime_value", DataType::Float64, true),
157 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
158 Field::new(
159 "greptime_timestamp",
160 DataType::Timestamp(TimeUnit::Millisecond, None),
161 false,
162 ),
163 ])))) as Arc<dyn ExecutionPlan>;
164 let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
165 Field::new("greptime_value", DataType::Float64, true),
166 Field::new("host", DataType::Utf8, true),
167 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
168 Field::new(
169 "greptime_timestamp",
170 DataType::Timestamp(TimeUnit::Millisecond, None),
171 false,
172 ),
173 ])))) as Arc<dyn ExecutionPlan>;
174 let on = vec![
175 (
176 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1)) as Arc<dyn PhysicalExpr>,
177 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2)) as Arc<dyn PhysicalExpr>,
178 ),
179 (
180 Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
181 Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
182 ),
183 ];
184 let join = Arc::new(
185 HashJoinExec::try_new(
186 left,
187 right,
188 on,
189 None,
190 &JoinType::Inner,
191 Some(vec![0, 3, 4, 5, 6]),
192 PartitionMode::Partitioned,
193 NullEquality::NullEqualsNull,
194 false,
195 )
196 .unwrap(),
197 ) as Arc<dyn ExecutionPlan>;
198 let original_schema = join.schema();
199
200 let optimized = PromqlTsidNarrowJoin
201 .optimize(join, &ConfigOptions::default())
202 .unwrap();
203 let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
204
205 assert_eq!(optimized_join.partition_mode(), &PartitionMode::CollectLeft);
206 assert_eq!(optimized.schema(), original_schema);
207 assert!(
208 displayable(optimized.as_ref())
209 .one_line()
210 .to_string()
211 .contains(
212 "projection=[greptime_value@0, greptime_value@3, host@4, __tsid@5, greptime_timestamp@6]"
213 )
214 );
215 }
216
217 #[test]
218 fn keeps_partitioned_join_when_left_side_carries_labels() {
219 let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
220 Field::new("greptime_value", DataType::Float64, true),
221 Field::new("host", DataType::Utf8, true),
222 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
223 Field::new(
224 "greptime_timestamp",
225 DataType::Timestamp(TimeUnit::Millisecond, None),
226 false,
227 ),
228 ])))) as Arc<dyn ExecutionPlan>;
229 let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
230 Field::new("greptime_value", DataType::Float64, true),
231 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
232 Field::new(
233 "greptime_timestamp",
234 DataType::Timestamp(TimeUnit::Millisecond, None),
235 false,
236 ),
237 ])))) as Arc<dyn ExecutionPlan>;
238 let join = Arc::new(
239 HashJoinExec::try_new(
240 left,
241 right,
242 vec![
243 (
244 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2))
245 as Arc<dyn PhysicalExpr>,
246 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1))
247 as Arc<dyn PhysicalExpr>,
248 ),
249 (
250 Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
251 Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
252 ),
253 ],
254 None,
255 &JoinType::Inner,
256 None,
257 PartitionMode::Partitioned,
258 NullEquality::NullEqualsNull,
259 false,
260 )
261 .unwrap(),
262 ) as Arc<dyn ExecutionPlan>;
263
264 let optimized = PromqlTsidNarrowJoin
265 .optimize(join, &ConfigOptions::default())
266 .unwrap();
267 let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
268
269 assert_eq!(optimized_join.partition_mode(), &PartitionMode::Partitioned);
270 }
271}