From 59824ab438cf30bacb57ca12eb552c2a4a086940 Mon Sep 17 00:00:00 2001 From: Brendan Clement Date: Wed, 3 Jun 2026 11:18:14 -0700 Subject: [PATCH] fix: address review comments on branch support --- python/python/lancedb/namespace.py | 6 ++- python/python/tests/test_table.py | 13 ++++++ rust/lancedb/src/database/namespace.rs | 58 ++++++++++++++++++++++++++ rust/lancedb/src/table/query.rs | 5 ++- 4 files changed, 80 insertions(+), 2 deletions(-) diff --git a/python/python/lancedb/namespace.py b/python/python/lancedb/namespace.py index bc69c087a..ba6749787 100644 --- a/python/python/lancedb/namespace.py +++ b/python/python/lancedb/namespace.py @@ -978,12 +978,13 @@ class AsyncLanceNamespaceDBConnection: namespace_path: Optional[List[str]] = None, storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, + branch: Optional[str] = None, ) -> AsyncTable: """Open an existing table from the namespace.""" if namespace_path is None: namespace_path = [] try: - return await self._inner.open_table( + tbl = await self._inner.open_table( name, namespace_path=namespace_path, storage_options=storage_options, @@ -994,6 +995,9 @@ class AsyncLanceNamespaceDBConnection: table_id = namespace_path + [name] raise TableNotFoundError(f"Table not found: {'$'.join(table_id)}") raise + if branch is not None: + return await tbl.branches.checkout(branch) + return tbl async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None): """Drop a table from the namespace.""" diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 2dd550566..d838580c9 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -982,6 +982,19 @@ def test_open_table_with_branch(tmp_path): assert db.open_table("t").count_rows() == 1 +@pytest.mark.asyncio +async def test_async_namespace_open_table_with_branch(tmp_path): + db = lancedb.connect_namespace_async("dir", {"root": str(tmp_path)}) + await db.create_namespace(["ns1"]) + table = await db.create_table("t", [{"id": 1}], namespace_path=["ns1"]) + branch = await table.branches.create("exp") + await branch.add([{"id": 2}]) + + # open_table(branch=...) on the async namespace connection must work + opened = await db.open_table("t", namespace_path=["ns1"], branch="exp") + assert await opened.count_rows() == 2 + + def test_branch_to_lance_targets_branch(tmp_path): pytest.importorskip("lance") db = lancedb.connect(tmp_path) diff --git a/rust/lancedb/src/database/namespace.rs b/rust/lancedb/src/database/namespace.rs index 4a0b09e05..fda665a31 100644 --- a/rust/lancedb/src/database/namespace.rs +++ b/rust/lancedb/src/database/namespace.rs @@ -740,6 +740,64 @@ mod tests { assert!(table_names.contains(&"test_table".to_string())); } + #[tokio::test] + async fn test_namespace_branch_query_under_pushdown_stays_local() { + // With QueryTable pushdown enabled, a query on the main branch routes to + // the namespace server, but a branch handle must run locally: the + // server-side request carries no branch and would return main's rows. + let tmp_dir = tempdir().unwrap(); + let root_path = tmp_dir.path().to_str().unwrap().to_string(); + + let mut properties = HashMap::new(); + properties.insert("root".to_string(), root_path); + + let conn = connect_namespace("dir", properties) + .pushdown_operation(NamespaceClientPushdownOperation::QueryTable) + .execute() + .await + .expect("Failed to connect to namespace"); + + conn.create_namespace(CreateNamespaceRequest { + id: Some(vec!["test_ns".into()]), + ..Default::default() + }) + .await + .expect("Failed to create namespace"); + + // main has 5 rows + let table = conn + .create_table("ref_test", create_test_data()) + .namespace(vec!["test_ns".into()]) + .execute() + .await + .expect("Failed to create table"); + let main_version = table.version().await.unwrap(); + + // fork a branch off main, then add 5 more rows so it differs from main + let branch = table + .create_branch("exp", main_version) + .await + .expect("Failed to create branch"); + branch + .add(create_test_data()) + .execute() + .await + .expect("Failed to append to branch"); + + // the branch query must run locally and see the branch's 10 rows -- + // not get routed to the server (which carries no branch) and see main's 5 + let results = branch + .query() + .execute() + .await + .expect("Failed to query branch") + .try_collect::>() + .await + .expect("Failed to collect results"); + let count: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(count, 10); + } + #[tokio::test] async fn test_namespace_describe_table() { // Setup: Create a temporary directory for the namespace diff --git a/rust/lancedb/src/table/query.rs b/rust/lancedb/src/table/query.rs index cc9312a0f..b136de2cd 100644 --- a/rust/lancedb/src/table/query.rs +++ b/rust/lancedb/src/table/query.rs @@ -41,11 +41,14 @@ pub async fn execute_query( query: &AnyQuery, options: QueryExecutionOptions, ) -> Result { - // If QueryTable pushdown is enabled and namespace client is configured, use server-side query execution + // QueryTable pushdown runs the query server-side, but only on the main + // branch: the namespace request carries no branch yet, so a branch handle + // must fall through to local execution. if table .pushdown_operations .contains(&NamespaceClientPushdownOperation::QueryTable) && let Some(ref namespace_client) = table.namespace_client + && table.dataset.current_branch().is_none() { return execute_namespace_query(table, namespace_client.clone(), query, options).await; }