From 0d4cb346f98854a26e43f0247391f286c9afd208 Mon Sep 17 00:00:00 2001 From: Brendan Clement Date: Tue, 2 Jun 2026 16:35:53 -0700 Subject: [PATCH] feat: add table branch support to the Rust core --- rust/lancedb/src/remote/table.rs | 32 ++++++ rust/lancedb/src/table.rs | 158 +++++++++++++++++++++++++++++- rust/lancedb/src/table/dataset.rs | 53 +++++++++- 3 files changed, 240 insertions(+), 3 deletions(-) diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 6dab22590..948dce594 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1383,6 +1383,38 @@ impl BaseTable for RemoteTable { .map_err(unwrap_shared_error) } + async fn create_branch( + &self, + _name: &str, + _from: lance::dataset::refs::Ref, + ) -> Result> { + Err(Error::NotSupported { + message: "branching is not yet supported on remote tables".into(), + }) + } + + async fn checkout_branch(&self, _name: &str) -> Result> { + Err(Error::NotSupported { + message: "branching is not yet supported on remote tables".into(), + }) + } + + async fn list_branches(&self) -> Result> { + Err(Error::NotSupported { + message: "branching is not yet supported on remote tables".into(), + }) + } + + async fn delete_branch(&self, _name: &str) -> Result<()> { + Err(Error::NotSupported { + message: "branching is not yet supported on remote tables".into(), + }) + } + + fn current_branch(&self) -> Option { + None + } + async fn count_rows(&self, filter: Option) -> Result { let mut request = self.post_read(&format!("/v1/table/{}/count_rows/", self.identifier)); diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 355483f0c..a3a418913 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -86,7 +86,7 @@ pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior}; pub use chrono::Duration; pub use delete::DeleteResult; use futures::future::join_all; -pub use lance::dataset::refs::{TagContents, Tags as LanceTags}; +pub use lance::dataset::refs::{BranchContents, TagContents, Tags as LanceTags}; pub use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::statistics::DatasetStatisticsExt; pub use lance_index::optimize::OptimizeOptions; @@ -625,6 +625,20 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { async fn restore(&self) -> Result<()>; /// List the versions of the table. async fn list_versions(&self) -> Result>; + /// Create a new branch from `from` and return a handle scoped to it. + async fn create_branch( + &self, + name: &str, + from: lance::dataset::refs::Ref, + ) -> Result>; + /// Check out an existing branch and return a handle scoped to it. + async fn checkout_branch(&self, name: &str) -> Result>; + /// List the branches of the table. + async fn list_branches(&self) -> Result>; + /// Delete a branch. + async fn delete_branch(&self, name: &str) -> Result<()>; + /// The branch this handle is scoped to, or `None` for `main`. + fn current_branch(&self) -> Option; /// Get the table definition. async fn table_definition(&self) -> Result; /// Get the table URI (storage location) @@ -1625,6 +1639,46 @@ impl Table { self.inner.tags().await } + /// Create a new branch from `from` (a version, tag, or branch) and return + /// a writable, isolated handle scoped to it. `self` is unaffected. + pub async fn create_branch( + &self, + name: &str, + from: impl Into, + ) -> Result { + let inner = self.inner.create_branch(name, from.into()).await?; + Ok(Self { + inner, + database: self.database.clone(), + embedding_registry: self.embedding_registry.clone(), + }) + } + + /// Check out an existing branch and return a handle scoped to it. + pub async fn checkout_branch(&self, name: &str) -> Result
{ + let inner = self.inner.checkout_branch(name).await?; + Ok(Self { + inner, + database: self.database.clone(), + embedding_registry: self.embedding_registry.clone(), + }) + } + + /// List the branches of the table. + pub async fn list_branches(&self) -> Result> { + self.inner.list_branches().await + } + + /// Delete a branch. + pub async fn delete_branch(&self, name: &str) -> Result<()> { + self.inner.delete_branch(name).await + } + + /// The branch this handle is scoped to, or `None` for `main`. + pub fn current_branch(&self) -> Option { + self.inner.current_branch() + } + /// Retrieve statistics on the table pub async fn stats(&self) -> Result { self.inner.stats().await @@ -1861,6 +1915,21 @@ impl NativeTable { self } + /// Build a sibling `NativeTable` with the same identity but a different + /// (independent) dataset wrapper — used to hand out branch-scoped handles. + fn with_dataset(&self, dataset: dataset::DatasetConsistencyWrapper) -> Self { + Self { + name: self.name.clone(), + namespace: self.namespace.clone(), + id: self.id.clone(), + uri: self.uri.clone(), + dataset, + read_consistency_interval: self.read_consistency_interval, + namespace_client: self.namespace_client.clone(), + pushdown_operations: self.pushdown_operations.clone(), + } + } + /// Opens an existing Table using a namespace client. /// /// This method uses `DatasetBuilder::from_namespace` to open the table, which @@ -2652,6 +2721,43 @@ impl BaseTable for NativeTable { self.dataset.reload().await } + async fn create_branch( + &self, + name: &str, + from: lance::dataset::refs::Ref, + ) -> Result> { + let mut ds = (*self.dataset.get().await?).clone(); + let branch_ds = ds.create_branch(name, from, None).await?; + let dataset = dataset::DatasetConsistencyWrapper::new_latest( + branch_ds, + self.read_consistency_interval, + ); + Ok(Arc::new(self.with_dataset(dataset))) + } + + async fn checkout_branch(&self, name: &str) -> Result> { + let branch_ds = self.dataset.get().await?.checkout_branch(name).await?; + let dataset = dataset::DatasetConsistencyWrapper::new_latest( + branch_ds, + self.read_consistency_interval, + ); + Ok(Arc::new(self.with_dataset(dataset))) + } + + async fn list_branches(&self) -> Result> { + Ok(self.dataset.get().await?.list_branches().await?) + } + + async fn delete_branch(&self, name: &str) -> Result<()> { + let mut ds = (*self.dataset.get().await?).clone(); + ds.delete_branch(name).await?; + Ok(()) + } + + fn current_branch(&self) -> Option { + self.dataset.current_branch() + } + async fn list_versions(&self) -> Result> { Ok(self.dataset.get().await?.versions().await?) } @@ -3378,6 +3484,56 @@ mod tests { assert_eq!(table.version().await.unwrap(), 4); } + #[tokio::test] + async fn test_branches() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + let conn = ConnectBuilder::new(uri) + .read_consistency_interval(Duration::from_secs(0)) + .execute() + .await + .unwrap(); + + // main: one row at v1 + let table = conn + .create_table("my_table", some_sample_data()) + .execute() + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 1); + assert_eq!(table.current_branch(), None); + let main_version = table.version().await.unwrap(); + + // branch off main's current version; it starts with main's data + let branch = table.create_branch("exp", main_version).await.unwrap(); + assert_eq!(branch.current_branch().as_deref(), Some("exp")); + assert_eq!(branch.count_rows(None).await.unwrap(), 1); + + // writes on the branch are isolated from main + branch.add(some_sample_data()).execute().await.unwrap(); + assert_eq!(branch.count_rows(None).await.unwrap(), 2); + assert_eq!( + table.count_rows(None).await.unwrap(), + 1, + "main must be untouched by branch writes" + ); + + // the branch shows up in the listing + let branches = table.list_branches().await.unwrap(); + assert!(branches.contains_key("exp")); + + // checking out the branch from the main handle sees the branch's latest data + let checked_out = table.checkout_branch("exp").await.unwrap(); + assert_eq!(checked_out.current_branch().as_deref(), Some("exp")); + assert_eq!(checked_out.count_rows(None).await.unwrap(), 2); + + // delete removes it from the listing + table.delete_branch("exp").await.unwrap(); + let branches = table.list_branches().await.unwrap(); + assert!(!branches.contains_key("exp")); + } + #[tokio::test] async fn test_create_index() { use arrow_array::RecordBatch; diff --git a/rust/lancedb/src/table/dataset.rs b/rust/lancedb/src/table/dataset.rs index b4673d876..1ff11198e 100644 --- a/rust/lancedb/src/table/dataset.rs +++ b/rust/lancedb/src/table/dataset.rs @@ -144,8 +144,19 @@ impl DatasetConsistencyWrapper { } /// Checkout a branch and track its HEAD for new versions. - pub async fn as_branch(&self, _branch: impl Into) -> Result<()> { - todo!("Branch support not yet implemented") + pub async fn as_branch(&self, branch: impl Into) -> Result<()> { + let branch = branch.into(); + let dataset = { self.state.lock()?.dataset.clone() }; + let new_dataset = dataset.checkout_branch(&branch).await?; + + let mut state = self.state.lock()?; + state.dataset = Arc::new(new_dataset); + state.pinned_version = None; + drop(state); + if let ConsistencyMode::Eventual(bg_cache) = &self.consistency { + bg_cache.invalidate(); + } + Ok(()) } /// Check that the dataset is in a mutable mode (Latest). @@ -161,6 +172,17 @@ impl DatasetConsistencyWrapper { } } + /// The branch this wrapper is currently tracking, or `None` for `main`. + pub fn current_branch(&self) -> Option { + self.state + .lock() + .unwrap_or_else(|e| e.into_inner()) + .dataset + .manifest() + .branch + .clone() + } + /// Returns the version, if in time travel mode, or None otherwise. pub fn time_travel_version(&self) -> Option { self.state @@ -737,4 +759,31 @@ mod tests { let result = wrapper.reload().await; assert!(result.is_err()); } + + #[tokio::test] + async fn test_as_branch_is_writable_and_tracked() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + + // v1 on main, then shallow-clone a branch off it + let mut ds = create_test_dataset(uri).await; + let v1 = ds.version().version; + ds.create_branch("exp", v1, None).await.unwrap(); + + // wrapper starts on main: latest, writable, no branch + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + assert_eq!(wrapper.current_branch(), None); + + // switch to the branch + wrapper.as_branch("exp").await.unwrap(); + assert_eq!(wrapper.current_branch().as_deref(), Some("exp")); + + // a branch is writable (unlike a pinned/time-travel checkout) + wrapper.ensure_mutable().unwrap(); + assert_eq!(wrapper.time_travel_version(), None); + + // get() returns the branch dataset + let on_branch = wrapper.get().await.unwrap(); + assert_eq!(on_branch.manifest().branch.as_deref(), Some("exp")); + } }