fix: respect max_batch_length for Rust vector and hybrid queries (#3172)

Fixes #1540

I could not reproduce this on current `main` from Python, but I could
still reproduce it from the Rust SDK.

Python no longer reproduces because the current Python vector/hybrid
query paths re-chunk results into a `pyarrow.Table` before returning
batches. Rust still reproduced because `max_batch_length` was passed
into planning/scanning, but vector search could still emit larger
`RecordBatch`es later in execution (for example after KNN / TopK), so it
was not enforced on the final Rust output stream.

This PR enforces `max_batch_length` on the final Rust query output
stream and adds Rust regression coverage.

Before the fix, the Rust repro produced:
`num_batches=2, max_batch=8192, min_batch=1808, all_le_100=false`

After the fix, the same repro produces batches `<= 100`.

## Runnable Rust repro

Before this fix, current `main` could still return batches like `[8192,
1808]` here even with `max_batch_length = 100`:

```rust
use std::sync::Arc;

use arrow_array::{
    types::Float32Type, FixedSizeListArray, RecordBatch, RecordBatchReader, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lancedb::query::{ExecutableQuery, QueryBase, QueryExecutionOptions};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let tmp = tempfile::tempdir()?;
    let uri = tmp.path().to_str().unwrap();

    let rows = 10_000;
    let schema = Arc::new(Schema::new(vec![
        Field::new("id", DataType::Utf8, false),
        Field::new(
            "vector",
            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
            false,
        ),
    ]));

    let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
    let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
        (0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
        4,
    );
    let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)])?;
    let reader: Box<dyn RecordBatchReader + Send> = Box::new(
        arrow_array::RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema),
    );

    let db = lancedb::connect(uri).execute().await?;
    let table = db.create_table("test", reader).execute().await?;

    let mut opts = QueryExecutionOptions::default();
    opts.max_batch_length = 100;

    let mut stream = table
        .query()
        .nearest_to(vec![0.0, 1.0, 2.0, 3.0])?
        .limit(rows)
        .execute_with_options(opts)
        .await?;

    let mut sizes = Vec::new();
    while let Some(batch) = stream.try_next().await? {
        sizes.push(batch.num_rows());
    }

    println!("{sizes:?}");
    Ok(())
}
```

Signed-off-by: yaommen <myanstu@163.com>
This commit is contained in:
yaommen
2026-03-31 06:43:58 +08:00
committed by GitHub
parent e3d53dd185
commit a0a2942ad5
3 changed files with 280 additions and 16 deletions

View File

@@ -5,7 +5,7 @@ use std::sync::Arc;
use std::{future::Future, time::Duration};
use arrow::compute::concat_batches;
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array};
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, RecordBatch, make_array};
use arrow_schema::{DataType, SchemaRef};
use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlan;
@@ -17,15 +17,17 @@ use lance_datafusion::exec::execute_plan;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::inverted::SCORE_COL;
use lance_index::vector::DIST_COL;
use lance_io::stream::RecordBatchStreamAdapter;
use crate::DistanceType;
use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
use crate::table::BaseTable;
use crate::utils::TimeoutStream;
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
use crate::utils::{MaxBatchLengthStream, TimeoutStream};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
table::AnyQuery,
};
mod hybrid;
@@ -604,6 +606,14 @@ impl Default for QueryExecutionOptions {
}
}
impl QueryExecutionOptions {
fn without_output_batch_length_limit(&self) -> Self {
let mut options = self.clone();
options.max_batch_length = 0;
options
}
}
/// A trait for a query object that can be executed to get results
///
/// There are various kinds of queries but they all return results
@@ -1180,6 +1190,8 @@ impl VectorQuery {
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let max_batch_length = options.max_batch_length as usize;
let internal_options = options.without_output_batch_length_limit();
// clone query and specify we want to include row IDs, which can be needed for reranking
let mut fts_query = Query::new(self.parent.clone());
fts_query.request = self.request.base.clone();
@@ -1189,8 +1201,8 @@ impl VectorQuery {
vector_query.request.base.full_text_search = None;
let (fts_results, vec_results) = try_join!(
fts_query.execute_with_options(options.clone()),
vector_query.inner_execute_with_options(options)
fts_query.execute_with_options(internal_options.clone()),
vector_query.inner_execute_with_options(internal_options)
)?;
let (fts_results, vec_results) = try_join!(
@@ -1245,9 +1257,7 @@ impl VectorQuery {
results = results.drop_column(ROW_ID)?;
}
Ok(SendableRecordBatchStream::from(
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
))
Ok(single_batch_stream(results, max_batch_length))
}
async fn inner_execute_with_options(
@@ -1256,6 +1266,7 @@ impl VectorQuery {
) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan(options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
@@ -1265,6 +1276,25 @@ impl VectorQuery {
}
}
fn single_batch_stream(batch: RecordBatch, max_batch_length: usize) -> SendableRecordBatchStream {
let schema = batch.schema();
if max_batch_length == 0 || batch.num_rows() <= max_batch_length {
return Box::pin(SimpleRecordBatchStream::new(
stream::iter([Ok(batch)]),
schema,
));
}
let mut batches = Vec::with_capacity(batch.num_rows().div_ceil(max_batch_length));
let mut offset = 0;
while offset < batch.num_rows() {
let length = (batch.num_rows() - offset).min(max_batch_length);
batches.push(Ok(batch.slice(offset, length)));
offset += length;
}
Box::pin(SimpleRecordBatchStream::new(stream::iter(batches), schema))
}
impl ExecutableQuery for VectorQuery {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
let query = AnyQuery::VectorQuery(self.request.clone());
@@ -1753,6 +1783,50 @@ mod tests {
.unwrap()
}
async fn make_large_vector_table(tmp_dir: &tempfile::TempDir, rows: usize) -> Table {
let dataset_path = tmp_dir.path().join("large_test.lance");
let uri = dataset_path.to_str().unwrap();
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("id", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
4,
),
false,
),
]));
let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
4,
);
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)]).unwrap();
let conn = connect(uri).execute().await.unwrap();
conn.create_table("my_table", vec![batch])
.execute()
.await
.unwrap()
}
async fn assert_stream_batches_at_most(
mut results: SendableRecordBatchStream,
max_batch_length: usize,
) {
let mut saw_batch = false;
while let Some(batch) = results.next().await {
let batch = batch.unwrap();
saw_batch = true;
assert!(batch.num_rows() <= max_batch_length);
}
assert!(saw_batch);
}
#[tokio::test]
async fn test_execute_with_options() {
let tmp_dir = tempdir().unwrap();
@@ -1772,6 +1846,83 @@ mod tests {
}
}
#[tokio::test]
async fn test_vector_query_execute_with_options_respects_max_batch_length() {
let tmp_dir = tempdir().unwrap();
let table = make_large_vector_table(&tmp_dir, 10_000).await;
let results = table
.query()
.nearest_to(vec![0.0, 1.0, 2.0, 3.0])
.unwrap()
.limit(10_000)
.execute_with_options(QueryExecutionOptions {
max_batch_length: 100,
..Default::default()
})
.await
.unwrap();
assert_stream_batches_at_most(results, 100).await;
}
#[tokio::test]
async fn test_hybrid_query_execute_with_options_respects_max_batch_length() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path();
let conn = connect(dataset_path.to_str().unwrap())
.execute()
.await
.unwrap();
let dims = 2;
let rows = 512;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("text", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dims,
),
false,
),
]));
let text = StringArray::from_iter_values((0..rows).map(|_| "match"));
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..rows).map(|i| Some(vec![Some(i as f32), Some(0.0)])),
dims,
);
let record_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vectors)]).unwrap();
let table = conn
.create_table("my_table", record_batch)
.execute()
.await
.unwrap();
table
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
.replace(true)
.execute()
.await
.unwrap();
let results = table
.query()
.full_text_search(FullTextSearchQuery::new("match".to_string()))
.limit(rows)
.nearest_to(&[0.0, 0.0])
.unwrap()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 100,
..Default::default()
})
.await
.unwrap();
assert_stream_batches_at_most(results, 100).await;
}
#[tokio::test]
async fn test_analyze_plan() {
let tmp_dir = tempdir().unwrap();

View File

@@ -9,7 +9,7 @@ use crate::expr::expr_to_sql_string;
use crate::query::{
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
};
use crate::utils::{TimeoutStream, default_vector_column};
use crate::utils::{MaxBatchLengthStream, TimeoutStream, default_vector_column};
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::Array;
@@ -66,6 +66,7 @@ async fn execute_generic_query(
) -> Result<DatasetRecordBatchStream> {
let plan = create_plan(table, query, options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
@@ -200,7 +201,9 @@ pub async fn create_plan(
scanner.with_row_id();
}
scanner.batch_size(options.max_batch_length as usize);
if options.max_batch_length > 0 {
scanner.batch_size(options.max_batch_length as usize);
}
if query.base.fast_search {
scanner.fast_search();

View File

@@ -335,6 +335,85 @@ impl Stream for TimeoutStream {
}
}
/// A `Stream` wrapper that slices oversized batches to enforce a maximum batch length.
pub struct MaxBatchLengthStream {
inner: SendableRecordBatchStream,
max_batch_length: Option<usize>,
buffered_batch: Option<RecordBatch>,
buffered_offset: usize,
}
impl MaxBatchLengthStream {
pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self {
Self {
inner,
max_batch_length: (max_batch_length > 0).then_some(max_batch_length),
buffered_batch: None,
buffered_offset: 0,
}
}
pub fn new_boxed(
inner: SendableRecordBatchStream,
max_batch_length: usize,
) -> SendableRecordBatchStream {
if max_batch_length == 0 {
inner
} else {
Box::pin(Self::new(inner, max_batch_length))
}
}
}
impl RecordBatchStream for MaxBatchLengthStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for MaxBatchLengthStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
loop {
let Some(max_batch_length) = self.max_batch_length else {
return Pin::new(&mut self.inner).poll_next(cx);
};
if let Some(batch) = self.buffered_batch.clone() {
if self.buffered_offset < batch.num_rows() {
let remaining = batch.num_rows() - self.buffered_offset;
let length = remaining.min(max_batch_length);
let sliced = batch.slice(self.buffered_offset, length);
self.buffered_offset += length;
if self.buffered_offset >= batch.num_rows() {
self.buffered_batch = None;
self.buffered_offset = 0;
}
return std::task::Poll::Ready(Some(Ok(sliced)));
}
self.buffered_batch = None;
self.buffered_offset = 0;
}
match Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(batch))) => {
if batch.num_rows() <= max_batch_length {
return std::task::Poll::Ready(Some(Ok(batch)));
}
self.buffered_batch = Some(batch);
self.buffered_offset = 0;
}
other => return other,
}
}
}
}
#[cfg(test)]
mod tests {
use arrow_array::Int32Array;
@@ -470,7 +549,7 @@ mod tests {
assert_eq!(string_to_datatype(string), Some(expected));
}
fn sample_batch() -> RecordBatch {
fn sample_batch(num_rows: i32) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"col1",
DataType::Int32,
@@ -478,14 +557,14 @@ mod tests {
)]));
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
)
.unwrap()
}
#[tokio::test]
async fn test_timeout_stream() {
let batch = sample_batch();
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -515,7 +594,7 @@ mod tests {
#[tokio::test]
async fn test_timeout_stream_zero_duration() {
let batch = sample_batch();
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -534,7 +613,7 @@ mod tests {
#[tokio::test]
async fn test_timeout_stream_completes_normally() {
let batch = sample_batch();
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -552,4 +631,35 @@ mod tests {
// Stream should be empty now
assert!(timeout_stream.next().await.is_none());
}
async fn collect_batch_sizes(
stream: SendableRecordBatchStream,
max_batch_length: usize,
) -> Vec<usize> {
let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length);
sliced_stream
.by_ref()
.map(|batch| batch.unwrap().num_rows())
.collect::<Vec<_>>()
.await
}
#[tokio::test]
async fn test_max_batch_length_stream_behaviors() {
let schema = sample_batch(7).schema();
let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
assert_eq!(
collect_batch_sizes(sendable_stream, 3).await,
vec![2, 3, 3, 1]
);
let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]),
));
assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]);
}
}