feat(python): add read_consistency_interval argument (#828)

This PR refactors how we handle read consistency: does the `LanceTable`
class always pick up modifications to the table made by other instance
or processes. Users have three options they can set at the connection
level:

1. (Default) `read_consistency_interval=None` means it will not check at
all. Users can call `table.checkout_latest()` to manually check for
updates.
2. `read_consistency_interval=timedelta(0)` means **always** check for
updates, giving strong read consistency.
3. `read_consistency_interval=timedelta(seconds=20)` means check for
updates every 20 seconds. This is eventual consistency, a compromise
between the two options above.

## Table reference state

There is now an explicit difference between a `LanceTable` that tracks
the current version and one that is fixed at a historical version. We
now enforce that users cannot write if they have checked out an old
version. They are instructed to call `checkout_latest()` before calling
the write methods.

Since `conn.open_table()` doesn't have a parameter for version, users
will only get fixed references if they call `table.checkout()`.

The difference between these two can be seen in the repr: Table that are
fixed at a particular version will have a `version` displayed in the
repr. Otherwise, the version will not be shown.

```python
>>> table
LanceTable(connection=..., name="my_table")
>>> table.checkout(1)
>>> table
LanceTable(connection=..., name="my_table", version=1)
```

I decided to not create different classes for these states, because I
think we already have enough complexity with the Cloud vs OSS table
references.

Based on #812
This commit is contained in:
Will Jones
2024-02-05 08:12:19 -08:00
committed by GitHub
parent 738511c5f2
commit 57605a2d86
8 changed files with 322 additions and 101 deletions

View File

@@ -42,6 +42,12 @@ To run the unit tests:
pytest pytest
``` ```
To run the doc tests:
```bash
pytest --doctest-modules lancedb
```
To run linter and automatically fix all errors: To run linter and automatically fix all errors:
```bash ```bash

View File

@@ -13,6 +13,7 @@
import importlib.metadata import importlib.metadata
import os import os
from datetime import timedelta
from typing import Optional from typing import Optional
__version__ = importlib.metadata.version("lancedb") __version__ = importlib.metadata.version("lancedb")
@@ -30,6 +31,7 @@ def connect(
api_key: Optional[str] = None, api_key: Optional[str] = None,
region: str = "us-east-1", region: str = "us-east-1",
host_override: Optional[str] = None, host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
) -> DBConnection: ) -> DBConnection:
"""Connect to a LanceDB database. """Connect to a LanceDB database.
@@ -45,6 +47,18 @@ def connect(
The region to use for LanceDB Cloud. The region to use for LanceDB Cloud.
host_override: str, optional host_override: str, optional
The override url for LanceDB Cloud. The override url for LanceDB Cloud.
read_consistency_interval: timedelta, default None
(For LanceDB OSS only)
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
zero seconds. Then every read will check for updates from other
processes. As a compromise, you can set this to a non-zero timedelta
for eventual consistency. If more than that interval has passed since
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Examples Examples
-------- --------
@@ -73,4 +87,4 @@ def connect(
if api_key is None: if api_key is None:
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}") raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
return RemoteDBConnection(uri, api_key, region, host_override) return RemoteDBConnection(uri, api_key, region, host_override)
return LanceDBConnection(uri) return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)

View File

@@ -26,6 +26,8 @@ from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import timedelta
from .common import DATA, URI from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel from .pydantic import LanceModel
@@ -118,7 +120,7 @@ class DBConnection(EnforceOverrides):
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data) >>> db.create_table("my_table", data)
LanceTable(my_table) LanceTable(connection=..., name="my_table")
>>> db["my_table"].head() >>> db["my_table"].head()
pyarrow.Table pyarrow.Table
vector: fixed_size_list<item: float>[2] vector: fixed_size_list<item: float>[2]
@@ -139,7 +141,7 @@ class DBConnection(EnforceOverrides):
... "long": [-122.7, -74.1] ... "long": [-122.7, -74.1]
... }) ... })
>>> db.create_table("table2", data) >>> db.create_table("table2", data)
LanceTable(table2) LanceTable(connection=..., name="table2")
>>> db["table2"].head() >>> db["table2"].head()
pyarrow.Table pyarrow.Table
vector: fixed_size_list<item: float>[2] vector: fixed_size_list<item: float>[2]
@@ -161,7 +163,7 @@ class DBConnection(EnforceOverrides):
... pa.field("long", pa.float32()) ... pa.field("long", pa.float32())
... ]) ... ])
>>> db.create_table("table3", data, schema = custom_schema) >>> db.create_table("table3", data, schema = custom_schema)
LanceTable(table3) LanceTable(connection=..., name="table3")
>>> db["table3"].head() >>> db["table3"].head()
pyarrow.Table pyarrow.Table
vector: fixed_size_list<item: float>[2] vector: fixed_size_list<item: float>[2]
@@ -195,7 +197,7 @@ class DBConnection(EnforceOverrides):
... pa.field("price", pa.float32()), ... pa.field("price", pa.float32()),
... ]) ... ])
>>> db.create_table("table4", make_batches(), schema=schema) >>> db.create_table("table4", make_batches(), schema=schema)
LanceTable(table4) LanceTable(connection=..., name="table4")
""" """
raise NotImplementedError raise NotImplementedError
@@ -243,6 +245,16 @@ class LanceDBConnection(DBConnection):
---------- ----------
uri: str or Path uri: str or Path
The root uri of the database. The root uri of the database.
read_consistency_interval: timedelta, default None
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
zero seconds. Then every read will check for updates from other
processes. As a compromise, you can set this to a non-zero timedelta
for eventual consistency. If more than that interval has passed since
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Examples Examples
-------- --------
@@ -250,22 +262,24 @@ class LanceDBConnection(DBConnection):
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}, >>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
... {"vector": [0.5, 1.3], "b": 4}]) ... {"vector": [0.5, 1.3], "b": 4}])
LanceTable(my_table) LanceTable(connection=..., name="my_table")
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}]) >>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
LanceTable(another_table) LanceTable(connection=..., name="another_table")
>>> sorted(db.table_names()) >>> sorted(db.table_names())
['another_table', 'my_table'] ['another_table', 'my_table']
>>> len(db) >>> len(db)
2 2
>>> db["my_table"] >>> db["my_table"]
LanceTable(my_table) LanceTable(connection=..., name="my_table")
>>> "my_table" in db >>> "my_table" in db
True True
>>> db.drop_table("my_table") >>> db.drop_table("my_table")
>>> db.drop_table("another_table") >>> db.drop_table("another_table")
""" """
def __init__(self, uri: URI): def __init__(
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
):
if not isinstance(uri, Path): if not isinstance(uri, Path):
scheme = get_uri_scheme(uri) scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file" is_local = isinstance(uri, Path) or scheme == "file"
@@ -277,6 +291,14 @@ class LanceDBConnection(DBConnection):
self._uri = str(uri) self._uri = str(uri)
self._entered = False self._entered = False
self.read_consistency_interval = read_consistency_interval
def __repr__(self) -> str:
val = f"{self.__class__.__name__}({self._uri}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
return val
@property @property
def uri(self) -> str: def uri(self) -> str:

View File

@@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel):
... name: str ... name: str
... vector: Vector(2) ... vector: Vector(2)
... ...
>>> db = lancedb.connect("/tmp") >>> db = lancedb.connect("./example")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema()) >>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
>>> table.add([ >>> table.add([
... TestModel(name="test", vector=[1.0, 2.0]) ... TestModel(name="test", vector=[1.0, 2.0])

View File

@@ -14,7 +14,10 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import timedelta
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -41,8 +44,6 @@ from .util import (
from .utils.events import register_event from .utils.events import register_event
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import timedelta
import PIL import PIL
from lance.dataset import CleanupStats, ReaderLike from lance.dataset import CleanupStats, ReaderLike
@@ -299,7 +300,7 @@ class Table(ABC):
import lance import lance
dataset = lance.dataset("/tmp/images.lance") dataset = lance.dataset("./images.lance")
dataset.create_scalar_index("category") dataset.create_scalar_index("category")
""" """
raise NotImplementedError raise NotImplementedError
@@ -642,23 +643,145 @@ class Table(ABC):
""" """
class _LanceDatasetRef(ABC):
@property
@abstractmethod
def dataset(self) -> LanceDataset:
pass
@property
@abstractmethod
def dataset_mut(self) -> LanceDataset:
pass
@dataclass
class _LanceLatestDatasetRef(_LanceDatasetRef):
"""Reference to the latest version of a LanceDataset."""
uri: str
read_consistency_interval: Optional[timedelta] = None
last_consistency_check: Optional[float] = None
_dataset: Optional[LanceDataset] = None
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri)
self.last_consistency_check = time.monotonic()
elif self.read_consistency_interval is not None:
now = time.monotonic()
diff = timedelta(seconds=now - self.last_consistency_check)
if (
self.last_consistency_check is None
or diff > self.read_consistency_interval
):
self._dataset = self._dataset.checkout_version(
self._dataset.latest_version
)
self.last_consistency_check = time.monotonic()
return self._dataset
@dataset.setter
def dataset(self, value: LanceDataset):
self._dataset = value
self.last_consistency_check = time.monotonic()
@property
def dataset_mut(self) -> LanceDataset:
return self.dataset
@dataclass
class _LanceTimeTravelRef(_LanceDatasetRef):
uri: str
version: int
_dataset: Optional[LanceDataset] = None
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri, version=self.version)
return self._dataset
@dataset.setter
def dataset(self, value: LanceDataset):
self._dataset = value
self.version = value.version
@property
def dataset_mut(self) -> LanceDataset:
raise ValueError(
"Cannot mutate table reference fixed at version "
f"{self.version}. Call checkout_latest() to get a mutable "
"table reference."
)
class LanceTable(Table): class LanceTable(Table):
""" """
A table in a LanceDB database. A table in a LanceDB database.
This can be opened in two modes: standard and time-travel.
Standard mode is the default. In this mode, the table is mutable and tracks
the latest version of the table. The level of read consistency is controlled
by the `read_consistency_interval` parameter on the connection.
Time-travel mode is activated by specifying a version number. In this mode,
the table is immutable and fixed to a specific version. This is useful for
querying historical versions of the table.
""" """
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None): def __init__(
self,
connection: "LanceDBConnection",
name: str,
version: Optional[int] = None,
):
self._conn = connection self._conn = connection
self.name = name self.name = name
self._version = version
def _reset_dataset(self, version=None): if version is not None:
try: self._ref = _LanceTimeTravelRef(
if "_dataset" in self.__dict__: uri=self._dataset_uri,
del self.__dict__["_dataset"] version=version,
self._version = version )
except AttributeError: else:
pass self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=connection.read_consistency_interval,
)
@classmethod
def open(cls, db, name, **kwargs):
tbl = cls(db, name, **kwargs)
fs, path = fs_from_uri(tbl._dataset_uri)
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)"
)
register_event("open_table")
return tbl
@property
def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance")
@property
def _dataset(self) -> LanceDataset:
return self._ref.dataset
@property
def _dataset_mut(self) -> LanceDataset:
return self._ref.dataset_mut
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
@property @property
def schema(self) -> pa.Schema: def schema(self) -> pa.Schema:
@@ -686,6 +809,9 @@ class LanceTable(Table):
keep writing to the dataset starting from an old version, then use keep writing to the dataset starting from an old version, then use
the `restore` function. the `restore` function.
Calling this method will set the table into time-travel mode. If you
wish to return to standard mode, call `checkout_latest`.
Parameters Parameters
---------- ----------
version : int version : int
@@ -710,15 +836,13 @@ class LanceTable(Table):
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
""" """
max_ver = max([v["version"] for v in self._dataset.versions()]) max_ver = self._dataset.latest_version
if version < 1 or version > max_ver: if version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}") raise ValueError(f"Invalid version {version}")
self._reset_dataset(version=version)
try: try:
# Accessing the property updates the cached value ds = self._dataset.checkout_version(version)
_ = self._dataset except IOError as e:
except Exception as e:
if "not found" in str(e): if "not found" in str(e):
raise ValueError( raise ValueError(
f"Version {version} no longer exists. Was it cleaned up?" f"Version {version} no longer exists. Was it cleaned up?"
@@ -726,6 +850,27 @@ class LanceTable(Table):
else: else:
raise e raise e
self._ref = _LanceTimeTravelRef(
uri=self._dataset_uri,
version=version,
)
# We've already loaded the version so we can populate it directly.
self._ref.dataset = ds
def checkout_latest(self):
"""Checkout the latest version of the table. This is an in-place operation.
The table will be set back into standard mode, and will track the latest
version of the table.
"""
self.checkout(self._dataset.latest_version)
ds = self._ref.dataset
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=self._conn.read_consistency_interval,
)
self._ref.dataset = ds
def restore(self, version: int = None): def restore(self, version: int = None):
"""Restore a version of the table. This is an in-place operation. """Restore a version of the table. This is an in-place operation.
@@ -760,7 +905,7 @@ class LanceTable(Table):
>>> len(table.list_versions()) >>> len(table.list_versions())
4 4
""" """
max_ver = max([v["version"] for v in self._dataset.versions()]) max_ver = self._dataset.latest_version
if version is None: if version is None:
version = self.version version = self.version
elif version < 1 or version > max_ver: elif version < 1 or version > max_ver:
@@ -768,12 +913,17 @@ class LanceTable(Table):
else: else:
self.checkout(version) self.checkout(version)
if version == max_ver: ds = self._dataset
# no-op if restoring the latest version
return
self._dataset.restore() # no-op if restoring the latest version
self._reset_dataset() if version != max_ver:
ds.restore()
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=self._conn.read_consistency_interval,
)
self._ref.dataset = ds
def count_rows(self, filter: Optional[str] = None) -> int: def count_rows(self, filter: Optional[str] = None) -> int:
""" """
@@ -790,7 +940,11 @@ class LanceTable(Table):
return self.count_rows() return self.count_rows()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"LanceTable({self.name})" val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"'
if isinstance(self._ref, _LanceTimeTravelRef):
val += f", version={self._ref.version}"
val += ")"
return val
def __str__(self) -> str: def __str__(self) -> str:
return self.__repr__() return self.__repr__()
@@ -840,10 +994,6 @@ class LanceTable(Table):
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
) )
@property
def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance")
def create_index( def create_index(
self, self,
metric="L2", metric="L2",
@@ -855,7 +1005,7 @@ class LanceTable(Table):
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
): ):
"""Create an index on the table.""" """Create an index on the table."""
self._dataset.create_index( self._dataset_mut.create_index(
column=vector_column_name, column=vector_column_name,
index_type="IVF_PQ", index_type="IVF_PQ",
metric=metric, metric=metric,
@@ -865,11 +1015,12 @@ class LanceTable(Table):
accelerator=accelerator, accelerator=accelerator,
index_cache_size=index_cache_size, index_cache_size=index_cache_size,
) )
self._reset_dataset()
register_event("create_index") register_event("create_index")
def create_scalar_index(self, column: str, *, replace: bool = True): def create_scalar_index(self, column: str, *, replace: bool = True):
self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace) self._dataset_mut.create_scalar_index(
column, index_type="BTREE", replace=replace
)
def create_fts_index( def create_fts_index(
self, self,
@@ -912,14 +1063,6 @@ class LanceTable(Table):
def _get_fts_index_path(self): def _get_fts_index_path(self):
return join_uri(self._dataset_uri, "_indices", "tantivy") return join_uri(self._dataset_uri, "_indices", "tantivy")
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri, version=self._version)
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
def add( def add(
self, self,
data: DATA, data: DATA,
@@ -958,8 +1101,11 @@ class LanceTable(Table):
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
) )
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode) # Access the dataset_mut property to ensure that the dataset is mutable.
self._reset_dataset() self._ref.dataset_mut
self._ref.dataset = lance.write_dataset(
data, self._dataset_uri, schema=self.schema, mode=mode
)
register_event("add") register_event("add")
def merge( def merge(
@@ -1020,10 +1166,9 @@ class LanceTable(Table):
other_table = other_table.to_lance() other_table = other_table.to_lance()
if isinstance(other_table, LanceDataset): if isinstance(other_table, LanceDataset):
other_table = other_table.to_table() other_table = other_table.to_table()
self._dataset.merge( self._ref.dataset = self._dataset_mut.merge(
other_table, left_on=left_on, right_on=right_on, schema=schema other_table, left_on=left_on, right_on=right_on, schema=schema
) )
self._reset_dataset()
register_event("merge") register_event("merge")
@cached_property @cached_property
@@ -1226,22 +1371,8 @@ class LanceTable(Table):
register_event("create_table") register_event("create_table")
return new_table return new_table
@classmethod
def open(cls, db, name):
tbl = cls(db, name)
fs, path = fs_from_uri(tbl._dataset_uri)
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)"
)
register_event("open_table")
return tbl
def delete(self, where: str): def delete(self, where: str):
self._dataset.delete(where) self._dataset_mut.delete(where)
def update( def update(
self, self,
@@ -1295,8 +1426,7 @@ class LanceTable(Table):
if values is not None: if values is not None:
values_sql = {k: value_to_sql(v) for k, v in values.items()} values_sql = {k: value_to_sql(v) for k, v in values.items()}
self.to_lance().update(values_sql, where) self._dataset_mut.update(values_sql, where)
self._reset_dataset()
register_event("update") register_event("update")
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:

View File

@@ -48,7 +48,7 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies] [project.optional-dependencies]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"] tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
dev = ["ruff", "pre-commit"] dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]

View File

@@ -88,6 +88,7 @@ def test_embedding_function(tmp_path):
assert np.allclose(actual, expected) assert np.allclose(actual, expected)
@pytest.mark.slow
def test_embedding_function_rate_limit(tmp_path): def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model): def _get_schema_from_model(model):
class Schema(LanceModel): class Schema(LanceModel):

View File

@@ -12,8 +12,10 @@
# limitations under the License. # limitations under the License.
import functools import functools
from copy import copy
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from pathlib import Path from pathlib import Path
from time import sleep
from typing import List from typing import List
from unittest.mock import PropertyMock, patch from unittest.mock import PropertyMock, patch
@@ -25,6 +27,7 @@ import pyarrow as pa
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
@@ -35,6 +38,7 @@ from lancedb.table import LanceTable
class MockDB: class MockDB:
def __init__(self, uri: Path): def __init__(self, uri: Path):
self.uri = uri self.uri = uri
self.read_consistency_interval = None
@functools.cached_property @functools.cached_property
def is_managed_remote(self) -> bool: def is_managed_remote(self) -> bool:
@@ -267,39 +271,38 @@ def test_versioning(db):
def test_create_index_method(): def test_create_index_method():
with patch.object(LanceTable, "_reset_dataset", return_value=None): with patch.object(
with patch.object( LanceTable, "_dataset_mut", new_callable=PropertyMock
LanceTable, "_dataset", new_callable=PropertyMock ) as mock_dataset:
) as mock_dataset: # Setup mock responses
# Setup mock responses mock_dataset.return_value.create_index.return_value = None
mock_dataset.return_value.create_index.return_value = None
# Create a LanceTable object # Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri") connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table") table = LanceTable(connection, "test_table")
# Call the create_index method # Call the create_index method
table.create_index( table.create_index(
metric="L2", metric="L2",
num_partitions=256, num_partitions=256,
num_sub_vectors=96, num_sub_vectors=96,
vector_column_name="vector", vector_column_name="vector",
replace=True, replace=True,
index_cache_size=256, index_cache_size=256,
) )
# Check that the _dataset.create_index method was called # Check that the _dataset.create_index method was called
# with the right parameters # with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with( mock_dataset.return_value.create_index.assert_called_once_with(
column="vector", column="vector",
index_type="IVF_PQ", index_type="IVF_PQ",
metric="L2", metric="L2",
num_partitions=256, num_partitions=256,
num_sub_vectors=96, num_sub_vectors=96,
replace=True, replace=True,
accelerator=None, accelerator=None,
index_cache_size=256, index_cache_size=256,
) )
def test_add_with_nans(db): def test_add_with_nans(db):
@@ -792,3 +795,48 @@ def test_hybrid_search(db):
"Our father who art in heaven", query_type="hybrid" "Our father who art in heaven", query_type="hybrid"
).to_pydantic(MyTable) ).to_pydantic(MyTable)
assert result1 == result3 assert result1 == result3
@pytest.mark.parametrize(
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
)
def test_consistency(tmp_path, consistency_interval):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
table2 = db2.open_table("my_table")
assert table2.version == table.version
table.add([{"id": 1}])
if consistency_interval is None:
assert table2.version == table.version - 1
table2.checkout_latest()
assert table2.version == table.version
elif consistency_interval == timedelta(seconds=0):
assert table2.version == table.version
else:
# (consistency_interval == timedelta(seconds=0.1)
assert table2.version == table.version - 1
sleep(0.1)
assert table2.version == table.version
def test_restore_consistency(tmp_path):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
table2 = db2.open_table("my_table")
assert table2.version == table.version
# If we call checkout, it should lose consistency
table_fixed = copy(table2)
table_fixed.checkout(table.version)
# But if we call checkout_latest, it should be consistent again
table_ref_latest = copy(table_fixed)
table_ref_latest.checkout_latest()
table.add([{"id": 2}])
assert table_fixed.version == table.version - 1
assert table_ref_latest.version == table.version