diff --git a/python/Cargo.toml b/python/Cargo.toml index 6ecc2087..115ff2e0 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -22,8 +22,6 @@ pyo3 = { version = "0.21", features = ["extension-module", "abi3-py38", "gil-ref # pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] } pyo3-asyncio-0-21 = { version = "0.21.0", features = ["attributes", "tokio-runtime"] } -# Prevent dynamic linking of lzma, which comes from datafusion -lzma-sys = { version = "*", features = ["static"] } pin-project = "1.1.5" futures.workspace = true tokio = { version = "1.36.0", features = ["sync"] } @@ -35,4 +33,6 @@ pyo3-build-config = { version = "0.20.3", features = [ ] } [features] +default = ["remote"] fp16kernels = ["lancedb/fp16kernels"] +remote = ["lancedb/remote"] diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 67a64479..b394fa6f 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -19,6 +19,8 @@ from typing import Dict, Optional, Union, Any __version__ = importlib.metadata.version("lancedb") +from lancedb.remote import ClientConfig + from ._lancedb import connect as lancedb_connect from .common import URI, sanitize_uri from .db import AsyncConnection, DBConnection, LanceDBConnection @@ -120,7 +122,7 @@ async def connect_async( region: str = "us-east-1", host_override: Optional[str] = None, read_consistency_interval: Optional[timedelta] = None, - request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, + client_config: Optional[Union[ClientConfig, Dict[str, Any]]] = None, storage_options: Optional[Dict[str, str]] = None, ) -> AsyncConnection: """Connect to a LanceDB database. @@ -148,6 +150,10 @@ async def connect_async( the last check, then the table will be checked for updates. Note: this consistency only applies to read operations. Write operations are always consistent. + client_config: ClientConfig or dict, optional + Configuration options for the LanceDB Cloud HTTP client. If a dict, then + the keys are the attributes of the ClientConfig class. If None, then the + default configuration is used. storage_options: dict, optional Additional options for the storage backend. See available options at https://lancedb.github.io/lancedb/guides/storage/ @@ -160,7 +166,13 @@ async def connect_async( ... # For a local directory, provide a path to the database ... db = await lancedb.connect_async("~/.lancedb") ... # For object storage, use a URI prefix - ... db = await lancedb.connect_async("s3://my-bucket/lancedb") + ... db = await lancedb.connect_async("s3://my-bucket/lancedb", + ... storage_options={ + ... "aws_access_key_id": "***"}) + ... # Connect to LanceDB cloud + ... db = await lancedb.connect_async("db://my_database", api_key="ldb_...", + ... client_config={ + ... "retry_config": {"retries": 5}}) Returns ------- @@ -172,6 +184,9 @@ async def connect_async( else: read_consistency_interval_secs = None + if isinstance(client_config, dict): + client_config = ClientConfig(**client_config) + return AsyncConnection( await lancedb_connect( sanitize_uri(uri), @@ -179,6 +194,7 @@ async def connect_async( region, host_override, read_consistency_interval_secs, + client_config, storage_options, ) ) diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index c6e15c83..fdd0cfae 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -12,9 +12,12 @@ # limitations under the License. import abc +from dataclasses import dataclass +from datetime import timedelta from typing import List, Optional import attrs +from lancedb import __version__ import pyarrow as pa from pydantic import BaseModel @@ -62,3 +65,109 @@ class LanceDBClient(abc.ABC): def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: """Query the LanceDB server for the given table and query.""" pass + + +@dataclass +class TimeoutConfig: + """Timeout configuration for remote HTTP client. + + Attributes + ---------- + connect_timeout: Optional[timedelta] + The timeout for establishing a connection. Default is 120 seconds (2 minutes). + This can also be set via the environment variable + `LANCE_CLIENT_CONNECT_TIMEOUT`, as an integer number of seconds. + read_timeout: Optional[timedelta] + The timeout for reading data from the server. Default is 300 seconds + (5 minutes). This can also be set via the environment variable + `LANCE_CLIENT_READ_TIMEOUT`, as an integer number of seconds. + pool_idle_timeout: Optional[timedelta] + The timeout for keeping idle connections in the connection pool. Default + is 300 seconds (5 minutes). This can also be set via the environment variable + `LANCE_CLIENT_CONNECTION_TIMEOUT`, as an integer number of seconds. + """ + + connect_timeout: Optional[timedelta] = None + read_timeout: Optional[timedelta] = None + pool_idle_timeout: Optional[timedelta] = None + + @staticmethod + def __to_timedelta(value) -> Optional[timedelta]: + if value is None: + return None + elif isinstance(value, timedelta): + return value + elif isinstance(value, (int, float)): + return timedelta(seconds=value) + else: + raise ValueError( + f"Invalid value for timeout: {value}, must be a timedelta " + "or number of seconds" + ) + + def __post_init__(self): + self.connect_timeout = self.__to_timedelta(self.connect_timeout) + self.read_timeout = self.__to_timedelta(self.read_timeout) + self.pool_idle_timeout = self.__to_timedelta(self.pool_idle_timeout) + + +@dataclass +class RetryConfig: + """Retry configuration for the remote HTTP client. + + Attributes + ---------- + retries: Optional[int] + The maximum number of retries for a request. Default is 3. You can also set this + via the environment variable `LANCE_CLIENT_MAX_RETRIES`. + connect_retries: Optional[int] + The maximum number of retries for connection errors. Default is 3. You can also + set this via the environment variable `LANCE_CLIENT_CONNECT_RETRIES`. + read_retries: Optional[int] + The maximum number of retries for read errors. Default is 3. You can also set + this via the environment variable `LANCE_CLIENT_READ_RETRIES`. + backoff_factor: Optional[float] + The backoff factor to apply between retries. Default is 0.25. Between each retry + the client will wait for the amount of seconds: + `{backoff factor} * (2 ** ({number of previous retries}))`. So for the default + of 0.25, the first retry will wait 0.25 seconds, the second retry will wait 0.5 + seconds, the third retry will wait 1 second, etc. + + You can also set this via the environment variable + `LANCE_CLIENT_RETRY_BACKOFF_FACTOR`. + backoff_jitter: Optional[float] + The jitter to apply to the backoff factor, in seconds. Default is 0.25. + + A random value between 0 and `backoff_jitter` will be added to the backoff + factor in seconds. So for the default of 0.25 seconds, between 0 and 250 + milliseconds will be added to the sleep between each retry. + + You can also set this via the environment variable + `LANCE_CLIENT_RETRY_BACKOFF_JITTER`. + statuses: Optional[List[int] + The HTTP status codes for which to retry the request. Default is + [429, 500, 502, 503]. + + You can also set this via the environment variable + `LANCE_CLIENT_RETRY_STATUSES`. Use a comma-separated list of integers. + """ + + retries: Optional[int] = None + connect_retries: Optional[int] = None + read_retries: Optional[int] = None + backoff_factor: Optional[float] = None + backoff_jitter: Optional[float] = None + statuses: Optional[List[int]] = None + + +@dataclass +class ClientConfig: + user_agent: str = f"LanceDB-Python-Client/{__version__}" + retry_config: Optional[RetryConfig] = None + timeout_config: Optional[TimeoutConfig] = None + + def __post_init__(self): + if isinstance(self.retry_config, dict): + self.retry_config = RetryConfig(**self.retry_config) + if isinstance(self.timeout_config, dict): + self.timeout_config = TimeoutConfig(**self.timeout_config) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 729fb550..dee183d9 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors +import http.server +import threading from unittest.mock import MagicMock +import uuid import lancedb import pyarrow as pa from lancedb.remote.client import VectorQuery, VectorQueryResult +import pytest class FakeLanceDBClient: @@ -81,3 +85,57 @@ def test_create_table_with_recordbatches(): table = conn.create_table("test", [batch], schema=batch.schema) assert table.name == "test" assert client.post.call_args[0][0] == "/v1/table/test/create/" + + +def make_mock_http_handler(handler): + class MockLanceDBHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + handler(self) + + def do_POST(self): + handler(self) + + return MockLanceDBHandler + + +@pytest.mark.asyncio +async def test_async_remote_db(): + def handler(request): + # We created a UUID request id + request_id = request.headers["x-request-id"] + assert uuid.UUID(request_id).version == 4 + + # We set a user agent with the current library version + user_agent = request.headers["User-Agent"] + assert user_agent == f"LanceDB-Python-Client/{lancedb.__version__}" + + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b'{"tables": []}') + + def run_server(): + with http.server.HTTPServer( + ("localhost", 8080), make_mock_http_handler(handler) + ) as server: + # we will only make one request + server.handle_request() + + handle = threading.Thread(target=run_server) + handle.start() + + db = await lancedb.connect_async( + "db://dev", + api_key="fake", + host_override="http://localhost:8080", + client_config={ + "retry_config": {"retries": 2}, + "timeout_config": { + "connect_timeout": 1, + }, + }, + ) + table_names = await db.table_names() + assert table_names == [] + + handle.join() diff --git a/python/src/connection.rs b/python/src/connection.rs index 4f7e20a9..200285a4 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -7,7 +7,7 @@ use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::From use lancedb::connection::{Connection as LanceConnection, CreateTableMode, LanceFileVersion}; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, - pyclass, pyfunction, pymethods, Bound, PyAny, PyRef, PyResult, Python, + pyclass, pyfunction, pymethods, Bound, FromPyObject, PyAny, PyRef, PyResult, Python, }; use pyo3_asyncio_0_21::tokio::future_into_py; @@ -187,6 +187,7 @@ impl Connection { } #[pyfunction] +#[allow(clippy::too_many_arguments)] pub fn connect( py: Python, uri: String, @@ -194,6 +195,7 @@ pub fn connect( region: Option, host_override: Option, read_consistency_interval: Option, + client_config: Option, storage_options: Option>, ) -> PyResult> { future_into_py(py, async move { @@ -214,6 +216,70 @@ pub fn connect( if let Some(storage_options) = storage_options { builder = builder.storage_options(storage_options); } + #[cfg(feature = "remote")] + if let Some(client_config) = client_config { + builder = builder.client_config(client_config.into()); + } Ok(Connection::new(builder.execute().await.infer_error()?)) }) } + +#[derive(FromPyObject)] +pub struct PyClientConfig { + user_agent: String, + retry_config: Option, + timeout_config: Option, +} + +#[derive(FromPyObject)] +pub struct PyClientRetryConfig { + retries: Option, + connect_retries: Option, + read_retries: Option, + backoff_factor: Option, + backoff_jitter: Option, + statuses: Option>, +} + +#[derive(FromPyObject)] +pub struct PyClientTimeoutConfig { + connect_timeout: Option, + read_timeout: Option, + pool_idle_timeout: Option, +} + +#[cfg(feature = "remote")] +impl From for lancedb::remote::RetryConfig { + fn from(value: PyClientRetryConfig) -> Self { + Self { + retries: value.retries, + connect_retries: value.connect_retries, + read_retries: value.read_retries, + backoff_factor: value.backoff_factor, + backoff_jitter: value.backoff_jitter, + statuses: value.statuses, + } + } +} + +#[cfg(feature = "remote")] +impl From for lancedb::remote::TimeoutConfig { + fn from(value: PyClientTimeoutConfig) -> Self { + Self { + connect_timeout: value.connect_timeout, + read_timeout: value.read_timeout, + pool_idle_timeout: value.pool_idle_timeout, + } + } +} + +#[cfg(feature = "remote")] +impl From for lancedb::remote::ClientConfig { + fn from(value: PyClientConfig) -> Self { + Self { + user_agent: value.user_agent, + retry_config: value.retry_config.map(Into::into).unwrap_or_default(), + timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(), + } + } +} diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 644cb1f1..a42c0733 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -29,7 +29,7 @@ const REQUEST_ID_HEADER: &str = "x-request-id"; pub struct ClientConfig { pub timeout_config: TimeoutConfig, pub retry_config: RetryConfig, - /// User agent to use for requests. The default provides the libary + /// User agent to use for requests. The default provides the library /// name and version. pub user_agent: String, // TODO: how to configure request ids?