diff --git a/src/common/function/src/aggrs/aggr_wrapper.rs b/src/common/function/src/aggrs/aggr_wrapper.rs index d76565b4dd..686482aa8a 100644 --- a/src/common/function/src/aggrs/aggr_wrapper.rs +++ b/src/common/function/src/aggrs/aggr_wrapper.rs @@ -324,6 +324,17 @@ impl AggregateUDFImpl for StateWrapper { is_distinct: false, }; let state_fields = self.inner.state_fields(state_fields_args)?; + + let state_fields = state_fields + .into_iter() + .map(|f| { + let mut f = f.as_ref().clone(); + // since state can be null when no input rows, so make all fields nullable + f.set_nullable(true); + Arc::new(f) + }) + .collect::>(); + let struct_field = DataType::Struct(state_fields.into()); Ok(struct_field) } @@ -388,6 +399,38 @@ impl Accumulator for StateAccum { .iter() .map(|s| s.to_array()) .collect::, _>>()?; + let array_type = array + .iter() + .map(|a| a.data_type().clone()) + .collect::>(); + let expected_type: Vec<_> = self + .state_fields + .iter() + .map(|f| f.data_type().clone()) + .collect(); + if array_type != expected_type { + debug!( + "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}", + self.state_fields.len(), + array.len(), + self.state_fields, + array_type, + ); + let guess_schema = array + .iter() + .enumerate() + .map(|(index, array)| { + Field::new( + format!("col_{index}[mismatch_state]").as_str(), + array.data_type().clone(), + true, + ) + }) + .collect::(); + let arr = StructArray::try_new(guess_schema, array, None)?; + + return Ok(ScalarValue::Struct(Arc::new(arr))); + } let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?; Ok(ScalarValue::Struct(Arc::new(struct_array))) } diff --git a/src/query/src/dist_plan/analyzer/test.rs b/src/query/src/dist_plan/analyzer/test.rs index 722e888f14..0e1947f48c 100644 --- a/src/query/src/dist_plan/analyzer/test.rs +++ b/src/query/src/dist_plan/analyzer/test.rs @@ -23,7 +23,7 @@ use common_recordbatch::error::Result as RecordBatchResult; use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream}; use common_telemetry::init_default_ut_logging; use datafusion::datasource::DefaultTableSource; -use datafusion::functions_aggregate::expr_fn::avg; +use datafusion::functions_aggregate::expr_fn::{avg, last_value}; use datafusion::functions_aggregate::min_max::{max, min}; use datafusion_common::JoinType; use datafusion_expr::expr::ScalarFunction; @@ -1494,3 +1494,36 @@ fn date_bin_ts_group_by() { .join("\n"); assert_eq!(expected, result.to_string()); } + +/// check that `last_value(ts order by ts)` won't be push down +#[test] +fn test_not_push_down_aggr_order_by() { + init_default_ut_logging(); + let test_table = TestTable::table_with_name(0, "t".to_string()); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(test_table), + ))); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .aggregate( + Vec::::new(), + vec![last_value(col("ts"), vec![col("ts").sort(true, false)])], + ) + .unwrap() + .build() + .unwrap(); + + let config = ConfigOptions::default(); + let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap(); + + let expected = [ + "Aggregate: groupBy=[[]], aggr=[[last_value(t.ts) ORDER BY [t.ts ASC NULLS LAST]]]", + " Projection: t.pk1, t.pk2, t.pk3, t.ts, t.number", + " MergeScan [is_placeholder=false, remote_input=[", + "TableScan: t", + "]]", + ] + .join("\n"); + + assert_eq!(expected, result.to_string()); +} diff --git a/src/query/src/dist_plan/commutativity.rs b/src/query/src/dist_plan/commutativity.rs index 2fbeaf004f..008e19c583 100644 --- a/src/query/src/dist_plan/commutativity.rs +++ b/src/query/src/dist_plan/commutativity.rs @@ -79,8 +79,11 @@ pub fn step_aggr_to_upper_aggr( pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool { aggr_exprs.iter().all(|expr| { if let Some(aggr_func) = get_aggr_func(expr) { - if aggr_func.params.distinct { - // Distinct aggregate functions are not steppable(yet). + if aggr_func.params.distinct + || !aggr_func.params.order_by.is_empty() + || aggr_func.params.filter.is_some() + { + // Distinct aggregate functions/order by/filter in aggr args are not steppable(yet). return false; }