diff --git a/docs/src/js/interfaces/OpenTableOptions.md b/docs/src/js/interfaces/OpenTableOptions.md index 07190aa16..3c7f1e817 100644 --- a/docs/src/js/interfaces/OpenTableOptions.md +++ b/docs/src/js/interfaces/OpenTableOptions.md @@ -8,6 +8,18 @@ ## Properties +### branch? + +```ts +optional branch: string; +``` + +Open the table scoped to this branch instead of the default branch. + +Reads and writes on the returned table operate in the branch's context. + +*** + ### ~~indexCacheSize?~~ ```ts diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 2361d68c4..a066a100f 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -118,6 +118,21 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( expect(Object.keys(after)).not.toContain("exp"); }); + it("should open a branch via open_table", async () => { + const db = await connect(tmpDir.name); + await table.add([{ id: 1 }]); + const branch = await (await table.branches()).create("exp"); + await branch.add([{ id: 2 }]); + + // open_table(..., { branch }) returns a handle scoped to the branch + const opened = await db.openTable("some_table", undefined, { + branch: "exp", + }); + expect(await opened.countRows()).toBe(2); + // opening without branch still tracks main + expect(await (await db.openTable("some_table")).countRows()).toBe(1); + }); + it("should show table stats", async () => { await table.add([{ id: 1 }, { id: 2 }]); await table.add([{ id: 1 }]); diff --git a/nodejs/lancedb/connection.ts b/nodejs/lancedb/connection.ts index f6eacb20f..5ce155789 100644 --- a/nodejs/lancedb/connection.ts +++ b/nodejs/lancedb/connection.ts @@ -84,6 +84,12 @@ export interface CreateTableOptions { } export interface OpenTableOptions { + /** + * Open the table scoped to this branch instead of the default branch. + * + * Reads and writes on the returned table operate in the branch's context. + */ + branch?: string; /** * Configuration for object storage. * @@ -483,7 +489,11 @@ export class LocalConnection extends Connection { options?.indexCacheSize, ); - return new LocalTable(innerTable); + const table = new LocalTable(innerTable); + if (options?.branch != null) { + return (await table.branches()).checkout(options.branch); + } + return table; } async cloneTable( diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 76b3b4aeb..d7ec073ed 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -416,6 +416,7 @@ class DBConnection(EnforceOverrides): namespace_path: Optional[List[str]] = None, storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, + branch: Optional[str] = None, ) -> Table: """Open a Lance Table in the database. @@ -444,6 +445,9 @@ class DBConnection(EnforceOverrides): connection will be inherited by the table, but can be overridden here. See available options at + branch: str, optional + If provided, open a handle scoped to this branch instead of the + default branch. Reads and writes operate in the branch's context. Returns ------- @@ -958,6 +962,7 @@ class LanceDBConnection(DBConnection): namespace_path: Optional[List[str]] = None, storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, + branch: Optional[str] = None, ) -> LanceTable: """Open a table in the database. @@ -968,6 +973,9 @@ class LanceDBConnection(DBConnection): namespace_path: List[str], optional The namespace to open the table from. When non-empty, the table is resolved through the directory namespace client. + branch: str, optional + If provided, open a handle scoped to this branch instead of the + default branch. Reads and writes operate in the branch's context. Returns ------- @@ -987,20 +995,24 @@ class LanceDBConnection(DBConnection): ) if namespace_path: - return self._namespace_conn().open_table( + tbl = self._namespace_conn().open_table( + name, + namespace_path=namespace_path, + storage_options=storage_options, + index_cache_size=index_cache_size, + ) + else: + tbl = LanceTable.open( + self, name, namespace_path=namespace_path, storage_options=storage_options, index_cache_size=index_cache_size, ) - return LanceTable.open( - self, - name, - namespace_path=namespace_path, - storage_options=storage_options, - index_cache_size=index_cache_size, - ) + if branch is not None: + return tbl.branches.checkout(branch) + return tbl def clone_table( self, @@ -1641,6 +1653,7 @@ class AsyncConnection(object): location: Optional[str] = None, namespace_client: Optional[Any] = None, managed_versioning: Optional[bool] = None, + branch: Optional[str] = None, ) -> AsyncTable: """Open a Lance Table in the database. @@ -1676,6 +1689,9 @@ class AsyncConnection(object): managed_versioning: bool, optional Whether managed versioning is enabled for this table. If provided, avoids a redundant describe_table call when namespace_client is set. + branch: str, optional + If provided, open a handle scoped to this branch instead of the + default branch. Reads and writes operate in the branch's context. Returns ------- @@ -1692,7 +1708,10 @@ class AsyncConnection(object): namespace_client=namespace_client, managed_versioning=managed_versioning, ) - return AsyncTable(table) + tbl = AsyncTable(table) + if branch is not None: + return await tbl.branches.checkout(branch) + return tbl async def clone_table( self, diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 931f8c759..8548f2e27 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -952,6 +952,23 @@ def test_branches_preserve_namespace(tmp_path): assert branch.id == table.id +def test_open_table_with_branch(tmp_path): + db = lancedb.connect(tmp_path) + table = db.create_table("t", [{"i": 1}]) + table.branches.create("exp").add([{"i": 2}]) + + # open_table(branch=...) returns a handle scoped to the branch + assert db.open_table("t", branch="exp").count_rows() == 2 + # opening without branch still tracks main + assert db.open_table("t").count_rows() == 1 + + # with a namespace, the opened branch handle preserves namespace identity + nt = db.create_table("ns_t", [{"i": 1}], namespace_path=["ns1"]) + nt.branches.create("exp") + opened = db.open_table("ns_t", namespace_path=["ns1"], branch="exp") + assert opened.namespace == ["ns1"] + + @pytest.mark.asyncio async def test_async_branches(tmp_path): db = await lancedb.connect_async(tmp_path) diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index c1d475c9c..a478abfea 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -119,6 +119,7 @@ pub struct OpenTableBuilder { parent: Arc, request: OpenTableRequest, embedding_registry: Arc, + branch: Option, } impl OpenTableBuilder { @@ -139,6 +140,7 @@ impl OpenTableBuilder { managed_versioning: None, }, embedding_registry, + branch: None, } } @@ -259,14 +261,22 @@ impl OpenTableBuilder { self } + /// Open the table scoped to the given branch instead of the default branch. + /// + /// Reads and writes on the returned table operate in the branch's context. + pub fn branch(mut self, branch: impl Into) -> Self { + self.branch = Some(branch.into()); + self + } + /// Open the table pub async fn execute(self) -> Result { let table = self.parent.open_table(self.request).await?; - Ok(Table::new_with_embedding_registry( - table, - self.parent, - self.embedding_registry, - )) + let table = Table::new_with_embedding_registry(table, self.parent, self.embedding_registry); + match self.branch { + Some(branch) => table.checkout_branch(&branch).await, + None => Ok(table), + } } } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index b0df27450..c7a584203 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -3528,6 +3528,16 @@ mod tests { assert_eq!(checked_out.current_branch().as_deref(), Some("exp")); assert_eq!(checked_out.count_rows(None).await.unwrap(), 2); + // open_table(...).branch(...) opens directly onto the branch + let opened = conn + .open_table("my_table") + .branch("exp") + .execute() + .await + .unwrap(); + assert_eq!(opened.current_branch().as_deref(), Some("exp")); + assert_eq!(opened.count_rows(None).await.unwrap(), 2); + // delete removes it from the listing table.delete_branch("exp").await.unwrap(); let branches = table.list_branches().await.unwrap();