feat: search multiple query vectors as one query (#1811)

Allows users to pass multiple query vector as part of a single query
plan. This just runs the queries in parallel without any further
optimization. It's mostly a convenience.

Previously, I think this was only handled by the sync Python remote API.
This makes it common across all SDKs.

Closes https://github.com/lancedb/lancedb/issues/1803

```python
>>> import lancedb
>>> import asyncio
>>> 
>>> async def main():
...     db = await lancedb.connect_async("./demo")
...     table = await db.create_table("demo", [{"id": 1, "vector": [1, 2, 3]}, {"id": 2, "vector": [4, 5, 6]}], mode="overwrite")
...     return await table.query().nearest_to([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [4.0, 5.0, 6.0]]).limit(1).to_pandas()
... 
>>> asyncio.run(main())
   query_index  id           vector  _distance
0            2   2  [4.0, 5.0, 6.0]        0.0
1            1   2  [4.0, 5.0, 6.0]        0.0
2            0   1  [1.0, 2.0, 3.0]        0.0
```
This commit is contained in:
Will Jones
2024-11-13 16:05:16 -08:00
committed by GitHub
parent 0fd8a50bd7
commit abd75e0ead
9 changed files with 366 additions and 64 deletions

View File

@@ -1491,7 +1491,7 @@ class AsyncQuery(AsyncQueryBase):
return pa.array(vec)
def nearest_to(
self, query_vector: Optional[Union[VEC, Tuple]] = None
self, query_vector: Optional[Union[VEC, Tuple, List[VEC]]] = None
) -> AsyncVectorQuery:
"""
Find the nearest vectors to the given query vector.
@@ -1529,10 +1529,30 @@ class AsyncQuery(AsyncQueryBase):
Vector searches always have a [limit][]. If `limit` has not been called then
a default `limit` of 10 will be used.
Typically, a single vector is passed in as the query. However, you can also
pass in multiple vectors. This can be useful if you want to find the nearest
vectors to multiple query vectors. This is not expected to be faster than
making multiple queries concurrently; it is just a convenience method.
If multiple vectors are passed in then an additional column `query_index`
will be added to the results. This column will contain the index of the
query vector that the result is nearest to.
"""
return AsyncVectorQuery(
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
if (
isinstance(query_vector, list)
and len(query_vector) > 0
and not isinstance(query_vector[0], (float, int))
):
# multiple have been passed
query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector]
new_self = self._inner.nearest_to(query_vectors[0])
for v in query_vectors[1:]:
new_self.add_query_vector(v)
return AsyncVectorQuery(new_self)
else:
return AsyncVectorQuery(
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
def nearest_to_text(
self, query: str, columns: Union[str, List[str]] = []

View File

@@ -229,6 +229,17 @@ def test_query_sync_maximal():
)
def test_query_sync_multiple_vectors():
def handler(_body):
return pa.table({"id": [1]})
with query_test_table(handler) as table:
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
assert len(results) == 2
results.sort(key=lambda x: x["query_index"])
assert results == [{"id": 1, "query_index": 0}, {"id": 1, "query_index": 1}]
def test_query_sync_fts():
def handler(body):
assert body == {

View File

@@ -142,6 +142,13 @@ impl VectorQuery {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
let array = make_array(data);
self.inner = self.inner.clone().add_query_vector(array).infer_error()?;
Ok(())
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}