Skip to main content

common_function/scalars/json/
json_get_rewriter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[cfg(test)]
16use std::sync::Arc;
17
18use arrow_schema::{DataType, TimeUnit};
19use datafusion::common::config::ConfigOptions;
20use datafusion::common::tree_node::Transformed;
21use datafusion::common::{DFSchema, Result};
22use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
23use datafusion::scalar::ScalarValue;
24use datafusion_expr::expr::ScalarFunction;
25use datafusion_expr::{Cast, Expr};
26
27use crate::scalars::json::JsonGetWithType;
28
29#[derive(Debug)]
30pub struct JsonGetRewriter;
31
32impl FunctionRewrite for JsonGetRewriter {
33    fn name(&self) -> &'static str {
34        "JsonGetRewriter"
35    }
36
37    fn rewrite(
38        &self,
39        expr: Expr,
40        _schema: &DFSchema,
41        _config: &ConfigOptions,
42    ) -> Result<Transformed<Expr>> {
43        Ok(match expr {
44            Expr::Cast(cast) => inject_type_from_cast_expr(cast)?,
45            Expr::ScalarFunction(cast) => inject_type_from_cast_func(cast)?,
46            expr => Transformed::no(expr),
47        })
48    }
49}
50
51// Expr::Cast(
52//   Expr::ScalarFunction(
53//     json_get(column, path),
54//     <data_type>
55//   )
56// )
57// =>
58// Expr::ScalarFunction(
59//   json_get(column, path, <data_type>)
60// )
61fn inject_type_from_cast_expr(cast: Cast) -> Result<Transformed<Expr>> {
62    let Cast {
63        expr,
64        mut data_type,
65    } = cast;
66
67    let mut json_get = match *expr {
68        Expr::ScalarFunction(f)
69            if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
70        {
71            f
72        }
73        expr => {
74            return Ok(Transformed::no(Expr::Cast(Cast {
75                expr: Box::new(expr),
76                data_type,
77            })));
78        }
79    };
80
81    if data_type.is_string() {
82        data_type = DataType::Utf8View;
83    }
84    let with_type = ScalarValue::try_new_null(&data_type).map(|x| Expr::Literal(x, None))?;
85    json_get.args.push(with_type);
86    Ok(Transformed::yes(Expr::ScalarFunction(json_get)))
87}
88
89// Expr::ScalarFunction(
90//   arrow_cast(
91//     Expr::ScalarFunction(
92//       json_get(column, path),
93//     ),
94//     <data_type>
95//   )
96// )
97// =>
98// Expr::ScalarFunction(
99//   json_get(column, path, <data_type>)
100// )
101fn inject_type_from_cast_func(cast: ScalarFunction) -> Result<Transformed<Expr>> {
102    let ScalarFunction { func, args } = cast;
103
104    // Check if this is an Arrow cast function
105    // The function name might be "arrow_cast" or similar
106    let func_name = func.name().to_ascii_lowercase();
107    if !func_name.contains("arrow_cast") {
108        let original = Expr::ScalarFunction(ScalarFunction { func, args });
109        return Ok(Transformed::no(original));
110    }
111
112    // Arrow cast function should have exactly 2 arguments:
113    // 1. The expression to cast (could be json_get)
114    // 2. The target type as a string literal
115    if args.len() != 2 {
116        let original = Expr::ScalarFunction(ScalarFunction { func, args });
117        return Ok(Transformed::no(original));
118    }
119    let [arg0, arg1] = args.try_into().unwrap_or_else(|_| unreachable!());
120
121    let Some(with_type) = arg1
122        .as_literal()
123        .and_then(|x| x.try_as_str())
124        .flatten()
125        .and_then(parse_data_type_from_string)
126    else {
127        let original = Expr::ScalarFunction(ScalarFunction {
128            func,
129            args: vec![arg0, arg1],
130        });
131        return Ok(Transformed::no(original));
132    };
133
134    let mut json_get = match arg0 {
135        Expr::ScalarFunction(f)
136            if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
137        {
138            f
139        }
140        arg0 => {
141            let original = Expr::ScalarFunction(ScalarFunction {
142                func,
143                args: vec![arg0, arg1],
144            });
145            return Ok(Transformed::no(original));
146        }
147    };
148
149    let with_type = ScalarValue::try_new_null(&with_type).map(|x| Expr::Literal(x, None))?;
150    json_get.args.push(with_type);
151
152    let rewritten = Expr::ScalarFunction(json_get);
153    Ok(Transformed::yes(rewritten))
154}
155
156// Parse a data type from a string representation
157fn parse_data_type_from_string(type_str: &str) -> Option<DataType> {
158    match type_str.to_lowercase().as_str() {
159        "int8" | "tinyint" => Some(DataType::Int8),
160        "int16" | "smallint" => Some(DataType::Int16),
161        "int32" | "integer" => Some(DataType::Int32),
162        "int64" | "bigint" => Some(DataType::Int64),
163        "uint8" => Some(DataType::UInt8),
164        "uint16" => Some(DataType::UInt16),
165        "uint32" => Some(DataType::UInt32),
166        "uint64" => Some(DataType::UInt64),
167        "float32" | "real" => Some(DataType::Float32),
168        "float64" | "double" => Some(DataType::Float64),
169        "boolean" | "bool" => Some(DataType::Boolean),
170        "string" | "text" | "varchar" => Some(DataType::Utf8),
171        "timestamp" => Some(DataType::Timestamp(TimeUnit::Microsecond, None)),
172        "date" => Some(DataType::Date32),
173        _ => None,
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use arrow_schema::DataType;
180    use datafusion::common::DFSchema;
181    use datafusion::common::config::ConfigOptions;
182    use datafusion::logical_expr::expr::Cast;
183    use datafusion::scalar::ScalarValue;
184    use datafusion_expr::Expr;
185    use datafusion_expr::expr::ScalarFunction;
186
187    use super::*;
188
189    #[test]
190    fn test_rewrite_regular_cast() {
191        let rewriter = JsonGetRewriter;
192        let schema = DFSchema::empty();
193        let config = ConfigOptions::new();
194
195        // Create a json_get function
196        let json_expr = Expr::ScalarFunction(ScalarFunction {
197            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
198                crate::scalars::json::JsonGetWithType::default(),
199            ))),
200            args: vec![
201                Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
202                Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
203            ],
204        });
205
206        // Create a cast expression: json_get(...)::int8
207        let cast_expr = Expr::Cast(Cast {
208            expr: Box::new(json_expr),
209            data_type: DataType::Int8,
210        });
211
212        // Apply the rewriter
213        let result = rewriter.rewrite(cast_expr, &schema, &config).unwrap();
214
215        // Verify the result is transformed
216        assert!(result.transformed);
217
218        // Verify the result is a ScalarFunction
219        match result.data {
220            Expr::ScalarFunction(func) => {
221                // Should have 3 arguments now (original 2 + null cast)
222                assert_eq!(func.args.len(), 3);
223
224                // First argument should be the original json
225                match &func.args[0] {
226                    Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
227                        assert_eq!(json, "{\"a\":1}");
228                    }
229                    _ => panic!("First argument should be a string literal"),
230                }
231
232                // Second argument should be the path
233                match &func.args[1] {
234                    Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
235                        assert_eq!(path, "$.a");
236                    }
237                    _ => panic!("Second argument should be a string literal"),
238                }
239
240                // Third argument should be a null cast to Int8
241                match &func.args[2] {
242                    Expr::Literal(value, _) => {
243                        assert_eq!(value.data_type(), DataType::Int8);
244                    }
245                    _ => panic!("Third argument should be a cast expression"),
246                }
247            }
248            _ => panic!("Result should be a ScalarFunction"),
249        }
250    }
251
252    #[test]
253    fn test_rewrite_arrow_cast_function() {
254        let rewriter = JsonGetRewriter;
255        let schema = DFSchema::empty();
256        let config = ConfigOptions::new();
257
258        // Create a parse_json function
259        let parse_json_expr = Expr::ScalarFunction(ScalarFunction {
260            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
261                crate::scalars::json::ParseJsonFunction::default(),
262            ))),
263            args: vec![Expr::Literal(
264                ScalarValue::Utf8(Some("{\"a\":1}".to_string())),
265                None,
266            )],
267        });
268
269        // Create a json_get function
270        let json_get_expr = Expr::ScalarFunction(ScalarFunction {
271            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
272                crate::scalars::json::JsonGetWithType::default(),
273            ))),
274            args: vec![
275                parse_json_expr,
276                Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None),
277            ],
278        });
279
280        // Create an arrow cast function: cast(json_get(...), 'Int64')
281        // Note: ArrowCastFunc doesn't exist in this codebase, so this test uses a simple cast instead
282        let arrow_cast_expr = Expr::Cast(Cast {
283            expr: Box::new(json_get_expr),
284            data_type: DataType::Int64,
285        });
286
287        // Apply the rewriter
288        let result = rewriter.rewrite(arrow_cast_expr, &schema, &config).unwrap();
289
290        // Verify the result is transformed
291        assert!(result.transformed);
292
293        // Verify the result is a ScalarFunction (json_get_with_type)
294        match result.data {
295            Expr::ScalarFunction(func) => {
296                // Should have 3 arguments now (original 2 + null cast)
297                assert_eq!(func.args.len(), 3);
298
299                // First argument should be the original parse_json function
300                match &func.args[0] {
301                    Expr::ScalarFunction(parse_func) => {
302                        // Verify it's a parse_json function with the right argument
303                        assert!(
304                            parse_func
305                                .func
306                                .name()
307                                .to_ascii_lowercase()
308                                .contains("parse_json")
309                        );
310                        assert_eq!(parse_func.args.len(), 1);
311                        match &parse_func.args[0] {
312                            Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
313                                assert_eq!(json, "{\"a\":1}");
314                            }
315                            _ => panic!("Parse json argument should be a string literal"),
316                        }
317                    }
318                    _ => panic!("First argument should be a parse_json function"),
319                }
320
321                // Second argument should be the path
322                match &func.args[1] {
323                    Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
324                        assert_eq!(path, "a");
325                    }
326                    _ => panic!("Second argument should be a string literal"),
327                }
328
329                // Third argument should be a null cast to Int64
330                match &func.args[2] {
331                    Expr::Literal(value, _) => {
332                        assert_eq!(value.data_type(), DataType::Int64);
333                    }
334                    _ => panic!("Third argument should be a cast expression"),
335                }
336            }
337            _ => panic!("Result should be a ScalarFunction"),
338        }
339    }
340
341    #[test]
342    fn test_no_rewrite_for_other_functions() {
343        let rewriter = JsonGetRewriter;
344        let schema = DFSchema::empty();
345        let config = ConfigOptions::new();
346
347        // Create a non-json function
348        let other_func = Expr::ScalarFunction(ScalarFunction {
349            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
350                crate::scalars::test::TestAndFunction::default(),
351            ))),
352            args: vec![Expr::Literal(ScalarValue::Int64(Some(4)), None)],
353        });
354
355        // Apply the rewriter
356        let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
357
358        // Verify the result is not transformed
359        assert!(!result.transformed);
360    }
361
362    #[test]
363    fn test_no_rewrite_for_non_cast_functions() {
364        let rewriter = JsonGetRewriter;
365        let schema = DFSchema::empty();
366        let config = ConfigOptions::new();
367
368        // Create a scalar function that doesn't contain "cast"
369        let other_func = Expr::ScalarFunction(ScalarFunction {
370            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
371                crate::scalars::test::TestAndFunction::default(),
372            ))),
373            args: vec![
374                Expr::ScalarFunction(ScalarFunction {
375                    func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
376                        crate::scalars::json::JsonGetWithType::default(),
377                    ))),
378                    args: vec![
379                        Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
380                        Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
381                    ],
382                }),
383                Expr::Literal(ScalarValue::Utf8(Some("Int64".to_string())), None),
384            ],
385        });
386
387        // Apply the rewriter
388        let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
389
390        // Verify the result is not transformed
391        assert!(!result.transformed);
392    }
393}