feat: Implements an iterator to read the RecordBatch in BulkPart (#6647)

* feat: impl RecordBatchIter for BulkPart

Signed-off-by: evenyag <realevenyag@gmail.com>

* refactor: rename BulkPartIter to EncodedBulkPartIter

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: add iter benchmark

Signed-off-by: evenyag <realevenyag@gmail.com>

* feat: filter by primary key columns

Signed-off-by: evenyag <realevenyag@gmail.com>

* refactor: move struct definitions

Signed-off-by: evenyag <realevenyag@gmail.com>

* feat: bulk iter for flat schema

Signed-off-by: evenyag <realevenyag@gmail.com>

* feat: iter filter benchmark

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: fix compiler errors

Signed-off-by: evenyag <realevenyag@gmail.com>

* fix: use corrent sequence array to compare

Signed-off-by: evenyag <realevenyag@gmail.com>

* refactor: remove RecordBatchIter

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: update comments

Signed-off-by: evenyag <realevenyag@gmail.com>

* style: fix clippy

Signed-off-by: evenyag <realevenyag@gmail.com>

* feat: apply projection first

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: address comment

No need to check number of rows after filter

Signed-off-by: evenyag <realevenyag@gmail.com>

---------

Signed-off-by: evenyag <realevenyag@gmail.com>
This commit is contained in:
Yingwen
2025-08-05 16:11:28 +08:00
committed by GitHub
parent 9c3b83e84d
commit 50f7f61fdc
6 changed files with 405 additions and 15 deletions

View File

@@ -21,7 +21,9 @@ use datafusion_common::Column;
use datafusion_expr::{lit, Expr};
use datatypes::data_type::ConcreteDataType;
use datatypes::schema::ColumnSchema;
use mito2::memtable::bulk::context::BulkIterContext;
use mito2::memtable::bulk::part::BulkPartConverter;
use mito2::memtable::bulk::part_reader::BulkPartRecordBatchIter;
use mito2::memtable::partition_tree::{PartitionTreeConfig, PartitionTreeMemtable};
use mito2::memtable::time_series::TimeSeriesMemtable;
use mito2::memtable::{KeyValues, Memtable};
@@ -421,11 +423,83 @@ fn bulk_part_converter(c: &mut Criterion) {
}
}
fn bulk_part_record_batch_iter_filter(c: &mut Criterion) {
let metadata = Arc::new(cpu_metadata());
let schema = to_flat_sst_arrow_schema(&metadata, &FlatSchemaOptions::default());
let start_sec = 1710043200;
let mut group = c.benchmark_group("bulk_part_record_batch_iter_filter");
// Pre-create RecordBatch and primary key arrays
let (record_batch_with_filter, record_batch_no_filter) = {
let generator = CpuDataGenerator::new(metadata.clone(), 4096, start_sec, start_sec + 1);
let codec = Arc::new(DensePrimaryKeyCodec::new(&metadata));
let mut converter = BulkPartConverter::new(&metadata, schema, 4096, codec, true);
if let Some(kvs) = generator.iter().next() {
converter.append_key_values(&kvs).unwrap();
}
let bulk_part = converter.convert().unwrap();
let record_batch = bulk_part.batch;
(record_batch.clone(), record_batch)
};
// Pre-create predicate
let generator = CpuDataGenerator::new(metadata.clone(), 4096, start_sec, start_sec + 1);
let predicate = generator.random_host_filter();
// Benchmark with hostname filter using non-encoded primary keys
group.bench_function("4096_rows_with_hostname_filter", |b| {
b.iter(|| {
// Create context for BulkPartRecordBatchIter with predicate
let context = Arc::new(BulkIterContext::new(
metadata.clone(),
&None, // No projection
Some(predicate.clone()), // With hostname filter
true,
));
// Create and iterate over BulkPartRecordBatchIter with filter
let iter =
BulkPartRecordBatchIter::new(record_batch_with_filter.clone(), context, None);
// Consume all batches
for batch_result in iter {
let _batch = batch_result.unwrap();
}
});
});
// Benchmark without filter for comparison
group.bench_function("4096_rows_no_filter", |b| {
b.iter(|| {
// Create context for BulkPartRecordBatchIter without predicate
let context = Arc::new(BulkIterContext::new(
metadata.clone(),
&None, // No projection
None, // No predicate
true,
));
// Create and iterate over BulkPartRecordBatchIter
let iter = BulkPartRecordBatchIter::new(record_batch_no_filter.clone(), context, None);
// Consume all batches
for batch_result in iter {
let _batch = batch_result.unwrap();
}
});
});
}
criterion_group!(
benches,
write_rows,
full_scan,
filter_1_host,
bulk_part_converter,
bulk_part_record_batch_iter_filter
);
criterion_main!(benches);

View File

@@ -27,10 +27,10 @@ use crate::memtable::{
};
#[allow(unused)]
mod context;
pub mod context;
#[allow(unused)]
pub mod part;
mod part_reader;
pub mod part_reader;
mod row_group_reader;
#[derive(Debug)]

View File

@@ -24,22 +24,24 @@ use store_api::storage::ColumnId;
use table::predicate::Predicate;
use crate::sst::parquet::file_range::RangeBase;
use crate::sst::parquet::flat_format::FlatReadFormat;
use crate::sst::parquet::format::ReadFormat;
use crate::sst::parquet::reader::SimpleFilterContext;
use crate::sst::parquet::stats::RowGroupPruningStats;
pub(crate) type BulkIterContextRef = Arc<BulkIterContext>;
pub(crate) struct BulkIterContext {
pub struct BulkIterContext {
pub(crate) base: RangeBase,
pub(crate) predicate: Option<Predicate>,
}
impl BulkIterContext {
pub(crate) fn new(
pub fn new(
region_metadata: RegionMetadataRef,
projection: &Option<&[ColumnId]>,
predicate: Option<Predicate>,
flat_format: bool,
) -> Self {
let codec = build_primary_key_codec(&region_metadata);
@@ -54,7 +56,7 @@ impl BulkIterContext {
})
.collect();
let read_format = build_read_format(region_metadata, projection);
let read_format = build_read_format(region_metadata, projection, flat_format);
Self {
base: RangeBase {
@@ -99,9 +101,10 @@ impl BulkIterContext {
fn build_read_format(
region_metadata: RegionMetadataRef,
projection: &Option<&[ColumnId]>,
flat_format: bool,
) -> ReadFormat {
let read_format = if let Some(column_ids) = &projection {
ReadFormat::new(region_metadata, column_ids.iter().copied())
ReadFormat::new(region_metadata, column_ids.iter().copied(), flat_format)
} else {
// No projection, lists all column ids to read.
ReadFormat::new(
@@ -110,6 +113,7 @@ fn build_read_format(
.column_metadatas
.iter()
.map(|col| col.column_id),
flat_format,
)
};

View File

@@ -58,7 +58,7 @@ use crate::error::{
EncodeSnafu, NewRecordBatchSnafu, Result,
};
use crate::memtable::bulk::context::BulkIterContextRef;
use crate::memtable::bulk::part_reader::BulkPartIter;
use crate::memtable::bulk::part_reader::EncodedBulkPartIter;
use crate::memtable::time_series::{ValueBuilder, Values};
use crate::memtable::BoxedBatchIterator;
use crate::sst::parquet::format::{PrimaryKeyArray, ReadFormat};
@@ -520,7 +520,7 @@ impl EncodedBulkPart {
return Ok(None);
}
let iter = BulkPartIter::try_new(
let iter = EncodedBulkPartIter::try_new(
context,
row_groups_to_read,
self.metadata.parquet_metadata.clone(),
@@ -1243,6 +1243,7 @@ mod tests {
part.metadata.region_metadata.clone(),
&Some(projection.as_slice()),
None,
false,
)),
None,
)
@@ -1294,6 +1295,7 @@ mod tests {
part.metadata.region_metadata.clone(),
&None,
predicate,
false,
));
let mut reader = part
.read(context, None)
@@ -1324,6 +1326,7 @@ mod tests {
Some(Predicate::new(vec![datafusion_expr::col("ts").eq(
datafusion_expr::lit(ScalarValue::TimestampMillisecond(Some(300), None)),
)])),
false,
));
assert!(part.read(context, None).unwrap().is_none());

View File

@@ -13,29 +13,36 @@
// limitations under the License.
use std::collections::VecDeque;
use std::ops::BitAnd;
use std::sync::Arc;
use bytes::Bytes;
use datatypes::arrow::array::{BooleanArray, Scalar, UInt64Array};
use datatypes::arrow::buffer::BooleanBuffer;
use datatypes::arrow::record_batch::RecordBatch;
use parquet::arrow::ProjectionMask;
use parquet::file::metadata::ParquetMetaData;
use snafu::ResultExt;
use store_api::storage::SequenceNumber;
use crate::error;
use crate::error::{self, ComputeArrowSnafu};
use crate::memtable::bulk::context::BulkIterContextRef;
use crate::memtable::bulk::row_group_reader::{
MemtableRowGroupReader, MemtableRowGroupReaderBuilder,
};
use crate::read::Batch;
use crate::sst::parquet::flat_format::sequence_column_index;
use crate::sst::parquet::reader::MaybeFilter;
/// Iterator for reading data inside a bulk part.
pub struct BulkPartIter {
pub struct EncodedBulkPartIter {
row_groups_to_read: VecDeque<usize>,
current_reader: Option<PruneReader>,
builder: MemtableRowGroupReaderBuilder,
sequence: Option<SequenceNumber>,
}
impl BulkPartIter {
impl EncodedBulkPartIter {
/// Creates a new [BulkPartIter].
pub(crate) fn try_new(
context: BulkIterContextRef,
@@ -92,7 +99,7 @@ impl BulkPartIter {
}
}
impl Iterator for BulkPartIter {
impl Iterator for EncodedBulkPartIter {
type Item = error::Result<Batch>;
fn next(&mut self) -> Option<Self::Item> {
@@ -153,3 +160,294 @@ impl PruneReader {
self.row_group_reader = reader;
}
}
/// Iterator for a record batch in a bulk part.
pub struct BulkPartRecordBatchIter {
/// The RecordBatch to read from
record_batch: Option<RecordBatch>,
/// Iterator context for filtering
context: BulkIterContextRef,
/// Sequence number filter.
sequence: Option<Scalar<UInt64Array>>,
}
impl BulkPartRecordBatchIter {
/// Creates a new [BulkPartRecordBatchIter] from a RecordBatch.
pub fn new(
record_batch: RecordBatch,
context: BulkIterContextRef,
sequence: Option<SequenceNumber>,
) -> Self {
assert!(context.read_format().as_flat().is_some());
let sequence = sequence.map(UInt64Array::new_scalar);
Self {
record_batch: Some(record_batch),
context,
sequence,
}
}
/// Applies projection to the RecordBatch if needed.
fn apply_projection(&self, record_batch: RecordBatch) -> error::Result<RecordBatch> {
let projection_indices = self.context.read_format().projection_indices();
if projection_indices.len() == record_batch.num_columns() {
return Ok(record_batch);
}
record_batch
.project(projection_indices)
.context(ComputeArrowSnafu)
}
// TODO(yingwen): Supports sparse encoding which doesn't have decoded primary key columns.
/// Applies both predicate filtering and sequence filtering in a single pass.
/// Returns None if the filtered batch is empty.
fn apply_combined_filters(
&self,
record_batch: RecordBatch,
) -> error::Result<Option<RecordBatch>> {
let num_rows = record_batch.num_rows();
let mut combined_filter = None;
// First, apply predicate filters.
if !self.context.base.filters.is_empty() {
let num_rows = record_batch.num_rows();
let mut mask = BooleanBuffer::new_set(num_rows);
// Run filter one by one and combine them result, similar to RangeBase::precise_filter
for filter_ctx in &self.context.base.filters {
let filter = match filter_ctx.filter() {
MaybeFilter::Filter(f) => f,
// Column matches.
MaybeFilter::Matched => continue,
// Column doesn't match, filter the entire batch.
MaybeFilter::Pruned => return Ok(None),
};
// Safety: We checked the format type in new().
let Some(column_index) = self
.context
.read_format()
.as_flat()
.unwrap()
.projected_index_by_id(filter_ctx.column_id())
else {
continue;
};
let array = record_batch.column(column_index);
let result = filter
.evaluate_array(array)
.context(crate::error::RecordBatchSnafu)?;
mask = mask.bitand(&result);
}
// Convert the mask to BooleanArray
combined_filter = Some(BooleanArray::from(mask));
}
// Filters rows by the given `sequence`. Only preserves rows with sequence less than or equal to `sequence`.
if let Some(sequence) = &self.sequence {
let sequence_column =
record_batch.column(sequence_column_index(record_batch.num_columns()));
let sequence_filter =
datatypes::arrow::compute::kernels::cmp::lt_eq(sequence_column, sequence)
.context(ComputeArrowSnafu)?;
// Combine with existing filter using AND operation
combined_filter = match combined_filter {
None => Some(sequence_filter),
Some(existing_filter) => {
let and_result =
datatypes::arrow::compute::and(&existing_filter, &sequence_filter)
.context(ComputeArrowSnafu)?;
Some(and_result)
}
};
}
// Apply the combined filter if any filters were applied
let Some(filter_array) = combined_filter else {
// No filters applied, return original batch
return Ok(Some(record_batch));
};
let select_count = filter_array.true_count();
if select_count == 0 {
return Ok(None);
}
if select_count == num_rows {
return Ok(Some(record_batch));
}
let filtered_batch =
datatypes::arrow::compute::filter_record_batch(&record_batch, &filter_array)
.context(ComputeArrowSnafu)?;
Ok(Some(filtered_batch))
}
fn process_batch(&mut self, record_batch: RecordBatch) -> error::Result<Option<RecordBatch>> {
// Apply projection first.
let projected_batch = self.apply_projection(record_batch)?;
// Apply combined filtering (both predicate and sequence filters)
let Some(filtered_batch) = self.apply_combined_filters(projected_batch)? else {
return Ok(None);
};
Ok(Some(filtered_batch))
}
}
impl Iterator for BulkPartRecordBatchIter {
type Item = error::Result<RecordBatch>;
fn next(&mut self) -> Option<Self::Item> {
let record_batch = self.record_batch.take()?;
self.process_batch(record_batch).transpose()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use api::v1::SemanticType;
use datafusion_expr::{col, lit};
use datatypes::arrow::array::{ArrayRef, Int64Array, StringArray, UInt64Array, UInt8Array};
use datatypes::arrow::datatypes::{DataType, Field, Schema};
use datatypes::data_type::ConcreteDataType;
use datatypes::schema::ColumnSchema;
use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
use store_api::storage::RegionId;
use table::predicate::Predicate;
use super::*;
use crate::memtable::bulk::context::BulkIterContext;
#[test]
fn test_bulk_part_record_batch_iter() {
// Create a simple schema
let schema = Arc::new(Schema::new(vec![
Field::new("key1", DataType::Utf8, false),
Field::new("field1", DataType::Int64, false),
Field::new(
"timestamp",
DataType::Timestamp(datatypes::arrow::datatypes::TimeUnit::Millisecond, None),
false,
),
Field::new(
"__primary_key",
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Binary)),
false,
),
Field::new("__sequence", DataType::UInt64, false),
Field::new("__op_type", DataType::UInt8, false),
]));
// Create test data
let key1 = Arc::new(StringArray::from_iter_values(["key1", "key2", "key3"]));
let field1 = Arc::new(Int64Array::from(vec![11, 12, 13]));
let timestamp = Arc::new(datatypes::arrow::array::TimestampMillisecondArray::from(
vec![1000, 2000, 3000],
));
// Create primary key dictionary array
use datatypes::arrow::array::{BinaryArray, DictionaryArray, UInt32Array};
let values = Arc::new(BinaryArray::from_iter_values([b"key1", b"key2", b"key3"]));
let keys = UInt32Array::from(vec![0, 1, 2]);
let primary_key = Arc::new(DictionaryArray::new(keys, values));
let sequence = Arc::new(UInt64Array::from(vec![1, 2, 3]));
let op_type = Arc::new(UInt8Array::from(vec![1, 1, 1])); // PUT operations
let record_batch = RecordBatch::try_new(
schema,
vec![
key1,
field1,
timestamp,
primary_key.clone(),
sequence,
op_type,
],
)
.unwrap();
// Create a minimal region metadata for testing
let mut builder = RegionMetadataBuilder::new(RegionId::new(1, 1));
builder
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(
"key1",
ConcreteDataType::string_datatype(),
false,
),
semantic_type: SemanticType::Tag,
column_id: 0,
})
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(
"field1",
ConcreteDataType::int64_datatype(),
false,
),
semantic_type: SemanticType::Field,
column_id: 1,
})
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(
"timestamp",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
semantic_type: SemanticType::Timestamp,
column_id: 2,
})
.primary_key(vec![0]);
let region_metadata = builder.build().unwrap();
// Create context
let context = Arc::new(BulkIterContext::new(
Arc::new(region_metadata.clone()),
&None, // No projection
None, // No predicate
true,
));
// Iterates all rows.
let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context.clone(), None);
let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
assert_eq!(1, result.len());
assert_eq!(3, result[0].num_rows());
assert_eq!(6, result[0].num_columns(),);
// Creates iter with sequence filter (only include sequences <= 2)
let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context, Some(2));
let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
assert_eq!(1, result.len());
let expect_sequence = Arc::new(UInt64Array::from(vec![1, 2])) as ArrayRef;
assert_eq!(
&expect_sequence,
result[0].column(result[0].num_columns() - 2)
);
assert_eq!(6, result[0].num_columns());
let context = Arc::new(BulkIterContext::new(
Arc::new(region_metadata),
&Some(&[0, 2]),
Some(Predicate::new(vec![col("key1").eq(lit("key2"))])),
true,
));
// Creates iter with projection and predicate.
let iter = BulkPartRecordBatchIter::new(record_batch.clone(), context.clone(), None);
let result: Vec<_> = iter.map(|rb| rb.unwrap()).collect();
assert_eq!(1, result.len());
assert_eq!(1, result[0].num_rows());
assert_eq!(5, result[0].num_columns());
let expect_sequence = Arc::new(UInt64Array::from(vec![2])) as ArrayRef;
assert_eq!(
&expect_sequence,
result[0].column(result[0].num_columns() - 2)
);
}
}

View File

@@ -140,12 +140,16 @@ pub enum ReadFormat {
}
impl ReadFormat {
// TODO(yingwen): Add a flag to choose format type.
pub(crate) fn new(
metadata: RegionMetadataRef,
column_ids: impl Iterator<Item = ColumnId>,
flat_format: bool,
) -> Self {
Self::new_primary_key(metadata, column_ids)
if flat_format {
Self::new_flat(metadata, column_ids)
} else {
Self::new_primary_key(metadata, column_ids)
}
}
/// Creates a helper to read the primary key format.
@@ -171,6 +175,13 @@ impl ReadFormat {
}
}
pub(crate) fn as_flat(&self) -> Option<&FlatReadFormat> {
match self {
ReadFormat::Flat(format) => Some(format),
_ => None,
}
}
/// Gets the arrow schema of the SST file.
///
/// This schema is computed from the region metadata but should be the same
@@ -1201,7 +1212,7 @@ mod tests {
.iter()
.map(|col| col.column_id)
.collect();
let read_format = ReadFormat::new(metadata, column_ids.iter().copied());
let read_format = ReadFormat::new(metadata, column_ids.iter().copied(), false);
let columns: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![1, 1, 10, 10])), // field1