feat: configurable timeout for LanceDB Cloud queries (#1090)

This commit is contained in:
Rob Meng
2024-03-11 16:15:48 -04:00
committed by Weston Pace
parent 89ce417452
commit 35bc4f3078
3 changed files with 20 additions and 4 deletions

View File

@@ -34,6 +34,7 @@ def connect(
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
**kwargs,
) -> DBConnection:
"""Connect to a LanceDB database.
@@ -98,7 +99,12 @@ def connect(
if isinstance(request_thread_pool, int):
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
return RemoteDBConnection(
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
uri,
api_key,
region,
host_override,
request_thread_pool=request_thread_pool,
**kwargs,
)
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)

View File

@@ -58,6 +58,9 @@ class RestfulLanceDBClient:
closed: bool = attrs.field(default=False, init=False)
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
read_timeout: float = attrs.field(default=300.0, kw_only=True)
@functools.cached_property
def session(self) -> requests.Session:
sess = requests.Session()
@@ -117,7 +120,7 @@ class RestfulLanceDBClient:
urljoin(self.url, uri),
params=params,
headers=self.headers,
timeout=(120.0, 300.0),
timeout=(self.connection_timeout, self.read_timeout),
) as resp:
self._check_status(resp)
return resp.json()
@@ -159,7 +162,7 @@ class RestfulLanceDBClient:
urljoin(self.url, uri),
headers=headers,
params=params,
timeout=(120.0, 300.0),
timeout=(self.connection_timeout, self.read_timeout),
**req_kwargs,
) as resp:
self._check_status(resp)

View File

@@ -41,6 +41,8 @@ class RemoteDBConnection(DBConnection):
region: str,
host_override: Optional[str] = None,
request_thread_pool: Optional[ThreadPoolExecutor] = None,
connection_timeout: float = 120.0,
read_timeout: float = 300.0,
):
"""Connect to a remote LanceDB database."""
parsed = urlparse(db_url)
@@ -49,7 +51,12 @@ class RemoteDBConnection(DBConnection):
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(
self.db_name, region, api_key, host_override
self.db_name,
region,
api_key,
host_override,
connection_timeout=connection_timeout,
read_timeout=read_timeout,
)
self._request_thread_pool = request_thread_pool