feat: get stat for aggr

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-04-15 15:18:18 +08:00
parent a5686f0042
commit 3e53a562cf
18 changed files with 2100 additions and 251 deletions

View File

@@ -330,6 +330,39 @@ impl StateWrapper {
acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
Ok(acc_args)
}
/// Builds a state scalar from explicit state-field values.
///
/// The caller must provide one scalar per state field in the wrapper's state layout.
/// This method is responsible only for validating the current wrapper state type and
/// assembling the final struct scalar from those explicit field values.
pub fn value_from_custom_state_fields(
&self,
arg_types: &[DataType],
state_values: Vec<ScalarValue>,
) -> datafusion_common::Result<ScalarValue> {
let DataType::Struct(fields) = self.return_type(arg_types)? else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected struct state type for {}, got non-struct return type",
self.name()
)));
};
if fields.len() != state_values.len() {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected {} state fields for {}, got {}",
fields.len(),
self.name(),
state_values.len()
)));
}
let arrays = state_values
.into_iter()
.map(|value| value.to_array())
.collect::<datafusion_common::Result<Vec<_>>>()?;
let struct_array = build_state_struct_array(&fields, arrays)?;
Ok(ScalarValue::Struct(Arc::new(struct_array)))
}
}
impl AggregateUDFImpl for StateWrapper {
@@ -472,13 +505,59 @@ impl AggregateUDFImpl for StateWrapper {
};
let array = ret.to_array().ok()?;
let struct_array = StructArray::new(fields.clone(), vec![array], None);
let ret = ScalarValue::Struct(Arc::new(struct_array));
Some(ret)
}
}
fn build_state_struct_array(
fields: &Fields,
arrays: Vec<ArrayRef>,
) -> datafusion_common::Result<StructArray> {
let array_type = arrays
.iter()
.map(|array| array.data_type().clone())
.collect::<Vec<_>>();
let expected_type = fields
.iter()
.map(|field| field.data_type().clone())
.collect::<Vec<_>>();
if array_type != expected_type {
// Keep this fallback intentionally lenient.
//
// Historically the wrapper path has tolerated state-schema drift as long as the
// physical state columns remain positionally compatible. This shows up most clearly
// in order-sensitive aggregates such as first_value/last_value, where DataFusion-side
// state metadata and the arrays we need to wrap may not line up exactly. The merge
// path consumes state columns by position, not by field metadata, so preserving a
// struct wrapper here is more compatible than failing eagerly on field/type mismatch.
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
fields.len(),
arrays.len(),
fields,
array_type,
);
let guess_schema = arrays
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
return StructArray::try_new(guess_schema, arrays, None)
.map_err(|err| datafusion_common::DataFusionError::ArrowError(Box::new(err), None));
}
StructArray::try_new(fields.clone(), arrays, None)
.map_err(|err| datafusion_common::DataFusionError::ArrowError(Box::new(err), None))
}
/// The wrapper's input is the same as the original aggregate function's input,
/// and the output is the state function's output.
#[derive(Debug)]
@@ -510,42 +589,9 @@ impl StateGroupsAccum {
}
fn wrap_state_arrays(&self, arrays: Vec<ArrayRef>) -> datafusion_common::Result<ArrayRef> {
let array_type = arrays
.iter()
.map(|array| array.data_type().clone())
.collect::<Vec<_>>();
let expected_type = self
.state_fields
.iter()
.map(|field| field.data_type().clone())
.collect::<Vec<_>>();
if array_type != expected_type {
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
self.state_fields.len(),
arrays.len(),
self.state_fields,
array_type,
);
let guess_schema = arrays
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
let array = StructArray::try_new(guess_schema, arrays, None)?;
return Ok(Arc::new(array));
}
Ok(Arc::new(StructArray::try_new(
self.state_fields.clone(),
Ok(Arc::new(build_state_struct_array(
&self.state_fields,
arrays,
None,
)?))
}
}
@@ -621,44 +667,11 @@ impl Accumulator for StateAccum {
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
let state = self.inner.state()?;
let array = state
let arrays = state
.iter()
.map(|s| s.to_array())
.collect::<Result<Vec<_>, _>>()?;
let array_type = array
.iter()
.map(|a| a.data_type().clone())
.collect::<Vec<_>>();
let expected_type: Vec<_> = self
.state_fields
.iter()
.map(|f| f.data_type().clone())
.collect();
if array_type != expected_type {
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
self.state_fields.len(),
array.len(),
self.state_fields,
array_type,
);
let guess_schema = array
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
let arr = StructArray::try_new(guess_schema, array, None)?;
return Ok(ScalarValue::Struct(Arc::new(arr)));
}
let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
let struct_array = build_state_struct_array(&self.state_fields, arrays)?;
Ok(ScalarValue::Struct(Arc::new(struct_array)))
}
@@ -860,7 +873,10 @@ impl Accumulator for MergeAccum {
"State fields mismatch, expected: {:?}, got: {:?}",
self.state_fields, fields
);
// state fields mismatch might be acceptable by datafusion, continue
// Intentionally continue here for compatibility with the wrapper's historical
// behavior: downstream merge logic uses the struct columns positionally, and some
// DataFusion/order-sensitive aggregate paths can produce equivalent state payloads
// whose field metadata does not exactly match our locally expected schema.
}
// now fields should be the same, so we can merge the batch

View File

@@ -28,6 +28,7 @@ use datafusion::datasource::DefaultTableSource;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::min_max::max_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::optimizer::AnalyzerRule;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
@@ -291,6 +292,50 @@ fn create_avg_state_groups_accumulator() -> Box<dyn GroupsAccumulator> {
state_wrapper.create_groups_accumulator(acc_args).unwrap()
}
fn test_state_scalar_for_type(data_type: &DataType) -> ScalarValue {
match data_type {
DataType::Float64 => ScalarValue::Float64(Some(1.5)),
DataType::UInt64 => ScalarValue::UInt64(Some(2)),
DataType::Int64 => ScalarValue::Int64(Some(3)),
_ => panic!("unsupported test data type: {data_type:?}"),
}
}
#[test]
fn test_value_from_custom_state_fields_single_field() {
let wrapper = StateWrapper::new((*max_udaf()).clone()).unwrap();
let value = wrapper
.value_from_custom_state_fields(&[DataType::Int64], vec![ScalarValue::Int64(Some(7))])
.unwrap();
let ScalarValue::Struct(array) = value else {
panic!("expected struct state")
};
assert_eq!(1, array.columns().len());
assert_eq!(DataType::Int64, array.column(0).data_type().clone());
}
#[test]
fn test_value_from_custom_state_fields_multi_field() {
let wrapper = StateWrapper::new((*avg_udaf()).clone()).unwrap();
let DataType::Struct(fields) = wrapper.return_type(&[DataType::Float64]).unwrap() else {
panic!("expected struct state type")
};
let values = fields
.iter()
.map(|field| test_state_scalar_for_type(field.data_type()))
.collect::<Vec<_>>();
let value = wrapper
.value_from_custom_state_fields(&[DataType::Float64], values)
.unwrap();
let ScalarValue::Struct(array) = value else {
panic!("expected struct state")
};
assert_eq!(fields.len(), array.columns().len());
}
#[tokio::test]
async fn test_sum_udaf() {
let ctx = SessionContext::new();

View File

@@ -28,6 +28,7 @@ pub(crate) mod prune;
pub(crate) mod pruner;
pub mod range;
pub(crate) mod range_cache;
pub(crate) mod scan_input_stats;
pub mod scan_region;
pub mod scan_util;
pub(crate) mod seq_scan;

View File

@@ -0,0 +1,433 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use api::v1::SemanticType;
use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use datafusion_common::pruning::PruningStatistics;
use datafusion_common::{Column, ScalarValue};
use datatypes::arrow::array::{Array, AsArray, UInt64Array};
use datatypes::arrow::compute::{cast, max, max_boolean, max_string, min, min_boolean, min_string};
use datatypes::arrow::datatypes::{
DataType as ArrowDataType, Float32Type, Float64Type, Int32Type, Int64Type,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt32Type, UInt64Type,
};
use datatypes::value::Value;
use store_api::metadata::RegionMetadata;
use store_api::scan_stats::{
RegionScanColumnStats as RegionScanColumnInputStats,
RegionScanFileStats as RegionScanFileInputStats, RegionScanStats as RegionScanInputStats,
};
use crate::read::scan_region::ScanInput;
use crate::sst::file::FileHandle;
use crate::sst::parquet::format::ReadFormat;
use crate::sst::parquet::stats::RowGroupPruningStats;
pub(crate) fn build_scan_input_stats(
input: &ScanInput,
metadata: &RegionMetadata,
) -> std::result::Result<RegionScanInputStats, BoxedError> {
let files = input
.files
.iter()
.enumerate()
.map(|(index, file)| {
let partition_expr_matches_region = file_partition_expr_matches_region(metadata, file)?;
Ok(RegionScanFileInputStats {
file_ordinal: index,
exact_num_rows: exact_file_num_rows(file),
time_range: exact_file_time_range(file),
field_stats: build_file_field_stats(input, metadata, file)?,
partition_expr_matches_region,
})
})
.collect::<std::result::Result<Vec<_>, BoxedError>>()?;
Ok(RegionScanInputStats { files })
}
fn exact_file_num_rows(file: &FileHandle) -> Option<usize> {
Some(file.num_rows())
}
fn exact_file_time_range(
file: &FileHandle,
) -> Option<(common_time::Timestamp, common_time::Timestamp)> {
(file.meta_ref().num_row_groups != 0).then_some(file.time_range())
}
fn build_file_field_stats(
input: &ScanInput,
metadata: &RegionMetadata,
file: &FileHandle,
) -> std::result::Result<HashMap<String, RegionScanColumnInputStats>, BoxedError> {
// TODO(ruihang): extract stats only for columns referenced by the supported aggregates
// instead of eagerly materializing every field column for every file.
let Some(parquet_meta) = input
.cache_strategy
.get_parquet_meta_data_from_mem_cache(file.file_id())
else {
return Ok(HashMap::new());
};
let region_metadata = Arc::new(metadata.clone());
let file_path = format!("{:?}", file.file_id());
let read_format = ReadFormat::new(
region_metadata.clone(),
None,
input.flat_format,
Some(parquet_meta.file_metadata().schema_descr().num_columns()),
&file_path,
false,
)
.map_err(BoxedError::new)?;
let row_groups = parquet_meta.row_groups();
let pruning_stats =
RowGroupPruningStats::new(row_groups, &read_format, Some(region_metadata), false);
metadata
.column_metadatas
.iter()
.filter(|column| column.semantic_type == SemanticType::Field)
.filter_map(|column| {
let stats = build_field_column_stats(
column.column_schema.name.as_str(),
row_groups,
&pruning_stats,
)
.transpose();
match stats {
Some(Ok(stats)) => Some(Ok((column.column_schema.name.to_string(), stats))),
Some(Err(err)) => Some(Err(err)),
None => None,
}
})
.collect()
}
fn build_field_column_stats(
column_name: &str,
row_groups: &[parquet::file::metadata::RowGroupMetaData],
pruning_stats: &impl PruningStatistics,
) -> std::result::Result<Option<RegionScanColumnInputStats>, BoxedError> {
let column = Column::from_name(column_name);
let min_value = aggregate_column_min_value(pruning_stats.min_values(&column).as_deref())?;
let max_value = aggregate_column_max_value(pruning_stats.max_values(&column).as_deref())?;
let exact_non_null_rows =
aggregate_exact_non_null_rows(pruning_stats.null_counts(&column).as_deref(), row_groups)?;
if min_value.is_none() && max_value.is_none() && exact_non_null_rows.is_none() {
return Ok(None);
}
Ok(Some(RegionScanColumnInputStats {
min_value,
max_value,
exact_non_null_rows,
}))
}
fn aggregate_column_min_value(
values: Option<&dyn Array>,
) -> std::result::Result<Option<Value>, BoxedError> {
aggregate_column_extreme_value(values, true)
}
fn aggregate_column_max_value(
values: Option<&dyn Array>,
) -> std::result::Result<Option<Value>, BoxedError> {
aggregate_column_extreme_value(values, false)
}
fn aggregate_column_extreme_value(
values: Option<&dyn Array>,
is_min: bool,
) -> std::result::Result<Option<Value>, BoxedError> {
let Some(values) = values else {
return Ok(None);
};
if values.is_empty() || values.null_count() > 0 {
return Ok(None);
}
if let Some(value) = aggregate_column_extreme_value_with_compute(values, is_min)? {
return Ok(Some(value));
}
aggregate_column_extreme_value_fallback(values, is_min)
}
fn aggregate_column_extreme_value_with_compute(
values: &dyn Array,
is_min: bool,
) -> std::result::Result<Option<Value>, BoxedError> {
if let ArrowDataType::Dictionary(_, value_type) = values.data_type() {
let casted = cast(values, value_type.as_ref()).map_err(|err| {
BoxedError::new(PlainError::new(
format!("failed to cast dictionary stats array to value type: {err}"),
StatusCode::Unexpected,
))
})?;
return aggregate_column_extreme_value_with_compute(casted.as_ref(), is_min);
}
macro_rules! compute_primitive_extreme {
($array_ty:ty, $variant:ident) => {{
let array = values.as_primitive::<$array_ty>();
let scalar = if is_min {
min(array).map(|value| ScalarValue::$variant(Some(value)))
} else {
max(array).map(|value| ScalarValue::$variant(Some(value)))
};
scalar
.map(|value| Value::try_from(value).map_err(BoxedError::new))
.transpose()
}};
}
macro_rules! compute_timestamp_extreme {
($array_ty:ty, $variant:ident, $tz:expr) => {{
let array = values.as_primitive::<$array_ty>();
let scalar = if is_min {
min(array).map(|value| ScalarValue::$variant(Some(value), $tz.clone()))
} else {
max(array).map(|value| ScalarValue::$variant(Some(value), $tz.clone()))
};
scalar
.map(|value| Value::try_from(value).map_err(BoxedError::new))
.transpose()
}};
}
match values.data_type() {
ArrowDataType::Boolean => {
let array = values.as_boolean();
let scalar = if is_min {
min_boolean(array).map(|value| ScalarValue::Boolean(Some(value)))
} else {
max_boolean(array).map(|value| ScalarValue::Boolean(Some(value)))
};
scalar
.map(|value| Value::try_from(value).map_err(BoxedError::new))
.transpose()
}
ArrowDataType::Utf8 => {
let array = values.as_string::<i32>();
let scalar = if is_min {
min_string(array)
} else {
max_string(array)
}
.map(|value| ScalarValue::Utf8(Some(value.to_string())));
scalar
.map(|value| Value::try_from(value).map_err(BoxedError::new))
.transpose()
}
ArrowDataType::LargeUtf8 => {
let array = values.as_string::<i64>();
let scalar = if is_min {
min_string(array)
} else {
max_string(array)
}
.map(|value| ScalarValue::LargeUtf8(Some(value.to_string())));
scalar
.map(|value| Value::try_from(value).map_err(BoxedError::new))
.transpose()
}
ArrowDataType::UInt32 => compute_primitive_extreme!(UInt32Type, UInt32),
ArrowDataType::UInt64 => compute_primitive_extreme!(UInt64Type, UInt64),
ArrowDataType::Int32 => compute_primitive_extreme!(Int32Type, Int32),
ArrowDataType::Int64 => compute_primitive_extreme!(Int64Type, Int64),
ArrowDataType::Float32 => compute_primitive_extreme!(Float32Type, Float32),
ArrowDataType::Float64 => compute_primitive_extreme!(Float64Type, Float64),
ArrowDataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Second, tz) => {
compute_timestamp_extreme!(TimestampSecondType, TimestampSecond, tz)
}
ArrowDataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Millisecond, tz) => {
compute_timestamp_extreme!(TimestampMillisecondType, TimestampMillisecond, tz)
}
ArrowDataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Microsecond, tz) => {
compute_timestamp_extreme!(TimestampMicrosecondType, TimestampMicrosecond, tz)
}
ArrowDataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Nanosecond, tz) => {
compute_timestamp_extreme!(TimestampNanosecondType, TimestampNanosecond, tz)
}
_ => Ok(None),
}
}
fn aggregate_column_extreme_value_fallback(
values: &dyn Array,
is_min: bool,
) -> std::result::Result<Option<Value>, BoxedError> {
let scalars = (0..values.len())
.map(|index| {
ScalarValue::try_from_array(values, index).map_err(|err| {
BoxedError::new(PlainError::new(
format!("failed to extract scalar value from stats array: {err}"),
StatusCode::Unexpected,
))
})
})
.collect::<std::result::Result<Vec<_>, _>>()?;
let mut iter = scalars
.into_iter()
.map(|value| Value::try_from(value).map_err(BoxedError::new));
let Some(first) = iter.next() else {
return Ok(None);
};
let first = first?;
iter.try_fold(first, |current, value| {
let value = value?;
let next = if is_min {
Value::min(current, value)
} else {
Value::max(current, value)
};
Ok::<_, BoxedError>(next)
})
.map(Some)
}
fn aggregate_exact_non_null_rows(
null_counts: Option<&dyn Array>,
row_groups: &[parquet::file::metadata::RowGroupMetaData],
) -> std::result::Result<Option<usize>, BoxedError> {
let Some(null_counts) = null_counts else {
return Ok(None);
};
if null_counts.null_count() > 0 {
return Ok(None);
}
let Some(null_counts) = null_counts.as_any().downcast_ref::<UInt64Array>() else {
return Ok(None);
};
row_groups
.iter()
.zip(null_counts.iter())
.try_fold(0usize, |acc, (row_group, null_count)| {
let row_count = usize::try_from(row_group.num_rows()).map_err(|err| {
BoxedError::new(PlainError::new(
format!("failed to convert row group row count to usize: {err}"),
StatusCode::Unexpected,
))
})?;
let null_count = usize::try_from(null_count.unwrap_or_default()).map_err(|err| {
BoxedError::new(PlainError::new(
format!("failed to convert parquet null count to usize: {err}"),
StatusCode::Unexpected,
))
})?;
Ok::<_, BoxedError>(acc + row_count.saturating_sub(null_count))
})
.map(Some)
}
fn file_partition_expr_matches_region(
metadata: &RegionMetadata,
file: &FileHandle,
) -> std::result::Result<bool, BoxedError> {
let file_partition_expr = file
.meta_ref()
.partition_expr
.as_ref()
.map(|expr| expr.as_json_str())
.transpose()
.map_err(BoxedError::new)?;
Ok(file_partition_expr == metadata.partition_expr)
}
#[cfg(test)]
mod tests {
use common_time::timestamp::TimeUnit as TimestampUnit;
use datatypes::arrow::array::{
DictionaryArray, Int64Array, StringArray, TimestampMillisecondArray, UInt64Array,
};
use datatypes::arrow::datatypes::Int32Type;
use super::*;
#[test]
fn test_aggregate_column_extreme_value_uses_numeric_fast_path() {
let values = UInt64Array::from(vec![Some(7), Some(2), Some(11)]);
let min_value = aggregate_column_min_value(Some(&values)).unwrap();
let max_value = aggregate_column_max_value(Some(&values)).unwrap();
assert_eq!(Some(Value::UInt64(2)), min_value);
assert_eq!(Some(Value::UInt64(11)), max_value);
}
#[test]
fn test_aggregate_column_extreme_value_uses_string_fast_path() {
let values = StringArray::from(vec![Some("delta"), Some("alpha"), Some("gamma")]);
let min_value = aggregate_column_min_value(Some(&values)).unwrap();
let max_value = aggregate_column_max_value(Some(&values)).unwrap();
assert_eq!(Some(Value::String("alpha".into())), min_value);
assert_eq!(Some(Value::String("gamma".into())), max_value);
}
#[test]
fn test_aggregate_column_extreme_value_uses_timestamp_fast_path() {
let values = TimestampMillisecondArray::from(vec![Some(7), Some(2), Some(11)]);
let min_value = aggregate_column_min_value(Some(&values)).unwrap();
let max_value = aggregate_column_max_value(Some(&values)).unwrap();
assert_eq!(
Some(Value::Timestamp(common_time::Timestamp::new(
2,
TimestampUnit::Millisecond
))),
min_value
);
assert_eq!(
Some(Value::Timestamp(common_time::Timestamp::new(
11,
TimestampUnit::Millisecond
))),
max_value
);
}
#[test]
fn test_aggregate_column_extreme_value_dictionary_falls_back() {
let values =
DictionaryArray::<Int32Type>::from_iter([Some("delta"), Some("alpha"), Some("gamma")]);
let min_value = aggregate_column_min_value(Some(&values)).unwrap();
let max_value = aggregate_column_max_value(Some(&values)).unwrap();
assert_eq!(Some(Value::String("alpha".into())), min_value);
assert_eq!(Some(Value::String("gamma".into())), max_value);
}
#[test]
fn test_aggregate_column_extreme_value_returns_none_when_any_stats_are_null() {
let values = Int64Array::from(vec![Some(7), None, Some(11)]);
assert_eq!(None, aggregate_column_min_value(Some(&values)).unwrap());
assert_eq!(None, aggregate_column_max_value(Some(&values)).unwrap());
}
}

View File

@@ -32,6 +32,7 @@ use store_api::metadata::RegionMetadataRef;
use store_api::region_engine::{
PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
};
use store_api::scan_stats::RegionScanStats;
use store_api::storage::TimeSeriesRowSelector;
use tokio::sync::Semaphore;
@@ -43,6 +44,7 @@ use crate::read::last_row::{FlatLastRowReader, LastRowReader};
use crate::read::merge::MergeReaderBuilder;
use crate::read::pruner::{PartitionPruner, Pruner};
use crate::read::range::RangeMeta;
use crate::read::scan_input_stats::build_scan_input_stats;
use crate::read::scan_region::{ScanInput, StreamContext};
use crate::read::scan_util::{
PartitionMetrics, PartitionMetricsList, SplitRecordBatchStream, scan_file_ranges,
@@ -669,6 +671,14 @@ impl RegionScanner for SeqScan {
predicate.is_some()
}
fn scan_input_stats(&self) -> Result<Option<RegionScanStats>, BoxedError> {
build_scan_input_stats(
&self.stream_ctx.input,
self.stream_ctx.input.mapper.metadata(),
)
.map(Some)
}
fn add_dyn_filter_to_predicate(
&mut self,
filter_exprs: Vec<Arc<dyn datafusion::physical_plan::PhysicalExpr>>,

View File

@@ -36,6 +36,7 @@ use store_api::metadata::RegionMetadataRef;
use store_api::region_engine::{
PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
};
use store_api::scan_stats::RegionScanStats;
use tokio::sync::Semaphore;
use tokio::sync::mpsc::error::{SendTimeoutError, TrySendError};
use tokio::sync::mpsc::{self, Receiver, Sender};
@@ -45,6 +46,7 @@ use crate::error::{
ScanSeriesSnafu, TooManyFilesToReadSnafu,
};
use crate::read::pruner::{PartitionPruner, Pruner};
use crate::read::scan_input_stats::build_scan_input_stats;
use crate::read::scan_region::{ScanInput, StreamContext};
use crate::read::scan_util::{PartitionMetrics, PartitionMetricsList, SeriesDistributorMetrics};
use crate::read::seq_scan::{SeqScan, build_flat_sources, build_sources};
@@ -364,6 +366,14 @@ impl RegionScanner for SeriesScan {
predicate.is_some()
}
fn scan_input_stats(&self) -> Result<Option<RegionScanStats>, BoxedError> {
build_scan_input_stats(
&self.stream_ctx.input,
self.stream_ctx.input.mapper.metadata(),
)
.map(Some)
}
fn add_dyn_filter_to_predicate(
&mut self,
filter_exprs: Vec<Arc<dyn datafusion::physical_plan::PhysicalExpr>>,

View File

@@ -32,9 +32,11 @@ use store_api::metadata::RegionMetadataRef;
use store_api::region_engine::{
PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties,
};
use store_api::scan_stats::RegionScanStats;
use crate::error::{PartitionOutOfRangeSnafu, Result};
use crate::read::pruner::{PartitionPruner, Pruner};
use crate::read::scan_input_stats::build_scan_input_stats;
use crate::read::scan_region::{ScanInput, StreamContext};
use crate::read::scan_util::{
PartitionMetrics, PartitionMetricsList, scan_file_ranges, scan_flat_file_ranges,
@@ -495,6 +497,14 @@ impl RegionScanner for UnorderedScan {
predicate.is_some()
}
fn scan_input_stats(&self) -> Result<Option<RegionScanStats>, BoxedError> {
build_scan_input_stats(
&self.stream_ctx.input,
self.stream_ctx.input.mapper.metadata(),
)
.map(Some)
}
fn add_dyn_filter_to_predicate(
&mut self,
filter_exprs: Vec<Arc<dyn datafusion::physical_plan::PhysicalExpr>>,

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod aggregate_stats;
pub mod aggr_stats;
pub mod constant_term;
pub mod count_wildcard;
pub mod parallelize_scan;

View File

@@ -0,0 +1,111 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use common_telemetry::debug;
use datafusion::config::ConfigOptions;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion_common::Result as DfResult;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datatypes::arrow::datatypes::DataType;
use table::table::scan::RegionScanExec;
mod check;
mod split;
#[cfg(test)]
mod tests;
use check::RewriteCheck;
#[derive(Debug)]
pub struct AggregateStats;
/// All supported aggregate from statistics
#[derive(Debug, Clone, PartialEq, Eq)]
enum StatsAgg {
CountStar,
CountField {
column_name: String,
arg_type: DataType,
},
CountTimeIndex {
arg_type: DataType,
},
MinField {
column_name: String,
arg_type: DataType,
},
MinTimeIndex {
arg_type: DataType,
},
MaxField {
column_name: String,
arg_type: DataType,
},
MaxTimeIndex {
arg_type: DataType,
},
}
impl PhysicalOptimizerRule for AggregateStats {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> DfResult<Arc<dyn ExecutionPlan>> {
Self::do_optimize(plan)
}
fn name(&self) -> &str {
"aggregate_stats"
}
fn schema_check(&self) -> bool {
true
}
}
impl AggregateStats {
fn do_optimize(plan: Arc<dyn ExecutionPlan>) -> DfResult<Arc<dyn ExecutionPlan>> {
let result = plan
.transform_down(|plan| {
let Some(aggregate_exec) = plan.as_any().downcast_ref::<AggregateExec>() else {
return Ok(Transformed::no(plan));
};
let Some(region_scan) = Self::extract_region_scan(aggregate_exec) else {
return Ok(Transformed::no(plan));
};
let check = RewriteCheck::new(aggregate_exec, region_scan);
if let Some(reason) = check.skip_reason()? {
debug!("Skip aggregate stats optimization: {reason}");
return Ok(Transformed::no(plan));
}
Ok(Transformed::no(plan))
})?
.data;
Ok(result)
}
fn extract_region_scan(aggregate_exec: &AggregateExec) -> Option<&RegionScanExec> {
let child = aggregate_exec.children().into_iter().next()?;
child.as_any().downcast_ref::<RegionScanExec>()
}
}

View File

@@ -0,0 +1,221 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_telemetry::debug;
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion_common::Result as DfResult;
use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datafusion_physical_expr::expressions::{Column, Literal};
use table::table::scan::RegionScanExec;
use super::StatsAgg;
use super::split::{StatsAggExt, has_partition_expr_mismatch};
#[derive(Debug)]
pub(super) struct RewriteCheck<'a> {
aggregate_exec: &'a AggregateExec,
region_scan: &'a RegionScanExec,
}
impl<'a> RewriteCheck<'a> {
pub(super) fn new(aggregate_exec: &'a AggregateExec, region_scan: &'a RegionScanExec) -> Self {
Self {
aggregate_exec,
region_scan,
}
}
pub(super) fn skip_reason(&self) -> DfResult<Option<RejectReason>> {
// MVP only handles global aggregates over append-mode region scans. Anything
// else falls back to the normal execution path.
if !self.region_scan.append_mode() || !self.aggregate_exec.group_expr().is_empty() {
return Ok(Some(RejectReason::UnsupportedPlan));
}
let aggs = match self.parse_aggs() {
Ok(aggs) => aggs,
Err(reason) => return Ok(Some(reason)),
};
let scan_input_stats = self.try_scan_stats();
if self.stats_unavailable(scan_input_stats.as_ref()) {
return Ok(Some(RejectReason::StatsUnavailable));
}
if !aggs
.iter()
.all(|agg| agg.has_stats_files(scan_input_stats.as_ref().unwrap()))
{
return Ok(Some(RejectReason::NoStatsFiles));
}
Ok(None)
}
fn try_scan_stats(&self) -> Option<store_api::scan_stats::RegionScanStats> {
match self.region_scan.scan_input_stats() {
Ok(stats) => stats,
Err(err) => {
debug!(
"Skip aggregate stats optimization: failed to collect scan input stats: {err}"
);
None
}
}
}
fn parse_aggs(&self) -> Result<Vec<StatsAgg>, RejectReason> {
let aggr_exprs = self.aggregate_exec.aggr_expr();
if aggr_exprs.is_empty() {
return Err(RejectReason::UnsupportedAggregate);
}
aggr_exprs.iter().map(|expr| self.parse_agg(expr)).collect()
}
fn parse_agg(&self, expr: &AggregateFunctionExpr) -> Result<StatsAgg, RejectReason> {
if !is_supported_aggregate_name(expr.fun().name()) {
return Err(RejectReason::UnsupportedAggregate);
}
Self::check_agg_shape(expr)?;
let inputs = expr.expressions();
let name = expr.fun().name().to_ascii_lowercase();
// COUNT(*) is usually rewrite to COUNT(time-index)
// before this physical optimizer runs, so CountStar is mostly a defensive fallback
if name == "count" && is_count_star_expr(&inputs) {
return Ok(StatsAgg::CountStar);
}
if inputs.len() != 1 {
return Err(RejectReason::UnsupportedAggregate);
}
let Some(column) = inputs[0].as_any().downcast_ref::<Column>() else {
return Err(RejectReason::UnsupportedAggregate);
};
let column_name = column.name().to_string();
let arg_type = inputs[0]
.data_type(self.aggregate_exec.input_schema().as_ref())
.map_err(|_| RejectReason::UnsupportedAggregate)?;
if self.is_tag_column(&column_name) {
return Err(RejectReason::UnsupportedPlan);
}
let is_time_index = column_name == self.region_scan.time_index();
match (name.as_str(), is_time_index) {
("count", true) => Ok(StatsAgg::CountTimeIndex { arg_type }),
("count", false) => Ok(StatsAgg::CountField {
column_name,
arg_type,
}),
("min", true) => Ok(StatsAgg::MinTimeIndex { arg_type }),
("min", false) => Ok(StatsAgg::MinField {
column_name,
arg_type,
}),
("max", true) => Ok(StatsAgg::MaxTimeIndex { arg_type }),
("max", false) => Ok(StatsAgg::MaxField {
column_name,
arg_type,
}),
_ => Err(RejectReason::UnsupportedAggregate),
}
}
pub(super) fn check_agg_shape(expr: &AggregateFunctionExpr) -> Result<(), RejectReason> {
if expr.is_distinct()
|| expr.ignore_nulls()
|| expr.is_reversed()
|| !expr.order_bys().is_empty()
{
return Err(RejectReason::UnsupportedAggregate);
}
Ok(())
}
fn is_tag_column(&self, column_name: &str) -> bool {
self.region_scan
.tag_columns()
.iter()
.any(|tag| tag == column_name)
}
fn stats_unavailable(
&self,
scan_input_stats: Option<&store_api::scan_stats::RegionScanStats>,
) -> bool {
// These cases keep the plan shape eligible in principle, but the scan cannot
// provide metadata that is trustworthy enough for a safe stats-only rewrite.
self.region_scan.has_predicate_without_region()
|| scan_input_stats.is_none()
|| has_partition_expr_mismatch(scan_input_stats)
}
}
#[derive(Debug)]
pub(super) enum RejectReason {
/// The physical plan shape is outside the current safe rewrite envelope.
UnsupportedPlan,
/// At least one aggregate function or aggregate shape is unsupported.
UnsupportedAggregate,
/// The scan cannot provide trustworthy stats for this rewrite attempt.
StatsUnavailable,
/// The aggregate shape is supported, but no file contributes metadata-only results.
NoStatsFiles,
}
impl std::fmt::Display for RejectReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RejectReason::UnsupportedPlan => write!(
f,
"aggregate stats MVP does not support this plan shape safely"
),
RejectReason::UnsupportedAggregate => write!(
f,
"aggregate stats MVP only supports the narrowed no-GROUP-BY count/min/max matrix"
),
RejectReason::StatsUnavailable => write!(
f,
"aggregate stats MVP found a supported shape, but trustworthy statistics are unavailable for it"
),
RejectReason::NoStatsFiles => write!(
f,
"aggregate stats rewrite requires at least one stats-backed file after safety checks"
),
}
}
}
pub(super) fn is_supported_aggregate_name(name: &str) -> bool {
matches!(name.to_ascii_lowercase().as_str(), "min" | "max" | "count")
}
fn is_count_star_expr(inputs: &[std::sync::Arc<dyn PhysicalExpr>]) -> bool {
match inputs {
// Keep the legacy empty-input shape as a compatibility fallback.
[] => true,
[arg] => arg
.as_any()
.downcast_ref::<Literal>()
.is_some_and(|lit| lit.value() == &COUNT_STAR_EXPANSION),
_ => false,
}
}

View File

@@ -0,0 +1,399 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_function::aggrs::aggr_wrapper::StateWrapper;
use common_time::Timestamp;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf};
use datafusion_common::{Result as DfResult, ScalarValue};
use datatypes::arrow::datatypes::DataType;
use datatypes::data_type::ConcreteDataType;
use datatypes::value::Value;
use store_api::scan_stats::RegionScanStats;
use super::StatsAgg;
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) enum FileStatsRequirement {
FileExactRowCount,
FileTimeRange,
RowGroupMinMax,
RowGroupNullCount,
}
/// Splits scan input files into two buckets for one aggregate rewrite path.
///
/// `stats_file_ordinals` and `scan_file_ordinals` store
/// `RegionScanFileStats::file_ordinal`, not indexes into
/// `RegionScanStats::files`. The optimizer later uses these ordinals to decide
/// which physical files can be answered from metadata and which still need a
/// real scan.
///
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct FileSplit<T> {
/// File ordinals whose stats is sufficient to contribute to the rewrite.
pub stats_file_ordinals: Vec<usize>,
/// File ordinals that still need to be scanned because required stats are missing.
pub scan_file_ordinals: Vec<usize>,
/// Aggregate contribution computed from `stats_file_ordinals` only.
pub stats: T,
}
impl<T> FileSplit<T> {
fn has_stats_files(&self) -> bool {
!self.stats_file_ordinals.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
/// File-level time bounds collected from metadata-only files.
pub(super) struct TimeBounds {
pub min: Option<Timestamp>,
pub max: Option<Timestamp>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
/// Field value bounds collected from metadata-only files.
pub(super) struct ValueBounds {
pub min: Option<Value>,
pub max: Option<Value>,
}
pub(super) type CountStarFileSplit = FileSplit<usize>;
pub(super) type FieldCountFileSplit = FileSplit<usize>;
pub(super) type TimeFileSplit = FileSplit<TimeBounds>;
pub(super) type FieldMinMaxFileSplit = FileSplit<ValueBounds>;
pub(super) trait StatsAggExt {
fn has_stats_files(&self, scan_input_stats: &RegionScanStats) -> bool;
}
impl StatsAggExt for StatsAgg {
fn has_stats_files(&self, scan_input_stats: &RegionScanStats) -> bool {
match self {
StatsAgg::CountStar => split_count_star_files(scan_input_stats).has_stats_files(),
StatsAgg::CountField { column_name, .. } => {
split_count_field_files(scan_input_stats, column_name).has_stats_files()
}
StatsAgg::CountTimeIndex { .. }
| StatsAgg::MinTimeIndex { .. }
| StatsAgg::MaxTimeIndex { .. } => split_time_files(scan_input_stats).has_stats_files(),
StatsAgg::MinField { column_name, .. } | StatsAgg::MaxField { column_name, .. } => {
split_min_max_field_files(scan_input_stats, column_name).has_stats_files()
}
}
}
}
#[cfg(test)]
impl StatsAgg {
pub(super) fn file_stats_requirement(&self) -> FileStatsRequirement {
match self {
StatsAgg::CountStar => FileStatsRequirement::FileExactRowCount,
StatsAgg::CountField { .. } => FileStatsRequirement::RowGroupNullCount,
StatsAgg::CountTimeIndex { .. }
| StatsAgg::MinTimeIndex { .. }
| StatsAgg::MaxTimeIndex { .. } => FileStatsRequirement::FileTimeRange,
StatsAgg::MinField { .. } | StatsAgg::MaxField { .. } => {
FileStatsRequirement::RowGroupMinMax
}
}
}
}
pub(super) fn has_partition_expr_mismatch(scan_input_stats: Option<&RegionScanStats>) -> bool {
scan_input_stats
.map(|stats| {
stats
.files
.iter()
.any(|file| !file.partition_expr_matches_region)
})
.unwrap_or(false)
}
pub(super) fn split_count_star_files(scan_input_stats: &RegionScanStats) -> CountStarFileSplit {
scan_input_stats.files.iter().fold(
FileSplit {
stats_file_ordinals: Vec::new(),
scan_file_ordinals: Vec::new(),
stats: 0,
},
|mut split, file| {
if file.partition_expr_matches_region
&& let Some(num_rows) = file.exact_num_rows
{
split.stats_file_ordinals.push(file.file_ordinal);
split.stats += num_rows;
return split;
}
split.scan_file_ordinals.push(file.file_ordinal);
split
},
)
}
pub(super) fn split_time_files(scan_input_stats: &RegionScanStats) -> TimeFileSplit {
scan_input_stats.files.iter().fold(
FileSplit {
stats_file_ordinals: Vec::new(),
scan_file_ordinals: Vec::new(),
stats: TimeBounds::default(),
},
|mut split, file| {
if file.partition_expr_matches_region
&& let Some((min_ts, max_ts)) = file.time_range
{
split.stats_file_ordinals.push(file.file_ordinal);
split.stats.min = Some(match split.stats.min {
Some(current) => Timestamp::min(current, min_ts),
None => min_ts,
});
split.stats.max = Some(match split.stats.max {
Some(current) => Timestamp::max(current, max_ts),
None => max_ts,
});
return split;
}
split.scan_file_ordinals.push(file.file_ordinal);
split
},
)
}
pub(super) fn split_count_field_files(
scan_input_stats: &RegionScanStats,
column_name: &str,
) -> FieldCountFileSplit {
scan_input_stats.files.iter().fold(
FileSplit {
stats_file_ordinals: Vec::new(),
scan_file_ordinals: Vec::new(),
stats: 0,
},
|mut split, file| {
if file.partition_expr_matches_region
&& let Some(non_null_rows) = file
.field_stats
.get(column_name)
.and_then(|stats| stats.exact_non_null_rows)
{
split.stats_file_ordinals.push(file.file_ordinal);
split.stats += non_null_rows;
return split;
}
split.scan_file_ordinals.push(file.file_ordinal);
split
},
)
}
pub(super) fn split_min_max_field_files(
scan_input_stats: &RegionScanStats,
column_name: &str,
) -> FieldMinMaxFileSplit {
scan_input_stats.files.iter().fold(
FileSplit {
stats_file_ordinals: Vec::new(),
scan_file_ordinals: Vec::new(),
stats: ValueBounds::default(),
},
|mut split, file| {
if file.partition_expr_matches_region
&& let Some(stats) = file.field_stats.get(column_name)
&& let (Some(min_value), Some(max_value)) =
(stats.min_value.clone(), stats.max_value.clone())
{
split.stats_file_ordinals.push(file.file_ordinal);
split.stats.min = Some(match split.stats.min {
Some(current) => Value::min(current, min_value),
None => min_value,
});
split.stats.max = Some(match split.stats.max {
Some(current) => Value::max(current, max_value),
None => max_value,
});
return split;
}
split.scan_file_ordinals.push(file.file_ordinal);
split
},
)
}
// These helpers are implemented for the upcoming mixed stats-plus-scan rewrite in task 3.
#[allow(dead_code)]
pub(super) fn partial_state_from_stats(
aggregate: &StatsAgg,
scan_input_stats: &RegionScanStats,
) -> DfResult<Option<ScalarValue>> {
match aggregate {
StatsAgg::CountStar => {
let split = split_count_star_files(scan_input_stats);
if !split.has_stats_files() {
return Ok(None);
}
let wrapper = StateWrapper::new((*count_udaf()).clone())?;
wrapper
.value_from_custom_state_fields(
&[],
vec![ScalarValue::Int64(Some(split.stats as i64))],
)
.map(Some)
}
StatsAgg::CountField {
arg_type,
column_name,
} => {
let split = split_count_field_files(scan_input_stats, column_name);
if !split.has_stats_files() {
return Ok(None);
}
let wrapper = StateWrapper::new((*count_udaf()).clone())?;
wrapper
.value_from_custom_state_fields(
std::slice::from_ref(arg_type),
vec![ScalarValue::Int64(Some(split.stats as i64))],
)
.map(Some)
}
StatsAgg::CountTimeIndex { arg_type } => {
let split = split_count_star_files(scan_input_stats);
if !split.has_stats_files() {
return Ok(None);
}
let wrapper = StateWrapper::new((*count_udaf()).clone())?;
wrapper
.value_from_custom_state_fields(
std::slice::from_ref(arg_type),
vec![ScalarValue::Int64(Some(split.stats as i64))],
)
.map(Some)
}
StatsAgg::MinField {
arg_type,
column_name,
} => {
let split = split_min_max_field_files(scan_input_stats, column_name);
let Some(value) = split.stats.min.as_ref() else {
return Ok(None);
};
let wrapper = StateWrapper::new((*min_udaf()).clone())?;
wrapper
.value_from_custom_state_fields(
std::slice::from_ref(arg_type),
vec![stats_value_scalar(value, arg_type)?],
)
.map(Some)
}
StatsAgg::MinTimeIndex { arg_type } => {
let split = split_time_files(scan_input_stats);
let Some(timestamp) = split.stats.min else {
return Ok(None);
};
let wrapper = StateWrapper::new((*min_udaf()).clone())?;
wrapper
.value_from_custom_state_fields(
std::slice::from_ref(arg_type),
vec![timestamp_scalar_value(&timestamp, arg_type)?],
)
.map(Some)
}
StatsAgg::MaxField {
arg_type,
column_name,
} => {
let split = split_min_max_field_files(scan_input_stats, column_name);
let Some(value) = split.stats.max.as_ref() else {
return Ok(None);
};
let wrapper = StateWrapper::new((*max_udaf()).clone())?;
wrapper
.value_from_custom_state_fields(
std::slice::from_ref(arg_type),
vec![stats_value_scalar(value, arg_type)?],
)
.map(Some)
}
StatsAgg::MaxTimeIndex { arg_type } => {
let split = split_time_files(scan_input_stats);
let Some(timestamp) = split.stats.max else {
return Ok(None);
};
let wrapper = StateWrapper::new((*max_udaf()).clone())?;
wrapper
.value_from_custom_state_fields(
std::slice::from_ref(arg_type),
vec![timestamp_scalar_value(&timestamp, arg_type)?],
)
.map(Some)
}
}
}
#[allow(dead_code)]
fn timestamp_scalar_value(timestamp: &Timestamp, arg_type: &DataType) -> DfResult<ScalarValue> {
match arg_type {
DataType::Timestamp(unit, tz) => {
let converted = timestamp.convert_to((*unit).into()).ok_or_else(|| {
datafusion_common::DataFusionError::Internal(format!(
"failed to convert timestamp {timestamp:?} to {unit:?}"
))
})?;
Ok(match unit {
datatypes::arrow::datatypes::TimeUnit::Second => {
ScalarValue::TimestampSecond(Some(converted.value()), tz.clone())
}
datatypes::arrow::datatypes::TimeUnit::Millisecond => {
ScalarValue::TimestampMillisecond(Some(converted.value()), tz.clone())
}
datatypes::arrow::datatypes::TimeUnit::Microsecond => {
ScalarValue::TimestampMicrosecond(Some(converted.value()), tz.clone())
}
datatypes::arrow::datatypes::TimeUnit::Nanosecond => {
ScalarValue::TimestampNanosecond(Some(converted.value()), tz.clone())
}
})
}
_ => Err(datafusion_common::DataFusionError::Internal(format!(
"expected timestamp arg type, got {arg_type:?}"
))),
}
}
#[allow(dead_code)]
fn stats_value_scalar(value: &Value, arg_type: &DataType) -> DfResult<ScalarValue> {
let output_type = ConcreteDataType::try_from(arg_type).map_err(|err| {
datafusion_common::DataFusionError::Internal(format!(
"failed to convert arrow type {arg_type:?} to concrete type: {err}"
))
})?;
value.try_to_scalar_value(&output_type).map_err(|err| {
datafusion_common::DataFusionError::Internal(format!(
"failed to convert stats value {value:?} to scalar for {arg_type:?}: {err}"
))
})
}

View File

@@ -0,0 +1,705 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use arrow::array::{Int64Array, TimestampMillisecondArray};
use arrow::datatypes::{Field, Schema};
use common_time::Timestamp;
use common_time::timestamp::TimeUnit as TimestampUnit;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion_physical_expr::expressions::{Column, Literal, PhysicalSortExpr};
use datatypes::arrow::array::AsArray;
use datatypes::arrow::datatypes::{DataType, TimeUnit};
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema as GreptimeSchema};
use datatypes::value::Value;
use session::context::QueryContext;
use store_api::scan_stats::{
RegionScanColumnStats as RegionScanColumnInputStats,
RegionScanFileStats as RegionScanFileInputStats, RegionScanStats as RegionScanInputStats,
};
use table::metadata::{TableInfoBuilder, TableMetaBuilder};
use table::test_util::EmptyTable;
use super::StatsAgg;
use super::check::{RejectReason, RewriteCheck, is_supported_aggregate_name};
use super::split::{
FileStatsRequirement, StatsAggExt, has_partition_expr_mismatch, partial_state_from_stats,
split_count_field_files, split_count_star_files, split_min_max_field_files, split_time_files,
};
use crate::parser::QueryLanguageParser;
use crate::tests::new_query_engine_with_table;
fn test_timestamp(value: i64) -> Timestamp {
Timestamp::new(value, TimestampUnit::Millisecond)
}
fn field_stats(
exact_non_null_rows: Option<usize>,
min_value: Option<Value>,
max_value: Option<Value>,
) -> HashMap<String, RegionScanColumnInputStats> {
HashMap::from([(
"value".to_string(),
RegionScanColumnInputStats {
min_value,
max_value,
exact_non_null_rows,
},
)])
}
fn build_test_aggr_expr(
distinct: bool,
ignore_nulls: bool,
order_by: bool,
) -> datafusion_common::Result<AggregateFunctionExpr> {
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int64,
true,
)]));
let args = vec![Arc::new(Column::new("value", 0)) as Arc<dyn PhysicalExpr>];
let mut builder = AggregateExprBuilder::new(Arc::new((*count_udaf()).clone()), args)
.schema(schema)
.alias("count(value)");
if distinct {
builder = builder.with_distinct(true);
}
if ignore_nulls {
builder = builder.ignore_nulls();
}
if order_by {
builder = builder.order_by(vec![PhysicalSortExpr {
expr: Arc::new(Column::new("value", 0)),
options: Default::default(),
}]);
}
builder.build()
}
fn build_count_star_aggr_expr() -> datafusion_common::Result<AggregateFunctionExpr> {
let schema = Arc::new(Schema::empty());
let args = vec![Arc::new(Literal::new(COUNT_STAR_EXPANSION)) as Arc<dyn PhysicalExpr>];
AggregateExprBuilder::new(Arc::new((*count_udaf()).clone()), args)
.schema(schema)
.alias("count(*)")
.build()
}
fn new_test_engine() -> crate::QueryEngineRef {
let columns = vec![
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true),
];
let schema = Arc::new(GreptimeSchema::new(columns));
let table_meta = TableMetaBuilder::empty()
.schema(schema)
.primary_key_indices(vec![0])
.value_indices(vec![1])
.next_column_id(1024)
.build()
.unwrap();
let table_info = TableInfoBuilder::new("test", table_meta).build().unwrap();
let table = EmptyTable::from_table_info(&table_info);
new_query_engine_with_table(table)
}
async fn parse_sql_to_plan(sql: &str) -> LogicalPlan {
let query_ctx = QueryContext::arc();
let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
new_test_engine()
.planner()
.plan(&stmt, query_ctx)
.await
.unwrap()
}
#[test]
fn test_supported_aggregate_names() {
assert!(is_supported_aggregate_name("min"));
assert!(is_supported_aggregate_name("MAX"));
assert!(is_supported_aggregate_name("count"));
assert!(!is_supported_aggregate_name("sum"));
assert!(!is_supported_aggregate_name("avg(value)"));
}
#[test]
fn test_file_stats_requirement_matrix() {
let count_star = StatsAgg::CountStar;
assert_eq!(
count_star.file_stats_requirement(),
FileStatsRequirement::FileExactRowCount
);
let time_min = StatsAgg::MinTimeIndex {
arg_type: DataType::Timestamp(TimeUnit::Millisecond, None),
};
assert_eq!(
time_min.file_stats_requirement(),
FileStatsRequirement::FileTimeRange
);
let field_count = StatsAgg::CountField {
column_name: "value".to_string(),
arg_type: DataType::Int64,
};
assert_eq!(
field_count.file_stats_requirement(),
FileStatsRequirement::RowGroupNullCount
);
}
#[test]
fn test_count_star_file_stats_eligibility() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: None,
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(42),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: false,
},
RegionScanFileInputStats {
file_ordinal: 2,
exact_num_rows: Some(7),
time_range: Some((test_timestamp(50), test_timestamp(60))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
],
};
assert!(StatsAgg::CountStar.has_stats_files(&stats));
}
#[test]
fn test_split_count_star_files() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: None,
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 2,
exact_num_rows: Some(5),
time_range: Some((test_timestamp(50), test_timestamp(60))),
field_stats: HashMap::new(),
partition_expr_matches_region: false,
},
],
};
let split = split_count_star_files(&stats);
assert_eq!(split.stats_file_ordinals, vec![0]);
assert_eq!(split.scan_file_ordinals, vec![1, 2]);
assert_eq!(split.stats, 3);
}
#[test]
fn test_split_count_star_files_keeps_zero_row_files_stats_eligible() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(0),
time_range: None,
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: None,
time_range: None,
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
],
};
let split = split_count_star_files(&stats);
assert_eq!(split.stats_file_ordinals, vec![0]);
assert_eq!(split.scan_file_ordinals, vec![1]);
assert_eq!(split.stats, 0);
}
#[test]
fn test_supported_aggregate_rejects_distinct_ignore_nulls_and_order_by_shapes() {
let distinct = build_test_aggr_expr(true, false, false).unwrap();
let ignore_nulls = build_test_aggr_expr(false, true, false).unwrap();
let order_by = build_test_aggr_expr(false, false, true).unwrap();
assert!(matches!(
RewriteCheck::check_agg_shape(&distinct),
Err(RejectReason::UnsupportedAggregate)
));
assert!(matches!(
RewriteCheck::check_agg_shape(&ignore_nulls),
Err(RejectReason::UnsupportedAggregate)
));
assert!(matches!(
RewriteCheck::check_agg_shape(&order_by),
Err(RejectReason::UnsupportedAggregate)
));
}
#[test]
fn test_count_star_expansion_is_treated_as_count_star() {
let expr = build_count_star_aggr_expr().unwrap();
assert_eq!(expr.fun().name(), "count");
assert_eq!(expr.expressions().len(), 1);
assert!(
expr.expressions()[0]
.as_any()
.downcast_ref::<Literal>()
.is_some_and(|lit| lit.value() == &COUNT_STAR_EXPANSION)
);
}
#[tokio::test]
async fn test_sql_count_star_is_planned_with_count_star_expansion() {
let plan = parse_sql_to_plan("select count(*) from test").await;
let LogicalPlan::Projection(projection) = plan else {
panic!("expected projection over aggregate plan, got {plan:?}");
};
let LogicalPlan::Aggregate(aggregate) = projection.input.as_ref() else {
panic!("expected aggregate input, got {:?}", projection.input);
};
assert_eq!(aggregate.aggr_expr.len(), 1);
let Expr::AggregateFunction(AggregateFunction {
func,
params: AggregateFunctionParams { args, .. },
}) = &aggregate.aggr_expr[0]
else {
panic!(
"expected aggregate function expr, got {:?}",
aggregate.aggr_expr[0]
);
};
assert_eq!(func.name(), "count");
assert_eq!(args.len(), 1);
assert!(matches!(
&args[0],
Expr::Literal(value, _) if value == &COUNT_STAR_EXPANSION
));
}
#[test]
fn test_partition_expr_mismatch_detection() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(1),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(2),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: false,
},
],
};
assert!(has_partition_expr_mismatch(Some(&stats)));
}
#[test]
fn test_time_file_stats_eligibility() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: None,
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: None,
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
],
};
assert!(
StatsAgg::MinTimeIndex {
arg_type: DataType::Timestamp(TimeUnit::Millisecond, None),
}
.has_stats_files(&stats)
);
}
#[test]
fn test_split_time_files() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(50), test_timestamp(70))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: None,
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 2,
exact_num_rows: Some(5),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 3,
exact_num_rows: Some(6),
time_range: Some((test_timestamp(5), test_timestamp(100))),
field_stats: HashMap::new(),
partition_expr_matches_region: false,
},
],
};
let split = split_time_files(&stats);
assert_eq!(split.stats_file_ordinals, vec![0, 2]);
assert_eq!(split.scan_file_ordinals, vec![1, 3]);
assert_eq!(split.stats.min, Some(test_timestamp(10)));
assert_eq!(split.stats.max, Some(test_timestamp(70)));
}
#[test]
fn test_count_field_file_stats_eligibility() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(5),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: field_stats(Some(4), Some(Value::Int64(1)), Some(Value::Int64(9))),
partition_expr_matches_region: true,
},
],
};
assert!(
StatsAgg::CountField {
column_name: "value".to_string(),
arg_type: DataType::Int64,
}
.has_stats_files(&stats)
);
}
#[test]
fn test_split_count_field_files() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 2,
exact_num_rows: Some(5),
time_range: Some((test_timestamp(50), test_timestamp(60))),
field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(8))),
partition_expr_matches_region: true,
},
],
};
let split = split_count_field_files(&stats, "value");
assert_eq!(split.stats_file_ordinals, vec![0, 2]);
assert_eq!(split.scan_file_ordinals, vec![1]);
assert_eq!(split.stats, 6);
}
#[test]
fn test_min_max_field_stats_eligibility() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: false,
},
],
};
assert!(
StatsAgg::MinField {
column_name: "value".to_string(),
arg_type: DataType::Int64,
}
.has_stats_files(&stats)
);
}
#[test]
fn test_split_min_max_field_files() {
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: field_stats(Some(2), Some(Value::Int64(4)), Some(Value::Int64(9))),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 2,
exact_num_rows: Some(5),
time_range: Some((test_timestamp(50), test_timestamp(60))),
field_stats: field_stats(Some(4), Some(Value::Int64(1)), Some(Value::Int64(7))),
partition_expr_matches_region: true,
},
],
};
let split = split_min_max_field_files(&stats, "value");
assert_eq!(split.stats_file_ordinals, vec![0, 2]);
assert_eq!(split.scan_file_ordinals, vec![1]);
assert_eq!(split.stats.min, Some(Value::Int64(1)));
assert_eq!(split.stats.max, Some(Value::Int64(9)));
}
#[test]
fn test_partial_state_from_stats_count_star() {
let aggregate = StatsAgg::CountStar;
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
],
};
let value = partial_state_from_stats(&aggregate, &stats)
.unwrap()
.unwrap();
let array = value.to_array().unwrap();
let struct_array = array.as_struct();
let count_values = struct_array
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(count_values.value(0), 7);
}
#[test]
fn test_partial_state_from_stats_count_field() {
let aggregate = StatsAgg::CountField {
column_name: "value".to_string(),
arg_type: DataType::Int64,
};
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(9))),
partition_expr_matches_region: true,
},
],
};
let value = partial_state_from_stats(&aggregate, &stats)
.unwrap()
.unwrap();
let array = value.to_array().unwrap();
let struct_array = array.as_struct();
let count_values = struct_array
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(count_values.value(0), 6);
}
#[test]
fn test_partial_state_from_stats_min_time() {
let aggregate = StatsAgg::MinTimeIndex {
arg_type: DataType::Timestamp(TimeUnit::Millisecond, None),
};
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(50), test_timestamp(70))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(5),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: HashMap::new(),
partition_expr_matches_region: true,
},
],
};
let value = partial_state_from_stats(&aggregate, &stats)
.unwrap()
.unwrap();
let array = value.to_array().unwrap();
let struct_array = array.as_struct();
let ts_values = struct_array
.column(0)
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.unwrap();
assert_eq!(ts_values.value(0), 10);
}
#[test]
fn test_partial_state_from_stats_max_field() {
let aggregate = StatsAgg::MaxField {
column_name: "value".to_string(),
arg_type: DataType::Int64,
};
let stats = RegionScanInputStats {
files: vec![
RegionScanFileInputStats {
file_ordinal: 0,
exact_num_rows: Some(3),
time_range: Some((test_timestamp(10), test_timestamp(20))),
field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))),
partition_expr_matches_region: true,
},
RegionScanFileInputStats {
file_ordinal: 1,
exact_num_rows: Some(4),
time_range: Some((test_timestamp(30), test_timestamp(40))),
field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(9))),
partition_expr_matches_region: true,
},
],
};
let value = partial_state_from_stats(&aggregate, &stats)
.unwrap()
.unwrap();
let array = value.to_array().unwrap();
let struct_array = array.as_struct();
let max_values = struct_array
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(max_values.value(0), 9);
}

View File

@@ -1,177 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use common_telemetry::debug;
use datafusion::config::ConfigOptions;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion_common::Result as DfResult;
use datafusion_common::tree_node::{Transformed, TreeNode};
use table::table::scan::RegionScanExec;
#[derive(Debug)]
pub struct AggregateStats;
impl PhysicalOptimizerRule for AggregateStats {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> DfResult<Arc<dyn ExecutionPlan>> {
Self::do_optimize(plan)
}
fn name(&self) -> &str {
"aggregate_stats"
}
fn schema_check(&self) -> bool {
true
}
}
impl AggregateStats {
fn do_optimize(plan: Arc<dyn ExecutionPlan>) -> DfResult<Arc<dyn ExecutionPlan>> {
let result = plan
.transform_down(|plan| {
let Some(aggregate_exec) = plan.as_any().downcast_ref::<AggregateExec>() else {
return Ok(Transformed::no(plan));
};
let Some(region_scan) = Self::extract_region_scan(aggregate_exec) else {
return Ok(Transformed::no(plan));
};
let eligibility = AggregateStatsEligibility::new(aggregate_exec, region_scan);
if let Some(reason) = eligibility.ineligible_reason() {
debug!("Skip aggregate stats optimization: {reason}");
return Ok(Transformed::no(plan));
}
// TODO(ruihang): implement mixed stats-plus-scan rewrite in follow-up tasks.
Ok(Transformed::no(plan))
})?
.data;
Ok(result)
}
fn extract_region_scan(aggregate_exec: &AggregateExec) -> Option<&RegionScanExec> {
let child = aggregate_exec.children().into_iter().next()?;
child.as_any().downcast_ref::<RegionScanExec>()
}
}
#[derive(Debug)]
struct AggregateStatsEligibility<'a> {
aggregate_exec: &'a AggregateExec,
region_scan: &'a RegionScanExec,
}
impl<'a> AggregateStatsEligibility<'a> {
fn new(aggregate_exec: &'a AggregateExec, region_scan: &'a RegionScanExec) -> Self {
Self {
aggregate_exec,
region_scan,
}
}
fn ineligible_reason(&self) -> Option<EligibilityRejection> {
if !self.region_scan.append_mode() {
return Some(EligibilityRejection::NonAppendOnly);
}
if !self.aggregate_exec.group_expr().is_empty() {
return Some(EligibilityRejection::GroupedAggregate);
}
if !self.has_supported_aggregates() {
return Some(EligibilityRejection::UnsupportedAggregate);
}
if !self.has_stats_eligible_candidates() {
return Some(EligibilityRejection::NoStatsEligibleFiles);
}
None
}
fn has_supported_aggregates(&self) -> bool {
let aggr_exprs = self.aggregate_exec.aggr_expr();
!aggr_exprs.is_empty()
&& aggr_exprs
.iter()
.all(|expr| is_supported_aggregate_name(expr.name()))
}
fn has_stats_eligible_candidates(&self) -> bool {
// TODO(ruihang): replace this scaffold with per-file stats classification.
self.region_scan.total_rows() > 0
}
}
#[derive(Debug)]
enum EligibilityRejection {
NonAppendOnly,
GroupedAggregate,
UnsupportedAggregate,
NoStatsEligibleFiles,
}
impl std::fmt::Display for EligibilityRejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EligibilityRejection::NonAppendOnly => {
write!(f, "aggregate stats MVP only supports append-only scans")
}
EligibilityRejection::GroupedAggregate => {
write!(f, "aggregate stats MVP does not support GROUP BY yet")
}
EligibilityRejection::UnsupportedAggregate => {
write!(
f,
"aggregate stats MVP only supports min/max/count aggregates"
)
}
EligibilityRejection::NoStatsEligibleFiles => {
write!(
f,
"aggregate stats rewrite requires at least one stats-eligible file"
)
}
}
}
}
fn is_supported_aggregate_name(name: &str) -> bool {
let normalized = name.split('(').next().unwrap_or(name).to_ascii_lowercase();
matches!(normalized.as_str(), "min" | "max" | "count")
}
#[cfg(test)]
mod tests {
use super::is_supported_aggregate_name;
#[test]
fn test_supported_aggregate_names() {
assert!(is_supported_aggregate_name("min"));
assert!(is_supported_aggregate_name("max(value)"));
assert!(is_supported_aggregate_name("count(*)"));
assert!(!is_supported_aggregate_name("sum(value)"));
assert!(!is_supported_aggregate_name("avg(value)"));
}
}

View File

@@ -59,7 +59,7 @@ use crate::dist_plan::{
};
use crate::metrics::{QUERY_MEMORY_POOL_REJECTED_TOTAL, QUERY_MEMORY_POOL_USAGE_BYTES};
use crate::optimizer::ExtensionAnalyzerRule;
use crate::optimizer::aggregate_stats::AggregateStats;
use crate::optimizer::aggr_stats::AggregateStats;
use crate::optimizer::constant_term::MatchesConstantTermOptimizer;
use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
use crate::optimizer::parallelize_scan::ParallelizeScan;

View File

@@ -26,6 +26,7 @@ pub mod mito_engine_options;
pub mod path_utils;
pub mod region_engine;
pub mod region_request;
pub mod scan_stats;
pub mod sst_entry;
pub mod storage;

View File

@@ -37,6 +37,7 @@ use crate::metadata::RegionMetadataRef;
use crate::region_request::{
BatchRegionDdlRequest, RegionCatchupRequest, RegionOpenRequest, RegionRequest,
};
use crate::scan_stats::RegionScanStats;
use crate::storage::{FileId, RegionId, ScanRequest, SequenceNumber};
/// The settable region role state.
@@ -451,6 +452,11 @@ pub trait RegionScanner: Debug + DisplayAs + Send {
/// Check if there is any predicate exclude region partition exprs that may be executed in this scanner.
fn has_predicate_without_region(&self) -> bool;
/// Returns file-level scan statistics for the current scanner input when available.
fn scan_input_stats(&self) -> Result<Option<RegionScanStats>, BoxedError> {
Ok(None)
}
/// Add the given dynamic filter expressions to the predicate of the scanner.
/// Returns a vector of booleans indicating which filter expressions were applied.
/// true indicates the filter expression was applied(will be use by scanner to prune by stat for row group),

View File

@@ -0,0 +1,45 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Trusted, scan-derived statistics exposed by region scanners for higher-level optimizations.
//!
//! These are not raw storage-format statistics. They are conservative summaries that scanner
//! implementations can safely expose for optimizer consumption.
use std::collections::HashMap;
use common_time::Timestamp;
use datatypes::value::Value;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RegionScanColumnStats {
pub min_value: Option<Value>,
pub max_value: Option<Value>,
pub exact_non_null_rows: Option<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RegionScanFileStats {
/// Stable file id within one `RegionScanStats` snapshot.
pub file_ordinal: usize,
pub exact_num_rows: Option<usize>,
pub time_range: Option<(Timestamp, Timestamp)>,
pub field_stats: HashMap<String, RegionScanColumnStats>,
pub partition_expr_matches_region: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RegionScanStats {
pub files: Vec<RegionScanFileStats>,
}

View File

@@ -48,6 +48,7 @@ use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
use store_api::region_engine::{
PartitionRange, PrepareRequest, QueryScanContext, RegionScannerRef,
};
use store_api::scan_stats::RegionScanStats;
use store_api::storage::{ScanRequest, TimeSeriesDistribution};
use crate::table::metrics::StreamMetrics;
@@ -315,6 +316,18 @@ impl RegionScanExec {
self.total_rows
}
pub fn has_predicate_without_region(&self) -> bool {
self.scanner.lock().unwrap().has_predicate_without_region()
}
pub fn scan_input_stats(&self) -> DfResult<Option<RegionScanStats>> {
self.scanner
.lock()
.unwrap()
.scan_input_stats()
.map_err(|err| DataFusionError::External(err.into()))
}
pub fn with_distinguish_partition_range(&self, distinguish_partition_range: bool) {
let mut scanner = self.scanner.lock().unwrap();
// set distinguish_partition_range won't fail