Skip to main content

query/datafusion/
json_expr_planner.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::{Arc, LazyLock};
16
17use arrow_schema::Field;
18use arrow_schema::extension::ExtensionType;
19use common_function::scalars::json::json_get::JsonGetWithType;
20use common_function::scalars::udf::create_udf;
21use datafusion_common::arrow::datatypes::DataType;
22use datafusion_common::{Column, DFSchema, Result, ScalarValue, TableReference};
23use datafusion_expr::expr::{BinaryExpr, ScalarFunction};
24use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr};
25use datafusion_expr::{Expr, ExprSchemable, Operator, ScalarUDF};
26use datatypes::extension::json::JsonExtensionType;
27use either::Either;
28use sqlparser::ast::BinaryOperator;
29
30/// Rewrites JSON-aware SQL expressions into DataFusion expressions.
31///
32/// This planner handles two cases:
33/// - Rewrites compound identifiers on JSON extension columns into `json_get` function.
34///   For example, `select a.b.c` => `select json_get(a, "b.c")`.
35/// - Pushes an "expected type" argument into the `json_get` function when it participates in a
36///   binary operator. So that `json_get` knows the wanted data type when dealing with variant
37///   JSON values.
38///   For example, `select json_get(a, "b.c") + 1` => `select json_get(a, "b.c", NULL::Int64) + 1`.
39#[derive(Debug)]
40pub(crate) struct JsonExprPlanner;
41
42impl ExprPlanner for JsonExprPlanner {
43    fn plan_binary_op(
44        &self,
45        expr: RawBinaryExpr,
46        schema: &DFSchema,
47    ) -> Result<PlannerResult<RawBinaryExpr>> {
48        let RawBinaryExpr {
49            op,
50            mut left,
51            mut right,
52        } = expr;
53
54        if extract_untyped_json_get(&mut left).is_none()
55            && extract_untyped_json_get(&mut right).is_none()
56        {
57            return Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }));
58        }
59
60        let Some(expr_op) = parse_sql_op(&op) else {
61            return Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }));
62        };
63
64        let left_type = left.get_type(schema)?;
65        let right_type = right.get_type(schema)?;
66        let left = push_json_get_type_arg(left, right_type)?;
67        let right = push_json_get_type_arg(right, left_type)?;
68        match (left, right) {
69            (Either::Left(left), Either::Left(right)) => {
70                Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }))
71            }
72            (left, right) => Ok(PlannerResult::Planned(Expr::BinaryExpr(BinaryExpr::new(
73                Box::new(left.into_inner()),
74                expr_op,
75                Box::new(right.into_inner()),
76            )))),
77        }
78    }
79
80    fn plan_compound_identifier(
81        &self,
82        field: &Field,
83        qualifier: Option<&TableReference>,
84        nested_names: &[String],
85    ) -> Result<PlannerResult<Vec<Expr>>> {
86        if field.extension_type_name() != Some(JsonExtensionType::NAME) {
87            return Ok(PlannerResult::Original(Vec::new()));
88        }
89
90        static JSON_GET_UDF: LazyLock<Arc<ScalarUDF>> =
91            LazyLock::new(|| Arc::new(create_udf(Arc::new(JsonGetWithType::default()))));
92
93        let json_get = JSON_GET_UDF.clone();
94        let path = nested_names.join(".");
95        Ok(PlannerResult::Planned(Expr::ScalarFunction(
96            ScalarFunction::new_udf(
97                json_get,
98                vec![
99                    Expr::Column(Column::from((qualifier, field))),
100                    Expr::Literal(ScalarValue::Utf8(Some(path)), None),
101                ],
102            ),
103        )))
104    }
105}
106
107fn extract_untyped_json_get(expr: &mut Expr) -> Option<&mut ScalarFunction> {
108    match expr {
109        Expr::ScalarFunction(f)
110            if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
111        {
112            Some(f)
113        }
114        _ => None,
115    }
116}
117
118fn push_json_get_type_arg(mut expr: Expr, data_type: DataType) -> Result<Either<Expr, Expr>> {
119    let Some(json_get) = extract_untyped_json_get(&mut expr) else {
120        return Ok(Either::Left(expr));
121    };
122
123    let with_type = ScalarValue::try_new_null(&data_type).map(|x| Expr::Literal(x, None))?;
124    json_get.args.push(with_type);
125
126    Ok(Either::Right(expr))
127}
128
129fn parse_sql_op(op: &BinaryOperator) -> Option<Operator> {
130    match *op {
131        BinaryOperator::Plus => Some(Operator::Plus),
132        BinaryOperator::Minus => Some(Operator::Minus),
133        BinaryOperator::Multiply => Some(Operator::Multiply),
134        BinaryOperator::Divide => Some(Operator::Divide),
135        BinaryOperator::Modulo => Some(Operator::Modulo),
136        BinaryOperator::Gt => Some(Operator::Gt),
137        BinaryOperator::GtEq => Some(Operator::GtEq),
138        BinaryOperator::Lt => Some(Operator::Lt),
139        BinaryOperator::LtEq => Some(Operator::LtEq),
140        BinaryOperator::Eq => Some(Operator::Eq),
141        BinaryOperator::NotEq => Some(Operator::NotEq),
142        BinaryOperator::And => Some(Operator::And),
143        BinaryOperator::Or => Some(Operator::Or),
144        BinaryOperator::BitwiseAnd => Some(Operator::BitwiseAnd),
145        BinaryOperator::BitwiseOr => Some(Operator::BitwiseOr),
146        BinaryOperator::BitwiseXor => Some(Operator::BitwiseXor),
147        _ => None,
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use arrow_schema::Fields;
154    use datatypes::extension::json::JsonMetadata;
155
156    use super::*;
157
158    fn json_get_expr(base: Expr, path: &str) -> Expr {
159        let json_get = Arc::new(create_udf(Arc::new(JsonGetWithType::default())));
160        Expr::ScalarFunction(ScalarFunction::new_udf(
161            json_get,
162            vec![
163                base,
164                Expr::Literal(ScalarValue::Utf8(Some(path.to_string())), None),
165            ],
166        ))
167    }
168
169    fn json_field(name: &str) -> Field {
170        Field::new(name, DataType::Binary, true)
171            .with_extension_type(JsonExtensionType::new(Arc::new(JsonMetadata::default())))
172    }
173
174    #[test]
175    fn test_plan_binary_op() -> Result<()> {
176        let planner = JsonExprPlanner;
177        let schema = DFSchema::from_unqualified_fields(
178            Fields::from(vec![Field::new("value", DataType::Int64, true)]),
179            Default::default(),
180        )?;
181
182        let planned = planner.plan_binary_op(
183            RawBinaryExpr {
184                op: BinaryOperator::Eq,
185                left: json_get_expr(
186                    Expr::Literal(ScalarValue::Binary(Some(b"{\"a\": 1}".to_vec())), None),
187                    "a",
188                ),
189                right: Expr::Column(Column::new_unqualified("value")),
190            },
191            &schema,
192        )?;
193
194        match planned {
195            PlannerResult::Planned(Expr::BinaryExpr(expr)) => {
196                assert_eq!(expr.op, Operator::Eq);
197
198                match expr.left.as_ref() {
199                    Expr::ScalarFunction(func) => {
200                        assert_eq!(func.func.name(), JsonGetWithType::NAME);
201                        assert_eq!(func.args.len(), 3);
202                        assert_eq!(func.args[2], Expr::Literal(ScalarValue::Int64(None), None));
203                    }
204                    other => panic!("expected json_get on left side, got {other:?}"),
205                }
206
207                assert_eq!(
208                    expr.right.as_ref(),
209                    &Expr::Column(Column::new_unqualified("value"))
210                );
211            }
212            other => panic!("expected planned binary expression, got {other:?}"),
213        }
214
215        let original = planner.plan_binary_op(
216            RawBinaryExpr {
217                op: BinaryOperator::StringConcat,
218                left: Expr::Column(Column::new_unqualified("value")),
219                right: Expr::Literal(ScalarValue::Utf8(Some("x".to_string())), None),
220            },
221            &schema,
222        )?;
223
224        match original {
225            PlannerResult::Original(expr) => {
226                assert!(matches!(expr.op, BinaryOperator::StringConcat));
227                assert_eq!(expr.left, Expr::Column(Column::new_unqualified("value")));
228                assert_eq!(
229                    expr.right,
230                    Expr::Literal(ScalarValue::Utf8(Some("x".to_string())), None)
231                );
232            }
233            other => panic!(
234                "expected original expression for unsupported operator, got {:?}",
235                other,
236            ),
237        }
238
239        Ok(())
240    }
241
242    #[test]
243    fn test_plan_compound_identifier() -> Result<()> {
244        let planner = JsonExprPlanner;
245        let qualifier = TableReference::bare("events");
246        let nested_names = vec!["payload".to_string(), "cpu".to_string()];
247
248        let planned = planner.plan_compound_identifier(
249            &json_field("labels"),
250            Some(&qualifier),
251            &nested_names,
252        )?;
253
254        match planned {
255            PlannerResult::Planned(Expr::ScalarFunction(func)) => {
256                assert_eq!(func.func.name(), JsonGetWithType::NAME);
257                assert_eq!(func.args.len(), 2);
258                assert_eq!(
259                    func.args[0],
260                    Expr::Column(Column::new(Some(qualifier.clone()), "labels"))
261                );
262                assert_eq!(
263                    func.args[1],
264                    Expr::Literal(ScalarValue::Utf8(Some("payload.cpu".to_string())), None)
265                );
266            }
267            other => panic!("expected json_get scalar function, got {other:?}"),
268        }
269
270        let original = planner.plan_compound_identifier(
271            &Field::new("plain", DataType::Utf8, true),
272            Some(&qualifier),
273            &nested_names,
274        )?;
275
276        match original {
277            PlannerResult::Original(exprs) => assert!(exprs.is_empty()),
278            other => panic!(
279                "expected original empty result for non-json field, got {:?}",
280                other,
281            ),
282        }
283
284        Ok(())
285    }
286}