diff --git a/src/common/function/src/scalars/json/json_get_rewriter.rs b/src/common/function/src/scalars/json/json_get_rewriter.rs index 69cea9d443..137b307412 100644 --- a/src/common/function/src/scalars/json/json_get_rewriter.rs +++ b/src/common/function/src/scalars/json/json_get_rewriter.rs @@ -40,92 +40,111 @@ impl FunctionRewrite for JsonGetRewriter { _schema: &DFSchema, _config: &ConfigOptions, ) -> Result> { - let transform = match &expr { - Expr::Cast(cast) => rewrite_json_get_cast(cast), - Expr::ScalarFunction(scalar_func) => rewrite_arrow_cast_json_get(scalar_func), - _ => None, - }; - Ok(transform.unwrap_or_else(|| Transformed::no(expr))) + Ok(match expr { + Expr::Cast(cast) => inject_type_from_cast_expr(cast)?, + Expr::ScalarFunction(cast) => inject_type_from_cast_func(cast)?, + expr => Transformed::no(expr), + }) } } -fn is_json_get_function_call(scalar_func: &ScalarFunction) -> bool { - scalar_func.func.name().to_ascii_lowercase() == JsonGetWithType::NAME - && scalar_func.args.len() == 2 +// Expr::Cast( +// Expr::ScalarFunction( +// json_get(column, path), +// +// ) +// ) +// => +// Expr::ScalarFunction( +// json_get(column, path, ) +// ) +fn inject_type_from_cast_expr(cast: Cast) -> Result> { + let Cast { expr, data_type } = cast; + + let mut json_get = match *expr { + Expr::ScalarFunction(f) + if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 => + { + f + } + expr => { + return Ok(Transformed::no(Expr::Cast(Cast { + expr: Box::new(expr), + data_type, + }))); + } + }; + + let with_type = ScalarValue::try_new_null(&data_type).map(|x| Expr::Literal(x, None))?; + json_get.args.push(with_type); + Ok(Transformed::yes(Expr::ScalarFunction(json_get))) } -fn rewrite_json_get_cast(cast: &Cast) -> Option> { - let scalar_func = extract_scalar_function(&cast.expr)?; - if is_json_get_function_call(scalar_func) { - let null_expr = Expr::Literal(ScalarValue::Null, None); - let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast { - expr: Box::new(null_expr), - data_type: cast.data_type.clone(), - }); +// Expr::ScalarFunction( +// arrow_cast( +// Expr::ScalarFunction( +// json_get(column, path), +// ), +// +// ) +// ) +// => +// Expr::ScalarFunction( +// json_get(column, path, ) +// ) +fn inject_type_from_cast_func(cast: ScalarFunction) -> Result> { + let ScalarFunction { func, args } = cast; - let mut args = scalar_func.args.clone(); - args.push(null_cast); - - Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction { - func: scalar_func.func.clone(), - args, - }))) - } else { - None - } -} - -// Handle Arrow cast function: cast(json_get(a, 'path'), 'Int64') -fn rewrite_arrow_cast_json_get(scalar_func: &ScalarFunction) -> Option> { // Check if this is an Arrow cast function // The function name might be "arrow_cast" or similar - let func_name = scalar_func.func.name().to_ascii_lowercase(); + let func_name = func.name().to_ascii_lowercase(); if !func_name.contains("arrow_cast") { - return None; + let original = Expr::ScalarFunction(ScalarFunction { func, args }); + return Ok(Transformed::no(original)); } // Arrow cast function should have exactly 2 arguments: // 1. The expression to cast (could be json_get) // 2. The target type as a string literal - if scalar_func.args.len() != 2 { - return None; + if args.len() != 2 { + let original = Expr::ScalarFunction(ScalarFunction { func, args }); + return Ok(Transformed::no(original)); } + let [arg0, arg1] = args.try_into().unwrap_or_else(|_| unreachable!()); - // Extract the inner json_get function - let json_get_func = extract_scalar_function(&scalar_func.args[0])?; - - // Check if it's a json_get function - if is_json_get_function_call(json_get_func) { - // Get the target type from the second argument - let target_type = extract_string_literal(&scalar_func.args[1])?; - let data_type = parse_data_type_from_string(&target_type)?; - - // Create the null expression with the same type - let null_expr = Expr::Literal(ScalarValue::Null, None); - let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast { - expr: Box::new(null_expr), - data_type, + let Some(with_type) = arg1 + .as_literal() + .and_then(|x| x.try_as_str()) + .flatten() + .and_then(parse_data_type_from_string) + else { + let original = Expr::ScalarFunction(ScalarFunction { + func, + args: vec![arg0, arg1], }); + return Ok(Transformed::no(original)); + }; - // Create the new json_get_with_type function with the null parameter - let mut args = json_get_func.args.clone(); - args.push(null_cast); + let mut json_get = match arg0 { + Expr::ScalarFunction(f) + if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 => + { + f + } + arg0 => { + let original = Expr::ScalarFunction(ScalarFunction { + func, + args: vec![arg0, arg1], + }); + return Ok(Transformed::no(original)); + } + }; - Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction { - func: json_get_func.func.clone(), - args, - }))) - } else { - None - } -} + let with_type = ScalarValue::try_new_null(&with_type).map(|x| Expr::Literal(x, None))?; + json_get.args.push(with_type); -// Extract string literal from an expression -fn extract_string_literal(expr: &Expr) -> Option { - match expr { - Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s.clone()), - _ => None, - } + let rewritten = Expr::ScalarFunction(json_get); + Ok(Transformed::yes(rewritten)) } // Parse a data type from a string representation @@ -149,13 +168,6 @@ fn parse_data_type_from_string(type_str: &str) -> Option { } } -fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> { - match expr { - Expr::ScalarFunction(func) => Some(func), - _ => None, - } -} - #[cfg(test)] mod tests { use arrow_schema::DataType; @@ -221,12 +233,8 @@ mod tests { // Third argument should be a null cast to Int8 match &func.args[2] { - Expr::Cast(Cast { expr, data_type }) => { - assert_eq!(*data_type, DataType::Int8); - match expr.as_ref() { - Expr::Literal(ScalarValue::Null, _) => {} - _ => panic!("Third argument should be a null cast"), - } + Expr::Literal(value, _) => { + assert_eq!(value.data_type(), DataType::Int8); } _ => panic!("Third argument should be a cast expression"), } @@ -314,12 +322,8 @@ mod tests { // Third argument should be a null cast to Int64 match &func.args[2] { - Expr::Cast(Cast { expr, data_type }) => { - assert_eq!(*data_type, DataType::Int64); - match expr.as_ref() { - Expr::Literal(ScalarValue::Null, _) => {} - _ => panic!("Third argument should be a null cast"), - } + Expr::Literal(value, _) => { + assert_eq!(value.data_type(), DataType::Int64); } _ => panic!("Third argument should be a cast expression"), } diff --git a/src/query/src/optimizer/json_type_concretize.rs b/src/query/src/optimizer/json_type_concretize.rs index 61f6d7081f..9c7e764efe 100644 --- a/src/query/src/optimizer/json_type_concretize.rs +++ b/src/query/src/optimizer/json_type_concretize.rs @@ -97,9 +97,9 @@ fn deduce_json_type(expr: &Expr) -> Result> { let Some(Expr::Column(column)) = f.args.first() else { return plan_err!( - "First argument of {} is expected to be a column expr, actual: {}", + "First argument of {} is expected to be a column expr, actual: {:?}", JsonGetWithType::NAME, - f.args[0] + f.args.first() ); }; @@ -111,9 +111,9 @@ fn deduce_json_type(expr: &Expr) -> Result> { .flatten() else { return plan_err!( - "Second argument of {} is expected to be a string literal, actual: {}", + "Second argument of {} is expected to be a string literal, actual: {:?}", JsonGetWithType::NAME, - f.args[1] + f.args.get(1) ); };