fix: add restore with tag in python and nodejs API (#2374)

add restore with tag API in python and nodejs API and add tests to guard
them

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- The restore functionality now supports using version tags in addition
to numeric version identifiers, allowing you to revert tables to a state
marked by a tag.
- **Bug Fixes**
  - Restoring with an unknown tag now properly raises an error.
- **Documentation**
- Updated documentation and examples to clarify that restore accepts
both version numbers and tags.
- **Tests**
- Added new tests to verify restore behavior with version tags and error
handling for unknown tags.
  - Added tests for checkout and restore operations involving tags.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
LuQQiu
2025-05-06 16:12:58 -07:00
committed by GitHub
parent 89dc80c42a
commit c9ae1b1737
6 changed files with 95 additions and 33 deletions

View File

@@ -51,7 +51,7 @@ class Table:
async def version(self) -> int: ...
async def checkout(self, version: Union[int, str]): ...
async def checkout_latest(self): ...
async def restore(self, version: Optional[int] = None): ...
async def restore(self, version: Optional[Union[int, str]] = None): ...
async def list_indices(self) -> list[IndexConfig]: ...
async def delete(self, filter: str) -> DeleteResult: ...
async def add_columns(self, columns: list[tuple[str, str]]) -> AddColumnsResult: ...

View File

@@ -100,7 +100,7 @@ class RemoteTable(Table):
def checkout_latest(self):
return LOOP.run(self._table.checkout_latest())
def restore(self, version: Optional[int] = None):
def restore(self, version: Optional[Union[int, str]] = None):
return LOOP.run(self._table.restore(version))
def list_indices(self) -> Iterable[IndexConfig]:

View File

@@ -1470,7 +1470,7 @@ class Table(ABC):
"""
@abstractmethod
def restore(self, version: Optional[int] = None):
def restore(self, version: Optional[Union[int, str]] = None):
"""Restore a version of the table. This is an in-place operation.
This creates a new version where the data is equivalent to the
@@ -1478,9 +1478,10 @@ class Table(ABC):
Parameters
----------
version : int, default None
The version to restore. If unspecified then restores the currently
checked out version. If the currently checked out version is the
version : int or str, default None
The version number or version tag to restore.
If unspecified then restores the currently checked out version.
If the currently checked out version is the
latest version then this is a no-op.
"""
@@ -1710,7 +1711,7 @@ class LanceTable(Table):
"""
LOOP.run(self._table.checkout_latest())
def restore(self, version: Optional[int] = None):
def restore(self, version: Optional[Union[int, str]] = None):
"""Restore a version of the table. This is an in-place operation.
This creates a new version where the data is equivalent to the
@@ -1718,9 +1719,10 @@ class LanceTable(Table):
Parameters
----------
version : int, default None
The version to restore. If unspecified then restores the currently
checked out version. If the currently checked out version is the
version : int or str, default None
The version number or version tag to restore.
If unspecified then restores the currently checked out version.
If the currently checked out version is the
latest version then this is a no-op.
Examples
@@ -1738,12 +1740,20 @@ class LanceTable(Table):
AddResult(version=2)
>>> table.version
2
>>> table.tags.create("v2", 2)
>>> table.restore(1)
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
>>> len(table.list_versions())
3
>>> table.restore("v2")
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
1 [0.5, 0.2] vector
>>> len(table.list_versions())
4
"""
if version is not None:
LOOP.run(self._table.checkout(version))
@@ -3962,7 +3972,7 @@ class AsyncTable:
"""
await self._inner.checkout_latest()
async def restore(self, version: Optional[int] = None):
async def restore(self, version: Optional[int | str] = None):
"""
Restore the table to the currently checked out version

View File

@@ -769,6 +769,29 @@ def test_restore(mem_db: DBConnection):
table.restore(0)
def test_restore_with_tags(mem_db: DBConnection):
table = mem_db.create_table(
"my_table",
data=[{"vector": [1.1, 0.9], "type": "vector"}],
)
tag = "tag1"
table.tags.create(tag, 1)
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
table.restore(tag)
assert len(table.list_versions()) == 3
assert len(table) == 1
expected = table.to_arrow()
table.add([{"vector": [0.3, 0.3], "type": "vector"}])
table.checkout("tag1")
table.restore()
assert len(table.list_versions()) == 5
assert table.to_arrow() == expected
with pytest.raises(ValueError):
table.restore("tag_unknown")
def test_merge(tmp_db: DBConnection, tmp_path):
pytest.importorskip("lance")
import lance

View File

@@ -17,10 +17,10 @@ use lancedb::table::{
Table as LanceDbTable,
};
use pyo3::{
exceptions::{PyIOError, PyKeyError, PyRuntimeError, PyValueError},
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods,
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods, PyInt, PyString},
Bound, FromPyObject, PyAny, PyObject, PyRef, PyResult, Python,
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;
@@ -520,25 +520,15 @@ impl Table {
})
}
pub fn checkout(self_: PyRef<'_, Self>, version: PyObject) -> PyResult<Bound<'_, PyAny>> {
pub fn checkout(self_: PyRef<'_, Self>, version: LanceVersion) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
let py = self_.py();
let (is_int, int_value, string_value) = if let Ok(i) = version.downcast_bound::<PyInt>(py) {
let num: u64 = i.extract()?;
(true, num, String::new())
} else if let Ok(s) = version.downcast_bound::<PyString>(py) {
let str_value = s.to_string();
(false, 0, str_value)
} else {
return Err(PyIOError::new_err(
"version must be an integer or a string.",
));
};
future_into_py(py, async move {
if is_int {
inner.checkout(int_value).await.infer_error()
} else {
inner.checkout_tag(&string_value).await.infer_error()
match version {
LanceVersion::Version(version_num) => {
inner.checkout(version_num).await.infer_error()
}
LanceVersion::Tag(tag) => inner.checkout_tag(&tag).await.infer_error(),
}
})
}
@@ -551,12 +541,19 @@ impl Table {
}
#[pyo3(signature = (version=None))]
pub fn restore(self_: PyRef<'_, Self>, version: Option<u64>) -> PyResult<Bound<'_, PyAny>> {
pub fn restore(
self_: PyRef<'_, Self>,
version: Option<LanceVersion>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
let py = self_.py();
future_into_py(self_.py(), async move {
future_into_py(py, async move {
if let Some(version) = version {
inner.checkout(version).await.infer_error()?;
match version {
LanceVersion::Version(num) => inner.checkout(num).await.infer_error()?,
LanceVersion::Tag(tag) => inner.checkout_tag(&tag).await.infer_error()?,
}
}
inner.restore().await.infer_error()
})
@@ -795,6 +792,12 @@ impl Table {
}
}
#[derive(FromPyObject)]
pub enum LanceVersion {
Version(u64),
Tag(String),
}
#[derive(FromPyObject)]
#[pyo3(from_item_all)]
pub struct MergeInsertParams {