From c9ae1b17379d67d170c37abdc8df0de2c7eb174a Mon Sep 17 00:00:00 2001 From: LuQQiu Date: Tue, 6 May 2025 16:12:58 -0700 Subject: [PATCH] fix: add restore with tag in python and nodejs API (#2374) add restore with tag API in python and nodejs API and add tests to guard them ## Summary by CodeRabbit - **New Features** - The restore functionality now supports using version tags in addition to numeric version identifiers, allowing you to revert tables to a state marked by a tag. - **Bug Fixes** - Restoring with an unknown tag now properly raises an error. - **Documentation** - Updated documentation and examples to clarify that restore accepts both version numbers and tags. - **Tests** - Added new tests to verify restore behavior with version tags and error handling for unknown tags. - Added tests for checkout and restore operations involving tags. --- nodejs/__test__/table.test.ts | 26 +++++++++++++++ python/python/lancedb/_lancedb.pyi | 2 +- python/python/lancedb/remote/table.py | 2 +- python/python/lancedb/table.py | 28 +++++++++++----- python/python/tests/test_table.py | 23 +++++++++++++ python/src/table.rs | 47 ++++++++++++++------------- 6 files changed, 95 insertions(+), 33 deletions(-) diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index dcc385b3..4b23a82e 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -1287,6 +1287,32 @@ describe("when dealing with tags", () => { await table.checkoutLatest(); expect(await table.version()).toBe(4); }); + + it("can checkout and restore tags", async () => { + const conn = await connect(tmpDir.name, { + readConsistencyInterval: 0, + }); + + const table = await conn.createTable("my_table", [ + { id: 1n, vector: [0.1, 0.2] }, + ]); + expect(await table.version()).toBe(1); + expect(await table.countRows()).toBe(1); + const tagsManager = await table.tags(); + const tag1 = "tag1"; + await tagsManager.create(tag1, 1); + await table.add([{ id: 2n, vector: [0.3, 0.4] }]); + const tag2 = "tag2"; + await tagsManager.create(tag2, 2); + expect(await table.version()).toBe(2); + await table.checkout(tag1); + expect(await table.version()).toBe(1); + await table.restore(); + expect(await table.version()).toBe(3); + expect(await table.countRows()).toBe(1); + await table.add([{ id: 3n, vector: [0.5, 0.6] }]); + expect(await table.countRows()).toBe(2); + }); }); describe("when optimizing a dataset", () => { diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 7e9934aa..00a507f7 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -51,7 +51,7 @@ class Table: async def version(self) -> int: ... async def checkout(self, version: Union[int, str]): ... async def checkout_latest(self): ... - async def restore(self, version: Optional[int] = None): ... + async def restore(self, version: Optional[Union[int, str]] = None): ... async def list_indices(self) -> list[IndexConfig]: ... async def delete(self, filter: str) -> DeleteResult: ... async def add_columns(self, columns: list[tuple[str, str]]) -> AddColumnsResult: ... diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index ed6d14ea..d8aae374 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -100,7 +100,7 @@ class RemoteTable(Table): def checkout_latest(self): return LOOP.run(self._table.checkout_latest()) - def restore(self, version: Optional[int] = None): + def restore(self, version: Optional[Union[int, str]] = None): return LOOP.run(self._table.restore(version)) def list_indices(self) -> Iterable[IndexConfig]: diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 91a8fea5..7bf43c9d 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1470,7 +1470,7 @@ class Table(ABC): """ @abstractmethod - def restore(self, version: Optional[int] = None): + def restore(self, version: Optional[Union[int, str]] = None): """Restore a version of the table. This is an in-place operation. This creates a new version where the data is equivalent to the @@ -1478,9 +1478,10 @@ class Table(ABC): Parameters ---------- - version : int, default None - The version to restore. If unspecified then restores the currently - checked out version. If the currently checked out version is the + version : int or str, default None + The version number or version tag to restore. + If unspecified then restores the currently checked out version. + If the currently checked out version is the latest version then this is a no-op. """ @@ -1710,7 +1711,7 @@ class LanceTable(Table): """ LOOP.run(self._table.checkout_latest()) - def restore(self, version: Optional[int] = None): + def restore(self, version: Optional[Union[int, str]] = None): """Restore a version of the table. This is an in-place operation. This creates a new version where the data is equivalent to the @@ -1718,9 +1719,10 @@ class LanceTable(Table): Parameters ---------- - version : int, default None - The version to restore. If unspecified then restores the currently - checked out version. If the currently checked out version is the + version : int or str, default None + The version number or version tag to restore. + If unspecified then restores the currently checked out version. + If the currently checked out version is the latest version then this is a no-op. Examples @@ -1738,12 +1740,20 @@ class LanceTable(Table): AddResult(version=2) >>> table.version 2 + >>> table.tags.create("v2", 2) >>> table.restore(1) >>> table.to_pandas() vector type 0 [1.1, 0.9] vector >>> len(table.list_versions()) 3 + >>> table.restore("v2") + >>> table.to_pandas() + vector type + 0 [1.1, 0.9] vector + 1 [0.5, 0.2] vector + >>> len(table.list_versions()) + 4 """ if version is not None: LOOP.run(self._table.checkout(version)) @@ -3962,7 +3972,7 @@ class AsyncTable: """ await self._inner.checkout_latest() - async def restore(self, version: Optional[int] = None): + async def restore(self, version: Optional[int | str] = None): """ Restore the table to the currently checked out version diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index af412c5e..a8529cb3 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -769,6 +769,29 @@ def test_restore(mem_db: DBConnection): table.restore(0) +def test_restore_with_tags(mem_db: DBConnection): + table = mem_db.create_table( + "my_table", + data=[{"vector": [1.1, 0.9], "type": "vector"}], + ) + tag = "tag1" + table.tags.create(tag, 1) + table.add([{"vector": [0.5, 0.2], "type": "vector"}]) + table.restore(tag) + assert len(table.list_versions()) == 3 + assert len(table) == 1 + expected = table.to_arrow() + + table.add([{"vector": [0.3, 0.3], "type": "vector"}]) + table.checkout("tag1") + table.restore() + assert len(table.list_versions()) == 5 + assert table.to_arrow() == expected + + with pytest.raises(ValueError): + table.restore("tag_unknown") + + def test_merge(tmp_db: DBConnection, tmp_path): pytest.importorskip("lance") import lance diff --git a/python/src/table.rs b/python/src/table.rs index 820b22bd..e00ebd95 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -17,10 +17,10 @@ use lancedb::table::{ Table as LanceDbTable, }; use pyo3::{ - exceptions::{PyIOError, PyKeyError, PyRuntimeError, PyValueError}, + exceptions::{PyKeyError, PyRuntimeError, PyValueError}, pyclass, pymethods, - types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods, PyInt, PyString}, - Bound, FromPyObject, PyAny, PyObject, PyRef, PyResult, Python, + types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods}, + Bound, FromPyObject, PyAny, PyRef, PyResult, Python, }; use pyo3_async_runtimes::tokio::future_into_py; @@ -520,25 +520,15 @@ impl Table { }) } - pub fn checkout(self_: PyRef<'_, Self>, version: PyObject) -> PyResult> { + pub fn checkout(self_: PyRef<'_, Self>, version: LanceVersion) -> PyResult> { let inner = self_.inner_ref()?.clone(); let py = self_.py(); - let (is_int, int_value, string_value) = if let Ok(i) = version.downcast_bound::(py) { - let num: u64 = i.extract()?; - (true, num, String::new()) - } else if let Ok(s) = version.downcast_bound::(py) { - let str_value = s.to_string(); - (false, 0, str_value) - } else { - return Err(PyIOError::new_err( - "version must be an integer or a string.", - )); - }; future_into_py(py, async move { - if is_int { - inner.checkout(int_value).await.infer_error() - } else { - inner.checkout_tag(&string_value).await.infer_error() + match version { + LanceVersion::Version(version_num) => { + inner.checkout(version_num).await.infer_error() + } + LanceVersion::Tag(tag) => inner.checkout_tag(&tag).await.infer_error(), } }) } @@ -551,12 +541,19 @@ impl Table { } #[pyo3(signature = (version=None))] - pub fn restore(self_: PyRef<'_, Self>, version: Option) -> PyResult> { + pub fn restore( + self_: PyRef<'_, Self>, + version: Option, + ) -> PyResult> { let inner = self_.inner_ref()?.clone(); + let py = self_.py(); - future_into_py(self_.py(), async move { + future_into_py(py, async move { if let Some(version) = version { - inner.checkout(version).await.infer_error()?; + match version { + LanceVersion::Version(num) => inner.checkout(num).await.infer_error()?, + LanceVersion::Tag(tag) => inner.checkout_tag(&tag).await.infer_error()?, + } } inner.restore().await.infer_error() }) @@ -795,6 +792,12 @@ impl Table { } } +#[derive(FromPyObject)] +pub enum LanceVersion { + Version(u64), + Tag(String), +} + #[derive(FromPyObject)] #[pyo3(from_item_all)] pub struct MergeInsertParams {