feat(python, rust): expose offset in query (#1556)

PR is part of #1555
This commit is contained in:
Gagan Bhullar
2024-09-05 09:33:07 -06:00
committed by GitHub
parent 2b8e872be0
commit b24810a011
7 changed files with 101 additions and 1 deletions

View File

@@ -73,6 +73,7 @@ class Query:
def where(self, filter: str): ...
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def offset(self, offset: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> Query: ...
async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
@@ -83,6 +84,7 @@ class VectorQuery:
def select(self, columns: List[str]): ...
def select_with_projection(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def offset(self, offset: int): ...
def column(self, column: str): ...
def distance_type(self, distance_type: str): ...
def postfilter(self): ...

View File

@@ -85,6 +85,8 @@ class Query(pydantic.BaseModel):
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
offset: int
The offset to start fetching results from
"""
vector_column: Optional[str] = None
@@ -119,6 +121,8 @@ class Query(pydantic.BaseModel):
with_row_id: bool = False
offset: int = 0
class LanceQueryBuilder(ABC):
"""An abstract query builder. Subclasses are defined for vector search,
@@ -233,6 +237,7 @@ class LanceQueryBuilder(ABC):
def __init__(self, table: "Table"):
self._table = table
self._limit = 10
self._offset = 0
self._columns = None
self._where = None
self._prefilter = False
@@ -371,6 +376,25 @@ class LanceQueryBuilder(ABC):
self._limit = limit
return self
def offset(self, offset: int) -> LanceQueryBuilder:
"""Set the offset for the results.
Parameters
----------
offset: int
The offset to start fetching results from.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
if offset is None or offset <= 0:
self._offset = 0
else:
self._offset = offset
return self
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
"""Set the columns to return.
@@ -649,6 +673,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
refine_factor=self._refine_factor,
vector_column=self._vector_column,
with_row_id=self._with_row_id,
offset=self._offset,
)
result_set = self._table._execute_query(query, batch_size)
if self._reranker is not None:
@@ -780,6 +805,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
"columns": self._fts_columns,
},
vector=[],
offset=self._offset,
)
results = self._table._execute_query(query)
results = results.read_all()
@@ -1220,6 +1246,18 @@ class AsyncQueryBase(object):
self._inner.limit(limit)
return self
def offset(self, offset: int) -> AsyncQuery:
"""
Set the offset for the results.
Parameters
----------
offset: int
The offset to start fetching results from.
"""
self._inner.offset(offset)
return self
async def to_batches(
self, *, max_batch_length: Optional[int] = None
) -> AsyncRecordBatchReader:

View File

@@ -1708,6 +1708,7 @@ class LanceTable(Table):
full_text_query=query.full_text_query,
with_row_id=query.with_row_id,
batch_size=batch_size,
offset=query.offset,
).to_reader()
def _do_merge(

View File

@@ -51,6 +51,7 @@ class MockTable:
"refine_factor": query.refine_factor,
},
batch_size=batch_size,
offset=query.offset,
).to_reader()
@@ -106,6 +107,13 @@ def test_cast(table):
assert r0.float_field == 1.0
def test_offset(table):
results_without_offset = LanceVectorQueryBuilder(table, [0, 0], "vector")
assert len(results_without_offset.to_pandas()) == 2
results_with_offset = LanceVectorQueryBuilder(table, [0, 0], "vector").offset(1)
assert len(results_with_offset.to_pandas()) == 1
def test_query_builder(table):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
@@ -269,7 +277,10 @@ async def test_query_async(table_async: AsyncTable):
table_async.query().select({"foo": "id", "bar": "id + 1"}),
expected_columns=["foo", "bar"],
)
await check_query(table_async.query().limit(1), expected_num_rows=1)
await check_query(table_async.query().offset(1), expected_num_rows=1)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2
)