diff --git a/src/common/function/src/aggrs/aggr_wrapper/tests.rs b/src/common/function/src/aggrs/aggr_wrapper/tests.rs index 744cc8db6e..ea0bf7445e 100644 --- a/src/common/function/src/aggrs/aggr_wrapper/tests.rs +++ b/src/common/function/src/aggrs/aggr_wrapper/tests.rs @@ -17,7 +17,9 @@ use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use arrow::array::{ArrayRef, BooleanArray, Float64Array, Int64Array, UInt64Array}; +use arrow::array::{ + ArrayRef, BooleanArray, Float64Array, Int64Array, TimestampMillisecondArray, UInt64Array, +}; use arrow::record_batch::RecordBatch; use arrow_schema::SchemaRef; use common_telemetry::init_default_ut_logging; @@ -164,6 +166,20 @@ impl DummyTableProvider { record_batch: Mutex::new(record_batch), } } + + pub fn with_ts(record_batch: Option) -> Self { + Self { + schema: Arc::new(arrow_schema::Schema::new(vec![ + Field::new("number", DataType::Int64, true), + Field::new( + "ts", + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), + false, + ), + ])), + record_batch: Mutex::new(record_batch), + } + } } impl Default for DummyTableProvider { @@ -226,6 +242,21 @@ fn dummy_table_scan() -> LogicalPlan { ) } +fn dummy_table_scan_with_ts() -> LogicalPlan { + let table_provider = Arc::new(DummyTableProvider::with_ts(None)); + let table_source = DefaultTableSource::new(table_provider); + LogicalPlan::TableScan( + TableScan::try_new( + TableReference::bare("Number"), + Arc::new(table_source), + None, + vec![], + None, + ) + .unwrap(), + ) +} + #[tokio::test] async fn test_sum_udaf() { let ctx = SessionContext::new(); @@ -556,15 +587,15 @@ async fn test_last_value_order_by_udaf() { let last_value = (*last_value).clone(); let original_aggr = Aggregate::try_new( - Arc::new(dummy_table_scan()), + Arc::new(dummy_table_scan_with_ts()), vec![], vec![Expr::AggregateFunction(AggregateFunction::new_udf( Arc::new(last_value.clone()), - vec![Expr::Column(Column::new_unqualified("number"))], + vec![Expr::Column(Column::new_unqualified("ts"))], false, None, vec![datafusion_expr::expr::Sort::new( - Expr::Column(Column::new_unqualified("number")), + Expr::Column(Column::new_unqualified("ts")), true, true, )], @@ -579,15 +610,15 @@ async fn test_last_value_order_by_udaf() { let expected_aggr_state_plan = LogicalPlan::Aggregate( Aggregate::try_new( - Arc::new(dummy_table_scan()), + Arc::new(dummy_table_scan_with_ts()), vec![], vec![Expr::AggregateFunction(AggregateFunction::new_udf( state_func, - vec![Expr::Column(Column::new_unqualified("number"))], + vec![Expr::Column(Column::new_unqualified("ts"))], false, None, vec![datafusion_expr::expr::Sort::new( - Expr::Column(Column::new_unqualified("number")), + Expr::Column(Column::new_unqualified("ts")), true, true, )], @@ -607,11 +638,19 @@ async fn test_last_value_order_by_udaf() { assert_eq!( res.lower_state.schema().as_arrow(), &arrow_schema::Schema::new(vec![Field::new( - "__last_value_state(number) ORDER BY [number ASC NULLS FIRST]", + "__last_value_state(ts) ORDER BY [ts ASC NULLS FIRST]", DataType::Struct( vec![ - Field::new("last_value[last_value]", DataType::Int64, true), - Field::new("number", DataType::Int64, true), // ordering field is added to state fields too + Field::new( + "last_value[last_value]", + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), + true + ), + Field::new( + "ts", + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), + true + ), // ordering field is added to state fields too Field::new("is_set", DataType::Boolean, true) ] .into() @@ -620,46 +659,6 @@ async fn test_last_value_order_by_udaf() { )]) ); - let expected_merge_fn = MergeWrapper::new( - last_value.clone(), - Arc::new( - AggregateExprBuilder::new( - Arc::new(last_value.clone()), - vec![Arc::new( - datafusion::physical_expr::expressions::Column::new("number", 0), - )], - ) - .schema(Arc::new(dummy_table_scan().schema().as_arrow().clone())) - .alias("last_value(number) ORDER BY [number ASC NULLS FIRST]") - .build() - .unwrap(), - ), - vec![DataType::Int64], - ) - .unwrap(); - - let expected_merge_plan = LogicalPlan::Aggregate( - Aggregate::try_new( - Arc::new(fixed_aggr_state_plan.clone()), - vec![], - vec![ - Expr::AggregateFunction(AggregateFunction::new_udf( - Arc::new(expected_merge_fn.into()), - vec![Expr::Column(Column::new_unqualified( - "__last_value_state(number) ORDER BY [number ASC NULLS FIRST]", - ))], - false, - None, - vec![], - None, - )) - .alias("last_value(number) ORDER BY [number ASC NULLS FIRST]"), - ], - ) - .unwrap(), - ); - assert_eq!(&res.upper_merge, &expected_merge_plan); - let phy_aggr_state_plan = DefaultPhysicalPlanner::default() .create_physical_plan(&fixed_aggr_state_plan, &ctx.state()) .await @@ -670,36 +669,79 @@ async fn test_last_value_order_by_udaf() { .unwrap(); let aggr_func_expr = &aggr_exec.aggr_expr()[0]; + let expected_merge_fn = MergeWrapper::new( + last_value.clone(), + aggr_func_expr.clone(), + vec![DataType::Timestamp( + arrow_schema::TimeUnit::Millisecond, + None, + )], + ) + .unwrap(); + + let expected_merge_plan = LogicalPlan::Aggregate( + Aggregate::try_new( + Arc::new(fixed_aggr_state_plan.clone()), + vec![], + vec![ + Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::new(expected_merge_fn.into()), + vec![Expr::Column(Column::new_unqualified( + "__last_value_state(ts) ORDER BY [ts ASC NULLS FIRST]", + ))], + false, + None, + vec![], + None, + )) + .alias("last_value(ts) ORDER BY [ts ASC NULLS FIRST]"), + ], + ) + .unwrap(), + ); + assert_eq!(&res.upper_merge, &expected_merge_plan); + let mut state_accum = aggr_func_expr.create_accumulator().unwrap(); // evaluate the state function - let input = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); - let values = vec![Arc::new(input) as arrow::array::ArrayRef]; + let input = Arc::new(TimestampMillisecondArray::from(vec![ + Some(1), + Some(2), + None, + Some(3), + ])) as arrow::array::ArrayRef; + // notice since sorting exist, the input must have two columns, one for the value, one for the ordering + let values = vec![input.clone(), input]; state_accum.update_batch(&values).unwrap(); let state = state_accum.state().unwrap(); - // FIXME(discord9): once datafusion fixes the issue that last_value udaf state fields are not correctly(missing ordering field if `last` field is part of ordering field) - // then change it back to 3 fields - assert_eq!(state.len(), 2); // last value weird optimization(or maybe bug?) that it only has 2 state fields now - assert_eq!(state[0], ScalarValue::Int64(Some(3))); - assert_eq!(state[1], ScalarValue::Boolean(Some(true))); + assert_eq!(state.len(), 3); + assert_eq!(state[0], ScalarValue::TimestampMillisecond(Some(3), None)); + assert_eq!(state[1], ScalarValue::TimestampMillisecond(Some(3), None)); + assert_eq!(state[2], ScalarValue::Boolean(Some(true))); let eval_res = state_accum.evaluate().unwrap(); let expected = Arc::new( StructArray::try_new( vec![ - Field::new("col_0[mismatch_state]", DataType::Int64, true), - Field::new("col_1[mismatch_state]", DataType::Boolean, true), - // Field::new("last_value[last_value]", DataType::Int64, true), - // Field::new("number", DataType::Int64, true), - // Field::new("is_set", DataType::Boolean, true), + Field::new( + "last_value[last_value]", + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts", + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), + true, + ), + Field::new("is_set", DataType::Boolean, true), ] .into(), vec![ - Arc::new(Int64Array::from(vec![Some(3)])), - // Arc::new(Int64Array::from(vec![Some(3)])), + Arc::new(TimestampMillisecondArray::from(vec![Some(3)])), + Arc::new(TimestampMillisecondArray::from(vec![Some(3)])), Arc::new(BooleanArray::from(vec![Some(true)])), ], None,