mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-25 23:49:58 +00:00
fix: not step when aggr have order by/filter (#7015)
* fix: not applied Signed-off-by: discord9 <discord9@163.com> * chore: per review Signed-off-by: discord9 <discord9@163.com> * test: confirm order by not push down Signed-off-by: discord9 <discord9@163.com> --------- Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
@@ -324,6 +324,17 @@ impl AggregateUDFImpl for StateWrapper {
|
|||||||
is_distinct: false,
|
is_distinct: false,
|
||||||
};
|
};
|
||||||
let state_fields = self.inner.state_fields(state_fields_args)?;
|
let state_fields = self.inner.state_fields(state_fields_args)?;
|
||||||
|
|
||||||
|
let state_fields = state_fields
|
||||||
|
.into_iter()
|
||||||
|
.map(|f| {
|
||||||
|
let mut f = f.as_ref().clone();
|
||||||
|
// since state can be null when no input rows, so make all fields nullable
|
||||||
|
f.set_nullable(true);
|
||||||
|
Arc::new(f)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let struct_field = DataType::Struct(state_fields.into());
|
let struct_field = DataType::Struct(state_fields.into());
|
||||||
Ok(struct_field)
|
Ok(struct_field)
|
||||||
}
|
}
|
||||||
@@ -388,6 +399,38 @@ impl Accumulator for StateAccum {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.to_array())
|
.map(|s| s.to_array())
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
let array_type = array
|
||||||
|
.iter()
|
||||||
|
.map(|a| a.data_type().clone())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let expected_type: Vec<_> = self
|
||||||
|
.state_fields
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.data_type().clone())
|
||||||
|
.collect();
|
||||||
|
if array_type != expected_type {
|
||||||
|
debug!(
|
||||||
|
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
|
||||||
|
self.state_fields.len(),
|
||||||
|
array.len(),
|
||||||
|
self.state_fields,
|
||||||
|
array_type,
|
||||||
|
);
|
||||||
|
let guess_schema = array
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(index, array)| {
|
||||||
|
Field::new(
|
||||||
|
format!("col_{index}[mismatch_state]").as_str(),
|
||||||
|
array.data_type().clone(),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Fields>();
|
||||||
|
let arr = StructArray::try_new(guess_schema, array, None)?;
|
||||||
|
|
||||||
|
return Ok(ScalarValue::Struct(Arc::new(arr)));
|
||||||
|
}
|
||||||
let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
|
let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
|
||||||
Ok(ScalarValue::Struct(Arc::new(struct_array)))
|
Ok(ScalarValue::Struct(Arc::new(struct_array)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ use common_recordbatch::error::Result as RecordBatchResult;
|
|||||||
use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
|
use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
|
||||||
use common_telemetry::init_default_ut_logging;
|
use common_telemetry::init_default_ut_logging;
|
||||||
use datafusion::datasource::DefaultTableSource;
|
use datafusion::datasource::DefaultTableSource;
|
||||||
use datafusion::functions_aggregate::expr_fn::avg;
|
use datafusion::functions_aggregate::expr_fn::{avg, last_value};
|
||||||
use datafusion::functions_aggregate::min_max::{max, min};
|
use datafusion::functions_aggregate::min_max::{max, min};
|
||||||
use datafusion_common::JoinType;
|
use datafusion_common::JoinType;
|
||||||
use datafusion_expr::expr::ScalarFunction;
|
use datafusion_expr::expr::ScalarFunction;
|
||||||
@@ -1494,3 +1494,36 @@ fn date_bin_ts_group_by() {
|
|||||||
.join("\n");
|
.join("\n");
|
||||||
assert_eq!(expected, result.to_string());
|
assert_eq!(expected, result.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// check that `last_value(ts order by ts)` won't be push down
|
||||||
|
#[test]
|
||||||
|
fn test_not_push_down_aggr_order_by() {
|
||||||
|
init_default_ut_logging();
|
||||||
|
let test_table = TestTable::table_with_name(0, "t".to_string());
|
||||||
|
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
|
||||||
|
DfTableProviderAdapter::new(test_table),
|
||||||
|
)));
|
||||||
|
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||||
|
.unwrap()
|
||||||
|
.aggregate(
|
||||||
|
Vec::<Expr>::new(),
|
||||||
|
vec![last_value(col("ts"), vec![col("ts").sort(true, false)])],
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = ConfigOptions::default();
|
||||||
|
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
|
||||||
|
|
||||||
|
let expected = [
|
||||||
|
"Aggregate: groupBy=[[]], aggr=[[last_value(t.ts) ORDER BY [t.ts ASC NULLS LAST]]]",
|
||||||
|
" Projection: t.pk1, t.pk2, t.pk3, t.ts, t.number",
|
||||||
|
" MergeScan [is_placeholder=false, remote_input=[",
|
||||||
|
"TableScan: t",
|
||||||
|
"]]",
|
||||||
|
]
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
assert_eq!(expected, result.to_string());
|
||||||
|
}
|
||||||
|
|||||||
@@ -79,8 +79,11 @@ pub fn step_aggr_to_upper_aggr(
|
|||||||
pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
|
pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
|
||||||
aggr_exprs.iter().all(|expr| {
|
aggr_exprs.iter().all(|expr| {
|
||||||
if let Some(aggr_func) = get_aggr_func(expr) {
|
if let Some(aggr_func) = get_aggr_func(expr) {
|
||||||
if aggr_func.params.distinct {
|
if aggr_func.params.distinct
|
||||||
// Distinct aggregate functions are not steppable(yet).
|
|| !aggr_func.params.order_by.is_empty()
|
||||||
|
|| aggr_func.params.filter.is_some()
|
||||||
|
{
|
||||||
|
// Distinct aggregate functions/order by/filter in aggr args are not steppable(yet).
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user