mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 03:12:57 +00:00
dont print apikey in remote client toString, add hostoverride to python client (#353)
This commit is contained in:
@@ -19,7 +19,11 @@ from .schema import vector
|
||||
|
||||
|
||||
def connect(
|
||||
uri: URI, *, api_key: Optional[str] = None, region: str = "us-west-2"
|
||||
uri: URI,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
region: str = "us-west-2",
|
||||
host_override: Optional[str] = None,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -55,5 +59,5 @@ def connect(
|
||||
if isinstance(uri, str) and uri.startswith("db://"):
|
||||
if api_key is None:
|
||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||
return RemoteDBConnection(uri, api_key, region)
|
||||
return RemoteDBConnection(uri, api_key, region, host_override)
|
||||
return LanceDBConnection(uri)
|
||||
|
||||
@@ -48,11 +48,16 @@ class RestfulLanceDBClient:
|
||||
db_name: str
|
||||
region: str
|
||||
api_key: Credential
|
||||
host_override: Optional[str] = attr.field(default=None)
|
||||
|
||||
closed: bool = attr.field(default=False, init=False)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
url = f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
url = (
|
||||
self.host_override
|
||||
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
)
|
||||
return aiohttp.ClientSession(url)
|
||||
|
||||
async def close(self):
|
||||
@@ -66,6 +71,8 @@ class RestfulLanceDBClient:
|
||||
}
|
||||
if self.region == "local": # Local test mode
|
||||
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
||||
if self.host_override:
|
||||
headers["x-lancedb-database"] = self.db_name
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -30,14 +30,22 @@ from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
class RemoteDBConnection(DBConnection):
|
||||
"""A connection to a remote LanceDB database."""
|
||||
|
||||
def __init__(self, db_url: str, api_key: str, region: str):
|
||||
def __init__(
|
||||
self,
|
||||
db_url: str,
|
||||
api_key: str,
|
||||
region: str,
|
||||
host_override: Optional[str] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self.db_name = parsed.netloc
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name, region, api_key, host_override
|
||||
)
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
|
||||
Reference in New Issue
Block a user