mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-02 01:42:57 +00:00
feat(python): bind python async remote client to rust client (#1700)
Closes [#1638](https://github.com/lancedb/lancedb/issues/1638) This just binds the Python Async client to the Rust remote client.
This commit is contained in:
@@ -19,6 +19,8 @@ from typing import Dict, Optional, Union, Any
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from lancedb.remote import ClientConfig
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
@@ -120,7 +122,7 @@ async def connect_async(
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
client_config: Optional[Union[ClientConfig, Dict[str, Any]]] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
@@ -148,6 +150,10 @@ async def connect_async(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
client_config: ClientConfig or dict, optional
|
||||
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
|
||||
the keys are the attributes of the ClientConfig class. If None, then the
|
||||
default configuration is used.
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
@@ -160,7 +166,13 @@ async def connect_async(
|
||||
... # For a local directory, provide a path to the database
|
||||
... db = await lancedb.connect_async("~/.lancedb")
|
||||
... # For object storage, use a URI prefix
|
||||
... db = await lancedb.connect_async("s3://my-bucket/lancedb")
|
||||
... db = await lancedb.connect_async("s3://my-bucket/lancedb",
|
||||
... storage_options={
|
||||
... "aws_access_key_id": "***"})
|
||||
... # Connect to LanceDB cloud
|
||||
... db = await lancedb.connect_async("db://my_database", api_key="ldb_...",
|
||||
... client_config={
|
||||
... "retry_config": {"retries": 5}})
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -172,6 +184,9 @@ async def connect_async(
|
||||
else:
|
||||
read_consistency_interval_secs = None
|
||||
|
||||
if isinstance(client_config, dict):
|
||||
client_config = ClientConfig(**client_config)
|
||||
|
||||
return AsyncConnection(
|
||||
await lancedb_connect(
|
||||
sanitize_uri(uri),
|
||||
@@ -179,6 +194,7 @@ async def connect_async(
|
||||
region,
|
||||
host_override,
|
||||
read_consistency_interval_secs,
|
||||
client_config,
|
||||
storage_options,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -12,9 +12,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
import attrs
|
||||
from lancedb import __version__
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -62,3 +65,109 @@ class LanceDBClient(abc.ABC):
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query the LanceDB server for the given table and query."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeoutConfig:
|
||||
"""Timeout configuration for remote HTTP client.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
connect_timeout: Optional[timedelta]
|
||||
The timeout for establishing a connection. Default is 120 seconds (2 minutes).
|
||||
This can also be set via the environment variable
|
||||
`LANCE_CLIENT_CONNECT_TIMEOUT`, as an integer number of seconds.
|
||||
read_timeout: Optional[timedelta]
|
||||
The timeout for reading data from the server. Default is 300 seconds
|
||||
(5 minutes). This can also be set via the environment variable
|
||||
`LANCE_CLIENT_READ_TIMEOUT`, as an integer number of seconds.
|
||||
pool_idle_timeout: Optional[timedelta]
|
||||
The timeout for keeping idle connections in the connection pool. Default
|
||||
is 300 seconds (5 minutes). This can also be set via the environment variable
|
||||
`LANCE_CLIENT_CONNECTION_TIMEOUT`, as an integer number of seconds.
|
||||
"""
|
||||
|
||||
connect_timeout: Optional[timedelta] = None
|
||||
read_timeout: Optional[timedelta] = None
|
||||
pool_idle_timeout: Optional[timedelta] = None
|
||||
|
||||
@staticmethod
|
||||
def __to_timedelta(value) -> Optional[timedelta]:
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, timedelta):
|
||||
return value
|
||||
elif isinstance(value, (int, float)):
|
||||
return timedelta(seconds=value)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid value for timeout: {value}, must be a timedelta "
|
||||
"or number of seconds"
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.connect_timeout = self.__to_timedelta(self.connect_timeout)
|
||||
self.read_timeout = self.__to_timedelta(self.read_timeout)
|
||||
self.pool_idle_timeout = self.__to_timedelta(self.pool_idle_timeout)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Retry configuration for the remote HTTP client.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
retries: Optional[int]
|
||||
The maximum number of retries for a request. Default is 3. You can also set this
|
||||
via the environment variable `LANCE_CLIENT_MAX_RETRIES`.
|
||||
connect_retries: Optional[int]
|
||||
The maximum number of retries for connection errors. Default is 3. You can also
|
||||
set this via the environment variable `LANCE_CLIENT_CONNECT_RETRIES`.
|
||||
read_retries: Optional[int]
|
||||
The maximum number of retries for read errors. Default is 3. You can also set
|
||||
this via the environment variable `LANCE_CLIENT_READ_RETRIES`.
|
||||
backoff_factor: Optional[float]
|
||||
The backoff factor to apply between retries. Default is 0.25. Between each retry
|
||||
the client will wait for the amount of seconds:
|
||||
`{backoff factor} * (2 ** ({number of previous retries}))`. So for the default
|
||||
of 0.25, the first retry will wait 0.25 seconds, the second retry will wait 0.5
|
||||
seconds, the third retry will wait 1 second, etc.
|
||||
|
||||
You can also set this via the environment variable
|
||||
`LANCE_CLIENT_RETRY_BACKOFF_FACTOR`.
|
||||
backoff_jitter: Optional[float]
|
||||
The jitter to apply to the backoff factor, in seconds. Default is 0.25.
|
||||
|
||||
A random value between 0 and `backoff_jitter` will be added to the backoff
|
||||
factor in seconds. So for the default of 0.25 seconds, between 0 and 250
|
||||
milliseconds will be added to the sleep between each retry.
|
||||
|
||||
You can also set this via the environment variable
|
||||
`LANCE_CLIENT_RETRY_BACKOFF_JITTER`.
|
||||
statuses: Optional[List[int]
|
||||
The HTTP status codes for which to retry the request. Default is
|
||||
[429, 500, 502, 503].
|
||||
|
||||
You can also set this via the environment variable
|
||||
`LANCE_CLIENT_RETRY_STATUSES`. Use a comma-separated list of integers.
|
||||
"""
|
||||
|
||||
retries: Optional[int] = None
|
||||
connect_retries: Optional[int] = None
|
||||
read_retries: Optional[int] = None
|
||||
backoff_factor: Optional[float] = None
|
||||
backoff_jitter: Optional[float] = None
|
||||
statuses: Optional[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||
retry_config: Optional[RetryConfig] = None
|
||||
timeout_config: Optional[TimeoutConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.retry_config, dict):
|
||||
self.retry_config = RetryConfig(**self.retry_config)
|
||||
if isinstance(self.timeout_config, dict):
|
||||
self.timeout_config = TimeoutConfig(**self.timeout_config)
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import http.server
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
import uuid
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
import pytest
|
||||
|
||||
|
||||
class FakeLanceDBClient:
|
||||
@@ -81,3 +85,57 @@ def test_create_table_with_recordbatches():
|
||||
table = conn.create_table("test", [batch], schema=batch.schema)
|
||||
assert table.name == "test"
|
||||
assert client.post.call_args[0][0] == "/v1/table/test/create/"
|
||||
|
||||
|
||||
def make_mock_http_handler(handler):
|
||||
class MockLanceDBHandler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
handler(self)
|
||||
|
||||
def do_POST(self):
|
||||
handler(self)
|
||||
|
||||
return MockLanceDBHandler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_remote_db():
|
||||
def handler(request):
|
||||
# We created a UUID request id
|
||||
request_id = request.headers["x-request-id"]
|
||||
assert uuid.UUID(request_id).version == 4
|
||||
|
||||
# We set a user agent with the current library version
|
||||
user_agent = request.headers["User-Agent"]
|
||||
assert user_agent == f"LanceDB-Python-Client/{lancedb.__version__}"
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
def run_server():
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
# we will only make one request
|
||||
server.handle_request()
|
||||
|
||||
handle = threading.Thread(target=run_server)
|
||||
handle.start()
|
||||
|
||||
db = await lancedb.connect_async(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override="http://localhost:8080",
|
||||
client_config={
|
||||
"retry_config": {"retries": 2},
|
||||
"timeout_config": {
|
||||
"connect_timeout": 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
table_names = await db.table_names()
|
||||
assert table_names == []
|
||||
|
||||
handle.join()
|
||||
|
||||
Reference in New Issue
Block a user