fix: not step when aggr have order by/filter (#7015)

* fix: not applied

Signed-off-by: discord9 <discord9@163.com>

* chore: per review

Signed-off-by: discord9 <discord9@163.com>

* test: confirm order by not push down

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2025-09-25 16:43:18 +08:00
committed by GitHub
parent 06a4f0abea
commit 11a08d1381
3 changed files with 82 additions and 3 deletions

View File

@@ -324,6 +324,17 @@ impl AggregateUDFImpl for StateWrapper {
is_distinct: false,
};
let state_fields = self.inner.state_fields(state_fields_args)?;
let state_fields = state_fields
.into_iter()
.map(|f| {
let mut f = f.as_ref().clone();
// since state can be null when no input rows, so make all fields nullable
f.set_nullable(true);
Arc::new(f)
})
.collect::<Vec<_>>();
let struct_field = DataType::Struct(state_fields.into());
Ok(struct_field)
}
@@ -388,6 +399,38 @@ impl Accumulator for StateAccum {
.iter()
.map(|s| s.to_array())
.collect::<Result<Vec<_>, _>>()?;
let array_type = array
.iter()
.map(|a| a.data_type().clone())
.collect::<Vec<_>>();
let expected_type: Vec<_> = self
.state_fields
.iter()
.map(|f| f.data_type().clone())
.collect();
if array_type != expected_type {
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
self.state_fields.len(),
array.len(),
self.state_fields,
array_type,
);
let guess_schema = array
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
let arr = StructArray::try_new(guess_schema, array, None)?;
return Ok(ScalarValue::Struct(Arc::new(arr)));
}
let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
Ok(ScalarValue::Struct(Arc::new(struct_array)))
}

View File

@@ -23,7 +23,7 @@ use common_recordbatch::error::Result as RecordBatchResult;
use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
use common_telemetry::init_default_ut_logging;
use datafusion::datasource::DefaultTableSource;
use datafusion::functions_aggregate::expr_fn::avg;
use datafusion::functions_aggregate::expr_fn::{avg, last_value};
use datafusion::functions_aggregate::min_max::{max, min};
use datafusion_common::JoinType;
use datafusion_expr::expr::ScalarFunction;
@@ -1494,3 +1494,36 @@ fn date_bin_ts_group_by() {
.join("\n");
assert_eq!(expected, result.to_string());
}
/// check that `last_value(ts order by ts)` won't be push down
#[test]
fn test_not_push_down_aggr_order_by() {
init_default_ut_logging();
let test_table = TestTable::table_with_name(0, "t".to_string());
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(test_table),
)));
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.aggregate(
Vec::<Expr>::new(),
vec![last_value(col("ts"), vec![col("ts").sort(true, false)])],
)
.unwrap()
.build()
.unwrap();
let config = ConfigOptions::default();
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
let expected = [
"Aggregate: groupBy=[[]], aggr=[[last_value(t.ts) ORDER BY [t.ts ASC NULLS LAST]]]",
" Projection: t.pk1, t.pk2, t.pk3, t.ts, t.number",
" MergeScan [is_placeholder=false, remote_input=[",
"TableScan: t",
"]]",
]
.join("\n");
assert_eq!(expected, result.to_string());
}

View File

@@ -79,8 +79,11 @@ pub fn step_aggr_to_upper_aggr(
pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
aggr_exprs.iter().all(|expr| {
if let Some(aggr_func) = get_aggr_func(expr) {
if aggr_func.params.distinct {
// Distinct aggregate functions are not steppable(yet).
if aggr_func.params.distinct
|| !aggr_func.params.order_by.is_empty()
|| aggr_func.params.filter.is_some()
{
// Distinct aggregate functions/order by/filter in aggr args are not steppable(yet).
return false;
}