Support storage options and folder prefix

This commit is contained in:
Lu Qiu
2024-12-12 16:17:34 -08:00
parent bc0814767b
commit 5489e215a3
3 changed files with 29 additions and 2 deletions

View File

@@ -35,6 +35,7 @@ def connect(
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
storage_options: Optional[Dict[str, str]] = None,
**kwargs,
) -> DBConnection:
"""Connect to a LanceDB database.
@@ -70,6 +71,9 @@ def connect(
executor will be used for making requests. This is for LanceDB Cloud
only and is only used when making batch requests (i.e., passing in
multiple queries to the search method at once).
storage_options: dict, optional
Additional options for the storage backend. See available options at
https://lancedb.github.io/lancedb/guides/storage/
Examples
--------
@@ -105,12 +109,16 @@ def connect(
region,
host_override,
request_thread_pool=request_thread_pool,
storage_options=storage_options,
**kwargs,
)
if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}")
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
return LanceDBConnection(
uri,
read_consistency_interval=read_consistency_interval,
)
async def connect_async(

View File

@@ -52,6 +52,7 @@ def _read_ipc(resp: requests.Response) -> pa.Table:
@attrs.define(slots=False)
class RestfulLanceDBClient:
db_name: str
db_prefix: str | None
region: str
api_key: Credential
host_override: Optional[str] = attrs.field(default=None)
@@ -60,6 +61,7 @@ class RestfulLanceDBClient:
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
read_timeout: float = attrs.field(default=300.0, kw_only=True)
storage_options: Optional[Dict[str, str]] = attrs.field(default=None, kw_only=True)
@functools.cached_property
def session(self) -> requests.Session:
@@ -92,6 +94,18 @@ class RestfulLanceDBClient:
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
if self.host_override:
headers["x-lancedb-database"] = self.db_name
if self.storage_options:
if self.storage_options.get("account_name") is not None:
headers["x-azure-storage-account-name"] = self.storage_options[
"account_name"
]
if self.storage_options.get("azure_storage_account_name") is not None:
headers["x-azure-storage-account-name"] = self.storage_options[
"azure_storage_account_name"
]
if self.db_prefix:
headers["x-lancedb-database-prefix"] = self.db_prefix
return headers
@staticmethod

View File

@@ -15,7 +15,7 @@ import inspect
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse
from cachetools import TTLCache
@@ -44,20 +44,25 @@ class RemoteDBConnection(DBConnection):
request_thread_pool: Optional[ThreadPoolExecutor] = None,
connection_timeout: float = 120.0,
read_timeout: float = 300.0,
storage_options: Optional[Dict[str, str]] = None,
):
"""Connect to a remote LanceDB database."""
parsed = urlparse(db_url)
if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self.db_name = parsed.netloc
prefix = parsed.path.lstrip("/")
self.db_prefix = None if not prefix else prefix
self.api_key = api_key
self._client = RestfulLanceDBClient(
self.db_name,
self.db_prefix,
region,
api_key,
host_override,
connection_timeout=connection_timeout,
read_timeout=read_timeout,
storage_options=storage_options,
)
self._request_thread_pool = request_thread_pool
self._table_cache = TTLCache(maxsize=10000, ttl=300)