From 1234911ed3a3f5dc4e9ea20a1458a9d14b2e8296 Mon Sep 17 00:00:00 2001 From: dennis zhuang Date: Mon, 1 Sep 2025 15:00:26 +0800 Subject: [PATCH] refactor: query config options (#6781) * feat: refactor columnar and vector conversion Signed-off-by: Dennis Zhuang * feat: initialize config options from query context Signed-off-by: Dennis Zhuang * fix: failure tests Signed-off-by: Dennis Zhuang * chore: revert ColumnarValue::try_from_vector Signed-off-by: Dennis Zhuang --------- Signed-off-by: Dennis Zhuang --- Cargo.lock | 1 + src/common/function/src/scalars/math/clamp.rs | 66 ++++++++++--------- src/common/function/src/scalars/udf.rs | 11 +--- src/datatypes/src/vectors/helper.rs | 5 +- src/flow/src/batching_mode/task.rs | 2 +- src/flow/src/batching_mode/utils.rs | 4 +- src/flow/src/df_optimizer.rs | 8 ++- src/flow/src/test_utils.rs | 4 +- src/flow/src/transform.rs | 6 +- src/operator/src/error.rs | 10 +++ src/operator/src/statement/admin.rs | 23 ++----- src/query/src/datafusion.rs | 3 + src/session/Cargo.toml | 1 + src/session/src/context.rs | 8 +++ .../common/function/arithmetic.result | 8 ++- .../standalone/common/promql/scalar.result | 12 +++- 16 files changed, 103 insertions(+), 69 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0f6092dcec..1d68761834 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11635,6 +11635,7 @@ dependencies = [ "common-session", "common-telemetry", "common-time", + "datafusion-common", "derive_builder 0.20.2", "derive_more", "snafu 0.8.6", diff --git a/src/common/function/src/scalars/math/clamp.rs b/src/common/function/src/scalars/math/clamp.rs index eb5f349d0f..db0330b546 100644 --- a/src/common/function/src/scalars/math/clamp.rs +++ b/src/common/function/src/scalars/math/clamp.rs @@ -34,6 +34,33 @@ pub struct ClampFunction; const CLAMP_NAME: &str = "clamp"; +/// Ensure the vector is constant and not empty (i.e., all values are identical) +fn ensure_constant_vector(vector: &VectorRef) -> Result<()> { + ensure!( + !vector.is_empty(), + InvalidFuncArgsSnafu { + err_msg: "Expect at least one value", + } + ); + + if vector.is_const() { + return Ok(()); + } + + let first = vector.get_ref(0); + for i in 1..vector.len() { + let v = vector.get_ref(i); + if first != v { + return InvalidFuncArgsSnafu { + err_msg: "All values in min/max argument must be identical", + } + .fail(); + } + } + + Ok(()) +} + impl Function for ClampFunction { fn name(&self) -> &str { CLAMP_NAME @@ -80,16 +107,9 @@ impl Function for ClampFunction { ), } ); - ensure!( - (columns[1].len() == 1 || columns[1].is_const()) - && (columns[2].len() == 1 || columns[2].is_const()), - InvalidFuncArgsSnafu { - err_msg: format!( - "The second and third args should be scalar, have: {:?}, {:?}", - columns[1], columns[2] - ), - } - ); + + ensure_constant_vector(&columns[1])?; + ensure_constant_vector(&columns[2])?; with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { let input_array = columns[0].to_arrow_array(); @@ -204,15 +224,8 @@ impl Function for ClampMinFunction { ), } ); - ensure!( - columns[1].len() == 1 || columns[1].is_const(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The second arg (min) should be scalar, have: {:?}", - columns[1] - ), - } - ); + + ensure_constant_vector(&columns[1])?; with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { let input_array = columns[0].to_arrow_array(); @@ -292,15 +305,8 @@ impl Function for ClampMaxFunction { ), } ); - ensure!( - columns[1].len() == 1 || columns[1].is_const(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The second arg (max) should be scalar, have: {:?}", - columns[1] - ), - } - ); + + ensure_constant_vector(&columns[1])?; with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { let input_array = columns[0].to_arrow_array(); @@ -537,8 +543,8 @@ mod test { let func = ClampFunction; let args = [ Arc::new(Float64Vector::from(input)) as _, - Arc::new(Float64Vector::from_vec(vec![min, min])) as _, - Arc::new(Float64Vector::from_vec(vec![max])) as _, + Arc::new(Float64Vector::from_vec(vec![min, max])) as _, + Arc::new(Float64Vector::from_vec(vec![max, min])) as _, ]; let result = func.eval(&FunctionContext::default(), args.as_slice()); assert!(result.is_err()); diff --git a/src/common/function/src/scalars/udf.rs b/src/common/function/src/scalars/udf.rs index 8f09882af1..54eaab846f 100644 --- a/src/common/function/src/scalars/udf.rs +++ b/src/common/function/src/scalars/udf.rs @@ -16,15 +16,12 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use common_query::error::FromScalarValueSnafu; use common_query::prelude::ColumnarValue; use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_expr::ScalarUDF; use datatypes::data_type::DataType; use datatypes::prelude::*; -use datatypes::vectors::Helper; use session::context::QueryContextRef; -use snafu::ResultExt; use crate::function::{FunctionContext, FunctionRef}; use crate::state::FunctionState; @@ -76,13 +73,7 @@ impl ScalarUDFImpl for ScalarUdf { let columns = args .args .iter() - .map(|x| { - ColumnarValue::try_from(x).and_then(|y| match y { - ColumnarValue::Vector(z) => Ok(z), - ColumnarValue::Scalar(z) => Helper::try_from_scalar_value(z, args.number_rows) - .context(FromScalarValueSnafu), - }) - }) + .map(|x| ColumnarValue::try_from(x).and_then(|y| y.try_into_vector(args.number_rows))) .collect::>>()?; let v = self .function diff --git a/src/datatypes/src/vectors/helper.rs b/src/datatypes/src/vectors/helper.rs index b867d1a7b1..8b5bdb1dd5 100644 --- a/src/datatypes/src/vectors/helper.rs +++ b/src/datatypes/src/vectors/helper.rs @@ -128,6 +128,10 @@ impl Helper { ScalarValue::Boolean(v) => { ConstantVector::new(Arc::new(BooleanVector::from(vec![v])), length) } + ScalarValue::Float16(v) => ConstantVector::new( + Arc::new(Float32Vector::from(vec![v.map(f32::from)])), + length, + ), ScalarValue::Float32(v) => { ConstantVector::new(Arc::new(Float32Vector::from(vec![v])), length) } @@ -243,7 +247,6 @@ impl Helper { | ScalarValue::LargeList(_) | ScalarValue::Dictionary(_, _) | ScalarValue::Union(_, _, _) - | ScalarValue::Float16(_) | ScalarValue::Utf8View(_) | ScalarValue::BinaryView(_) | ScalarValue::Map(_) diff --git a/src/flow/src/batching_mode/task.rs b/src/flow/src/batching_mode/task.rs index eee9eb001f..cdd2f635e5 100644 --- a/src/flow/src/batching_mode/task.rs +++ b/src/flow/src/batching_mode/task.rs @@ -714,7 +714,7 @@ impl BatchingTask { })? .data; // only apply optimize after complex rewrite is done - let new_plan = apply_df_optimizer(rewrite).await?; + let new_plan = apply_df_optimizer(rewrite, &query_ctx).await?; let info = PlanInfo { plan: new_plan.clone(), diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index e0708336f1..05ab1177b0 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -122,13 +122,13 @@ pub async fn sql_to_df_plan( }; let plan = engine .planner() - .plan(&query_stmt, query_ctx) + .plan(&query_stmt, query_ctx.clone()) .await .map_err(BoxedError::new) .context(ExternalSnafu)?; let plan = if optimize { - apply_df_optimizer(plan).await? + apply_df_optimizer(plan, &query_ctx).await? } else { plan }; diff --git a/src/flow/src/df_optimizer.rs b/src/flow/src/df_optimizer.rs index c2eb7b246a..75a91b6070 100644 --- a/src/flow/src/df_optimizer.rs +++ b/src/flow/src/df_optimizer.rs @@ -44,6 +44,7 @@ use query::optimizer::count_wildcard::CountWildcardToTimeIndexRule; use query::parser::QueryLanguageParser; use query::query_engine::DefaultSerializer; use query::QueryEngine; +use session::context::QueryContextRef; use snafu::ResultExt; /// note here we are using the `substrait_proto_df` crate from the `substrait` module and /// rename it to `substrait_proto` @@ -57,8 +58,9 @@ use crate::plan::TypedPlan; // TODO(discord9): use `Analyzer` to manage rules if more `AnalyzerRule` is needed pub async fn apply_df_optimizer( plan: datafusion_expr::LogicalPlan, + query_ctx: &QueryContextRef, ) -> Result { - let cfg = ConfigOptions::new(); + let cfg = query_ctx.create_config_options(); let analyzer = Analyzer::with_rules(vec![ Arc::new(CountWildcardToTimeIndexRule), Arc::new(AvgExpandRule), @@ -107,12 +109,12 @@ pub async fn sql_to_flow_plan( .context(ExternalSnafu)?; let plan = engine .planner() - .plan(&stmt, query_ctx) + .plan(&stmt, query_ctx.clone()) .await .map_err(BoxedError::new) .context(ExternalSnafu)?; - let opted_plan = apply_df_optimizer(plan).await?; + let opted_plan = apply_df_optimizer(plan, &query_ctx).await?; // TODO(discord9): add df optimization let sub_plan = DFLogicalSubstraitConvertor {} diff --git a/src/flow/src/test_utils.rs b/src/flow/src/test_utils.rs index ecaabae32d..2b47d1595c 100644 --- a/src/flow/src/test_utils.rs +++ b/src/flow/src/test_utils.rs @@ -172,7 +172,9 @@ pub async fn sql_to_substrait(engine: Arc, sql: &str) -> proto: .plan(&stmt, QueryContext::arc()) .await .unwrap(); - let plan = apply_df_optimizer(plan).await.unwrap(); + let plan = apply_df_optimizer(plan, &QueryContext::arc()) + .await + .unwrap(); // encode then decode so to rely on the impl of conversion from logical plan to substrait plan let bytes = DFLogicalSubstraitConvertor {} diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs index a3ecfcd5fe..d2c2ebf1a4 100644 --- a/src/flow/src/transform.rs +++ b/src/flow/src/transform.rs @@ -293,7 +293,9 @@ mod test { .plan(&stmt, QueryContext::arc()) .await .unwrap(); - let plan = apply_df_optimizer(plan).await.unwrap(); + let plan = apply_df_optimizer(plan, &QueryContext::arc()) + .await + .unwrap(); // encode then decode so to rely on the impl of conversion from logical plan to substrait plan let bytes = DFLogicalSubstraitConvertor {} @@ -315,7 +317,7 @@ mod test { .plan(&stmt, QueryContext::arc()) .await .unwrap(); - let plan = apply_df_optimizer(plan).await; + let plan = apply_df_optimizer(plan, &QueryContext::arc()).await; assert!(plan.is_err()); } diff --git a/src/operator/src/error.rs b/src/operator/src/error.rs index 5996a4a2b2..ddf370e7c7 100644 --- a/src/operator/src/error.rs +++ b/src/operator/src/error.rs @@ -19,6 +19,7 @@ use common_error::define_into_tonic_status; use common_error::ext::{BoxedError, ErrorExt}; use common_error::status_code::StatusCode; use common_macro::stack_trace_debug; +use common_query::error::Error as QueryResult; use datafusion::parquet; use datafusion_common::DataFusionError; use datatypes::arrow::error::ArrowError; @@ -36,6 +37,14 @@ pub enum Error { location: Location, }, + #[snafu(display("Failed to cast result: `{}`", source))] + Cast { + #[snafu(source)] + source: QueryResult, + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("View already exists: `{name}`"))] ViewAlreadyExists { name: String, @@ -870,6 +879,7 @@ pub type Result = std::result::Result; impl ErrorExt for Error { fn status_code(&self) -> StatusCode { match self { + Error::Cast { source, .. } => source.status_code(), Error::InvalidSql { .. } | Error::InvalidConfigValue { .. } | Error::InvalidInsertRequest { .. } diff --git a/src/operator/src/statement/admin.rs b/src/operator/src/statement/admin.rs index 9cfab774d3..908f84c05e 100644 --- a/src/operator/src/statement/admin.rs +++ b/src/operator/src/statement/admin.rs @@ -32,7 +32,7 @@ use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::{Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Value as SqlValue}; use sql::statements::admin::Admin; -use crate::error::{self, ExecuteAdminFunctionSnafu, IntoVectorsSnafu, Result}; +use crate::error::{self, CastSnafu, ExecuteAdminFunctionSnafu, Result}; use crate::statement::StatementExecutor; const DUMMY_COLUMN: &str = ""; @@ -118,7 +118,7 @@ impl StatementExecutor { .collect(), return_field: Arc::new(arrow::datatypes::Field::new("result", ret_type, true)), number_rows: if args.is_empty() { 1 } else { args[0].len() }, - config_options: Arc::new(datafusion_common::config::ConfigOptions::default()), + config_options: Arc::new(query_ctx.create_config_options()), }; // Execute the async UDF @@ -134,22 +134,11 @@ impl StatementExecutor { })?; // Convert result back to VectorRef - let result = match result_columnar { - datafusion_expr::ColumnarValue::Array(array) => { - datatypes::vectors::Helper::try_into_vector(array).context(IntoVectorsSnafu)? - } - datafusion_expr::ColumnarValue::Scalar(scalar) => { - let array = - scalar - .to_array_of_size(1) - .with_context(|_| ExecuteAdminFunctionSnafu { - msg: format!("Failed to convert scalar to array for {}", fn_name), - })?; - datatypes::vectors::Helper::try_into_vector(array).context(IntoVectorsSnafu)? - } - }; + let result_columnar: common_query::prelude::ColumnarValue = + (&result_columnar).try_into().context(CastSnafu)?; + + let result_vector: VectorRef = result_columnar.try_into_vector(1).context(CastSnafu)?; - let result_vector: VectorRef = result; let column_schemas = vec![ColumnSchema::new( // Use statement as the result column name stmt.to_string(), diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 2967e3a5e8..f7577368b0 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -537,6 +537,9 @@ impl QueryEngine for DatafusionQueryEngine { } } + // configure execution options + state.config_mut().options_mut().execution.time_zone = query_ctx.timezone().to_string(); + // usually it's impossible to have both `set variable` set by sql client and // hint in header by grpc client, so only need to deal with them separately if query_ctx.configuration_parameter().allow_query_fallback() { diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index abe38c7ac3..d6ee98650f 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -22,6 +22,7 @@ common-recordbatch.workspace = true common-session.workspace = true common-telemetry.workspace = true common-time.workspace = true +datafusion-common.workspace = true derive_builder.workspace = true derive_more = { version = "1", default-features = false, features = ["debug"] } snafu.workspace = true diff --git a/src/session/src/context.rs b/src/session/src/context.rs index fac891ac82..533d01a27b 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -28,6 +28,7 @@ use common_recordbatch::cursor::RecordBatchStreamCursor; use common_telemetry::warn; use common_time::timezone::parse_timezone; use common_time::Timezone; +use datafusion_common::config::ConfigOptions; use derive_builder::Builder; use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; @@ -221,6 +222,13 @@ impl QueryContext { Arc::new(QueryContextBuilder::default().build()) } + /// Create a new datafusion's ConfigOptions instance based on the current QueryContext. + pub fn create_config_options(&self) -> ConfigOptions { + let mut config = ConfigOptions::default(); + config.execution.time_zone = self.timezone().to_string(); + config + } + pub fn with(catalog: &str, schema: &str) -> QueryContext { QueryContextBuilder::default() .current_catalog(catalog.to_string()) diff --git a/tests/cases/standalone/common/function/arithmetic.result b/tests/cases/standalone/common/function/arithmetic.result index 17e668b24d..91087bec17 100644 --- a/tests/cases/standalone/common/function/arithmetic.result +++ b/tests/cases/standalone/common/function/arithmetic.result @@ -76,7 +76,13 @@ SELECT CLAMP(0.5, 0, 1); SELECT CLAMP(10, 1, 0); -Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: ConstantVector([Int64(1); 1]), ConstantVector([Int64(0); 1]) +Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: PrimitiveVector { array: PrimitiveArray +[ + 1, +] }, PrimitiveVector { array: PrimitiveArray +[ + 0, +] } SELECT CLAMP_MIN(10, 12); diff --git a/tests/cases/standalone/common/promql/scalar.result b/tests/cases/standalone/common/promql/scalar.result index f581a976b0..12a4481ad0 100644 --- a/tests/cases/standalone/common/promql/scalar.result +++ b/tests/cases/standalone/common/promql/scalar.result @@ -375,7 +375,17 @@ TQL EVAL (0, 15, '5s') clamp(host, 6 - 6, 6 + 6); -- SQLNESS SORT_RESULT 3 1 TQL EVAL (0, 15, '5s') clamp(host, 12, 0); -Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: ConstantVector([Float64(12.0); 3]), ConstantVector([Float64(0.0); 3]) +Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: PrimitiveVector { array: PrimitiveArray +[ + 12.0, + 0.0, + 0.0, + 0.0, + 12.0, + 12.0, +[ +] }, PrimitiveVector { array: PrimitiveArray +] } -- SQLNESS SORT_RESULT 3 1 TQL EVAL (0, 15, '5s') clamp(host{host="host1"}, -1, 6);