mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 22:29:58 +00:00
fix: add restore with tag in python and nodejs API (#2374)
add restore with tag API in python and nodejs API and add tests to guard them <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - The restore functionality now supports using version tags in addition to numeric version identifiers, allowing you to revert tables to a state marked by a tag. - **Bug Fixes** - Restoring with an unknown tag now properly raises an error. - **Documentation** - Updated documentation and examples to clarify that restore accepts both version numbers and tags. - **Tests** - Added new tests to verify restore behavior with version tags and error handling for unknown tags. - Added tests for checkout and restore operations involving tags. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -51,7 +51,7 @@ class Table:
|
||||
async def version(self) -> int: ...
|
||||
async def checkout(self, version: Union[int, str]): ...
|
||||
async def checkout_latest(self): ...
|
||||
async def restore(self, version: Optional[int] = None): ...
|
||||
async def restore(self, version: Optional[Union[int, str]] = None): ...
|
||||
async def list_indices(self) -> list[IndexConfig]: ...
|
||||
async def delete(self, filter: str) -> DeleteResult: ...
|
||||
async def add_columns(self, columns: list[tuple[str, str]]) -> AddColumnsResult: ...
|
||||
|
||||
@@ -100,7 +100,7 @@ class RemoteTable(Table):
|
||||
def checkout_latest(self):
|
||||
return LOOP.run(self._table.checkout_latest())
|
||||
|
||||
def restore(self, version: Optional[int] = None):
|
||||
def restore(self, version: Optional[Union[int, str]] = None):
|
||||
return LOOP.run(self._table.restore(version))
|
||||
|
||||
def list_indices(self) -> Iterable[IndexConfig]:
|
||||
|
||||
@@ -1470,7 +1470,7 @@ class Table(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def restore(self, version: Optional[int] = None):
|
||||
def restore(self, version: Optional[Union[int, str]] = None):
|
||||
"""Restore a version of the table. This is an in-place operation.
|
||||
|
||||
This creates a new version where the data is equivalent to the
|
||||
@@ -1478,9 +1478,10 @@ class Table(ABC):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version : int, default None
|
||||
The version to restore. If unspecified then restores the currently
|
||||
checked out version. If the currently checked out version is the
|
||||
version : int or str, default None
|
||||
The version number or version tag to restore.
|
||||
If unspecified then restores the currently checked out version.
|
||||
If the currently checked out version is the
|
||||
latest version then this is a no-op.
|
||||
"""
|
||||
|
||||
@@ -1710,7 +1711,7 @@ class LanceTable(Table):
|
||||
"""
|
||||
LOOP.run(self._table.checkout_latest())
|
||||
|
||||
def restore(self, version: Optional[int] = None):
|
||||
def restore(self, version: Optional[Union[int, str]] = None):
|
||||
"""Restore a version of the table. This is an in-place operation.
|
||||
|
||||
This creates a new version where the data is equivalent to the
|
||||
@@ -1718,9 +1719,10 @@ class LanceTable(Table):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version : int, default None
|
||||
The version to restore. If unspecified then restores the currently
|
||||
checked out version. If the currently checked out version is the
|
||||
version : int or str, default None
|
||||
The version number or version tag to restore.
|
||||
If unspecified then restores the currently checked out version.
|
||||
If the currently checked out version is the
|
||||
latest version then this is a no-op.
|
||||
|
||||
Examples
|
||||
@@ -1738,12 +1740,20 @@ class LanceTable(Table):
|
||||
AddResult(version=2)
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.tags.create("v2", 2)
|
||||
>>> table.restore(1)
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> len(table.list_versions())
|
||||
3
|
||||
>>> table.restore("v2")
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
1 [0.5, 0.2] vector
|
||||
>>> len(table.list_versions())
|
||||
4
|
||||
"""
|
||||
if version is not None:
|
||||
LOOP.run(self._table.checkout(version))
|
||||
@@ -3962,7 +3972,7 @@ class AsyncTable:
|
||||
"""
|
||||
await self._inner.checkout_latest()
|
||||
|
||||
async def restore(self, version: Optional[int] = None):
|
||||
async def restore(self, version: Optional[int | str] = None):
|
||||
"""
|
||||
Restore the table to the currently checked out version
|
||||
|
||||
|
||||
@@ -769,6 +769,29 @@ def test_restore(mem_db: DBConnection):
|
||||
table.restore(0)
|
||||
|
||||
|
||||
def test_restore_with_tags(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"my_table",
|
||||
data=[{"vector": [1.1, 0.9], "type": "vector"}],
|
||||
)
|
||||
tag = "tag1"
|
||||
table.tags.create(tag, 1)
|
||||
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
table.restore(tag)
|
||||
assert len(table.list_versions()) == 3
|
||||
assert len(table) == 1
|
||||
expected = table.to_arrow()
|
||||
|
||||
table.add([{"vector": [0.3, 0.3], "type": "vector"}])
|
||||
table.checkout("tag1")
|
||||
table.restore()
|
||||
assert len(table.list_versions()) == 5
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.restore("tag_unknown")
|
||||
|
||||
|
||||
def test_merge(tmp_db: DBConnection, tmp_path):
|
||||
pytest.importorskip("lance")
|
||||
import lance
|
||||
|
||||
@@ -17,10 +17,10 @@ use lancedb::table::{
|
||||
Table as LanceDbTable,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyIOError, PyKeyError, PyRuntimeError, PyValueError},
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods,
|
||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods, PyInt, PyString},
|
||||
Bound, FromPyObject, PyAny, PyObject, PyRef, PyResult, Python,
|
||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
||||
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
@@ -520,25 +520,15 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn checkout(self_: PyRef<'_, Self>, version: PyObject) -> PyResult<Bound<'_, PyAny>> {
|
||||
pub fn checkout(self_: PyRef<'_, Self>, version: LanceVersion) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
let py = self_.py();
|
||||
let (is_int, int_value, string_value) = if let Ok(i) = version.downcast_bound::<PyInt>(py) {
|
||||
let num: u64 = i.extract()?;
|
||||
(true, num, String::new())
|
||||
} else if let Ok(s) = version.downcast_bound::<PyString>(py) {
|
||||
let str_value = s.to_string();
|
||||
(false, 0, str_value)
|
||||
} else {
|
||||
return Err(PyIOError::new_err(
|
||||
"version must be an integer or a string.",
|
||||
));
|
||||
};
|
||||
future_into_py(py, async move {
|
||||
if is_int {
|
||||
inner.checkout(int_value).await.infer_error()
|
||||
} else {
|
||||
inner.checkout_tag(&string_value).await.infer_error()
|
||||
match version {
|
||||
LanceVersion::Version(version_num) => {
|
||||
inner.checkout(version_num).await.infer_error()
|
||||
}
|
||||
LanceVersion::Tag(tag) => inner.checkout_tag(&tag).await.infer_error(),
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -551,12 +541,19 @@ impl Table {
|
||||
}
|
||||
|
||||
#[pyo3(signature = (version=None))]
|
||||
pub fn restore(self_: PyRef<'_, Self>, version: Option<u64>) -> PyResult<Bound<'_, PyAny>> {
|
||||
pub fn restore(
|
||||
self_: PyRef<'_, Self>,
|
||||
version: Option<LanceVersion>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
let py = self_.py();
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
future_into_py(py, async move {
|
||||
if let Some(version) = version {
|
||||
inner.checkout(version).await.infer_error()?;
|
||||
match version {
|
||||
LanceVersion::Version(num) => inner.checkout(num).await.infer_error()?,
|
||||
LanceVersion::Tag(tag) => inner.checkout_tag(&tag).await.infer_error()?,
|
||||
}
|
||||
}
|
||||
inner.restore().await.infer_error()
|
||||
})
|
||||
@@ -795,6 +792,12 @@ impl Table {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
pub enum LanceVersion {
|
||||
Version(u64),
|
||||
Tag(String),
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
#[pyo3(from_item_all)]
|
||||
pub struct MergeInsertParams {
|
||||
|
||||
Reference in New Issue
Block a user