feat(python): add read_consistency_interval argument (#828)

This PR refactors how we handle read consistency: does the `LanceTable`
class always pick up modifications to the table made by other instance
or processes. Users have three options they can set at the connection
level:

1. (Default) `read_consistency_interval=None` means it will not check at
all. Users can call `table.checkout_latest()` to manually check for
updates.
2. `read_consistency_interval=timedelta(0)` means **always** check for
updates, giving strong read consistency.
3. `read_consistency_interval=timedelta(seconds=20)` means check for
updates every 20 seconds. This is eventual consistency, a compromise
between the two options above.

There is now an explicit difference between a `LanceTable` that tracks
the current version and one that is fixed at a historical version. We
now enforce that users cannot write if they have checked out an old
version. They are instructed to call `checkout_latest()` before calling
the write methods.

Since `conn.open_table()` doesn't have a parameter for version, users
will only get fixed references if they call `table.checkout()`.

The difference between these two can be seen in the repr: Table that are
fixed at a particular version will have a `version` displayed in the
repr. Otherwise, the version will not be shown.

```python
>>> table
LanceTable(connection=..., name="my_table")
>>> table.checkout(1)
>>> table
LanceTable(connection=..., name="my_table", version=1)
```

I decided to not create different classes for these states, because I
think we already have enough complexity with the Cloud vs OSS table
references.

Based on #812
This commit is contained in:
Will Jones
2024-02-05 08:12:19 -08:00
committed by Weston Pace
parent 0f00cd0097
commit 39cc2fd62b
8 changed files with 322 additions and 101 deletions

View File

@@ -42,6 +42,12 @@ To run the unit tests:
pytest
```
To run the doc tests:
```bash
pytest --doctest-modules lancedb
```
To run linter and automatically fix all errors:
```bash

View File

@@ -13,6 +13,7 @@
import importlib.metadata
import os
from datetime import timedelta
from typing import Optional
__version__ = importlib.metadata.version("lancedb")
@@ -29,6 +30,7 @@ def connect(
api_key: Optional[str] = None,
region: str = "us-east-1",
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
) -> DBConnection:
"""Connect to a LanceDB database.
@@ -44,6 +46,18 @@ def connect(
The region to use for LanceDB Cloud.
host_override: str, optional
The override url for LanceDB Cloud.
read_consistency_interval: timedelta, default None
(For LanceDB OSS only)
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
zero seconds. Then every read will check for updates from other
processes. As a compromise, you can set this to a non-zero timedelta
for eventual consistency. If more than that interval has passed since
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Examples
--------
@@ -72,4 +86,4 @@ def connect(
if api_key is None:
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
return RemoteDBConnection(uri, api_key, region, host_override)
return LanceDBConnection(uri)
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)

View File

@@ -26,6 +26,8 @@ from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
if TYPE_CHECKING:
from datetime import timedelta
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel
@@ -118,7 +120,7 @@ class DBConnection(EnforceOverrides):
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data)
LanceTable(my_table)
LanceTable(connection=..., name="my_table")
>>> db["my_table"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -139,7 +141,7 @@ class DBConnection(EnforceOverrides):
... "long": [-122.7, -74.1]
... })
>>> db.create_table("table2", data)
LanceTable(table2)
LanceTable(connection=..., name="table2")
>>> db["table2"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -161,7 +163,7 @@ class DBConnection(EnforceOverrides):
... pa.field("long", pa.float32())
... ])
>>> db.create_table("table3", data, schema = custom_schema)
LanceTable(table3)
LanceTable(connection=..., name="table3")
>>> db["table3"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -195,7 +197,7 @@ class DBConnection(EnforceOverrides):
... pa.field("price", pa.float32()),
... ])
>>> db.create_table("table4", make_batches(), schema=schema)
LanceTable(table4)
LanceTable(connection=..., name="table4")
"""
raise NotImplementedError
@@ -243,6 +245,16 @@ class LanceDBConnection(DBConnection):
----------
uri: str or Path
The root uri of the database.
read_consistency_interval: timedelta, default None
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
zero seconds. Then every read will check for updates from other
processes. As a compromise, you can set this to a non-zero timedelta
for eventual consistency. If more than that interval has passed since
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Examples
--------
@@ -250,22 +262,24 @@ class LanceDBConnection(DBConnection):
>>> db = lancedb.connect("./.lancedb")
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
... {"vector": [0.5, 1.3], "b": 4}])
LanceTable(my_table)
LanceTable(connection=..., name="my_table")
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
LanceTable(another_table)
LanceTable(connection=..., name="another_table")
>>> sorted(db.table_names())
['another_table', 'my_table']
>>> len(db)
2
>>> db["my_table"]
LanceTable(my_table)
LanceTable(connection=..., name="my_table")
>>> "my_table" in db
True
>>> db.drop_table("my_table")
>>> db.drop_table("another_table")
"""
def __init__(self, uri: URI):
def __init__(
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
):
if not isinstance(uri, Path):
scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file"
@@ -277,6 +291,14 @@ class LanceDBConnection(DBConnection):
self._uri = str(uri)
self._entered = False
self.read_consistency_interval = read_consistency_interval
def __repr__(self) -> str:
val = f"{self.__class__.__name__}({self._uri}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
return val
@property
def uri(self) -> str:

View File

@@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel):
... name: str
... vector: Vector(2)
...
>>> db = lancedb.connect("/tmp")
>>> db = lancedb.connect("./example")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
>>> table.add([
... TestModel(name="test", vector=[1.0, 2.0])

View File

@@ -14,7 +14,10 @@
from __future__ import annotations
import inspect
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -40,8 +43,6 @@ from .util import (
)
if TYPE_CHECKING:
from datetime import timedelta
import PIL
from lance.dataset import CleanupStats, ReaderLike
@@ -298,7 +299,7 @@ class Table(ABC):
import lance
dataset = lance.dataset("/tmp/images.lance")
dataset = lance.dataset("./images.lance")
dataset.create_scalar_index("category")
"""
raise NotImplementedError
@@ -641,23 +642,145 @@ class Table(ABC):
"""
class _LanceDatasetRef(ABC):
@property
@abstractmethod
def dataset(self) -> LanceDataset:
pass
@property
@abstractmethod
def dataset_mut(self) -> LanceDataset:
pass
@dataclass
class _LanceLatestDatasetRef(_LanceDatasetRef):
"""Reference to the latest version of a LanceDataset."""
uri: str
read_consistency_interval: Optional[timedelta] = None
last_consistency_check: Optional[float] = None
_dataset: Optional[LanceDataset] = None
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri)
self.last_consistency_check = time.monotonic()
elif self.read_consistency_interval is not None:
now = time.monotonic()
diff = timedelta(seconds=now - self.last_consistency_check)
if (
self.last_consistency_check is None
or diff > self.read_consistency_interval
):
self._dataset = self._dataset.checkout_version(
self._dataset.latest_version
)
self.last_consistency_check = time.monotonic()
return self._dataset
@dataset.setter
def dataset(self, value: LanceDataset):
self._dataset = value
self.last_consistency_check = time.monotonic()
@property
def dataset_mut(self) -> LanceDataset:
return self.dataset
@dataclass
class _LanceTimeTravelRef(_LanceDatasetRef):
uri: str
version: int
_dataset: Optional[LanceDataset] = None
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri, version=self.version)
return self._dataset
@dataset.setter
def dataset(self, value: LanceDataset):
self._dataset = value
self.version = value.version
@property
def dataset_mut(self) -> LanceDataset:
raise ValueError(
"Cannot mutate table reference fixed at version "
f"{self.version}. Call checkout_latest() to get a mutable "
"table reference."
)
class LanceTable(Table):
"""
A table in a LanceDB database.
This can be opened in two modes: standard and time-travel.
Standard mode is the default. In this mode, the table is mutable and tracks
the latest version of the table. The level of read consistency is controlled
by the `read_consistency_interval` parameter on the connection.
Time-travel mode is activated by specifying a version number. In this mode,
the table is immutable and fixed to a specific version. This is useful for
querying historical versions of the table.
"""
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
def __init__(
self,
connection: "LanceDBConnection",
name: str,
version: Optional[int] = None,
):
self._conn = connection
self.name = name
self._version = version
def _reset_dataset(self, version=None):
try:
if "_dataset" in self.__dict__:
del self.__dict__["_dataset"]
self._version = version
except AttributeError:
pass
if version is not None:
self._ref = _LanceTimeTravelRef(
uri=self._dataset_uri,
version=version,
)
else:
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=connection.read_consistency_interval,
)
@classmethod
def open(cls, db, name, **kwargs):
tbl = cls(db, name, **kwargs)
fs, path = fs_from_uri(tbl._dataset_uri)
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)"
)
register_event("open_table")
return tbl
@property
def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance")
@property
def _dataset(self) -> LanceDataset:
return self._ref.dataset
@property
def _dataset_mut(self) -> LanceDataset:
return self._ref.dataset_mut
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
@property
def schema(self) -> pa.Schema:
@@ -685,6 +808,9 @@ class LanceTable(Table):
keep writing to the dataset starting from an old version, then use
the `restore` function.
Calling this method will set the table into time-travel mode. If you
wish to return to standard mode, call `checkout_latest`.
Parameters
----------
version : int
@@ -709,15 +835,13 @@ class LanceTable(Table):
vector type
0 [1.1, 0.9] vector
"""
max_ver = max([v["version"] for v in self._dataset.versions()])
max_ver = self._dataset.latest_version
if version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}")
self._reset_dataset(version=version)
try:
# Accessing the property updates the cached value
_ = self._dataset
except Exception as e:
ds = self._dataset.checkout_version(version)
except IOError as e:
if "not found" in str(e):
raise ValueError(
f"Version {version} no longer exists. Was it cleaned up?"
@@ -725,6 +849,27 @@ class LanceTable(Table):
else:
raise e
self._ref = _LanceTimeTravelRef(
uri=self._dataset_uri,
version=version,
)
# We've already loaded the version so we can populate it directly.
self._ref.dataset = ds
def checkout_latest(self):
"""Checkout the latest version of the table. This is an in-place operation.
The table will be set back into standard mode, and will track the latest
version of the table.
"""
self.checkout(self._dataset.latest_version)
ds = self._ref.dataset
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=self._conn.read_consistency_interval,
)
self._ref.dataset = ds
def restore(self, version: int = None):
"""Restore a version of the table. This is an in-place operation.
@@ -759,7 +904,7 @@ class LanceTable(Table):
>>> len(table.list_versions())
4
"""
max_ver = max([v["version"] for v in self._dataset.versions()])
max_ver = self._dataset.latest_version
if version is None:
version = self.version
elif version < 1 or version > max_ver:
@@ -767,12 +912,17 @@ class LanceTable(Table):
else:
self.checkout(version)
if version == max_ver:
# no-op if restoring the latest version
return
ds = self._dataset
self._dataset.restore()
self._reset_dataset()
# no-op if restoring the latest version
if version != max_ver:
ds.restore()
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=self._conn.read_consistency_interval,
)
self._ref.dataset = ds
def count_rows(self, filter: Optional[str] = None) -> int:
"""
@@ -789,7 +939,11 @@ class LanceTable(Table):
return self.count_rows()
def __repr__(self) -> str:
return f"LanceTable({self.name})"
val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"'
if isinstance(self._ref, _LanceTimeTravelRef):
val += f", version={self._ref.version}"
val += ")"
return val
def __str__(self) -> str:
return self.__repr__()
@@ -839,10 +993,6 @@ class LanceTable(Table):
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
)
@property
def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance")
def create_index(
self,
metric="L2",
@@ -854,7 +1004,7 @@ class LanceTable(Table):
index_cache_size: Optional[int] = None,
):
"""Create an index on the table."""
self._dataset.create_index(
self._dataset_mut.create_index(
column=vector_column_name,
index_type="IVF_PQ",
metric=metric,
@@ -864,10 +1014,11 @@ class LanceTable(Table):
accelerator=accelerator,
index_cache_size=index_cache_size,
)
self._reset_dataset()
def create_scalar_index(self, column: str, *, replace: bool = True):
self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace)
self._dataset_mut.create_scalar_index(
column, index_type="BTREE", replace=replace
)
def create_fts_index(
self,
@@ -909,14 +1060,6 @@ class LanceTable(Table):
def _get_fts_index_path(self):
return join_uri(self._dataset_uri, "_indices", "tantivy")
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri, version=self._version)
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
def add(
self,
data: DATA,
@@ -955,8 +1098,11 @@ class LanceTable(Table):
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
self._reset_dataset()
# Access the dataset_mut property to ensure that the dataset is mutable.
self._ref.dataset_mut
self._ref.dataset = lance.write_dataset(
data, self._dataset_uri, schema=self.schema, mode=mode
)
def merge(
self,
@@ -1016,10 +1162,9 @@ class LanceTable(Table):
other_table = other_table.to_lance()
if isinstance(other_table, LanceDataset):
other_table = other_table.to_table()
self._dataset.merge(
self._ref.dataset = self._dataset_mut.merge(
other_table, left_on=left_on, right_on=right_on, schema=schema
)
self._reset_dataset()
@cached_property
def embedding_functions(self) -> dict:
@@ -1219,22 +1364,8 @@ class LanceTable(Table):
return new_table
@classmethod
def open(cls, db, name):
tbl = cls(db, name)
fs, path = fs_from_uri(tbl._dataset_uri)
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)"
)
register_event("open_table")
return tbl
def delete(self, where: str):
self._dataset.delete(where)
self._dataset_mut.delete(where)
def update(
self,
@@ -1288,8 +1419,7 @@ class LanceTable(Table):
if values is not None:
values_sql = {k: value_to_sql(v) for k, v in values.items()}
self.to_lance().update(values_sql, where)
self._reset_dataset()
self._dataset_mut.update(values_sql, where)
def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance()

View File

@@ -45,7 +45,7 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]

View File

@@ -88,6 +88,7 @@ def test_embedding_function(tmp_path):
assert np.allclose(actual, expected)
@pytest.mark.slow
def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model):
class Schema(LanceModel):

View File

@@ -12,8 +12,10 @@
# limitations under the License.
import functools
from copy import copy
from datetime import date, datetime, timedelta
from pathlib import Path
from time import sleep
from typing import List
from unittest.mock import PropertyMock, patch
@@ -25,6 +27,7 @@ import pyarrow as pa
import pytest
from pydantic import BaseModel
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
@@ -35,6 +38,7 @@ from lancedb.table import LanceTable
class MockDB:
def __init__(self, uri: Path):
self.uri = uri
self.read_consistency_interval = None
@functools.cached_property
def is_managed_remote(self) -> bool:
@@ -267,39 +271,38 @@ def test_versioning(db):
def test_create_index_method():
with patch.object(LanceTable, "_reset_dataset", return_value=None):
with patch.object(
LanceTable, "_dataset", new_callable=PropertyMock
) as mock_dataset:
# Setup mock responses
mock_dataset.return_value.create_index.return_value = None
with patch.object(
LanceTable, "_dataset_mut", new_callable=PropertyMock
) as mock_dataset:
# Setup mock responses
mock_dataset.return_value.create_index.return_value = None
# Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table")
# Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table")
# Call the create_index method
table.create_index(
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name="vector",
replace=True,
index_cache_size=256,
)
# Call the create_index method
table.create_index(
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name="vector",
replace=True,
index_cache_size=256,
)
# Check that the _dataset.create_index method was called
# with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
metric="L2",
num_partitions=256,
num_sub_vectors=96,
replace=True,
accelerator=None,
index_cache_size=256,
)
# Check that the _dataset.create_index method was called
# with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
metric="L2",
num_partitions=256,
num_sub_vectors=96,
replace=True,
accelerator=None,
index_cache_size=256,
)
def test_add_with_nans(db):
@@ -792,3 +795,48 @@ def test_hybrid_search(db):
"Our father who art in heaven", query_type="hybrid"
).to_pydantic(MyTable)
assert result1 == result3
@pytest.mark.parametrize(
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
)
def test_consistency(tmp_path, consistency_interval):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
table2 = db2.open_table("my_table")
assert table2.version == table.version
table.add([{"id": 1}])
if consistency_interval is None:
assert table2.version == table.version - 1
table2.checkout_latest()
assert table2.version == table.version
elif consistency_interval == timedelta(seconds=0):
assert table2.version == table.version
else:
# (consistency_interval == timedelta(seconds=0.1)
assert table2.version == table.version - 1
sleep(0.1)
assert table2.version == table.version
def test_restore_consistency(tmp_path):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
table2 = db2.open_table("my_table")
assert table2.version == table.version
# If we call checkout, it should lose consistency
table_fixed = copy(table2)
table_fixed.checkout(table.version)
# But if we call checkout_latest, it should be consistent again
table_ref_latest = copy(table_fixed)
table_ref_latest.checkout_latest()
table.add([{"id": 2}])
assert table_fixed.version == table.version - 1
assert table_ref_latest.version == table.version