From 7773bda7ee31f9c433eae52bee73b17ed8908d80 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Fri, 29 Dec 2023 15:33:03 -0800 Subject: [PATCH] feat(python): first cut batch queries for remote api (#753) issue separate requests under the hood and concatenate results --- python/lancedb/query.py | 4 +++- python/lancedb/remote/table.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 743602ad..3bdc763b 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -70,7 +70,7 @@ class Query(pydantic.BaseModel): vector_column: str = VECTOR_COLUMN_NAME # vector to search for - vector: List[float] + vector: Union[List[float], List[List[float]]] # sql filter to refine the query with filter: Optional[str] = None @@ -421,6 +421,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): vector and the returned vectors. """ vector = self._query if isinstance(self._query, list) else self._query.tolist() + if isinstance(vector[0], np.ndarray): + vector = [v.tolist() for v in vector] query = Query( vector=vector, filter=self._where, diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 158728fb..e09011a7 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import uuid from functools import cached_property from typing import Dict, Optional, Union @@ -227,8 +228,24 @@ class RemoteTable(Table): return LanceVectorQueryBuilder(self, query, vector_column_name) def _execute_query(self, query: Query) -> pa.Table: - result = self._conn._client.query(self._name, query) - return self._conn._loop.run_until_complete(result).to_arrow() + if ( + query.vector is not None + and len(query.vector) > 0 + and not isinstance(query.vector[0], float) + ): + futures = [] + for v in query.vector: + v = list(v) + q = query.copy() + q.vector = v + futures.append(self._conn._client.query(self._name, q)) + result = self._conn._loop.run_until_complete(asyncio.gather(*futures)) + return pa.concat_tables( + [add_index(r.to_arrow(), i) for i, r in enumerate(result)] + ) + else: + result = self._conn._client.query(self._name, query) + return self._conn._loop.run_until_complete(result).to_arrow() def delete(self, predicate: str): """Delete rows from the table. @@ -342,3 +359,11 @@ class RemoteTable(Table): self._conn._loop.run_until_complete( self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) ) + + +def add_index(tbl: pa.Table, i: int) -> pa.Table: + return tbl.add_column( + 0, + pa.field("query_index", pa.uint32()), + pa.array([i] * len(tbl), pa.uint32()), + )