feat(python): add optional threadpool for batch requests (#981)

Currently if a batch request is given to the remote API, each query is
sent sequentially. We should allow the user to specify a threadpool.
This commit is contained in:
Chang She
2024-02-16 20:22:22 -08:00
committed by Weston Pace
parent 26eec4bef4
commit bc850e6add
4 changed files with 37 additions and 9 deletions

View File

@@ -13,8 +13,9 @@
import importlib.metadata
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Optional
from typing import Optional, Union
__version__ = importlib.metadata.version("lancedb")
@@ -31,6 +32,7 @@ def connect(
region: str = "us-east-1",
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
) -> DBConnection:
"""Connect to a LanceDB database.
@@ -57,7 +59,14 @@ def connect(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
request_thread_pool: int or ThreadPoolExecutor, optional
The thread pool to use for making batch requests to the LanceDB Cloud API.
If an integer, then a ThreadPoolExecutor will be created with that
number of threads. If None, then a ThreadPoolExecutor will be created
with the default number of threads. If a ThreadPoolExecutor, then that
executor will be used for making requests. This is for LanceDB Cloud
only and is only used when making batch requests (i.e., passing in
multiple queries to the search method at once).
Examples
--------
@@ -85,5 +94,9 @@ def connect(
api_key = os.environ.get("LANCEDB_API_KEY")
if api_key is None:
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
return RemoteDBConnection(uri, api_key, region, host_override)
if isinstance(request_thread_pool, int):
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
return RemoteDBConnection(
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
)
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)

View File

@@ -14,6 +14,7 @@
import inspect
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, List, Optional, Union
from urllib.parse import urlparse
@@ -39,6 +40,7 @@ class RemoteDBConnection(DBConnection):
api_key: str,
region: str,
host_override: Optional[str] = None,
request_thread_pool: Optional[ThreadPoolExecutor] = None,
):
"""Connect to a remote LanceDB database."""
parsed = urlparse(db_url)
@@ -49,6 +51,7 @@ class RemoteDBConnection(DBConnection):
self._client = RestfulLanceDBClient(
self.db_name, region, api_key, host_override
)
self._request_thread_pool = request_thread_pool
def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})"

View File

@@ -13,6 +13,7 @@
import logging
import uuid
from concurrent.futures import Future
from functools import cached_property
from typing import Dict, Optional, Union
@@ -270,15 +271,28 @@ class RemoteTable(Table):
and len(query.vector) > 0
and not isinstance(query.vector[0], float)
):
if self._conn._request_thread_pool is None:
def submit(name, q):
f = Future()
f.set_result(self._conn._client.query(name, q))
return f
else:
def submit(name, q):
return self._conn._request_thread_pool.submit(
self._conn._client.query, name, q
)
results = []
for v in query.vector:
v = list(v)
q = query.copy()
q.vector = v
results.append(self._conn._client.query(self._name, q))
results.append(submit(self._name, q))
return pa.concat_tables(
[add_index(r.to_arrow(), i) for i, r in enumerate(results)]
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
)
else:
result = self._conn._client.query(self._name, query)

View File

@@ -803,10 +803,8 @@ def test_count_rows(db):
assert table.count_rows(filter="text='bar'") == 1
def test_hybrid_search(db):
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
# Probably not being parsed right by the fts
db = MockDB("~/lancedb_")
def test_hybrid_search(db, tmp_path):
db = MockDB(str(tmp_path))
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()