common_function/scalars/json/
json_get_rewriter.rs1#[cfg(test)]
16use std::sync::Arc;
17
18use arrow_schema::{DataType, TimeUnit};
19use datafusion::common::config::ConfigOptions;
20use datafusion::common::tree_node::Transformed;
21use datafusion::common::{DFSchema, Result};
22use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
23use datafusion::scalar::ScalarValue;
24use datafusion_expr::expr::ScalarFunction;
25use datafusion_expr::{Cast, Expr};
26
27use crate::scalars::json::JsonGetWithType;
28
29#[derive(Debug)]
30pub struct JsonGetRewriter;
31
32impl FunctionRewrite for JsonGetRewriter {
33 fn name(&self) -> &'static str {
34 "JsonGetRewriter"
35 }
36
37 fn rewrite(
38 &self,
39 expr: Expr,
40 _schema: &DFSchema,
41 _config: &ConfigOptions,
42 ) -> Result<Transformed<Expr>> {
43 Ok(match expr {
44 Expr::Cast(cast) => inject_type_from_cast_expr(cast)?,
45 Expr::ScalarFunction(cast) => inject_type_from_cast_func(cast)?,
46 expr => Transformed::no(expr),
47 })
48 }
49}
50
51fn inject_type_from_cast_expr(cast: Cast) -> Result<Transformed<Expr>> {
62 let Cast { expr, data_type } = cast;
63
64 let mut json_get = match *expr {
65 Expr::ScalarFunction(f)
66 if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
67 {
68 f
69 }
70 expr => {
71 return Ok(Transformed::no(Expr::Cast(Cast {
72 expr: Box::new(expr),
73 data_type,
74 })));
75 }
76 };
77
78 let with_type = ScalarValue::try_new_null(&data_type).map(|x| Expr::Literal(x, None))?;
79 json_get.args.push(with_type);
80 Ok(Transformed::yes(Expr::ScalarFunction(json_get)))
81}
82
83fn inject_type_from_cast_func(cast: ScalarFunction) -> Result<Transformed<Expr>> {
96 let ScalarFunction { func, args } = cast;
97
98 let func_name = func.name().to_ascii_lowercase();
101 if !func_name.contains("arrow_cast") {
102 let original = Expr::ScalarFunction(ScalarFunction { func, args });
103 return Ok(Transformed::no(original));
104 }
105
106 if args.len() != 2 {
110 let original = Expr::ScalarFunction(ScalarFunction { func, args });
111 return Ok(Transformed::no(original));
112 }
113 let [arg0, arg1] = args.try_into().unwrap_or_else(|_| unreachable!());
114
115 let Some(with_type) = arg1
116 .as_literal()
117 .and_then(|x| x.try_as_str())
118 .flatten()
119 .and_then(parse_data_type_from_string)
120 else {
121 let original = Expr::ScalarFunction(ScalarFunction {
122 func,
123 args: vec![arg0, arg1],
124 });
125 return Ok(Transformed::no(original));
126 };
127
128 let mut json_get = match arg0 {
129 Expr::ScalarFunction(f)
130 if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
131 {
132 f
133 }
134 arg0 => {
135 let original = Expr::ScalarFunction(ScalarFunction {
136 func,
137 args: vec![arg0, arg1],
138 });
139 return Ok(Transformed::no(original));
140 }
141 };
142
143 let with_type = ScalarValue::try_new_null(&with_type).map(|x| Expr::Literal(x, None))?;
144 json_get.args.push(with_type);
145
146 let rewritten = Expr::ScalarFunction(json_get);
147 Ok(Transformed::yes(rewritten))
148}
149
150fn parse_data_type_from_string(type_str: &str) -> Option<DataType> {
152 match type_str.to_lowercase().as_str() {
153 "int8" | "tinyint" => Some(DataType::Int8),
154 "int16" | "smallint" => Some(DataType::Int16),
155 "int32" | "integer" => Some(DataType::Int32),
156 "int64" | "bigint" => Some(DataType::Int64),
157 "uint8" => Some(DataType::UInt8),
158 "uint16" => Some(DataType::UInt16),
159 "uint32" => Some(DataType::UInt32),
160 "uint64" => Some(DataType::UInt64),
161 "float32" | "real" => Some(DataType::Float32),
162 "float64" | "double" => Some(DataType::Float64),
163 "boolean" | "bool" => Some(DataType::Boolean),
164 "string" | "text" | "varchar" => Some(DataType::Utf8),
165 "timestamp" => Some(DataType::Timestamp(TimeUnit::Microsecond, None)),
166 "date" => Some(DataType::Date32),
167 _ => None,
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use arrow_schema::DataType;
174 use datafusion::common::DFSchema;
175 use datafusion::common::config::ConfigOptions;
176 use datafusion::logical_expr::expr::Cast;
177 use datafusion::scalar::ScalarValue;
178 use datafusion_expr::Expr;
179 use datafusion_expr::expr::ScalarFunction;
180
181 use super::*;
182
183 #[test]
184 fn test_rewrite_regular_cast() {
185 let rewriter = JsonGetRewriter;
186 let schema = DFSchema::empty();
187 let config = ConfigOptions::new();
188
189 let json_expr = Expr::ScalarFunction(ScalarFunction {
191 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
192 crate::scalars::json::JsonGetWithType::default(),
193 ))),
194 args: vec![
195 Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
196 Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
197 ],
198 });
199
200 let cast_expr = Expr::Cast(Cast {
202 expr: Box::new(json_expr),
203 data_type: DataType::Int8,
204 });
205
206 let result = rewriter.rewrite(cast_expr, &schema, &config).unwrap();
208
209 assert!(result.transformed);
211
212 match result.data {
214 Expr::ScalarFunction(func) => {
215 assert_eq!(func.args.len(), 3);
217
218 match &func.args[0] {
220 Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
221 assert_eq!(json, "{\"a\":1}");
222 }
223 _ => panic!("First argument should be a string literal"),
224 }
225
226 match &func.args[1] {
228 Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
229 assert_eq!(path, "$.a");
230 }
231 _ => panic!("Second argument should be a string literal"),
232 }
233
234 match &func.args[2] {
236 Expr::Literal(value, _) => {
237 assert_eq!(value.data_type(), DataType::Int8);
238 }
239 _ => panic!("Third argument should be a cast expression"),
240 }
241 }
242 _ => panic!("Result should be a ScalarFunction"),
243 }
244 }
245
246 #[test]
247 fn test_rewrite_arrow_cast_function() {
248 let rewriter = JsonGetRewriter;
249 let schema = DFSchema::empty();
250 let config = ConfigOptions::new();
251
252 let parse_json_expr = Expr::ScalarFunction(ScalarFunction {
254 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
255 crate::scalars::json::ParseJsonFunction::default(),
256 ))),
257 args: vec![Expr::Literal(
258 ScalarValue::Utf8(Some("{\"a\":1}".to_string())),
259 None,
260 )],
261 });
262
263 let json_get_expr = Expr::ScalarFunction(ScalarFunction {
265 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
266 crate::scalars::json::JsonGetWithType::default(),
267 ))),
268 args: vec![
269 parse_json_expr,
270 Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None),
271 ],
272 });
273
274 let arrow_cast_expr = Expr::Cast(Cast {
277 expr: Box::new(json_get_expr),
278 data_type: DataType::Int64,
279 });
280
281 let result = rewriter.rewrite(arrow_cast_expr, &schema, &config).unwrap();
283
284 assert!(result.transformed);
286
287 match result.data {
289 Expr::ScalarFunction(func) => {
290 assert_eq!(func.args.len(), 3);
292
293 match &func.args[0] {
295 Expr::ScalarFunction(parse_func) => {
296 assert!(
298 parse_func
299 .func
300 .name()
301 .to_ascii_lowercase()
302 .contains("parse_json")
303 );
304 assert_eq!(parse_func.args.len(), 1);
305 match &parse_func.args[0] {
306 Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
307 assert_eq!(json, "{\"a\":1}");
308 }
309 _ => panic!("Parse json argument should be a string literal"),
310 }
311 }
312 _ => panic!("First argument should be a parse_json function"),
313 }
314
315 match &func.args[1] {
317 Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
318 assert_eq!(path, "a");
319 }
320 _ => panic!("Second argument should be a string literal"),
321 }
322
323 match &func.args[2] {
325 Expr::Literal(value, _) => {
326 assert_eq!(value.data_type(), DataType::Int64);
327 }
328 _ => panic!("Third argument should be a cast expression"),
329 }
330 }
331 _ => panic!("Result should be a ScalarFunction"),
332 }
333 }
334
335 #[test]
336 fn test_no_rewrite_for_other_functions() {
337 let rewriter = JsonGetRewriter;
338 let schema = DFSchema::empty();
339 let config = ConfigOptions::new();
340
341 let other_func = Expr::ScalarFunction(ScalarFunction {
343 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
344 crate::scalars::test::TestAndFunction::default(),
345 ))),
346 args: vec![Expr::Literal(ScalarValue::Int64(Some(4)), None)],
347 });
348
349 let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
351
352 assert!(!result.transformed);
354 }
355
356 #[test]
357 fn test_no_rewrite_for_non_cast_functions() {
358 let rewriter = JsonGetRewriter;
359 let schema = DFSchema::empty();
360 let config = ConfigOptions::new();
361
362 let other_func = Expr::ScalarFunction(ScalarFunction {
364 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
365 crate::scalars::test::TestAndFunction::default(),
366 ))),
367 args: vec![
368 Expr::ScalarFunction(ScalarFunction {
369 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
370 crate::scalars::json::JsonGetWithType::default(),
371 ))),
372 args: vec![
373 Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
374 Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
375 ],
376 }),
377 Expr::Literal(ScalarValue::Utf8(Some("Int64".to_string())), None),
378 ],
379 });
380
381 let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
383
384 assert!(!result.transformed);
386 }
387}