common_function/scalars/json/
json_get_rewriter.rs1#[cfg(test)]
16use std::sync::Arc;
17
18use arrow_schema::{DataType, TimeUnit};
19use datafusion::common::config::ConfigOptions;
20use datafusion::common::tree_node::Transformed;
21use datafusion::common::{DFSchema, Result};
22use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
23use datafusion::scalar::ScalarValue;
24use datafusion_expr::expr::ScalarFunction;
25use datafusion_expr::{Cast, Expr};
26
27use crate::scalars::json::JsonGetWithType;
28
29#[derive(Debug)]
30pub struct JsonGetRewriter;
31
32impl FunctionRewrite for JsonGetRewriter {
33 fn name(&self) -> &'static str {
34 "JsonGetRewriter"
35 }
36
37 fn rewrite(
38 &self,
39 expr: Expr,
40 _schema: &DFSchema,
41 _config: &ConfigOptions,
42 ) -> Result<Transformed<Expr>> {
43 Ok(match expr {
44 Expr::Cast(cast) => inject_type_from_cast_expr(cast)?,
45 Expr::ScalarFunction(cast) => inject_type_from_cast_func(cast)?,
46 expr => Transformed::no(expr),
47 })
48 }
49}
50
51fn inject_type_from_cast_expr(cast: Cast) -> Result<Transformed<Expr>> {
62 let Cast {
63 expr,
64 mut data_type,
65 } = cast;
66
67 let mut json_get = match *expr {
68 Expr::ScalarFunction(f)
69 if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
70 {
71 f
72 }
73 expr => {
74 return Ok(Transformed::no(Expr::Cast(Cast {
75 expr: Box::new(expr),
76 data_type,
77 })));
78 }
79 };
80
81 if data_type.is_string() {
82 data_type = DataType::Utf8View;
83 }
84 let with_type = ScalarValue::try_new_null(&data_type).map(|x| Expr::Literal(x, None))?;
85 json_get.args.push(with_type);
86 Ok(Transformed::yes(Expr::ScalarFunction(json_get)))
87}
88
89fn inject_type_from_cast_func(cast: ScalarFunction) -> Result<Transformed<Expr>> {
102 let ScalarFunction { func, args } = cast;
103
104 let func_name = func.name().to_ascii_lowercase();
107 if !func_name.contains("arrow_cast") {
108 let original = Expr::ScalarFunction(ScalarFunction { func, args });
109 return Ok(Transformed::no(original));
110 }
111
112 if args.len() != 2 {
116 let original = Expr::ScalarFunction(ScalarFunction { func, args });
117 return Ok(Transformed::no(original));
118 }
119 let [arg0, arg1] = args.try_into().unwrap_or_else(|_| unreachable!());
120
121 let Some(with_type) = arg1
122 .as_literal()
123 .and_then(|x| x.try_as_str())
124 .flatten()
125 .and_then(parse_data_type_from_string)
126 else {
127 let original = Expr::ScalarFunction(ScalarFunction {
128 func,
129 args: vec![arg0, arg1],
130 });
131 return Ok(Transformed::no(original));
132 };
133
134 let mut json_get = match arg0 {
135 Expr::ScalarFunction(f)
136 if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
137 {
138 f
139 }
140 arg0 => {
141 let original = Expr::ScalarFunction(ScalarFunction {
142 func,
143 args: vec![arg0, arg1],
144 });
145 return Ok(Transformed::no(original));
146 }
147 };
148
149 let with_type = ScalarValue::try_new_null(&with_type).map(|x| Expr::Literal(x, None))?;
150 json_get.args.push(with_type);
151
152 let rewritten = Expr::ScalarFunction(json_get);
153 Ok(Transformed::yes(rewritten))
154}
155
156fn parse_data_type_from_string(type_str: &str) -> Option<DataType> {
158 match type_str.to_lowercase().as_str() {
159 "int8" | "tinyint" => Some(DataType::Int8),
160 "int16" | "smallint" => Some(DataType::Int16),
161 "int32" | "integer" => Some(DataType::Int32),
162 "int64" | "bigint" => Some(DataType::Int64),
163 "uint8" => Some(DataType::UInt8),
164 "uint16" => Some(DataType::UInt16),
165 "uint32" => Some(DataType::UInt32),
166 "uint64" => Some(DataType::UInt64),
167 "float32" | "real" => Some(DataType::Float32),
168 "float64" | "double" => Some(DataType::Float64),
169 "boolean" | "bool" => Some(DataType::Boolean),
170 "string" | "text" | "varchar" => Some(DataType::Utf8),
171 "timestamp" => Some(DataType::Timestamp(TimeUnit::Microsecond, None)),
172 "date" => Some(DataType::Date32),
173 _ => None,
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use arrow_schema::DataType;
180 use datafusion::common::DFSchema;
181 use datafusion::common::config::ConfigOptions;
182 use datafusion::logical_expr::expr::Cast;
183 use datafusion::scalar::ScalarValue;
184 use datafusion_expr::Expr;
185 use datafusion_expr::expr::ScalarFunction;
186
187 use super::*;
188
189 #[test]
190 fn test_rewrite_regular_cast() {
191 let rewriter = JsonGetRewriter;
192 let schema = DFSchema::empty();
193 let config = ConfigOptions::new();
194
195 let json_expr = Expr::ScalarFunction(ScalarFunction {
197 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
198 crate::scalars::json::JsonGetWithType::default(),
199 ))),
200 args: vec![
201 Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
202 Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
203 ],
204 });
205
206 let cast_expr = Expr::Cast(Cast {
208 expr: Box::new(json_expr),
209 data_type: DataType::Int8,
210 });
211
212 let result = rewriter.rewrite(cast_expr, &schema, &config).unwrap();
214
215 assert!(result.transformed);
217
218 match result.data {
220 Expr::ScalarFunction(func) => {
221 assert_eq!(func.args.len(), 3);
223
224 match &func.args[0] {
226 Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
227 assert_eq!(json, "{\"a\":1}");
228 }
229 _ => panic!("First argument should be a string literal"),
230 }
231
232 match &func.args[1] {
234 Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
235 assert_eq!(path, "$.a");
236 }
237 _ => panic!("Second argument should be a string literal"),
238 }
239
240 match &func.args[2] {
242 Expr::Literal(value, _) => {
243 assert_eq!(value.data_type(), DataType::Int8);
244 }
245 _ => panic!("Third argument should be a cast expression"),
246 }
247 }
248 _ => panic!("Result should be a ScalarFunction"),
249 }
250 }
251
252 #[test]
253 fn test_rewrite_arrow_cast_function() {
254 let rewriter = JsonGetRewriter;
255 let schema = DFSchema::empty();
256 let config = ConfigOptions::new();
257
258 let parse_json_expr = Expr::ScalarFunction(ScalarFunction {
260 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
261 crate::scalars::json::ParseJsonFunction::default(),
262 ))),
263 args: vec![Expr::Literal(
264 ScalarValue::Utf8(Some("{\"a\":1}".to_string())),
265 None,
266 )],
267 });
268
269 let json_get_expr = Expr::ScalarFunction(ScalarFunction {
271 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
272 crate::scalars::json::JsonGetWithType::default(),
273 ))),
274 args: vec![
275 parse_json_expr,
276 Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None),
277 ],
278 });
279
280 let arrow_cast_expr = Expr::Cast(Cast {
283 expr: Box::new(json_get_expr),
284 data_type: DataType::Int64,
285 });
286
287 let result = rewriter.rewrite(arrow_cast_expr, &schema, &config).unwrap();
289
290 assert!(result.transformed);
292
293 match result.data {
295 Expr::ScalarFunction(func) => {
296 assert_eq!(func.args.len(), 3);
298
299 match &func.args[0] {
301 Expr::ScalarFunction(parse_func) => {
302 assert!(
304 parse_func
305 .func
306 .name()
307 .to_ascii_lowercase()
308 .contains("parse_json")
309 );
310 assert_eq!(parse_func.args.len(), 1);
311 match &parse_func.args[0] {
312 Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
313 assert_eq!(json, "{\"a\":1}");
314 }
315 _ => panic!("Parse json argument should be a string literal"),
316 }
317 }
318 _ => panic!("First argument should be a parse_json function"),
319 }
320
321 match &func.args[1] {
323 Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
324 assert_eq!(path, "a");
325 }
326 _ => panic!("Second argument should be a string literal"),
327 }
328
329 match &func.args[2] {
331 Expr::Literal(value, _) => {
332 assert_eq!(value.data_type(), DataType::Int64);
333 }
334 _ => panic!("Third argument should be a cast expression"),
335 }
336 }
337 _ => panic!("Result should be a ScalarFunction"),
338 }
339 }
340
341 #[test]
342 fn test_no_rewrite_for_other_functions() {
343 let rewriter = JsonGetRewriter;
344 let schema = DFSchema::empty();
345 let config = ConfigOptions::new();
346
347 let other_func = Expr::ScalarFunction(ScalarFunction {
349 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
350 crate::scalars::test::TestAndFunction::default(),
351 ))),
352 args: vec![Expr::Literal(ScalarValue::Int64(Some(4)), None)],
353 });
354
355 let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
357
358 assert!(!result.transformed);
360 }
361
362 #[test]
363 fn test_no_rewrite_for_non_cast_functions() {
364 let rewriter = JsonGetRewriter;
365 let schema = DFSchema::empty();
366 let config = ConfigOptions::new();
367
368 let other_func = Expr::ScalarFunction(ScalarFunction {
370 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
371 crate::scalars::test::TestAndFunction::default(),
372 ))),
373 args: vec![
374 Expr::ScalarFunction(ScalarFunction {
375 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
376 crate::scalars::json::JsonGetWithType::default(),
377 ))),
378 args: vec![
379 Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
380 Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
381 ],
382 }),
383 Expr::Literal(ScalarValue::Utf8(Some("Int64".to_string())), None),
384 ],
385 });
386
387 let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
389
390 assert!(!result.transformed);
392 }
393}