[python] Temporary restore feature (#428)

This adds LanceTable.restore as a temporary feature. It reads data from
a previous version and creates
a new snapshot version using that data. This makes the version writeable
unlike checkout. This should be replaced once the feature is implemented
in pylance.

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-08-14 20:10:29 -07:00
committed by GitHub
parent 1fcc67fd2c
commit e3061d4cb4
2 changed files with 60 additions and 3 deletions

View File

@@ -268,10 +268,11 @@ class LanceTable(Table):
self.name = name
self._version = version
def _reset_dataset(self):
def _reset_dataset(self, version=None):
try:
if "_dataset" in self.__dict__:
del self.__dict__["_dataset"]
self._version = version
except AttributeError:
pass
@@ -297,7 +298,9 @@ class LanceTable(Table):
def checkout(self, version: int):
"""Checkout a version of the table. This is an in-place operation.
This allows viewing previous versions of the table.
This allows viewing previous versions of the table. If you wish to
keep writing to the dataset starting from an old version, then use
the `restore` function instead.
Parameters
----------
@@ -325,7 +328,49 @@ class LanceTable(Table):
max_ver = max([v["version"] for v in self._dataset.versions()])
if version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}")
self._version = version
self._reset_dataset(version=version)
def restore(self, version: int):
"""Restore a version of the table. This is an in-place operation.
This creates a new version where the data is equivalent to the
specified previous version. Note that this creates a new snapshot.
Parameters
----------
version : int
The version to restore.
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version
1
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
>>> table.version
2
>>> table.restore(1)
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
>>> len(table.list_versions())
3
"""
max_ver = max([v["version"] for v in self._dataset.versions()])
if version < 1 or version >= max_ver:
raise ValueError(f"Invalid version {version}")
if version == max_ver:
self._reset_dataset()
return
self.checkout(version)
data = self.to_arrow()
self.checkout(max_ver)
self.add(data, mode="overwrite")
self._reset_dataset()
def __len__(self):