mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-17 21:40:37 +00:00
test: update unit test by passing extra sort columns (#7030)
* tests: fix unit test by passing one sort columns Signed-off-by: discord9 <discord9@163.com> * chore: per copilot Signed-off-by: discord9 <discord9@163.com> --------- Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
@@ -17,7 +17,9 @@ use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use arrow::array::{ArrayRef, BooleanArray, Float64Array, Int64Array, UInt64Array};
|
||||
use arrow::array::{
|
||||
ArrayRef, BooleanArray, Float64Array, Int64Array, TimestampMillisecondArray, UInt64Array,
|
||||
};
|
||||
use arrow::record_batch::RecordBatch;
|
||||
use arrow_schema::SchemaRef;
|
||||
use common_telemetry::init_default_ut_logging;
|
||||
@@ -164,6 +166,20 @@ impl DummyTableProvider {
|
||||
record_batch: Mutex::new(record_batch),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_ts(record_batch: Option<RecordBatch>) -> Self {
|
||||
Self {
|
||||
schema: Arc::new(arrow_schema::Schema::new(vec![
|
||||
Field::new("number", DataType::Int64, true),
|
||||
Field::new(
|
||||
"ts",
|
||||
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
|
||||
false,
|
||||
),
|
||||
])),
|
||||
record_batch: Mutex::new(record_batch),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DummyTableProvider {
|
||||
@@ -226,6 +242,21 @@ fn dummy_table_scan() -> LogicalPlan {
|
||||
)
|
||||
}
|
||||
|
||||
fn dummy_table_scan_with_ts() -> LogicalPlan {
|
||||
let table_provider = Arc::new(DummyTableProvider::with_ts(None));
|
||||
let table_source = DefaultTableSource::new(table_provider);
|
||||
LogicalPlan::TableScan(
|
||||
TableScan::try_new(
|
||||
TableReference::bare("Number"),
|
||||
Arc::new(table_source),
|
||||
None,
|
||||
vec![],
|
||||
None,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sum_udaf() {
|
||||
let ctx = SessionContext::new();
|
||||
@@ -556,15 +587,15 @@ async fn test_last_value_order_by_udaf() {
|
||||
let last_value = (*last_value).clone();
|
||||
|
||||
let original_aggr = Aggregate::try_new(
|
||||
Arc::new(dummy_table_scan()),
|
||||
Arc::new(dummy_table_scan_with_ts()),
|
||||
vec![],
|
||||
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
Arc::new(last_value.clone()),
|
||||
vec![Expr::Column(Column::new_unqualified("number"))],
|
||||
vec![Expr::Column(Column::new_unqualified("ts"))],
|
||||
false,
|
||||
None,
|
||||
vec![datafusion_expr::expr::Sort::new(
|
||||
Expr::Column(Column::new_unqualified("number")),
|
||||
Expr::Column(Column::new_unqualified("ts")),
|
||||
true,
|
||||
true,
|
||||
)],
|
||||
@@ -579,15 +610,15 @@ async fn test_last_value_order_by_udaf() {
|
||||
|
||||
let expected_aggr_state_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
Arc::new(dummy_table_scan()),
|
||||
Arc::new(dummy_table_scan_with_ts()),
|
||||
vec![],
|
||||
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
state_func,
|
||||
vec![Expr::Column(Column::new_unqualified("number"))],
|
||||
vec![Expr::Column(Column::new_unqualified("ts"))],
|
||||
false,
|
||||
None,
|
||||
vec![datafusion_expr::expr::Sort::new(
|
||||
Expr::Column(Column::new_unqualified("number")),
|
||||
Expr::Column(Column::new_unqualified("ts")),
|
||||
true,
|
||||
true,
|
||||
)],
|
||||
@@ -607,11 +638,19 @@ async fn test_last_value_order_by_udaf() {
|
||||
assert_eq!(
|
||||
res.lower_state.schema().as_arrow(),
|
||||
&arrow_schema::Schema::new(vec![Field::new(
|
||||
"__last_value_state(number) ORDER BY [number ASC NULLS FIRST]",
|
||||
"__last_value_state(ts) ORDER BY [ts ASC NULLS FIRST]",
|
||||
DataType::Struct(
|
||||
vec![
|
||||
Field::new("last_value[last_value]", DataType::Int64, true),
|
||||
Field::new("number", DataType::Int64, true), // ordering field is added to state fields too
|
||||
Field::new(
|
||||
"last_value[last_value]",
|
||||
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
|
||||
true
|
||||
),
|
||||
Field::new(
|
||||
"ts",
|
||||
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
|
||||
true
|
||||
), // ordering field is added to state fields too
|
||||
Field::new("is_set", DataType::Boolean, true)
|
||||
]
|
||||
.into()
|
||||
@@ -620,46 +659,6 @@ async fn test_last_value_order_by_udaf() {
|
||||
)])
|
||||
);
|
||||
|
||||
let expected_merge_fn = MergeWrapper::new(
|
||||
last_value.clone(),
|
||||
Arc::new(
|
||||
AggregateExprBuilder::new(
|
||||
Arc::new(last_value.clone()),
|
||||
vec![Arc::new(
|
||||
datafusion::physical_expr::expressions::Column::new("number", 0),
|
||||
)],
|
||||
)
|
||||
.schema(Arc::new(dummy_table_scan().schema().as_arrow().clone()))
|
||||
.alias("last_value(number) ORDER BY [number ASC NULLS FIRST]")
|
||||
.build()
|
||||
.unwrap(),
|
||||
),
|
||||
vec![DataType::Int64],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected_merge_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
Arc::new(fixed_aggr_state_plan.clone()),
|
||||
vec![],
|
||||
vec![
|
||||
Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
Arc::new(expected_merge_fn.into()),
|
||||
vec![Expr::Column(Column::new_unqualified(
|
||||
"__last_value_state(number) ORDER BY [number ASC NULLS FIRST]",
|
||||
))],
|
||||
false,
|
||||
None,
|
||||
vec![],
|
||||
None,
|
||||
))
|
||||
.alias("last_value(number) ORDER BY [number ASC NULLS FIRST]"),
|
||||
],
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(&res.upper_merge, &expected_merge_plan);
|
||||
|
||||
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&fixed_aggr_state_plan, &ctx.state())
|
||||
.await
|
||||
@@ -670,36 +669,79 @@ async fn test_last_value_order_by_udaf() {
|
||||
.unwrap();
|
||||
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
|
||||
|
||||
let expected_merge_fn = MergeWrapper::new(
|
||||
last_value.clone(),
|
||||
aggr_func_expr.clone(),
|
||||
vec![DataType::Timestamp(
|
||||
arrow_schema::TimeUnit::Millisecond,
|
||||
None,
|
||||
)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected_merge_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
Arc::new(fixed_aggr_state_plan.clone()),
|
||||
vec![],
|
||||
vec![
|
||||
Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
Arc::new(expected_merge_fn.into()),
|
||||
vec![Expr::Column(Column::new_unqualified(
|
||||
"__last_value_state(ts) ORDER BY [ts ASC NULLS FIRST]",
|
||||
))],
|
||||
false,
|
||||
None,
|
||||
vec![],
|
||||
None,
|
||||
))
|
||||
.alias("last_value(ts) ORDER BY [ts ASC NULLS FIRST]"),
|
||||
],
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(&res.upper_merge, &expected_merge_plan);
|
||||
|
||||
let mut state_accum = aggr_func_expr.create_accumulator().unwrap();
|
||||
|
||||
// evaluate the state function
|
||||
let input = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
|
||||
let values = vec![Arc::new(input) as arrow::array::ArrayRef];
|
||||
let input = Arc::new(TimestampMillisecondArray::from(vec![
|
||||
Some(1),
|
||||
Some(2),
|
||||
None,
|
||||
Some(3),
|
||||
])) as arrow::array::ArrayRef;
|
||||
// notice since sorting exist, the input must have two columns, one for the value, one for the ordering
|
||||
let values = vec![input.clone(), input];
|
||||
|
||||
state_accum.update_batch(&values).unwrap();
|
||||
|
||||
let state = state_accum.state().unwrap();
|
||||
|
||||
// FIXME(discord9): once datafusion fixes the issue that last_value udaf state fields are not correctly(missing ordering field if `last` field is part of ordering field)
|
||||
// then change it back to 3 fields
|
||||
assert_eq!(state.len(), 2); // last value weird optimization(or maybe bug?) that it only has 2 state fields now
|
||||
assert_eq!(state[0], ScalarValue::Int64(Some(3)));
|
||||
assert_eq!(state[1], ScalarValue::Boolean(Some(true)));
|
||||
assert_eq!(state.len(), 3);
|
||||
assert_eq!(state[0], ScalarValue::TimestampMillisecond(Some(3), None));
|
||||
assert_eq!(state[1], ScalarValue::TimestampMillisecond(Some(3), None));
|
||||
assert_eq!(state[2], ScalarValue::Boolean(Some(true)));
|
||||
|
||||
let eval_res = state_accum.evaluate().unwrap();
|
||||
let expected = Arc::new(
|
||||
StructArray::try_new(
|
||||
vec![
|
||||
Field::new("col_0[mismatch_state]", DataType::Int64, true),
|
||||
Field::new("col_1[mismatch_state]", DataType::Boolean, true),
|
||||
// Field::new("last_value[last_value]", DataType::Int64, true),
|
||||
// Field::new("number", DataType::Int64, true),
|
||||
// Field::new("is_set", DataType::Boolean, true),
|
||||
Field::new(
|
||||
"last_value[last_value]",
|
||||
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
|
||||
true,
|
||||
),
|
||||
Field::new(
|
||||
"ts",
|
||||
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
|
||||
true,
|
||||
),
|
||||
Field::new("is_set", DataType::Boolean, true),
|
||||
]
|
||||
.into(),
|
||||
vec![
|
||||
Arc::new(Int64Array::from(vec![Some(3)])),
|
||||
// Arc::new(Int64Array::from(vec![Some(3)])),
|
||||
Arc::new(TimestampMillisecondArray::from(vec![Some(3)])),
|
||||
Arc::new(TimestampMillisecondArray::from(vec![Some(3)])),
|
||||
Arc::new(BooleanArray::from(vec![Some(true)])),
|
||||
],
|
||||
None,
|
||||
|
||||
Reference in New Issue
Block a user