diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 798e7d11..10b681c4 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -35,6 +35,7 @@ def connect( host_override: Optional[str] = None, read_consistency_interval: Optional[timedelta] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, + storage_options: Optional[Dict[str, str]] = None, **kwargs, ) -> DBConnection: """Connect to a LanceDB database. @@ -70,6 +71,9 @@ def connect( executor will be used for making requests. This is for LanceDB Cloud only and is only used when making batch requests (i.e., passing in multiple queries to the search method at once). + storage_options: dict, optional + Additional options for the storage backend. See available options at + https://lancedb.github.io/lancedb/guides/storage/ Examples -------- @@ -105,12 +109,16 @@ def connect( region, host_override, request_thread_pool=request_thread_pool, + storage_options=storage_options, **kwargs, ) if kwargs: raise ValueError(f"Unknown keyword arguments: {kwargs}") - return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) + return LanceDBConnection( + uri, + read_consistency_interval=read_consistency_interval, + ) async def connect_async( diff --git a/python/python/lancedb/remote/client.py b/python/python/lancedb/remote/client.py index 4975e39d..c6554a9d 100644 --- a/python/python/lancedb/remote/client.py +++ b/python/python/lancedb/remote/client.py @@ -52,6 +52,7 @@ def _read_ipc(resp: requests.Response) -> pa.Table: @attrs.define(slots=False) class RestfulLanceDBClient: db_name: str + db_prefix: str | None region: str api_key: Credential host_override: Optional[str] = attrs.field(default=None) @@ -60,6 +61,7 @@ class RestfulLanceDBClient: connection_timeout: float = attrs.field(default=120.0, kw_only=True) read_timeout: float = attrs.field(default=300.0, kw_only=True) + storage_options: Optional[Dict[str, str]] = attrs.field(default=None, kw_only=True) @functools.cached_property def session(self) -> requests.Session: @@ -92,6 +94,18 @@ class RestfulLanceDBClient: headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com" if self.host_override: headers["x-lancedb-database"] = self.db_name + if self.storage_options: + if self.storage_options.get("account_name") is not None: + headers["x-azure-storage-account-name"] = self.storage_options[ + "account_name" + ] + if self.storage_options.get("azure_storage_account_name") is not None: + headers["x-azure-storage-account-name"] = self.storage_options[ + "azure_storage_account_name" + ] + if self.db_prefix: + headers["x-lancedb-database-prefix"] = self.db_prefix + return headers @staticmethod diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 66e01360..89a88654 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -15,7 +15,7 @@ import inspect import logging import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Union from urllib.parse import urlparse from cachetools import TTLCache @@ -44,20 +44,25 @@ class RemoteDBConnection(DBConnection): request_thread_pool: Optional[ThreadPoolExecutor] = None, connection_timeout: float = 120.0, read_timeout: float = 300.0, + storage_options: Optional[Dict[str, str]] = None, ): """Connect to a remote LanceDB database.""" parsed = urlparse(db_url) if parsed.scheme != "db": raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") self.db_name = parsed.netloc + prefix = parsed.path.lstrip("/") + self.db_prefix = None if not prefix else prefix self.api_key = api_key self._client = RestfulLanceDBClient( self.db_name, + self.db_prefix, region, api_key, host_override, connection_timeout=connection_timeout, read_timeout=read_timeout, + storage_options=storage_options, ) self._request_thread_pool = request_thread_pool self._table_cache = TTLCache(maxsize=10000, ttl=300)