feat: support flat fields

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-04-27 11:53:18 +08:00
parent 08c66ab00b
commit 481b21ee18
4 changed files with 435 additions and 106 deletions

View File

@@ -625,3 +625,67 @@ impl Pruner {
);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn file_builder_entry_with_cached_builder(
requirements: Option<Arc<[SupportStatAggr]>>,
) -> FileBuilderEntry {
FileBuilderEntry {
builder: Some(Arc::new(FileRangeBuilder::default())),
stats_aware_skip_requirements: requirements,
remaining_ranges: 1,
waiters: Vec::new(),
}
}
#[test]
fn keeps_cached_builder_when_stats_aware_skip_requirements_match() {
let requirements = vec![SupportStatAggr::CountRows];
let skip_config = StatsAwareSkipConfig::new(requirements.clone()).unwrap();
let mut entry = file_builder_entry_with_cached_builder(Some(requirements.into()));
let cached_builder = entry.builder.as_ref().unwrap().clone();
entry.clear_builder_if_skip_requirements_changed(Some(&skip_config));
assert!(entry.builder.is_some());
assert!(Arc::ptr_eq(
entry.builder.as_ref().unwrap(),
&cached_builder
));
assert_eq!(
entry.stats_aware_skip_requirements.as_deref(),
Some(skip_config.requirements())
);
}
#[test]
fn clears_cached_builder_when_stats_aware_skip_requirements_change() {
let old_requirements = vec![SupportStatAggr::CountRows];
let new_requirements = vec![SupportStatAggr::CountNonNull {
column_name: "value".to_string(),
}];
let skip_config = StatsAwareSkipConfig::new(new_requirements).unwrap();
let mut entry = file_builder_entry_with_cached_builder(Some(old_requirements.into()));
PRUNER_ACTIVE_BUILDERS.inc();
entry.clear_builder_if_skip_requirements_changed(Some(&skip_config));
assert!(entry.builder.is_none());
assert!(entry.stats_aware_skip_requirements.is_none());
}
#[test]
fn clears_cached_builder_when_stats_aware_skip_is_removed() {
let old_requirements = vec![SupportStatAggr::CountRows];
let mut entry = file_builder_entry_with_cached_builder(Some(old_requirements.into()));
PRUNER_ACTIVE_BUILDERS.inc();
entry.clear_builder_if_skip_requirements_changed(None);
assert!(entry.builder.is_none());
assert!(entry.stats_aware_skip_requirements.is_none());
}
}

View File

@@ -222,23 +222,171 @@ impl<'a> RewriteTarget<'a> {
#[cfg(test)]
mod tests {
use std::io::Cursor;
use std::sync::{Arc, Mutex};
use api::v1::SemanticType;
use common_recordbatch::EmptyRecordBatchStream;
use bytes::Bytes;
use common_error::ext::BoxedError;
use common_query::aggr_stats::StatsCandidateFile;
use common_recordbatch::{
EmptyRecordBatchStream, RecordBatch, RecordBatches, SendableRecordBatchStream,
};
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::physical_plan::aggregates::PhysicalGroupBy;
use datafusion::physical_plan::collect;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion::scalar::ScalarValue;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion_physical_expr::expressions::{Column as PhysicalColumn, Literal};
use datatypes::arrow::array::Float64Array;
use datatypes::arrow::datatypes::DataType;
use datatypes::arrow::record_batch::RecordBatch as ArrowRecordBatch;
use datatypes::data_type::ConcreteDataType;
use datatypes::prelude::VectorRef;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{Float64Vector, TimestampMillisecondVector};
use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
use store_api::region_engine::SinglePartitionScanner;
use store_api::region_engine::{
FileStatsItem, PrepareRequest, QueryScanContext, RegionScanner, RowGroupStatsItem,
ScannerProperties, SendableFileStatsStream, SinglePartitionScanner,
};
use store_api::storage::{RegionId, ScanRequest};
use super::*;
#[derive(Clone, Debug)]
struct ScanTestFile {
id: String,
stats: FileStatsItem,
batch: RecordBatch,
}
#[derive(Debug)]
struct RecordingStatsScanner {
schema: Arc<Schema>,
metadata: store_api::metadata::RegionMetadataRef,
properties: ScannerProperties,
files: Vec<ScanTestFile>,
scanned_file_ids: Arc<Mutex<Vec<String>>>,
}
impl RecordingStatsScanner {
fn new(
schema: Arc<Schema>,
metadata: store_api::metadata::RegionMetadataRef,
files: Vec<ScanTestFile>,
scanned_file_ids: Arc<Mutex<Vec<String>>>,
) -> Self {
let total_rows = files.iter().map(|file| file.batch.num_rows()).sum();
Self {
schema,
metadata,
properties: ScannerProperties::default()
.with_append_mode(true)
.with_total_rows(total_rows),
files,
scanned_file_ids,
}
}
fn should_skip_file(&self, file: &ScanTestFile) -> bool {
let requirements = self.properties.stats_aware_skip_requirements();
!requirements.is_empty()
&& StatsCandidateFile::from_file_stats(
&file.stats,
self.metadata.partition_expr.as_deref(),
requirements,
&self.metadata.schema,
)
.unwrap()
.is_some()
}
}
impl RegionScanner for RecordingStatsScanner {
fn name(&self) -> &str {
"RecordingStatsScanner"
}
fn properties(&self) -> &ScannerProperties {
&self.properties
}
fn schema(&self) -> Arc<Schema> {
self.schema.clone()
}
fn metadata(&self) -> store_api::metadata::RegionMetadataRef {
self.metadata.clone()
}
fn prepare(&mut self, request: PrepareRequest) -> std::result::Result<(), BoxedError> {
self.properties.prepare(request);
Ok(())
}
fn scan_partition(
&self,
_ctx: &QueryScanContext,
_metrics_set: &ExecutionPlanMetricsSet,
_partition: usize,
) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
let batches = self
.files
.iter()
.filter(|file| !self.should_skip_file(file))
.map(|file| {
self.scanned_file_ids.lock().unwrap().push(file.id.clone());
file.batch.clone()
})
.collect::<Vec<_>>();
Ok(RecordBatches::try_new(self.schema.clone(), batches)
.unwrap()
.as_stream())
}
fn scan_stats(
&self,
_ctx: &QueryScanContext,
) -> std::result::Result<SendableFileStatsStream, BoxedError> {
Ok(Box::pin(futures::stream::iter(
self.files.clone().into_iter().map(|file| Ok(file.stats)),
)))
}
fn has_predicate_without_region(&self) -> bool {
false
}
fn add_dyn_filter_to_predicate(
&mut self,
filter_exprs: Vec<Arc<dyn datafusion::physical_plan::PhysicalExpr>>,
) -> Vec<bool> {
vec![false; filter_exprs.len()]
}
fn set_logical_region(&mut self, logical_region: bool) {
self.properties.set_logical_region(logical_region);
}
}
impl datafusion::physical_plan::DisplayAs for RecordingStatsScanner {
fn fmt_as(
&self,
_t: datafusion::physical_plan::DisplayFormatType,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
write!(f, "RecordingStatsScanner")
}
}
fn build_count_expr(schema: arrow_schema::SchemaRef) -> Arc<AggregateFunctionExpr> {
Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![Arc::new(PhysicalColumn::new("v0", 0))])
@@ -342,6 +490,124 @@ mod tests {
Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap())
}
fn build_region_metadata(
partition_expr: Option<&str>,
) -> store_api::metadata::RegionMetadataRef {
let mut metadata_builder = RegionMetadataBuilder::new(RegionId::new(1, 1));
metadata_builder
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new("v0", ConcreteDataType::float64_datatype(), true),
semantic_type: SemanticType::Field,
column_id: 1,
})
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
)
.with_time_index(true),
semantic_type: SemanticType::Timestamp,
column_id: 2,
})
.primary_key(vec![]);
let mut metadata = metadata_builder.build().unwrap();
metadata.set_partition_expr(partition_expr.map(str::to_string));
Arc::new(metadata)
}
fn build_float_row_groups(chunks: &[Vec<Option<f64>>]) -> Vec<RowGroupStatsItem> {
let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
"v0",
DataType::Float64,
true,
)]));
let mut buffer = Cursor::new(Vec::new());
let props = WriterProperties::builder().build();
let mut writer =
ArrowWriter::try_new(&mut buffer, arrow_schema.clone(), Some(props)).unwrap();
for chunk in chunks {
let batch = ArrowRecordBatch::try_new(
arrow_schema.clone(),
vec![Arc::new(Float64Array::from(chunk.clone()))],
)
.unwrap();
writer.write(&batch).unwrap();
}
writer.close().unwrap();
let metadata = ParquetRecordBatchReaderBuilder::try_new(Bytes::from(buffer.into_inner()))
.unwrap()
.metadata()
.clone();
metadata
.row_groups()
.iter()
.enumerate()
.map(|(row_group_index, metadata)| RowGroupStatsItem {
row_group_index,
metadata: Arc::new(metadata.clone()),
})
.collect()
}
fn build_scan_test_file(
schema: Arc<Schema>,
id: &str,
partition_expr: &str,
values: Vec<Option<f64>>,
with_row_groups: bool,
ts_start: i64,
) -> ScanTestFile {
let batch = RecordBatch::new(
schema,
vec![
Arc::new(Float64Vector::from(values.clone())) as VectorRef,
Arc::new(TimestampMillisecondVector::from_values(
(0..values.len()).map(|offset| ts_start + offset as i64),
)) as VectorRef,
],
)
.unwrap();
ScanTestFile {
id: id.to_string(),
stats: FileStatsItem {
num_rows: Some(values.len() as u64),
file_partition_expr: Some(partition_expr.to_string()),
row_groups: if with_row_groups {
build_float_row_groups(&[values])
} else {
vec![]
},
},
batch,
}
}
fn build_recording_region_scan(
files: Vec<ScanTestFile>,
scanned_file_ids: Arc<Mutex<Vec<String>>>,
) -> Arc<RegionScanExec> {
let schema = Arc::new(Schema::new(vec![
ColumnSchema::new("v0", ConcreteDataType::float64_datatype(), true),
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
]));
let metadata = build_region_metadata(Some("host = 'a'"));
let scanner = Box::new(RecordingStatsScanner::new(
schema,
metadata,
files,
scanned_file_ids,
));
Arc::new(RegionScanExec::new(scanner, ScanRequest::default(), None).unwrap())
}
fn build_final_over_partial_plan() -> Arc<dyn ExecutionPlan> {
build_final_over_partial_plan_with(build_region_scan(true), build_count_expr, None)
}
@@ -516,6 +782,70 @@ mod tests {
assert_final_over_partial_without_union(&optimized);
}
#[tokio::test]
async fn rewrite_mixed_plan_returns_correct_result_and_scans_only_fallback_files() {
let scanned_file_ids = Arc::new(Mutex::new(Vec::new()));
let schema = Arc::new(Schema::new(vec![
ColumnSchema::new("v0", ConcreteDataType::float64_datatype(), true),
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
]));
let region_scan = build_recording_region_scan(
vec![
build_scan_test_file(
schema.clone(),
"eligible-a",
"host = 'a'",
vec![Some(1.0), None, Some(2.0)],
true,
0,
),
build_scan_test_file(
schema.clone(),
"eligible-b",
"host = 'a'",
vec![Some(3.0), Some(4.0)],
true,
10,
),
build_scan_test_file(
schema,
"fallback-c",
"host = 'a'",
vec![Some(5.0), None, Some(6.0)],
false,
20,
),
],
scanned_file_ids.clone(),
);
let plan = build_final_over_partial_plan_with(region_scan, build_count_expr, None);
let optimized = AggrStatsPhysicalRule
.optimize(plan, &ConfigOptions::default())
.unwrap();
let batches = collect(
optimized,
Arc::new(datafusion::execution::TaskContext::default()),
)
.await
.unwrap();
assert_eq!(batches.len(), 1);
let values = batches[0]
.column(0)
.as_any()
.downcast_ref::<datafusion::arrow::array::Int64Array>()
.unwrap();
assert_eq!(values.value(0), 6);
let scanned = scanned_file_ids.lock().unwrap().clone();
assert_eq!(scanned, vec!["fallback-c".to_string()]);
}
fn assert_rewritten_stats_requirement(
plan: &Arc<dyn ExecutionPlan>,
expected: &[SupportStatAggr],

View File

@@ -66,16 +66,27 @@ fn build_state_scalar(
field: &Field,
requirement: &SupportStatAggr,
) -> Result<ScalarValue> {
let DataType::Struct(state_fields) = field.data_type() else {
let Some(value) = candidate.stat_value(requirement)? else {
return Err(DataFusionError::Internal(format!(
"StatsScanExec expects struct state field, got {:?} for {}",
field.data_type(),
field.name()
"StatsScanExec built an ineligible stats candidate for requirement {:?}",
requirement
)));
};
let DataType::Struct(state_fields) = field.data_type() else {
let output_type = ConcreteDataType::from_arrow_type(field.data_type());
return value.try_to_scalar_value(&output_type).map_err(|error| {
DataFusionError::Internal(format!(
"StatsScanExec failed to convert state value for {}: {}",
field.name(),
error
))
});
};
if state_fields.len() != 1 {
return Err(DataFusionError::Internal(format!(
"StatsScanExec only supports single-field state in v1, got {} fields for {}",
"StatsScanExec only supports single-field state, got {} fields for {}",
state_fields.len(),
field.name()
)));
@@ -83,12 +94,6 @@ fn build_state_scalar(
let inner_field = state_fields[0].as_ref();
let output_type = ConcreteDataType::from_arrow_type(inner_field.data_type());
let Some(value) = candidate.stat_value(requirement)? else {
return Err(DataFusionError::Internal(format!(
"StatsScanExec built an ineligible stats candidate for requirement {:?}",
requirement
)));
};
let scalar = value.try_to_scalar_value(&output_type).map_err(|error| {
DataFusionError::Internal(format!(
@@ -422,19 +427,6 @@ mod tests {
}
}
fn single_state_field(
name: &str,
inner_name: &str,
inner_type: DataType,
inner_nullable: bool,
) -> Field {
Field::new(
name,
DataType::Struct(vec![Field::new(inner_name, inner_type, inner_nullable)].into()),
true,
)
}
fn build_region_metadata(partition_expr: Option<&str>) -> RegionMetadataRef {
let mut builder = RegionMetadataBuilder::new(RegionId::new(1, 1));
builder.push_column_metadata(ColumnMetadata {
@@ -521,21 +513,6 @@ mod tests {
batches.into_iter().next().unwrap()
}
fn assert_struct_state_matches_field<'a>(
batch: &'a RecordBatch,
column_index: usize,
expected_field: &Field,
) -> &'a StructArray {
let struct_array = batch
.column(column_index)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
assert_eq!(struct_array.fields().len(), 1);
assert_eq!(struct_array.fields()[0].as_ref(), expected_field);
struct_array
}
#[tokio::test]
async fn stats_scan_exec_matches_datafusion_count_state_field() {
let aggr_expr = build_datafusion_aggr_expr(count_udaf(), "count(value)");
@@ -543,12 +520,7 @@ mod tests {
assert_eq!(state_fields.len(), 1);
let inner_field = state_fields[0].as_ref().clone();
let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field(
"count_state",
inner_field.name(),
inner_field.data_type().clone(),
inner_field.is_nullable(),
)]));
let schema = Arc::new(arrow_schema::Schema::new(vec![inner_field.clone()]));
let region_metadata = build_region_metadata(Some("host = 'a'"));
let scanner = StaticStatsScanner {
schema: region_metadata.schema.clone(),
@@ -571,8 +543,8 @@ mod tests {
let batch = collect_single_batch(&exec).await;
let struct_array = assert_struct_state_matches_field(&batch, 0, &inner_field);
let values = struct_array
assert_eq!(batch.schema().field(0).as_ref(), &inner_field);
let values = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
@@ -587,12 +559,7 @@ mod tests {
assert_eq!(state_fields.len(), 1);
let inner_field = state_fields[0].as_ref().clone();
let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field(
"min_state",
inner_field.name(),
inner_field.data_type().clone(),
inner_field.is_nullable(),
)]));
let schema = Arc::new(arrow_schema::Schema::new(vec![inner_field.clone()]));
let region_metadata = build_region_metadata(Some("host = 'a'"));
let scanner = StaticStatsScanner {
schema: region_metadata.schema.clone(),
@@ -618,8 +585,8 @@ mod tests {
let batch = collect_single_batch(&exec).await;
let struct_array = assert_struct_state_matches_field(&batch, 0, &inner_field);
let values = struct_array
assert_eq!(batch.schema().field(0).as_ref(), &inner_field);
let values = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
@@ -634,12 +601,7 @@ mod tests {
assert_eq!(state_fields.len(), 1);
let inner_field = state_fields[0].as_ref().clone();
let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field(
"max_state",
inner_field.name(),
inner_field.data_type().clone(),
inner_field.is_nullable(),
)]));
let schema = Arc::new(arrow_schema::Schema::new(vec![inner_field.clone()]));
let region_metadata = build_region_metadata(Some("host = 'a'"));
let scanner = StaticStatsScanner {
schema: region_metadata.schema.clone(),
@@ -665,8 +627,8 @@ mod tests {
let batch = collect_single_batch(&exec).await;
let struct_array = assert_struct_state_matches_field(&batch, 0, &inner_field);
let values = struct_array
assert_eq!(batch.schema().field(0).as_ref(), &inner_field);
let values = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
@@ -680,17 +642,12 @@ mod tests {
let state_fields = aggr_expr.state_fields().unwrap();
assert!(state_fields.len() > 1);
let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
"avg_state",
DataType::Struct(
state_fields
.iter()
.map(|field| field.as_ref().clone())
.collect::<Vec<_>>()
.into(),
),
true,
)]));
let schema = Arc::new(arrow_schema::Schema::new(
state_fields
.iter()
.map(|field| field.as_ref().clone())
.collect::<Vec<_>>(),
));
let region_metadata = build_region_metadata(Some("host = 'a'"));
let scanner = StaticStatsScanner {
schema: region_metadata.schema.clone(),
@@ -714,18 +671,14 @@ mod tests {
let mut stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
let error = stream.next().await.unwrap().unwrap_err();
assert!(
error
.to_string()
.contains("only supports single-field state in v1")
);
assert!(error.to_string().contains("schema/requirement mismatch"));
}
#[tokio::test]
async fn stats_scan_exec_emits_state_rows_for_eligible_files() {
let schema = Arc::new(arrow_schema::Schema::new(vec![
single_state_field("count_state", "count[count]", DataType::Int64, false),
single_state_field("max_state", "max[max]", DataType::Int64, true),
Field::new("count[count]", DataType::Int64, false),
Field::new("max[max]", DataType::Int64, true),
]));
let requirements = vec![
SupportStatAggr::CountNonNull {
@@ -778,26 +731,16 @@ mod tests {
let batch = &batches[0];
assert_eq!(batch.num_rows(), 1);
let count_state = batch
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let count_values = count_state
let count_values = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(count_values.value(0), 7);
let max_state = batch
let max_values = batch
.column(1)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let max_values = max_state
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(max_values.value(0), 9);
@@ -805,8 +748,7 @@ mod tests {
#[tokio::test]
async fn stats_scan_exec_emits_no_batches_when_all_files_fallback() {
let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field(
"count_state",
let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
"count[count]",
DataType::Int64,
false,
@@ -846,8 +788,7 @@ mod tests {
#[tokio::test]
async fn stats_scan_exec_count_rows_uses_file_num_rows_without_row_groups() {
let schema = Arc::new(arrow_schema::Schema::new(vec![single_state_field(
"count_state",
let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
"count[count]",
DataType::Int64,
false,
@@ -880,12 +821,7 @@ mod tests {
let batch = collect_single_batch(&exec).await;
assert_eq!(batch.num_rows(), 2);
let count_state = batch
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let count_values = count_state
let count_values = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()

View File

@@ -13,7 +13,6 @@
// limitations under the License.
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datafusion_physical_expr::expressions::{Column as PhysicalColumn, Literal};
pub use store_api::region_engine::SupportStatAggr;