diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 5a79306c..50db93d5 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -46,9 +46,12 @@ class LanceTable: A table in a LanceDB database. """ - def __init__(self, connection: "lancedb.db.LanceDBConnection", name: str): + def __init__( + self, connection: "lancedb.db.LanceDBConnection", name: str, version: int = None + ): self._conn = connection self.name = name + self._version = version def _reset_dataset(self): try: @@ -61,6 +64,23 @@ class LanceTable: """Return the schema of the table.""" return self._dataset.schema + def list_versions(self): + """List all versions of the table""" + return self._dataset.versions() + + @property + def version(self): + """Get the current version of the table""" + return self._dataset.version + + def checkout(self, version: int): + """Checkout a version of the table""" + max_ver = max([v["version"] for v in self._dataset.versions()]) + if version < 1 or version > max_ver: + raise ValueError(f"Invalid version {version}") + self._version = version + self._reset_dataset() + def __len__(self): return self._dataset.count_rows() @@ -108,7 +128,7 @@ class LanceTable: @cached_property def _dataset(self) -> LanceDataset: - return lance.dataset(self._dataset_uri) + return lance.dataset(self._dataset_uri, version=self._version) def to_lance(self) -> LanceDataset: """Return the LanceDataset backing this table.""" diff --git a/python/tests/test_table.py b/python/tests/test_table.py index d5699faa..d66ac817 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -114,3 +114,26 @@ def test_add(db): ), ) assert expected == table.to_arrow() + + +def test_versioning(db): + table = LanceTable.create( + db, + "test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], + ) + + assert len(table.list_versions()) == 1 + assert table.version == 1 + + table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}]) + assert len(table.list_versions()) == 2 + assert table.version == 2 + assert len(table) == 3 + + table.checkout(1) + assert table.version == 1 + assert len(table) == 2