test: update unit test by passing extra sort columns (#7030)

* tests: fix unit test by passing one sort columns

Signed-off-by: discord9 <discord9@163.com>

* chore: per copilot

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2025-09-28 11:22:43 +08:00
committed by GitHub
parent 0717773f62
commit 8bcf4a8ab5

View File

@@ -17,7 +17,9 @@ use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use arrow::array::{ArrayRef, BooleanArray, Float64Array, Int64Array, UInt64Array};
use arrow::array::{
ArrayRef, BooleanArray, Float64Array, Int64Array, TimestampMillisecondArray, UInt64Array,
};
use arrow::record_batch::RecordBatch;
use arrow_schema::SchemaRef;
use common_telemetry::init_default_ut_logging;
@@ -164,6 +166,20 @@ impl DummyTableProvider {
record_batch: Mutex::new(record_batch),
}
}
pub fn with_ts(record_batch: Option<RecordBatch>) -> Self {
Self {
schema: Arc::new(arrow_schema::Schema::new(vec![
Field::new("number", DataType::Int64, true),
Field::new(
"ts",
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
false,
),
])),
record_batch: Mutex::new(record_batch),
}
}
}
impl Default for DummyTableProvider {
@@ -226,6 +242,21 @@ fn dummy_table_scan() -> LogicalPlan {
)
}
fn dummy_table_scan_with_ts() -> LogicalPlan {
let table_provider = Arc::new(DummyTableProvider::with_ts(None));
let table_source = DefaultTableSource::new(table_provider);
LogicalPlan::TableScan(
TableScan::try_new(
TableReference::bare("Number"),
Arc::new(table_source),
None,
vec![],
None,
)
.unwrap(),
)
}
#[tokio::test]
async fn test_sum_udaf() {
let ctx = SessionContext::new();
@@ -556,15 +587,15 @@ async fn test_last_value_order_by_udaf() {
let last_value = (*last_value).clone();
let original_aggr = Aggregate::try_new(
Arc::new(dummy_table_scan()),
Arc::new(dummy_table_scan_with_ts()),
vec![],
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(last_value.clone()),
vec![Expr::Column(Column::new_unqualified("number"))],
vec![Expr::Column(Column::new_unqualified("ts"))],
false,
None,
vec![datafusion_expr::expr::Sort::new(
Expr::Column(Column::new_unqualified("number")),
Expr::Column(Column::new_unqualified("ts")),
true,
true,
)],
@@ -579,15 +610,15 @@ async fn test_last_value_order_by_udaf() {
let expected_aggr_state_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
Arc::new(dummy_table_scan()),
Arc::new(dummy_table_scan_with_ts()),
vec![],
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
state_func,
vec![Expr::Column(Column::new_unqualified("number"))],
vec![Expr::Column(Column::new_unqualified("ts"))],
false,
None,
vec![datafusion_expr::expr::Sort::new(
Expr::Column(Column::new_unqualified("number")),
Expr::Column(Column::new_unqualified("ts")),
true,
true,
)],
@@ -607,11 +638,19 @@ async fn test_last_value_order_by_udaf() {
assert_eq!(
res.lower_state.schema().as_arrow(),
&arrow_schema::Schema::new(vec![Field::new(
"__last_value_state(number) ORDER BY [number ASC NULLS FIRST]",
"__last_value_state(ts) ORDER BY [ts ASC NULLS FIRST]",
DataType::Struct(
vec![
Field::new("last_value[last_value]", DataType::Int64, true),
Field::new("number", DataType::Int64, true), // ordering field is added to state fields too
Field::new(
"last_value[last_value]",
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
true
),
Field::new(
"ts",
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
true
), // ordering field is added to state fields too
Field::new("is_set", DataType::Boolean, true)
]
.into()
@@ -620,46 +659,6 @@ async fn test_last_value_order_by_udaf() {
)])
);
let expected_merge_fn = MergeWrapper::new(
last_value.clone(),
Arc::new(
AggregateExprBuilder::new(
Arc::new(last_value.clone()),
vec![Arc::new(
datafusion::physical_expr::expressions::Column::new("number", 0),
)],
)
.schema(Arc::new(dummy_table_scan().schema().as_arrow().clone()))
.alias("last_value(number) ORDER BY [number ASC NULLS FIRST]")
.build()
.unwrap(),
),
vec![DataType::Int64],
)
.unwrap();
let expected_merge_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
Arc::new(fixed_aggr_state_plan.clone()),
vec![],
vec![
Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(expected_merge_fn.into()),
vec![Expr::Column(Column::new_unqualified(
"__last_value_state(number) ORDER BY [number ASC NULLS FIRST]",
))],
false,
None,
vec![],
None,
))
.alias("last_value(number) ORDER BY [number ASC NULLS FIRST]"),
],
)
.unwrap(),
);
assert_eq!(&res.upper_merge, &expected_merge_plan);
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&fixed_aggr_state_plan, &ctx.state())
.await
@@ -670,36 +669,79 @@ async fn test_last_value_order_by_udaf() {
.unwrap();
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
let expected_merge_fn = MergeWrapper::new(
last_value.clone(),
aggr_func_expr.clone(),
vec![DataType::Timestamp(
arrow_schema::TimeUnit::Millisecond,
None,
)],
)
.unwrap();
let expected_merge_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
Arc::new(fixed_aggr_state_plan.clone()),
vec![],
vec![
Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(expected_merge_fn.into()),
vec![Expr::Column(Column::new_unqualified(
"__last_value_state(ts) ORDER BY [ts ASC NULLS FIRST]",
))],
false,
None,
vec![],
None,
))
.alias("last_value(ts) ORDER BY [ts ASC NULLS FIRST]"),
],
)
.unwrap(),
);
assert_eq!(&res.upper_merge, &expected_merge_plan);
let mut state_accum = aggr_func_expr.create_accumulator().unwrap();
// evaluate the state function
let input = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
let values = vec![Arc::new(input) as arrow::array::ArrayRef];
let input = Arc::new(TimestampMillisecondArray::from(vec![
Some(1),
Some(2),
None,
Some(3),
])) as arrow::array::ArrayRef;
// notice since sorting exist, the input must have two columns, one for the value, one for the ordering
let values = vec![input.clone(), input];
state_accum.update_batch(&values).unwrap();
let state = state_accum.state().unwrap();
// FIXME(discord9): once datafusion fixes the issue that last_value udaf state fields are not correctly(missing ordering field if `last` field is part of ordering field)
// then change it back to 3 fields
assert_eq!(state.len(), 2); // last value weird optimization(or maybe bug?) that it only has 2 state fields now
assert_eq!(state[0], ScalarValue::Int64(Some(3)));
assert_eq!(state[1], ScalarValue::Boolean(Some(true)));
assert_eq!(state.len(), 3);
assert_eq!(state[0], ScalarValue::TimestampMillisecond(Some(3), None));
assert_eq!(state[1], ScalarValue::TimestampMillisecond(Some(3), None));
assert_eq!(state[2], ScalarValue::Boolean(Some(true)));
let eval_res = state_accum.evaluate().unwrap();
let expected = Arc::new(
StructArray::try_new(
vec![
Field::new("col_0[mismatch_state]", DataType::Int64, true),
Field::new("col_1[mismatch_state]", DataType::Boolean, true),
// Field::new("last_value[last_value]", DataType::Int64, true),
// Field::new("number", DataType::Int64, true),
// Field::new("is_set", DataType::Boolean, true),
Field::new(
"last_value[last_value]",
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
true,
),
Field::new(
"ts",
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
true,
),
Field::new("is_set", DataType::Boolean, true),
]
.into(),
vec![
Arc::new(Int64Array::from(vec![Some(3)])),
// Arc::new(Int64Array::from(vec![Some(3)])),
Arc::new(TimestampMillisecondArray::from(vec![Some(3)])),
Arc::new(TimestampMillisecondArray::from(vec![Some(3)])),
Arc::new(BooleanArray::from(vec![Some(true)])),
],
None,