mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 10:30:40 +00:00
[python] Temporary restore feature (#428)
This adds LanceTable.restore as a temporary feature. It reads data from a previous version and creates a new snapshot version using that data. This makes the version writeable unlike checkout. This should be replaced once the feature is implemented in pylance. Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
@@ -268,10 +268,11 @@ class LanceTable(Table):
|
||||
self.name = name
|
||||
self._version = version
|
||||
|
||||
def _reset_dataset(self):
|
||||
def _reset_dataset(self, version=None):
|
||||
try:
|
||||
if "_dataset" in self.__dict__:
|
||||
del self.__dict__["_dataset"]
|
||||
self._version = version
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -297,7 +298,9 @@ class LanceTable(Table):
|
||||
def checkout(self, version: int):
|
||||
"""Checkout a version of the table. This is an in-place operation.
|
||||
|
||||
This allows viewing previous versions of the table.
|
||||
This allows viewing previous versions of the table. If you wish to
|
||||
keep writing to the dataset starting from an old version, then use
|
||||
the `restore` function instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -325,7 +328,49 @@ class LanceTable(Table):
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
if version < 1 or version > max_ver:
|
||||
raise ValueError(f"Invalid version {version}")
|
||||
self._version = version
|
||||
self._reset_dataset(version=version)
|
||||
|
||||
def restore(self, version: int):
|
||||
"""Restore a version of the table. This is an in-place operation.
|
||||
|
||||
This creates a new version where the data is equivalent to the
|
||||
specified previous version. Note that this creates a new snapshot.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version : int
|
||||
The version to restore.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table.version
|
||||
1
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.restore(1)
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> len(table.list_versions())
|
||||
3
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
if version < 1 or version >= max_ver:
|
||||
raise ValueError(f"Invalid version {version}")
|
||||
if version == max_ver:
|
||||
self._reset_dataset()
|
||||
return
|
||||
self.checkout(version)
|
||||
data = self.to_arrow()
|
||||
self.checkout(max_ver)
|
||||
self.add(data, mode="overwrite")
|
||||
self._reset_dataset()
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -268,3 +268,15 @@ def test_add_with_nans(db):
|
||||
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
|
||||
v = arrow_tbl["vector"].to_pylist()[0]
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
|
||||
def test_restore(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"vector": [1.1, 0.9], "type": "vector"}],
|
||||
)
|
||||
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
table.restore(1)
|
||||
assert len(table.list_versions()) == 3
|
||||
assert len(table) == 1
|
||||
|
||||
Reference in New Issue
Block a user