resolve PR comments

Signed-off-by: luofucong <luofc@foxmail.com>
This commit is contained in:
luofucong
2026-05-08 18:52:13 +08:00
parent ee135caeb9
commit 7deb711fb9
2 changed files with 93 additions and 89 deletions

View File

@@ -40,92 +40,111 @@ impl FunctionRewrite for JsonGetRewriter {
_schema: &DFSchema,
_config: &ConfigOptions,
) -> Result<Transformed<Expr>> {
let transform = match &expr {
Expr::Cast(cast) => rewrite_json_get_cast(cast),
Expr::ScalarFunction(scalar_func) => rewrite_arrow_cast_json_get(scalar_func),
_ => None,
};
Ok(transform.unwrap_or_else(|| Transformed::no(expr)))
Ok(match expr {
Expr::Cast(cast) => inject_type_from_cast_expr(cast)?,
Expr::ScalarFunction(cast) => inject_type_from_cast_func(cast)?,
expr => Transformed::no(expr),
})
}
}
fn is_json_get_function_call(scalar_func: &ScalarFunction) -> bool {
scalar_func.func.name().to_ascii_lowercase() == JsonGetWithType::NAME
&& scalar_func.args.len() == 2
// Expr::Cast(
// Expr::ScalarFunction(
// json_get(column, path),
// <data_type>
// )
// )
// =>
// Expr::ScalarFunction(
// json_get(column, path, <data_type>)
// )
fn inject_type_from_cast_expr(cast: Cast) -> Result<Transformed<Expr>> {
let Cast { expr, data_type } = cast;
let mut json_get = match *expr {
Expr::ScalarFunction(f)
if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
{
f
}
expr => {
return Ok(Transformed::no(Expr::Cast(Cast {
expr: Box::new(expr),
data_type,
})));
}
};
let with_type = ScalarValue::try_new_null(&data_type).map(|x| Expr::Literal(x, None))?;
json_get.args.push(with_type);
Ok(Transformed::yes(Expr::ScalarFunction(json_get)))
}
fn rewrite_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
let scalar_func = extract_scalar_function(&cast.expr)?;
if is_json_get_function_call(scalar_func) {
let null_expr = Expr::Literal(ScalarValue::Null, None);
let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast {
expr: Box::new(null_expr),
data_type: cast.data_type.clone(),
});
// Expr::ScalarFunction(
// arrow_cast(
// Expr::ScalarFunction(
// json_get(column, path),
// ),
// <data_type>
// )
// )
// =>
// Expr::ScalarFunction(
// json_get(column, path, <data_type>)
// )
fn inject_type_from_cast_func(cast: ScalarFunction) -> Result<Transformed<Expr>> {
let ScalarFunction { func, args } = cast;
let mut args = scalar_func.args.clone();
args.push(null_cast);
Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
func: scalar_func.func.clone(),
args,
})))
} else {
None
}
}
// Handle Arrow cast function: cast(json_get(a, 'path'), 'Int64')
fn rewrite_arrow_cast_json_get(scalar_func: &ScalarFunction) -> Option<Transformed<Expr>> {
// Check if this is an Arrow cast function
// The function name might be "arrow_cast" or similar
let func_name = scalar_func.func.name().to_ascii_lowercase();
let func_name = func.name().to_ascii_lowercase();
if !func_name.contains("arrow_cast") {
return None;
let original = Expr::ScalarFunction(ScalarFunction { func, args });
return Ok(Transformed::no(original));
}
// Arrow cast function should have exactly 2 arguments:
// 1. The expression to cast (could be json_get)
// 2. The target type as a string literal
if scalar_func.args.len() != 2 {
return None;
if args.len() != 2 {
let original = Expr::ScalarFunction(ScalarFunction { func, args });
return Ok(Transformed::no(original));
}
let [arg0, arg1] = args.try_into().unwrap_or_else(|_| unreachable!());
// Extract the inner json_get function
let json_get_func = extract_scalar_function(&scalar_func.args[0])?;
// Check if it's a json_get function
if is_json_get_function_call(json_get_func) {
// Get the target type from the second argument
let target_type = extract_string_literal(&scalar_func.args[1])?;
let data_type = parse_data_type_from_string(&target_type)?;
// Create the null expression with the same type
let null_expr = Expr::Literal(ScalarValue::Null, None);
let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast {
expr: Box::new(null_expr),
data_type,
let Some(with_type) = arg1
.as_literal()
.and_then(|x| x.try_as_str())
.flatten()
.and_then(parse_data_type_from_string)
else {
let original = Expr::ScalarFunction(ScalarFunction {
func,
args: vec![arg0, arg1],
});
return Ok(Transformed::no(original));
};
// Create the new json_get_with_type function with the null parameter
let mut args = json_get_func.args.clone();
args.push(null_cast);
let mut json_get = match arg0 {
Expr::ScalarFunction(f)
if f.func.name().eq_ignore_ascii_case(JsonGetWithType::NAME) && f.args.len() == 2 =>
{
f
}
arg0 => {
let original = Expr::ScalarFunction(ScalarFunction {
func,
args: vec![arg0, arg1],
});
return Ok(Transformed::no(original));
}
};
Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
func: json_get_func.func.clone(),
args,
})))
} else {
None
}
}
let with_type = ScalarValue::try_new_null(&with_type).map(|x| Expr::Literal(x, None))?;
json_get.args.push(with_type);
// Extract string literal from an expression
fn extract_string_literal(expr: &Expr) -> Option<String> {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s.clone()),
_ => None,
}
let rewritten = Expr::ScalarFunction(json_get);
Ok(Transformed::yes(rewritten))
}
// Parse a data type from a string representation
@@ -149,13 +168,6 @@ fn parse_data_type_from_string(type_str: &str) -> Option<DataType> {
}
}
fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
match expr {
Expr::ScalarFunction(func) => Some(func),
_ => None,
}
}
#[cfg(test)]
mod tests {
use arrow_schema::DataType;
@@ -221,12 +233,8 @@ mod tests {
// Third argument should be a null cast to Int8
match &func.args[2] {
Expr::Cast(Cast { expr, data_type }) => {
assert_eq!(*data_type, DataType::Int8);
match expr.as_ref() {
Expr::Literal(ScalarValue::Null, _) => {}
_ => panic!("Third argument should be a null cast"),
}
Expr::Literal(value, _) => {
assert_eq!(value.data_type(), DataType::Int8);
}
_ => panic!("Third argument should be a cast expression"),
}
@@ -314,12 +322,8 @@ mod tests {
// Third argument should be a null cast to Int64
match &func.args[2] {
Expr::Cast(Cast { expr, data_type }) => {
assert_eq!(*data_type, DataType::Int64);
match expr.as_ref() {
Expr::Literal(ScalarValue::Null, _) => {}
_ => panic!("Third argument should be a null cast"),
}
Expr::Literal(value, _) => {
assert_eq!(value.data_type(), DataType::Int64);
}
_ => panic!("Third argument should be a cast expression"),
}

View File

@@ -97,9 +97,9 @@ fn deduce_json_type(expr: &Expr) -> Result<Option<(String, JsonNativeType)>> {
let Some(Expr::Column(column)) = f.args.first() else {
return plan_err!(
"First argument of {} is expected to be a column expr, actual: {}",
"First argument of {} is expected to be a column expr, actual: {:?}",
JsonGetWithType::NAME,
f.args[0]
f.args.first()
);
};
@@ -111,9 +111,9 @@ fn deduce_json_type(expr: &Expr) -> Result<Option<(String, JsonNativeType)>> {
.flatten()
else {
return plan_err!(
"Second argument of {} is expected to be a string literal, actual: {}",
"Second argument of {} is expected to be a string literal, actual: {:?}",
JsonGetWithType::NAME,
f.args[1]
f.args.get(1)
);
};