mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
feat(python): add optional threadpool for batch requests (#981)
Currently if a batch request is given to the remote API, each query is sent sequentially. We should allow the user to specify a threadpool.
This commit is contained in:
@@ -13,8 +13,9 @@
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
@@ -31,6 +32,7 @@ def connect(
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -57,7 +59,14 @@ def connect(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -85,5 +94,9 @@ def connect(
|
||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||
return RemoteDBConnection(uri, api_key, region, host_override)
|
||||
if isinstance(request_thread_pool, int):
|
||||
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
|
||||
return RemoteDBConnection(
|
||||
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
|
||||
)
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -39,6 +40,7 @@ class RemoteDBConnection(DBConnection):
|
||||
api_key: str,
|
||||
region: str,
|
||||
host_override: Optional[str] = None,
|
||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
parsed = urlparse(db_url)
|
||||
@@ -49,6 +51,7 @@ class RemoteDBConnection(DBConnection):
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name, region, api_key, host_override
|
||||
)
|
||||
self._request_thread_pool = request_thread_pool
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
@@ -270,15 +271,28 @@ class RemoteTable(Table):
|
||||
and len(query.vector) > 0
|
||||
and not isinstance(query.vector[0], float)
|
||||
):
|
||||
if self._conn._request_thread_pool is None:
|
||||
|
||||
def submit(name, q):
|
||||
f = Future()
|
||||
f.set_result(self._conn._client.query(name, q))
|
||||
return f
|
||||
else:
|
||||
|
||||
def submit(name, q):
|
||||
return self._conn._request_thread_pool.submit(
|
||||
self._conn._client.query, name, q
|
||||
)
|
||||
|
||||
results = []
|
||||
for v in query.vector:
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
results.append(self._conn._client.query(self._name, q))
|
||||
results.append(submit(self._name, q))
|
||||
|
||||
return pa.concat_tables(
|
||||
[add_index(r.to_arrow(), i) for i, r in enumerate(results)]
|
||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||
)
|
||||
else:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
|
||||
@@ -803,10 +803,8 @@ def test_count_rows(db):
|
||||
assert table.count_rows(filter="text='bar'") == 1
|
||||
|
||||
|
||||
def test_hybrid_search(db):
|
||||
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
|
||||
# Probably not being parsed right by the fts
|
||||
db = MockDB("~/lancedb_")
|
||||
def test_hybrid_search(db, tmp_path):
|
||||
db = MockDB(str(tmp_path))
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user