mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat(python): add read_consistency_interval argument (#828)
This PR refactors how we handle read consistency: does the `LanceTable` class always pick up modifications to the table made by other instance or processes. Users have three options they can set at the connection level: 1. (Default) `read_consistency_interval=None` means it will not check at all. Users can call `table.checkout_latest()` to manually check for updates. 2. `read_consistency_interval=timedelta(0)` means **always** check for updates, giving strong read consistency. 3. `read_consistency_interval=timedelta(seconds=20)` means check for updates every 20 seconds. This is eventual consistency, a compromise between the two options above. There is now an explicit difference between a `LanceTable` that tracks the current version and one that is fixed at a historical version. We now enforce that users cannot write if they have checked out an old version. They are instructed to call `checkout_latest()` before calling the write methods. Since `conn.open_table()` doesn't have a parameter for version, users will only get fixed references if they call `table.checkout()`. The difference between these two can be seen in the repr: Table that are fixed at a particular version will have a `version` displayed in the repr. Otherwise, the version will not be shown. ```python >>> table LanceTable(connection=..., name="my_table") >>> table.checkout(1) >>> table LanceTable(connection=..., name="my_table", version=1) ``` I decided to not create different classes for these states, because I think we already have enough complexity with the Cloud vs OSS table references. Based on #812
This commit is contained in:
@@ -42,6 +42,12 @@ To run the unit tests:
|
||||
pytest
|
||||
```
|
||||
|
||||
To run the doc tests:
|
||||
|
||||
```bash
|
||||
pytest --doctest-modules lancedb
|
||||
```
|
||||
|
||||
To run linter and automatically fix all errors:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
@@ -29,6 +30,7 @@ def connect(
|
||||
api_key: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -44,6 +46,18 @@ def connect(
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -72,4 +86,4 @@ def connect(
|
||||
if api_key is None:
|
||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||
return RemoteDBConnection(uri, api_key, region, host_override)
|
||||
return LanceDBConnection(uri)
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
@@ -26,6 +26,8 @@ from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
@@ -118,7 +120,7 @@ class DBConnection(EnforceOverrides):
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(my_table)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -139,7 +141,7 @@ class DBConnection(EnforceOverrides):
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(table2)
|
||||
LanceTable(connection=..., name="table2")
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -161,7 +163,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(table3)
|
||||
LanceTable(connection=..., name="table3")
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -195,7 +197,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(table4)
|
||||
LanceTable(connection=..., name="table4")
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -243,6 +245,16 @@ class LanceDBConnection(DBConnection):
|
||||
----------
|
||||
uri: str or Path
|
||||
The root uri of the database.
|
||||
read_consistency_interval: timedelta, default None
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -250,22 +262,24 @@ class LanceDBConnection(DBConnection):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4}])
|
||||
LanceTable(my_table)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||
LanceTable(another_table)
|
||||
LanceTable(connection=..., name="another_table")
|
||||
>>> sorted(db.table_names())
|
||||
['another_table', 'my_table']
|
||||
>>> len(db)
|
||||
2
|
||||
>>> db["my_table"]
|
||||
LanceTable(my_table)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> "my_table" in db
|
||||
True
|
||||
>>> db.drop_table("my_table")
|
||||
>>> db.drop_table("another_table")
|
||||
"""
|
||||
|
||||
def __init__(self, uri: URI):
|
||||
def __init__(
|
||||
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
|
||||
):
|
||||
if not isinstance(uri, Path):
|
||||
scheme = get_uri_scheme(uri)
|
||||
is_local = isinstance(uri, Path) or scheme == "file"
|
||||
@@ -277,6 +291,14 @@ class LanceDBConnection(DBConnection):
|
||||
self._uri = str(uri)
|
||||
|
||||
self._entered = False
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
|
||||
def __repr__(self) -> str:
|
||||
val = f"{self.__class__.__name__}({self._uri}"
|
||||
if self.read_consistency_interval is not None:
|
||||
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||
val += ")"
|
||||
return val
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
|
||||
@@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel):
|
||||
... name: str
|
||||
... vector: Vector(2)
|
||||
...
|
||||
>>> db = lancedb.connect("/tmp")
|
||||
>>> db = lancedb.connect("./example")
|
||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||
>>> table.add([
|
||||
... TestModel(name="test", vector=[1.0, 2.0])
|
||||
|
||||
@@ -14,7 +14,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -40,8 +43,6 @@ from .util import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
import PIL
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
|
||||
@@ -298,7 +299,7 @@ class Table(ABC):
|
||||
|
||||
import lance
|
||||
|
||||
dataset = lance.dataset("/tmp/images.lance")
|
||||
dataset = lance.dataset("./images.lance")
|
||||
dataset.create_scalar_index("category")
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -641,23 +642,145 @@ class Table(ABC):
|
||||
"""
|
||||
|
||||
|
||||
class _LanceDatasetRef(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def dataset(self) -> LanceDataset:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dataset_mut(self) -> LanceDataset:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||
"""Reference to the latest version of a LanceDataset."""
|
||||
|
||||
uri: str
|
||||
read_consistency_interval: Optional[timedelta] = None
|
||||
last_consistency_check: Optional[float] = None
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri)
|
||||
self.last_consistency_check = time.monotonic()
|
||||
elif self.read_consistency_interval is not None:
|
||||
now = time.monotonic()
|
||||
diff = timedelta(seconds=now - self.last_consistency_check)
|
||||
if (
|
||||
self.last_consistency_check is None
|
||||
or diff > self.read_consistency_interval
|
||||
):
|
||||
self._dataset = self._dataset.checkout_version(
|
||||
self._dataset.latest_version
|
||||
)
|
||||
self.last_consistency_check = time.monotonic()
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
def dataset(self, value: LanceDataset):
|
||||
self._dataset = value
|
||||
self.last_consistency_check = time.monotonic()
|
||||
|
||||
@property
|
||||
def dataset_mut(self) -> LanceDataset:
|
||||
return self.dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LanceTimeTravelRef(_LanceDatasetRef):
|
||||
uri: str
|
||||
version: int
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri, version=self.version)
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
def dataset(self, value: LanceDataset):
|
||||
self._dataset = value
|
||||
self.version = value.version
|
||||
|
||||
@property
|
||||
def dataset_mut(self) -> LanceDataset:
|
||||
raise ValueError(
|
||||
"Cannot mutate table reference fixed at version "
|
||||
f"{self.version}. Call checkout_latest() to get a mutable "
|
||||
"table reference."
|
||||
)
|
||||
|
||||
|
||||
class LanceTable(Table):
|
||||
"""
|
||||
A table in a LanceDB database.
|
||||
|
||||
This can be opened in two modes: standard and time-travel.
|
||||
|
||||
Standard mode is the default. In this mode, the table is mutable and tracks
|
||||
the latest version of the table. The level of read consistency is controlled
|
||||
by the `read_consistency_interval` parameter on the connection.
|
||||
|
||||
Time-travel mode is activated by specifying a version number. In this mode,
|
||||
the table is immutable and fixed to a specific version. This is useful for
|
||||
querying historical versions of the table.
|
||||
"""
|
||||
|
||||
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
|
||||
def __init__(
|
||||
self,
|
||||
connection: "LanceDBConnection",
|
||||
name: str,
|
||||
version: Optional[int] = None,
|
||||
):
|
||||
self._conn = connection
|
||||
self.name = name
|
||||
self._version = version
|
||||
|
||||
def _reset_dataset(self, version=None):
|
||||
try:
|
||||
if "_dataset" in self.__dict__:
|
||||
del self.__dict__["_dataset"]
|
||||
self._version = version
|
||||
except AttributeError:
|
||||
pass
|
||||
if version is not None:
|
||||
self._ref = _LanceTimeTravelRef(
|
||||
uri=self._dataset_uri,
|
||||
version=version,
|
||||
)
|
||||
else:
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=connection.read_consistency_interval,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name, **kwargs):
|
||||
tbl = cls(db, name, **kwargs)
|
||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
||||
file_info = fs.get_file_info(path)
|
||||
if file_info.type != pa.fs.FileType.Directory:
|
||||
raise FileNotFoundError(
|
||||
f"Table {name} does not exist."
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
)
|
||||
register_event("open_table")
|
||||
|
||||
return tbl
|
||||
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
@property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return self._ref.dataset
|
||||
|
||||
@property
|
||||
def _dataset_mut(self) -> LanceDataset:
|
||||
return self._ref.dataset_mut
|
||||
|
||||
def to_lance(self) -> LanceDataset:
|
||||
"""Return the LanceDataset backing this table."""
|
||||
return self._dataset
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
@@ -685,6 +808,9 @@ class LanceTable(Table):
|
||||
keep writing to the dataset starting from an old version, then use
|
||||
the `restore` function.
|
||||
|
||||
Calling this method will set the table into time-travel mode. If you
|
||||
wish to return to standard mode, call `checkout_latest`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version : int
|
||||
@@ -709,15 +835,13 @@ class LanceTable(Table):
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
max_ver = self._dataset.latest_version
|
||||
if version < 1 or version > max_ver:
|
||||
raise ValueError(f"Invalid version {version}")
|
||||
self._reset_dataset(version=version)
|
||||
|
||||
try:
|
||||
# Accessing the property updates the cached value
|
||||
_ = self._dataset
|
||||
except Exception as e:
|
||||
ds = self._dataset.checkout_version(version)
|
||||
except IOError as e:
|
||||
if "not found" in str(e):
|
||||
raise ValueError(
|
||||
f"Version {version} no longer exists. Was it cleaned up?"
|
||||
@@ -725,6 +849,27 @@ class LanceTable(Table):
|
||||
else:
|
||||
raise e
|
||||
|
||||
self._ref = _LanceTimeTravelRef(
|
||||
uri=self._dataset_uri,
|
||||
version=version,
|
||||
)
|
||||
# We've already loaded the version so we can populate it directly.
|
||||
self._ref.dataset = ds
|
||||
|
||||
def checkout_latest(self):
|
||||
"""Checkout the latest version of the table. This is an in-place operation.
|
||||
|
||||
The table will be set back into standard mode, and will track the latest
|
||||
version of the table.
|
||||
"""
|
||||
self.checkout(self._dataset.latest_version)
|
||||
ds = self._ref.dataset
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=self._conn.read_consistency_interval,
|
||||
)
|
||||
self._ref.dataset = ds
|
||||
|
||||
def restore(self, version: int = None):
|
||||
"""Restore a version of the table. This is an in-place operation.
|
||||
|
||||
@@ -759,7 +904,7 @@ class LanceTable(Table):
|
||||
>>> len(table.list_versions())
|
||||
4
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
max_ver = self._dataset.latest_version
|
||||
if version is None:
|
||||
version = self.version
|
||||
elif version < 1 or version > max_ver:
|
||||
@@ -767,12 +912,17 @@ class LanceTable(Table):
|
||||
else:
|
||||
self.checkout(version)
|
||||
|
||||
if version == max_ver:
|
||||
# no-op if restoring the latest version
|
||||
return
|
||||
ds = self._dataset
|
||||
|
||||
self._dataset.restore()
|
||||
self._reset_dataset()
|
||||
# no-op if restoring the latest version
|
||||
if version != max_ver:
|
||||
ds.restore()
|
||||
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=self._conn.read_consistency_interval,
|
||||
)
|
||||
self._ref.dataset = ds
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
@@ -789,7 +939,11 @@ class LanceTable(Table):
|
||||
return self.count_rows()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LanceTable({self.name})"
|
||||
val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"'
|
||||
if isinstance(self._ref, _LanceTimeTravelRef):
|
||||
val += f", version={self._ref.version}"
|
||||
val += ")"
|
||||
return val
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
@@ -839,10 +993,6 @@ class LanceTable(Table):
|
||||
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
|
||||
)
|
||||
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
@@ -854,7 +1004,7 @@ class LanceTable(Table):
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on the table."""
|
||||
self._dataset.create_index(
|
||||
self._dataset_mut.create_index(
|
||||
column=vector_column_name,
|
||||
index_type="IVF_PQ",
|
||||
metric=metric,
|
||||
@@ -864,10 +1014,11 @@ class LanceTable(Table):
|
||||
accelerator=accelerator,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
self._reset_dataset()
|
||||
|
||||
def create_scalar_index(self, column: str, *, replace: bool = True):
|
||||
self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace)
|
||||
self._dataset_mut.create_scalar_index(
|
||||
column, index_type="BTREE", replace=replace
|
||||
)
|
||||
|
||||
def create_fts_index(
|
||||
self,
|
||||
@@ -909,14 +1060,6 @@ class LanceTable(Table):
|
||||
def _get_fts_index_path(self):
|
||||
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
||||
|
||||
@cached_property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return lance.dataset(self._dataset_uri, version=self._version)
|
||||
|
||||
def to_lance(self) -> LanceDataset:
|
||||
"""Return the LanceDataset backing this table."""
|
||||
return self._dataset
|
||||
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
@@ -955,8 +1098,11 @@ class LanceTable(Table):
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
||||
self._reset_dataset()
|
||||
# Access the dataset_mut property to ensure that the dataset is mutable.
|
||||
self._ref.dataset_mut
|
||||
self._ref.dataset = lance.write_dataset(
|
||||
data, self._dataset_uri, schema=self.schema, mode=mode
|
||||
)
|
||||
|
||||
def merge(
|
||||
self,
|
||||
@@ -1016,10 +1162,9 @@ class LanceTable(Table):
|
||||
other_table = other_table.to_lance()
|
||||
if isinstance(other_table, LanceDataset):
|
||||
other_table = other_table.to_table()
|
||||
self._dataset.merge(
|
||||
self._ref.dataset = self._dataset_mut.merge(
|
||||
other_table, left_on=left_on, right_on=right_on, schema=schema
|
||||
)
|
||||
self._reset_dataset()
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
@@ -1219,22 +1364,8 @@ class LanceTable(Table):
|
||||
|
||||
return new_table
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name):
|
||||
tbl = cls(db, name)
|
||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
||||
file_info = fs.get_file_info(path)
|
||||
if file_info.type != pa.fs.FileType.Directory:
|
||||
raise FileNotFoundError(
|
||||
f"Table {name} does not exist."
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
)
|
||||
register_event("open_table")
|
||||
|
||||
return tbl
|
||||
|
||||
def delete(self, where: str):
|
||||
self._dataset.delete(where)
|
||||
self._dataset_mut.delete(where)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -1288,8 +1419,7 @@ class LanceTable(Table):
|
||||
if values is not None:
|
||||
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
||||
|
||||
self.to_lance().update(values_sql, where)
|
||||
self._reset_dataset()
|
||||
self._dataset_mut.update(values_sql, where)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
ds = self.to_lance()
|
||||
|
||||
@@ -45,7 +45,7 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
|
||||
@@ -88,6 +88,7 @@ def test_embedding_function(tmp_path):
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_embedding_function_rate_limit(tmp_path):
|
||||
def _get_schema_from_model(model):
|
||||
class Schema(LanceModel):
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from copy import copy
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
@@ -25,6 +27,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
@@ -35,6 +38,7 @@ from lancedb.table import LanceTable
|
||||
class MockDB:
|
||||
def __init__(self, uri: Path):
|
||||
self.uri = uri
|
||||
self.read_consistency_interval = None
|
||||
|
||||
@functools.cached_property
|
||||
def is_managed_remote(self) -> bool:
|
||||
@@ -267,39 +271,38 @@ def test_versioning(db):
|
||||
|
||||
|
||||
def test_create_index_method():
|
||||
with patch.object(LanceTable, "_reset_dataset", return_value=None):
|
||||
with patch.object(
|
||||
LanceTable, "_dataset", new_callable=PropertyMock
|
||||
) as mock_dataset:
|
||||
# Setup mock responses
|
||||
mock_dataset.return_value.create_index.return_value = None
|
||||
with patch.object(
|
||||
LanceTable, "_dataset_mut", new_callable=PropertyMock
|
||||
) as mock_dataset:
|
||||
# Setup mock responses
|
||||
mock_dataset.return_value.create_index.return_value = None
|
||||
|
||||
# Create a LanceTable object
|
||||
connection = LanceDBConnection(uri="mock.uri")
|
||||
table = LanceTable(connection, "test_table")
|
||||
# Create a LanceTable object
|
||||
connection = LanceDBConnection(uri="mock.uri")
|
||||
table = LanceTable(connection, "test_table")
|
||||
|
||||
# Call the create_index method
|
||||
table.create_index(
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name="vector",
|
||||
replace=True,
|
||||
index_cache_size=256,
|
||||
)
|
||||
# Call the create_index method
|
||||
table.create_index(
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name="vector",
|
||||
replace=True,
|
||||
index_cache_size=256,
|
||||
)
|
||||
|
||||
# Check that the _dataset.create_index method was called
|
||||
# with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
replace=True,
|
||||
accelerator=None,
|
||||
index_cache_size=256,
|
||||
)
|
||||
# Check that the _dataset.create_index method was called
|
||||
# with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
replace=True,
|
||||
accelerator=None,
|
||||
index_cache_size=256,
|
||||
)
|
||||
|
||||
|
||||
def test_add_with_nans(db):
|
||||
@@ -792,3 +795,48 @@ def test_hybrid_search(db):
|
||||
"Our father who art in heaven", query_type="hybrid"
|
||||
).to_pydantic(MyTable)
|
||||
assert result1 == result3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
|
||||
)
|
||||
def test_consistency(tmp_path, consistency_interval):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||
|
||||
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
|
||||
table2 = db2.open_table("my_table")
|
||||
assert table2.version == table.version
|
||||
|
||||
table.add([{"id": 1}])
|
||||
|
||||
if consistency_interval is None:
|
||||
assert table2.version == table.version - 1
|
||||
table2.checkout_latest()
|
||||
assert table2.version == table.version
|
||||
elif consistency_interval == timedelta(seconds=0):
|
||||
assert table2.version == table.version
|
||||
else:
|
||||
# (consistency_interval == timedelta(seconds=0.1)
|
||||
assert table2.version == table.version - 1
|
||||
sleep(0.1)
|
||||
assert table2.version == table.version
|
||||
|
||||
|
||||
def test_restore_consistency(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||
|
||||
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table2 = db2.open_table("my_table")
|
||||
assert table2.version == table.version
|
||||
|
||||
# If we call checkout, it should lose consistency
|
||||
table_fixed = copy(table2)
|
||||
table_fixed.checkout(table.version)
|
||||
# But if we call checkout_latest, it should be consistent again
|
||||
table_ref_latest = copy(table_fixed)
|
||||
table_ref_latest.checkout_latest()
|
||||
table.add([{"id": 2}])
|
||||
assert table_fixed.version == table.version - 1
|
||||
assert table_ref_latest.version == table.version
|
||||
|
||||
Reference in New Issue
Block a user