mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-18 05:50:41 +00:00
feat: add a fallback parameter type inference by reading cast type (#7712)
* feat: add a fallback parameter type inference by reading cast * fix: typo * fix: lint and typo * refactor: make extract function private * refactor: fix_placeholder_types is no longer needed
This commit is contained in:
@@ -14,9 +14,11 @@
|
||||
|
||||
use std::any::Any;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::DataType;
|
||||
use async_trait::async_trait;
|
||||
use catalog::table_source::DfTableSourceProvider;
|
||||
use common_error::ext::BoxedError;
|
||||
@@ -25,6 +27,7 @@ use datafusion::common::{DFSchema, plan_err};
|
||||
use datafusion::execution::context::SessionState;
|
||||
use datafusion::sql::planner::PlannerContext;
|
||||
use datafusion_common::ToDFSchema;
|
||||
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
|
||||
use datafusion_expr::{
|
||||
Analyze, Explain, ExplainFormat, Expr as DfExpr, LogicalPlan, LogicalPlanBuilder, PlanType,
|
||||
ToStringifiedPlan, col,
|
||||
@@ -405,6 +408,89 @@ impl DfLogicalPlanner {
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts cast types for all placeholders in a logical plan.
|
||||
/// Returns a map where each placeholder ID is mapped to:
|
||||
/// - Some(DataType) if the placeholder is cast to a specific type
|
||||
/// - None if the placeholder exists but has no cast
|
||||
///
|
||||
/// Example: `$1::TEXT` returns `{"$1": Some(DataType::Utf8)}`
|
||||
///
|
||||
/// This function walks through all expressions in the logical plan,
|
||||
/// including subqueries, to identify placeholders and their cast types.
|
||||
fn extract_placeholder_cast_types(
|
||||
plan: &LogicalPlan,
|
||||
) -> Result<HashMap<String, Option<DataType>>> {
|
||||
let mut placeholder_types = HashMap::new();
|
||||
let mut casted_placeholders = HashSet::new();
|
||||
|
||||
plan.apply(|node| {
|
||||
for expr in node.expressions() {
|
||||
let _ = expr.apply(|e| {
|
||||
if let DfExpr::Cast(cast) = e
|
||||
&& let DfExpr::Placeholder(ph) = &*cast.expr
|
||||
{
|
||||
placeholder_types.insert(ph.id.clone(), Some(cast.data_type.clone()));
|
||||
casted_placeholders.insert(ph.id.clone());
|
||||
}
|
||||
|
||||
if let DfExpr::Placeholder(ph) = e
|
||||
&& !casted_placeholders.contains(&ph.id)
|
||||
&& !placeholder_types.contains_key(&ph.id)
|
||||
{
|
||||
placeholder_types.insert(ph.id.clone(), None);
|
||||
}
|
||||
|
||||
Ok(TreeNodeRecursion::Continue)
|
||||
});
|
||||
}
|
||||
Ok(TreeNodeRecursion::Continue)
|
||||
})?;
|
||||
|
||||
Ok(placeholder_types)
|
||||
}
|
||||
|
||||
/// Gets inferred parameter types from a logical plan.
|
||||
/// Returns a map where each parameter ID is mapped to:
|
||||
/// - Some(DataType) if the parameter type could be inferred
|
||||
/// - None if the parameter type could not be inferred
|
||||
///
|
||||
/// This function first uses DataFusion's `get_parameter_types()` to infer types.
|
||||
/// If any parameters have `None` values (i.e., DataFusion couldn't infer their types),
|
||||
/// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts.
|
||||
///
|
||||
/// This is because datafusion can only infer types for a limited cases.
|
||||
///
|
||||
/// Example: For query `WHERE $1::TEXT AND $2`, DataFusion may not infer `$2`'s type,
|
||||
/// but this function will return `{"$1": Some(DataType::Utf8), "$2": None}`.
|
||||
pub fn get_inferred_parameter_types(
|
||||
plan: &LogicalPlan,
|
||||
) -> Result<HashMap<String, Option<DataType>>> {
|
||||
let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
|
||||
|
||||
let has_none = param_types.values().any(|v| v.is_none());
|
||||
|
||||
if !has_none {
|
||||
Ok(param_types)
|
||||
} else {
|
||||
let cast_types = Self::extract_placeholder_cast_types(plan)?;
|
||||
|
||||
let mut merged = param_types;
|
||||
|
||||
for (id, opt_type) in cast_types {
|
||||
merged
|
||||
.entry(id)
|
||||
.and_modify(|existing| {
|
||||
if existing.is_none() {
|
||||
*existing = opt_type.clone();
|
||||
}
|
||||
})
|
||||
.or_insert(opt_type);
|
||||
}
|
||||
|
||||
Ok(merged)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -453,3 +539,84 @@ impl LogicalPlanner for DfLogicalPlanner {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::DataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use session::context::QueryContext;
|
||||
use table::metadata::{TableInfoBuilder, TableMetaBuilder};
|
||||
use table::test_util::EmptyTable;
|
||||
|
||||
use super::*;
|
||||
use crate::QueryEngineRef;
|
||||
use crate::parser::QueryLanguageParser;
|
||||
|
||||
async fn create_test_engine() -> QueryEngineRef {
|
||||
let columns = vec![
|
||||
ColumnSchema::new("id", ConcreteDataType::int32_datatype(), false),
|
||||
ColumnSchema::new("name", ConcreteDataType::string_datatype(), true),
|
||||
];
|
||||
let schema = Arc::new(Schema::new(columns));
|
||||
let table_meta = TableMetaBuilder::empty()
|
||||
.schema(schema)
|
||||
.primary_key_indices(vec![0])
|
||||
.value_indices(vec![1])
|
||||
.next_column_id(1024)
|
||||
.build()
|
||||
.unwrap();
|
||||
let table_info = TableInfoBuilder::new("test", table_meta).build().unwrap();
|
||||
let table = EmptyTable::from_table_info(&table_info);
|
||||
|
||||
crate::tests::new_query_engine_with_table(table)
|
||||
}
|
||||
|
||||
async fn parse_sql_to_plan(sql: &str) -> LogicalPlan {
|
||||
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
|
||||
let engine = create_test_engine().await;
|
||||
engine
|
||||
.planner()
|
||||
.plan(&stmt, QueryContext::arc())
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_placeholder_cast_types_multiple() {
|
||||
let plan = parse_sql_to_plan(
|
||||
"SELECT $1::INT, $2::TEXT, $3, $4::INTEGER FROM test WHERE $5::FLOAT > 0",
|
||||
)
|
||||
.await;
|
||||
let types = DfLogicalPlanner::extract_placeholder_cast_types(&plan).unwrap();
|
||||
|
||||
assert_eq!(types.len(), 5);
|
||||
assert_eq!(types.get("$1"), Some(&Some(DataType::Int32)));
|
||||
assert_eq!(types.get("$2"), Some(&Some(DataType::Utf8)));
|
||||
assert_eq!(types.get("$3"), Some(&None));
|
||||
assert_eq!(types.get("$4"), Some(&Some(DataType::Int32)));
|
||||
assert_eq!(types.get("$5"), Some(&Some(DataType::Float32)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_inferred_parameter_types_fallback_for_udf_args() {
|
||||
// datafusion is not able to infer type for scalar function arguments
|
||||
let plan = parse_sql_to_plan(
|
||||
"SELECT parse_ident($1), parse_ident($2::TEXT) FROM test WHERE id > $3",
|
||||
)
|
||||
.await;
|
||||
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
|
||||
|
||||
assert_eq!(types.len(), 3);
|
||||
|
||||
let type_1 = types.get("$1").unwrap();
|
||||
let type_2 = types.get("$2").unwrap();
|
||||
let type_3 = types.get("$3").unwrap();
|
||||
|
||||
assert!(type_1.is_none(), "Expected $1 to be None");
|
||||
assert_eq!(type_2, &Some(DataType::Utf8));
|
||||
assert_eq!(type_3, &Some(DataType::Int32));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -445,6 +445,14 @@ pub enum Error {
|
||||
error: query::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to infer parameter types"))]
|
||||
InferParameterTypes {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
#[snafu(source)]
|
||||
error: query::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("{}", reason))]
|
||||
UnexpectedResult {
|
||||
reason: String,
|
||||
@@ -721,6 +729,7 @@ impl ErrorExt for Error {
|
||||
| InvalidPromRemoteRequest { .. }
|
||||
| InvalidFlightTicket { .. }
|
||||
| InvalidPrepareStatement { .. }
|
||||
| InferParameterTypes { .. }
|
||||
| DataFrame { .. }
|
||||
| PreparedStmtTypeMismatch { .. }
|
||||
| TimePrecision { .. }
|
||||
|
||||
@@ -34,6 +34,7 @@ use opensrv_mysql::{
|
||||
StatementMetaWriter, ValueInner,
|
||||
};
|
||||
use parking_lot::RwLock;
|
||||
use query::planner::DfLogicalPlanner;
|
||||
use query::query_engine::DescribeResult;
|
||||
use rand::RngCore;
|
||||
use session::context::{Channel, QueryContextRef};
|
||||
@@ -45,10 +46,12 @@ use sql::statements::statement::Statement;
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
use crate::SqlPlan;
|
||||
use crate::error::{self, DataFrameSnafu, InvalidPrepareStatementSnafu, Result};
|
||||
use crate::error::{
|
||||
self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result,
|
||||
};
|
||||
use crate::metrics::METRIC_AUTH_FAILURE;
|
||||
use crate::mysql::helper::{
|
||||
self, fix_placeholder_types, format_placeholder, replace_placeholders, transform_placeholders,
|
||||
self, format_placeholder, replace_placeholders, transform_placeholders,
|
||||
};
|
||||
use crate::mysql::writer;
|
||||
use crate::mysql::writer::{create_mysql_column, handle_err};
|
||||
@@ -206,7 +209,7 @@ impl MysqlInstanceShim {
|
||||
let describe_result = self
|
||||
.do_describe(statement.clone(), query_ctx.clone())
|
||||
.await?;
|
||||
let (mut plan, schema) = if let Some(DescribeResult {
|
||||
let (plan, schema) = if let Some(DescribeResult {
|
||||
logical_plan,
|
||||
schema,
|
||||
}) = describe_result
|
||||
@@ -216,17 +219,13 @@ impl MysqlInstanceShim {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let params = if let Some(plan) = &mut plan {
|
||||
fix_placeholder_types(plan)?;
|
||||
debug!("Plan after fix placeholder types: {:#?}", plan);
|
||||
prepared_params(
|
||||
&plan
|
||||
.get_parameter_types()
|
||||
.context(DataFrameSnafu)?
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
|
||||
.collect(),
|
||||
)?
|
||||
let params = if let Some(plan) = &plan {
|
||||
let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
|
||||
.context(InferParameterTypesSnafu)?
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
|
||||
.collect();
|
||||
prepared_params(¶m_types)?
|
||||
} else {
|
||||
dummy_params(param_num)?
|
||||
};
|
||||
@@ -293,11 +292,9 @@ impl MysqlInstanceShim {
|
||||
};
|
||||
|
||||
let outputs = match sql_plan.plan {
|
||||
Some(mut plan) => {
|
||||
fix_placeholder_types(&mut plan)?;
|
||||
let param_types = plan
|
||||
.get_parameter_types()
|
||||
.context(DataFrameSnafu)?
|
||||
Some(plan) => {
|
||||
let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan)
|
||||
.context(InferParameterTypesSnafu)?
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
@@ -13,16 +13,12 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::ControlFlow;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use arrow_schema::Field;
|
||||
use chrono::NaiveDate;
|
||||
use common_query::prelude::ScalarValue;
|
||||
use common_sql::convert::sql_value_to_value;
|
||||
use common_time::{Date, Timestamp};
|
||||
use datafusion_common::tree_node::{Transformed, TreeNode};
|
||||
use datafusion_expr::LogicalPlan;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::ColumnSchema;
|
||||
use datatypes::types::TimestampType;
|
||||
@@ -33,7 +29,7 @@ use snafu::ResultExt;
|
||||
use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::error::{self, DataFusionSnafu, Result};
|
||||
use crate::error::{self, Result};
|
||||
|
||||
/// Returns the placeholder string "$i".
|
||||
pub fn format_placeholder(i: usize) -> String {
|
||||
@@ -82,40 +78,6 @@ pub fn transform_placeholders(stmt: Statement) -> Statement {
|
||||
}
|
||||
}
|
||||
|
||||
/// Give placeholder that cast to certain type `data_type` the same data type as is cast to
|
||||
///
|
||||
/// because it seems datafusion will not give data type to placeholder if it need to be cast to certain type, still unknown if this is a feature or a bug. And if a placeholder expr have no data type, datafusion will fail to extract it using `LogicalPlan::get_parameter_types`
|
||||
pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> {
|
||||
let give_placeholder_types = |mut e: datafusion_expr::Expr| {
|
||||
if let datafusion_expr::Expr::Cast(cast) = &mut e {
|
||||
if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr {
|
||||
if ph.field.is_none() {
|
||||
ph.field = Some(Arc::new(Field::new("", cast.data_type.clone(), true)));
|
||||
common_telemetry::debug!(
|
||||
"give placeholder type {:?} to {:?}",
|
||||
cast.data_type,
|
||||
ph
|
||||
);
|
||||
Ok(Transformed::yes(e))
|
||||
} else {
|
||||
Ok(Transformed::no(e))
|
||||
}
|
||||
} else {
|
||||
Ok(Transformed::no(e))
|
||||
}
|
||||
} else {
|
||||
Ok(Transformed::no(e))
|
||||
}
|
||||
};
|
||||
let give_placeholder_types_recursively =
|
||||
|e: datafusion_expr::Expr| e.transform(give_placeholder_types);
|
||||
*plan = std::mem::take(plan)
|
||||
.transform(|p| p.map_expressions(give_placeholder_types_recursively))
|
||||
.context(DataFusionSnafu)?
|
||||
.data;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn visit_placeholders<V>(v: &mut V)
|
||||
where
|
||||
V: VisitMut,
|
||||
|
||||
@@ -34,6 +34,7 @@ use pgwire::api::stmt::{QueryParser, StoredStatement};
|
||||
use pgwire::api::{ClientInfo, ErrorHandler, Type};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use pgwire::messages::PgWireBackendMessage;
|
||||
use query::planner::DfLogicalPlanner;
|
||||
use query::query_engine::DescribeResult;
|
||||
use session::Session;
|
||||
use session::context::QueryContextRef;
|
||||
@@ -43,7 +44,7 @@ use sql::parser::{ParseOptions, ParserContext};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::SqlPlan;
|
||||
use crate::error::{DataFusionSnafu, Result};
|
||||
use crate::error::{DataFusionSnafu, InferParameterTypesSnafu, Result};
|
||||
use crate::postgres::types::*;
|
||||
use crate::postgres::utils::convert_err;
|
||||
use crate::postgres::{PostgresServerHandlerInner, fixtures};
|
||||
@@ -369,9 +370,8 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
|
||||
// client provided parameter types, can be empty if client doesn't try to parse statement
|
||||
let provided_param_types = &stmt.parameter_types;
|
||||
let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
|
||||
let param_types = plan
|
||||
.get_parameter_types()
|
||||
.context(DataFusionSnafu)
|
||||
let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
|
||||
.context(InferParameterTypesSnafu)
|
||||
.map_err(convert_err)?
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
|
||||
|
||||
@@ -39,12 +39,13 @@ use pgwire::api::results::{DataRowEncoder, FieldInfo};
|
||||
use pgwire::error::{PgWireError, PgWireResult};
|
||||
use pgwire::messages::data::DataRow;
|
||||
use pgwire::types::format::FormatOptions as PgFormatOptions;
|
||||
use query::planner::DfLogicalPlanner;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ResultExt;
|
||||
|
||||
pub use self::error::{PgErrorCode, PgErrorSeverity};
|
||||
use crate::SqlPlan;
|
||||
use crate::error::{self as server_error, DataFusionSnafu, Result};
|
||||
use crate::error::{self as server_error, InferParameterTypesSnafu, Result};
|
||||
use crate::postgres::utils::convert_err;
|
||||
|
||||
pub(super) fn schema_to_pg(
|
||||
@@ -364,9 +365,8 @@ pub(super) fn parameters_to_scalar_values(
|
||||
let mut results = Vec::with_capacity(param_count);
|
||||
|
||||
let client_param_types = &portal.statement.parameter_types;
|
||||
let server_param_types = plan
|
||||
.get_parameter_types()
|
||||
.context(DataFusionSnafu)
|
||||
let server_param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
|
||||
.context(InferParameterTypesSnafu)
|
||||
.map_err(convert_err)?
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
|
||||
|
||||
Reference in New Issue
Block a user