feat(python): bind python async remote client to rust client (#1700)

Closes [#1638](https://github.com/lancedb/lancedb/issues/1638)

This just binds the Python Async client to the Rust remote client.
This commit is contained in:
Will Jones
2024-10-01 15:46:59 -07:00
committed by GitHub
parent a416925ca1
commit f305f34d9b
6 changed files with 255 additions and 6 deletions

View File

@@ -19,6 +19,8 @@ from typing import Dict, Optional, Union, Any
__version__ = importlib.metadata.version("lancedb")
from lancedb.remote import ClientConfig
from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri
from .db import AsyncConnection, DBConnection, LanceDBConnection
@@ -120,7 +122,7 @@ async def connect_async(
region: str = "us-east-1",
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
client_config: Optional[Union[ClientConfig, Dict[str, Any]]] = None,
storage_options: Optional[Dict[str, str]] = None,
) -> AsyncConnection:
"""Connect to a LanceDB database.
@@ -148,6 +150,10 @@ async def connect_async(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
client_config: ClientConfig or dict, optional
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
the keys are the attributes of the ClientConfig class. If None, then the
default configuration is used.
storage_options: dict, optional
Additional options for the storage backend. See available options at
https://lancedb.github.io/lancedb/guides/storage/
@@ -160,7 +166,13 @@ async def connect_async(
... # For a local directory, provide a path to the database
... db = await lancedb.connect_async("~/.lancedb")
... # For object storage, use a URI prefix
... db = await lancedb.connect_async("s3://my-bucket/lancedb")
... db = await lancedb.connect_async("s3://my-bucket/lancedb",
... storage_options={
... "aws_access_key_id": "***"})
... # Connect to LanceDB cloud
... db = await lancedb.connect_async("db://my_database", api_key="ldb_...",
... client_config={
... "retry_config": {"retries": 5}})
Returns
-------
@@ -172,6 +184,9 @@ async def connect_async(
else:
read_consistency_interval_secs = None
if isinstance(client_config, dict):
client_config = ClientConfig(**client_config)
return AsyncConnection(
await lancedb_connect(
sanitize_uri(uri),
@@ -179,6 +194,7 @@ async def connect_async(
region,
host_override,
read_consistency_interval_secs,
client_config,
storage_options,
)
)

View File

@@ -12,9 +12,12 @@
# limitations under the License.
import abc
from dataclasses import dataclass
from datetime import timedelta
from typing import List, Optional
import attrs
from lancedb import __version__
import pyarrow as pa
from pydantic import BaseModel
@@ -62,3 +65,109 @@ class LanceDBClient(abc.ABC):
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query the LanceDB server for the given table and query."""
pass
@dataclass
class TimeoutConfig:
"""Timeout configuration for remote HTTP client.
Attributes
----------
connect_timeout: Optional[timedelta]
The timeout for establishing a connection. Default is 120 seconds (2 minutes).
This can also be set via the environment variable
`LANCE_CLIENT_CONNECT_TIMEOUT`, as an integer number of seconds.
read_timeout: Optional[timedelta]
The timeout for reading data from the server. Default is 300 seconds
(5 minutes). This can also be set via the environment variable
`LANCE_CLIENT_READ_TIMEOUT`, as an integer number of seconds.
pool_idle_timeout: Optional[timedelta]
The timeout for keeping idle connections in the connection pool. Default
is 300 seconds (5 minutes). This can also be set via the environment variable
`LANCE_CLIENT_CONNECTION_TIMEOUT`, as an integer number of seconds.
"""
connect_timeout: Optional[timedelta] = None
read_timeout: Optional[timedelta] = None
pool_idle_timeout: Optional[timedelta] = None
@staticmethod
def __to_timedelta(value) -> Optional[timedelta]:
if value is None:
return None
elif isinstance(value, timedelta):
return value
elif isinstance(value, (int, float)):
return timedelta(seconds=value)
else:
raise ValueError(
f"Invalid value for timeout: {value}, must be a timedelta "
"or number of seconds"
)
def __post_init__(self):
self.connect_timeout = self.__to_timedelta(self.connect_timeout)
self.read_timeout = self.__to_timedelta(self.read_timeout)
self.pool_idle_timeout = self.__to_timedelta(self.pool_idle_timeout)
@dataclass
class RetryConfig:
"""Retry configuration for the remote HTTP client.
Attributes
----------
retries: Optional[int]
The maximum number of retries for a request. Default is 3. You can also set this
via the environment variable `LANCE_CLIENT_MAX_RETRIES`.
connect_retries: Optional[int]
The maximum number of retries for connection errors. Default is 3. You can also
set this via the environment variable `LANCE_CLIENT_CONNECT_RETRIES`.
read_retries: Optional[int]
The maximum number of retries for read errors. Default is 3. You can also set
this via the environment variable `LANCE_CLIENT_READ_RETRIES`.
backoff_factor: Optional[float]
The backoff factor to apply between retries. Default is 0.25. Between each retry
the client will wait for the amount of seconds:
`{backoff factor} * (2 ** ({number of previous retries}))`. So for the default
of 0.25, the first retry will wait 0.25 seconds, the second retry will wait 0.5
seconds, the third retry will wait 1 second, etc.
You can also set this via the environment variable
`LANCE_CLIENT_RETRY_BACKOFF_FACTOR`.
backoff_jitter: Optional[float]
The jitter to apply to the backoff factor, in seconds. Default is 0.25.
A random value between 0 and `backoff_jitter` will be added to the backoff
factor in seconds. So for the default of 0.25 seconds, between 0 and 250
milliseconds will be added to the sleep between each retry.
You can also set this via the environment variable
`LANCE_CLIENT_RETRY_BACKOFF_JITTER`.
statuses: Optional[List[int]
The HTTP status codes for which to retry the request. Default is
[429, 500, 502, 503].
You can also set this via the environment variable
`LANCE_CLIENT_RETRY_STATUSES`. Use a comma-separated list of integers.
"""
retries: Optional[int] = None
connect_retries: Optional[int] = None
read_retries: Optional[int] = None
backoff_factor: Optional[float] = None
backoff_jitter: Optional[float] = None
statuses: Optional[List[int]] = None
@dataclass
class ClientConfig:
user_agent: str = f"LanceDB-Python-Client/{__version__}"
retry_config: Optional[RetryConfig] = None
timeout_config: Optional[TimeoutConfig] = None
def __post_init__(self):
if isinstance(self.retry_config, dict):
self.retry_config = RetryConfig(**self.retry_config)
if isinstance(self.timeout_config, dict):
self.timeout_config = TimeoutConfig(**self.timeout_config)

View File

@@ -1,11 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import http.server
import threading
from unittest.mock import MagicMock
import uuid
import lancedb
import pyarrow as pa
from lancedb.remote.client import VectorQuery, VectorQueryResult
import pytest
class FakeLanceDBClient:
@@ -81,3 +85,57 @@ def test_create_table_with_recordbatches():
table = conn.create_table("test", [batch], schema=batch.schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
def make_mock_http_handler(handler):
class MockLanceDBHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
handler(self)
def do_POST(self):
handler(self)
return MockLanceDBHandler
@pytest.mark.asyncio
async def test_async_remote_db():
def handler(request):
# We created a UUID request id
request_id = request.headers["x-request-id"]
assert uuid.UUID(request_id).version == 4
# We set a user agent with the current library version
user_agent = request.headers["User-Agent"]
assert user_agent == f"LanceDB-Python-Client/{lancedb.__version__}"
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b'{"tables": []}')
def run_server():
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
# we will only make one request
server.handle_request()
handle = threading.Thread(target=run_server)
handle.start()
db = await lancedb.connect_async(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
table_names = await db.table_names()
assert table_names == []
handle.join()