diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index cd33b3b8..5f9def39 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -35,6 +35,7 @@ def connect( host_override: Optional[str] = None, read_consistency_interval: Optional[timedelta] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, + **kwargs, ) -> DBConnection: """Connect to a LanceDB database. @@ -99,7 +100,12 @@ def connect( if isinstance(request_thread_pool, int): request_thread_pool = ThreadPoolExecutor(request_thread_pool) return RemoteDBConnection( - uri, api_key, region, host_override, request_thread_pool=request_thread_pool + uri, + api_key, + region, + host_override, + request_thread_pool=request_thread_pool, + **kwargs, ) return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) diff --git a/python/python/lancedb/remote/client.py b/python/python/lancedb/remote/client.py index 11e18ab3..4975e39d 100644 --- a/python/python/lancedb/remote/client.py +++ b/python/python/lancedb/remote/client.py @@ -58,6 +58,9 @@ class RestfulLanceDBClient: closed: bool = attrs.field(default=False, init=False) + connection_timeout: float = attrs.field(default=120.0, kw_only=True) + read_timeout: float = attrs.field(default=300.0, kw_only=True) + @functools.cached_property def session(self) -> requests.Session: sess = requests.Session() @@ -117,7 +120,7 @@ class RestfulLanceDBClient: urljoin(self.url, uri), params=params, headers=self.headers, - timeout=(120.0, 300.0), + timeout=(self.connection_timeout, self.read_timeout), ) as resp: self._check_status(resp) return resp.json() @@ -159,7 +162,7 @@ class RestfulLanceDBClient: urljoin(self.url, uri), headers=headers, params=params, - timeout=(120.0, 300.0), + timeout=(self.connection_timeout, self.read_timeout), **req_kwargs, ) as resp: self._check_status(resp) diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 3d15a6df..f2ded712 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -41,6 +41,8 @@ class RemoteDBConnection(DBConnection): region: str, host_override: Optional[str] = None, request_thread_pool: Optional[ThreadPoolExecutor] = None, + connection_timeout: float = 120.0, + read_timeout: float = 300.0, ): """Connect to a remote LanceDB database.""" parsed = urlparse(db_url) @@ -49,7 +51,12 @@ class RemoteDBConnection(DBConnection): self.db_name = parsed.netloc self.api_key = api_key self._client = RestfulLanceDBClient( - self.db_name, region, api_key, host_override + self.db_name, + region, + api_key, + host_override, + connection_timeout=connection_timeout, + read_timeout=read_timeout, ) self._request_thread_pool = request_thread_pool