feat(python): first cut batch queries for remote api (#753)

issue separate requests under the hood and concatenate results
This commit is contained in:
Chang She
2023-12-29 15:33:03 -08:00
committed by GitHub
parent 392777952f
commit 7773bda7ee
2 changed files with 30 additions and 3 deletions

View File

@@ -70,7 +70,7 @@ class Query(pydantic.BaseModel):
vector_column: str = VECTOR_COLUMN_NAME
# vector to search for
vector: List[float]
vector: Union[List[float], List[List[float]]]
# sql filter to refine the query with
filter: Optional[str] = None
@@ -421,6 +421,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
vector and the returned vectors.
"""
vector = self._query if isinstance(self._query, list) else self._query.tolist()
if isinstance(vector[0], np.ndarray):
vector = [v.tolist() for v in vector]
query = Query(
vector=vector,
filter=self._where,

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from functools import cached_property
from typing import Dict, Optional, Union
@@ -227,8 +228,24 @@ class RemoteTable(Table):
return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table:
result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow()
if (
query.vector is not None
and len(query.vector) > 0
and not isinstance(query.vector[0], float)
):
futures = []
for v in query.vector:
v = list(v)
q = query.copy()
q.vector = v
futures.append(self._conn._client.query(self._name, q))
result = self._conn._loop.run_until_complete(asyncio.gather(*futures))
return pa.concat_tables(
[add_index(r.to_arrow(), i) for i, r in enumerate(result)]
)
else:
result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow()
def delete(self, predicate: str):
"""Delete rows from the table.
@@ -342,3 +359,11 @@ class RemoteTable(Table):
self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
)
def add_index(tbl: pa.Table, i: int) -> pa.Table:
return tbl.add_column(
0,
pa.field("query_index", pa.uint32()),
pa.array([i] * len(tbl), pa.uint32()),
)