From 1440924955774198cf83d2f221ec7ddd8d260c16 Mon Sep 17 00:00:00 2001 From: fys <40801205+fengys1996@users.noreply.github.com> Date: Wed, 22 Apr 2026 17:35:21 +0800 Subject: [PATCH] fix: nested projection missing roots (#7993) * fix: nested projection missing roots * add docs and unit test * fix: cr by ai * move some computations to new method --- src/mito2/src/sst/parquet/file_range.rs | 5 - src/mito2/src/sst/parquet/flat_format.rs | 39 ++- src/mito2/src/sst/parquet/read_columns.rs | 160 +++++++++-- src/mito2/src/sst/parquet/reader.rs | 67 +++-- src/mito2/src/sst/parquet/reader/stream.rs | 314 +++++++++++++++++++++ 5 files changed, 525 insertions(+), 60 deletions(-) create mode 100644 src/mito2/src/sst/parquet/reader/stream.rs diff --git a/src/mito2/src/sst/parquet/file_range.rs b/src/mito2/src/sst/parquet/file_range.rs index e8ae1a788a..8034310eee 100644 --- a/src/mito2/src/sst/parquet/file_range.rs +++ b/src/mito2/src/sst/parquet/file_range.rs @@ -260,11 +260,6 @@ impl FileRangeContext { } } - /// Returns the path of the file to read. - pub(crate) fn file_path(&self) -> &str { - self.reader_builder.file_path() - } - /// Returns filters pushed down. pub(crate) fn filters(&self) -> &[SimpleFilterContext] { &self.base.filters diff --git a/src/mito2/src/sst/parquet/flat_format.rs b/src/mito2/src/sst/parquet/flat_format.rs index f4c2ea3eca..bf3709e310 100644 --- a/src/mito2/src/sst/parquet/flat_format.rs +++ b/src/mito2/src/sst/parquet/flat_format.rs @@ -255,6 +255,15 @@ impl FlatReadFormat { } } + /// Gets the projected output schema produced by parquet reading. + pub(crate) fn output_arrow_schema(&self) -> Result { + let schema = self + .arrow_schema() + .project(self.projection_indices()) + .context(ComputeArrowSnafu)?; + Ok(Arc::new(schema)) + } + /// Gets the metadata of the SST. pub(crate) fn metadata(&self) -> &RegionMetadataRef { match &self.parquet_adapter { @@ -790,6 +799,8 @@ impl FlatReadFormat { #[cfg(test)] mod tests { + use std::sync::Arc; + use api::v1::SemanticType; use datatypes::prelude::ConcreteDataType; use datatypes::schema::ColumnSchema; @@ -797,8 +808,10 @@ mod tests { use store_api::metadata::{ColumnMetadata, RegionMetadata, RegionMetadataBuilder}; use store_api::storage::RegionId; - use super::field_column_start; - use crate::sst::{FlatSchemaOptions, flat_sst_arrow_schema_column_num}; + use super::{FlatReadFormat, field_column_start}; + use crate::sst::{ + FlatSchemaOptions, flat_sst_arrow_schema_column_num, to_flat_sst_arrow_schema, + }; /// Builds a `RegionMetadata` with the given number of tags and fields. fn build_metadata( @@ -872,4 +885,26 @@ mod tests { ); } } + + #[test] + fn test_output_arrow_schema_uses_projection() { + let metadata = Arc::new(build_metadata(1, 2, PrimaryKeyEncoding::Dense)); + let read_format = FlatReadFormat::new( + metadata.clone(), + [0_u32, 2_u32].into_iter(), + None, + "test", + false, + ) + .unwrap(); + + let output_schema = read_format.output_arrow_schema().unwrap(); + let expected = Arc::new( + to_flat_sst_arrow_schema(&metadata, &FlatSchemaOptions::default()) + .project(read_format.projection_indices()) + .unwrap(), + ); + + assert_eq!(expected, output_schema); + } } diff --git a/src/mito2/src/sst/parquet/read_columns.rs b/src/mito2/src/sst/parquet/read_columns.rs index f0f35a4099..ee177822d1 100644 --- a/src/mito2/src/sst/parquet/read_columns.rs +++ b/src/mito2/src/sst/parquet/read_columns.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use parquet::arrow::ProjectionMask; use parquet::schema::types::SchemaDescriptor; @@ -104,30 +104,87 @@ impl ParquetReadColumn { } } -/// Builds a projection mask from parquet read columns. -pub fn build_projection_mask( +/// Projection plan built for a parquet file. +#[derive(Clone)] +pub struct ProjectionMaskPlan { + /// `mask` is the projection mask applied to the parquet reader. + pub mask: ProjectionMask, + /// A boolean mask in output schema order indicating whether each + /// projected root column is physically present in the parquet + /// read result. + /// + /// - `true`: the column exists in the input `RecordBatch`. + /// - `false`: the column is missing (e.g., due to unmatched nested + /// paths) and must be synthesized during post-processing (typically + /// filled with null/default values). + /// + /// The length of `projected_root_presence` is always equal to the + /// number of fields in the output schema. + pub projected_root_presence: Vec, +} + +/// Builds a projection mask plan for reading a parquet file. +/// +/// `parquet_read_cols` defines the requested root columns and optional +/// nested paths to read. +/// +/// `parquet_schema_desc` is the schema descriptor of the current parquet +/// file. It is used to resolve requested nested paths to actual leaf +/// column indices. +/// +/// See [`ProjectionMaskPlan`] for the returned value. +/// +/// For example, if the query requests `j.a` and `k`, but the current +/// parquet file only contains leaves under `j.b` and `k`, then the +/// returned plan keeps `k` in the projection mask and marks `j` as +/// not present in the output, so it can be synthesized during +/// post-processing. +pub fn build_projection_plan( parquet_read_cols: &ParquetReadColumns, parquet_schema_desc: &SchemaDescriptor, -) -> ProjectionMask { - if parquet_read_cols.has_nested() { - let leaf_indices = build_parquet_leaves_indices(parquet_schema_desc, parquet_read_cols); - ProjectionMask::leaves(parquet_schema_desc, leaf_indices) - } else { - ProjectionMask::roots(parquet_schema_desc, parquet_read_cols.root_indices_iter()) +) -> ProjectionMaskPlan { + if !parquet_read_cols.has_nested() { + let mask = + ProjectionMask::roots(parquet_schema_desc, parquet_read_cols.root_indices_iter()); + return ProjectionMaskPlan { + mask, + projected_root_presence: vec![true; parquet_read_cols.columns().len()], + }; + } + + let (leaf_indices, matched_roots) = + build_parquet_leaves_indices(parquet_schema_desc, parquet_read_cols); + + let projected_root_presence = parquet_read_cols + .columns() + .iter() + .map(|col| matched_roots.contains(&col.root_index())) + .collect(); + + let mask = ProjectionMask::leaves(parquet_schema_desc, leaf_indices); + ProjectionMaskPlan { + mask, + projected_root_presence, } } -/// Builds parquet leaf-column indices from parquet read columns. +/// Builds parquet leaf-column indices for reading a parquet file. +/// +/// Returns `(leaf_indices, matched_roots)`: +/// - `leaf_indices`: matched parquet leaf column indices +/// - `matched_roots`: root column indices that match at least one leaf in the +/// current parquet schema. fn build_parquet_leaves_indices( parquet_schema_desc: &SchemaDescriptor, projection: &ParquetReadColumns, -) -> Vec { +) -> (Vec, HashSet) { let mut map = HashMap::with_capacity(projection.cols.len()); for col in &projection.cols { map.insert(col.root_index, &col.nested_paths); } let mut leaf_indices = Vec::new(); + let mut matched_roots = HashSet::with_capacity(projection.cols.len()); for (leaf_idx, leaf_col) in parquet_schema_desc.columns().iter().enumerate() { let root_idx = parquet_schema_desc.get_column_root_idx(leaf_idx); let Some(nested_paths) = map.get(&root_idx) else { @@ -135,6 +192,7 @@ fn build_parquet_leaves_indices( }; if nested_paths.is_empty() { leaf_indices.push(leaf_idx); + matched_roots.insert(root_idx); continue; } @@ -144,9 +202,10 @@ fn build_parquet_leaves_indices( .any(|nested_path| leaf_path.starts_with(nested_path)) { leaf_indices.push(leaf_idx); + matched_roots.insert(root_idx); } } - leaf_indices + (leaf_indices, matched_roots) } #[cfg(test)] @@ -158,6 +217,20 @@ mod tests { use super::*; + #[test] + fn test_build_projection_mask_without_nested_paths() { + let parquet_schema_desc = build_test_nested_parquet_schema(); + let projection = ParquetReadColumns::from_deduped_root_indices([0, 1]); + + let plan = build_projection_plan(&projection, &parquet_schema_desc); + + assert_eq!(vec![true, true], plan.projected_root_presence); + assert_eq!( + ProjectionMask::roots(&parquet_schema_desc, [0, 1]), + plan.mask + ); + } + #[test] fn test_reads_whole_root() { let parquet_schema_desc = build_test_nested_parquet_schema(); @@ -170,10 +243,10 @@ mod tests { has_nested: false, }; - assert_eq!( - vec![0, 1, 2], - build_parquet_leaves_indices(&parquet_schema_desc, &projection) - ); + let (leaf_indices, matched_roots) = + build_parquet_leaves_indices(&parquet_schema_desc, &projection); + assert_eq!(vec![0, 1, 2], leaf_indices); + assert_eq!(HashSet::from([0]), matched_roots); } #[test] @@ -194,10 +267,10 @@ mod tests { has_nested: true, }; - assert_eq!( - vec![1, 2, 3], - build_parquet_leaves_indices(&parquet_schema_desc, &projection) - ); + let (leaf_indices, matched_roots) = + build_parquet_leaves_indices(&parquet_schema_desc, &projection); + assert_eq!(vec![1, 2, 3], leaf_indices); + assert_eq!(HashSet::from([0, 1]), matched_roots); } #[test] @@ -212,10 +285,10 @@ mod tests { has_nested: true, }; - assert_eq!( - vec![1, 2], - build_parquet_leaves_indices(&parquet_schema_desc, &projection) - ); + let (leaf_indices, matched_roots) = + build_parquet_leaves_indices(&parquet_schema_desc, &projection); + assert_eq!(vec![1, 2], leaf_indices); + assert_eq!(HashSet::from([0]), matched_roots); } #[test] @@ -230,9 +303,36 @@ mod tests { has_nested: true, }; + let (leaf_indices, matched_roots) = + build_parquet_leaves_indices(&parquet_schema_desc, &projection); + assert_eq!(vec![1], leaf_indices); + assert_eq!(HashSet::from([0]), matched_roots); + } + + #[test] + fn test_build_projection_mask_with_unmatched_roots() { + let parquet_schema_desc = build_test_nested_parquet_schema(); + + let projection = ParquetReadColumns { + cols: vec![ + ParquetReadColumn { + root_index: 0, + nested_paths: vec![vec!["j".to_string(), "missing".to_string()]], + }, + ParquetReadColumn { + root_index: 1, + nested_paths: vec![], + }, + ], + has_nested: true, + }; + + let plan = build_projection_plan(&projection, &parquet_schema_desc); + + assert_eq!(vec![false, true], plan.projected_root_presence); assert_eq!( - vec![1], - build_parquet_leaves_indices(&parquet_schema_desc, &projection) + ProjectionMask::leaves(&parquet_schema_desc, vec![3]), + plan.mask ); } @@ -251,10 +351,10 @@ mod tests { has_nested: true, }; - assert_eq!( - vec![0, 2], - build_parquet_leaves_indices(&parquet_schema_desc, &projection) - ); + let (leaf_indices, matched_roots) = + build_parquet_leaves_indices(&parquet_schema_desc, &projection); + assert_eq!(vec![0, 2], leaf_indices); + assert_eq!(HashSet::from([0]), matched_roots); } // Test schema: diff --git a/src/mito2/src/sst/parquet/reader.rs b/src/mito2/src/sst/parquet/reader.rs index 6942c8223d..def97d8f67 100644 --- a/src/mito2/src/sst/parquet/reader.rs +++ b/src/mito2/src/sst/parquet/reader.rs @@ -14,6 +14,8 @@ //! Parquet reader. +mod stream; + #[cfg(feature = "vector_index")] use std::collections::BTreeSet; use std::collections::HashSet; @@ -25,7 +27,7 @@ use common_recordbatch::filter::SimpleFilterEvaluator; use common_telemetry::{tracing, warn}; use datafusion_expr::Expr; use datatypes::arrow::array::ArrayRef; -use datatypes::arrow::datatypes::Field; +use datatypes::arrow::datatypes::{Field, SchemaRef}; use datatypes::arrow::record_batch::RecordBatch; use datatypes::data_type::ConcreteDataType; use datatypes::prelude::DataType; @@ -44,6 +46,7 @@ use store_api::region_request::PathType; use store_api::storage::{ColumnId, FileId}; use table::predicate::Predicate; +use self::stream::{ParquetErrorAdapter, ProjectedRecordBatchStream}; use crate::cache::index::result_cache::PredicateKey; use crate::cache::{CacheStrategy, CachedSstMeta}; #[cfg(feature = "vector_index")] @@ -78,7 +81,9 @@ use crate::sst::parquet::metadata::MetadataLoader; use crate::sst::parquet::prefilter::{ PrefilterContextBuilder, execute_prefilter, is_usable_primary_key_filter, }; -use crate::sst::parquet::read_columns::{ParquetReadColumns, build_projection_mask}; +use crate::sst::parquet::read_columns::{ + ParquetReadColumns, ProjectionMaskPlan, build_projection_plan, +}; use crate::sst::parquet::row_group::ParquetFetchMetrics; use crate::sst::parquet::row_selection::RowGroupSelection; use crate::sst::parquet::stats::RowGroupPruningStats; @@ -405,7 +410,8 @@ impl ParquetReaderBuilder { read_format.projection_indices().iter().copied(), ); - let projection_mask = build_projection_mask(&parquet_read_cols, parquet_schema_desc); + let projection_plan = build_projection_plan(&parquet_read_cols, parquet_schema_desc); + let selection = self .row_groups_to_read(&read_format, &parquet_meta, &mut metrics.filter_metrics) .await; @@ -484,13 +490,16 @@ impl ParquetReaderBuilder { parquet_meta.file_metadata().schema_descr(), ); + let output_schema = read_format.output_arrow_schema()?; + let reader_builder = RowGroupReaderBuilder { file_handle: self.file_handle.clone(), file_path, parquet_meta, arrow_metadata, + output_schema, object_store: self.object_store.clone(), - projection: projection_mask, + projection: projection_plan, cache_strategy: self.cache_strategy.clone(), prefilter_builder, }; @@ -1633,10 +1642,12 @@ pub(crate) struct RowGroupReaderBuilder { parquet_meta: Arc, /// Arrow reader metadata for building async stream. arrow_metadata: ArrowReaderMetadata, + /// Projected output schema aligned with `projection.projected_root_presence`. + output_schema: SchemaRef, /// Object store as an Operator. object_store: ObjectStore, /// Projection mask. - projection: ProjectionMask, + projection: ProjectionMaskPlan, /// Cache. cache_strategy: CacheStrategy, /// Pre-built prefilter state. `None` if prefiltering is not applicable. @@ -1691,19 +1702,20 @@ impl RowGroupReaderBuilder { pub(crate) async fn build( &self, build_ctx: RowGroupBuildContext<'_>, - ) -> Result> { + ) -> Result { let prefilter_ctx = self.prefilter_builder.as_ref().map(|b| b.build()); let Some(mut prefilter_ctx) = prefilter_ctx else { // No prefilter applicable, build stream with full projection. - return self + let stream = self .build_with_projection( build_ctx.row_group_idx, build_ctx.row_selection, - self.projection.clone(), + self.projection.mask.clone(), build_ctx.fetch_metrics, ) - .await; + .await?; + return self.make_projected_stream(stream); }; let prefilter_start = Instant::now(); @@ -1716,13 +1728,27 @@ impl RowGroupReaderBuilder { let refined_selection = Some(prefilter_result.refined_selection); - self.build_with_projection( - build_ctx.row_group_idx, - refined_selection, - self.projection.clone(), - build_ctx.fetch_metrics, + let stream = self + .build_with_projection( + build_ctx.row_group_idx, + refined_selection, + self.projection.mask.clone(), + build_ctx.fetch_metrics, + ) + .await?; + self.make_projected_stream(stream) + } + + fn make_projected_stream( + &self, + stream: ParquetRecordBatchStream, + ) -> Result { + let stream = ParquetErrorAdapter::new(stream, self.file_path.clone()); + ProjectedRecordBatchStream::new( + stream, + self.projection.projected_root_presence.clone(), + self.output_schema.clone(), ) - .await } /// Builds a [ParquetRecordBatchStream] with a custom projection mask. @@ -1996,17 +2022,14 @@ pub(crate) struct FlatRowGroupReader { /// Context for file ranges. context: FileRangeContextRef, /// Inner parquet record batch stream. - stream: ParquetRecordBatchStream, + stream: ProjectedRecordBatchStream, /// Cached sequence array to override sequences. override_sequence: Option, } impl FlatRowGroupReader { /// Creates a new flat reader from file range. - pub(crate) fn new( - context: FileRangeContextRef, - stream: ParquetRecordBatchStream, - ) -> Self { + pub(crate) fn new(context: FileRangeContextRef, stream: ProjectedRecordBatchStream) -> Self { // The batch length from the reader should be less than or equal to DEFAULT_READ_BATCH_SIZE. let override_sequence = context .read_format() @@ -2023,9 +2046,7 @@ impl FlatRowGroupReader { pub(crate) async fn next_batch(&mut self) -> Result> { match self.stream.next().await { Some(batch_result) => { - let record_batch = batch_result.context(ReadParquetSnafu { - path: self.context.file_path(), - })?; + let record_batch = batch_result?; let record_batch = self .context diff --git a/src/mito2/src/sst/parquet/reader/stream.rs b/src/mito2/src/sst/parquet/reader/stream.rs new file mode 100644 index 0000000000..adc0f44112 --- /dev/null +++ b/src/mito2/src/sst/parquet/reader/stream.rs @@ -0,0 +1,314 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use datatypes::arrow::array::new_null_array; +use datatypes::arrow::datatypes::SchemaRef; +use datatypes::arrow::record_batch::RecordBatch; +use futures::Stream; +use parquet::arrow::async_reader::ParquetRecordBatchStream; +use snafu::{IntoError, ResultExt, ensure}; + +use crate::error::{NewRecordBatchSnafu, ReadParquetSnafu, Result, UnexpectedSnafu}; +use crate::sst::parquet::async_reader::SstAsyncFileReader; + +/// Wraps a parquet record batch stream and fills missing projected root columns. +/// +/// Nested projection may ask parquet to read leaves under a root column. If none +/// of the requested leaves exists in the current parquet file, parquet decoding +/// omits the whole root from the physical `RecordBatch`. The logical projection +/// still contains that root, so this wrapper restores the output shape by +/// inserting a root-level null array. +pub struct MissingColFiller { + /// Inner stream that yields record batches from parquet reader. + inner: S, + /// Output schema expected by the upper reader. + output_schema: SchemaRef, + /// Whether each projected root exists in the physical batch returned by parquet. + projected_root_matches: Vec, + /// Number of columns expected from the physical batch returned by parquet. + expected_input_col_num: usize, + /// Whether all projected roots are present and the stream can pass batches through. + all_matched: bool, +} + +pub(crate) type ProjectedRecordBatchStream = MissingColFiller; + +impl MissingColFiller +where + S: Stream>, +{ + pub fn new( + inner: S, + projected_root_matches: Vec, + output_schema: SchemaRef, + ) -> Result> { + ensure!( + projected_root_matches.len() == output_schema.fields().len(), + UnexpectedSnafu { + reason: format!( + "MissingColFiller projected root matches len {} does not match output schema columns {}", + projected_root_matches.len(), + output_schema.fields().len() + ), + } + ); + + let expected_input_col_num = projected_root_matches + .iter() + .filter(|matched| **matched) + .count(); + let all_matched = projected_root_matches.iter().all(|&m| m); + + Ok(MissingColFiller { + inner, + output_schema, + projected_root_matches, + expected_input_col_num, + all_matched, + }) + } +} + +impl Stream for MissingColFiller +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + match Pin::new(&mut this.inner).poll_next(cx) { + Poll::Ready(Some(Ok(rb))) => { + let output_schema = &this.output_schema; + let rb = if this.all_matched { + rb + } else { + fill_missing_cols( + rb, + output_schema, + &this.projected_root_matches, + this.expected_input_col_num, + )? + }; + Poll::Ready(Some(Ok(rb))) + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +fn fill_missing_cols( + rb: RecordBatch, + output_schema: &SchemaRef, + projected_root_matches: &[bool], + expected_input_col_num: usize, +) -> Result { + ensure!( + rb.columns().len() == expected_input_col_num, + UnexpectedSnafu { + reason: format!( + "MissingColFiller expected {} input columns but got {}", + expected_input_col_num, + rb.columns().len() + ), + } + ); + + let mut cols = Vec::with_capacity(projected_root_matches.len()); + let mut idx = 0; + + for (field, matched) in output_schema.fields().iter().zip(projected_root_matches) { + if *matched { + cols.push(rb.column(idx).clone()); + idx += 1; + } else { + cols.push(new_null_array(field.data_type(), rb.num_rows())); + } + } + + RecordBatch::try_new(output_schema.clone(), cols).context(NewRecordBatchSnafu) +} + +/// Maps parquet stream errors into mito errors before batches enter the filler. +pub(crate) struct ParquetErrorAdapter { + inner: ParquetRecordBatchStream, + path: String, +} + +impl ParquetErrorAdapter { + pub(crate) fn new(inner: ParquetRecordBatchStream, path: String) -> Self { + Self { inner, path } + } +} + +impl Stream for ParquetErrorAdapter { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + match Pin::new(&mut this.inner).poll_next(cx) { + Poll::Ready(Some(Ok(rb))) => Poll::Ready(Some(Ok(rb))), + Poll::Ready(Some(Err(err))) => { + Poll::Ready(Some(Err( + ReadParquetSnafu { path: &this.path }.into_error(err) + ))) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datatypes::arrow::array::{Array, ArrayRef, Int64Array, StringArray}; + use datatypes::arrow::datatypes::{DataType, Field, Fields, Schema}; + use futures::{StreamExt, stream}; + + use super::*; + + #[tokio::test] + async fn test_filler_with_all_projected_roots_match() { + let output_schema = schema([ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Utf8, true), + ]); + let input = RecordBatch::try_new( + output_schema.clone(), + vec![int_array([1, 2, 3]), string_array(["x", "y", "z"])], + ) + .unwrap(); + let stream = stream::iter([Ok(input.clone())]); + + let mut filler = + MissingColFiller::new(stream, vec![true, true], output_schema.clone()).unwrap(); + let output = filler.next().await.unwrap().unwrap(); + + assert_eq!(input, output); + assert!(filler.next().await.is_none()); + } + + #[tokio::test] + async fn test_filler_with_fills_null_root_columns() { + let input_schema = schema([Field::new("a", DataType::Int64, true)]); + let output_schema = schema([ + Field::new("a", DataType::Int64, true), + Field::new("missing", DataType::Utf8, true), + Field::new("c", DataType::Int64, true), + ]); + let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap(); + let stream = stream::iter([Ok(input)]); + + let mut filler = + MissingColFiller::new(stream, vec![true, false, false], output_schema.clone()).unwrap(); + let output = filler.next().await.unwrap().unwrap(); + + assert_eq!(output_schema, output.schema()); + assert_eq!(3, output.num_columns()); + assert_eq!( + &[Some(10), Some(20)], + output + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect::>() + .as_slice() + ); + assert_eq!(DataType::Utf8, *output.column(1).data_type()); + assert_eq!(output.num_rows(), output.column(1).null_count()); + assert_eq!(DataType::Int64, *output.column(2).data_type()); + assert_eq!(output.num_rows(), output.column(2).null_count()); + } + + #[tokio::test] + async fn test_filler_with_fills_missing_struct_root_column() { + let input_schema = schema([Field::new("a", DataType::Int64, true)]); + let struct_type = DataType::Struct(Fields::from(vec![ + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Utf8, true), + ])); + let output_schema = schema([ + Field::new("a", DataType::Int64, true), + Field::new("missing_struct", struct_type.clone(), true), + ]); + let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap(); + let stream = stream::iter([Ok(input)]); + + let mut filler = + MissingColFiller::new(stream, vec![true, false], output_schema.clone()).unwrap(); + let output = filler.next().await.unwrap().unwrap(); + + assert_eq!(output_schema, output.schema()); + assert_eq!(2, output.num_columns()); + assert_eq!(struct_type, output.column(1).data_type().clone()); + assert_eq!(output.num_rows(), output.column(1).null_count()); + } + + #[tokio::test] + async fn test_filler_with_reject_projection_len_mismatch() { + let output_schema = schema([Field::new("a", DataType::Int64, true)]); + let stream = stream::iter([]); + + let err = match MissingColFiller::new(stream, vec![true, false], output_schema) { + Ok(_) => panic!("MissingColFiller should reject projection length mismatch"), + Err(err) => err, + }; + + assert!(err.to_string().contains("projected root matches len 2")); + } + + #[tokio::test] + async fn test_filler_reject_with_input_column_mismatch() { + let input_schema = schema([Field::new("a", DataType::Int64, true)]); + let output_schema = schema([ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("missing", DataType::Int64, true), + ]); + let input = RecordBatch::try_new(input_schema, vec![int_array([1, 2])]).unwrap(); + let stream = stream::iter([Ok(input)]); + + let mut filler = + MissingColFiller::new(stream, vec![true, true, false], output_schema).unwrap(); + let err = filler.next().await.unwrap().unwrap_err(); + + assert!( + err.to_string() + .contains("expected 2 input columns but got 1") + ); + } + + fn schema(fields: impl IntoIterator) -> SchemaRef { + Arc::new(Schema::new(fields.into_iter().collect::>())) + } + + fn int_array(values: impl IntoIterator) -> ArrayRef { + Arc::new(Int64Array::from_iter_values(values)) + } + + fn string_array(values: impl IntoIterator) -> ArrayRef { + Arc::new(StringArray::from_iter_values(values)) + } +}