Skip to main content

query/optimizer/
promql_tsid_narrow_join.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arrow_schema::{DataType, SchemaRef};
18use datafusion::config::ConfigOptions;
19use datafusion::physical_optimizer::PhysicalOptimizerRule;
20use datafusion::physical_plan::ExecutionPlan;
21use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
22use datafusion_common::Result as DfResult;
23use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
24use datafusion_expr::JoinType;
25use datafusion_physical_expr::expressions::Column;
26use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
27
28/// Chooses a broadcast-style hash join for the PromQL vector-vector shape where
29/// the build side only carries value, `__tsid`, and timestamp columns.
30///
31/// PromQL arithmetic joins often keep one side narrow (without raw labels) and the other side wide
32/// with all output labels. Partitioning both sides shuffles the wide stream.
33/// `CollectLeft` only gathers the narrow build side and lets the wide probe side
34/// keep its existing partitioning.
35#[derive(Debug)]
36pub struct PromqlTsidNarrowJoin;
37
38impl PhysicalOptimizerRule for PromqlTsidNarrowJoin {
39    fn optimize(
40        &self,
41        plan: Arc<dyn ExecutionPlan>,
42        _config: &ConfigOptions,
43    ) -> DfResult<Arc<dyn ExecutionPlan>> {
44        plan.transform_up(Self::rewrite_join).data()
45    }
46
47    fn name(&self) -> &str {
48        "PromqlTsidNarrowJoin"
49    }
50
51    fn schema_check(&self) -> bool {
52        true
53    }
54}
55
56impl PromqlTsidNarrowJoin {
57    fn rewrite_join(plan: Arc<dyn ExecutionPlan>) -> DfResult<Transformed<Arc<dyn ExecutionPlan>>> {
58        let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() else {
59            return Ok(Transformed::no(plan));
60        };
61
62        if !Self::should_collect_left(hash_join) {
63            return Ok(Transformed::no(plan));
64        }
65
66        Ok(Transformed::yes(
67            hash_join
68                .builder()
69                .with_partition_mode(PartitionMode::CollectLeft)
70                .reset_state()
71                .build_exec()?,
72        ))
73    }
74
75    fn should_collect_left(hash_join: &HashJoinExec) -> bool {
76        hash_join.partition_mode() == &PartitionMode::Partitioned
77            && hash_join.join_type() == &JoinType::Inner
78            && hash_join.filter().is_none()
79            && hash_join.right().schema().fields().len() > hash_join.left().schema().fields().len()
80            && Self::is_promql_value_tsid_time_schema(&hash_join.left().schema())
81            && Self::joins_on_tsid_and_time(hash_join)
82    }
83
84    fn is_promql_value_tsid_time_schema(schema: &SchemaRef) -> bool {
85        let mut has_value = false;
86        let mut has_tsid = false;
87        let mut has_time = false;
88
89        for field in schema.fields() {
90            match field.name().as_str() {
91                "greptime_value" => has_value = true,
92                DATA_SCHEMA_TSID_COLUMN_NAME => has_tsid = true,
93                _ if matches!(field.data_type(), DataType::Timestamp(_, _)) => has_time = true,
94                _ => return false,
95            }
96        }
97
98        has_value && has_tsid && has_time
99    }
100
101    fn joins_on_tsid_and_time(hash_join: &HashJoinExec) -> bool {
102        let mut has_tsid = false;
103        let mut has_time = false;
104
105        for (left, right) in hash_join.on() {
106            let (Some(left_col), Some(right_col)) = (
107                left.as_any().downcast_ref::<Column>(),
108                right.as_any().downcast_ref::<Column>(),
109            ) else {
110                return false;
111            };
112
113            if left_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME
114                && right_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME
115            {
116                has_tsid = true;
117            } else if matches!(
118                hash_join
119                    .left()
120                    .schema()
121                    .field(left_col.index())
122                    .data_type(),
123                DataType::Timestamp(_, _)
124            ) && matches!(
125                hash_join
126                    .right()
127                    .schema()
128                    .field(right_col.index())
129                    .data_type(),
130                DataType::Timestamp(_, _)
131            ) {
132                has_time = true;
133            }
134        }
135
136        has_tsid && has_time
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use arrow_schema::{DataType, Field, Schema, TimeUnit};
143    use datafusion::common::NullEquality;
144    use datafusion::physical_optimizer::PhysicalOptimizerRule;
145    use datafusion::physical_plan::displayable;
146    use datafusion::physical_plan::empty::EmptyExec;
147    use datafusion::physical_plan::joins::HashJoinExec;
148    use datafusion_common::config::ConfigOptions;
149    use datafusion_physical_expr::PhysicalExpr;
150
151    use super::*;
152
153    #[test]
154    fn chooses_collect_left_for_narrow_promql_build_side() {
155        let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
156            Field::new("greptime_value", DataType::Float64, true),
157            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
158            Field::new(
159                "greptime_timestamp",
160                DataType::Timestamp(TimeUnit::Millisecond, None),
161                false,
162            ),
163        ])))) as Arc<dyn ExecutionPlan>;
164        let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
165            Field::new("greptime_value", DataType::Float64, true),
166            Field::new("host", DataType::Utf8, true),
167            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
168            Field::new(
169                "greptime_timestamp",
170                DataType::Timestamp(TimeUnit::Millisecond, None),
171                false,
172            ),
173        ])))) as Arc<dyn ExecutionPlan>;
174        let on = vec![
175            (
176                Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1)) as Arc<dyn PhysicalExpr>,
177                Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2)) as Arc<dyn PhysicalExpr>,
178            ),
179            (
180                Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
181                Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
182            ),
183        ];
184        let join = Arc::new(
185            HashJoinExec::try_new(
186                left,
187                right,
188                on,
189                None,
190                &JoinType::Inner,
191                Some(vec![0, 3, 4, 5, 6]),
192                PartitionMode::Partitioned,
193                NullEquality::NullEqualsNull,
194                false,
195            )
196            .unwrap(),
197        ) as Arc<dyn ExecutionPlan>;
198        let original_schema = join.schema();
199
200        let optimized = PromqlTsidNarrowJoin
201            .optimize(join, &ConfigOptions::default())
202            .unwrap();
203        let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
204
205        assert_eq!(optimized_join.partition_mode(), &PartitionMode::CollectLeft);
206        assert_eq!(optimized.schema(), original_schema);
207        assert!(
208            displayable(optimized.as_ref())
209                .one_line()
210                .to_string()
211                .contains(
212                    "projection=[greptime_value@0, greptime_value@3, host@4, __tsid@5, greptime_timestamp@6]"
213                )
214        );
215    }
216
217    #[test]
218    fn keeps_partitioned_join_when_left_side_carries_labels() {
219        let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
220            Field::new("greptime_value", DataType::Float64, true),
221            Field::new("host", DataType::Utf8, true),
222            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
223            Field::new(
224                "greptime_timestamp",
225                DataType::Timestamp(TimeUnit::Millisecond, None),
226                false,
227            ),
228        ])))) as Arc<dyn ExecutionPlan>;
229        let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
230            Field::new("greptime_value", DataType::Float64, true),
231            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
232            Field::new(
233                "greptime_timestamp",
234                DataType::Timestamp(TimeUnit::Millisecond, None),
235                false,
236            ),
237        ])))) as Arc<dyn ExecutionPlan>;
238        let join = Arc::new(
239            HashJoinExec::try_new(
240                left,
241                right,
242                vec![
243                    (
244                        Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2))
245                            as Arc<dyn PhysicalExpr>,
246                        Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1))
247                            as Arc<dyn PhysicalExpr>,
248                    ),
249                    (
250                        Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
251                        Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
252                    ),
253                ],
254                None,
255                &JoinType::Inner,
256                None,
257                PartitionMode::Partitioned,
258                NullEquality::NullEqualsNull,
259                false,
260            )
261            .unwrap(),
262        ) as Arc<dyn ExecutionPlan>;
263
264        let optimized = PromqlTsidNarrowJoin
265            .optimize(join, &ConfigOptions::default())
266            .unwrap();
267        let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
268
269        assert_eq!(optimized_join.partition_mode(), &PartitionMode::Partitioned);
270    }
271}