diff --git a/src/mito2/src/read/pruner.rs b/src/mito2/src/read/pruner.rs index 2dfb41a482..9eb006b1cb 100644 --- a/src/mito2/src/read/pruner.rs +++ b/src/mito2/src/read/pruner.rs @@ -625,3 +625,67 @@ impl Pruner { ); } } + +#[cfg(test)] +mod tests { + use super::*; + + fn file_builder_entry_with_cached_builder( + requirements: Option>, + ) -> FileBuilderEntry { + FileBuilderEntry { + builder: Some(Arc::new(FileRangeBuilder::default())), + stats_aware_skip_requirements: requirements, + remaining_ranges: 1, + waiters: Vec::new(), + } + } + + #[test] + fn keeps_cached_builder_when_stats_aware_skip_requirements_match() { + let requirements = vec![SupportStatAggr::CountRows]; + let skip_config = StatsAwareSkipConfig::new(requirements.clone()).unwrap(); + let mut entry = file_builder_entry_with_cached_builder(Some(requirements.into())); + let cached_builder = entry.builder.as_ref().unwrap().clone(); + + entry.clear_builder_if_skip_requirements_changed(Some(&skip_config)); + + assert!(entry.builder.is_some()); + assert!(Arc::ptr_eq( + entry.builder.as_ref().unwrap(), + &cached_builder + )); + assert_eq!( + entry.stats_aware_skip_requirements.as_deref(), + Some(skip_config.requirements()) + ); + } + + #[test] + fn clears_cached_builder_when_stats_aware_skip_requirements_change() { + let old_requirements = vec![SupportStatAggr::CountRows]; + let new_requirements = vec![SupportStatAggr::CountNonNull { + column_name: "value".to_string(), + }]; + let skip_config = StatsAwareSkipConfig::new(new_requirements).unwrap(); + let mut entry = file_builder_entry_with_cached_builder(Some(old_requirements.into())); + + PRUNER_ACTIVE_BUILDERS.inc(); + entry.clear_builder_if_skip_requirements_changed(Some(&skip_config)); + + assert!(entry.builder.is_none()); + assert!(entry.stats_aware_skip_requirements.is_none()); + } + + #[test] + fn clears_cached_builder_when_stats_aware_skip_is_removed() { + let old_requirements = vec![SupportStatAggr::CountRows]; + let mut entry = file_builder_entry_with_cached_builder(Some(old_requirements.into())); + + PRUNER_ACTIVE_BUILDERS.inc(); + entry.clear_builder_if_skip_requirements_changed(None); + + assert!(entry.builder.is_none()); + assert!(entry.stats_aware_skip_requirements.is_none()); + } +} diff --git a/src/query/src/optimizer/aggr_stats.rs b/src/query/src/optimizer/aggr_stats.rs index 6560ecb4a6..109e087da0 100644 --- a/src/query/src/optimizer/aggr_stats.rs +++ b/src/query/src/optimizer/aggr_stats.rs @@ -222,23 +222,171 @@ impl<'a> RewriteTarget<'a> { #[cfg(test)] mod tests { + use std::io::Cursor; + use std::sync::{Arc, Mutex}; + use api::v1::SemanticType; - use common_recordbatch::EmptyRecordBatchStream; + use bytes::Bytes; + use common_error::ext::BoxedError; + use common_query::aggr_stats::StatsCandidateFile; + use common_recordbatch::{ + EmptyRecordBatchStream, RecordBatch, RecordBatches, SendableRecordBatchStream, + }; use datafusion::functions_aggregate::average::avg_udaf; use datafusion::functions_aggregate::count::count_udaf; + use datafusion::parquet::arrow::ArrowWriter; + use datafusion::parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; + use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_plan::aggregates::PhysicalGroupBy; + use datafusion::physical_plan::collect; + use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::scalar::ScalarValue; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{Column as PhysicalColumn, Literal}; + use datatypes::arrow::array::Float64Array; + use datatypes::arrow::datatypes::DataType; + use datatypes::arrow::record_batch::RecordBatch as ArrowRecordBatch; use datatypes::data_type::ConcreteDataType; + use datatypes::prelude::VectorRef; use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::{Float64Vector, TimestampMillisecondVector}; use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder}; - use store_api::region_engine::SinglePartitionScanner; + use store_api::region_engine::{ + FileStatsItem, PrepareRequest, QueryScanContext, RegionScanner, RowGroupStatsItem, + ScannerProperties, SendableFileStatsStream, SinglePartitionScanner, + }; use store_api::storage::{RegionId, ScanRequest}; use super::*; + #[derive(Clone, Debug)] + struct ScanTestFile { + id: String, + stats: FileStatsItem, + batch: RecordBatch, + } + + #[derive(Debug)] + struct RecordingStatsScanner { + schema: Arc, + metadata: store_api::metadata::RegionMetadataRef, + properties: ScannerProperties, + files: Vec, + scanned_file_ids: Arc>>, + } + + impl RecordingStatsScanner { + fn new( + schema: Arc, + metadata: store_api::metadata::RegionMetadataRef, + files: Vec, + scanned_file_ids: Arc>>, + ) -> Self { + let total_rows = files.iter().map(|file| file.batch.num_rows()).sum(); + Self { + schema, + metadata, + properties: ScannerProperties::default() + .with_append_mode(true) + .with_total_rows(total_rows), + files, + scanned_file_ids, + } + } + + fn should_skip_file(&self, file: &ScanTestFile) -> bool { + let requirements = self.properties.stats_aware_skip_requirements(); + !requirements.is_empty() + && StatsCandidateFile::from_file_stats( + &file.stats, + self.metadata.partition_expr.as_deref(), + requirements, + &self.metadata.schema, + ) + .unwrap() + .is_some() + } + } + + impl RegionScanner for RecordingStatsScanner { + fn name(&self) -> &str { + "RecordingStatsScanner" + } + + fn properties(&self) -> &ScannerProperties { + &self.properties + } + + fn schema(&self) -> Arc { + self.schema.clone() + } + + fn metadata(&self) -> store_api::metadata::RegionMetadataRef { + self.metadata.clone() + } + + fn prepare(&mut self, request: PrepareRequest) -> std::result::Result<(), BoxedError> { + self.properties.prepare(request); + Ok(()) + } + + fn scan_partition( + &self, + _ctx: &QueryScanContext, + _metrics_set: &ExecutionPlanMetricsSet, + _partition: usize, + ) -> std::result::Result { + let batches = self + .files + .iter() + .filter(|file| !self.should_skip_file(file)) + .map(|file| { + self.scanned_file_ids.lock().unwrap().push(file.id.clone()); + file.batch.clone() + }) + .collect::>(); + + Ok(RecordBatches::try_new(self.schema.clone(), batches) + .unwrap() + .as_stream()) + } + + fn scan_stats( + &self, + _ctx: &QueryScanContext, + ) -> std::result::Result { + Ok(Box::pin(futures::stream::iter( + self.files.clone().into_iter().map(|file| Ok(file.stats)), + ))) + } + + fn has_predicate_without_region(&self) -> bool { + false + } + + fn add_dyn_filter_to_predicate( + &mut self, + filter_exprs: Vec>, + ) -> Vec { + vec![false; filter_exprs.len()] + } + + fn set_logical_region(&mut self, logical_region: bool) { + self.properties.set_logical_region(logical_region); + } + } + + impl datafusion::physical_plan::DisplayAs for RecordingStatsScanner { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "RecordingStatsScanner") + } + } + fn build_count_expr(schema: arrow_schema::SchemaRef) -> Arc { Arc::new( AggregateExprBuilder::new(count_udaf(), vec![Arc::new(PhysicalColumn::new("v0", 0))]) @@ -342,6 +490,124 @@ mod tests { Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap()) } + fn build_region_metadata( + partition_expr: Option<&str>, + ) -> store_api::metadata::RegionMetadataRef { + let mut metadata_builder = RegionMetadataBuilder::new(RegionId::new(1, 1)); + metadata_builder + .push_column_metadata(ColumnMetadata { + column_schema: ColumnSchema::new("v0", ConcreteDataType::float64_datatype(), true), + semantic_type: SemanticType::Field, + column_id: 1, + }) + .push_column_metadata(ColumnMetadata { + column_schema: ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + semantic_type: SemanticType::Timestamp, + column_id: 2, + }) + .primary_key(vec![]); + let mut metadata = metadata_builder.build().unwrap(); + metadata.set_partition_expr(partition_expr.map(str::to_string)); + Arc::new(metadata) + } + + fn build_float_row_groups(chunks: &[Vec>]) -> Vec { + let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "v0", + DataType::Float64, + true, + )])); + let mut buffer = Cursor::new(Vec::new()); + let props = WriterProperties::builder().build(); + let mut writer = + ArrowWriter::try_new(&mut buffer, arrow_schema.clone(), Some(props)).unwrap(); + + for chunk in chunks { + let batch = ArrowRecordBatch::try_new( + arrow_schema.clone(), + vec![Arc::new(Float64Array::from(chunk.clone()))], + ) + .unwrap(); + writer.write(&batch).unwrap(); + } + writer.close().unwrap(); + + let metadata = ParquetRecordBatchReaderBuilder::try_new(Bytes::from(buffer.into_inner())) + .unwrap() + .metadata() + .clone(); + metadata + .row_groups() + .iter() + .enumerate() + .map(|(row_group_index, metadata)| RowGroupStatsItem { + row_group_index, + metadata: Arc::new(metadata.clone()), + }) + .collect() + } + + fn build_scan_test_file( + schema: Arc, + id: &str, + partition_expr: &str, + values: Vec>, + with_row_groups: bool, + ts_start: i64, + ) -> ScanTestFile { + let batch = RecordBatch::new( + schema, + vec![ + Arc::new(Float64Vector::from(values.clone())) as VectorRef, + Arc::new(TimestampMillisecondVector::from_values( + (0..values.len()).map(|offset| ts_start + offset as i64), + )) as VectorRef, + ], + ) + .unwrap(); + + ScanTestFile { + id: id.to_string(), + stats: FileStatsItem { + num_rows: Some(values.len() as u64), + file_partition_expr: Some(partition_expr.to_string()), + row_groups: if with_row_groups { + build_float_row_groups(&[values]) + } else { + vec![] + }, + }, + batch, + } + } + + fn build_recording_region_scan( + files: Vec, + scanned_file_ids: Arc>>, + ) -> Arc { + let schema = Arc::new(Schema::new(vec![ + ColumnSchema::new("v0", ConcreteDataType::float64_datatype(), true), + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ), + ])); + let metadata = build_region_metadata(Some("host = 'a'")); + let scanner = Box::new(RecordingStatsScanner::new( + schema, + metadata, + files, + scanned_file_ids, + )); + Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap()) + } + fn build_final_over_partial_plan() -> Arc { build_final_over_partial_plan_with(build_region_scan(true), build_count_expr, None) } @@ -516,6 +782,70 @@ mod tests { assert_final_over_partial_without_union(&optimized); } + #[tokio::test] + async fn rewrite_mixed_plan_returns_correct_result_and_scans_only_fallback_files() { + let scanned_file_ids = Arc::new(Mutex::new(Vec::new())); + let schema = Arc::new(Schema::new(vec![ + ColumnSchema::new("v0", ConcreteDataType::float64_datatype(), true), + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ), + ])); + let region_scan = build_recording_region_scan( + vec![ + build_scan_test_file( + schema.clone(), + "eligible-a", + "host = 'a'", + vec![Some(1.0), None, Some(2.0)], + true, + 0, + ), + build_scan_test_file( + schema.clone(), + "eligible-b", + "host = 'a'", + vec![Some(3.0), Some(4.0)], + true, + 10, + ), + build_scan_test_file( + schema, + "fallback-c", + "host = 'a'", + vec![Some(5.0), None, Some(6.0)], + false, + 20, + ), + ], + scanned_file_ids.clone(), + ); + + let plan = build_final_over_partial_plan_with(region_scan, build_count_expr, None); + let optimized = AggrStatsPhysicalRule + .optimize(plan, &ConfigOptions::default()) + .unwrap(); + + let batches = collect( + optimized, + Arc::new(datafusion::execution::TaskContext::default()), + ) + .await + .unwrap(); + assert_eq!(batches.len(), 1); + let values = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), 6); + + let scanned = scanned_file_ids.lock().unwrap().clone(); + assert_eq!(scanned, vec!["fallback-c".to_string()]); + } + fn assert_rewritten_stats_requirement( plan: &Arc, expected: &[SupportStatAggr], diff --git a/src/query/src/optimizer/aggr_stats/stat_scan.rs b/src/query/src/optimizer/aggr_stats/stat_scan.rs index 51f24ebe91..b5f27e569b 100644 --- a/src/query/src/optimizer/aggr_stats/stat_scan.rs +++ b/src/query/src/optimizer/aggr_stats/stat_scan.rs @@ -66,16 +66,27 @@ fn build_state_scalar( field: &Field, requirement: &SupportStatAggr, ) -> Result { - let DataType::Struct(state_fields) = field.data_type() else { + let Some(value) = candidate.stat_value(requirement)? else { return Err(DataFusionError::Internal(format!( - "StatsScanExec expects struct state field, got {:?} for {}", - field.data_type(), - field.name() + "StatsScanExec built an ineligible stats candidate for requirement {:?}", + requirement ))); }; + + let DataType::Struct(state_fields) = field.data_type() else { + let output_type = ConcreteDataType::from_arrow_type(field.data_type()); + return value.try_to_scalar_value(&output_type).map_err(|error| { + DataFusionError::Internal(format!( + "StatsScanExec failed to convert state value for {}: {}", + field.name(), + error + )) + }); + }; + if state_fields.len() != 1 { return Err(DataFusionError::Internal(format!( - "StatsScanExec only supports single-field state in v1, got {} fields for {}", + "StatsScanExec only supports single-field state, got {} fields for {}", state_fields.len(), field.name() ))); @@ -83,12 +94,6 @@ fn build_state_scalar( let inner_field = state_fields[0].as_ref(); let output_type = ConcreteDataType::from_arrow_type(inner_field.data_type()); - let Some(value) = candidate.stat_value(requirement)? else { - return Err(DataFusionError::Internal(format!( - "StatsScanExec built an ineligible stats candidate for requirement {:?}", - requirement - ))); - }; let scalar = value.try_to_scalar_value(&output_type).map_err(|error| { DataFusionError::Internal(format!( @@ -422,19 +427,6 @@ mod tests { } } - fn single_state_field( - name: &str, - inner_name: &str, - inner_type: DataType, - inner_nullable: bool, - ) -> Field { - Field::new( - name, - DataType::Struct(vec![Field::new(inner_name, inner_type, inner_nullable)].into()), - true, - ) - } - fn build_region_metadata(partition_expr: Option<&str>) -> RegionMetadataRef { let mut builder = RegionMetadataBuilder::new(RegionId::new(1, 1)); builder.push_column_metadata(ColumnMetadata { @@ -521,21 +513,6 @@ mod tests { batches.into_iter().next().unwrap() } - fn assert_struct_state_matches_field<'a>( - batch: &'a RecordBatch, - column_index: usize, - expected_field: &Field, - ) -> &'a StructArray { - let struct_array = batch - .column(column_index) - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(struct_array.fields().len(), 1); - assert_eq!(struct_array.fields()[0].as_ref(), expected_field); - struct_array - } - #[tokio::test] async fn stats_scan_exec_matches_datafusion_count_state_field() { let aggr_expr = build_datafusion_aggr_expr(count_udaf(), "count(value)"); @@ -543,12 +520,7 @@ mod tests { assert_eq!(state_fields.len(), 1); let inner_field = state_fields[0].as_ref().clone(); - let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field( - "count_state", - inner_field.name(), - inner_field.data_type().clone(), - inner_field.is_nullable(), - )])); + let schema = Arc::new(arrow_schema::Schema::new(vec![inner_field.clone()])); let region_metadata = build_region_metadata(Some("host = 'a'")); let scanner = StaticStatsScanner { schema: region_metadata.schema.clone(), @@ -571,8 +543,8 @@ mod tests { let batch = collect_single_batch(&exec).await; - let struct_array = assert_struct_state_matches_field(&batch, 0, &inner_field); - let values = struct_array + assert_eq!(batch.schema().field(0).as_ref(), &inner_field); + let values = batch .column(0) .as_any() .downcast_ref::() @@ -587,12 +559,7 @@ mod tests { assert_eq!(state_fields.len(), 1); let inner_field = state_fields[0].as_ref().clone(); - let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field( - "min_state", - inner_field.name(), - inner_field.data_type().clone(), - inner_field.is_nullable(), - )])); + let schema = Arc::new(arrow_schema::Schema::new(vec![inner_field.clone()])); let region_metadata = build_region_metadata(Some("host = 'a'")); let scanner = StaticStatsScanner { schema: region_metadata.schema.clone(), @@ -618,8 +585,8 @@ mod tests { let batch = collect_single_batch(&exec).await; - let struct_array = assert_struct_state_matches_field(&batch, 0, &inner_field); - let values = struct_array + assert_eq!(batch.schema().field(0).as_ref(), &inner_field); + let values = batch .column(0) .as_any() .downcast_ref::() @@ -634,12 +601,7 @@ mod tests { assert_eq!(state_fields.len(), 1); let inner_field = state_fields[0].as_ref().clone(); - let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field( - "max_state", - inner_field.name(), - inner_field.data_type().clone(), - inner_field.is_nullable(), - )])); + let schema = Arc::new(arrow_schema::Schema::new(vec![inner_field.clone()])); let region_metadata = build_region_metadata(Some("host = 'a'")); let scanner = StaticStatsScanner { schema: region_metadata.schema.clone(), @@ -665,8 +627,8 @@ mod tests { let batch = collect_single_batch(&exec).await; - let struct_array = assert_struct_state_matches_field(&batch, 0, &inner_field); - let values = struct_array + assert_eq!(batch.schema().field(0).as_ref(), &inner_field); + let values = batch .column(0) .as_any() .downcast_ref::() @@ -680,17 +642,12 @@ mod tests { let state_fields = aggr_expr.state_fields().unwrap(); assert!(state_fields.len() > 1); - let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "avg_state", - DataType::Struct( - state_fields - .iter() - .map(|field| field.as_ref().clone()) - .collect::>() - .into(), - ), - true, - )])); + let schema = Arc::new(arrow_schema::Schema::new( + state_fields + .iter() + .map(|field| field.as_ref().clone()) + .collect::>(), + )); let region_metadata = build_region_metadata(Some("host = 'a'")); let scanner = StaticStatsScanner { schema: region_metadata.schema.clone(), @@ -714,18 +671,14 @@ mod tests { let mut stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap(); let error = stream.next().await.unwrap().unwrap_err(); - assert!( - error - .to_string() - .contains("only supports single-field state in v1") - ); + assert!(error.to_string().contains("schema/requirement mismatch")); } #[tokio::test] async fn stats_scan_exec_emits_state_rows_for_eligible_files() { let schema = Arc::new(arrow_schema::Schema::new(vec![ - single_state_field("count_state", "count[count]", DataType::Int64, false), - single_state_field("max_state", "max[max]", DataType::Int64, true), + Field::new("count[count]", DataType::Int64, false), + Field::new("max[max]", DataType::Int64, true), ])); let requirements = vec![ SupportStatAggr::CountNonNull { @@ -778,26 +731,16 @@ mod tests { let batch = &batches[0]; assert_eq!(batch.num_rows(), 1); - let count_state = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - let count_values = count_state + let count_values = batch .column(0) .as_any() .downcast_ref::() .unwrap(); assert_eq!(count_values.value(0), 7); - let max_state = batch + let max_values = batch .column(1) .as_any() - .downcast_ref::() - .unwrap(); - let max_values = max_state - .column(0) - .as_any() .downcast_ref::() .unwrap(); assert_eq!(max_values.value(0), 9); @@ -805,8 +748,7 @@ mod tests { #[tokio::test] async fn stats_scan_exec_emits_no_batches_when_all_files_fallback() { - let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field( - "count_state", + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( "count[count]", DataType::Int64, false, @@ -846,8 +788,7 @@ mod tests { #[tokio::test] async fn stats_scan_exec_count_rows_uses_file_num_rows_without_row_groups() { - let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field( - "count_state", + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( "count[count]", DataType::Int64, false, @@ -880,12 +821,7 @@ mod tests { let batch = collect_single_batch(&exec).await; assert_eq!(batch.num_rows(), 2); - let count_state = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - let count_values = count_state + let count_values = batch .column(0) .as_any() .downcast_ref::() diff --git a/src/query/src/optimizer/aggr_stats/support_aggr.rs b/src/query/src/optimizer/aggr_stats/support_aggr.rs index 5cdd4660b7..aa6a8541bd 100644 --- a/src/query/src/optimizer/aggr_stats/support_aggr.rs +++ b/src/query/src/optimizer/aggr_stats/support_aggr.rs @@ -13,7 +13,6 @@ // limitations under the License. use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::expressions::{Column as PhysicalColumn, Literal}; pub use store_api::region_engine::SupportStatAggr;