diff --git a/python/lancedb/__init__.py b/python/lancedb/__init__.py index 7c04d865..91ca8f62 100644 --- a/python/lancedb/__init__.py +++ b/python/lancedb/__init__.py @@ -13,8 +13,9 @@ import importlib.metadata import os +from concurrent.futures import ThreadPoolExecutor from datetime import timedelta -from typing import Optional +from typing import Optional, Union __version__ = importlib.metadata.version("lancedb") @@ -31,6 +32,7 @@ def connect( region: str = "us-east-1", host_override: Optional[str] = None, read_consistency_interval: Optional[timedelta] = None, + request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, ) -> DBConnection: """Connect to a LanceDB database. @@ -57,7 +59,14 @@ def connect( the last check, then the table will be checked for updates. Note: this consistency only applies to read operations. Write operations are always consistent. - + request_thread_pool: int or ThreadPoolExecutor, optional + The thread pool to use for making batch requests to the LanceDB Cloud API. + If an integer, then a ThreadPoolExecutor will be created with that + number of threads. If None, then a ThreadPoolExecutor will be created + with the default number of threads. If a ThreadPoolExecutor, then that + executor will be used for making requests. This is for LanceDB Cloud + only and is only used when making batch requests (i.e., passing in + multiple queries to the search method at once). Examples -------- @@ -85,5 +94,9 @@ def connect( api_key = os.environ.get("LANCEDB_API_KEY") if api_key is None: raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}") - return RemoteDBConnection(uri, api_key, region, host_override) + if isinstance(request_thread_pool, int): + request_thread_pool = ThreadPoolExecutor(request_thread_pool) + return RemoteDBConnection( + uri, api_key, region, host_override, request_thread_pool=request_thread_pool + ) return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 0c0bd46a..3d15a6df 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -14,6 +14,7 @@ import inspect import logging import uuid +from concurrent.futures import ThreadPoolExecutor from typing import Iterable, List, Optional, Union from urllib.parse import urlparse @@ -39,6 +40,7 @@ class RemoteDBConnection(DBConnection): api_key: str, region: str, host_override: Optional[str] = None, + request_thread_pool: Optional[ThreadPoolExecutor] = None, ): """Connect to a remote LanceDB database.""" parsed = urlparse(db_url) @@ -49,6 +51,7 @@ class RemoteDBConnection(DBConnection): self._client = RestfulLanceDBClient( self.db_name, region, api_key, host_override ) + self._request_thread_pool = request_thread_pool def __repr__(self) -> str: return f"RemoteConnect(name={self.db_name})" diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index a38e7861..a8766eef 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -13,6 +13,7 @@ import logging import uuid +from concurrent.futures import Future from functools import cached_property from typing import Dict, Optional, Union @@ -270,15 +271,28 @@ class RemoteTable(Table): and len(query.vector) > 0 and not isinstance(query.vector[0], float) ): + if self._conn._request_thread_pool is None: + + def submit(name, q): + f = Future() + f.set_result(self._conn._client.query(name, q)) + return f + else: + + def submit(name, q): + return self._conn._request_thread_pool.submit( + self._conn._client.query, name, q + ) + results = [] for v in query.vector: v = list(v) q = query.copy() q.vector = v - results.append(self._conn._client.query(self._name, q)) + results.append(submit(self._name, q)) return pa.concat_tables( - [add_index(r.to_arrow(), i) for i, r in enumerate(results)] + [add_index(r.result().to_arrow(), i) for i, r in enumerate(results)] ) else: result = self._conn._client.query(self._name, query) diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 3c29ed35..9282a5c6 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -803,10 +803,8 @@ def test_count_rows(db): assert table.count_rows(filter="text='bar'") == 1 -def test_hybrid_search(db): - # hardcoding temporarily.. this test is failing with tmp_path mockdb. - # Probably not being parsed right by the fts - db = MockDB("~/lancedb_") +def test_hybrid_search(db, tmp_path): + db = MockDB(str(tmp_path)) # Create a LanceDB table schema with a vector and a text column emb = EmbeddingFunctionRegistry.get_instance().get("test")()