diff --git a/python/README.md b/python/README.md index 290447cd..94e27d6a 100644 --- a/python/README.md +++ b/python/README.md @@ -42,6 +42,12 @@ To run the unit tests: pytest ``` +To run the doc tests: + +```bash +pytest --doctest-modules lancedb +``` + To run linter and automatically fix all errors: ```bash diff --git a/python/lancedb/__init__.py b/python/lancedb/__init__.py index a87fffd4..7c04d865 100644 --- a/python/lancedb/__init__.py +++ b/python/lancedb/__init__.py @@ -13,6 +13,7 @@ import importlib.metadata import os +from datetime import timedelta from typing import Optional __version__ = importlib.metadata.version("lancedb") @@ -29,6 +30,7 @@ def connect( api_key: Optional[str] = None, region: str = "us-east-1", host_override: Optional[str] = None, + read_consistency_interval: Optional[timedelta] = None, ) -> DBConnection: """Connect to a LanceDB database. @@ -44,6 +46,18 @@ def connect( The region to use for LanceDB Cloud. host_override: str, optional The override url for LanceDB Cloud. + read_consistency_interval: timedelta, default None + (For LanceDB OSS only) + The interval at which to check for updates to the table from other + processes. If None, then consistency is not checked. For performance + reasons, this is the default. For strong consistency, set this to + zero seconds. Then every read will check for updates from other + processes. As a compromise, you can set this to a non-zero timedelta + for eventual consistency. If more than that interval has passed since + the last check, then the table will be checked for updates. Note: this + consistency only applies to read operations. Write operations are + always consistent. + Examples -------- @@ -72,4 +86,4 @@ def connect( if api_key is None: raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}") return RemoteDBConnection(uri, api_key, region, host_override) - return LanceDBConnection(uri) + return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) diff --git a/python/lancedb/db.py b/python/lancedb/db.py index eb09a538..41b56494 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -26,6 +26,8 @@ from .table import LanceTable, Table from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri if TYPE_CHECKING: + from datetime import timedelta + from .common import DATA, URI from .embeddings import EmbeddingFunctionConfig from .pydantic import LanceModel @@ -118,7 +120,7 @@ class DBConnection(EnforceOverrides): >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] >>> db.create_table("my_table", data) - LanceTable(my_table) + LanceTable(connection=..., name="my_table") >>> db["my_table"].head() pyarrow.Table vector: fixed_size_list[2] @@ -139,7 +141,7 @@ class DBConnection(EnforceOverrides): ... "long": [-122.7, -74.1] ... }) >>> db.create_table("table2", data) - LanceTable(table2) + LanceTable(connection=..., name="table2") >>> db["table2"].head() pyarrow.Table vector: fixed_size_list[2] @@ -161,7 +163,7 @@ class DBConnection(EnforceOverrides): ... pa.field("long", pa.float32()) ... ]) >>> db.create_table("table3", data, schema = custom_schema) - LanceTable(table3) + LanceTable(connection=..., name="table3") >>> db["table3"].head() pyarrow.Table vector: fixed_size_list[2] @@ -195,7 +197,7 @@ class DBConnection(EnforceOverrides): ... pa.field("price", pa.float32()), ... ]) >>> db.create_table("table4", make_batches(), schema=schema) - LanceTable(table4) + LanceTable(connection=..., name="table4") """ raise NotImplementedError @@ -243,6 +245,16 @@ class LanceDBConnection(DBConnection): ---------- uri: str or Path The root uri of the database. + read_consistency_interval: timedelta, default None + The interval at which to check for updates to the table from other + processes. If None, then consistency is not checked. For performance + reasons, this is the default. For strong consistency, set this to + zero seconds. Then every read will check for updates from other + processes. As a compromise, you can set this to a non-zero timedelta + for eventual consistency. If more than that interval has passed since + the last check, then the table will be checked for updates. Note: this + consistency only applies to read operations. Write operations are + always consistent. Examples -------- @@ -250,22 +262,24 @@ class LanceDBConnection(DBConnection): >>> db = lancedb.connect("./.lancedb") >>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}, ... {"vector": [0.5, 1.3], "b": 4}]) - LanceTable(my_table) + LanceTable(connection=..., name="my_table") >>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}]) - LanceTable(another_table) + LanceTable(connection=..., name="another_table") >>> sorted(db.table_names()) ['another_table', 'my_table'] >>> len(db) 2 >>> db["my_table"] - LanceTable(my_table) + LanceTable(connection=..., name="my_table") >>> "my_table" in db True >>> db.drop_table("my_table") >>> db.drop_table("another_table") """ - def __init__(self, uri: URI): + def __init__( + self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None + ): if not isinstance(uri, Path): scheme = get_uri_scheme(uri) is_local = isinstance(uri, Path) or scheme == "file" @@ -277,6 +291,14 @@ class LanceDBConnection(DBConnection): self._uri = str(uri) self._entered = False + self.read_consistency_interval = read_consistency_interval + + def __repr__(self) -> str: + val = f"{self.__class__.__name__}({self._uri}" + if self.read_consistency_interval is not None: + val += f", read_consistency_interval={repr(self.read_consistency_interval)}" + val += ")" + return val @property def uri(self) -> str: diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 2a550032..89bf08d6 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel): ... name: str ... vector: Vector(2) ... - >>> db = lancedb.connect("/tmp") + >>> db = lancedb.connect("./example") >>> table = db.create_table("test", schema=TestModel.to_arrow_schema()) >>> table.add([ ... TestModel(name="test", vector=[1.0, 2.0]) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 5bd6db16..be74c40d 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -14,7 +14,10 @@ from __future__ import annotations import inspect +import time from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import timedelta from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union @@ -40,8 +43,6 @@ from .util import ( ) if TYPE_CHECKING: - from datetime import timedelta - import PIL from lance.dataset import CleanupStats, ReaderLike @@ -298,7 +299,7 @@ class Table(ABC): import lance - dataset = lance.dataset("/tmp/images.lance") + dataset = lance.dataset("./images.lance") dataset.create_scalar_index("category") """ raise NotImplementedError @@ -641,23 +642,145 @@ class Table(ABC): """ +class _LanceDatasetRef(ABC): + @property + @abstractmethod + def dataset(self) -> LanceDataset: + pass + + @property + @abstractmethod + def dataset_mut(self) -> LanceDataset: + pass + + +@dataclass +class _LanceLatestDatasetRef(_LanceDatasetRef): + """Reference to the latest version of a LanceDataset.""" + + uri: str + read_consistency_interval: Optional[timedelta] = None + last_consistency_check: Optional[float] = None + _dataset: Optional[LanceDataset] = None + + @property + def dataset(self) -> LanceDataset: + if not self._dataset: + self._dataset = lance.dataset(self.uri) + self.last_consistency_check = time.monotonic() + elif self.read_consistency_interval is not None: + now = time.monotonic() + diff = timedelta(seconds=now - self.last_consistency_check) + if ( + self.last_consistency_check is None + or diff > self.read_consistency_interval + ): + self._dataset = self._dataset.checkout_version( + self._dataset.latest_version + ) + self.last_consistency_check = time.monotonic() + return self._dataset + + @dataset.setter + def dataset(self, value: LanceDataset): + self._dataset = value + self.last_consistency_check = time.monotonic() + + @property + def dataset_mut(self) -> LanceDataset: + return self.dataset + + +@dataclass +class _LanceTimeTravelRef(_LanceDatasetRef): + uri: str + version: int + _dataset: Optional[LanceDataset] = None + + @property + def dataset(self) -> LanceDataset: + if not self._dataset: + self._dataset = lance.dataset(self.uri, version=self.version) + return self._dataset + + @dataset.setter + def dataset(self, value: LanceDataset): + self._dataset = value + self.version = value.version + + @property + def dataset_mut(self) -> LanceDataset: + raise ValueError( + "Cannot mutate table reference fixed at version " + f"{self.version}. Call checkout_latest() to get a mutable " + "table reference." + ) + + class LanceTable(Table): """ A table in a LanceDB database. + + This can be opened in two modes: standard and time-travel. + + Standard mode is the default. In this mode, the table is mutable and tracks + the latest version of the table. The level of read consistency is controlled + by the `read_consistency_interval` parameter on the connection. + + Time-travel mode is activated by specifying a version number. In this mode, + the table is immutable and fixed to a specific version. This is useful for + querying historical versions of the table. """ - def __init__(self, connection: "LanceDBConnection", name: str, version: int = None): + def __init__( + self, + connection: "LanceDBConnection", + name: str, + version: Optional[int] = None, + ): self._conn = connection self.name = name - self._version = version - def _reset_dataset(self, version=None): - try: - if "_dataset" in self.__dict__: - del self.__dict__["_dataset"] - self._version = version - except AttributeError: - pass + if version is not None: + self._ref = _LanceTimeTravelRef( + uri=self._dataset_uri, + version=version, + ) + else: + self._ref = _LanceLatestDatasetRef( + uri=self._dataset_uri, + read_consistency_interval=connection.read_consistency_interval, + ) + + @classmethod + def open(cls, db, name, **kwargs): + tbl = cls(db, name, **kwargs) + fs, path = fs_from_uri(tbl._dataset_uri) + file_info = fs.get_file_info(path) + if file_info.type != pa.fs.FileType.Directory: + raise FileNotFoundError( + f"Table {name} does not exist." + f"Please first call db.create_table({name}, data)" + ) + register_event("open_table") + + return tbl + + @property + def _dataset_uri(self) -> str: + return join_uri(self._conn.uri, f"{self.name}.lance") + + @property + def _dataset(self) -> LanceDataset: + return self._ref.dataset + + @property + def _dataset_mut(self) -> LanceDataset: + return self._ref.dataset_mut + + def to_lance(self) -> LanceDataset: + """Return the LanceDataset backing this table.""" + return self._dataset @property def schema(self) -> pa.Schema: @@ -685,6 +808,9 @@ class LanceTable(Table): keep writing to the dataset starting from an old version, then use the `restore` function. + Calling this method will set the table into time-travel mode. If you + wish to return to standard mode, call `checkout_latest`. + Parameters ---------- version : int @@ -709,15 +835,13 @@ class LanceTable(Table): vector type 0 [1.1, 0.9] vector """ - max_ver = max([v["version"] for v in self._dataset.versions()]) + max_ver = self._dataset.latest_version if version < 1 or version > max_ver: raise ValueError(f"Invalid version {version}") - self._reset_dataset(version=version) try: - # Accessing the property updates the cached value - _ = self._dataset - except Exception as e: + ds = self._dataset.checkout_version(version) + except IOError as e: if "not found" in str(e): raise ValueError( f"Version {version} no longer exists. Was it cleaned up?" @@ -725,6 +849,27 @@ class LanceTable(Table): else: raise e + self._ref = _LanceTimeTravelRef( + uri=self._dataset_uri, + version=version, + ) + # We've already loaded the version so we can populate it directly. + self._ref.dataset = ds + + def checkout_latest(self): + """Checkout the latest version of the table. This is an in-place operation. + + The table will be set back into standard mode, and will track the latest + version of the table. + """ + self.checkout(self._dataset.latest_version) + ds = self._ref.dataset + self._ref = _LanceLatestDatasetRef( + uri=self._dataset_uri, + read_consistency_interval=self._conn.read_consistency_interval, + ) + self._ref.dataset = ds + def restore(self, version: int = None): """Restore a version of the table. This is an in-place operation. @@ -759,7 +904,7 @@ class LanceTable(Table): >>> len(table.list_versions()) 4 """ - max_ver = max([v["version"] for v in self._dataset.versions()]) + max_ver = self._dataset.latest_version if version is None: version = self.version elif version < 1 or version > max_ver: @@ -767,12 +912,17 @@ class LanceTable(Table): else: self.checkout(version) - if version == max_ver: - # no-op if restoring the latest version - return + ds = self._dataset - self._dataset.restore() - self._reset_dataset() + # no-op if restoring the latest version + if version != max_ver: + ds.restore() + + self._ref = _LanceLatestDatasetRef( + uri=self._dataset_uri, + read_consistency_interval=self._conn.read_consistency_interval, + ) + self._ref.dataset = ds def count_rows(self, filter: Optional[str] = None) -> int: """ @@ -789,7 +939,11 @@ class LanceTable(Table): return self.count_rows() def __repr__(self) -> str: - return f"LanceTable({self.name})" + val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"' + if isinstance(self._ref, _LanceTimeTravelRef): + val += f", version={self._ref.version}" + val += ")" + return val def __str__(self) -> str: return self.__repr__() @@ -839,10 +993,6 @@ class LanceTable(Table): self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size ) - @property - def _dataset_uri(self) -> str: - return join_uri(self._conn.uri, f"{self.name}.lance") - def create_index( self, metric="L2", @@ -854,7 +1004,7 @@ class LanceTable(Table): index_cache_size: Optional[int] = None, ): """Create an index on the table.""" - self._dataset.create_index( + self._dataset_mut.create_index( column=vector_column_name, index_type="IVF_PQ", metric=metric, @@ -864,10 +1014,11 @@ class LanceTable(Table): accelerator=accelerator, index_cache_size=index_cache_size, ) - self._reset_dataset() def create_scalar_index(self, column: str, *, replace: bool = True): - self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace) + self._dataset_mut.create_scalar_index( + column, index_type="BTREE", replace=replace + ) def create_fts_index( self, @@ -909,14 +1060,6 @@ class LanceTable(Table): def _get_fts_index_path(self): return join_uri(self._dataset_uri, "_indices", "tantivy") - @cached_property - def _dataset(self) -> LanceDataset: - return lance.dataset(self._dataset_uri, version=self._version) - - def to_lance(self) -> LanceDataset: - """Return the LanceDataset backing this table.""" - return self._dataset - def add( self, data: DATA, @@ -955,8 +1098,11 @@ class LanceTable(Table): on_bad_vectors=on_bad_vectors, fill_value=fill_value, ) - lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode) - self._reset_dataset() + # Access the dataset_mut property to ensure that the dataset is mutable. + self._ref.dataset_mut + self._ref.dataset = lance.write_dataset( + data, self._dataset_uri, schema=self.schema, mode=mode + ) def merge( self, @@ -1016,10 +1162,9 @@ class LanceTable(Table): other_table = other_table.to_lance() if isinstance(other_table, LanceDataset): other_table = other_table.to_table() - self._dataset.merge( + self._ref.dataset = self._dataset_mut.merge( other_table, left_on=left_on, right_on=right_on, schema=schema ) - self._reset_dataset() @cached_property def embedding_functions(self) -> dict: @@ -1219,22 +1364,8 @@ class LanceTable(Table): return new_table - @classmethod - def open(cls, db, name): - tbl = cls(db, name) - fs, path = fs_from_uri(tbl._dataset_uri) - file_info = fs.get_file_info(path) - if file_info.type != pa.fs.FileType.Directory: - raise FileNotFoundError( - f"Table {name} does not exist." - f"Please first call db.create_table({name}, data)" - ) - register_event("open_table") - - return tbl - def delete(self, where: str): - self._dataset.delete(where) + self._dataset_mut.delete(where) def update( self, @@ -1288,8 +1419,7 @@ class LanceTable(Table): if values is not None: values_sql = {k: value_to_sql(v) for k, v in values.items()} - self.to_lance().update(values_sql, where) - self._reset_dataset() + self._dataset_mut.update(values_sql, where) def _execute_query(self, query: Query) -> pa.Table: ds = self.to_lance() diff --git a/python/pyproject.toml b/python/pyproject.toml index b9a35386..1141ede6 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -45,7 +45,7 @@ classifiers = [ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] -tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"] +tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"] dev = ["ruff", "pre-commit"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py index a4c84fb0..32142a57 100644 --- a/python/tests/test_embeddings.py +++ b/python/tests/test_embeddings.py @@ -88,6 +88,7 @@ def test_embedding_function(tmp_path): assert np.allclose(actual, expected) +@pytest.mark.slow def test_embedding_function_rate_limit(tmp_path): def _get_schema_from_model(model): class Schema(LanceModel): diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 41d72da7..b7b8e00a 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -12,8 +12,10 @@ # limitations under the License. import functools +from copy import copy from datetime import date, datetime, timedelta from pathlib import Path +from time import sleep from typing import List from unittest.mock import PropertyMock, patch @@ -25,6 +27,7 @@ import pyarrow as pa import pytest from pydantic import BaseModel +import lancedb from lancedb.conftest import MockTextEmbeddingFunction from lancedb.db import LanceDBConnection from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry @@ -35,6 +38,7 @@ from lancedb.table import LanceTable class MockDB: def __init__(self, uri: Path): self.uri = uri + self.read_consistency_interval = None @functools.cached_property def is_managed_remote(self) -> bool: @@ -267,39 +271,38 @@ def test_versioning(db): def test_create_index_method(): - with patch.object(LanceTable, "_reset_dataset", return_value=None): - with patch.object( - LanceTable, "_dataset", new_callable=PropertyMock - ) as mock_dataset: - # Setup mock responses - mock_dataset.return_value.create_index.return_value = None + with patch.object( + LanceTable, "_dataset_mut", new_callable=PropertyMock + ) as mock_dataset: + # Setup mock responses + mock_dataset.return_value.create_index.return_value = None - # Create a LanceTable object - connection = LanceDBConnection(uri="mock.uri") - table = LanceTable(connection, "test_table") + # Create a LanceTable object + connection = LanceDBConnection(uri="mock.uri") + table = LanceTable(connection, "test_table") - # Call the create_index method - table.create_index( - metric="L2", - num_partitions=256, - num_sub_vectors=96, - vector_column_name="vector", - replace=True, - index_cache_size=256, - ) + # Call the create_index method + table.create_index( + metric="L2", + num_partitions=256, + num_sub_vectors=96, + vector_column_name="vector", + replace=True, + index_cache_size=256, + ) - # Check that the _dataset.create_index method was called - # with the right parameters - mock_dataset.return_value.create_index.assert_called_once_with( - column="vector", - index_type="IVF_PQ", - metric="L2", - num_partitions=256, - num_sub_vectors=96, - replace=True, - accelerator=None, - index_cache_size=256, - ) + # Check that the _dataset.create_index method was called + # with the right parameters + mock_dataset.return_value.create_index.assert_called_once_with( + column="vector", + index_type="IVF_PQ", + metric="L2", + num_partitions=256, + num_sub_vectors=96, + replace=True, + accelerator=None, + index_cache_size=256, + ) def test_add_with_nans(db): @@ -792,3 +795,48 @@ def test_hybrid_search(db): "Our father who art in heaven", query_type="hybrid" ).to_pydantic(MyTable) assert result1 == result3 + + +@pytest.mark.parametrize( + "consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)] +) +def test_consistency(tmp_path, consistency_interval): + db = lancedb.connect(tmp_path) + table = LanceTable.create(db, "my_table", data=[{"id": 0}]) + + db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval) + table2 = db2.open_table("my_table") + assert table2.version == table.version + + table.add([{"id": 1}]) + + if consistency_interval is None: + assert table2.version == table.version - 1 + table2.checkout_latest() + assert table2.version == table.version + elif consistency_interval == timedelta(seconds=0): + assert table2.version == table.version + else: + # (consistency_interval == timedelta(seconds=0.1) + assert table2.version == table.version - 1 + sleep(0.1) + assert table2.version == table.version + + +def test_restore_consistency(tmp_path): + db = lancedb.connect(tmp_path) + table = LanceTable.create(db, "my_table", data=[{"id": 0}]) + + db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) + table2 = db2.open_table("my_table") + assert table2.version == table.version + + # If we call checkout, it should lose consistency + table_fixed = copy(table2) + table_fixed.checkout(table.version) + # But if we call checkout_latest, it should be consistent again + table_ref_latest = copy(table_fixed) + table_ref_latest.checkout_latest() + table.add([{"id": 2}]) + assert table_fixed.version == table.version - 1 + assert table_ref_latest.version == table.version