diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index afbd62086..4e36bfbfb 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -226,6 +226,8 @@ class Table: async def close_lsm_writers(self) -> None: ... @property def tags(self) -> Tags: ... + @property + def branches(self) -> Branches: ... def query(self) -> Query: ... def take_offsets(self, offsets: list[int]) -> TakeQuery: ... def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ... @@ -238,6 +240,17 @@ class Tags: async def delete(self, tag: str): ... async def update(self, tag: str, version: int): ... +class Branches: + async def list(self) -> Dict[str, Any]: ... + async def create( + self, + name: str, + from_ref: Optional[str] = None, + from_version: Optional[int] = None, + ) -> Table: ... + async def checkout(self, name: str) -> Table: ... + async def delete(self, name: str) -> None: ... + class IndexConfig: name: str index_type: str diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 7adc2cc54..6f585ad76 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -758,6 +758,15 @@ class Table(ABC): """ raise NotImplementedError + @property + def branches(self) -> "Branches": + """Branch management for the table. + + Branches are isolated, writable lines of history forked from another + branch (or version). Writes on a branch do not affect ``main``. + """ + raise NotImplementedError + def __len__(self) -> int: """The number of rows in this Table""" return self.count_rows(None) @@ -2167,6 +2176,15 @@ class LanceTable(Table): """ return Tags(self._table) + @property + def branches(self) -> "Branches": + """Branch management for the table. + + ``create``/``checkout`` return a new table handle scoped to the branch; + writes on it do not affect ``main``. + """ + return Branches(self._table) + def checkout(self, version: Union[int, str]): """Checkout a version of the table. This is an in-place operation. @@ -5442,6 +5460,15 @@ class AsyncTable: """ return AsyncTags(self._inner) + @property + def branches(self) -> AsyncBranches: + """Branch management for the table. + + Branches are isolated, writable lines of history forked from another + branch (or version). Writes on a branch do not affect ``main``. + """ + return AsyncBranches(self._inner) + async def optimize( self, *, @@ -5777,6 +5804,50 @@ class Tags: LOOP.run(self._table.tags.update(tag, version)) +class Branches: + """ + Table branch manager. + """ + + def __init__(self, table): + self._table = table + + def list(self) -> Dict[str, Any]: + """List all branches, mapping name to branch metadata.""" + return LOOP.run(self._table.branches.list()) + + def create( + self, + name: str, + from_ref: Optional[str] = None, + from_version: Optional[int] = None, + ) -> "LanceTable": + """Create a branch and return a handle scoped to it. + + Parameters + ---------- + name: str + Name of the new branch. + from_ref: str, optional + Source branch to fork from. Defaults to ``main``. + from_version: int, optional + A specific version on ``from_ref`` to fork from. Defaults to latest. + """ + async_table = LOOP.run( + self._table.branches.create(name, from_ref, from_version) + ) + return LanceTable.from_inner(async_table._inner) + + def checkout(self, name: str) -> "LanceTable": + """Check out an existing branch and return a handle scoped to it.""" + async_table = LOOP.run(self._table.branches.checkout(name)) + return LanceTable.from_inner(async_table._inner) + + def delete(self, name: str) -> None: + """Delete a branch.""" + LOOP.run(self._table.branches.delete(name)) + + class AsyncTags: """ Async table tag manager. @@ -5844,3 +5915,47 @@ class AsyncTags: The new table version to tag. """ await self._table.tags.update(tag, version) + + +class AsyncBranches: + """Async table branch manager.""" + + def __init__(self, table): + self._table = table + + async def list(self) -> Dict[str, Any]: + """List all branches, mapping name to branch metadata.""" + return await self._table.branches.list() + + async def create( + self, + name: str, + from_ref: Optional[str] = None, + from_version: Optional[int] = None, + ) -> "AsyncTable": + """Create a branch and return a handle scoped to it. + + Parameters + ---------- + name: str + Name of the new branch. + from_ref: str, optional + Source branch to fork from. Defaults to ``main``. + from_version: int, optional + A specific version on ``from_ref`` to fork from. Defaults to latest. + """ + # "main" and None are two spellings of the root branch in lance; normalize + # so from_ref="main" behaves identically to the default. + if from_ref == "main": + from_ref = None + inner = await self._table.branches.create(name, from_ref, from_version) + return AsyncTable(inner) + + async def checkout(self, name: str) -> "AsyncTable": + """Check out an existing branch and return a handle scoped to it.""" + inner = await self._table.branches.checkout(name) + return AsyncTable(inner) + + async def delete(self, name: str) -> None: + """Delete a branch.""" + await self._table.branches.delete(name) diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 964f6b904..b230969b9 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -903,6 +903,79 @@ async def test_async_tags(mem_db_async: AsyncConnection): ) +def test_branches(tmp_path): + db = lancedb.connect(tmp_path) + table = db.create_table( + "test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], + ) + assert table.count_rows() == 2 + + # fork an isolated, writable branch from main + branch = table.branches.create("exp") + assert branch.count_rows() == 2 + branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}]) + + # writes on the branch do not touch main + assert branch.count_rows() == 3 + assert table.count_rows() == 2 + + # the branch is listed, with main (None) as its parent + branches = table.branches.list() + assert "exp" in branches + assert branches["exp"]["parent_branch"] is None + + # from_ref="main" is equivalent to the default + table.branches.create("exp2", from_ref="main") + assert table.branches.list()["exp2"]["parent_branch"] is None + + # checkout returns a handle scoped to the branch's latest + checked_out = table.branches.checkout("exp") + assert checked_out.count_rows() == 3 + + # delete removes it + table.branches.delete("exp") + table.branches.delete("exp2") + assert "exp" not in table.branches.list() + + +@pytest.mark.asyncio +async def test_async_branches(tmp_path): + db = await lancedb.connect_async(tmp_path) + table = await db.create_table( + "test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], + ) + assert await table.count_rows() == 2 + + branch = await table.branches.create("exp") + assert await branch.count_rows() == 2 + await branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}]) + + assert await branch.count_rows() == 3 + assert await table.count_rows() == 2 + + branches = await table.branches.list() + assert "exp" in branches + assert branches["exp"]["parent_branch"] is None + + await table.branches.create("exp2", from_ref="main") + assert (await table.branches.list())["exp2"]["parent_branch"] is None + + checked_out = await table.branches.checkout("exp") + assert await checked_out.count_rows() == 3 + + await table.branches.delete("exp") + await table.branches.delete("exp2") + assert "exp" not in await table.branches.list() + + @patch("lancedb.table.AsyncTable.create_index") def test_create_index_method(mock_create_index, mem_db: DBConnection): table = mem_db.create_table( diff --git a/python/src/table.rs b/python/src/table.rs index dc5f5ec0c..5993e035b 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -17,7 +17,7 @@ use arrow::{ }; use lancedb::table::{ AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, NewColumnTransform, - OptimizeAction, OptimizeOptions, Table as LanceDbTable, + OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable, }; use pyo3::{ Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python, @@ -864,6 +864,11 @@ impl Table { Ok(Tags::new(self.inner_ref()?.clone())) } + #[getter] + pub fn branches(&self) -> PyResult { + Ok(Branches::new(self.inner_ref()?.clone())) + } + #[pyo3(signature = (offsets))] pub fn take_offsets(self_: PyRef<'_, Self>, offsets: Vec) -> PyResult { Ok(TakeQuery::new( @@ -1265,3 +1270,66 @@ impl Tags { }) } } + +#[pyclass] +pub struct Branches { + inner: LanceDbTable, +} + +impl Branches { + pub fn new(table: LanceDbTable) -> Self { + Self { inner: table } + } +} + +#[pymethods] +impl Branches { + pub fn list(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let res = inner.list_branches().await.infer_error()?; + Python::attach(|py| { + let py_dict = PyDict::new(py); + for (name, contents) in res { + let value = PyDict::new(py); + value.set_item("parent_branch", contents.parent_branch)?; + value.set_item("parent_version", contents.parent_version)?; + value.set_item("manifest_size", contents.manifest_size)?; + py_dict.set_item(name, value)?; + } + Ok(py_dict.unbind()) + }) + }) + } + + #[pyo3(signature = (name, from_ref=None, from_version=None))] + pub fn create( + self_: PyRef<'_, Self>, + name: String, + from_ref: Option, + from_version: Option, + ) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let from = Ref::Version(from_ref, from_version); + let table = inner.create_branch(&name, from).await.infer_error()?; + Ok(Table::new(table)) + }) + } + + pub fn checkout(self_: PyRef<'_, Self>, name: String) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let table = inner.checkout_branch(&name).await.infer_error()?; + Ok(Table::new(table)) + }) + } + + pub fn delete(self_: PyRef<'_, Self>, name: String) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + inner.delete_branch(&name).await.infer_error()?; + Ok(()) + }) + } +} diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index a3a418913..207449241 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -86,7 +86,7 @@ pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior}; pub use chrono::Duration; pub use delete::DeleteResult; use futures::future::join_all; -pub use lance::dataset::refs::{BranchContents, TagContents, Tags as LanceTags}; +pub use lance::dataset::refs::{BranchContents, Ref, TagContents, Tags as LanceTags}; pub use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::statistics::DatasetStatisticsExt; pub use lance_index::optimize::OptimizeOptions;