From 5489e215a307ae92355b5967e42f19c71b6e3880 Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Thu, 12 Dec 2024 16:17:34 -0800 Subject: [PATCH] Support storage options and folder prefix --- python/python/lancedb/__init__.py | 10 +++++++++- python/python/lancedb/remote/client.py | 14 ++++++++++++++ python/python/lancedb/remote/db.py | 7 ++++++- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 798e7d11..10b681c4 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -35,6 +35,7 @@ def connect( host_override: Optional[str] = None, read_consistency_interval: Optional[timedelta] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, + storage_options: Optional[Dict[str, str]] = None, **kwargs, ) -> DBConnection: """Connect to a LanceDB database. @@ -70,6 +71,9 @@ def connect( executor will be used for making requests. This is for LanceDB Cloud only and is only used when making batch requests (i.e., passing in multiple queries to the search method at once). + storage_options: dict, optional + Additional options for the storage backend. See available options at + https://lancedb.github.io/lancedb/guides/storage/ Examples -------- @@ -105,12 +109,16 @@ def connect( region, host_override, request_thread_pool=request_thread_pool, + storage_options=storage_options, **kwargs, ) if kwargs: raise ValueError(f"Unknown keyword arguments: {kwargs}") - return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) + return LanceDBConnection( + uri, + read_consistency_interval=read_consistency_interval, + ) async def connect_async( diff --git a/python/python/lancedb/remote/client.py b/python/python/lancedb/remote/client.py index 4975e39d..c6554a9d 100644 --- a/python/python/lancedb/remote/client.py +++ b/python/python/lancedb/remote/client.py @@ -52,6 +52,7 @@ def _read_ipc(resp: requests.Response) -> pa.Table: @attrs.define(slots=False) class RestfulLanceDBClient: db_name: str + db_prefix: str | None region: str api_key: Credential host_override: Optional[str] = attrs.field(default=None) @@ -60,6 +61,7 @@ class RestfulLanceDBClient: connection_timeout: float = attrs.field(default=120.0, kw_only=True) read_timeout: float = attrs.field(default=300.0, kw_only=True) + storage_options: Optional[Dict[str, str]] = attrs.field(default=None, kw_only=True) @functools.cached_property def session(self) -> requests.Session: @@ -92,6 +94,18 @@ class RestfulLanceDBClient: headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com" if self.host_override: headers["x-lancedb-database"] = self.db_name + if self.storage_options: + if self.storage_options.get("account_name") is not None: + headers["x-azure-storage-account-name"] = self.storage_options[ + "account_name" + ] + if self.storage_options.get("azure_storage_account_name") is not None: + headers["x-azure-storage-account-name"] = self.storage_options[ + "azure_storage_account_name" + ] + if self.db_prefix: + headers["x-lancedb-database-prefix"] = self.db_prefix + return headers @staticmethod diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 66e01360..89a88654 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -15,7 +15,7 @@ import inspect import logging import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Union from urllib.parse import urlparse from cachetools import TTLCache @@ -44,20 +44,25 @@ class RemoteDBConnection(DBConnection): request_thread_pool: Optional[ThreadPoolExecutor] = None, connection_timeout: float = 120.0, read_timeout: float = 300.0, + storage_options: Optional[Dict[str, str]] = None, ): """Connect to a remote LanceDB database.""" parsed = urlparse(db_url) if parsed.scheme != "db": raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") self.db_name = parsed.netloc + prefix = parsed.path.lstrip("/") + self.db_prefix = None if not prefix else prefix self.api_key = api_key self._client = RestfulLanceDBClient( self.db_name, + self.db_prefix, region, api_key, host_override, connection_timeout=connection_timeout, read_timeout=read_timeout, + storage_options=storage_options, ) self._request_thread_pool = request_thread_pool self._table_cache = TTLCache(maxsize=10000, ttl=300)