feat(python, rust): expose offset in query (#1556)

PR is part of #1555
This commit is contained in:
Gagan Bhullar
2024-09-05 09:33:07 -06:00
committed by GitHub
parent 2b8e872be0
commit b24810a011
7 changed files with 101 additions and 1 deletions

View File

@@ -73,6 +73,7 @@ class Query:
def where(self, filter: str): ...
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def offset(self, offset: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> Query: ...
async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
@@ -83,6 +84,7 @@ class VectorQuery:
def select(self, columns: List[str]): ...
def select_with_projection(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def offset(self, offset: int): ...
def column(self, column: str): ...
def distance_type(self, distance_type: str): ...
def postfilter(self): ...

View File

@@ -85,6 +85,8 @@ class Query(pydantic.BaseModel):
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
offset: int
The offset to start fetching results from
"""
vector_column: Optional[str] = None
@@ -119,6 +121,8 @@ class Query(pydantic.BaseModel):
with_row_id: bool = False
offset: int = 0
class LanceQueryBuilder(ABC):
"""An abstract query builder. Subclasses are defined for vector search,
@@ -233,6 +237,7 @@ class LanceQueryBuilder(ABC):
def __init__(self, table: "Table"):
self._table = table
self._limit = 10
self._offset = 0
self._columns = None
self._where = None
self._prefilter = False
@@ -371,6 +376,25 @@ class LanceQueryBuilder(ABC):
self._limit = limit
return self
def offset(self, offset: int) -> LanceQueryBuilder:
"""Set the offset for the results.
Parameters
----------
offset: int
The offset to start fetching results from.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
if offset is None or offset <= 0:
self._offset = 0
else:
self._offset = offset
return self
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
"""Set the columns to return.
@@ -649,6 +673,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
refine_factor=self._refine_factor,
vector_column=self._vector_column,
with_row_id=self._with_row_id,
offset=self._offset,
)
result_set = self._table._execute_query(query, batch_size)
if self._reranker is not None:
@@ -780,6 +805,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
"columns": self._fts_columns,
},
vector=[],
offset=self._offset,
)
results = self._table._execute_query(query)
results = results.read_all()
@@ -1220,6 +1246,18 @@ class AsyncQueryBase(object):
self._inner.limit(limit)
return self
def offset(self, offset: int) -> AsyncQuery:
"""
Set the offset for the results.
Parameters
----------
offset: int
The offset to start fetching results from.
"""
self._inner.offset(offset)
return self
async def to_batches(
self, *, max_batch_length: Optional[int] = None
) -> AsyncRecordBatchReader:

View File

@@ -1708,6 +1708,7 @@ class LanceTable(Table):
full_text_query=query.full_text_query,
with_row_id=query.with_row_id,
batch_size=batch_size,
offset=query.offset,
).to_reader()
def _do_merge(

View File

@@ -51,6 +51,7 @@ class MockTable:
"refine_factor": query.refine_factor,
},
batch_size=batch_size,
offset=query.offset,
).to_reader()
@@ -106,6 +107,13 @@ def test_cast(table):
assert r0.float_field == 1.0
def test_offset(table):
results_without_offset = LanceVectorQueryBuilder(table, [0, 0], "vector")
assert len(results_without_offset.to_pandas()) == 2
results_with_offset = LanceVectorQueryBuilder(table, [0, 0], "vector").offset(1)
assert len(results_with_offset.to_pandas()) == 1
def test_query_builder(table):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
@@ -269,7 +277,10 @@ async def test_query_async(table_async: AsyncTable):
table_async.query().select({"foo": "id", "bar": "id + 1"}),
expected_columns=["foo", "bar"],
)
await check_query(table_async.query().limit(1), expected_num_rows=1)
await check_query(table_async.query().offset(1), expected_num_rows=1)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2
)

View File

@@ -64,6 +64,10 @@ impl Query {
self.inner = self.inner.clone().limit(limit as usize);
}
pub fn offset(&mut self, offset: u32) {
self.inner = self.inner.clone().offset(offset as usize);
}
pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult<VectorQuery> {
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
let array = make_array(data);
@@ -138,6 +142,10 @@ impl VectorQuery {
self.inner = self.inner.clone().limit(limit as usize);
}
pub fn offset(&mut self, offset: u32) {
self.inner = self.inner.clone().offset(offset as usize);
}
pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column);
}

View File

@@ -338,6 +338,12 @@ pub trait QueryBase {
/// it will default to 10.
fn limit(self, limit: usize) -> Self;
/// Set the offset of the query.
/// By default, it fetches starting with the first row.
/// This method can be used to skip the first `offset` rows.
fn offset(self, offset: usize) -> Self;
/// Only return rows which match the filter.
///
/// The filter should be supplied as an SQL query string. For example:
@@ -408,6 +414,11 @@ impl<T: HasQuery> QueryBase for T {
self
}
fn offset(mut self, offset: usize) -> Self {
self.mut_query().offset = Some(offset);
self
}
fn only_if(mut self, filter: impl AsRef<str>) -> Self {
self.mut_query().filter = Some(filter.as_ref().to_string());
self
@@ -520,6 +531,9 @@ pub struct Query {
/// limit the number of rows to return.
pub(crate) limit: Option<usize>,
/// Offset of the query.
pub(crate) offset: Option<usize>,
/// Apply filter to the returned rows.
pub(crate) filter: Option<String>,
@@ -541,6 +555,7 @@ impl Query {
Self {
parent,
limit: None,
offset: None,
filter: None,
full_text_search: None,
select: Select::All,
@@ -858,6 +873,7 @@ mod tests {
let query = table
.query()
.limit(100)
.offset(1)
.nearest_to(&[9.8, 8.7])
.unwrap()
.nprobes(1000)
@@ -870,6 +886,7 @@ mod tests {
new_vector
);
assert_eq!(query.base.limit.unwrap(), 100);
assert_eq!(query.base.offset.unwrap(), 1);
assert_eq!(query.nprobes, 1000);
assert!(query.use_index);
assert_eq!(query.distance_type, Some(DistanceType::Cosine));
@@ -916,10 +933,26 @@ mod tests {
let result = query.execute().await;
let mut stream = result.expect("should have result");
// should only have one batch
while let Some(batch) = stream.next().await {
// pre filter should return 10 rows
assert!(batch.expect("should be Ok").num_rows() == 10);
}
let query = table
.query()
.limit(10)
.offset(1)
.only_if(String::from("id % 2 == 0"))
.nearest_to(&[0.1; 4])
.unwrap();
let result = query.execute().await;
let mut stream = result.expect("should have result");
// should only have one batch
while let Some(batch) = stream.next().await {
// pre filter should return 10 rows
assert!(batch.expect("should be Ok").num_rows() == 9);
}
}
#[tokio::test]

View File

@@ -1852,9 +1852,16 @@ impl TableInternal for NativeTable {
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
scanner.limit(
query.base.limit.map(|limit| limit as i64),
query.base.offset.map(|offset| offset as i64),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
scanner.limit(
query.base.limit.map(|limit| limit as i64),
query.base.offset.map(|offset| offset as i64),
)?;
}
scanner.nprobs(query.nprobes);