Skip to main content

common_function/scalars/json/
json_get_rewriter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[cfg(test)]
16use std::sync::Arc;
17
18use arrow_schema::{DataType, TimeUnit};
19use datafusion::common::config::ConfigOptions;
20use datafusion::common::tree_node::Transformed;
21use datafusion::common::{DFSchema, Result};
22use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
23use datafusion::scalar::ScalarValue;
24use datafusion_expr::expr::ScalarFunction;
25use datafusion_expr::{Cast, Expr};
26
27use crate::scalars::json::JsonGetWithType;
28
29#[derive(Debug)]
30pub struct JsonGetRewriter;
31
32impl FunctionRewrite for JsonGetRewriter {
33    fn name(&self) -> &'static str {
34        "JsonGetRewriter"
35    }
36
37    fn rewrite(
38        &self,
39        expr: Expr,
40        _schema: &DFSchema,
41        _config: &ConfigOptions,
42    ) -> Result<Transformed<Expr>> {
43        Ok(match expr {
44            Expr::Cast(cast) => inject_type_from_cast_expr(cast)?,
45            Expr::ScalarFunction(cast) => inject_type_from_cast_func(cast)?,
46            expr => Transformed::no(expr),
47        })
48    }
49}
50
51// Expr::Cast(
52//   Expr::ScalarFunction(
53//     json_get(column, path),
54//     <data_type>
55//   )
56// )
57// =>
58// Expr::ScalarFunction(
59//   json_get(column, path, <data_type>)
60// )
61fn inject_type_from_cast_expr(cast: Cast) -> Result<Transformed<Expr>> {
62    let Cast { expr, data_type } = cast;
63
64    let mut json_get = match *expr {
65        Expr::ScalarFunction(f)
66            if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
67        {
68            f
69        }
70        expr => {
71            return Ok(Transformed::no(Expr::Cast(Cast {
72                expr: Box::new(expr),
73                data_type,
74            })));
75        }
76    };
77
78    let with_type = ScalarValue::try_new_null(&data_type).map(|x| Expr::Literal(x, None))?;
79    json_get.args.push(with_type);
80    Ok(Transformed::yes(Expr::ScalarFunction(json_get)))
81}
82
83// Expr::ScalarFunction(
84//   arrow_cast(
85//     Expr::ScalarFunction(
86//       json_get(column, path),
87//     ),
88//     <data_type>
89//   )
90// )
91// =>
92// Expr::ScalarFunction(
93//   json_get(column, path, <data_type>)
94// )
95fn inject_type_from_cast_func(cast: ScalarFunction) -> Result<Transformed<Expr>> {
96    let ScalarFunction { func, args } = cast;
97
98    // Check if this is an Arrow cast function
99    // The function name might be "arrow_cast" or similar
100    let func_name = func.name().to_ascii_lowercase();
101    if !func_name.contains("arrow_cast") {
102        let original = Expr::ScalarFunction(ScalarFunction { func, args });
103        return Ok(Transformed::no(original));
104    }
105
106    // Arrow cast function should have exactly 2 arguments:
107    // 1. The expression to cast (could be json_get)
108    // 2. The target type as a string literal
109    if args.len() != 2 {
110        let original = Expr::ScalarFunction(ScalarFunction { func, args });
111        return Ok(Transformed::no(original));
112    }
113    let [arg0, arg1] = args.try_into().unwrap_or_else(|_| unreachable!());
114
115    let Some(with_type) = arg1
116        .as_literal()
117        .and_then(|x| x.try_as_str())
118        .flatten()
119        .and_then(parse_data_type_from_string)
120    else {
121        let original = Expr::ScalarFunction(ScalarFunction {
122            func,
123            args: vec![arg0, arg1],
124        });
125        return Ok(Transformed::no(original));
126    };
127
128    let mut json_get = match arg0 {
129        Expr::ScalarFunction(f)
130            if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
131        {
132            f
133        }
134        arg0 => {
135            let original = Expr::ScalarFunction(ScalarFunction {
136                func,
137                args: vec![arg0, arg1],
138            });
139            return Ok(Transformed::no(original));
140        }
141    };
142
143    let with_type = ScalarValue::try_new_null(&with_type).map(|x| Expr::Literal(x, None))?;
144    json_get.args.push(with_type);
145
146    let rewritten = Expr::ScalarFunction(json_get);
147    Ok(Transformed::yes(rewritten))
148}
149
150// Parse a data type from a string representation
151fn parse_data_type_from_string(type_str: &str) -> Option<DataType> {
152    match type_str.to_lowercase().as_str() {
153        "int8" | "tinyint" => Some(DataType::Int8),
154        "int16" | "smallint" => Some(DataType::Int16),
155        "int32" | "integer" => Some(DataType::Int32),
156        "int64" | "bigint" => Some(DataType::Int64),
157        "uint8" => Some(DataType::UInt8),
158        "uint16" => Some(DataType::UInt16),
159        "uint32" => Some(DataType::UInt32),
160        "uint64" => Some(DataType::UInt64),
161        "float32" | "real" => Some(DataType::Float32),
162        "float64" | "double" => Some(DataType::Float64),
163        "boolean" | "bool" => Some(DataType::Boolean),
164        "string" | "text" | "varchar" => Some(DataType::Utf8),
165        "timestamp" => Some(DataType::Timestamp(TimeUnit::Microsecond, None)),
166        "date" => Some(DataType::Date32),
167        _ => None,
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use arrow_schema::DataType;
174    use datafusion::common::DFSchema;
175    use datafusion::common::config::ConfigOptions;
176    use datafusion::logical_expr::expr::Cast;
177    use datafusion::scalar::ScalarValue;
178    use datafusion_expr::Expr;
179    use datafusion_expr::expr::ScalarFunction;
180
181    use super::*;
182
183    #[test]
184    fn test_rewrite_regular_cast() {
185        let rewriter = JsonGetRewriter;
186        let schema = DFSchema::empty();
187        let config = ConfigOptions::new();
188
189        // Create a json_get function
190        let json_expr = Expr::ScalarFunction(ScalarFunction {
191            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
192                crate::scalars::json::JsonGetWithType::default(),
193            ))),
194            args: vec![
195                Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
196                Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
197            ],
198        });
199
200        // Create a cast expression: json_get(...)::int8
201        let cast_expr = Expr::Cast(Cast {
202            expr: Box::new(json_expr),
203            data_type: DataType::Int8,
204        });
205
206        // Apply the rewriter
207        let result = rewriter.rewrite(cast_expr, &schema, &config).unwrap();
208
209        // Verify the result is transformed
210        assert!(result.transformed);
211
212        // Verify the result is a ScalarFunction
213        match result.data {
214            Expr::ScalarFunction(func) => {
215                // Should have 3 arguments now (original 2 + null cast)
216                assert_eq!(func.args.len(), 3);
217
218                // First argument should be the original json
219                match &func.args[0] {
220                    Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
221                        assert_eq!(json, "{\"a\":1}");
222                    }
223                    _ => panic!("First argument should be a string literal"),
224                }
225
226                // Second argument should be the path
227                match &func.args[1] {
228                    Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
229                        assert_eq!(path, "$.a");
230                    }
231                    _ => panic!("Second argument should be a string literal"),
232                }
233
234                // Third argument should be a null cast to Int8
235                match &func.args[2] {
236                    Expr::Literal(value, _) => {
237                        assert_eq!(value.data_type(), DataType::Int8);
238                    }
239                    _ => panic!("Third argument should be a cast expression"),
240                }
241            }
242            _ => panic!("Result should be a ScalarFunction"),
243        }
244    }
245
246    #[test]
247    fn test_rewrite_arrow_cast_function() {
248        let rewriter = JsonGetRewriter;
249        let schema = DFSchema::empty();
250        let config = ConfigOptions::new();
251
252        // Create a parse_json function
253        let parse_json_expr = Expr::ScalarFunction(ScalarFunction {
254            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
255                crate::scalars::json::ParseJsonFunction::default(),
256            ))),
257            args: vec![Expr::Literal(
258                ScalarValue::Utf8(Some("{\"a\":1}".to_string())),
259                None,
260            )],
261        });
262
263        // Create a json_get function
264        let json_get_expr = Expr::ScalarFunction(ScalarFunction {
265            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
266                crate::scalars::json::JsonGetWithType::default(),
267            ))),
268            args: vec![
269                parse_json_expr,
270                Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None),
271            ],
272        });
273
274        // Create an arrow cast function: cast(json_get(...), 'Int64')
275        // Note: ArrowCastFunc doesn't exist in this codebase, so this test uses a simple cast instead
276        let arrow_cast_expr = Expr::Cast(Cast {
277            expr: Box::new(json_get_expr),
278            data_type: DataType::Int64,
279        });
280
281        // Apply the rewriter
282        let result = rewriter.rewrite(arrow_cast_expr, &schema, &config).unwrap();
283
284        // Verify the result is transformed
285        assert!(result.transformed);
286
287        // Verify the result is a ScalarFunction (json_get_with_type)
288        match result.data {
289            Expr::ScalarFunction(func) => {
290                // Should have 3 arguments now (original 2 + null cast)
291                assert_eq!(func.args.len(), 3);
292
293                // First argument should be the original parse_json function
294                match &func.args[0] {
295                    Expr::ScalarFunction(parse_func) => {
296                        // Verify it's a parse_json function with the right argument
297                        assert!(
298                            parse_func
299                                .func
300                                .name()
301                                .to_ascii_lowercase()
302                                .contains("parse_json")
303                        );
304                        assert_eq!(parse_func.args.len(), 1);
305                        match &parse_func.args[0] {
306                            Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
307                                assert_eq!(json, "{\"a\":1}");
308                            }
309                            _ => panic!("Parse json argument should be a string literal"),
310                        }
311                    }
312                    _ => panic!("First argument should be a parse_json function"),
313                }
314
315                // Second argument should be the path
316                match &func.args[1] {
317                    Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
318                        assert_eq!(path, "a");
319                    }
320                    _ => panic!("Second argument should be a string literal"),
321                }
322
323                // Third argument should be a null cast to Int64
324                match &func.args[2] {
325                    Expr::Literal(value, _) => {
326                        assert_eq!(value.data_type(), DataType::Int64);
327                    }
328                    _ => panic!("Third argument should be a cast expression"),
329                }
330            }
331            _ => panic!("Result should be a ScalarFunction"),
332        }
333    }
334
335    #[test]
336    fn test_no_rewrite_for_other_functions() {
337        let rewriter = JsonGetRewriter;
338        let schema = DFSchema::empty();
339        let config = ConfigOptions::new();
340
341        // Create a non-json function
342        let other_func = Expr::ScalarFunction(ScalarFunction {
343            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
344                crate::scalars::test::TestAndFunction::default(),
345            ))),
346            args: vec![Expr::Literal(ScalarValue::Int64(Some(4)), None)],
347        });
348
349        // Apply the rewriter
350        let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
351
352        // Verify the result is not transformed
353        assert!(!result.transformed);
354    }
355
356    #[test]
357    fn test_no_rewrite_for_non_cast_functions() {
358        let rewriter = JsonGetRewriter;
359        let schema = DFSchema::empty();
360        let config = ConfigOptions::new();
361
362        // Create a scalar function that doesn't contain "cast"
363        let other_func = Expr::ScalarFunction(ScalarFunction {
364            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
365                crate::scalars::test::TestAndFunction::default(),
366            ))),
367            args: vec![
368                Expr::ScalarFunction(ScalarFunction {
369                    func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
370                        crate::scalars::json::JsonGetWithType::default(),
371                    ))),
372                    args: vec![
373                        Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
374                        Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
375                    ],
376                }),
377                Expr::Literal(ScalarValue::Utf8(Some("Int64".to_string())), None),
378            ],
379        });
380
381        // Apply the rewriter
382        let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
383
384        // Verify the result is not transformed
385        assert!(!result.transformed);
386    }
387}