mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-16 11:30:41 +00:00
fix: respect max_batch_length for Rust vector and hybrid queries (#3172)
Fixes #1540 I could not reproduce this on current `main` from Python, but I could still reproduce it from the Rust SDK. Python no longer reproduces because the current Python vector/hybrid query paths re-chunk results into a `pyarrow.Table` before returning batches. Rust still reproduced because `max_batch_length` was passed into planning/scanning, but vector search could still emit larger `RecordBatch`es later in execution (for example after KNN / TopK), so it was not enforced on the final Rust output stream. This PR enforces `max_batch_length` on the final Rust query output stream and adds Rust regression coverage. Before the fix, the Rust repro produced: `num_batches=2, max_batch=8192, min_batch=1808, all_le_100=false` After the fix, the same repro produces batches `<= 100`. ## Runnable Rust repro Before this fix, current `main` could still return batches like `[8192, 1808]` here even with `max_batch_length = 100`: ```rust use std::sync::Arc; use arrow_array::{ types::Float32Type, FixedSizeListArray, RecordBatch, RecordBatchReader, StringArray, }; use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use lancedb::query::{ExecutableQuery, QueryBase, QueryExecutionOptions}; #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> { let tmp = tempfile::tempdir()?; let uri = tmp.path().to_str().unwrap(); let rows = 10_000; let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new( "vector", DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), false, ), ])); let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}"))); let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>( (0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])), 4, ); let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)])?; let reader: Box<dyn RecordBatchReader + Send> = Box::new( arrow_array::RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema), ); let db = lancedb::connect(uri).execute().await?; let table = db.create_table("test", reader).execute().await?; let mut opts = QueryExecutionOptions::default(); opts.max_batch_length = 100; let mut stream = table .query() .nearest_to(vec![0.0, 1.0, 2.0, 3.0])? .limit(rows) .execute_with_options(opts) .await?; let mut sizes = Vec::new(); while let Some(batch) = stream.try_next().await? { sizes.push(batch.num_rows()); } println!("{sizes:?}"); Ok(()) } ``` Signed-off-by: yaommen <myanstu@163.com>
This commit is contained in:
@@ -5,7 +5,7 @@ use std::sync::Arc;
|
||||
use std::{future::Future, time::Duration};
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array};
|
||||
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, RecordBatch, make_array};
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
@@ -17,15 +17,17 @@ use lance_datafusion::exec::execute_plan;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use lance_index::scalar::inverted::SCORE_COL;
|
||||
use lance_index::vector::DIST_COL;
|
||||
use lance_io::stream::RecordBatchStreamAdapter;
|
||||
|
||||
use crate::DistanceType;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::rrf::RRFReranker;
|
||||
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
|
||||
use crate::table::BaseTable;
|
||||
use crate::utils::TimeoutStream;
|
||||
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
|
||||
use crate::utils::{MaxBatchLengthStream, TimeoutStream};
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
table::AnyQuery,
|
||||
};
|
||||
|
||||
mod hybrid;
|
||||
|
||||
@@ -604,6 +606,14 @@ impl Default for QueryExecutionOptions {
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryExecutionOptions {
|
||||
fn without_output_batch_length_limit(&self) -> Self {
|
||||
let mut options = self.clone();
|
||||
options.max_batch_length = 0;
|
||||
options
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for a query object that can be executed to get results
|
||||
///
|
||||
/// There are various kinds of queries but they all return results
|
||||
@@ -1180,6 +1190,8 @@ impl VectorQuery {
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let max_batch_length = options.max_batch_length as usize;
|
||||
let internal_options = options.without_output_batch_length_limit();
|
||||
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||
let mut fts_query = Query::new(self.parent.clone());
|
||||
fts_query.request = self.request.base.clone();
|
||||
@@ -1189,8 +1201,8 @@ impl VectorQuery {
|
||||
|
||||
vector_query.request.base.full_text_search = None;
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
fts_query.execute_with_options(options.clone()),
|
||||
vector_query.inner_execute_with_options(options)
|
||||
fts_query.execute_with_options(internal_options.clone()),
|
||||
vector_query.inner_execute_with_options(internal_options)
|
||||
)?;
|
||||
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
@@ -1245,9 +1257,7 @@ impl VectorQuery {
|
||||
results = results.drop_column(ROW_ID)?;
|
||||
}
|
||||
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
|
||||
))
|
||||
Ok(single_batch_stream(results, max_batch_length))
|
||||
}
|
||||
|
||||
async fn inner_execute_with_options(
|
||||
@@ -1256,6 +1266,7 @@ impl VectorQuery {
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let plan = self.create_plan(options.clone()).await?;
|
||||
let inner = execute_plan(plan, Default::default())?;
|
||||
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
|
||||
let inner = if let Some(timeout) = options.timeout {
|
||||
TimeoutStream::new_boxed(inner, timeout)
|
||||
} else {
|
||||
@@ -1265,6 +1276,25 @@ impl VectorQuery {
|
||||
}
|
||||
}
|
||||
|
||||
fn single_batch_stream(batch: RecordBatch, max_batch_length: usize) -> SendableRecordBatchStream {
|
||||
let schema = batch.schema();
|
||||
if max_batch_length == 0 || batch.num_rows() <= max_batch_length {
|
||||
return Box::pin(SimpleRecordBatchStream::new(
|
||||
stream::iter([Ok(batch)]),
|
||||
schema,
|
||||
));
|
||||
}
|
||||
|
||||
let mut batches = Vec::with_capacity(batch.num_rows().div_ceil(max_batch_length));
|
||||
let mut offset = 0;
|
||||
while offset < batch.num_rows() {
|
||||
let length = (batch.num_rows() - offset).min(max_batch_length);
|
||||
batches.push(Ok(batch.slice(offset, length)));
|
||||
offset += length;
|
||||
}
|
||||
Box::pin(SimpleRecordBatchStream::new(stream::iter(batches), schema))
|
||||
}
|
||||
|
||||
impl ExecutableQuery for VectorQuery {
|
||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let query = AnyQuery::VectorQuery(self.request.clone());
|
||||
@@ -1753,6 +1783,50 @@ mod tests {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn make_large_vector_table(tmp_dir: &tempfile::TempDir, rows: usize) -> Table {
|
||||
let dataset_path = tmp_dir.path().join("large_test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("id", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
4,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
|
||||
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
|
||||
4,
|
||||
);
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)]).unwrap();
|
||||
|
||||
let conn = connect(uri).execute().await.unwrap();
|
||||
conn.create_table("my_table", vec![batch])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn assert_stream_batches_at_most(
|
||||
mut results: SendableRecordBatchStream,
|
||||
max_batch_length: usize,
|
||||
) {
|
||||
let mut saw_batch = false;
|
||||
while let Some(batch) = results.next().await {
|
||||
let batch = batch.unwrap();
|
||||
saw_batch = true;
|
||||
assert!(batch.num_rows() <= max_batch_length);
|
||||
}
|
||||
assert!(saw_batch);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_with_options() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
@@ -1772,6 +1846,83 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_query_execute_with_options_respects_max_batch_length() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_large_vector_table(&tmp_dir, 10_000).await;
|
||||
|
||||
let results = table
|
||||
.query()
|
||||
.nearest_to(vec![0.0, 1.0, 2.0, 3.0])
|
||||
.unwrap()
|
||||
.limit(10_000)
|
||||
.execute_with_options(QueryExecutionOptions {
|
||||
max_batch_length: 100,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_stream_batches_at_most(results, 100).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_query_execute_with_options_respects_max_batch_length() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path();
|
||||
let conn = connect(dataset_path.to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dims = 2;
|
||||
let rows = 512;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("text", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
dims,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let text = StringArray::from_iter_values((0..rows).map(|_| "match"));
|
||||
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..rows).map(|i| Some(vec![Some(i as f32), Some(0.0)])),
|
||||
dims,
|
||||
);
|
||||
let record_batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vectors)]).unwrap();
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||
.replace(true)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = table
|
||||
.query()
|
||||
.full_text_search(FullTextSearchQuery::new("match".to_string()))
|
||||
.limit(rows)
|
||||
.nearest_to(&[0.0, 0.0])
|
||||
.unwrap()
|
||||
.execute_with_options(QueryExecutionOptions {
|
||||
max_batch_length: 100,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_stream_batches_at_most(results, 100).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_plan() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::expr::expr_to_sql_string;
|
||||
use crate::query::{
|
||||
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
|
||||
};
|
||||
use crate::utils::{TimeoutStream, default_vector_column};
|
||||
use crate::utils::{MaxBatchLengthStream, TimeoutStream, default_vector_column};
|
||||
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
|
||||
use arrow::datatypes::{Float32Type, UInt8Type};
|
||||
use arrow_array::Array;
|
||||
@@ -66,6 +66,7 @@ async fn execute_generic_query(
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let plan = create_plan(table, query, options.clone()).await?;
|
||||
let inner = execute_plan(plan, Default::default())?;
|
||||
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
|
||||
let inner = if let Some(timeout) = options.timeout {
|
||||
TimeoutStream::new_boxed(inner, timeout)
|
||||
} else {
|
||||
@@ -200,7 +201,9 @@ pub async fn create_plan(
|
||||
scanner.with_row_id();
|
||||
}
|
||||
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
if options.max_batch_length > 0 {
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
}
|
||||
|
||||
if query.base.fast_search {
|
||||
scanner.fast_search();
|
||||
|
||||
@@ -335,6 +335,85 @@ impl Stream for TimeoutStream {
|
||||
}
|
||||
}
|
||||
|
||||
/// A `Stream` wrapper that slices oversized batches to enforce a maximum batch length.
|
||||
pub struct MaxBatchLengthStream {
|
||||
inner: SendableRecordBatchStream,
|
||||
max_batch_length: Option<usize>,
|
||||
buffered_batch: Option<RecordBatch>,
|
||||
buffered_offset: usize,
|
||||
}
|
||||
|
||||
impl MaxBatchLengthStream {
|
||||
pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
max_batch_length: (max_batch_length > 0).then_some(max_batch_length),
|
||||
buffered_batch: None,
|
||||
buffered_offset: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_boxed(
|
||||
inner: SendableRecordBatchStream,
|
||||
max_batch_length: usize,
|
||||
) -> SendableRecordBatchStream {
|
||||
if max_batch_length == 0 {
|
||||
inner
|
||||
} else {
|
||||
Box::pin(Self::new(inner, max_batch_length))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordBatchStream for MaxBatchLengthStream {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.inner.schema()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for MaxBatchLengthStream {
|
||||
type Item = DataFusionResult<RecordBatch>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
let Some(max_batch_length) = self.max_batch_length else {
|
||||
return Pin::new(&mut self.inner).poll_next(cx);
|
||||
};
|
||||
|
||||
if let Some(batch) = self.buffered_batch.clone() {
|
||||
if self.buffered_offset < batch.num_rows() {
|
||||
let remaining = batch.num_rows() - self.buffered_offset;
|
||||
let length = remaining.min(max_batch_length);
|
||||
let sliced = batch.slice(self.buffered_offset, length);
|
||||
self.buffered_offset += length;
|
||||
if self.buffered_offset >= batch.num_rows() {
|
||||
self.buffered_batch = None;
|
||||
self.buffered_offset = 0;
|
||||
}
|
||||
return std::task::Poll::Ready(Some(Ok(sliced)));
|
||||
}
|
||||
|
||||
self.buffered_batch = None;
|
||||
self.buffered_offset = 0;
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
std::task::Poll::Ready(Some(Ok(batch))) => {
|
||||
if batch.num_rows() <= max_batch_length {
|
||||
return std::task::Poll::Ready(Some(Ok(batch)));
|
||||
}
|
||||
self.buffered_batch = Some(batch);
|
||||
self.buffered_offset = 0;
|
||||
}
|
||||
other => return other,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow_array::Int32Array;
|
||||
@@ -470,7 +549,7 @@ mod tests {
|
||||
assert_eq!(string_to_datatype(string), Some(expected));
|
||||
}
|
||||
|
||||
fn sample_batch() -> RecordBatch {
|
||||
fn sample_batch(num_rows: i32) -> RecordBatch {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"col1",
|
||||
DataType::Int32,
|
||||
@@ -478,14 +557,14 @@ mod tests {
|
||||
)]));
|
||||
RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream() {
|
||||
let batch = sample_batch();
|
||||
let batch = sample_batch(3);
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
@@ -515,7 +594,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream_zero_duration() {
|
||||
let batch = sample_batch();
|
||||
let batch = sample_batch(3);
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
@@ -534,7 +613,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream_completes_normally() {
|
||||
let batch = sample_batch();
|
||||
let batch = sample_batch(3);
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
@@ -552,4 +631,35 @@ mod tests {
|
||||
// Stream should be empty now
|
||||
assert!(timeout_stream.next().await.is_none());
|
||||
}
|
||||
|
||||
async fn collect_batch_sizes(
|
||||
stream: SendableRecordBatchStream,
|
||||
max_batch_length: usize,
|
||||
) -> Vec<usize> {
|
||||
let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length);
|
||||
sliced_stream
|
||||
.by_ref()
|
||||
.map(|batch| batch.unwrap().num_rows())
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_max_batch_length_stream_behaviors() {
|
||||
let schema = sample_batch(7).schema();
|
||||
let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]);
|
||||
|
||||
let sendable_stream: SendableRecordBatchStream =
|
||||
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
|
||||
assert_eq!(
|
||||
collect_batch_sizes(sendable_stream, 3).await,
|
||||
vec![2, 3, 3, 1]
|
||||
);
|
||||
|
||||
let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
|
||||
schema,
|
||||
stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]),
|
||||
));
|
||||
assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user