mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
fix: limit and offset support paginating through FTS and vector search results (#2592)
Adds tests to ensure that users can paginate through simple scan, FTS, and vector search results using `limit` and `offset`. Tests upstream work: https://github.com/lancedb/lance/pull/4318 Closes #2459
This commit is contained in:
@@ -563,7 +563,7 @@ describe("When creating an index", () => {
|
||||
|
||||
// test offset
|
||||
rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow();
|
||||
expect(rst.numRows).toBe(1);
|
||||
expect(rst.numRows).toBe(2);
|
||||
|
||||
// test nprobes
|
||||
rst = await tbl.query().nearestTo(queryVec).limit(2).nprobes(50).toArrow();
|
||||
@@ -702,7 +702,7 @@ describe("When creating an index", () => {
|
||||
|
||||
// test offset
|
||||
rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow();
|
||||
expect(rst.numRows).toBe(1);
|
||||
expect(rst.numRows).toBe(2);
|
||||
|
||||
// test ef
|
||||
rst = await tbl.query().limit(2).nearestTo(queryVec).ef(100).toArrow();
|
||||
|
||||
@@ -1357,9 +1357,10 @@ mod tests {
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
||||
use rand::seq::IndexedRandom;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::{connect, database::CreateTableMode, Table};
|
||||
use crate::{connect, database::CreateTableMode, index::Index, Table};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_setters_getters() {
|
||||
@@ -1465,7 +1466,7 @@ mod tests {
|
||||
|
||||
while let Some(batch) = stream.next().await {
|
||||
// pre filter should return 10 rows
|
||||
assert!(batch.expect("should be Ok").num_rows() == 10);
|
||||
assert_eq!(batch.expect("should be Ok").num_rows(), 10);
|
||||
}
|
||||
|
||||
let query = table
|
||||
@@ -1480,7 +1481,7 @@ mod tests {
|
||||
// should only have one batch
|
||||
while let Some(batch) = stream.next().await {
|
||||
// pre filter should return 10 rows
|
||||
assert!(batch.expect("should be Ok").num_rows() == 9);
|
||||
assert_eq!(batch.expect("should be Ok").num_rows(), 10);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1941,6 +1942,125 @@ mod tests {
|
||||
assert_eq!(2, batch.num_columns());
|
||||
}
|
||||
|
||||
// TODO: Implement a good FTS test data generator in lance_datagen.
|
||||
fn fts_test_data(nrows: usize) -> RecordBatch {
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("text", DataType::Utf8, false),
|
||||
ArrowField::new("id", DataType::Int32, false),
|
||||
]));
|
||||
|
||||
let ids: Int32Array = (1..=nrows as i32).collect();
|
||||
|
||||
// Sample 1 - 3 tokens for each string value
|
||||
let tokens = ["a", "b", "c", "d", "e"];
|
||||
use rand::{rng, Rng};
|
||||
|
||||
let mut rng = rng();
|
||||
let text: StringArray = (0..nrows)
|
||||
.map(|_| {
|
||||
let num_tokens = rng.random_range(1..=3); // 1 to 3 tokens
|
||||
let selected_tokens: Vec<&str> = tokens
|
||||
.choose_multiple(&mut rng, num_tokens)
|
||||
.cloned()
|
||||
.collect();
|
||||
Some(selected_tokens.join(" "))
|
||||
})
|
||||
.collect();
|
||||
|
||||
RecordBatch::try_new(schema, vec![Arc::new(text), Arc::new(ids)]).unwrap()
|
||||
}
|
||||
|
||||
async fn run_query_request(table: &dyn BaseTable, query: AnyQuery) -> RecordBatch {
|
||||
use lance::io::RecordBatchStream;
|
||||
let stream = table.query(&query, Default::default()).await.unwrap();
|
||||
let schema = stream.schema();
|
||||
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
|
||||
arrow::compute::concat_batches(&schema, &batches).unwrap()
|
||||
}
|
||||
|
||||
async fn test_pagination(table: &dyn BaseTable, full_query: AnyQuery, page_size: usize) {
|
||||
// Get full results
|
||||
let full_results = run_query_request(table, full_query.clone()).await;
|
||||
|
||||
// Then use limit & offset to do paginated queries, assert each
|
||||
// is the same as a slice of the full results
|
||||
let mut offset = 0;
|
||||
while offset < full_results.num_rows() {
|
||||
let mut paginated_query = full_query.clone();
|
||||
let limit = page_size.min(full_results.num_rows() - offset);
|
||||
match &mut paginated_query {
|
||||
AnyQuery::Query(query)
|
||||
| AnyQuery::VectorQuery(VectorQueryRequest { base: query, .. }) => {
|
||||
query.limit = Some(limit);
|
||||
query.offset = Some(offset);
|
||||
}
|
||||
}
|
||||
let paginated_results = run_query_request(table, paginated_query).await;
|
||||
let expected_slice = full_results.slice(offset, limit);
|
||||
assert_eq!(
|
||||
paginated_results, expected_slice,
|
||||
"Paginated results do not match expected slice at offset {}, for page size {}",
|
||||
offset, page_size
|
||||
);
|
||||
offset += page_size;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pagination_with_scan() {
|
||||
let db = connect("memory://test").execute().await.unwrap();
|
||||
let table = db
|
||||
.create_table("test_table", make_non_empty_batches())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let query = AnyQuery::Query(table.query().into_request());
|
||||
test_pagination(table.base_table().as_ref(), query.clone(), 3).await;
|
||||
test_pagination(table.base_table().as_ref(), query, 10).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pagination_with_fts() {
|
||||
let db = connect("memory://test").execute().await.unwrap();
|
||||
let data = fts_test_data(400);
|
||||
let schema = data.schema();
|
||||
let data = RecordBatchIterator::new(vec![Ok(data)], schema);
|
||||
let table = db.create_table("test_table", data).execute().await.unwrap();
|
||||
|
||||
table
|
||||
.create_index(&["text"], Index::FTS(Default::default()))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let query = table
|
||||
.query()
|
||||
.full_text_search(FullTextSearchQuery::new("test".into()))
|
||||
.into_request();
|
||||
let query = AnyQuery::Query(query);
|
||||
test_pagination(table.base_table().as_ref(), query.clone(), 3).await;
|
||||
test_pagination(table.base_table().as_ref(), query, 10).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pagination_with_vector_query() {
|
||||
let db = connect("memory://test").execute().await.unwrap();
|
||||
let table = db
|
||||
.create_table("test_table", make_non_empty_batches())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let query_vector = vec![0.1_f32, 0.2, 0.3, 0.4];
|
||||
let query = table
|
||||
.query()
|
||||
.nearest_to(query_vector.as_slice())
|
||||
.unwrap()
|
||||
.limit(50)
|
||||
.into_request();
|
||||
let query = AnyQuery::VectorQuery(query);
|
||||
test_pagination(table.base_table().as_ref(), query.clone(), 3).await;
|
||||
test_pagination(table.base_table().as_ref(), query, 10).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_take_offsets() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
|
||||
@@ -401,6 +401,7 @@ pub enum Filter {
|
||||
}
|
||||
|
||||
/// A query that can be used to search a LanceDB table
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AnyQuery {
|
||||
Query(QueryRequest),
|
||||
VectorQuery(VectorQueryRequest),
|
||||
@@ -2387,20 +2388,13 @@ impl BaseTable for NativeTable {
|
||||
|
||||
let (_, element_type) = lance::index::vector::utils::get_vector_type(schema, &column)?;
|
||||
let is_binary = matches!(element_type, DataType::UInt8);
|
||||
let top_k = query.base.limit.unwrap_or(DEFAULT_TOP_K) + query.base.offset.unwrap_or(0);
|
||||
if is_binary {
|
||||
let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?;
|
||||
let query_vector = query_vector.as_primitive::<UInt8Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
scanner.nearest(&column, query_vector, top_k)?;
|
||||
} else {
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector.as_ref(),
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
scanner.nearest(&column, query_vector.as_ref(), top_k)?;
|
||||
}
|
||||
scanner.minimum_nprobes(query.minimum_nprobes);
|
||||
if let Some(maximum_nprobes) = query.maximum_nprobes {
|
||||
|
||||
Reference in New Issue
Block a user