diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 89510f90..c3c9adc8 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -311,7 +311,7 @@ class LanceTable(Table): This allows viewing previous versions of the table. If you wish to keep writing to the dataset starting from an old version, then use - the `restore` function instead. + the `restore` function. Parameters ---------- @@ -341,16 +341,18 @@ class LanceTable(Table): raise ValueError(f"Invalid version {version}") self._reset_dataset(version=version) - def restore(self, version: int): + def restore(self, version: int = None): """Restore a version of the table. This is an in-place operation. This creates a new version where the data is equivalent to the - specified previous version. Note that this creates a new snapshot. + specified previous version. Data is not copied (as of python-v0.2.1). Parameters ---------- - version : int - The version to restore. + version : int, default None + The version to restore. If unspecified then restores the currently + checked out version. If the currently checked out version is the + latest version then this is a no-op. Examples -------- @@ -373,15 +375,18 @@ class LanceTable(Table): 3 """ max_ver = max([v["version"] for v in self._dataset.versions()]) - if version < 1 or version >= max_ver: + if version is None: + version = self.version + elif version < 1 or version > max_ver: raise ValueError(f"Invalid version {version}") + else: + self.checkout(version) + if version == max_ver: - self._reset_dataset() + # no-op if restoring the latest version return - self.checkout(version) - data = self.to_arrow() - self.checkout(max_ver) - self.add(data, mode="overwrite") + + self._dataset.restore() self._reset_dataset() def __len__(self): diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 7d399644..defd4568 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -280,3 +280,18 @@ def test_restore(db): table.restore(1) assert len(table.list_versions()) == 3 assert len(table) == 1 + + expected = table.to_arrow() + table.checkout(1) + table.restore() + assert len(table.list_versions()) == 4 + assert table.to_arrow() == expected + + table.restore(4) # latest version should be no-op + assert len(table.list_versions()) == 4 + + with pytest.raises(ValueError): + table.restore(5) + + with pytest.raises(ValueError): + table.restore(0)