From e3061d4cb4a79915a4a47c8359c9a5121bcd0310 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Mon, 14 Aug 2023 20:10:29 -0700 Subject: [PATCH] [python] Temporary restore feature (#428) This adds LanceTable.restore as a temporary feature. It reads data from a previous version and creates a new snapshot version using that data. This makes the version writeable unlike checkout. This should be replaced once the feature is implemented in pylance. Co-authored-by: Chang She --- python/lancedb/table.py | 51 +++++++++++++++++++++++++++++++++++--- python/tests/test_table.py | 12 +++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 10366849..39eb94b6 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -268,10 +268,11 @@ class LanceTable(Table): self.name = name self._version = version - def _reset_dataset(self): + def _reset_dataset(self, version=None): try: if "_dataset" in self.__dict__: del self.__dict__["_dataset"] + self._version = version except AttributeError: pass @@ -297,7 +298,9 @@ class LanceTable(Table): def checkout(self, version: int): """Checkout a version of the table. This is an in-place operation. - This allows viewing previous versions of the table. + This allows viewing previous versions of the table. If you wish to + keep writing to the dataset starting from an old version, then use + the `restore` function instead. Parameters ---------- @@ -325,7 +328,49 @@ class LanceTable(Table): max_ver = max([v["version"] for v in self._dataset.versions()]) if version < 1 or version > max_ver: raise ValueError(f"Invalid version {version}") - self._version = version + self._reset_dataset(version=version) + + def restore(self, version: int): + """Restore a version of the table. This is an in-place operation. + + This creates a new version where the data is equivalent to the + specified previous version. Note that this creates a new snapshot. + + Parameters + ---------- + version : int + The version to restore. + + Examples + -------- + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}]) + >>> table.version + 1 + >>> table.to_pandas() + vector type + 0 [1.1, 0.9] vector + >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) + >>> table.version + 2 + >>> table.restore(1) + >>> table.to_pandas() + vector type + 0 [1.1, 0.9] vector + >>> len(table.list_versions()) + 3 + """ + max_ver = max([v["version"] for v in self._dataset.versions()]) + if version < 1 or version >= max_ver: + raise ValueError(f"Invalid version {version}") + if version == max_ver: + self._reset_dataset() + return + self.checkout(version) + data = self.to_arrow() + self.checkout(max_ver) + self.add(data, mode="overwrite") self._reset_dataset() def __len__(self): diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 8b892f75..7d399644 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -268,3 +268,15 @@ def test_add_with_nans(db): arrow_tbl = table.to_lance().to_table(filter="item == 'bar'") v = arrow_tbl["vector"].to_pylist()[0] assert np.allclose(v, np.array([0.0, 0.0])) + + +def test_restore(db): + table = LanceTable.create( + db, + "my_table", + data=[{"vector": [1.1, 0.9], "type": "vector"}], + ) + table.add([{"vector": [0.5, 0.2], "type": "vector"}]) + table.restore(1) + assert len(table.list_versions()) == 3 + assert len(table) == 1