fix: nested projection missing roots (#7993)

* fix: nested projection missing roots

* add docs and unit test

* fix: cr by ai

* move some computations to new method
This commit is contained in:
fys
2026-04-22 17:35:21 +08:00
committed by GitHub
parent 6649c14938
commit 1440924955
5 changed files with 525 additions and 60 deletions

View File

@@ -260,11 +260,6 @@ impl FileRangeContext {
}
}
/// Returns the path of the file to read.
pub(crate) fn file_path(&self) -> &str {
self.reader_builder.file_path()
}
/// Returns filters pushed down.
pub(crate) fn filters(&self) -> &[SimpleFilterContext] {
&self.base.filters

View File

@@ -255,6 +255,15 @@ impl FlatReadFormat {
}
}
/// Gets the projected output schema produced by parquet reading.
pub(crate) fn output_arrow_schema(&self) -> Result<SchemaRef> {
let schema = self
.arrow_schema()
.project(self.projection_indices())
.context(ComputeArrowSnafu)?;
Ok(Arc::new(schema))
}
/// Gets the metadata of the SST.
pub(crate) fn metadata(&self) -> &RegionMetadataRef {
match &self.parquet_adapter {
@@ -790,6 +799,8 @@ impl FlatReadFormat {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use api::v1::SemanticType;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::ColumnSchema;
@@ -797,8 +808,10 @@ mod tests {
use store_api::metadata::{ColumnMetadata, RegionMetadata, RegionMetadataBuilder};
use store_api::storage::RegionId;
use super::field_column_start;
use crate::sst::{FlatSchemaOptions, flat_sst_arrow_schema_column_num};
use super::{FlatReadFormat, field_column_start};
use crate::sst::{
FlatSchemaOptions, flat_sst_arrow_schema_column_num, to_flat_sst_arrow_schema,
};
/// Builds a `RegionMetadata` with the given number of tags and fields.
fn build_metadata(
@@ -872,4 +885,26 @@ mod tests {
);
}
}
#[test]
fn test_output_arrow_schema_uses_projection() {
let metadata = Arc::new(build_metadata(1, 2, PrimaryKeyEncoding::Dense));
let read_format = FlatReadFormat::new(
metadata.clone(),
[0_u32, 2_u32].into_iter(),
None,
"test",
false,
)
.unwrap();
let output_schema = read_format.output_arrow_schema().unwrap();
let expected = Arc::new(
to_flat_sst_arrow_schema(&metadata, &FlatSchemaOptions::default())
.project(read_format.projection_indices())
.unwrap(),
);
assert_eq!(expected, output_schema);
}
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use parquet::arrow::ProjectionMask;
use parquet::schema::types::SchemaDescriptor;
@@ -104,30 +104,87 @@ impl ParquetReadColumn {
}
}
/// Builds a projection mask from parquet read columns.
pub fn build_projection_mask(
/// Projection plan built for a parquet file.
#[derive(Clone)]
pub struct ProjectionMaskPlan {
/// `mask` is the projection mask applied to the parquet reader.
pub mask: ProjectionMask,
/// A boolean mask in output schema order indicating whether each
/// projected root column is physically present in the parquet
/// read result.
///
/// - `true`: the column exists in the input `RecordBatch`.
/// - `false`: the column is missing (e.g., due to unmatched nested
/// paths) and must be synthesized during post-processing (typically
/// filled with null/default values).
///
/// The length of `projected_root_presence` is always equal to the
/// number of fields in the output schema.
pub projected_root_presence: Vec<bool>,
}
/// Builds a projection mask plan for reading a parquet file.
///
/// `parquet_read_cols` defines the requested root columns and optional
/// nested paths to read.
///
/// `parquet_schema_desc` is the schema descriptor of the current parquet
/// file. It is used to resolve requested nested paths to actual leaf
/// column indices.
///
/// See [`ProjectionMaskPlan`] for the returned value.
///
/// For example, if the query requests `j.a` and `k`, but the current
/// parquet file only contains leaves under `j.b` and `k`, then the
/// returned plan keeps `k` in the projection mask and marks `j` as
/// not present in the output, so it can be synthesized during
/// post-processing.
pub fn build_projection_plan(
parquet_read_cols: &ParquetReadColumns,
parquet_schema_desc: &SchemaDescriptor,
) -> ProjectionMask {
if parquet_read_cols.has_nested() {
let leaf_indices = build_parquet_leaves_indices(parquet_schema_desc, parquet_read_cols);
ProjectionMask::leaves(parquet_schema_desc, leaf_indices)
} else {
ProjectionMask::roots(parquet_schema_desc, parquet_read_cols.root_indices_iter())
) -> ProjectionMaskPlan {
if !parquet_read_cols.has_nested() {
let mask =
ProjectionMask::roots(parquet_schema_desc, parquet_read_cols.root_indices_iter());
return ProjectionMaskPlan {
mask,
projected_root_presence: vec![true; parquet_read_cols.columns().len()],
};
}
let (leaf_indices, matched_roots) =
build_parquet_leaves_indices(parquet_schema_desc, parquet_read_cols);
let projected_root_presence = parquet_read_cols
.columns()
.iter()
.map(|col| matched_roots.contains(&col.root_index()))
.collect();
let mask = ProjectionMask::leaves(parquet_schema_desc, leaf_indices);
ProjectionMaskPlan {
mask,
projected_root_presence,
}
}
/// Builds parquet leaf-column indices from parquet read columns.
/// Builds parquet leaf-column indices for reading a parquet file.
///
/// Returns `(leaf_indices, matched_roots)`:
/// - `leaf_indices`: matched parquet leaf column indices
/// - `matched_roots`: root column indices that match at least one leaf in the
/// current parquet schema.
fn build_parquet_leaves_indices(
parquet_schema_desc: &SchemaDescriptor,
projection: &ParquetReadColumns,
) -> Vec<usize> {
) -> (Vec<usize>, HashSet<usize>) {
let mut map = HashMap::with_capacity(projection.cols.len());
for col in &projection.cols {
map.insert(col.root_index, &col.nested_paths);
}
let mut leaf_indices = Vec::new();
let mut matched_roots = HashSet::with_capacity(projection.cols.len());
for (leaf_idx, leaf_col) in parquet_schema_desc.columns().iter().enumerate() {
let root_idx = parquet_schema_desc.get_column_root_idx(leaf_idx);
let Some(nested_paths) = map.get(&root_idx) else {
@@ -135,6 +192,7 @@ fn build_parquet_leaves_indices(
};
if nested_paths.is_empty() {
leaf_indices.push(leaf_idx);
matched_roots.insert(root_idx);
continue;
}
@@ -144,9 +202,10 @@ fn build_parquet_leaves_indices(
.any(|nested_path| leaf_path.starts_with(nested_path))
{
leaf_indices.push(leaf_idx);
matched_roots.insert(root_idx);
}
}
leaf_indices
(leaf_indices, matched_roots)
}
#[cfg(test)]
@@ -158,6 +217,20 @@ mod tests {
use super::*;
#[test]
fn test_build_projection_mask_without_nested_paths() {
let parquet_schema_desc = build_test_nested_parquet_schema();
let projection = ParquetReadColumns::from_deduped_root_indices([0, 1]);
let plan = build_projection_plan(&projection, &parquet_schema_desc);
assert_eq!(vec![true, true], plan.projected_root_presence);
assert_eq!(
ProjectionMask::roots(&parquet_schema_desc, [0, 1]),
plan.mask
);
}
#[test]
fn test_reads_whole_root() {
let parquet_schema_desc = build_test_nested_parquet_schema();
@@ -170,10 +243,10 @@ mod tests {
has_nested: false,
};
assert_eq!(
vec![0, 1, 2],
build_parquet_leaves_indices(&parquet_schema_desc, &projection)
);
let (leaf_indices, matched_roots) =
build_parquet_leaves_indices(&parquet_schema_desc, &projection);
assert_eq!(vec![0, 1, 2], leaf_indices);
assert_eq!(HashSet::from([0]), matched_roots);
}
#[test]
@@ -194,10 +267,10 @@ mod tests {
has_nested: true,
};
assert_eq!(
vec![1, 2, 3],
build_parquet_leaves_indices(&parquet_schema_desc, &projection)
);
let (leaf_indices, matched_roots) =
build_parquet_leaves_indices(&parquet_schema_desc, &projection);
assert_eq!(vec![1, 2, 3], leaf_indices);
assert_eq!(HashSet::from([0, 1]), matched_roots);
}
#[test]
@@ -212,10 +285,10 @@ mod tests {
has_nested: true,
};
assert_eq!(
vec![1, 2],
build_parquet_leaves_indices(&parquet_schema_desc, &projection)
);
let (leaf_indices, matched_roots) =
build_parquet_leaves_indices(&parquet_schema_desc, &projection);
assert_eq!(vec![1, 2], leaf_indices);
assert_eq!(HashSet::from([0]), matched_roots);
}
#[test]
@@ -230,9 +303,36 @@ mod tests {
has_nested: true,
};
let (leaf_indices, matched_roots) =
build_parquet_leaves_indices(&parquet_schema_desc, &projection);
assert_eq!(vec![1], leaf_indices);
assert_eq!(HashSet::from([0]), matched_roots);
}
#[test]
fn test_build_projection_mask_with_unmatched_roots() {
let parquet_schema_desc = build_test_nested_parquet_schema();
let projection = ParquetReadColumns {
cols: vec![
ParquetReadColumn {
root_index: 0,
nested_paths: vec![vec!["j".to_string(), "missing".to_string()]],
},
ParquetReadColumn {
root_index: 1,
nested_paths: vec![],
},
],
has_nested: true,
};
let plan = build_projection_plan(&projection, &parquet_schema_desc);
assert_eq!(vec![false, true], plan.projected_root_presence);
assert_eq!(
vec![1],
build_parquet_leaves_indices(&parquet_schema_desc, &projection)
ProjectionMask::leaves(&parquet_schema_desc, vec![3]),
plan.mask
);
}
@@ -251,10 +351,10 @@ mod tests {
has_nested: true,
};
assert_eq!(
vec![0, 2],
build_parquet_leaves_indices(&parquet_schema_desc, &projection)
);
let (leaf_indices, matched_roots) =
build_parquet_leaves_indices(&parquet_schema_desc, &projection);
assert_eq!(vec![0, 2], leaf_indices);
assert_eq!(HashSet::from([0]), matched_roots);
}
// Test schema:

View File

@@ -14,6 +14,8 @@
//! Parquet reader.
mod stream;
#[cfg(feature = "vector_index")]
use std::collections::BTreeSet;
use std::collections::HashSet;
@@ -25,7 +27,7 @@ use common_recordbatch::filter::SimpleFilterEvaluator;
use common_telemetry::{tracing, warn};
use datafusion_expr::Expr;
use datatypes::arrow::array::ArrayRef;
use datatypes::arrow::datatypes::Field;
use datatypes::arrow::datatypes::{Field, SchemaRef};
use datatypes::arrow::record_batch::RecordBatch;
use datatypes::data_type::ConcreteDataType;
use datatypes::prelude::DataType;
@@ -44,6 +46,7 @@ use store_api::region_request::PathType;
use store_api::storage::{ColumnId, FileId};
use table::predicate::Predicate;
use self::stream::{ParquetErrorAdapter, ProjectedRecordBatchStream};
use crate::cache::index::result_cache::PredicateKey;
use crate::cache::{CacheStrategy, CachedSstMeta};
#[cfg(feature = "vector_index")]
@@ -78,7 +81,9 @@ use crate::sst::parquet::metadata::MetadataLoader;
use crate::sst::parquet::prefilter::{
PrefilterContextBuilder, execute_prefilter, is_usable_primary_key_filter,
};
use crate::sst::parquet::read_columns::{ParquetReadColumns, build_projection_mask};
use crate::sst::parquet::read_columns::{
ParquetReadColumns, ProjectionMaskPlan, build_projection_plan,
};
use crate::sst::parquet::row_group::ParquetFetchMetrics;
use crate::sst::parquet::row_selection::RowGroupSelection;
use crate::sst::parquet::stats::RowGroupPruningStats;
@@ -405,7 +410,8 @@ impl ParquetReaderBuilder {
read_format.projection_indices().iter().copied(),
);
let projection_mask = build_projection_mask(&parquet_read_cols, parquet_schema_desc);
let projection_plan = build_projection_plan(&parquet_read_cols, parquet_schema_desc);
let selection = self
.row_groups_to_read(&read_format, &parquet_meta, &mut metrics.filter_metrics)
.await;
@@ -484,13 +490,16 @@ impl ParquetReaderBuilder {
parquet_meta.file_metadata().schema_descr(),
);
let output_schema = read_format.output_arrow_schema()?;
let reader_builder = RowGroupReaderBuilder {
file_handle: self.file_handle.clone(),
file_path,
parquet_meta,
arrow_metadata,
output_schema,
object_store: self.object_store.clone(),
projection: projection_mask,
projection: projection_plan,
cache_strategy: self.cache_strategy.clone(),
prefilter_builder,
};
@@ -1633,10 +1642,12 @@ pub(crate) struct RowGroupReaderBuilder {
parquet_meta: Arc<ParquetMetaData>,
/// Arrow reader metadata for building async stream.
arrow_metadata: ArrowReaderMetadata,
/// Projected output schema aligned with `projection.projected_root_presence`.
output_schema: SchemaRef,
/// Object store as an Operator.
object_store: ObjectStore,
/// Projection mask.
projection: ProjectionMask,
projection: ProjectionMaskPlan,
/// Cache.
cache_strategy: CacheStrategy,
/// Pre-built prefilter state. `None` if prefiltering is not applicable.
@@ -1691,19 +1702,20 @@ impl RowGroupReaderBuilder {
pub(crate) async fn build(
&self,
build_ctx: RowGroupBuildContext<'_>,
) -> Result<ParquetRecordBatchStream<SstAsyncFileReader>> {
) -> Result<ProjectedRecordBatchStream> {
let prefilter_ctx = self.prefilter_builder.as_ref().map(|b| b.build());
let Some(mut prefilter_ctx) = prefilter_ctx else {
// No prefilter applicable, build stream with full projection.
return self
let stream = self
.build_with_projection(
build_ctx.row_group_idx,
build_ctx.row_selection,
self.projection.clone(),
self.projection.mask.clone(),
build_ctx.fetch_metrics,
)
.await;
.await?;
return self.make_projected_stream(stream);
};
let prefilter_start = Instant::now();
@@ -1716,13 +1728,27 @@ impl RowGroupReaderBuilder {
let refined_selection = Some(prefilter_result.refined_selection);
self.build_with_projection(
build_ctx.row_group_idx,
refined_selection,
self.projection.clone(),
build_ctx.fetch_metrics,
let stream = self
.build_with_projection(
build_ctx.row_group_idx,
refined_selection,
self.projection.mask.clone(),
build_ctx.fetch_metrics,
)
.await?;
self.make_projected_stream(stream)
}
fn make_projected_stream(
&self,
stream: ParquetRecordBatchStream<SstAsyncFileReader>,
) -> Result<ProjectedRecordBatchStream> {
let stream = ParquetErrorAdapter::new(stream, self.file_path.clone());
ProjectedRecordBatchStream::new(
stream,
self.projection.projected_root_presence.clone(),
self.output_schema.clone(),
)
.await
}
/// Builds a [ParquetRecordBatchStream] with a custom projection mask.
@@ -1996,17 +2022,14 @@ pub(crate) struct FlatRowGroupReader {
/// Context for file ranges.
context: FileRangeContextRef,
/// Inner parquet record batch stream.
stream: ParquetRecordBatchStream<SstAsyncFileReader>,
stream: ProjectedRecordBatchStream,
/// Cached sequence array to override sequences.
override_sequence: Option<ArrayRef>,
}
impl FlatRowGroupReader {
/// Creates a new flat reader from file range.
pub(crate) fn new(
context: FileRangeContextRef,
stream: ParquetRecordBatchStream<SstAsyncFileReader>,
) -> Self {
pub(crate) fn new(context: FileRangeContextRef, stream: ProjectedRecordBatchStream) -> Self {
// The batch length from the reader should be less than or equal to DEFAULT_READ_BATCH_SIZE.
let override_sequence = context
.read_format()
@@ -2023,9 +2046,7 @@ impl FlatRowGroupReader {
pub(crate) async fn next_batch(&mut self) -> Result<Option<RecordBatch>> {
match self.stream.next().await {
Some(batch_result) => {
let record_batch = batch_result.context(ReadParquetSnafu {
path: self.context.file_path(),
})?;
let record_batch = batch_result?;
let record_batch = self
.context

View File

@@ -0,0 +1,314 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::pin::Pin;
use std::task::{Context, Poll};
use datatypes::arrow::array::new_null_array;
use datatypes::arrow::datatypes::SchemaRef;
use datatypes::arrow::record_batch::RecordBatch;
use futures::Stream;
use parquet::arrow::async_reader::ParquetRecordBatchStream;
use snafu::{IntoError, ResultExt, ensure};
use crate::error::{NewRecordBatchSnafu, ReadParquetSnafu, Result, UnexpectedSnafu};
use crate::sst::parquet::async_reader::SstAsyncFileReader;
/// Wraps a parquet record batch stream and fills missing projected root columns.
///
/// Nested projection may ask parquet to read leaves under a root column. If none
/// of the requested leaves exists in the current parquet file, parquet decoding
/// omits the whole root from the physical `RecordBatch`. The logical projection
/// still contains that root, so this wrapper restores the output shape by
/// inserting a root-level null array.
pub struct MissingColFiller<S> {
/// Inner stream that yields record batches from parquet reader.
inner: S,
/// Output schema expected by the upper reader.
output_schema: SchemaRef,
/// Whether each projected root exists in the physical batch returned by parquet.
projected_root_matches: Vec<bool>,
/// Number of columns expected from the physical batch returned by parquet.
expected_input_col_num: usize,
/// Whether all projected roots are present and the stream can pass batches through.
all_matched: bool,
}
pub(crate) type ProjectedRecordBatchStream = MissingColFiller<ParquetErrorAdapter>;
impl<S> MissingColFiller<S>
where
S: Stream<Item = Result<RecordBatch>>,
{
pub fn new(
inner: S,
projected_root_matches: Vec<bool>,
output_schema: SchemaRef,
) -> Result<MissingColFiller<S>> {
ensure!(
projected_root_matches.len() == output_schema.fields().len(),
UnexpectedSnafu {
reason: format!(
"MissingColFiller projected root matches len {} does not match output schema columns {}",
projected_root_matches.len(),
output_schema.fields().len()
),
}
);
let expected_input_col_num = projected_root_matches
.iter()
.filter(|matched| **matched)
.count();
let all_matched = projected_root_matches.iter().all(|&m| m);
Ok(MissingColFiller {
inner,
output_schema,
projected_root_matches,
expected_input_col_num,
all_matched,
})
}
}
impl<S> Stream for MissingColFiller<S>
where
S: Stream<Item = Result<RecordBatch>> + Unpin,
{
type Item = Result<RecordBatch>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(rb))) => {
let output_schema = &this.output_schema;
let rb = if this.all_matched {
rb
} else {
fill_missing_cols(
rb,
output_schema,
&this.projected_root_matches,
this.expected_input_col_num,
)?
};
Poll::Ready(Some(Ok(rb)))
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
fn fill_missing_cols(
rb: RecordBatch,
output_schema: &SchemaRef,
projected_root_matches: &[bool],
expected_input_col_num: usize,
) -> Result<RecordBatch> {
ensure!(
rb.columns().len() == expected_input_col_num,
UnexpectedSnafu {
reason: format!(
"MissingColFiller expected {} input columns but got {}",
expected_input_col_num,
rb.columns().len()
),
}
);
let mut cols = Vec::with_capacity(projected_root_matches.len());
let mut idx = 0;
for (field, matched) in output_schema.fields().iter().zip(projected_root_matches) {
if *matched {
cols.push(rb.column(idx).clone());
idx += 1;
} else {
cols.push(new_null_array(field.data_type(), rb.num_rows()));
}
}
RecordBatch::try_new(output_schema.clone(), cols).context(NewRecordBatchSnafu)
}
/// Maps parquet stream errors into mito errors before batches enter the filler.
pub(crate) struct ParquetErrorAdapter {
inner: ParquetRecordBatchStream<SstAsyncFileReader>,
path: String,
}
impl ParquetErrorAdapter {
pub(crate) fn new(inner: ParquetRecordBatchStream<SstAsyncFileReader>, path: String) -> Self {
Self { inner, path }
}
}
impl Stream for ParquetErrorAdapter {
type Item = Result<RecordBatch>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(rb))) => Poll::Ready(Some(Ok(rb))),
Poll::Ready(Some(Err(err))) => {
Poll::Ready(Some(Err(
ReadParquetSnafu { path: &this.path }.into_error(err)
)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datatypes::arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use datatypes::arrow::datatypes::{DataType, Field, Fields, Schema};
use futures::{StreamExt, stream};
use super::*;
#[tokio::test]
async fn test_filler_with_all_projected_roots_match() {
let output_schema = schema([
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Utf8, true),
]);
let input = RecordBatch::try_new(
output_schema.clone(),
vec![int_array([1, 2, 3]), string_array(["x", "y", "z"])],
)
.unwrap();
let stream = stream::iter([Ok(input.clone())]);
let mut filler =
MissingColFiller::new(stream, vec![true, true], output_schema.clone()).unwrap();
let output = filler.next().await.unwrap().unwrap();
assert_eq!(input, output);
assert!(filler.next().await.is_none());
}
#[tokio::test]
async fn test_filler_with_fills_null_root_columns() {
let input_schema = schema([Field::new("a", DataType::Int64, true)]);
let output_schema = schema([
Field::new("a", DataType::Int64, true),
Field::new("missing", DataType::Utf8, true),
Field::new("c", DataType::Int64, true),
]);
let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
let stream = stream::iter([Ok(input)]);
let mut filler =
MissingColFiller::new(stream, vec![true, false, false], output_schema.clone()).unwrap();
let output = filler.next().await.unwrap().unwrap();
assert_eq!(output_schema, output.schema());
assert_eq!(3, output.num_columns());
assert_eq!(
&[Some(10), Some(20)],
output
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.iter()
.collect::<Vec<_>>()
.as_slice()
);
assert_eq!(DataType::Utf8, *output.column(1).data_type());
assert_eq!(output.num_rows(), output.column(1).null_count());
assert_eq!(DataType::Int64, *output.column(2).data_type());
assert_eq!(output.num_rows(), output.column(2).null_count());
}
#[tokio::test]
async fn test_filler_with_fills_missing_struct_root_column() {
let input_schema = schema([Field::new("a", DataType::Int64, true)]);
let struct_type = DataType::Struct(Fields::from(vec![
Field::new("x", DataType::Int64, true),
Field::new("y", DataType::Utf8, true),
]));
let output_schema = schema([
Field::new("a", DataType::Int64, true),
Field::new("missing_struct", struct_type.clone(), true),
]);
let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
let stream = stream::iter([Ok(input)]);
let mut filler =
MissingColFiller::new(stream, vec![true, false], output_schema.clone()).unwrap();
let output = filler.next().await.unwrap().unwrap();
assert_eq!(output_schema, output.schema());
assert_eq!(2, output.num_columns());
assert_eq!(struct_type, output.column(1).data_type().clone());
assert_eq!(output.num_rows(), output.column(1).null_count());
}
#[tokio::test]
async fn test_filler_with_reject_projection_len_mismatch() {
let output_schema = schema([Field::new("a", DataType::Int64, true)]);
let stream = stream::iter([]);
let err = match MissingColFiller::new(stream, vec![true, false], output_schema) {
Ok(_) => panic!("MissingColFiller should reject projection length mismatch"),
Err(err) => err,
};
assert!(err.to_string().contains("projected root matches len 2"));
}
#[tokio::test]
async fn test_filler_reject_with_input_column_mismatch() {
let input_schema = schema([Field::new("a", DataType::Int64, true)]);
let output_schema = schema([
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Int64, true),
Field::new("missing", DataType::Int64, true),
]);
let input = RecordBatch::try_new(input_schema, vec![int_array([1, 2])]).unwrap();
let stream = stream::iter([Ok(input)]);
let mut filler =
MissingColFiller::new(stream, vec![true, true, false], output_schema).unwrap();
let err = filler.next().await.unwrap().unwrap_err();
assert!(
err.to_string()
.contains("expected 2 input columns but got 1")
);
}
fn schema(fields: impl IntoIterator<Item = Field>) -> SchemaRef {
Arc::new(Schema::new(fields.into_iter().collect::<Vec<_>>()))
}
fn int_array(values: impl IntoIterator<Item = i64>) -> ArrayRef {
Arc::new(Int64Array::from_iter_values(values))
}
fn string_array(values: impl IntoIterator<Item = &'static str>) -> ArrayRef {
Arc::new(StringArray::from_iter_values(values))
}
}