fix: limit and offset support paginating through FTS and vector search results (#2592)

Adds tests to ensure that users can paginate through simple scan, FTS,
and vector search results using `limit` and `offset`.

Tests upstream work: https://github.com/lancedb/lance/pull/4318

Closes #2459
This commit is contained in:
Will Jones
2025-08-15 08:55:12 -07:00
committed by GitHub
parent 941eada703
commit dcf53c4506
3 changed files with 129 additions and 15 deletions

View File

@@ -563,7 +563,7 @@ describe("When creating an index", () => {
// test offset
rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow();
expect(rst.numRows).toBe(1);
expect(rst.numRows).toBe(2);
// test nprobes
rst = await tbl.query().nearestTo(queryVec).limit(2).nprobes(50).toArrow();
@@ -702,7 +702,7 @@ describe("When creating an index", () => {
// test offset
rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow();
expect(rst.numRows).toBe(1);
expect(rst.numRows).toBe(2);
// test ef
rst = await tbl.query().limit(2).nearestTo(queryVec).ef(100).toArrow();

View File

@@ -1357,9 +1357,10 @@ mod tests {
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use futures::{StreamExt, TryStreamExt};
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
use rand::seq::IndexedRandom;
use tempfile::tempdir;
use crate::{connect, database::CreateTableMode, Table};
use crate::{connect, database::CreateTableMode, index::Index, Table};
#[tokio::test]
async fn test_setters_getters() {
@@ -1465,7 +1466,7 @@ mod tests {
while let Some(batch) = stream.next().await {
// pre filter should return 10 rows
assert!(batch.expect("should be Ok").num_rows() == 10);
assert_eq!(batch.expect("should be Ok").num_rows(), 10);
}
let query = table
@@ -1480,7 +1481,7 @@ mod tests {
// should only have one batch
while let Some(batch) = stream.next().await {
// pre filter should return 10 rows
assert!(batch.expect("should be Ok").num_rows() == 9);
assert_eq!(batch.expect("should be Ok").num_rows(), 10);
}
}
@@ -1941,6 +1942,125 @@ mod tests {
assert_eq!(2, batch.num_columns());
}
// TODO: Implement a good FTS test data generator in lance_datagen.
fn fts_test_data(nrows: usize) -> RecordBatch {
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("text", DataType::Utf8, false),
ArrowField::new("id", DataType::Int32, false),
]));
let ids: Int32Array = (1..=nrows as i32).collect();
// Sample 1 - 3 tokens for each string value
let tokens = ["a", "b", "c", "d", "e"];
use rand::{rng, Rng};
let mut rng = rng();
let text: StringArray = (0..nrows)
.map(|_| {
let num_tokens = rng.random_range(1..=3); // 1 to 3 tokens
let selected_tokens: Vec<&str> = tokens
.choose_multiple(&mut rng, num_tokens)
.cloned()
.collect();
Some(selected_tokens.join(" "))
})
.collect();
RecordBatch::try_new(schema, vec![Arc::new(text), Arc::new(ids)]).unwrap()
}
async fn run_query_request(table: &dyn BaseTable, query: AnyQuery) -> RecordBatch {
use lance::io::RecordBatchStream;
let stream = table.query(&query, Default::default()).await.unwrap();
let schema = stream.schema();
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
arrow::compute::concat_batches(&schema, &batches).unwrap()
}
async fn test_pagination(table: &dyn BaseTable, full_query: AnyQuery, page_size: usize) {
// Get full results
let full_results = run_query_request(table, full_query.clone()).await;
// Then use limit & offset to do paginated queries, assert each
// is the same as a slice of the full results
let mut offset = 0;
while offset < full_results.num_rows() {
let mut paginated_query = full_query.clone();
let limit = page_size.min(full_results.num_rows() - offset);
match &mut paginated_query {
AnyQuery::Query(query)
| AnyQuery::VectorQuery(VectorQueryRequest { base: query, .. }) => {
query.limit = Some(limit);
query.offset = Some(offset);
}
}
let paginated_results = run_query_request(table, paginated_query).await;
let expected_slice = full_results.slice(offset, limit);
assert_eq!(
paginated_results, expected_slice,
"Paginated results do not match expected slice at offset {}, for page size {}",
offset, page_size
);
offset += page_size;
}
}
#[tokio::test]
async fn test_pagination_with_scan() {
let db = connect("memory://test").execute().await.unwrap();
let table = db
.create_table("test_table", make_non_empty_batches())
.execute()
.await
.unwrap();
let query = AnyQuery::Query(table.query().into_request());
test_pagination(table.base_table().as_ref(), query.clone(), 3).await;
test_pagination(table.base_table().as_ref(), query, 10).await;
}
#[tokio::test]
async fn test_pagination_with_fts() {
let db = connect("memory://test").execute().await.unwrap();
let data = fts_test_data(400);
let schema = data.schema();
let data = RecordBatchIterator::new(vec![Ok(data)], schema);
let table = db.create_table("test_table", data).execute().await.unwrap();
table
.create_index(&["text"], Index::FTS(Default::default()))
.execute()
.await
.unwrap();
let query = table
.query()
.full_text_search(FullTextSearchQuery::new("test".into()))
.into_request();
let query = AnyQuery::Query(query);
test_pagination(table.base_table().as_ref(), query.clone(), 3).await;
test_pagination(table.base_table().as_ref(), query, 10).await;
}
#[tokio::test]
async fn test_pagination_with_vector_query() {
let db = connect("memory://test").execute().await.unwrap();
let table = db
.create_table("test_table", make_non_empty_batches())
.execute()
.await
.unwrap();
let query_vector = vec![0.1_f32, 0.2, 0.3, 0.4];
let query = table
.query()
.nearest_to(query_vector.as_slice())
.unwrap()
.limit(50)
.into_request();
let query = AnyQuery::VectorQuery(query);
test_pagination(table.base_table().as_ref(), query.clone(), 3).await;
test_pagination(table.base_table().as_ref(), query, 10).await;
}
#[tokio::test]
async fn test_take_offsets() {
let tmp_dir = tempdir().unwrap();

View File

@@ -401,6 +401,7 @@ pub enum Filter {
}
/// A query that can be used to search a LanceDB table
#[derive(Debug, Clone)]
pub enum AnyQuery {
Query(QueryRequest),
VectorQuery(VectorQueryRequest),
@@ -2387,20 +2388,13 @@ impl BaseTable for NativeTable {
let (_, element_type) = lance::index::vector::utils::get_vector_type(schema, &column)?;
let is_binary = matches!(element_type, DataType::UInt8);
let top_k = query.base.limit.unwrap_or(DEFAULT_TOP_K) + query.base.offset.unwrap_or(0);
if is_binary {
let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?;
let query_vector = query_vector.as_primitive::<UInt8Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
scanner.nearest(&column, query_vector, top_k)?;
} else {
scanner.nearest(
&column,
query_vector.as_ref(),
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
scanner.nearest(&column, query_vector.as_ref(), top_k)?;
}
scanner.minimum_nprobes(query.minimum_nprobes);
if let Some(maximum_nprobes) = query.maximum_nprobes {