mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-15 02:50:44 +00:00
dont print apikey in remote client toString, add hostoverride to python client (#353)
This commit is contained in:
@@ -48,11 +48,16 @@ class RestfulLanceDBClient:
|
||||
db_name: str
|
||||
region: str
|
||||
api_key: Credential
|
||||
host_override: Optional[str] = attr.field(default=None)
|
||||
|
||||
closed: bool = attr.field(default=False, init=False)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
url = f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
url = (
|
||||
self.host_override
|
||||
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
)
|
||||
return aiohttp.ClientSession(url)
|
||||
|
||||
async def close(self):
|
||||
@@ -66,6 +71,8 @@ class RestfulLanceDBClient:
|
||||
}
|
||||
if self.region == "local": # Local test mode
|
||||
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
||||
if self.host_override:
|
||||
headers["x-lancedb-database"] = self.db_name
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -30,14 +30,22 @@ from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
class RemoteDBConnection(DBConnection):
|
||||
"""A connection to a remote LanceDB database."""
|
||||
|
||||
def __init__(self, db_url: str, api_key: str, region: str):
|
||||
def __init__(
|
||||
self,
|
||||
db_url: str,
|
||||
api_key: str,
|
||||
region: str,
|
||||
host_override: Optional[str] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self.db_name = parsed.netloc
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name, region, api_key, host_override
|
||||
)
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
|
||||
Reference in New Issue
Block a user