mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-14 03:50:39 +00:00
fix: type inference for sql rewrite (#8052)
fix: type inference for rewrited sql
This commit is contained in:
@@ -26,8 +26,8 @@ use common_telemetry::tracing;
|
||||
use datafusion::common::{DFSchema, plan_err};
|
||||
use datafusion::execution::context::SessionState;
|
||||
use datafusion::sql::planner::PlannerContext;
|
||||
use datafusion_common::ToDFSchema;
|
||||
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
|
||||
use datafusion_common::{ScalarValue, ToDFSchema};
|
||||
use datafusion_expr::expr::{Exists, InSubquery};
|
||||
use datafusion_expr::{
|
||||
Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType,
|
||||
@@ -451,6 +451,19 @@ impl DfLogicalPlanner {
|
||||
casted_placeholders.insert(ph.id.clone());
|
||||
}
|
||||
|
||||
// Handle arrow_cast(Placeholder, 'type_string') generated by SQL rewriter
|
||||
if let DfExpr::ScalarFunction(scalar_func) = e
|
||||
&& scalar_func.name() == "arrow_cast"
|
||||
&& scalar_func.args.len() == 2
|
||||
&& let DfExpr::Placeholder(ph) = &scalar_func.args[0]
|
||||
&& let DfExpr::Literal(ScalarValue::Utf8(Some(type_str)), _) =
|
||||
&scalar_func.args[1]
|
||||
&& let Ok(data_type) = type_str.parse::<DataType>()
|
||||
{
|
||||
placeholder_types.insert(ph.id.clone(), Some(data_type));
|
||||
casted_placeholders.insert(ph.id.clone());
|
||||
}
|
||||
|
||||
// Handle bare (non-casted) placeholders
|
||||
if let DfExpr::Placeholder(ph) = e
|
||||
&& !casted_placeholders.contains(&ph.id)
|
||||
@@ -869,4 +882,25 @@ mod tests {
|
||||
assert_eq!(types.get("$3"), Some(&Some(DataType::Int32)));
|
||||
assert_eq!(types.get("$4"), Some(&Some(DataType::Utf8)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_inferred_parameter_types_arrow_cast() {
|
||||
let plan = parse_sql_to_plan("SELECT $1::INT64, $2::FLOAT64, $3::INT16, $4::INT32, $5::UINT8, $6::UINT16, $7::UINT32").await;
|
||||
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
|
||||
|
||||
assert_eq!(types.get("$1"), Some(&Some(DataType::Int64)));
|
||||
assert_eq!(types.get("$2"), Some(&Some(DataType::Float64)));
|
||||
assert_eq!(types.get("$3"), Some(&Some(DataType::Int16)));
|
||||
assert_eq!(types.get("$4"), Some(&Some(DataType::Int32)));
|
||||
assert_eq!(types.get("$5"), Some(&Some(DataType::UInt8)));
|
||||
assert_eq!(types.get("$6"), Some(&Some(DataType::UInt16)));
|
||||
assert_eq!(types.get("$7"), Some(&Some(DataType::UInt32)));
|
||||
|
||||
let plan = parse_sql_to_plan("SELECT $1::INT8, $2::FLOAT8, $3::INT2, $4::INT8").await;
|
||||
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
|
||||
|
||||
assert_eq!(types.get("$1"), Some(&Some(DataType::Int64)));
|
||||
assert_eq!(types.get("$2"), Some(&Some(DataType::Float64)));
|
||||
assert_eq!(types.get("$3"), Some(&Some(DataType::Int16)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,8 +398,8 @@ pub(super) fn parameters_to_scalar_values(
|
||||
return Err(invalid_parameter_error(
|
||||
"unknown_parameter_type",
|
||||
Some(format!(
|
||||
"Cannot get parameter type information for parameter {}",
|
||||
idx
|
||||
"Cannot get type for parameter {}, try to provide a type using ${}::<type>",
|
||||
idx, idx
|
||||
)),
|
||||
));
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user