Merge pull request #35 from lancedb/changhiskhan/table-versioning

expose methods to work with versioning in tables
This commit is contained in:
Chang She
2023-04-19 21:09:32 -07:00
committed by GitHub
2 changed files with 45 additions and 2 deletions

View File

@@ -46,9 +46,12 @@ class LanceTable:
A table in a LanceDB database.
"""
def __init__(self, connection: "lancedb.db.LanceDBConnection", name: str):
def __init__(
self, connection: "lancedb.db.LanceDBConnection", name: str, version: int = None
):
self._conn = connection
self.name = name
self._version = version
def _reset_dataset(self):
try:
@@ -61,6 +64,23 @@ class LanceTable:
"""Return the schema of the table."""
return self._dataset.schema
def list_versions(self):
"""List all versions of the table"""
return self._dataset.versions()
@property
def version(self):
"""Get the current version of the table"""
return self._dataset.version
def checkout(self, version: int):
"""Checkout a version of the table"""
max_ver = max([v["version"] for v in self._dataset.versions()])
if version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}")
self._version = version
self._reset_dataset()
def __len__(self):
return self._dataset.count_rows()
@@ -108,7 +128,7 @@ class LanceTable:
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri)
return lance.dataset(self._dataset_uri, version=self._version)
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""

View File

@@ -114,3 +114,26 @@ def test_add(db):
),
)
assert expected == table.to_arrow()
def test_versioning(db):
table = LanceTable.create(
db,
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
assert len(table.list_versions()) == 1
assert table.version == 1
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 2
assert table.version == 2
assert len(table) == 3
table.checkout(1)
assert table.version == 1
assert len(table) == 2