From 239f725b320467fff24610c4653bb307e36fb14e Mon Sep 17 00:00:00 2001 From: Bert Date: Thu, 5 Dec 2024 14:54:39 -0500 Subject: [PATCH] feat(python)!: async-sync feature parity on Connections (#1905) Closes #1791 Closes #1764 Closes #1897 (Makes this unnecessary) BREAKING CHANGE: when using azure connection string `az://...` the call to connect will fail if the azure storage credentials are not set. this is breaking from the previous behaviour where the call would fail after connect, when user invokes methods on the connection. --- docs/src/guides/storage.md | 8 +- python/python/lancedb/__init__.py | 10 ++- .../lancedb/{remote => }/background_loop.py | 0 python/python/lancedb/db.py | 83 +++++++++---------- python/python/lancedb/remote/db.py | 5 +- python/python/lancedb/table.py | 42 +++++++--- python/python/tests/test_table.py | 1 + 7 files changed, 83 insertions(+), 66 deletions(-) rename python/python/lancedb/{remote => }/background_loop.py (100%) diff --git a/docs/src/guides/storage.md b/docs/src/guides/storage.md index f4a7904b..88cef2df 100644 --- a/docs/src/guides/storage.md +++ b/docs/src/guides/storage.md @@ -27,10 +27,13 @@ LanceDB OSS supports object stores such as AWS S3 (and compatible stores), Azure Azure Blob Storage: + ```python import lancedb db = lancedb.connect("az://bucket/path") ``` + Note that for Azure, storage credentials must be configured. See [below](#azure-blob-storage) for more details. + === "TypeScript" @@ -87,11 +90,6 @@ In most cases, when running in the respective cloud and permissions are set up c export TIMEOUT=60s ``` -!!! note "`storage_options` availability" - - The `storage_options` parameter is only available in Python *async* API and JavaScript API. - It is not yet supported in the Python synchronous API. - If you only want this to apply to one particular connection, you can pass the `storage_options` argument when opening the connection: === "Python" diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 2c5e521d..fb266a01 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -36,6 +36,7 @@ def connect( read_consistency_interval: Optional[timedelta] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, client_config: Union[ClientConfig, Dict[str, Any], None] = None, + storage_options: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> DBConnection: """Connect to a LanceDB database. @@ -67,6 +68,9 @@ def connect( Configuration options for the LanceDB Cloud HTTP client. If a dict, then the keys are the attributes of the ClientConfig class. If None, then the default configuration is used. + storage_options: dict, optional + Additional options for the storage backend. See available options at + https://lancedb.github.io/lancedb/guides/storage/ Examples -------- @@ -111,7 +115,11 @@ def connect( if kwargs: raise ValueError(f"Unknown keyword arguments: {kwargs}") - return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) + return LanceDBConnection( + uri, + read_consistency_interval=read_consistency_interval, + storage_options=storage_options, + ) async def connect_async( diff --git a/python/python/lancedb/remote/background_loop.py b/python/python/lancedb/background_loop.py similarity index 100% rename from python/python/lancedb/remote/background_loop.py rename to python/python/lancedb/background_loop.py diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 0a9e27d8..ab75ec1f 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -13,34 +13,29 @@ from __future__ import annotations -import asyncio -import os from abc import abstractmethod from pathlib import Path from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union -import pyarrow as pa from overrides import EnforceOverrides, override -from pyarrow import fs -from lancedb.common import data_to_reader, validate_schema +from lancedb.common import data_to_reader, sanitize_uri, validate_schema +from lancedb.background_loop import BackgroundEventLoop from ._lancedb import connect as lancedb_connect from .table import ( AsyncTable, LanceTable, Table, - _table_path, sanitize_create_table, ) from .util import ( - fs_from_uri, - get_uri_location, get_uri_scheme, validate_table_name, ) if TYPE_CHECKING: + import pyarrow as pa from .pydantic import LanceModel from datetime import timedelta @@ -48,6 +43,8 @@ if TYPE_CHECKING: from .common import DATA, URI from .embeddings import EmbeddingFunctionConfig +LOOP = BackgroundEventLoop() + class DBConnection(EnforceOverrides): """An active LanceDB connection interface.""" @@ -180,6 +177,7 @@ class DBConnection(EnforceOverrides): control over how data is saved, either provide the PyArrow schema to convert to or else provide a [PyArrow Table](pyarrow.Table) directly. + >>> import pyarrow as pa >>> custom_schema = pa.schema([ ... pa.field("vector", pa.list_(pa.float32(), 2)), ... pa.field("lat", pa.float32()), @@ -327,7 +325,11 @@ class LanceDBConnection(DBConnection): """ def __init__( - self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None + self, + uri: URI, + *, + read_consistency_interval: Optional[timedelta] = None, + storage_options: Optional[Dict[str, str]] = None, ): if not isinstance(uri, Path): scheme = get_uri_scheme(uri) @@ -338,9 +340,27 @@ class LanceDBConnection(DBConnection): uri = uri.expanduser().absolute() Path(uri).mkdir(parents=True, exist_ok=True) self._uri = str(uri) - self._entered = False self.read_consistency_interval = read_consistency_interval + self.storage_options = storage_options + + if read_consistency_interval is not None: + read_consistency_interval_secs = read_consistency_interval.total_seconds() + else: + read_consistency_interval_secs = None + + async def do_connect(): + return await lancedb_connect( + sanitize_uri(uri), + None, + None, + None, + read_consistency_interval_secs, + None, + storage_options, + ) + + self._conn = AsyncConnection(LOOP.run(do_connect())) def __repr__(self) -> str: val = f"{self.__class__.__name__}({self._uri}" @@ -364,32 +384,7 @@ class LanceDBConnection(DBConnection): Iterator of str. A list of table names. """ - try: - asyncio.get_running_loop() - # User application is async. Soon we will just tell them to use the - # async version. Until then fallback to the old sync implementation. - try: - filesystem = fs_from_uri(self.uri)[0] - except pa.ArrowInvalid: - raise NotImplementedError("Unsupported scheme: " + self.uri) - - try: - loc = get_uri_location(self.uri) - paths = filesystem.get_file_info(fs.FileSelector(loc)) - except FileNotFoundError: - # It is ok if the file does not exist since it will be created - paths = [] - tables = [ - os.path.splitext(file_info.base_name)[0] - for file_info in paths - if file_info.extension == "lance" - ] - tables.sort() - return tables - except RuntimeError: - # User application is sync. It is safe to use the async implementation - # under the hood. - return asyncio.run(self._async_get_table_names(page_token, limit)) + return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit)) def __len__(self) -> int: return len(self.table_names()) @@ -461,19 +456,16 @@ class LanceDBConnection(DBConnection): If True, ignore if the table does not exist. """ try: - table_uri = _table_path(self.uri, name) - filesystem, path = fs_from_uri(table_uri) - filesystem.delete_dir(path) - except FileNotFoundError: + LOOP.run(self._conn.drop_table(name)) + except ValueError as e: if not ignore_missing: - raise + raise e + if f"Table '{name}' was not found" not in str(e): + raise e @override def drop_database(self): - dummy_table_uri = _table_path(self.uri, "dummy") - uri = dummy_table_uri.removesuffix("dummy.lance") - filesystem, path = fs_from_uri(uri) - filesystem.delete_dir(path) + LOOP.run(self._conn.drop_database()) class AsyncConnection(object): @@ -689,6 +681,7 @@ class AsyncConnection(object): control over how data is saved, either provide the PyArrow schema to convert to or else provide a [PyArrow Table](pyarrow.Table) directly. + >>> import pyarrow as pa >>> custom_schema = pa.schema([ ... pa.field("vector", pa.list_(pa.float32(), 2)), ... pa.field("lat", pa.float32()), diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index a1281739..d79aacd9 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -20,19 +20,16 @@ import warnings from lancedb import connect_async from lancedb.remote import ClientConfig -from lancedb.remote.background_loop import BackgroundEventLoop import pyarrow as pa from overrides import override from ..common import DATA -from ..db import DBConnection +from ..db import DBConnection, LOOP from ..embeddings import EmbeddingFunctionConfig from ..pydantic import LanceModel from ..table import Table from ..util import validate_table_name -LOOP = BackgroundEventLoop() - class RemoteDBConnection(DBConnection): """A connection to a remote LanceDB database.""" diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 4edb4aa3..45388ab8 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1077,13 +1077,16 @@ class _LanceLatestDatasetRef(_LanceDatasetRef): index_cache_size: Optional[int] = None read_consistency_interval: Optional[timedelta] = None last_consistency_check: Optional[float] = None + storage_options: Optional[Dict[str, str]] = None _dataset: Optional[LanceDataset] = None @property def dataset(self) -> LanceDataset: if not self._dataset: self._dataset = lance.dataset( - self.uri, index_cache_size=self.index_cache_size + self.uri, + index_cache_size=self.index_cache_size, + storage_options=self.storage_options, ) self.last_consistency_check = time.monotonic() elif self.read_consistency_interval is not None: @@ -1114,13 +1117,17 @@ class _LanceTimeTravelRef(_LanceDatasetRef): uri: str version: int index_cache_size: Optional[int] = None + storage_options: Optional[Dict[str, str]] = None _dataset: Optional[LanceDataset] = None @property def dataset(self) -> LanceDataset: if not self._dataset: self._dataset = lance.dataset( - self.uri, version=self.version, index_cache_size=self.index_cache_size + self.uri, + version=self.version, + index_cache_size=self.index_cache_size, + storage_options=self.storage_options, ) return self._dataset @@ -1169,24 +1176,27 @@ class LanceTable(Table): uri=self._dataset_uri, version=version, index_cache_size=index_cache_size, + storage_options=connection.storage_options, ) else: self._ref = _LanceLatestDatasetRef( uri=self._dataset_uri, read_consistency_interval=connection.read_consistency_interval, index_cache_size=index_cache_size, + storage_options=connection.storage_options, ) @classmethod def open(cls, db, name, **kwargs): tbl = cls(db, name, **kwargs) - fs, path = fs_from_uri(tbl._dataset_path) - file_info = fs.get_file_info(path) - if file_info.type != pa.fs.FileType.Directory: - raise FileNotFoundError( - f"Table {name} does not exist." - f"Please first call db.create_table({name}, data)" - ) + + # check the dataset exists + try: + tbl.version + except ValueError as e: + if "Not found:" in str(e): + raise FileNotFoundError(f"Table {name} does not exist") + raise e return tbl @@ -1617,7 +1627,11 @@ class LanceTable(Table): # Access the dataset_mut property to ensure that the dataset is mutable. self._ref.dataset_mut self._ref.dataset = lance.write_dataset( - data, self._dataset_uri, schema=self.schema, mode=mode + data, + self._dataset_uri, + schema=self.schema, + mode=mode, + storage_options=self._ref.storage_options, ) def merge( @@ -1902,7 +1916,13 @@ class LanceTable(Table): empty = pa.Table.from_batches([], schema=schema) try: - lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode) + lance.write_dataset( + empty, + tbl._dataset_uri, + schema=schema, + mode=mode, + storage_options=db.storage_options, + ) except OSError as err: if "Dataset already exists" in str(err) and exist_ok: if tbl.schema != schema: diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index a4dd2e9d..d1b44c50 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -30,6 +30,7 @@ class MockDB: def __init__(self, uri: Path): self.uri = str(uri) self.read_consistency_interval = None + self.storage_options = None @functools.cached_property def is_managed_remote(self) -> bool: