From dcf53c45065bc54a00d05d96d603afd7c972ac22 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 15 Aug 2025 08:55:12 -0700 Subject: [PATCH] fix: limit and offset support paginating through FTS and vector search results (#2592) Adds tests to ensure that users can paginate through simple scan, FTS, and vector search results using `limit` and `offset`. Tests upstream work: https://github.com/lancedb/lance/pull/4318 Closes #2459 --- nodejs/__test__/table.test.ts | 4 +- rust/lancedb/src/query.rs | 126 +++++++++++++++++++++++++++++++++- rust/lancedb/src/table.rs | 14 ++-- 3 files changed, 129 insertions(+), 15 deletions(-) diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 7717be4d..b49e2ce8 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -563,7 +563,7 @@ describe("When creating an index", () => { // test offset rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow(); - expect(rst.numRows).toBe(1); + expect(rst.numRows).toBe(2); // test nprobes rst = await tbl.query().nearestTo(queryVec).limit(2).nprobes(50).toArrow(); @@ -702,7 +702,7 @@ describe("When creating an index", () => { // test offset rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow(); - expect(rst.numRows).toBe(1); + expect(rst.numRows).toBe(2); // test ef rst = await tbl.query().limit(2).nearestTo(queryVec).ef(100).toArrow(); diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index ce8857c0..9663379d 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -1357,9 +1357,10 @@ mod tests { use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use futures::{StreamExt, TryStreamExt}; use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; + use rand::seq::IndexedRandom; use tempfile::tempdir; - use crate::{connect, database::CreateTableMode, Table}; + use crate::{connect, database::CreateTableMode, index::Index, Table}; #[tokio::test] async fn test_setters_getters() { @@ -1465,7 +1466,7 @@ mod tests { while let Some(batch) = stream.next().await { // pre filter should return 10 rows - assert!(batch.expect("should be Ok").num_rows() == 10); + assert_eq!(batch.expect("should be Ok").num_rows(), 10); } let query = table @@ -1480,7 +1481,7 @@ mod tests { // should only have one batch while let Some(batch) = stream.next().await { // pre filter should return 10 rows - assert!(batch.expect("should be Ok").num_rows() == 9); + assert_eq!(batch.expect("should be Ok").num_rows(), 10); } } @@ -1941,6 +1942,125 @@ mod tests { assert_eq!(2, batch.num_columns()); } + // TODO: Implement a good FTS test data generator in lance_datagen. + fn fts_test_data(nrows: usize) -> RecordBatch { + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("text", DataType::Utf8, false), + ArrowField::new("id", DataType::Int32, false), + ])); + + let ids: Int32Array = (1..=nrows as i32).collect(); + + // Sample 1 - 3 tokens for each string value + let tokens = ["a", "b", "c", "d", "e"]; + use rand::{rng, Rng}; + + let mut rng = rng(); + let text: StringArray = (0..nrows) + .map(|_| { + let num_tokens = rng.random_range(1..=3); // 1 to 3 tokens + let selected_tokens: Vec<&str> = tokens + .choose_multiple(&mut rng, num_tokens) + .cloned() + .collect(); + Some(selected_tokens.join(" ")) + }) + .collect(); + + RecordBatch::try_new(schema, vec![Arc::new(text), Arc::new(ids)]).unwrap() + } + + async fn run_query_request(table: &dyn BaseTable, query: AnyQuery) -> RecordBatch { + use lance::io::RecordBatchStream; + let stream = table.query(&query, Default::default()).await.unwrap(); + let schema = stream.schema(); + let batches = stream.try_collect::>().await.unwrap(); + arrow::compute::concat_batches(&schema, &batches).unwrap() + } + + async fn test_pagination(table: &dyn BaseTable, full_query: AnyQuery, page_size: usize) { + // Get full results + let full_results = run_query_request(table, full_query.clone()).await; + + // Then use limit & offset to do paginated queries, assert each + // is the same as a slice of the full results + let mut offset = 0; + while offset < full_results.num_rows() { + let mut paginated_query = full_query.clone(); + let limit = page_size.min(full_results.num_rows() - offset); + match &mut paginated_query { + AnyQuery::Query(query) + | AnyQuery::VectorQuery(VectorQueryRequest { base: query, .. }) => { + query.limit = Some(limit); + query.offset = Some(offset); + } + } + let paginated_results = run_query_request(table, paginated_query).await; + let expected_slice = full_results.slice(offset, limit); + assert_eq!( + paginated_results, expected_slice, + "Paginated results do not match expected slice at offset {}, for page size {}", + offset, page_size + ); + offset += page_size; + } + } + + #[tokio::test] + async fn test_pagination_with_scan() { + let db = connect("memory://test").execute().await.unwrap(); + let table = db + .create_table("test_table", make_non_empty_batches()) + .execute() + .await + .unwrap(); + let query = AnyQuery::Query(table.query().into_request()); + test_pagination(table.base_table().as_ref(), query.clone(), 3).await; + test_pagination(table.base_table().as_ref(), query, 10).await; + } + + #[tokio::test] + async fn test_pagination_with_fts() { + let db = connect("memory://test").execute().await.unwrap(); + let data = fts_test_data(400); + let schema = data.schema(); + let data = RecordBatchIterator::new(vec![Ok(data)], schema); + let table = db.create_table("test_table", data).execute().await.unwrap(); + + table + .create_index(&["text"], Index::FTS(Default::default())) + .execute() + .await + .unwrap(); + let query = table + .query() + .full_text_search(FullTextSearchQuery::new("test".into())) + .into_request(); + let query = AnyQuery::Query(query); + test_pagination(table.base_table().as_ref(), query.clone(), 3).await; + test_pagination(table.base_table().as_ref(), query, 10).await; + } + + #[tokio::test] + async fn test_pagination_with_vector_query() { + let db = connect("memory://test").execute().await.unwrap(); + let table = db + .create_table("test_table", make_non_empty_batches()) + .execute() + .await + .unwrap(); + let query_vector = vec![0.1_f32, 0.2, 0.3, 0.4]; + let query = table + .query() + .nearest_to(query_vector.as_slice()) + .unwrap() + .limit(50) + .into_request(); + let query = AnyQuery::VectorQuery(query); + test_pagination(table.base_table().as_ref(), query.clone(), 3).await; + test_pagination(table.base_table().as_ref(), query, 10).await; + } + #[tokio::test] async fn test_take_offsets() { let tmp_dir = tempdir().unwrap(); diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index cee85932..8c700110 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -401,6 +401,7 @@ pub enum Filter { } /// A query that can be used to search a LanceDB table +#[derive(Debug, Clone)] pub enum AnyQuery { Query(QueryRequest), VectorQuery(VectorQueryRequest), @@ -2387,20 +2388,13 @@ impl BaseTable for NativeTable { let (_, element_type) = lance::index::vector::utils::get_vector_type(schema, &column)?; let is_binary = matches!(element_type, DataType::UInt8); + let top_k = query.base.limit.unwrap_or(DEFAULT_TOP_K) + query.base.offset.unwrap_or(0); if is_binary { let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?; let query_vector = query_vector.as_primitive::(); - scanner.nearest( - &column, - query_vector, - query.base.limit.unwrap_or(DEFAULT_TOP_K), - )?; + scanner.nearest(&column, query_vector, top_k)?; } else { - scanner.nearest( - &column, - query_vector.as_ref(), - query.base.limit.unwrap_or(DEFAULT_TOP_K), - )?; + scanner.nearest(&column, query_vector.as_ref(), top_k)?; } scanner.minimum_nprobes(query.minimum_nprobes); if let Some(maximum_nprobes) = query.maximum_nprobes {