diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index c4642637..55c3db99 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -73,6 +73,7 @@ class Query: def where(self, filter: str): ... def select(self, columns: Tuple[str, str]): ... def limit(self, limit: int): ... + def offset(self, offset: int): ... def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to_text(self, query: dict) -> Query: ... async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ... @@ -83,6 +84,7 @@ class VectorQuery: def select(self, columns: List[str]): ... def select_with_projection(self, columns: Tuple[str, str]): ... def limit(self, limit: int): ... + def offset(self, offset: int): ... def column(self, column: str): ... def distance_type(self, distance_type: str): ... def postfilter(self): ... diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 9da90987..9c9c69ae 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -85,6 +85,8 @@ class Query(pydantic.BaseModel): - See discussion in [Querying an ANN Index][querying-an-ann-index] for tuning advice. + offset: int + The offset to start fetching results from """ vector_column: Optional[str] = None @@ -119,6 +121,8 @@ class Query(pydantic.BaseModel): with_row_id: bool = False + offset: int = 0 + class LanceQueryBuilder(ABC): """An abstract query builder. Subclasses are defined for vector search, @@ -233,6 +237,7 @@ class LanceQueryBuilder(ABC): def __init__(self, table: "Table"): self._table = table self._limit = 10 + self._offset = 0 self._columns = None self._where = None self._prefilter = False @@ -371,6 +376,25 @@ class LanceQueryBuilder(ABC): self._limit = limit return self + def offset(self, offset: int) -> LanceQueryBuilder: + """Set the offset for the results. + + Parameters + ---------- + offset: int + The offset to start fetching results from. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + if offset is None or offset <= 0: + self._offset = 0 + else: + self._offset = offset + return self + def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder: """Set the columns to return. @@ -649,6 +673,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): refine_factor=self._refine_factor, vector_column=self._vector_column, with_row_id=self._with_row_id, + offset=self._offset, ) result_set = self._table._execute_query(query, batch_size) if self._reranker is not None: @@ -780,6 +805,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): "columns": self._fts_columns, }, vector=[], + offset=self._offset, ) results = self._table._execute_query(query) results = results.read_all() @@ -1220,6 +1246,18 @@ class AsyncQueryBase(object): self._inner.limit(limit) return self + def offset(self, offset: int) -> AsyncQuery: + """ + Set the offset for the results. + + Parameters + ---------- + offset: int + The offset to start fetching results from. + """ + self._inner.offset(offset) + return self + async def to_batches( self, *, max_batch_length: Optional[int] = None ) -> AsyncRecordBatchReader: diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 46df91c2..7d3ebaa0 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1708,6 +1708,7 @@ class LanceTable(Table): full_text_query=query.full_text_query, with_row_id=query.with_row_id, batch_size=batch_size, + offset=query.offset, ).to_reader() def _do_merge( diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index ae50c991..11750e4d 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -51,6 +51,7 @@ class MockTable: "refine_factor": query.refine_factor, }, batch_size=batch_size, + offset=query.offset, ).to_reader() @@ -106,6 +107,13 @@ def test_cast(table): assert r0.float_field == 1.0 +def test_offset(table): + results_without_offset = LanceVectorQueryBuilder(table, [0, 0], "vector") + assert len(results_without_offset.to_pandas()) == 2 + results_with_offset = LanceVectorQueryBuilder(table, [0, 0], "vector").offset(1) + assert len(results_with_offset.to_pandas()) == 1 + + def test_query_builder(table): rs = ( LanceVectorQueryBuilder(table, [0, 0], "vector") @@ -269,7 +277,10 @@ async def test_query_async(table_async: AsyncTable): table_async.query().select({"foo": "id", "bar": "id + 1"}), expected_columns=["foo", "bar"], ) + await check_query(table_async.query().limit(1), expected_num_rows=1) + await check_query(table_async.query().offset(1), expected_num_rows=1) + await check_query( table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2 ) diff --git a/python/src/query.rs b/python/src/query.rs index f88e60b4..42bd4a13 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -64,6 +64,10 @@ impl Query { self.inner = self.inner.clone().limit(limit as usize); } + pub fn offset(&mut self, offset: u32) { + self.inner = self.inner.clone().offset(offset as usize); + } + pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult { let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?; let array = make_array(data); @@ -138,6 +142,10 @@ impl VectorQuery { self.inner = self.inner.clone().limit(limit as usize); } + pub fn offset(&mut self, offset: u32) { + self.inner = self.inner.clone().offset(offset as usize); + } + pub fn column(&mut self, column: String) { self.inner = self.inner.clone().column(&column); } diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 714200ae..d2895668 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -338,6 +338,12 @@ pub trait QueryBase { /// it will default to 10. fn limit(self, limit: usize) -> Self; + /// Set the offset of the query. + + /// By default, it fetches starting with the first row. + /// This method can be used to skip the first `offset` rows. + fn offset(self, offset: usize) -> Self; + /// Only return rows which match the filter. /// /// The filter should be supplied as an SQL query string. For example: @@ -408,6 +414,11 @@ impl QueryBase for T { self } + fn offset(mut self, offset: usize) -> Self { + self.mut_query().offset = Some(offset); + self + } + fn only_if(mut self, filter: impl AsRef) -> Self { self.mut_query().filter = Some(filter.as_ref().to_string()); self @@ -520,6 +531,9 @@ pub struct Query { /// limit the number of rows to return. pub(crate) limit: Option, + /// Offset of the query. + pub(crate) offset: Option, + /// Apply filter to the returned rows. pub(crate) filter: Option, @@ -541,6 +555,7 @@ impl Query { Self { parent, limit: None, + offset: None, filter: None, full_text_search: None, select: Select::All, @@ -858,6 +873,7 @@ mod tests { let query = table .query() .limit(100) + .offset(1) .nearest_to(&[9.8, 8.7]) .unwrap() .nprobes(1000) @@ -870,6 +886,7 @@ mod tests { new_vector ); assert_eq!(query.base.limit.unwrap(), 100); + assert_eq!(query.base.offset.unwrap(), 1); assert_eq!(query.nprobes, 1000); assert!(query.use_index); assert_eq!(query.distance_type, Some(DistanceType::Cosine)); @@ -916,10 +933,26 @@ mod tests { let result = query.execute().await; let mut stream = result.expect("should have result"); // should only have one batch + while let Some(batch) = stream.next().await { // pre filter should return 10 rows assert!(batch.expect("should be Ok").num_rows() == 10); } + + let query = table + .query() + .limit(10) + .offset(1) + .only_if(String::from("id % 2 == 0")) + .nearest_to(&[0.1; 4]) + .unwrap(); + let result = query.execute().await; + let mut stream = result.expect("should have result"); + // should only have one batch + while let Some(batch) = stream.next().await { + // pre filter should return 10 rows + assert!(batch.expect("should be Ok").num_rows() == 9); + } } #[tokio::test] diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 12254819..f1942f0e 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1852,9 +1852,16 @@ impl TableInternal for NativeTable { query_vector, query.base.limit.unwrap_or(DEFAULT_TOP_K), )?; + scanner.limit( + query.base.limit.map(|limit| limit as i64), + query.base.offset.map(|offset| offset as i64), + )?; } else { // If there is no vector query, it's ok to not have a limit - scanner.limit(query.base.limit.map(|limit| limit as i64), None)?; + scanner.limit( + query.base.limit.map(|limit| limit as i64), + query.base.offset.map(|offset| offset as i64), + )?; } scanner.nprobs(query.nprobes);