From 2f1f9f6338fa82cb868190ccafffa362b11b2e0b Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Thu, 24 Aug 2023 11:00:34 -0700 Subject: [PATCH] [python] improve restore functionality (#451) Previously the temporary restore feature required copying data. The new feature in pylance does not. --------- Co-authored-by: Chang She Co-authored-by: Weston Pace --- python/lancedb/table.py | 27 ++++++++++++++++----------- python/tests/test_table.py | 15 +++++++++++++++ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 89510f90..c3c9adc8 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -311,7 +311,7 @@ class LanceTable(Table): This allows viewing previous versions of the table. If you wish to keep writing to the dataset starting from an old version, then use - the `restore` function instead. + the `restore` function. Parameters ---------- @@ -341,16 +341,18 @@ class LanceTable(Table): raise ValueError(f"Invalid version {version}") self._reset_dataset(version=version) - def restore(self, version: int): + def restore(self, version: int = None): """Restore a version of the table. This is an in-place operation. This creates a new version where the data is equivalent to the - specified previous version. Note that this creates a new snapshot. + specified previous version. Data is not copied (as of python-v0.2.1). Parameters ---------- - version : int - The version to restore. + version : int, default None + The version to restore. If unspecified then restores the currently + checked out version. If the currently checked out version is the + latest version then this is a no-op. Examples -------- @@ -373,15 +375,18 @@ class LanceTable(Table): 3 """ max_ver = max([v["version"] for v in self._dataset.versions()]) - if version < 1 or version >= max_ver: + if version is None: + version = self.version + elif version < 1 or version > max_ver: raise ValueError(f"Invalid version {version}") + else: + self.checkout(version) + if version == max_ver: - self._reset_dataset() + # no-op if restoring the latest version return - self.checkout(version) - data = self.to_arrow() - self.checkout(max_ver) - self.add(data, mode="overwrite") + + self._dataset.restore() self._reset_dataset() def __len__(self): diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 7d399644..defd4568 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -280,3 +280,18 @@ def test_restore(db): table.restore(1) assert len(table.list_versions()) == 3 assert len(table) == 1 + + expected = table.to_arrow() + table.checkout(1) + table.restore() + assert len(table.list_versions()) == 4 + assert table.to_arrow() == expected + + table.restore(4) # latest version should be no-op + assert len(table.list_versions()) == 4 + + with pytest.raises(ValueError): + table.restore(5) + + with pytest.raises(ValueError): + table.restore(0)