From 91cab3b556eab8aa4f399ec10e120092cdb9055b Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 5 Nov 2024 13:44:39 -0800 Subject: [PATCH] feat(python): transition Python remote sdk to use Rust implementation (#1701) * Replaces Python implementation of Remote SDK with Rust one. * Drops dependency on `attrs` and `cachetools`. Makes `requests` an optional dependency used only for embeddings feature. * Adds dependency on `nest-asyncio`. This was required to get hybrid search working. * Deprecate `request_thread_pool` parameter. We now use the tokio threadpool. * Stop caching the `schema` on a remote table. Schema is mutable and there's no mechanism in place to invalidate the cache. * Removed the client-side resolution of the vector column. We should already be resolving this server-side. --- python/pyproject.toml | 5 +- python/python/lancedb/__init__.py | 21 +- python/python/lancedb/_lancedb.pyi | 2 + python/python/lancedb/db.py | 12 + python/python/lancedb/embeddings/jinaai.py | 3 +- python/python/lancedb/index.py | 2 + python/python/lancedb/remote/__init__.py | 57 +-- python/python/lancedb/remote/arrow.py | 25 -- python/python/lancedb/remote/client.py | 269 -------------- .../lancedb/remote/connection_timeout.py | 115 ------ python/python/lancedb/remote/db.py | 146 ++++---- python/python/lancedb/remote/table.py | 230 ++++-------- python/python/lancedb/rerankers/jinaai.py | 3 +- python/python/lancedb/table.py | 56 ++- python/python/tests/test_embeddings_slow.py | 2 +- python/python/tests/test_index.py | 6 +- python/python/tests/test_remote_client.py | 96 ----- python/python/tests/test_remote_db.py | 335 +++++++++++++----- python/src/connection.rs | 11 + python/src/index.rs | 20 +- rust/lancedb/src/connection.rs | 6 +- 21 files changed, 521 insertions(+), 901 deletions(-) delete mode 100644 python/python/lancedb/remote/arrow.py delete mode 100644 python/python/lancedb/remote/client.py delete mode 100644 python/python/lancedb/remote/connection_timeout.py delete mode 100644 python/python/tests/test_remote_client.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 170f947b..86be43a1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,13 +3,11 @@ name = "lancedb" # version in Cargo.toml dependencies = [ "deprecation", + "nest-asyncio~=1.0", "pylance==0.19.2-beta.3", - "requests>=2.31.0", "tqdm>=4.27.0", "pydantic>=1.10", - "attrs>=21.3.0", "packaging", - "cachetools", "overrides>=0.7", ] description = "lancedb" @@ -61,6 +59,7 @@ dev = ["ruff", "pre-commit"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] embeddings = [ + "requests>=2.31.0", "openai>=1.6.1", "sentence-transformers", "torch", diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index b394fa6f..2c5e521d 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -19,12 +19,10 @@ from typing import Dict, Optional, Union, Any __version__ = importlib.metadata.version("lancedb") -from lancedb.remote import ClientConfig - from ._lancedb import connect as lancedb_connect from .common import URI, sanitize_uri from .db import AsyncConnection, DBConnection, LanceDBConnection -from .remote.db import RemoteDBConnection +from .remote import ClientConfig from .schema import vector from .table import AsyncTable @@ -37,6 +35,7 @@ def connect( host_override: Optional[str] = None, read_consistency_interval: Optional[timedelta] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, + client_config: Union[ClientConfig, Dict[str, Any], None] = None, **kwargs: Any, ) -> DBConnection: """Connect to a LanceDB database. @@ -64,14 +63,10 @@ def connect( the last check, then the table will be checked for updates. Note: this consistency only applies to read operations. Write operations are always consistent. - request_thread_pool: int or ThreadPoolExecutor, optional - The thread pool to use for making batch requests to the LanceDB Cloud API. - If an integer, then a ThreadPoolExecutor will be created with that - number of threads. If None, then a ThreadPoolExecutor will be created - with the default number of threads. If a ThreadPoolExecutor, then that - executor will be used for making requests. This is for LanceDB Cloud - only and is only used when making batch requests (i.e., passing in - multiple queries to the search method at once). + client_config: ClientConfig or dict, optional + Configuration options for the LanceDB Cloud HTTP client. If a dict, then + the keys are the attributes of the ClientConfig class. If None, then the + default configuration is used. Examples -------- @@ -94,6 +89,8 @@ def connect( conn : DBConnection A connection to a LanceDB database. """ + from .remote.db import RemoteDBConnection + if isinstance(uri, str) and uri.startswith("db://"): if api_key is None: api_key = os.environ.get("LANCEDB_API_KEY") @@ -106,7 +103,9 @@ def connect( api_key, region, host_override, + # TODO: remove this (deprecation warning downstream) request_thread_pool=request_thread_pool, + client_config=client_config, **kwargs, ) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 33b1f07c..bc4d6617 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -36,6 +36,8 @@ class Connection(object): data_storage_version: Optional[str] = None, enable_v2_manifest_paths: Optional[bool] = None, ) -> Table: ... + async def rename_table(self, old_name: str, new_name: str) -> None: ... + async def drop_table(self, name: str) -> None: ... class Table: def name(self) -> str: ... diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 6af4cdb8..0a9e27d8 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -817,6 +817,18 @@ class AsyncConnection(object): table = await self._inner.open_table(name, storage_options, index_cache_size) return AsyncTable(table) + async def rename_table(self, old_name: str, new_name: str): + """Rename a table in the database. + + Parameters + ---------- + old_name: str + The current name of the table. + new_name: str + The new name of the table. + """ + await self._inner.rename_table(old_name, new_name) + async def drop_table(self, name: str): """Drop a table from the database. diff --git a/python/python/lancedb/embeddings/jinaai.py b/python/python/lancedb/embeddings/jinaai.py index 6619627d..5f89d97c 100644 --- a/python/python/lancedb/embeddings/jinaai.py +++ b/python/python/lancedb/embeddings/jinaai.py @@ -13,7 +13,6 @@ import os import io -import requests import base64 from urllib.parse import urlparse from pathlib import Path @@ -226,6 +225,8 @@ class JinaEmbeddings(EmbeddingFunction): return [result["embedding"] for result in sorted_embeddings] def _init_client(self): + import requests + if JinaEmbeddings._session is None: if self.api_key is None and os.environ.get("JINA_API_KEY") is None: api_key_not_found_help("jina") diff --git a/python/python/lancedb/index.py b/python/python/lancedb/index.py index b7e44b52..a1b06a29 100644 --- a/python/python/lancedb/index.py +++ b/python/python/lancedb/index.py @@ -467,6 +467,8 @@ class IvfPq: The default value is 256. """ + if distance_type is not None: + distance_type = distance_type.lower() self._inner = LanceDbIndex.ivf_pq( distance_type=distance_type, num_partitions=num_partitions, diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index 98cbd2e5..e834c226 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -11,62 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import timedelta from typing import List, Optional -import attrs from lancedb import __version__ -import pyarrow as pa -from pydantic import BaseModel -from lancedb.common import VECTOR_COLUMN_NAME - -__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"] - - -class VectorQuery(BaseModel): - # vector to search for - vector: List[float] - - # sql filter to refine the query with - filter: Optional[str] = None - - # top k results to return - k: int - - # # metrics - _metric: str = "L2" - - # which columns to return in the results - columns: Optional[List[str]] = None - - # optional query parameters for tuning the results, - # e.g. `{"nprobes": "10", "refine_factor": "10"}` - nprobes: int = 10 - - refine_factor: Optional[int] = None - - vector_column: str = VECTOR_COLUMN_NAME - - fast_search: bool = False - - -@attrs.define -class VectorQueryResult: - # for now the response is directly seralized into a pandas dataframe - tbl: pa.Table - - def to_arrow(self) -> pa.Table: - return self.tbl - - -class LanceDBClient(abc.ABC): - @abc.abstractmethod - def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: - """Query the LanceDB server for the given table and query.""" - pass +__all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"] @dataclass @@ -165,8 +116,8 @@ class RetryConfig: @dataclass class ClientConfig: user_agent: str = f"LanceDB-Python-Client/{__version__}" - retry_config: Optional[RetryConfig] = None - timeout_config: Optional[TimeoutConfig] = None + retry_config: RetryConfig = field(default_factory=RetryConfig) + timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig) def __post_init__(self): if isinstance(self.retry_config, dict): diff --git a/python/python/lancedb/remote/arrow.py b/python/python/lancedb/remote/arrow.py deleted file mode 100644 index ac39e247..00000000 --- a/python/python/lancedb/remote/arrow.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Iterable, Union -import pyarrow as pa - - -def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes: - """Serialize a PyArrow Table to IPC binary.""" - sink = pa.BufferOutputStream() - if isinstance(table, Iterable): - table = pa.Table.from_batches(table) - with pa.ipc.new_stream(sink, table.schema) as writer: - writer.write_table(table) - return sink.getvalue().to_pybytes() diff --git a/python/python/lancedb/remote/client.py b/python/python/lancedb/remote/client.py deleted file mode 100644 index d546e92f..00000000 --- a/python/python/lancedb/remote/client.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import functools -import logging -import os -from typing import Any, Callable, Dict, List, Optional, Union -from urllib.parse import urljoin - -import attrs -import pyarrow as pa -import requests -from pydantic import BaseModel -from requests.adapters import HTTPAdapter -from urllib3 import Retry - -from lancedb.common import Credential -from lancedb.remote import VectorQuery, VectorQueryResult -from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory -from lancedb.remote.errors import LanceDBClientError - -ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream" - - -def _check_not_closed(f): - @functools.wraps(f) - def wrapped(self, *args, **kwargs): - if self.closed: - raise ValueError("Connection is closed") - return f(self, *args, **kwargs) - - return wrapped - - -def _read_ipc(resp: requests.Response) -> pa.Table: - resp_body = resp.content - with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader: - return reader.read_all() - - -@attrs.define(slots=False) -class RestfulLanceDBClient: - db_name: str - region: str - api_key: Credential - host_override: Optional[str] = attrs.field(default=None) - - closed: bool = attrs.field(default=False, init=False) - - connection_timeout: float = attrs.field(default=120.0, kw_only=True) - read_timeout: float = attrs.field(default=300.0, kw_only=True) - - @functools.cached_property - def session(self) -> requests.Session: - sess = requests.Session() - - retry_adapter_instance = retry_adapter(retry_adapter_options()) - sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance) - - adapter_class = LanceDBClientHTTPAdapterFactory() - sess.mount("https://", adapter_class()) - return sess - - @property - def url(self) -> str: - return ( - self.host_override - or f"https://{self.db_name}.{self.region}.api.lancedb.com" - ) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - return False # Do not suppress exceptions - - def close(self): - self.session.close() - self.closed = True - - @functools.cached_property - def headers(self) -> Dict[str, str]: - headers = { - "x-api-key": self.api_key, - } - if self.region == "local": # Local test mode - headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com" - if self.host_override: - headers["x-lancedb-database"] = self.db_name - return headers - - @staticmethod - def _check_status(resp: requests.Response): - # Leaving request id empty for now, as we'll be replacing this impl - # with the Rust one shortly. - if resp.status_code == 404: - raise LanceDBClientError( - f"Not found: {resp.text}", request_id="", status_code=404 - ) - elif 400 <= resp.status_code < 500: - raise LanceDBClientError( - f"Bad Request: {resp.status_code}, error: {resp.text}", - request_id="", - status_code=resp.status_code, - ) - elif 500 <= resp.status_code < 600: - raise LanceDBClientError( - f"Internal Server Error: {resp.status_code}, error: {resp.text}", - request_id="", - status_code=resp.status_code, - ) - elif resp.status_code != 200: - raise LanceDBClientError( - f"Unknown Error: {resp.status_code}, error: {resp.text}", - request_id="", - status_code=resp.status_code, - ) - - @_check_not_closed - def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None): - """Send a GET request and returns the deserialized response payload.""" - if isinstance(params, BaseModel): - params: Dict[str, Any] = params.dict(exclude_none=True) - with self.session.get( - urljoin(self.url, uri), - params=params, - headers=self.headers, - timeout=(self.connection_timeout, self.read_timeout), - ) as resp: - self._check_status(resp) - return resp.json() - - @_check_not_closed - def post( - self, - uri: str, - data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None, - params: Optional[Dict[str, Any]] = None, - content_type: Optional[str] = None, - deserialize: Callable = lambda resp: resp.json(), - request_id: Optional[str] = None, - ) -> Dict[str, Any]: - """Send a POST request and returns the deserialized response payload. - - Parameters - ---------- - uri : str - The uri to send the POST request to. - data: Union[Dict[str, Any], BaseModel] - request_id: Optional[str] - Optional client side request id to be sent in the request headers. - - """ - if isinstance(data, BaseModel): - data: Dict[str, Any] = data.dict(exclude_none=True) - if isinstance(data, bytes): - req_kwargs = {"data": data} - else: - req_kwargs = {"json": data} - - headers = self.headers.copy() - if content_type is not None: - headers["content-type"] = content_type - if request_id is not None: - headers["x-request-id"] = request_id - with self.session.post( - urljoin(self.url, uri), - headers=headers, - params=params, - timeout=(self.connection_timeout, self.read_timeout), - **req_kwargs, - ) as resp: - self._check_status(resp) - return deserialize(resp) - - @_check_not_closed - def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]: - """List all tables in the database.""" - if page_token is None: - page_token = "" - json = self.get("/v1/table/", {"limit": limit, "page_token": page_token}) - return json["tables"] - - @_check_not_closed - def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: - """Query a table.""" - tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc) - return VectorQueryResult(tbl) - - def mount_retry_adapter_for_table(self, table_name: str) -> None: - """ - Adds an http adapter to session that will retry retryable requests to the table. - """ - retry_options = retry_adapter_options(methods=["GET", "POST"]) - retry_adapter_instance = retry_adapter(retry_options) - session = self.session - - session.mount( - urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance - ) - session.mount( - urljoin(self.url, f"/v1/table/{table_name}/describe/"), - retry_adapter_instance, - ) - session.mount( - urljoin(self.url, f"/v1/table/{table_name}/index/list/"), - retry_adapter_instance, - ) - - -def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]: - return { - "retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")), - "connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")), - "read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")), - "backoff_factor": float( - os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25") - ), - "backoff_jitter": float( - os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25") - ), - "statuses": [ - int(i.strip()) - for i in os.environ.get( - "LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503" - ).split(",") - ], - "methods": methods, - } - - -def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter: - total_retries = options["retries"] - connect_retries = options["connect_retries"] - read_retries = options["read_retries"] - backoff_factor = options["backoff_factor"] - backoff_jitter = options["backoff_jitter"] - statuses = options["statuses"] - methods = frozenset(options["methods"]) - logging.debug( - f"Setting up retry adapter with {total_retries} retries," # noqa G003 - + f"connect retries {connect_retries}, read retries {read_retries}," - + f"backoff factor {backoff_factor}, statuses {statuses}, " - + f"methods {methods}" - ) - - return HTTPAdapter( - max_retries=Retry( - total=total_retries, - connect=connect_retries, - read=read_retries, - backoff_factor=backoff_factor, - backoff_jitter=backoff_jitter, - status_forcelist=statuses, - allowed_methods=methods, - ) - ) diff --git a/python/python/lancedb/remote/connection_timeout.py b/python/python/lancedb/remote/connection_timeout.py deleted file mode 100644 index f9d18e56..00000000 --- a/python/python/lancedb/remote/connection_timeout.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2024 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This module contains an adapter that will close connections if they have not been -# used before a certain timeout. This is necessary because some load balancers will -# close connections after a certain amount of time, but the request module may not yet -# have received the FIN/ACK and will try to reuse the connection. -# -# TODO some of the code here can be simplified if/when this PR is merged: -# https://github.com/urllib3/urllib3/pull/3275 - -import datetime -import logging -import os - -from requests.adapters import HTTPAdapter -from urllib3.connection import HTTPSConnection -from urllib3.connectionpool import HTTPSConnectionPool -from urllib3.poolmanager import PoolManager - - -def get_client_connection_timeout() -> int: - return int(os.environ.get("LANCE_CLIENT_CONNECTION_TIMEOUT", "300")) - - -class LanceDBHTTPSConnection(HTTPSConnection): - """ - HTTPSConnection that tracks the last time it was used. - """ - - idle_timeout: datetime.timedelta - last_activity: datetime.datetime - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.last_activity = datetime.datetime.now() - - def request(self, *args, **kwargs): - self.last_activity = datetime.datetime.now() - super().request(*args, **kwargs) - - def is_expired(self): - return datetime.datetime.now() - self.last_activity > self.idle_timeout - - -def LanceDBHTTPSConnectionPoolFactory(client_idle_timeout: int): - """ - Creates a connection pool class that can be used to close idle connections. - """ - - class LanceDBHTTPSConnectionPool(HTTPSConnectionPool): - # override the connection class - ConnectionCls = LanceDBHTTPSConnection - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def _get_conn(self, timeout: float | None = None): - logging.debug("Getting https connection") - conn = super()._get_conn(timeout) - if conn.is_expired(): - logging.debug("Closing expired connection") - conn.close() - - return conn - - def _new_conn(self): - conn = super()._new_conn() - conn.idle_timeout = datetime.timedelta(seconds=client_idle_timeout) - return conn - - return LanceDBHTTPSConnectionPool - - -class LanceDBClientPoolManager(PoolManager): - def __init__( - self, client_idle_timeout: int, num_pools: int, maxsize: int, **kwargs - ): - super().__init__(num_pools=num_pools, maxsize=maxsize, **kwargs) - # inject our connection pool impl - connection_pool_class = LanceDBHTTPSConnectionPoolFactory( - client_idle_timeout=client_idle_timeout - ) - self.pool_classes_by_scheme["https"] = connection_pool_class - - -def LanceDBClientHTTPAdapterFactory(): - """ - Creates an HTTPAdapter class that can be used to close idle connections - """ - - # closure over the timeout - client_idle_timeout = get_client_connection_timeout() - - class LanceDBClientRequestHTTPAdapter(HTTPAdapter): - def init_poolmanager(self, connections, maxsize, block=False): - # inject our pool manager impl - self.poolmanager = LanceDBClientPoolManager( - client_idle_timeout=client_idle_timeout, - num_pools=connections, - maxsize=maxsize, - block=block, - ) - - return LanceDBClientRequestHTTPAdapter diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index bb7554a4..51ef389e 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -11,13 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +from datetime import timedelta import logging -import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union from urllib.parse import urlparse +import warnings -from cachetools import TTLCache +from lancedb import connect_async +from lancedb.remote import ClientConfig import pyarrow as pa from overrides import override @@ -25,10 +28,8 @@ from ..common import DATA from ..db import DBConnection from ..embeddings import EmbeddingFunctionConfig from ..pydantic import LanceModel -from ..table import Table, sanitize_create_table +from ..table import Table from ..util import validate_table_name -from .arrow import to_ipc_binary -from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient class RemoteDBConnection(DBConnection): @@ -41,26 +42,70 @@ class RemoteDBConnection(DBConnection): region: str, host_override: Optional[str] = None, request_thread_pool: Optional[ThreadPoolExecutor] = None, - connection_timeout: float = 120.0, - read_timeout: float = 300.0, + client_config: Union[ClientConfig, Dict[str, Any], None] = None, + connection_timeout: Optional[float] = None, + read_timeout: Optional[float] = None, ): """Connect to a remote LanceDB database.""" + + if isinstance(client_config, dict): + client_config = ClientConfig(**client_config) + elif client_config is None: + client_config = ClientConfig() + + # These are legacy options from the old Python-based client. We keep them + # here for backwards compatibility, but will remove them in a future release. + if request_thread_pool is not None: + warnings.warn( + "request_thread_pool is no longer used and will be removed in " + "a future release.", + DeprecationWarning, + ) + + if connection_timeout is not None: + warnings.warn( + "connection_timeout is deprecated and will be removed in a future " + "release. Please use client_config.timeout_config.connect_timeout " + "instead.", + DeprecationWarning, + ) + client_config.timeout_config.connect_timeout = timedelta( + seconds=connection_timeout + ) + + if read_timeout is not None: + warnings.warn( + "read_timeout is deprecated and will be removed in a future release. " + "Please use client_config.timeout_config.read_timeout instead.", + DeprecationWarning, + ) + client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout) + parsed = urlparse(db_url) if parsed.scheme != "db": raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") - self._uri = str(db_url) self.db_name = parsed.netloc - self.api_key = api_key - self._client = RestfulLanceDBClient( - self.db_name, - region, - api_key, - host_override, - connection_timeout=connection_timeout, - read_timeout=read_timeout, + + import nest_asyncio + + nest_asyncio.apply() + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self.client_config = client_config + + self._conn = self._loop.run_until_complete( + connect_async( + db_url, + api_key=api_key, + region=region, + host_override=host_override, + client_config=client_config, + ) ) - self._request_thread_pool = request_thread_pool - self._table_cache = TTLCache(maxsize=10000, ttl=300) def __repr__(self) -> str: return f"RemoteConnect(name={self.db_name})" @@ -82,16 +127,9 @@ class RemoteDBConnection(DBConnection): ------- An iterator of table names. """ - while True: - result = self._client.list_tables(limit, page_token) - - if len(result) > 0: - page_token = result[len(result) - 1] - else: - break - for item in result: - self._table_cache[item] = True - yield item + return self._loop.run_until_complete( + self._conn.table_names(start_after=page_token, limit=limit) + ) @override def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table: @@ -108,20 +146,14 @@ class RemoteDBConnection(DBConnection): """ from .table import RemoteTable - self._client.mount_retry_adapter_for_table(name) - if index_cache_size is not None: logging.info( "index_cache_size is ignored in LanceDb Cloud" " (there is no local cache to configure)" ) - # check if table exists - if self._table_cache.get(name) is None: - self._client.post(f"/v1/table/{name}/describe/") - self._table_cache[name] = True - - return RemoteTable(self, name) + table = self._loop.run_until_complete(self._conn.open_table(name)) + return RemoteTable(table, self.db_name, self._loop) @override def create_table( @@ -233,27 +265,20 @@ class RemoteDBConnection(DBConnection): "Please vote https://github.com/lancedb/lancedb/issues/626 " "for this feature." ) - if mode is not None: - logging.warning("mode is not yet supported on LanceDB Cloud.") - - data, schema = sanitize_create_table( - data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value - ) from .table import RemoteTable - data = to_ipc_binary(data) - request_id = uuid.uuid4().hex - - self._client.post( - f"/v1/table/{name}/create/", - data=data, - request_id=request_id, - content_type=ARROW_STREAM_CONTENT_TYPE, + table = self._loop.run_until_complete( + self._conn.create_table( + name, + data, + mode=mode, + schema=schema, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) ) - - self._table_cache[name] = True - return RemoteTable(self, name) + return RemoteTable(table, self.db_name, self._loop) @override def drop_table(self, name: str): @@ -264,11 +289,7 @@ class RemoteDBConnection(DBConnection): name: str The name of the table. """ - - self._client.post( - f"/v1/table/{name}/drop/", - ) - self._table_cache.pop(name, default=None) + self._loop.run_until_complete(self._conn.drop_table(name)) @override def rename_table(self, cur_name: str, new_name: str): @@ -281,12 +302,7 @@ class RemoteDBConnection(DBConnection): new_name: str The new name of the table. """ - self._client.post( - f"/v1/table/{cur_name}/rename/", - data={"new_table_name": new_name}, - ) - self._table_cache.pop(cur_name, default=None) - self._table_cache[new_name] = True + self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name)) async def close(self): """Close the connection to the database.""" diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 986fbced..e2d88b98 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -11,53 +11,56 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging -import uuid -from concurrent.futures import Future from functools import cached_property from typing import Dict, Iterable, List, Optional, Union, Literal +from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList import pyarrow as pa -from lance import json_to_schema from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.merge import LanceMergeInsertBuilder from lancedb.embeddings import EmbeddingFunctionRegistry from ..query import LanceVectorQueryBuilder, LanceQueryBuilder -from ..table import Query, Table, _sanitize_data -from ..util import value_to_sql, infer_vector_column_name -from .arrow import to_ipc_binary -from .client import ARROW_STREAM_CONTENT_TYPE -from .db import RemoteDBConnection +from ..table import AsyncTable, Query, Table class RemoteTable(Table): - def __init__(self, conn: RemoteDBConnection, name: str): - self._conn = conn - self.name = name + def __init__( + self, + table: AsyncTable, + db_name: str, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + self._loop = loop + self._table = table + self.db_name = db_name + + @property + def name(self) -> str: + """The name of the table""" + return self._table.name def __repr__(self) -> str: - return f"RemoteTable({self._conn.db_name}.{self.name})" + return f"RemoteTable({self.db_name}.{self.name})" def __len__(self) -> int: self.count_rows(None) - @cached_property + @property def schema(self) -> pa.Schema: """The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of this Table """ - resp = self._conn._client.post(f"/v1/table/{self.name}/describe/") - schema = json_to_schema(resp["schema"]) - return schema + return self._loop.run_until_complete(self._table.schema()) @property def version(self) -> int: """Get the current version of the table""" - resp = self._conn._client.post(f"/v1/table/{self.name}/describe/") - return resp["version"] + return self._loop.run_until_complete(self._table.version()) @cached_property def embedding_functions(self) -> dict: @@ -84,20 +87,18 @@ class RemoteTable(Table): def list_indices(self): """List all the indices on the table""" - resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/") - return resp + return self._loop.run_until_complete(self._table.list_indices()) def index_stats(self, index_uuid: str): """List all the stats of a specified index""" - resp = self._conn._client.post( - f"/v1/table/{self.name}/index/{index_uuid}/stats/" - ) - return resp + return self._loop.run_until_complete(self._table.index_stats(index_uuid)) def create_scalar_index( self, column: str, index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar", + *, + replace: bool = False, ): """Creates a scalar index Parameters @@ -107,20 +108,23 @@ class RemoteTable(Table): or string column. index_type : str The index type of the scalar index. Must be "scalar" (BTREE), - "BTREE", "BITMAP", or "LABEL_LIST" + "BTREE", "BITMAP", or "LABEL_LIST", + replace : bool + If True, replace the existing index with the new one. """ + if index_type == "scalar" or index_type == "BTREE": + config = BTree() + elif index_type == "BITMAP": + config = Bitmap() + elif index_type == "LABEL_LIST": + config = LabelList() + else: + raise ValueError(f"Unknown index type: {index_type}") - data = { - "column": column, - "index_type": index_type, - "replace": True, - } - resp = self._conn._client.post( - f"/v1/table/{self.name}/create_scalar_index/", data=data + self._loop.run_until_complete( + self._table.create_index(column, config=config, replace=replace) ) - return resp - def create_fts_index( self, column: str, @@ -128,15 +132,10 @@ class RemoteTable(Table): replace: bool = False, with_position: bool = True, ): - data = { - "column": column, - "index_type": "FTS", - "replace": replace, - } - resp = self._conn._client.post( - f"/v1/table/{self.name}/create_index/", data=data + config = FTS(with_position=with_position) + self._loop.run_until_complete( + self._table.create_index(column, config=config, replace=replace) ) - return resp def create_index( self, @@ -204,17 +203,22 @@ class RemoteTable(Table): "Existing indexes will always be replaced." ) - data = { - "column": vector_column_name, - "index_type": index_type, - "metric_type": metric, - "index_cache_size": index_cache_size, - } - resp = self._conn._client.post( - f"/v1/table/{self.name}/create_index/", data=data - ) + index_type = index_type.upper() + if index_type == "VECTOR" or index_type == "IVF_PQ": + config = IvfPq(distance_type=metric) + elif index_type == "IVF_HNSW_PQ": + config = HnswPq(distance_type=metric) + elif index_type == "IVF_HNSW_SQ": + config = HnswSq(distance_type=metric) + else: + raise ValueError( + f"Unknown vector index type: {index_type}. Valid options are" + " 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'" + ) - return resp + self._loop.run_until_complete( + self._table.create_index(vector_column_name, config=config) + ) def add( self, @@ -246,22 +250,10 @@ class RemoteTable(Table): The value to use when filling vectors. Only used if on_bad_vectors="fill". """ - data, _ = _sanitize_data( - data, - self.schema, - metadata=self.schema.metadata, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - payload = to_ipc_binary(data) - - request_id = uuid.uuid4().hex - - self._conn._client.post( - f"/v1/table/{self.name}/insert/", - data=payload, - params={"request_id": request_id, "mode": mode}, - content_type=ARROW_STREAM_CONTENT_TYPE, + self._loop.run_until_complete( + self._table.add( + data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) ) def search( @@ -337,12 +329,6 @@ class RemoteTable(Table): # empty query builder is not supported in saas, raise error if query is None and query_type != "hybrid": raise ValueError("Empty query is not supported") - vector_column_name = infer_vector_column_name( - schema=self.schema, - query_type=query_type, - query=query, - vector_column_name=vector_column_name, - ) return LanceQueryBuilder.create( self, @@ -356,37 +342,9 @@ class RemoteTable(Table): def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - if ( - query.vector is not None - and len(query.vector) > 0 - and not isinstance(query.vector[0], float) - ): - if self._conn._request_thread_pool is None: - - def submit(name, q): - f = Future() - f.set_result(self._conn._client.query(name, q)) - return f - - else: - - def submit(name, q): - return self._conn._request_thread_pool.submit( - self._conn._client.query, name, q - ) - - results = [] - for v in query.vector: - v = list(v) - q = query.copy() - q.vector = v - results.append(submit(self.name, q)) - return pa.concat_tables( - [add_index(r.result().to_arrow(), i) for i, r in enumerate(results)] - ).to_reader() - else: - result = self._conn._client.query(self.name, query) - return result.to_arrow().to_reader() + return self._loop.run_until_complete( + self._table._execute_query(query, batch_size=batch_size) + ) def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] @@ -403,42 +361,8 @@ class RemoteTable(Table): on_bad_vectors: str, fill_value: float, ): - data, _ = _sanitize_data( - new_data, - self.schema, - metadata=None, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - payload = to_ipc_binary(data) - - params = {} - if len(merge._on) != 1: - raise ValueError( - "RemoteTable only supports a single on key in merge_insert" - ) - params["on"] = merge._on[0] - params["when_matched_update_all"] = str(merge._when_matched_update_all).lower() - if merge._when_matched_update_all_condition is not None: - params["when_matched_update_all_filt"] = ( - merge._when_matched_update_all_condition - ) - params["when_not_matched_insert_all"] = str( - merge._when_not_matched_insert_all - ).lower() - params["when_not_matched_by_source_delete"] = str( - merge._when_not_matched_by_source_delete - ).lower() - if merge._when_not_matched_by_source_condition is not None: - params["when_not_matched_by_source_delete_filt"] = ( - merge._when_not_matched_by_source_condition - ) - - self._conn._client.post( - f"/v1/table/{self.name}/merge_insert/", - data=payload, - params=params, - content_type=ARROW_STREAM_CONTENT_TYPE, + self._loop.run_until_complete( + self._table._do_merge(merge, new_data, on_bad_vectors, fill_value) ) def delete(self, predicate: str): @@ -488,8 +412,7 @@ class RemoteTable(Table): x vector _distance # doctest: +SKIP 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP """ - payload = {"predicate": predicate} - self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload) + self._loop.run_until_complete(self._table.delete(predicate)) def update( self, @@ -539,18 +462,9 @@ class RemoteTable(Table): 2 2 [10.0, 10.0] # doctest: +SKIP """ - if values is not None and values_sql is not None: - raise ValueError("Only one of values or values_sql can be provided") - if values is None and values_sql is None: - raise ValueError("Either values or values_sql must be provided") - - if values is not None: - updates = [[k, value_to_sql(v)] for k, v in values.items()] - else: - updates = [[k, v] for k, v in values_sql.items()] - - payload = {"predicate": where, "updates": updates} - self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload) + self._loop.run_until_complete( + self._table.update(where=where, updates=values, updates_sql=values_sql) + ) def cleanup_old_versions(self, *_): """cleanup_old_versions() is not supported on the LanceDB cloud""" @@ -565,11 +479,7 @@ class RemoteTable(Table): ) def count_rows(self, filter: Optional[str] = None) -> int: - payload = {"predicate": filter} - resp = self._conn._client.post( - f"/v1/table/{self.name}/count_rows/", data=payload - ) - return resp + return self._loop.run_until_complete(self._table.count_rows(filter)) def add_columns(self, transforms: Dict[str, str]): raise NotImplementedError( diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py index 6be646bd..c44355cf 100644 --- a/python/python/lancedb/rerankers/jinaai.py +++ b/python/python/lancedb/rerankers/jinaai.py @@ -12,7 +12,6 @@ # limitations under the License. import os -import requests from functools import cached_property from typing import Union @@ -57,6 +56,8 @@ class JinaReranker(Reranker): @cached_property def _client(self): + import requests + if os.environ.get("JINA_API_KEY") is None and self.api_key is None: raise ValueError( "JINA_API_KEY not set. Either set it in your environment or \ diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 59dc4487..18e2c266 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -62,7 +62,7 @@ if TYPE_CHECKING: from lance.dataset import CleanupStats, ReaderLike from ._lancedb import Table as LanceDBTable, OptimizeStats from .db import LanceDBConnection - from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS + from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS, HnswPq, HnswSq pd = safe_import_pandas() pl = safe_import_polars() @@ -948,7 +948,9 @@ class Table(ABC): return _table_uri(self._conn.uri, self.name) def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]: - if get_uri_scheme(self._dataset_uri) != "file": + from .remote.table import RemoteTable + + if isinstance(self, RemoteTable) or get_uri_scheme(self._dataset_uri) != "file": return ("", None, False) path = join_uri(self._dataset_uri, "_indices", "fts") fs, path = fs_from_uri(path) @@ -2382,7 +2384,9 @@ class AsyncTable: column: str, *, replace: Optional[bool] = None, - config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None, + config: Optional[ + Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS] + ] = None, ): """Create an index to speed up queries @@ -2535,7 +2539,44 @@ class AsyncTable: async def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - pass + # The sync remote table calls into this method, so we need to map the + # query to the async version of the query and run that here. This is only + # used for that code path right now. + async_query = self.query().limit(query.k) + if query.offset > 0: + async_query = async_query.offset(query.offset) + if query.columns: + async_query = async_query.select(query.columns) + if query.filter: + async_query = async_query.where(query.filter) + if query.fast_search: + async_query = async_query.fast_search() + if query.with_row_id: + async_query = async_query.with_row_id() + + if query.vector: + async_query = ( + async_query.nearest_to(query.vector) + .distance_type(query.metric) + .nprobes(query.nprobes) + ) + if query.refine_factor: + async_query = async_query.refine_factor(query.refine_factor) + if query.vector_column: + async_query = async_query.column(query.vector_column) + + if not query.prefilter: + async_query = async_query.postfilter() + + if isinstance(query.full_text_query, str): + async_query = async_query.nearest_to_text(query.full_text_query) + elif isinstance(query.full_text_query, dict): + fts_query = query.full_text_query["query"] + fts_columns = query.full_text_query.get("columns", []) or [] + async_query = async_query.nearest_to_text(fts_query, columns=fts_columns) + + table = await async_query.to_arrow() + return table.to_reader() async def _do_merge( self, @@ -2781,7 +2822,7 @@ class AsyncTable: cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000) return await self._inner.optimize(cleanup_older_than, delete_unverified) - async def list_indices(self) -> IndexConfig: + async def list_indices(self) -> Iterable[IndexConfig]: """ List all indices that have been created with Self::create_index """ @@ -2865,3 +2906,8 @@ class IndexStatistics: ] distance_type: Optional[Literal["l2", "cosine", "dot"]] = None num_indices: Optional[int] = None + + # This exists for backwards compatibility with an older API, which returned + # a dictionary instead of a class. + def __getitem__(self, key): + return getattr(self, key) diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 87b2e249..9e17ca66 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -18,7 +18,6 @@ import lancedb import numpy as np import pandas as pd import pytest -import requests from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector @@ -108,6 +107,7 @@ def test_basic_text_embeddings(alias, tmp_path): @pytest.mark.slow def test_openclip(tmp_path): + import requests from PIL import Image db = lancedb.connect(tmp_path) diff --git a/python/python/tests/test_index.py b/python/python/tests/test_index.py index 1245997e..3268179b 100644 --- a/python/python/tests/test_index.py +++ b/python/python/tests/test_index.py @@ -49,7 +49,7 @@ async def test_create_scalar_index(some_table: AsyncTable): # Can recreate if replace=True await some_table.create_index("id", replace=True) indices = await some_table.list_indices() - assert str(indices) == '[Index(BTree, columns=["id"])]' + assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]' assert len(indices) == 1 assert indices[0].index_type == "BTree" assert indices[0].columns == ["id"] @@ -64,7 +64,7 @@ async def test_create_scalar_index(some_table: AsyncTable): async def test_create_bitmap_index(some_table: AsyncTable): await some_table.create_index("id", config=Bitmap()) indices = await some_table.list_indices() - assert str(indices) == '[Index(Bitmap, columns=["id"])]' + assert str(indices) == '[Index(Bitmap, columns=["id"], name="id_idx")]' indices = await some_table.list_indices() assert len(indices) == 1 index_name = indices[0].name @@ -80,7 +80,7 @@ async def test_create_bitmap_index(some_table: AsyncTable): async def test_create_label_list_index(some_table: AsyncTable): await some_table.create_index("tags", config=LabelList()) indices = await some_table.list_indices() - assert str(indices) == '[Index(LabelList, columns=["tags"])]' + assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]' @pytest.mark.asyncio diff --git a/python/python/tests/test_remote_client.py b/python/python/tests/test_remote_client.py deleted file mode 100644 index f5874953..00000000 --- a/python/python/tests/test_remote_client.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import attrs -import numpy as np -import pandas as pd -import pyarrow as pa -import pytest -from aiohttp import web -from lancedb.remote.client import RestfulLanceDBClient, VectorQuery - - -@attrs.define -class MockLanceDBServer: - runner: web.AppRunner = attrs.field(init=False) - site: web.TCPSite = attrs.field(init=False) - - async def query_handler(self, request: web.Request) -> web.Response: - table_name = request.match_info["table_name"] - assert table_name == "test_table" - - await request.json() - # TODO: do some matching - - vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector") - ids = pd.Series(range(10), name="id") - df = pd.DataFrame([vecs, ids]).T - - batch = pa.RecordBatch.from_pandas( - df, - schema=pa.schema( - [ - pa.field("vector", pa.list_(pa.float32(), 128)), - pa.field("id", pa.int64()), - ] - ), - ) - - sink = pa.BufferOutputStream() - with pa.ipc.new_file(sink, batch.schema) as writer: - writer.write_batch(batch) - - return web.Response(body=sink.getvalue().to_pybytes()) - - async def setup(self): - app = web.Application() - app.add_routes([web.post("/table/{table_name}", self.query_handler)]) - self.runner = web.AppRunner(app) - await self.runner.setup() - self.site = web.TCPSite(self.runner, "localhost", 8111) - - async def start(self): - await self.site.start() - - async def stop(self): - await self.runner.cleanup() - - -@pytest.mark.skip(reason="flaky somehow, fix later") -@pytest.mark.asyncio -async def test_e2e_with_mock_server(): - mock_server = MockLanceDBServer() - await mock_server.setup() - await mock_server.start() - - try: - with RestfulLanceDBClient("lancedb+http://localhost:8111") as client: - df = ( - await client.query( - "test_table", - VectorQuery( - vector=np.random.rand(128).tolist(), - k=10, - _metric="L2", - columns=["id", "vector"], - ), - ) - ).to_pandas() - - assert "vector" in df.columns - assert "id" in df.columns - - assert client.closed - finally: - # make sure we don't leak resources - await mock_server.stop() diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index e03b6636..bc3a2783 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -2,91 +2,19 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors import contextlib +from datetime import timedelta import http.server +import json import threading from unittest.mock import MagicMock import uuid import lancedb +from lancedb.conftest import MockTextEmbeddingFunction +from lancedb.remote import ClientConfig from lancedb.remote.errors import HttpError, RetryError -import pyarrow as pa -from lancedb.remote.client import VectorQuery, VectorQueryResult import pytest - - -class FakeLanceDBClient: - def close(self): - pass - - def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: - assert table_name == "test" - t = pa.schema([]).empty_table() - return VectorQueryResult(t) - - def post(self, path: str): - pass - - def mount_retry_adapter_for_table(self, table_name: str): - pass - - -def test_remote_db(): - conn = lancedb.connect("db://client-will-be-injected", api_key="fake") - setattr(conn, "_client", FakeLanceDBClient()) - - table = conn["test"] - table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) - table.search([1.0, 2.0]).to_pandas() - - -def test_create_empty_table(): - client = MagicMock() - conn = lancedb.connect("db://client-will-be-injected", api_key="fake") - - conn._client = client - - schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) - - client.post.return_value = {"status": "ok"} - table = conn.create_table("test", schema=schema) - assert table.name == "test" - assert client.post.call_args[0][0] == "/v1/table/test/create/" - - json_schema = { - "fields": [ - { - "name": "vector", - "nullable": True, - "type": { - "type": "fixed_size_list", - "fields": [ - {"name": "item", "nullable": True, "type": {"type": "float"}} - ], - "length": 2, - }, - }, - ] - } - client.post.return_value = {"schema": json_schema} - assert table.schema == schema - assert client.post.call_args[0][0] == "/v1/table/test/describe/" - - client.post.return_value = 0 - assert table.count_rows(None) == 0 - - -def test_create_table_with_recordbatches(): - client = MagicMock() - conn = lancedb.connect("db://client-will-be-injected", api_key="fake") - - conn._client = client - - batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"]) - - client.post.return_value = {"status": "ok"} - table = conn.create_table("test", [batch], schema=batch.schema) - assert table.name == "test" - assert client.post.call_args[0][0] == "/v1/table/test/create/" +import pyarrow as pa def make_mock_http_handler(handler): @@ -100,8 +28,35 @@ def make_mock_http_handler(handler): return MockLanceDBHandler +@contextlib.contextmanager +def mock_lancedb_connection(handler): + with http.server.HTTPServer( + ("localhost", 8080), make_mock_http_handler(handler) + ) as server: + handle = threading.Thread(target=server.serve_forever) + handle.start() + + db = lancedb.connect( + "db://dev", + api_key="fake", + host_override="http://localhost:8080", + client_config={ + "retry_config": {"retries": 2}, + "timeout_config": { + "connect_timeout": 1, + }, + }, + ) + + try: + yield db + finally: + server.shutdown() + handle.join() + + @contextlib.asynccontextmanager -async def mock_lancedb_connection(handler): +async def mock_lancedb_connection_async(handler): with http.server.HTTPServer( ("localhost", 8080), make_mock_http_handler(handler) ) as server: @@ -143,7 +98,7 @@ async def test_async_remote_db(): request.end_headers() request.wfile.write(b'{"tables": []}') - async with mock_lancedb_connection(handler) as db: + async with mock_lancedb_connection_async(handler) as db: table_names = await db.table_names() assert table_names == [] @@ -159,12 +114,12 @@ async def test_http_error(): request.end_headers() request.wfile.write(b"Internal Server Error") - async with mock_lancedb_connection(handler) as db: - with pytest.raises(HttpError, match="Internal Server Error") as exc_info: + async with mock_lancedb_connection_async(handler) as db: + with pytest.raises(HttpError) as exc_info: await db.table_names() assert exc_info.value.request_id == request_id_holder["request_id"] - assert exc_info.value.status_code == 507 + assert "Internal Server Error" in str(exc_info.value) @pytest.mark.asyncio @@ -178,15 +133,225 @@ async def test_retry_error(): request.end_headers() request.wfile.write(b"Try again later") - async with mock_lancedb_connection(handler) as db: - with pytest.raises(RetryError, match="Hit retry limit") as exc_info: + async with mock_lancedb_connection_async(handler) as db: + with pytest.raises(RetryError) as exc_info: await db.table_names() assert exc_info.value.request_id == request_id_holder["request_id"] - assert exc_info.value.status_code == 429 cause = exc_info.value.__cause__ assert isinstance(cause, HttpError) assert "Try again later" in str(cause) assert cause.request_id == request_id_holder["request_id"] assert cause.status_code == 429 + + +@contextlib.contextmanager +def query_test_table(query_handler): + def handler(request): + if request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"{}") + elif request.path == "/v1/table/test/query/": + content_len = int(request.headers.get("Content-Length")) + body = request.rfile.read(content_len) + body = json.loads(body) + + data = query_handler(body) + + request.send_response(200) + request.send_header("Content-Type", "application/vnd.apache.arrow.file") + request.end_headers() + + with pa.ipc.new_file(request.wfile, schema=data.schema) as f: + f.write_table(data) + else: + request.send_response(404) + request.end_headers() + + with mock_lancedb_connection(handler) as db: + assert repr(db) == "RemoteConnect(name=dev)" + table = db.open_table("test") + assert repr(table) == "RemoteTable(dev.test)" + yield table + + +def test_query_sync_minimal(): + def handler(body): + assert body == { + "distance_type": "l2", + "k": 10, + "prefilter": False, + "refine_factor": None, + "vector": [1.0, 2.0, 3.0], + "nprobes": 20, + } + + return pa.table({"id": [1, 2, 3]}) + + with query_test_table(handler) as table: + data = table.search([1, 2, 3]).to_list() + expected = [{"id": 1}, {"id": 2}, {"id": 3}] + assert data == expected + + +def test_query_sync_maximal(): + def handler(body): + assert body == { + "distance_type": "cosine", + "k": 42, + "prefilter": True, + "refine_factor": 10, + "vector": [1.0, 2.0, 3.0], + "nprobes": 5, + "filter": "id > 0", + "columns": ["id", "name"], + "vector_column": "vector2", + "fast_search": True, + "with_row_id": True, + } + + return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + + with query_test_table(handler) as table: + ( + table.search([1, 2, 3], vector_column_name="vector2", fast_search=True) + .metric("cosine") + .limit(42) + .refine_factor(10) + .nprobes(5) + .where("id > 0", prefilter=True) + .with_row_id(True) + .select(["id", "name"]) + .to_list() + ) + + +def test_query_sync_fts(): + def handler(body): + assert body == { + "full_text_query": { + "query": "puppy", + "columns": [], + }, + "k": 10, + "vector": [], + } + + return pa.table({"id": [1, 2, 3]}) + + with query_test_table(handler) as table: + (table.search("puppy", query_type="fts").to_list()) + + def handler(body): + assert body == { + "full_text_query": { + "query": "puppy", + "columns": ["name", "description"], + }, + "k": 42, + "vector": [], + "with_row_id": True, + } + + return pa.table({"id": [1, 2, 3]}) + + with query_test_table(handler) as table: + ( + table.search("puppy", query_type="fts", fts_columns=["name", "description"]) + .with_row_id(True) + .limit(42) + .to_list() + ) + + +def test_query_sync_hybrid(): + def handler(body): + if "full_text_query" in body: + # FTS query + assert body == { + "full_text_query": { + "query": "puppy", + "columns": [], + }, + "k": 42, + "vector": [], + "with_row_id": True, + } + return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]}) + else: + # Vector query + assert body == { + "distance_type": "l2", + "k": 42, + "prefilter": False, + "refine_factor": None, + "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "nprobes": 20, + "with_row_id": True, + } + return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]}) + + with query_test_table(handler) as table: + embedding_func = MockTextEmbeddingFunction() + embedding_config = MagicMock() + embedding_config.function = embedding_func + + embedding_funcs = MagicMock() + embedding_funcs.get = MagicMock(return_value=embedding_config) + table.embedding_functions = embedding_funcs + + (table.search("puppy", query_type="hybrid").limit(42).to_list()) + + +def test_create_client(): + mandatory_args = { + "uri": "db://dev", + "api_key": "fake-api-key", + "region": "us-east-1", + } + + db = lancedb.connect(**mandatory_args) + assert isinstance(db.client_config, ClientConfig) + + db = lancedb.connect(**mandatory_args, client_config={}) + assert isinstance(db.client_config, ClientConfig) + + db = lancedb.connect( + **mandatory_args, + client_config=ClientConfig(timeout_config={"connect_timeout": 42}), + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) + + db = lancedb.connect( + **mandatory_args, + client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}}, + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) + + db = lancedb.connect( + **mandatory_args, client_config=ClientConfig(retry_config={"retries": 42}) + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.retry_config.retries == 42 + + db = lancedb.connect( + **mandatory_args, client_config={"retry_config": {"retries": 42}} + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.retry_config.retries == 42 + + with pytest.warns(DeprecationWarning): + db = lancedb.connect(**mandatory_args, connection_timeout=42) + assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) + + with pytest.warns(DeprecationWarning): + db = lancedb.connect(**mandatory_args, read_timeout=42) + assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42) + + with pytest.warns(DeprecationWarning): + lancedb.connect(**mandatory_args, request_thread_pool=10) diff --git a/python/src/connection.rs b/python/src/connection.rs index 200285a4..46e15cfb 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -170,6 +170,17 @@ impl Connection { }) } + pub fn rename_table( + self_: PyRef<'_, Self>, + old_name: String, + new_name: String, + ) -> PyResult> { + let inner = self_.get_inner()?.clone(); + future_into_py(self_.py(), async move { + inner.rename_table(old_name, new_name).await.infer_error() + }) + } + pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult> { let inner = self_.get_inner()?.clone(); future_into_py(self_.py(), async move { diff --git a/python/src/index.rs b/python/src/index.rs index 7510b7fe..4ea4c19f 100644 --- a/python/src/index.rs +++ b/python/src/index.rs @@ -24,8 +24,8 @@ use lancedb::{ DistanceType, }; use pyo3::{ - exceptions::{PyRuntimeError, PyValueError}, - pyclass, pymethods, PyResult, + exceptions::{PyKeyError, PyRuntimeError, PyValueError}, + pyclass, pymethods, IntoPy, PyObject, PyResult, Python, }; use crate::util::parse_distance_type; @@ -236,7 +236,21 @@ pub struct IndexConfig { #[pymethods] impl IndexConfig { pub fn __repr__(&self) -> String { - format!("Index({}, columns={:?})", self.index_type, self.columns) + format!( + "Index({}, columns={:?}, name=\"{}\")", + self.index_type, self.columns, self.name + ) + } + + // For backwards-compatibility with the old sync SDK, we also support getting + // attributes via __getitem__. + pub fn __getitem__(&self, key: String, py: Python<'_>) -> PyResult { + match key.as_str() { + "index_type" => Ok(self.index_type.clone().into_py(py)), + "columns" => Ok(self.columns.clone().into_py(py)), + "name" | "index_name" => Ok(self.name.clone().into_py(py)), + _ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))), + } } } diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 44a6b443..40329b66 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -39,9 +39,6 @@ use crate::utils::validate_table_name; use crate::Table; pub use lance_encoding::version::LanceFileVersion; -#[cfg(feature = "remote")] -use log::warn; - pub const LANCE_FILE_EXTENSION: &str = "lance"; pub type TableBuilderCallback = Box OpenTableBuilder + Send>; @@ -719,8 +716,7 @@ impl ConnectBuilder { let api_key = self.api_key.ok_or_else(|| Error::InvalidInput { message: "An api_key is required when connecting to LanceDb Cloud".to_string(), })?; - // TODO: remove this warning when the remote client is ready - warn!("The rust implementation of the remote client is not yet ready for use."); + let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new( &self.uri, &api_key,