From a0a2942ad54fda92af8dc614043df18ba34486bb Mon Sep 17 00:00:00 2001 From: yaommen Date: Tue, 31 Mar 2026 06:43:58 +0800 Subject: [PATCH] fix: respect max_batch_length for Rust vector and hybrid queries (#3172) Fixes #1540 I could not reproduce this on current `main` from Python, but I could still reproduce it from the Rust SDK. Python no longer reproduces because the current Python vector/hybrid query paths re-chunk results into a `pyarrow.Table` before returning batches. Rust still reproduced because `max_batch_length` was passed into planning/scanning, but vector search could still emit larger `RecordBatch`es later in execution (for example after KNN / TopK), so it was not enforced on the final Rust output stream. This PR enforces `max_batch_length` on the final Rust query output stream and adds Rust regression coverage. Before the fix, the Rust repro produced: `num_batches=2, max_batch=8192, min_batch=1808, all_le_100=false` After the fix, the same repro produces batches `<= 100`. ## Runnable Rust repro Before this fix, current `main` could still return batches like `[8192, 1808]` here even with `max_batch_length = 100`: ```rust use std::sync::Arc; use arrow_array::{ types::Float32Type, FixedSizeListArray, RecordBatch, RecordBatchReader, StringArray, }; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use lancedb::query::{ExecutableQuery, QueryBase, QueryExecutionOptions}; #[tokio::main] async fn main() -> Result<(), Box> { let tmp = tempfile::tempdir()?; let uri = tmp.path().to_str().unwrap(); let rows = 10_000; let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new( "vector", DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), false, ), ])); let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}"))); let vectors = FixedSizeListArray::from_iter_primitive::( (0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])), 4, ); let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)])?; let reader: Box = Box::new( arrow_array::RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), ); let db = lancedb::connect(uri).execute().await?; let table = db.create_table("test", reader).execute().await?; let mut opts = QueryExecutionOptions::default(); opts.max_batch_length = 100; let mut stream = table .query() .nearest_to(vec![0.0, 1.0, 2.0, 3.0])? .limit(rows) .execute_with_options(opts) .await?; let mut sizes = Vec::new(); while let Some(batch) = stream.try_next().await? { sizes.push(batch.num_rows()); } println!("{sizes:?}"); Ok(()) } ``` Signed-off-by: yaommen --- rust/lancedb/src/query.rs | 169 ++++++++++++++++++++++++++++++-- rust/lancedb/src/table/query.rs | 7 +- rust/lancedb/src/utils/mod.rs | 120 ++++++++++++++++++++++- 3 files changed, 280 insertions(+), 16 deletions(-) diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index a1804a79c..8f60230b0 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use std::{future::Future, time::Duration}; use arrow::compute::concat_batches; -use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array}; +use arrow_array::{Array, Float16Array, Float32Array, Float64Array, RecordBatch, make_array}; use arrow_schema::{DataType, SchemaRef}; use datafusion_expr::Expr; use datafusion_physical_plan::ExecutionPlan; @@ -17,15 +17,17 @@ use lance_datafusion::exec::execute_plan; use lance_index::scalar::FullTextSearchQuery; use lance_index::scalar::inverted::SCORE_COL; use lance_index::vector::DIST_COL; -use lance_io::stream::RecordBatchStreamAdapter; use crate::DistanceType; use crate::error::{Error, Result}; use crate::rerankers::rrf::RRFReranker; use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result}; use crate::table::BaseTable; -use crate::utils::TimeoutStream; -use crate::{arrow::SendableRecordBatchStream, table::AnyQuery}; +use crate::utils::{MaxBatchLengthStream, TimeoutStream}; +use crate::{ + arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, + table::AnyQuery, +}; mod hybrid; @@ -604,6 +606,14 @@ impl Default for QueryExecutionOptions { } } +impl QueryExecutionOptions { + fn without_output_batch_length_limit(&self) -> Self { + let mut options = self.clone(); + options.max_batch_length = 0; + options + } +} + /// A trait for a query object that can be executed to get results /// /// There are various kinds of queries but they all return results @@ -1180,6 +1190,8 @@ impl VectorQuery { &self, options: QueryExecutionOptions, ) -> Result { + let max_batch_length = options.max_batch_length as usize; + let internal_options = options.without_output_batch_length_limit(); // clone query and specify we want to include row IDs, which can be needed for reranking let mut fts_query = Query::new(self.parent.clone()); fts_query.request = self.request.base.clone(); @@ -1189,8 +1201,8 @@ impl VectorQuery { vector_query.request.base.full_text_search = None; let (fts_results, vec_results) = try_join!( - fts_query.execute_with_options(options.clone()), - vector_query.inner_execute_with_options(options) + fts_query.execute_with_options(internal_options.clone()), + vector_query.inner_execute_with_options(internal_options) )?; let (fts_results, vec_results) = try_join!( @@ -1245,9 +1257,7 @@ impl VectorQuery { results = results.drop_column(ROW_ID)?; } - Ok(SendableRecordBatchStream::from( - RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])), - )) + Ok(single_batch_stream(results, max_batch_length)) } async fn inner_execute_with_options( @@ -1256,6 +1266,7 @@ impl VectorQuery { ) -> Result { let plan = self.create_plan(options.clone()).await?; let inner = execute_plan(plan, Default::default())?; + let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize); let inner = if let Some(timeout) = options.timeout { TimeoutStream::new_boxed(inner, timeout) } else { @@ -1265,6 +1276,25 @@ impl VectorQuery { } } +fn single_batch_stream(batch: RecordBatch, max_batch_length: usize) -> SendableRecordBatchStream { + let schema = batch.schema(); + if max_batch_length == 0 || batch.num_rows() <= max_batch_length { + return Box::pin(SimpleRecordBatchStream::new( + stream::iter([Ok(batch)]), + schema, + )); + } + + let mut batches = Vec::with_capacity(batch.num_rows().div_ceil(max_batch_length)); + let mut offset = 0; + while offset < batch.num_rows() { + let length = (batch.num_rows() - offset).min(max_batch_length); + batches.push(Ok(batch.slice(offset, length))); + offset += length; + } + Box::pin(SimpleRecordBatchStream::new(stream::iter(batches), schema)) +} + impl ExecutableQuery for VectorQuery { async fn create_plan(&self, options: QueryExecutionOptions) -> Result> { let query = AnyQuery::VectorQuery(self.request.clone()); @@ -1753,6 +1783,50 @@ mod tests { .unwrap() } + async fn make_large_vector_table(tmp_dir: &tempfile::TempDir, rows: usize) -> Table { + let dataset_path = tmp_dir.path().join("large_test.lance"); + let uri = dataset_path.to_str().unwrap(); + + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("id", DataType::Utf8, false), + ArrowField::new( + "vector", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Float32, true)), + 4, + ), + false, + ), + ])); + + let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}"))); + let vectors = FixedSizeListArray::from_iter_primitive::( + (0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])), + 4, + ); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)]).unwrap(); + + let conn = connect(uri).execute().await.unwrap(); + conn.create_table("my_table", vec![batch]) + .execute() + .await + .unwrap() + } + + async fn assert_stream_batches_at_most( + mut results: SendableRecordBatchStream, + max_batch_length: usize, + ) { + let mut saw_batch = false; + while let Some(batch) = results.next().await { + let batch = batch.unwrap(); + saw_batch = true; + assert!(batch.num_rows() <= max_batch_length); + } + assert!(saw_batch); + } + #[tokio::test] async fn test_execute_with_options() { let tmp_dir = tempdir().unwrap(); @@ -1772,6 +1846,83 @@ mod tests { } } + #[tokio::test] + async fn test_vector_query_execute_with_options_respects_max_batch_length() { + let tmp_dir = tempdir().unwrap(); + let table = make_large_vector_table(&tmp_dir, 10_000).await; + + let results = table + .query() + .nearest_to(vec![0.0, 1.0, 2.0, 3.0]) + .unwrap() + .limit(10_000) + .execute_with_options(QueryExecutionOptions { + max_batch_length: 100, + ..Default::default() + }) + .await + .unwrap(); + assert_stream_batches_at_most(results, 100).await; + } + + #[tokio::test] + async fn test_hybrid_query_execute_with_options_respects_max_batch_length() { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path(); + let conn = connect(dataset_path.to_str().unwrap()) + .execute() + .await + .unwrap(); + + let dims = 2; + let rows = 512; + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("text", DataType::Utf8, false), + ArrowField::new( + "vector", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Float32, true)), + dims, + ), + false, + ), + ])); + + let text = StringArray::from_iter_values((0..rows).map(|_| "match")); + let vectors = FixedSizeListArray::from_iter_primitive::( + (0..rows).map(|i| Some(vec![Some(i as f32), Some(0.0)])), + dims, + ); + let record_batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vectors)]).unwrap(); + let table = conn + .create_table("my_table", record_batch) + .execute() + .await + .unwrap(); + + table + .create_index(&["text"], crate::index::Index::FTS(Default::default())) + .replace(true) + .execute() + .await + .unwrap(); + + let results = table + .query() + .full_text_search(FullTextSearchQuery::new("match".to_string())) + .limit(rows) + .nearest_to(&[0.0, 0.0]) + .unwrap() + .execute_with_options(QueryExecutionOptions { + max_batch_length: 100, + ..Default::default() + }) + .await + .unwrap(); + assert_stream_batches_at_most(results, 100).await; + } + #[tokio::test] async fn test_analyze_plan() { let tmp_dir = tempdir().unwrap(); diff --git a/rust/lancedb/src/table/query.rs b/rust/lancedb/src/table/query.rs index abce6d325..6cbcf4e19 100644 --- a/rust/lancedb/src/table/query.rs +++ b/rust/lancedb/src/table/query.rs @@ -9,7 +9,7 @@ use crate::expr::expr_to_sql_string; use crate::query::{ DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest, }; -use crate::utils::{TimeoutStream, default_vector_column}; +use crate::utils::{MaxBatchLengthStream, TimeoutStream, default_vector_column}; use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder}; use arrow::datatypes::{Float32Type, UInt8Type}; use arrow_array::Array; @@ -66,6 +66,7 @@ async fn execute_generic_query( ) -> Result { let plan = create_plan(table, query, options.clone()).await?; let inner = execute_plan(plan, Default::default())?; + let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize); let inner = if let Some(timeout) = options.timeout { TimeoutStream::new_boxed(inner, timeout) } else { @@ -200,7 +201,9 @@ pub async fn create_plan( scanner.with_row_id(); } - scanner.batch_size(options.max_batch_length as usize); + if options.max_batch_length > 0 { + scanner.batch_size(options.max_batch_length as usize); + } if query.base.fast_search { scanner.fast_search(); diff --git a/rust/lancedb/src/utils/mod.rs b/rust/lancedb/src/utils/mod.rs index ffed533f6..0af8623b4 100644 --- a/rust/lancedb/src/utils/mod.rs +++ b/rust/lancedb/src/utils/mod.rs @@ -335,6 +335,85 @@ impl Stream for TimeoutStream { } } +/// A `Stream` wrapper that slices oversized batches to enforce a maximum batch length. +pub struct MaxBatchLengthStream { + inner: SendableRecordBatchStream, + max_batch_length: Option, + buffered_batch: Option, + buffered_offset: usize, +} + +impl MaxBatchLengthStream { + pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self { + Self { + inner, + max_batch_length: (max_batch_length > 0).then_some(max_batch_length), + buffered_batch: None, + buffered_offset: 0, + } + } + + pub fn new_boxed( + inner: SendableRecordBatchStream, + max_batch_length: usize, + ) -> SendableRecordBatchStream { + if max_batch_length == 0 { + inner + } else { + Box::pin(Self::new(inner, max_batch_length)) + } + } +} + +impl RecordBatchStream for MaxBatchLengthStream { + fn schema(&self) -> SchemaRef { + self.inner.schema() + } +} + +impl Stream for MaxBatchLengthStream { + type Item = DataFusionResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + loop { + let Some(max_batch_length) = self.max_batch_length else { + return Pin::new(&mut self.inner).poll_next(cx); + }; + + if let Some(batch) = self.buffered_batch.clone() { + if self.buffered_offset < batch.num_rows() { + let remaining = batch.num_rows() - self.buffered_offset; + let length = remaining.min(max_batch_length); + let sliced = batch.slice(self.buffered_offset, length); + self.buffered_offset += length; + if self.buffered_offset >= batch.num_rows() { + self.buffered_batch = None; + self.buffered_offset = 0; + } + return std::task::Poll::Ready(Some(Ok(sliced))); + } + + self.buffered_batch = None; + self.buffered_offset = 0; + } + + match Pin::new(&mut self.inner).poll_next(cx) { + std::task::Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() <= max_batch_length { + return std::task::Poll::Ready(Some(Ok(batch))); + } + self.buffered_batch = Some(batch); + self.buffered_offset = 0; + } + other => return other, + } + } + } +} + #[cfg(test)] mod tests { use arrow_array::Int32Array; @@ -470,7 +549,7 @@ mod tests { assert_eq!(string_to_datatype(string), Some(expected)); } - fn sample_batch() -> RecordBatch { + fn sample_batch(num_rows: i32) -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new( "col1", DataType::Int32, @@ -478,14 +557,14 @@ mod tests { )])); RecordBatch::try_new( schema.clone(), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + vec![Arc::new(Int32Array::from_iter_values(0..num_rows))], ) .unwrap() } #[tokio::test] async fn test_timeout_stream() { - let batch = sample_batch(); + let batch = sample_batch(3); let schema = batch.schema(); let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]); @@ -515,7 +594,7 @@ mod tests { #[tokio::test] async fn test_timeout_stream_zero_duration() { - let batch = sample_batch(); + let batch = sample_batch(3); let schema = batch.schema(); let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]); @@ -534,7 +613,7 @@ mod tests { #[tokio::test] async fn test_timeout_stream_completes_normally() { - let batch = sample_batch(); + let batch = sample_batch(3); let schema = batch.schema(); let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]); @@ -552,4 +631,35 @@ mod tests { // Stream should be empty now assert!(timeout_stream.next().await.is_none()); } + + async fn collect_batch_sizes( + stream: SendableRecordBatchStream, + max_batch_length: usize, + ) -> Vec { + let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length); + sliced_stream + .by_ref() + .map(|batch| batch.unwrap().num_rows()) + .collect::>() + .await + } + + #[tokio::test] + async fn test_max_batch_length_stream_behaviors() { + let schema = sample_batch(7).schema(); + let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]); + + let sendable_stream: SendableRecordBatchStream = + Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream)); + assert_eq!( + collect_batch_sizes(sendable_stream, 3).await, + vec![2, 3, 3, 1] + ); + + let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( + schema, + stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]), + )); + assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]); + } }