diff --git a/src/common/function/src/aggrs/aggr_wrapper.rs b/src/common/function/src/aggrs/aggr_wrapper.rs index 6242ab9454..7b67e11fba 100644 --- a/src/common/function/src/aggrs/aggr_wrapper.rs +++ b/src/common/function/src/aggrs/aggr_wrapper.rs @@ -330,6 +330,39 @@ impl StateWrapper { acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?; Ok(acc_args) } + + /// Builds a state scalar from explicit state-field values. + /// + /// The caller must provide one scalar per state field in the wrapper's state layout. + /// This method is responsible only for validating the current wrapper state type and + /// assembling the final struct scalar from those explicit field values. + pub fn value_from_custom_state_fields( + &self, + arg_types: &[DataType], + state_values: Vec, + ) -> datafusion_common::Result { + let DataType::Struct(fields) = self.return_type(arg_types)? else { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Expected struct state type for {}, got non-struct return type", + self.name() + ))); + }; + if fields.len() != state_values.len() { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Expected {} state fields for {}, got {}", + fields.len(), + self.name(), + state_values.len() + ))); + } + + let arrays = state_values + .into_iter() + .map(|value| value.to_array()) + .collect::>>()?; + let struct_array = build_state_struct_array(&fields, arrays)?; + Ok(ScalarValue::Struct(Arc::new(struct_array))) + } } impl AggregateUDFImpl for StateWrapper { @@ -472,13 +505,59 @@ impl AggregateUDFImpl for StateWrapper { }; let array = ret.to_array().ok()?; - let struct_array = StructArray::new(fields.clone(), vec![array], None); let ret = ScalarValue::Struct(Arc::new(struct_array)); Some(ret) } } +fn build_state_struct_array( + fields: &Fields, + arrays: Vec, +) -> datafusion_common::Result { + let array_type = arrays + .iter() + .map(|array| array.data_type().clone()) + .collect::>(); + let expected_type = fields + .iter() + .map(|field| field.data_type().clone()) + .collect::>(); + if array_type != expected_type { + // Keep this fallback intentionally lenient. + // + // Historically the wrapper path has tolerated state-schema drift as long as the + // physical state columns remain positionally compatible. This shows up most clearly + // in order-sensitive aggregates such as first_value/last_value, where DataFusion-side + // state metadata and the arrays we need to wrap may not line up exactly. The merge + // path consumes state columns by position, not by field metadata, so preserving a + // struct wrapper here is more compatible than failing eagerly on field/type mismatch. + debug!( + "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}", + fields.len(), + arrays.len(), + fields, + array_type, + ); + let guess_schema = arrays + .iter() + .enumerate() + .map(|(index, array)| { + Field::new( + format!("col_{index}[mismatch_state]").as_str(), + array.data_type().clone(), + true, + ) + }) + .collect::(); + return StructArray::try_new(guess_schema, arrays, None) + .map_err(|err| datafusion_common::DataFusionError::ArrowError(Box::new(err), None)); + } + + StructArray::try_new(fields.clone(), arrays, None) + .map_err(|err| datafusion_common::DataFusionError::ArrowError(Box::new(err), None)) +} + /// The wrapper's input is the same as the original aggregate function's input, /// and the output is the state function's output. #[derive(Debug)] @@ -510,42 +589,9 @@ impl StateGroupsAccum { } fn wrap_state_arrays(&self, arrays: Vec) -> datafusion_common::Result { - let array_type = arrays - .iter() - .map(|array| array.data_type().clone()) - .collect::>(); - let expected_type = self - .state_fields - .iter() - .map(|field| field.data_type().clone()) - .collect::>(); - if array_type != expected_type { - debug!( - "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}", - self.state_fields.len(), - arrays.len(), - self.state_fields, - array_type, - ); - let guess_schema = arrays - .iter() - .enumerate() - .map(|(index, array)| { - Field::new( - format!("col_{index}[mismatch_state]").as_str(), - array.data_type().clone(), - true, - ) - }) - .collect::(); - let array = StructArray::try_new(guess_schema, arrays, None)?; - return Ok(Arc::new(array)); - } - - Ok(Arc::new(StructArray::try_new( - self.state_fields.clone(), + Ok(Arc::new(build_state_struct_array( + &self.state_fields, arrays, - None, )?)) } } @@ -621,44 +667,11 @@ impl Accumulator for StateAccum { fn evaluate(&mut self) -> datafusion_common::Result { let state = self.inner.state()?; - let array = state + let arrays = state .iter() .map(|s| s.to_array()) .collect::, _>>()?; - let array_type = array - .iter() - .map(|a| a.data_type().clone()) - .collect::>(); - let expected_type: Vec<_> = self - .state_fields - .iter() - .map(|f| f.data_type().clone()) - .collect(); - if array_type != expected_type { - debug!( - "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}", - self.state_fields.len(), - array.len(), - self.state_fields, - array_type, - ); - let guess_schema = array - .iter() - .enumerate() - .map(|(index, array)| { - Field::new( - format!("col_{index}[mismatch_state]").as_str(), - array.data_type().clone(), - true, - ) - }) - .collect::(); - let arr = StructArray::try_new(guess_schema, array, None)?; - - return Ok(ScalarValue::Struct(Arc::new(arr))); - } - - let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?; + let struct_array = build_state_struct_array(&self.state_fields, arrays)?; Ok(ScalarValue::Struct(Arc::new(struct_array))) } @@ -860,7 +873,10 @@ impl Accumulator for MergeAccum { "State fields mismatch, expected: {:?}, got: {:?}", self.state_fields, fields ); - // state fields mismatch might be acceptable by datafusion, continue + // Intentionally continue here for compatibility with the wrapper's historical + // behavior: downstream merge logic uses the struct columns positionally, and some + // DataFusion/order-sensitive aggregate paths can produce equivalent state payloads + // whose field metadata does not exactly match our locally expected schema. } // now fields should be the same, so we can merge the batch diff --git a/src/common/function/src/aggrs/aggr_wrapper/tests.rs b/src/common/function/src/aggrs/aggr_wrapper/tests.rs index de3a77df6b..516babb114 100644 --- a/src/common/function/src/aggrs/aggr_wrapper/tests.rs +++ b/src/common/function/src/aggrs/aggr_wrapper/tests.rs @@ -28,6 +28,7 @@ use datafusion::datasource::DefaultTableSource; use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion::functions_aggregate::average::avg_udaf; use datafusion::functions_aggregate::count::count_udaf; +use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::optimizer::AnalyzerRule; use datafusion::optimizer::analyzer::type_coercion::TypeCoercion; @@ -291,6 +292,50 @@ fn create_avg_state_groups_accumulator() -> Box { state_wrapper.create_groups_accumulator(acc_args).unwrap() } +fn test_state_scalar_for_type(data_type: &DataType) -> ScalarValue { + match data_type { + DataType::Float64 => ScalarValue::Float64(Some(1.5)), + DataType::UInt64 => ScalarValue::UInt64(Some(2)), + DataType::Int64 => ScalarValue::Int64(Some(3)), + _ => panic!("unsupported test data type: {data_type:?}"), + } +} + +#[test] +fn test_value_from_custom_state_fields_single_field() { + let wrapper = StateWrapper::new((*max_udaf()).clone()).unwrap(); + let value = wrapper + .value_from_custom_state_fields(&[DataType::Int64], vec![ScalarValue::Int64(Some(7))]) + .unwrap(); + + let ScalarValue::Struct(array) = value else { + panic!("expected struct state") + }; + assert_eq!(1, array.columns().len()); + assert_eq!(DataType::Int64, array.column(0).data_type().clone()); +} + +#[test] +fn test_value_from_custom_state_fields_multi_field() { + let wrapper = StateWrapper::new((*avg_udaf()).clone()).unwrap(); + let DataType::Struct(fields) = wrapper.return_type(&[DataType::Float64]).unwrap() else { + panic!("expected struct state type") + }; + + let values = fields + .iter() + .map(|field| test_state_scalar_for_type(field.data_type())) + .collect::>(); + let value = wrapper + .value_from_custom_state_fields(&[DataType::Float64], values) + .unwrap(); + + let ScalarValue::Struct(array) = value else { + panic!("expected struct state") + }; + assert_eq!(fields.len(), array.columns().len()); +} + #[tokio::test] async fn test_sum_udaf() { let ctx = SessionContext::new(); diff --git a/src/mito2/src/read.rs b/src/mito2/src/read.rs index 240a99c247..57507c86f4 100644 --- a/src/mito2/src/read.rs +++ b/src/mito2/src/read.rs @@ -28,6 +28,7 @@ pub(crate) mod prune; pub(crate) mod pruner; pub mod range; pub(crate) mod range_cache; +pub(crate) mod scan_input_stats; pub mod scan_region; pub mod scan_util; pub(crate) mod seq_scan; diff --git a/src/mito2/src/read/scan_input_stats.rs b/src/mito2/src/read/scan_input_stats.rs new file mode 100644 index 0000000000..4e43d0cc4d --- /dev/null +++ b/src/mito2/src/read/scan_input_stats.rs @@ -0,0 +1,433 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use api::v1::SemanticType; +use common_error::ext::{BoxedError, PlainError}; +use common_error::status_code::StatusCode; +use datafusion_common::pruning::PruningStatistics; +use datafusion_common::{Column, ScalarValue}; +use datatypes::arrow::array::{Array, AsArray, UInt64Array}; +use datatypes::arrow::compute::{cast, max, max_boolean, max_string, min, min_boolean, min_string}; +use datatypes::arrow::datatypes::{ + DataType as ArrowDataType, Float32Type, Float64Type, Int32Type, Int64Type, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt32Type, UInt64Type, +}; +use datatypes::value::Value; +use store_api::metadata::RegionMetadata; +use store_api::scan_stats::{ + RegionScanColumnStats as RegionScanColumnInputStats, + RegionScanFileStats as RegionScanFileInputStats, RegionScanStats as RegionScanInputStats, +}; + +use crate::read::scan_region::ScanInput; +use crate::sst::file::FileHandle; +use crate::sst::parquet::format::ReadFormat; +use crate::sst::parquet::stats::RowGroupPruningStats; + +pub(crate) fn build_scan_input_stats( + input: &ScanInput, + metadata: &RegionMetadata, +) -> std::result::Result { + let files = input + .files + .iter() + .enumerate() + .map(|(index, file)| { + let partition_expr_matches_region = file_partition_expr_matches_region(metadata, file)?; + Ok(RegionScanFileInputStats { + file_ordinal: index, + exact_num_rows: exact_file_num_rows(file), + time_range: exact_file_time_range(file), + field_stats: build_file_field_stats(input, metadata, file)?, + partition_expr_matches_region, + }) + }) + .collect::, BoxedError>>()?; + + Ok(RegionScanInputStats { files }) +} + +fn exact_file_num_rows(file: &FileHandle) -> Option { + Some(file.num_rows()) +} + +fn exact_file_time_range( + file: &FileHandle, +) -> Option<(common_time::Timestamp, common_time::Timestamp)> { + (file.meta_ref().num_row_groups != 0).then_some(file.time_range()) +} + +fn build_file_field_stats( + input: &ScanInput, + metadata: &RegionMetadata, + file: &FileHandle, +) -> std::result::Result, BoxedError> { + // TODO(ruihang): extract stats only for columns referenced by the supported aggregates + // instead of eagerly materializing every field column for every file. + let Some(parquet_meta) = input + .cache_strategy + .get_parquet_meta_data_from_mem_cache(file.file_id()) + else { + return Ok(HashMap::new()); + }; + + let region_metadata = Arc::new(metadata.clone()); + let file_path = format!("{:?}", file.file_id()); + let read_format = ReadFormat::new( + region_metadata.clone(), + None, + input.flat_format, + Some(parquet_meta.file_metadata().schema_descr().num_columns()), + &file_path, + false, + ) + .map_err(BoxedError::new)?; + let row_groups = parquet_meta.row_groups(); + let pruning_stats = + RowGroupPruningStats::new(row_groups, &read_format, Some(region_metadata), false); + + metadata + .column_metadatas + .iter() + .filter(|column| column.semantic_type == SemanticType::Field) + .filter_map(|column| { + let stats = build_field_column_stats( + column.column_schema.name.as_str(), + row_groups, + &pruning_stats, + ) + .transpose(); + match stats { + Some(Ok(stats)) => Some(Ok((column.column_schema.name.to_string(), stats))), + Some(Err(err)) => Some(Err(err)), + None => None, + } + }) + .collect() +} + +fn build_field_column_stats( + column_name: &str, + row_groups: &[parquet::file::metadata::RowGroupMetaData], + pruning_stats: &impl PruningStatistics, +) -> std::result::Result, BoxedError> { + let column = Column::from_name(column_name); + let min_value = aggregate_column_min_value(pruning_stats.min_values(&column).as_deref())?; + let max_value = aggregate_column_max_value(pruning_stats.max_values(&column).as_deref())?; + let exact_non_null_rows = + aggregate_exact_non_null_rows(pruning_stats.null_counts(&column).as_deref(), row_groups)?; + + if min_value.is_none() && max_value.is_none() && exact_non_null_rows.is_none() { + return Ok(None); + } + + Ok(Some(RegionScanColumnInputStats { + min_value, + max_value, + exact_non_null_rows, + })) +} + +fn aggregate_column_min_value( + values: Option<&dyn Array>, +) -> std::result::Result, BoxedError> { + aggregate_column_extreme_value(values, true) +} + +fn aggregate_column_max_value( + values: Option<&dyn Array>, +) -> std::result::Result, BoxedError> { + aggregate_column_extreme_value(values, false) +} + +fn aggregate_column_extreme_value( + values: Option<&dyn Array>, + is_min: bool, +) -> std::result::Result, BoxedError> { + let Some(values) = values else { + return Ok(None); + }; + if values.is_empty() || values.null_count() > 0 { + return Ok(None); + } + + if let Some(value) = aggregate_column_extreme_value_with_compute(values, is_min)? { + return Ok(Some(value)); + } + + aggregate_column_extreme_value_fallback(values, is_min) +} + +fn aggregate_column_extreme_value_with_compute( + values: &dyn Array, + is_min: bool, +) -> std::result::Result, BoxedError> { + if let ArrowDataType::Dictionary(_, value_type) = values.data_type() { + let casted = cast(values, value_type.as_ref()).map_err(|err| { + BoxedError::new(PlainError::new( + format!("failed to cast dictionary stats array to value type: {err}"), + StatusCode::Unexpected, + )) + })?; + return aggregate_column_extreme_value_with_compute(casted.as_ref(), is_min); + } + + macro_rules! compute_primitive_extreme { + ($array_ty:ty, $variant:ident) => {{ + let array = values.as_primitive::<$array_ty>(); + let scalar = if is_min { + min(array).map(|value| ScalarValue::$variant(Some(value))) + } else { + max(array).map(|value| ScalarValue::$variant(Some(value))) + }; + scalar + .map(|value| Value::try_from(value).map_err(BoxedError::new)) + .transpose() + }}; + } + + macro_rules! compute_timestamp_extreme { + ($array_ty:ty, $variant:ident, $tz:expr) => {{ + let array = values.as_primitive::<$array_ty>(); + let scalar = if is_min { + min(array).map(|value| ScalarValue::$variant(Some(value), $tz.clone())) + } else { + max(array).map(|value| ScalarValue::$variant(Some(value), $tz.clone())) + }; + scalar + .map(|value| Value::try_from(value).map_err(BoxedError::new)) + .transpose() + }}; + } + + match values.data_type() { + ArrowDataType::Boolean => { + let array = values.as_boolean(); + let scalar = if is_min { + min_boolean(array).map(|value| ScalarValue::Boolean(Some(value))) + } else { + max_boolean(array).map(|value| ScalarValue::Boolean(Some(value))) + }; + scalar + .map(|value| Value::try_from(value).map_err(BoxedError::new)) + .transpose() + } + ArrowDataType::Utf8 => { + let array = values.as_string::(); + let scalar = if is_min { + min_string(array) + } else { + max_string(array) + } + .map(|value| ScalarValue::Utf8(Some(value.to_string()))); + scalar + .map(|value| Value::try_from(value).map_err(BoxedError::new)) + .transpose() + } + ArrowDataType::LargeUtf8 => { + let array = values.as_string::(); + let scalar = if is_min { + min_string(array) + } else { + max_string(array) + } + .map(|value| ScalarValue::LargeUtf8(Some(value.to_string()))); + scalar + .map(|value| Value::try_from(value).map_err(BoxedError::new)) + .transpose() + } + ArrowDataType::UInt32 => compute_primitive_extreme!(UInt32Type, UInt32), + ArrowDataType::UInt64 => compute_primitive_extreme!(UInt64Type, UInt64), + ArrowDataType::Int32 => compute_primitive_extreme!(Int32Type, Int32), + ArrowDataType::Int64 => compute_primitive_extreme!(Int64Type, Int64), + ArrowDataType::Float32 => compute_primitive_extreme!(Float32Type, Float32), + ArrowDataType::Float64 => compute_primitive_extreme!(Float64Type, Float64), + ArrowDataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Second, tz) => { + compute_timestamp_extreme!(TimestampSecondType, TimestampSecond, tz) + } + ArrowDataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Millisecond, tz) => { + compute_timestamp_extreme!(TimestampMillisecondType, TimestampMillisecond, tz) + } + ArrowDataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Microsecond, tz) => { + compute_timestamp_extreme!(TimestampMicrosecondType, TimestampMicrosecond, tz) + } + ArrowDataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Nanosecond, tz) => { + compute_timestamp_extreme!(TimestampNanosecondType, TimestampNanosecond, tz) + } + _ => Ok(None), + } +} + +fn aggregate_column_extreme_value_fallback( + values: &dyn Array, + is_min: bool, +) -> std::result::Result, BoxedError> { + let scalars = (0..values.len()) + .map(|index| { + ScalarValue::try_from_array(values, index).map_err(|err| { + BoxedError::new(PlainError::new( + format!("failed to extract scalar value from stats array: {err}"), + StatusCode::Unexpected, + )) + }) + }) + .collect::, _>>()?; + let mut iter = scalars + .into_iter() + .map(|value| Value::try_from(value).map_err(BoxedError::new)); + let Some(first) = iter.next() else { + return Ok(None); + }; + let first = first?; + + iter.try_fold(first, |current, value| { + let value = value?; + let next = if is_min { + Value::min(current, value) + } else { + Value::max(current, value) + }; + Ok::<_, BoxedError>(next) + }) + .map(Some) +} + +fn aggregate_exact_non_null_rows( + null_counts: Option<&dyn Array>, + row_groups: &[parquet::file::metadata::RowGroupMetaData], +) -> std::result::Result, BoxedError> { + let Some(null_counts) = null_counts else { + return Ok(None); + }; + if null_counts.null_count() > 0 { + return Ok(None); + } + let Some(null_counts) = null_counts.as_any().downcast_ref::() else { + return Ok(None); + }; + + row_groups + .iter() + .zip(null_counts.iter()) + .try_fold(0usize, |acc, (row_group, null_count)| { + let row_count = usize::try_from(row_group.num_rows()).map_err(|err| { + BoxedError::new(PlainError::new( + format!("failed to convert row group row count to usize: {err}"), + StatusCode::Unexpected, + )) + })?; + let null_count = usize::try_from(null_count.unwrap_or_default()).map_err(|err| { + BoxedError::new(PlainError::new( + format!("failed to convert parquet null count to usize: {err}"), + StatusCode::Unexpected, + )) + })?; + Ok::<_, BoxedError>(acc + row_count.saturating_sub(null_count)) + }) + .map(Some) +} + +fn file_partition_expr_matches_region( + metadata: &RegionMetadata, + file: &FileHandle, +) -> std::result::Result { + let file_partition_expr = file + .meta_ref() + .partition_expr + .as_ref() + .map(|expr| expr.as_json_str()) + .transpose() + .map_err(BoxedError::new)?; + Ok(file_partition_expr == metadata.partition_expr) +} + +#[cfg(test)] +mod tests { + use common_time::timestamp::TimeUnit as TimestampUnit; + use datatypes::arrow::array::{ + DictionaryArray, Int64Array, StringArray, TimestampMillisecondArray, UInt64Array, + }; + use datatypes::arrow::datatypes::Int32Type; + + use super::*; + + #[test] + fn test_aggregate_column_extreme_value_uses_numeric_fast_path() { + let values = UInt64Array::from(vec![Some(7), Some(2), Some(11)]); + + let min_value = aggregate_column_min_value(Some(&values)).unwrap(); + let max_value = aggregate_column_max_value(Some(&values)).unwrap(); + + assert_eq!(Some(Value::UInt64(2)), min_value); + assert_eq!(Some(Value::UInt64(11)), max_value); + } + + #[test] + fn test_aggregate_column_extreme_value_uses_string_fast_path() { + let values = StringArray::from(vec![Some("delta"), Some("alpha"), Some("gamma")]); + + let min_value = aggregate_column_min_value(Some(&values)).unwrap(); + let max_value = aggregate_column_max_value(Some(&values)).unwrap(); + + assert_eq!(Some(Value::String("alpha".into())), min_value); + assert_eq!(Some(Value::String("gamma".into())), max_value); + } + + #[test] + fn test_aggregate_column_extreme_value_uses_timestamp_fast_path() { + let values = TimestampMillisecondArray::from(vec![Some(7), Some(2), Some(11)]); + + let min_value = aggregate_column_min_value(Some(&values)).unwrap(); + let max_value = aggregate_column_max_value(Some(&values)).unwrap(); + + assert_eq!( + Some(Value::Timestamp(common_time::Timestamp::new( + 2, + TimestampUnit::Millisecond + ))), + min_value + ); + assert_eq!( + Some(Value::Timestamp(common_time::Timestamp::new( + 11, + TimestampUnit::Millisecond + ))), + max_value + ); + } + + #[test] + fn test_aggregate_column_extreme_value_dictionary_falls_back() { + let values = + DictionaryArray::::from_iter([Some("delta"), Some("alpha"), Some("gamma")]); + + let min_value = aggregate_column_min_value(Some(&values)).unwrap(); + let max_value = aggregate_column_max_value(Some(&values)).unwrap(); + + assert_eq!(Some(Value::String("alpha".into())), min_value); + assert_eq!(Some(Value::String("gamma".into())), max_value); + } + + #[test] + fn test_aggregate_column_extreme_value_returns_none_when_any_stats_are_null() { + let values = Int64Array::from(vec![Some(7), None, Some(11)]); + + assert_eq!(None, aggregate_column_min_value(Some(&values)).unwrap()); + assert_eq!(None, aggregate_column_max_value(Some(&values)).unwrap()); + } +} diff --git a/src/mito2/src/read/seq_scan.rs b/src/mito2/src/read/seq_scan.rs index a1b3b8f350..b9a5fe85ed 100644 --- a/src/mito2/src/read/seq_scan.rs +++ b/src/mito2/src/read/seq_scan.rs @@ -32,6 +32,7 @@ use store_api::metadata::RegionMetadataRef; use store_api::region_engine::{ PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties, }; +use store_api::scan_stats::RegionScanStats; use store_api::storage::TimeSeriesRowSelector; use tokio::sync::Semaphore; @@ -43,6 +44,7 @@ use crate::read::last_row::{FlatLastRowReader, LastRowReader}; use crate::read::merge::MergeReaderBuilder; use crate::read::pruner::{PartitionPruner, Pruner}; use crate::read::range::RangeMeta; +use crate::read::scan_input_stats::build_scan_input_stats; use crate::read::scan_region::{ScanInput, StreamContext}; use crate::read::scan_util::{ PartitionMetrics, PartitionMetricsList, SplitRecordBatchStream, scan_file_ranges, @@ -669,6 +671,14 @@ impl RegionScanner for SeqScan { predicate.is_some() } + fn scan_input_stats(&self) -> Result, BoxedError> { + build_scan_input_stats( + &self.stream_ctx.input, + self.stream_ctx.input.mapper.metadata(), + ) + .map(Some) + } + fn add_dyn_filter_to_predicate( &mut self, filter_exprs: Vec>, diff --git a/src/mito2/src/read/series_scan.rs b/src/mito2/src/read/series_scan.rs index 2d6994d0af..39764183e6 100644 --- a/src/mito2/src/read/series_scan.rs +++ b/src/mito2/src/read/series_scan.rs @@ -36,6 +36,7 @@ use store_api::metadata::RegionMetadataRef; use store_api::region_engine::{ PartitionRange, PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties, }; +use store_api::scan_stats::RegionScanStats; use tokio::sync::Semaphore; use tokio::sync::mpsc::error::{SendTimeoutError, TrySendError}; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -45,6 +46,7 @@ use crate::error::{ ScanSeriesSnafu, TooManyFilesToReadSnafu, }; use crate::read::pruner::{PartitionPruner, Pruner}; +use crate::read::scan_input_stats::build_scan_input_stats; use crate::read::scan_region::{ScanInput, StreamContext}; use crate::read::scan_util::{PartitionMetrics, PartitionMetricsList, SeriesDistributorMetrics}; use crate::read::seq_scan::{SeqScan, build_flat_sources, build_sources}; @@ -364,6 +366,14 @@ impl RegionScanner for SeriesScan { predicate.is_some() } + fn scan_input_stats(&self) -> Result, BoxedError> { + build_scan_input_stats( + &self.stream_ctx.input, + self.stream_ctx.input.mapper.metadata(), + ) + .map(Some) + } + fn add_dyn_filter_to_predicate( &mut self, filter_exprs: Vec>, diff --git a/src/mito2/src/read/unordered_scan.rs b/src/mito2/src/read/unordered_scan.rs index 2d557e8871..ea2a86e09d 100644 --- a/src/mito2/src/read/unordered_scan.rs +++ b/src/mito2/src/read/unordered_scan.rs @@ -32,9 +32,11 @@ use store_api::metadata::RegionMetadataRef; use store_api::region_engine::{ PrepareRequest, QueryScanContext, RegionScanner, ScannerProperties, }; +use store_api::scan_stats::RegionScanStats; use crate::error::{PartitionOutOfRangeSnafu, Result}; use crate::read::pruner::{PartitionPruner, Pruner}; +use crate::read::scan_input_stats::build_scan_input_stats; use crate::read::scan_region::{ScanInput, StreamContext}; use crate::read::scan_util::{ PartitionMetrics, PartitionMetricsList, scan_file_ranges, scan_flat_file_ranges, @@ -495,6 +497,14 @@ impl RegionScanner for UnorderedScan { predicate.is_some() } + fn scan_input_stats(&self) -> Result, BoxedError> { + build_scan_input_stats( + &self.stream_ctx.input, + self.stream_ctx.input.mapper.metadata(), + ) + .map(Some) + } + fn add_dyn_filter_to_predicate( &mut self, filter_exprs: Vec>, diff --git a/src/query/src/optimizer.rs b/src/query/src/optimizer.rs index ff22971670..646a8ef36c 100644 --- a/src/query/src/optimizer.rs +++ b/src/query/src/optimizer.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub mod aggregate_stats; +pub mod aggr_stats; pub mod constant_term; pub mod count_wildcard; pub mod parallelize_scan; diff --git a/src/query/src/optimizer/aggr_stats.rs b/src/query/src/optimizer/aggr_stats.rs new file mode 100644 index 0000000000..31af3469eb --- /dev/null +++ b/src/query/src/optimizer/aggr_stats.rs @@ -0,0 +1,111 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use common_telemetry::debug; +use datafusion::config::ConfigOptions; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::aggregates::AggregateExec; +use datafusion_common::Result as DfResult; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datatypes::arrow::datatypes::DataType; +use table::table::scan::RegionScanExec; + +mod check; +mod split; +#[cfg(test)] +mod tests; + +use check::RewriteCheck; + +#[derive(Debug)] +pub struct AggregateStats; + +/// All supported aggregate from statistics +#[derive(Debug, Clone, PartialEq, Eq)] +enum StatsAgg { + CountStar, + CountField { + column_name: String, + arg_type: DataType, + }, + CountTimeIndex { + arg_type: DataType, + }, + MinField { + column_name: String, + arg_type: DataType, + }, + MinTimeIndex { + arg_type: DataType, + }, + MaxField { + column_name: String, + arg_type: DataType, + }, + MaxTimeIndex { + arg_type: DataType, + }, +} + +impl PhysicalOptimizerRule for AggregateStats { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> DfResult> { + Self::do_optimize(plan) + } + + fn name(&self) -> &str { + "aggregate_stats" + } + + fn schema_check(&self) -> bool { + true + } +} + +impl AggregateStats { + fn do_optimize(plan: Arc) -> DfResult> { + let result = plan + .transform_down(|plan| { + let Some(aggregate_exec) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + let Some(region_scan) = Self::extract_region_scan(aggregate_exec) else { + return Ok(Transformed::no(plan)); + }; + + let check = RewriteCheck::new(aggregate_exec, region_scan); + if let Some(reason) = check.skip_reason()? { + debug!("Skip aggregate stats optimization: {reason}"); + return Ok(Transformed::no(plan)); + } + + Ok(Transformed::no(plan)) + })? + .data; + + Ok(result) + } + + fn extract_region_scan(aggregate_exec: &AggregateExec) -> Option<&RegionScanExec> { + let child = aggregate_exec.children().into_iter().next()?; + child.as_any().downcast_ref::() + } +} diff --git a/src/query/src/optimizer/aggr_stats/check.rs b/src/query/src/optimizer/aggr_stats/check.rs new file mode 100644 index 0000000000..3dbeef2315 --- /dev/null +++ b/src/query/src/optimizer/aggr_stats/check.rs @@ -0,0 +1,221 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_telemetry::debug; +use datafusion::physical_plan::aggregates::AggregateExec; +use datafusion_common::Result as DfResult; +use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::expressions::{Column, Literal}; +use table::table::scan::RegionScanExec; + +use super::StatsAgg; +use super::split::{StatsAggExt, has_partition_expr_mismatch}; + +#[derive(Debug)] +pub(super) struct RewriteCheck<'a> { + aggregate_exec: &'a AggregateExec, + region_scan: &'a RegionScanExec, +} + +impl<'a> RewriteCheck<'a> { + pub(super) fn new(aggregate_exec: &'a AggregateExec, region_scan: &'a RegionScanExec) -> Self { + Self { + aggregate_exec, + region_scan, + } + } + + pub(super) fn skip_reason(&self) -> DfResult> { + // MVP only handles global aggregates over append-mode region scans. Anything + // else falls back to the normal execution path. + if !self.region_scan.append_mode() || !self.aggregate_exec.group_expr().is_empty() { + return Ok(Some(RejectReason::UnsupportedPlan)); + } + + let aggs = match self.parse_aggs() { + Ok(aggs) => aggs, + Err(reason) => return Ok(Some(reason)), + }; + + let scan_input_stats = self.try_scan_stats(); + if self.stats_unavailable(scan_input_stats.as_ref()) { + return Ok(Some(RejectReason::StatsUnavailable)); + } + + if !aggs + .iter() + .all(|agg| agg.has_stats_files(scan_input_stats.as_ref().unwrap())) + { + return Ok(Some(RejectReason::NoStatsFiles)); + } + + Ok(None) + } + + fn try_scan_stats(&self) -> Option { + match self.region_scan.scan_input_stats() { + Ok(stats) => stats, + Err(err) => { + debug!( + "Skip aggregate stats optimization: failed to collect scan input stats: {err}" + ); + None + } + } + } + + fn parse_aggs(&self) -> Result, RejectReason> { + let aggr_exprs = self.aggregate_exec.aggr_expr(); + if aggr_exprs.is_empty() { + return Err(RejectReason::UnsupportedAggregate); + } + + aggr_exprs.iter().map(|expr| self.parse_agg(expr)).collect() + } + + fn parse_agg(&self, expr: &AggregateFunctionExpr) -> Result { + if !is_supported_aggregate_name(expr.fun().name()) { + return Err(RejectReason::UnsupportedAggregate); + } + + Self::check_agg_shape(expr)?; + + let inputs = expr.expressions(); + let name = expr.fun().name().to_ascii_lowercase(); + // COUNT(*) is usually rewrite to COUNT(time-index) + // before this physical optimizer runs, so CountStar is mostly a defensive fallback + if name == "count" && is_count_star_expr(&inputs) { + return Ok(StatsAgg::CountStar); + } + + if inputs.len() != 1 { + return Err(RejectReason::UnsupportedAggregate); + } + + let Some(column) = inputs[0].as_any().downcast_ref::() else { + return Err(RejectReason::UnsupportedAggregate); + }; + let column_name = column.name().to_string(); + let arg_type = inputs[0] + .data_type(self.aggregate_exec.input_schema().as_ref()) + .map_err(|_| RejectReason::UnsupportedAggregate)?; + + if self.is_tag_column(&column_name) { + return Err(RejectReason::UnsupportedPlan); + } + + let is_time_index = column_name == self.region_scan.time_index(); + match (name.as_str(), is_time_index) { + ("count", true) => Ok(StatsAgg::CountTimeIndex { arg_type }), + ("count", false) => Ok(StatsAgg::CountField { + column_name, + arg_type, + }), + ("min", true) => Ok(StatsAgg::MinTimeIndex { arg_type }), + ("min", false) => Ok(StatsAgg::MinField { + column_name, + arg_type, + }), + ("max", true) => Ok(StatsAgg::MaxTimeIndex { arg_type }), + ("max", false) => Ok(StatsAgg::MaxField { + column_name, + arg_type, + }), + _ => Err(RejectReason::UnsupportedAggregate), + } + } + + pub(super) fn check_agg_shape(expr: &AggregateFunctionExpr) -> Result<(), RejectReason> { + if expr.is_distinct() + || expr.ignore_nulls() + || expr.is_reversed() + || !expr.order_bys().is_empty() + { + return Err(RejectReason::UnsupportedAggregate); + } + + Ok(()) + } + + fn is_tag_column(&self, column_name: &str) -> bool { + self.region_scan + .tag_columns() + .iter() + .any(|tag| tag == column_name) + } + + fn stats_unavailable( + &self, + scan_input_stats: Option<&store_api::scan_stats::RegionScanStats>, + ) -> bool { + // These cases keep the plan shape eligible in principle, but the scan cannot + // provide metadata that is trustworthy enough for a safe stats-only rewrite. + self.region_scan.has_predicate_without_region() + || scan_input_stats.is_none() + || has_partition_expr_mismatch(scan_input_stats) + } +} + +#[derive(Debug)] +pub(super) enum RejectReason { + /// The physical plan shape is outside the current safe rewrite envelope. + UnsupportedPlan, + /// At least one aggregate function or aggregate shape is unsupported. + UnsupportedAggregate, + /// The scan cannot provide trustworthy stats for this rewrite attempt. + StatsUnavailable, + /// The aggregate shape is supported, but no file contributes metadata-only results. + NoStatsFiles, +} + +impl std::fmt::Display for RejectReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RejectReason::UnsupportedPlan => write!( + f, + "aggregate stats MVP does not support this plan shape safely" + ), + RejectReason::UnsupportedAggregate => write!( + f, + "aggregate stats MVP only supports the narrowed no-GROUP-BY count/min/max matrix" + ), + RejectReason::StatsUnavailable => write!( + f, + "aggregate stats MVP found a supported shape, but trustworthy statistics are unavailable for it" + ), + RejectReason::NoStatsFiles => write!( + f, + "aggregate stats rewrite requires at least one stats-backed file after safety checks" + ), + } + } +} + +pub(super) fn is_supported_aggregate_name(name: &str) -> bool { + matches!(name.to_ascii_lowercase().as_str(), "min" | "max" | "count") +} + +fn is_count_star_expr(inputs: &[std::sync::Arc]) -> bool { + match inputs { + // Keep the legacy empty-input shape as a compatibility fallback. + [] => true, + [arg] => arg + .as_any() + .downcast_ref::() + .is_some_and(|lit| lit.value() == &COUNT_STAR_EXPANSION), + _ => false, + } +} diff --git a/src/query/src/optimizer/aggr_stats/split.rs b/src/query/src/optimizer/aggr_stats/split.rs new file mode 100644 index 0000000000..e49183d96b --- /dev/null +++ b/src/query/src/optimizer/aggr_stats/split.rs @@ -0,0 +1,399 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_function::aggrs::aggr_wrapper::StateWrapper; +use common_time::Timestamp; +use datafusion::functions_aggregate::count::count_udaf; +use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf}; +use datafusion_common::{Result as DfResult, ScalarValue}; +use datatypes::arrow::datatypes::DataType; +use datatypes::data_type::ConcreteDataType; +use datatypes::value::Value; +use store_api::scan_stats::RegionScanStats; + +use super::StatsAgg; + +#[cfg(test)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) enum FileStatsRequirement { + FileExactRowCount, + FileTimeRange, + RowGroupMinMax, + RowGroupNullCount, +} + +/// Splits scan input files into two buckets for one aggregate rewrite path. +/// +/// `stats_file_ordinals` and `scan_file_ordinals` store +/// `RegionScanFileStats::file_ordinal`, not indexes into +/// `RegionScanStats::files`. The optimizer later uses these ordinals to decide +/// which physical files can be answered from metadata and which still need a +/// real scan. +/// +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct FileSplit { + /// File ordinals whose stats is sufficient to contribute to the rewrite. + pub stats_file_ordinals: Vec, + /// File ordinals that still need to be scanned because required stats are missing. + pub scan_file_ordinals: Vec, + /// Aggregate contribution computed from `stats_file_ordinals` only. + pub stats: T, +} + +impl FileSplit { + fn has_stats_files(&self) -> bool { + !self.stats_file_ordinals.is_empty() + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +/// File-level time bounds collected from metadata-only files. +pub(super) struct TimeBounds { + pub min: Option, + pub max: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +/// Field value bounds collected from metadata-only files. +pub(super) struct ValueBounds { + pub min: Option, + pub max: Option, +} + +pub(super) type CountStarFileSplit = FileSplit; +pub(super) type FieldCountFileSplit = FileSplit; +pub(super) type TimeFileSplit = FileSplit; +pub(super) type FieldMinMaxFileSplit = FileSplit; + +pub(super) trait StatsAggExt { + fn has_stats_files(&self, scan_input_stats: &RegionScanStats) -> bool; +} + +impl StatsAggExt for StatsAgg { + fn has_stats_files(&self, scan_input_stats: &RegionScanStats) -> bool { + match self { + StatsAgg::CountStar => split_count_star_files(scan_input_stats).has_stats_files(), + StatsAgg::CountField { column_name, .. } => { + split_count_field_files(scan_input_stats, column_name).has_stats_files() + } + StatsAgg::CountTimeIndex { .. } + | StatsAgg::MinTimeIndex { .. } + | StatsAgg::MaxTimeIndex { .. } => split_time_files(scan_input_stats).has_stats_files(), + StatsAgg::MinField { column_name, .. } | StatsAgg::MaxField { column_name, .. } => { + split_min_max_field_files(scan_input_stats, column_name).has_stats_files() + } + } + } +} + +#[cfg(test)] +impl StatsAgg { + pub(super) fn file_stats_requirement(&self) -> FileStatsRequirement { + match self { + StatsAgg::CountStar => FileStatsRequirement::FileExactRowCount, + StatsAgg::CountField { .. } => FileStatsRequirement::RowGroupNullCount, + StatsAgg::CountTimeIndex { .. } + | StatsAgg::MinTimeIndex { .. } + | StatsAgg::MaxTimeIndex { .. } => FileStatsRequirement::FileTimeRange, + StatsAgg::MinField { .. } | StatsAgg::MaxField { .. } => { + FileStatsRequirement::RowGroupMinMax + } + } + } +} + +pub(super) fn has_partition_expr_mismatch(scan_input_stats: Option<&RegionScanStats>) -> bool { + scan_input_stats + .map(|stats| { + stats + .files + .iter() + .any(|file| !file.partition_expr_matches_region) + }) + .unwrap_or(false) +} + +pub(super) fn split_count_star_files(scan_input_stats: &RegionScanStats) -> CountStarFileSplit { + scan_input_stats.files.iter().fold( + FileSplit { + stats_file_ordinals: Vec::new(), + scan_file_ordinals: Vec::new(), + stats: 0, + }, + |mut split, file| { + if file.partition_expr_matches_region + && let Some(num_rows) = file.exact_num_rows + { + split.stats_file_ordinals.push(file.file_ordinal); + split.stats += num_rows; + return split; + } + + split.scan_file_ordinals.push(file.file_ordinal); + split + }, + ) +} + +pub(super) fn split_time_files(scan_input_stats: &RegionScanStats) -> TimeFileSplit { + scan_input_stats.files.iter().fold( + FileSplit { + stats_file_ordinals: Vec::new(), + scan_file_ordinals: Vec::new(), + stats: TimeBounds::default(), + }, + |mut split, file| { + if file.partition_expr_matches_region + && let Some((min_ts, max_ts)) = file.time_range + { + split.stats_file_ordinals.push(file.file_ordinal); + split.stats.min = Some(match split.stats.min { + Some(current) => Timestamp::min(current, min_ts), + None => min_ts, + }); + split.stats.max = Some(match split.stats.max { + Some(current) => Timestamp::max(current, max_ts), + None => max_ts, + }); + return split; + } + + split.scan_file_ordinals.push(file.file_ordinal); + split + }, + ) +} + +pub(super) fn split_count_field_files( + scan_input_stats: &RegionScanStats, + column_name: &str, +) -> FieldCountFileSplit { + scan_input_stats.files.iter().fold( + FileSplit { + stats_file_ordinals: Vec::new(), + scan_file_ordinals: Vec::new(), + stats: 0, + }, + |mut split, file| { + if file.partition_expr_matches_region + && let Some(non_null_rows) = file + .field_stats + .get(column_name) + .and_then(|stats| stats.exact_non_null_rows) + { + split.stats_file_ordinals.push(file.file_ordinal); + split.stats += non_null_rows; + return split; + } + + split.scan_file_ordinals.push(file.file_ordinal); + split + }, + ) +} + +pub(super) fn split_min_max_field_files( + scan_input_stats: &RegionScanStats, + column_name: &str, +) -> FieldMinMaxFileSplit { + scan_input_stats.files.iter().fold( + FileSplit { + stats_file_ordinals: Vec::new(), + scan_file_ordinals: Vec::new(), + stats: ValueBounds::default(), + }, + |mut split, file| { + if file.partition_expr_matches_region + && let Some(stats) = file.field_stats.get(column_name) + && let (Some(min_value), Some(max_value)) = + (stats.min_value.clone(), stats.max_value.clone()) + { + split.stats_file_ordinals.push(file.file_ordinal); + split.stats.min = Some(match split.stats.min { + Some(current) => Value::min(current, min_value), + None => min_value, + }); + split.stats.max = Some(match split.stats.max { + Some(current) => Value::max(current, max_value), + None => max_value, + }); + return split; + } + + split.scan_file_ordinals.push(file.file_ordinal); + split + }, + ) +} + +// These helpers are implemented for the upcoming mixed stats-plus-scan rewrite in task 3. +#[allow(dead_code)] +pub(super) fn partial_state_from_stats( + aggregate: &StatsAgg, + scan_input_stats: &RegionScanStats, +) -> DfResult> { + match aggregate { + StatsAgg::CountStar => { + let split = split_count_star_files(scan_input_stats); + if !split.has_stats_files() { + return Ok(None); + } + + let wrapper = StateWrapper::new((*count_udaf()).clone())?; + wrapper + .value_from_custom_state_fields( + &[], + vec![ScalarValue::Int64(Some(split.stats as i64))], + ) + .map(Some) + } + StatsAgg::CountField { + arg_type, + column_name, + } => { + let split = split_count_field_files(scan_input_stats, column_name); + if !split.has_stats_files() { + return Ok(None); + } + + let wrapper = StateWrapper::new((*count_udaf()).clone())?; + wrapper + .value_from_custom_state_fields( + std::slice::from_ref(arg_type), + vec![ScalarValue::Int64(Some(split.stats as i64))], + ) + .map(Some) + } + StatsAgg::CountTimeIndex { arg_type } => { + let split = split_count_star_files(scan_input_stats); + if !split.has_stats_files() { + return Ok(None); + } + + let wrapper = StateWrapper::new((*count_udaf()).clone())?; + wrapper + .value_from_custom_state_fields( + std::slice::from_ref(arg_type), + vec![ScalarValue::Int64(Some(split.stats as i64))], + ) + .map(Some) + } + StatsAgg::MinField { + arg_type, + column_name, + } => { + let split = split_min_max_field_files(scan_input_stats, column_name); + let Some(value) = split.stats.min.as_ref() else { + return Ok(None); + }; + + let wrapper = StateWrapper::new((*min_udaf()).clone())?; + wrapper + .value_from_custom_state_fields( + std::slice::from_ref(arg_type), + vec![stats_value_scalar(value, arg_type)?], + ) + .map(Some) + } + StatsAgg::MinTimeIndex { arg_type } => { + let split = split_time_files(scan_input_stats); + let Some(timestamp) = split.stats.min else { + return Ok(None); + }; + + let wrapper = StateWrapper::new((*min_udaf()).clone())?; + wrapper + .value_from_custom_state_fields( + std::slice::from_ref(arg_type), + vec![timestamp_scalar_value(×tamp, arg_type)?], + ) + .map(Some) + } + StatsAgg::MaxField { + arg_type, + column_name, + } => { + let split = split_min_max_field_files(scan_input_stats, column_name); + let Some(value) = split.stats.max.as_ref() else { + return Ok(None); + }; + + let wrapper = StateWrapper::new((*max_udaf()).clone())?; + wrapper + .value_from_custom_state_fields( + std::slice::from_ref(arg_type), + vec![stats_value_scalar(value, arg_type)?], + ) + .map(Some) + } + StatsAgg::MaxTimeIndex { arg_type } => { + let split = split_time_files(scan_input_stats); + let Some(timestamp) = split.stats.max else { + return Ok(None); + }; + + let wrapper = StateWrapper::new((*max_udaf()).clone())?; + wrapper + .value_from_custom_state_fields( + std::slice::from_ref(arg_type), + vec![timestamp_scalar_value(×tamp, arg_type)?], + ) + .map(Some) + } + } +} + +#[allow(dead_code)] +fn timestamp_scalar_value(timestamp: &Timestamp, arg_type: &DataType) -> DfResult { + match arg_type { + DataType::Timestamp(unit, tz) => { + let converted = timestamp.convert_to((*unit).into()).ok_or_else(|| { + datafusion_common::DataFusionError::Internal(format!( + "failed to convert timestamp {timestamp:?} to {unit:?}" + )) + })?; + Ok(match unit { + datatypes::arrow::datatypes::TimeUnit::Second => { + ScalarValue::TimestampSecond(Some(converted.value()), tz.clone()) + } + datatypes::arrow::datatypes::TimeUnit::Millisecond => { + ScalarValue::TimestampMillisecond(Some(converted.value()), tz.clone()) + } + datatypes::arrow::datatypes::TimeUnit::Microsecond => { + ScalarValue::TimestampMicrosecond(Some(converted.value()), tz.clone()) + } + datatypes::arrow::datatypes::TimeUnit::Nanosecond => { + ScalarValue::TimestampNanosecond(Some(converted.value()), tz.clone()) + } + }) + } + _ => Err(datafusion_common::DataFusionError::Internal(format!( + "expected timestamp arg type, got {arg_type:?}" + ))), + } +} + +#[allow(dead_code)] +fn stats_value_scalar(value: &Value, arg_type: &DataType) -> DfResult { + let output_type = ConcreteDataType::try_from(arg_type).map_err(|err| { + datafusion_common::DataFusionError::Internal(format!( + "failed to convert arrow type {arg_type:?} to concrete type: {err}" + )) + })?; + value.try_to_scalar_value(&output_type).map_err(|err| { + datafusion_common::DataFusionError::Internal(format!( + "failed to convert stats value {value:?} to scalar for {arg_type:?}: {err}" + )) + }) +} diff --git a/src/query/src/optimizer/aggr_stats/tests.rs b/src/query/src/optimizer/aggr_stats/tests.rs new file mode 100644 index 0000000000..c98e6e5fee --- /dev/null +++ b/src/query/src/optimizer/aggr_stats/tests.rs @@ -0,0 +1,705 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::{Int64Array, TimestampMillisecondArray}; +use arrow::datatypes::{Field, Schema}; +use common_time::Timestamp; +use common_time::timestamp::TimeUnit as TimestampUnit; +use datafusion::functions_aggregate::count::count_udaf; +use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; +use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; +use datafusion_physical_expr::expressions::{Column, Literal, PhysicalSortExpr}; +use datatypes::arrow::array::AsArray; +use datatypes::arrow::datatypes::{DataType, TimeUnit}; +use datatypes::prelude::ConcreteDataType; +use datatypes::schema::{ColumnSchema, Schema as GreptimeSchema}; +use datatypes::value::Value; +use session::context::QueryContext; +use store_api::scan_stats::{ + RegionScanColumnStats as RegionScanColumnInputStats, + RegionScanFileStats as RegionScanFileInputStats, RegionScanStats as RegionScanInputStats, +}; +use table::metadata::{TableInfoBuilder, TableMetaBuilder}; +use table::test_util::EmptyTable; + +use super::StatsAgg; +use super::check::{RejectReason, RewriteCheck, is_supported_aggregate_name}; +use super::split::{ + FileStatsRequirement, StatsAggExt, has_partition_expr_mismatch, partial_state_from_stats, + split_count_field_files, split_count_star_files, split_min_max_field_files, split_time_files, +}; +use crate::parser::QueryLanguageParser; +use crate::tests::new_query_engine_with_table; + +fn test_timestamp(value: i64) -> Timestamp { + Timestamp::new(value, TimestampUnit::Millisecond) +} + +fn field_stats( + exact_non_null_rows: Option, + min_value: Option, + max_value: Option, +) -> HashMap { + HashMap::from([( + "value".to_string(), + RegionScanColumnInputStats { + min_value, + max_value, + exact_non_null_rows, + }, + )]) +} + +fn build_test_aggr_expr( + distinct: bool, + ignore_nulls: bool, + order_by: bool, +) -> datafusion_common::Result { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + true, + )])); + let args = vec![Arc::new(Column::new("value", 0)) as Arc]; + let mut builder = AggregateExprBuilder::new(Arc::new((*count_udaf()).clone()), args) + .schema(schema) + .alias("count(value)"); + + if distinct { + builder = builder.with_distinct(true); + } + if ignore_nulls { + builder = builder.ignore_nulls(); + } + if order_by { + builder = builder.order_by(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("value", 0)), + options: Default::default(), + }]); + } + + builder.build() +} + +fn build_count_star_aggr_expr() -> datafusion_common::Result { + let schema = Arc::new(Schema::empty()); + let args = vec![Arc::new(Literal::new(COUNT_STAR_EXPANSION)) as Arc]; + + AggregateExprBuilder::new(Arc::new((*count_udaf()).clone()), args) + .schema(schema) + .alias("count(*)") + .build() +} + +fn new_test_engine() -> crate::QueryEngineRef { + let columns = vec![ + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ), + ColumnSchema::new("value", ConcreteDataType::int64_datatype(), true), + ]; + let schema = Arc::new(GreptimeSchema::new(columns)); + let table_meta = TableMetaBuilder::empty() + .schema(schema) + .primary_key_indices(vec![0]) + .value_indices(vec![1]) + .next_column_id(1024) + .build() + .unwrap(); + let table_info = TableInfoBuilder::new("test", table_meta).build().unwrap(); + let table = EmptyTable::from_table_info(&table_info); + + new_query_engine_with_table(table) +} + +async fn parse_sql_to_plan(sql: &str) -> LogicalPlan { + let query_ctx = QueryContext::arc(); + let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap(); + new_test_engine() + .planner() + .plan(&stmt, query_ctx) + .await + .unwrap() +} + +#[test] +fn test_supported_aggregate_names() { + assert!(is_supported_aggregate_name("min")); + assert!(is_supported_aggregate_name("MAX")); + assert!(is_supported_aggregate_name("count")); + assert!(!is_supported_aggregate_name("sum")); + assert!(!is_supported_aggregate_name("avg(value)")); +} + +#[test] +fn test_file_stats_requirement_matrix() { + let count_star = StatsAgg::CountStar; + assert_eq!( + count_star.file_stats_requirement(), + FileStatsRequirement::FileExactRowCount + ); + + let time_min = StatsAgg::MinTimeIndex { + arg_type: DataType::Timestamp(TimeUnit::Millisecond, None), + }; + assert_eq!( + time_min.file_stats_requirement(), + FileStatsRequirement::FileTimeRange + ); + + let field_count = StatsAgg::CountField { + column_name: "value".to_string(), + arg_type: DataType::Int64, + }; + assert_eq!( + field_count.file_stats_requirement(), + FileStatsRequirement::RowGroupNullCount + ); +} + +#[test] +fn test_count_star_file_stats_eligibility() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: None, + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(42), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: false, + }, + RegionScanFileInputStats { + file_ordinal: 2, + exact_num_rows: Some(7), + time_range: Some((test_timestamp(50), test_timestamp(60))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + ], + }; + + assert!(StatsAgg::CountStar.has_stats_files(&stats)); +} + +#[test] +fn test_split_count_star_files() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: None, + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 2, + exact_num_rows: Some(5), + time_range: Some((test_timestamp(50), test_timestamp(60))), + field_stats: HashMap::new(), + partition_expr_matches_region: false, + }, + ], + }; + + let split = split_count_star_files(&stats); + assert_eq!(split.stats_file_ordinals, vec![0]); + assert_eq!(split.scan_file_ordinals, vec![1, 2]); + assert_eq!(split.stats, 3); +} + +#[test] +fn test_split_count_star_files_keeps_zero_row_files_stats_eligible() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(0), + time_range: None, + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: None, + time_range: None, + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + ], + }; + + let split = split_count_star_files(&stats); + assert_eq!(split.stats_file_ordinals, vec![0]); + assert_eq!(split.scan_file_ordinals, vec![1]); + assert_eq!(split.stats, 0); +} + +#[test] +fn test_supported_aggregate_rejects_distinct_ignore_nulls_and_order_by_shapes() { + let distinct = build_test_aggr_expr(true, false, false).unwrap(); + let ignore_nulls = build_test_aggr_expr(false, true, false).unwrap(); + let order_by = build_test_aggr_expr(false, false, true).unwrap(); + + assert!(matches!( + RewriteCheck::check_agg_shape(&distinct), + Err(RejectReason::UnsupportedAggregate) + )); + assert!(matches!( + RewriteCheck::check_agg_shape(&ignore_nulls), + Err(RejectReason::UnsupportedAggregate) + )); + assert!(matches!( + RewriteCheck::check_agg_shape(&order_by), + Err(RejectReason::UnsupportedAggregate) + )); +} + +#[test] +fn test_count_star_expansion_is_treated_as_count_star() { + let expr = build_count_star_aggr_expr().unwrap(); + + assert_eq!(expr.fun().name(), "count"); + assert_eq!(expr.expressions().len(), 1); + assert!( + expr.expressions()[0] + .as_any() + .downcast_ref::() + .is_some_and(|lit| lit.value() == &COUNT_STAR_EXPANSION) + ); +} + +#[tokio::test] +async fn test_sql_count_star_is_planned_with_count_star_expansion() { + let plan = parse_sql_to_plan("select count(*) from test").await; + let LogicalPlan::Projection(projection) = plan else { + panic!("expected projection over aggregate plan, got {plan:?}"); + }; + let LogicalPlan::Aggregate(aggregate) = projection.input.as_ref() else { + panic!("expected aggregate input, got {:?}", projection.input); + }; + + assert_eq!(aggregate.aggr_expr.len(), 1); + let Expr::AggregateFunction(AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + }) = &aggregate.aggr_expr[0] + else { + panic!( + "expected aggregate function expr, got {:?}", + aggregate.aggr_expr[0] + ); + }; + + assert_eq!(func.name(), "count"); + assert_eq!(args.len(), 1); + assert!(matches!( + &args[0], + Expr::Literal(value, _) if value == &COUNT_STAR_EXPANSION + )); +} + +#[test] +fn test_partition_expr_mismatch_detection() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(1), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(2), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: false, + }, + ], + }; + + assert!(has_partition_expr_mismatch(Some(&stats))); +} + +#[test] +fn test_time_file_stats_eligibility() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: None, + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: None, + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + ], + }; + + assert!( + StatsAgg::MinTimeIndex { + arg_type: DataType::Timestamp(TimeUnit::Millisecond, None), + } + .has_stats_files(&stats) + ); +} + +#[test] +fn test_split_time_files() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(50), test_timestamp(70))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: None, + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 2, + exact_num_rows: Some(5), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 3, + exact_num_rows: Some(6), + time_range: Some((test_timestamp(5), test_timestamp(100))), + field_stats: HashMap::new(), + partition_expr_matches_region: false, + }, + ], + }; + + let split = split_time_files(&stats); + assert_eq!(split.stats_file_ordinals, vec![0, 2]); + assert_eq!(split.scan_file_ordinals, vec![1, 3]); + assert_eq!(split.stats.min, Some(test_timestamp(10))); + assert_eq!(split.stats.max, Some(test_timestamp(70))); +} + +#[test] +fn test_count_field_file_stats_eligibility() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(5), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: field_stats(Some(4), Some(Value::Int64(1)), Some(Value::Int64(9))), + partition_expr_matches_region: true, + }, + ], + }; + + assert!( + StatsAgg::CountField { + column_name: "value".to_string(), + arg_type: DataType::Int64, + } + .has_stats_files(&stats) + ); +} + +#[test] +fn test_split_count_field_files() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 2, + exact_num_rows: Some(5), + time_range: Some((test_timestamp(50), test_timestamp(60))), + field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(8))), + partition_expr_matches_region: true, + }, + ], + }; + + let split = split_count_field_files(&stats, "value"); + assert_eq!(split.stats_file_ordinals, vec![0, 2]); + assert_eq!(split.scan_file_ordinals, vec![1]); + assert_eq!(split.stats, 6); +} + +#[test] +fn test_min_max_field_stats_eligibility() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: false, + }, + ], + }; + + assert!( + StatsAgg::MinField { + column_name: "value".to_string(), + arg_type: DataType::Int64, + } + .has_stats_files(&stats) + ); +} + +#[test] +fn test_split_min_max_field_files() { + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: field_stats(Some(2), Some(Value::Int64(4)), Some(Value::Int64(9))), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 2, + exact_num_rows: Some(5), + time_range: Some((test_timestamp(50), test_timestamp(60))), + field_stats: field_stats(Some(4), Some(Value::Int64(1)), Some(Value::Int64(7))), + partition_expr_matches_region: true, + }, + ], + }; + + let split = split_min_max_field_files(&stats, "value"); + assert_eq!(split.stats_file_ordinals, vec![0, 2]); + assert_eq!(split.scan_file_ordinals, vec![1]); + assert_eq!(split.stats.min, Some(Value::Int64(1))); + assert_eq!(split.stats.max, Some(Value::Int64(9))); +} + +#[test] +fn test_partial_state_from_stats_count_star() { + let aggregate = StatsAgg::CountStar; + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + ], + }; + + let value = partial_state_from_stats(&aggregate, &stats) + .unwrap() + .unwrap(); + let array = value.to_array().unwrap(); + let struct_array = array.as_struct(); + let count_values = struct_array + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count_values.value(0), 7); +} + +#[test] +fn test_partial_state_from_stats_count_field() { + let aggregate = StatsAgg::CountField { + column_name: "value".to_string(), + arg_type: DataType::Int64, + }; + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(9))), + partition_expr_matches_region: true, + }, + ], + }; + + let value = partial_state_from_stats(&aggregate, &stats) + .unwrap() + .unwrap(); + let array = value.to_array().unwrap(); + let struct_array = array.as_struct(); + let count_values = struct_array + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count_values.value(0), 6); +} + +#[test] +fn test_partial_state_from_stats_min_time() { + let aggregate = StatsAgg::MinTimeIndex { + arg_type: DataType::Timestamp(TimeUnit::Millisecond, None), + }; + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(50), test_timestamp(70))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(5), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: HashMap::new(), + partition_expr_matches_region: true, + }, + ], + }; + + let value = partial_state_from_stats(&aggregate, &stats) + .unwrap() + .unwrap(); + let array = value.to_array().unwrap(); + let struct_array = array.as_struct(); + let ts_values = struct_array + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(ts_values.value(0), 10); +} + +#[test] +fn test_partial_state_from_stats_max_field() { + let aggregate = StatsAgg::MaxField { + column_name: "value".to_string(), + arg_type: DataType::Int64, + }; + let stats = RegionScanInputStats { + files: vec![ + RegionScanFileInputStats { + file_ordinal: 0, + exact_num_rows: Some(3), + time_range: Some((test_timestamp(10), test_timestamp(20))), + field_stats: field_stats(Some(2), Some(Value::Int64(1)), Some(Value::Int64(3))), + partition_expr_matches_region: true, + }, + RegionScanFileInputStats { + file_ordinal: 1, + exact_num_rows: Some(4), + time_range: Some((test_timestamp(30), test_timestamp(40))), + field_stats: field_stats(Some(4), Some(Value::Int64(5)), Some(Value::Int64(9))), + partition_expr_matches_region: true, + }, + ], + }; + + let value = partial_state_from_stats(&aggregate, &stats) + .unwrap() + .unwrap(); + let array = value.to_array().unwrap(); + let struct_array = array.as_struct(); + let max_values = struct_array + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(max_values.value(0), 9); +} diff --git a/src/query/src/optimizer/aggregate_stats.rs b/src/query/src/optimizer/aggregate_stats.rs deleted file mode 100644 index b07af77c60..0000000000 --- a/src/query/src/optimizer/aggregate_stats.rs +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use common_telemetry::debug; -use datafusion::config::ConfigOptions; -use datafusion::physical_optimizer::PhysicalOptimizerRule; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::aggregates::AggregateExec; -use datafusion_common::Result as DfResult; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use table::table::scan::RegionScanExec; - -#[derive(Debug)] -pub struct AggregateStats; - -impl PhysicalOptimizerRule for AggregateStats { - fn optimize( - &self, - plan: Arc, - _config: &ConfigOptions, - ) -> DfResult> { - Self::do_optimize(plan) - } - - fn name(&self) -> &str { - "aggregate_stats" - } - - fn schema_check(&self) -> bool { - true - } -} - -impl AggregateStats { - fn do_optimize(plan: Arc) -> DfResult> { - let result = plan - .transform_down(|plan| { - let Some(aggregate_exec) = plan.as_any().downcast_ref::() else { - return Ok(Transformed::no(plan)); - }; - - let Some(region_scan) = Self::extract_region_scan(aggregate_exec) else { - return Ok(Transformed::no(plan)); - }; - - let eligibility = AggregateStatsEligibility::new(aggregate_exec, region_scan); - if let Some(reason) = eligibility.ineligible_reason() { - debug!("Skip aggregate stats optimization: {reason}"); - return Ok(Transformed::no(plan)); - } - - // TODO(ruihang): implement mixed stats-plus-scan rewrite in follow-up tasks. - Ok(Transformed::no(plan)) - })? - .data; - - Ok(result) - } - - fn extract_region_scan(aggregate_exec: &AggregateExec) -> Option<&RegionScanExec> { - let child = aggregate_exec.children().into_iter().next()?; - child.as_any().downcast_ref::() - } -} - -#[derive(Debug)] -struct AggregateStatsEligibility<'a> { - aggregate_exec: &'a AggregateExec, - region_scan: &'a RegionScanExec, -} - -impl<'a> AggregateStatsEligibility<'a> { - fn new(aggregate_exec: &'a AggregateExec, region_scan: &'a RegionScanExec) -> Self { - Self { - aggregate_exec, - region_scan, - } - } - - fn ineligible_reason(&self) -> Option { - if !self.region_scan.append_mode() { - return Some(EligibilityRejection::NonAppendOnly); - } - - if !self.aggregate_exec.group_expr().is_empty() { - return Some(EligibilityRejection::GroupedAggregate); - } - - if !self.has_supported_aggregates() { - return Some(EligibilityRejection::UnsupportedAggregate); - } - - if !self.has_stats_eligible_candidates() { - return Some(EligibilityRejection::NoStatsEligibleFiles); - } - - None - } - - fn has_supported_aggregates(&self) -> bool { - let aggr_exprs = self.aggregate_exec.aggr_expr(); - !aggr_exprs.is_empty() - && aggr_exprs - .iter() - .all(|expr| is_supported_aggregate_name(expr.name())) - } - - fn has_stats_eligible_candidates(&self) -> bool { - // TODO(ruihang): replace this scaffold with per-file stats classification. - self.region_scan.total_rows() > 0 - } -} - -#[derive(Debug)] -enum EligibilityRejection { - NonAppendOnly, - GroupedAggregate, - UnsupportedAggregate, - NoStatsEligibleFiles, -} - -impl std::fmt::Display for EligibilityRejection { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - EligibilityRejection::NonAppendOnly => { - write!(f, "aggregate stats MVP only supports append-only scans") - } - EligibilityRejection::GroupedAggregate => { - write!(f, "aggregate stats MVP does not support GROUP BY yet") - } - EligibilityRejection::UnsupportedAggregate => { - write!( - f, - "aggregate stats MVP only supports min/max/count aggregates" - ) - } - EligibilityRejection::NoStatsEligibleFiles => { - write!( - f, - "aggregate stats rewrite requires at least one stats-eligible file" - ) - } - } - } -} - -fn is_supported_aggregate_name(name: &str) -> bool { - let normalized = name.split('(').next().unwrap_or(name).to_ascii_lowercase(); - matches!(normalized.as_str(), "min" | "max" | "count") -} - -#[cfg(test)] -mod tests { - use super::is_supported_aggregate_name; - - #[test] - fn test_supported_aggregate_names() { - assert!(is_supported_aggregate_name("min")); - assert!(is_supported_aggregate_name("max(value)")); - assert!(is_supported_aggregate_name("count(*)")); - assert!(!is_supported_aggregate_name("sum(value)")); - assert!(!is_supported_aggregate_name("avg(value)")); - } -} diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 8a7383342f..a680a79ad5 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -59,7 +59,7 @@ use crate::dist_plan::{ }; use crate::metrics::{QUERY_MEMORY_POOL_REJECTED_TOTAL, QUERY_MEMORY_POOL_USAGE_BYTES}; use crate::optimizer::ExtensionAnalyzerRule; -use crate::optimizer::aggregate_stats::AggregateStats; +use crate::optimizer::aggr_stats::AggregateStats; use crate::optimizer::constant_term::MatchesConstantTermOptimizer; use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule; use crate::optimizer::parallelize_scan::ParallelizeScan; diff --git a/src/store-api/src/lib.rs b/src/store-api/src/lib.rs index 4df594fc67..ccb3abc8f0 100644 --- a/src/store-api/src/lib.rs +++ b/src/store-api/src/lib.rs @@ -26,6 +26,7 @@ pub mod mito_engine_options; pub mod path_utils; pub mod region_engine; pub mod region_request; +pub mod scan_stats; pub mod sst_entry; pub mod storage; diff --git a/src/store-api/src/region_engine.rs b/src/store-api/src/region_engine.rs index b3f460d01d..54507c58f3 100644 --- a/src/store-api/src/region_engine.rs +++ b/src/store-api/src/region_engine.rs @@ -37,6 +37,7 @@ use crate::metadata::RegionMetadataRef; use crate::region_request::{ BatchRegionDdlRequest, RegionCatchupRequest, RegionOpenRequest, RegionRequest, }; +use crate::scan_stats::RegionScanStats; use crate::storage::{FileId, RegionId, ScanRequest, SequenceNumber}; /// The settable region role state. @@ -451,6 +452,11 @@ pub trait RegionScanner: Debug + DisplayAs + Send { /// Check if there is any predicate exclude region partition exprs that may be executed in this scanner. fn has_predicate_without_region(&self) -> bool; + /// Returns file-level scan statistics for the current scanner input when available. + fn scan_input_stats(&self) -> Result, BoxedError> { + Ok(None) + } + /// Add the given dynamic filter expressions to the predicate of the scanner. /// Returns a vector of booleans indicating which filter expressions were applied. /// true indicates the filter expression was applied(will be use by scanner to prune by stat for row group), diff --git a/src/store-api/src/scan_stats.rs b/src/store-api/src/scan_stats.rs new file mode 100644 index 0000000000..fbe7865293 --- /dev/null +++ b/src/store-api/src/scan_stats.rs @@ -0,0 +1,45 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Trusted, scan-derived statistics exposed by region scanners for higher-level optimizations. +//! +//! These are not raw storage-format statistics. They are conservative summaries that scanner +//! implementations can safely expose for optimizer consumption. + +use std::collections::HashMap; + +use common_time::Timestamp; +use datatypes::value::Value; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RegionScanColumnStats { + pub min_value: Option, + pub max_value: Option, + pub exact_non_null_rows: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RegionScanFileStats { + /// Stable file id within one `RegionScanStats` snapshot. + pub file_ordinal: usize, + pub exact_num_rows: Option, + pub time_range: Option<(Timestamp, Timestamp)>, + pub field_stats: HashMap, + pub partition_expr_matches_region: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RegionScanStats { + pub files: Vec, +} diff --git a/src/table/src/table/scan.rs b/src/table/src/table/scan.rs index a49f727d99..94f35a8167 100644 --- a/src/table/src/table/scan.rs +++ b/src/table/src/table/scan.rs @@ -48,6 +48,7 @@ use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME; use store_api::region_engine::{ PartitionRange, PrepareRequest, QueryScanContext, RegionScannerRef, }; +use store_api::scan_stats::RegionScanStats; use store_api::storage::{ScanRequest, TimeSeriesDistribution}; use crate::table::metrics::StreamMetrics; @@ -315,6 +316,18 @@ impl RegionScanExec { self.total_rows } + pub fn has_predicate_without_region(&self) -> bool { + self.scanner.lock().unwrap().has_predicate_without_region() + } + + pub fn scan_input_stats(&self) -> DfResult> { + self.scanner + .lock() + .unwrap() + .scan_input_stats() + .map_err(|err| DataFusionError::External(err.into())) + } + pub fn with_distinguish_partition_range(&self, distinguish_partition_range: bool) { let mut scanner = self.scanner.lock().unwrap(); // set distinguish_partition_range won't fail