From 53517b3aaaefc2a65872ef9fb39b46d91d677b64 Mon Sep 17 00:00:00 2001 From: Brendan Clement Date: Mon, 8 Jun 2026 16:26:46 -0700 Subject: [PATCH] feat: add table branch support (#3490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Adds first-class support for table branches across the Rust core and the Python and TypeScript SDKs. Rust ```rust use lance::dataset::refs::Ref; // Create a branch from main and write to it — main is untouched. let exp = table.create_branch("exp", Ref::Version(None, None)).await?; exp.add(batches).await?; // Reopen the branch later: check out from a table, or open it directly. let exp = table.checkout_branch("exp").await?; let exp = db.open_table("items").branch("exp").execute().await?; let branches = table.list_branches().await?; table.delete_branch("exp").await?; ``` Python ```python # Create a branch from main and write to it branch = await table.branches.create("exp", from_ref="main") await branch.add(data) # Reopen the branch later: check out from a table, or open it directly. branch = await table.branches.checkout("exp") branch = await db.open_table("items", branch="exp") await table.branches.list() await table.branches.delete("exp") ``` TypeScript ```typescript const branches = await table.branches(); // Create a branch from main and write to it const branch = await branches.create("exp"); await branch.add(data); // Reopen the branch later: check out from a table, or open it directly. const checkedOut = await branches.checkout("exp"); const opened = await db.openTable("items", undefined, { branch: "exp" }); await branches.list(); await branches.delete("exp"); ``` ### Testing - Added unit tests - ran smoke tests against python and typescript sdks on local machine ### Next steps - Add RemoteTable support - Add Branch Comparison support - Merge Branching support --- docs/src/js/classes/BranchContents.md | 43 +++ docs/src/js/classes/Branches.md | 90 +++++++ docs/src/js/classes/Table.md | 17 ++ docs/src/js/globals.md | 2 + docs/src/js/interfaces/OpenTableOptions.md | 12 + nodejs/__test__/table.test.ts | 58 +++++ nodejs/lancedb/connection.ts | 12 +- nodejs/lancedb/index.ts | 2 + nodejs/lancedb/table.ts | 62 +++++ nodejs/src/table.rs | 78 +++++- python/python/lancedb/_lancedb.pyi | 14 + python/python/lancedb/db.py | 37 ++- python/python/lancedb/namespace.py | 9 +- python/python/lancedb/remote/db.py | 4 + python/python/lancedb/table.py | 183 ++++++++++++- python/python/tests/test_namespace.py | 4 + python/python/tests/test_table.py | 154 +++++++++++ python/src/table.rs | 74 +++++- rust/lancedb/src/connection.rs | 20 +- rust/lancedb/src/database/namespace.rs | 58 +++++ rust/lancedb/src/remote/table.rs | 32 +++ rust/lancedb/src/table.rs | 287 ++++++++++++++++++++- rust/lancedb/src/table/dataset.rs | 53 +++- rust/lancedb/src/table/query.rs | 5 +- 24 files changed, 1275 insertions(+), 35 deletions(-) create mode 100644 docs/src/js/classes/BranchContents.md create mode 100644 docs/src/js/classes/Branches.md diff --git a/docs/src/js/classes/BranchContents.md b/docs/src/js/classes/BranchContents.md new file mode 100644 index 000000000..47339d7ab --- /dev/null +++ b/docs/src/js/classes/BranchContents.md @@ -0,0 +1,43 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / BranchContents + +# Class: BranchContents + +## Constructors + +### new BranchContents() + +```ts +new BranchContents(): BranchContents +``` + +#### Returns + +[`BranchContents`](BranchContents.md) + +## Properties + +### manifestSize + +```ts +manifestSize: number; +``` + +*** + +### parentBranch? + +```ts +optional parentBranch: string; +``` + +*** + +### parentVersion + +```ts +parentVersion: number; +``` diff --git a/docs/src/js/classes/Branches.md b/docs/src/js/classes/Branches.md new file mode 100644 index 000000000..ea4d72f07 --- /dev/null +++ b/docs/src/js/classes/Branches.md @@ -0,0 +1,90 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / Branches + +# Class: Branches + +Branch manager for a [Table](Table.md). + +Unlike tags, `create` and `checkout` return a new [Table](Table.md) handle scoped +to the branch; writes on it do not affect `main`. + +## Methods + +### checkout() + +```ts +checkout(name): Promise +``` + +Check out an existing branch and return a handle scoped to it. + +#### Parameters + +* **name**: `string` + +#### Returns + +`Promise`<[`Table`](Table.md)> + +*** + +### create() + +```ts +create( + name, + fromRef?, + fromVersion?): Promise
+``` + +Create a branch and return a handle scoped to it. + +#### Parameters + +* **name**: `string` + Name of the new branch. + +* **fromRef?**: `string` + Source branch to fork from. Defaults to `main`. + +* **fromVersion?**: `number` + A specific version on `fromRef`. Defaults to latest. + +#### Returns + +`Promise`<[`Table`](Table.md)> + +*** + +### delete() + +```ts +delete(name): Promise +``` + +Delete a branch. + +#### Parameters + +* **name**: `string` + +#### Returns + +`Promise`<`void`> + +*** + +### list() + +```ts +list(): Promise> +``` + +List all branches, mapping name to branch metadata. + +#### Returns + +`Promise`<`Record`<`string`, [`BranchContents`](BranchContents.md)>> diff --git a/docs/src/js/classes/Table.md b/docs/src/js/classes/Table.md index 1675f2c93..a7a4dfafc 100644 --- a/docs/src/js/classes/Table.md +++ b/docs/src/js/classes/Table.md @@ -110,6 +110,23 @@ containing the new version number of the table after altering the columns. *** +### branches() + +```ts +abstract branches(): Promise +``` + +Get the branch manager for this table. + +Branches are isolated, writable lines of history forked from another +branch (or version). Writes on a branch do not affect `main`. + +#### Returns + +`Promise`<[`Branches`](Branches.md)> + +*** + ### checkout() ```ts diff --git a/docs/src/js/globals.md b/docs/src/js/globals.md index 7d26fd75b..3efa3a360 100644 --- a/docs/src/js/globals.md +++ b/docs/src/js/globals.md @@ -19,6 +19,8 @@ - [BooleanQuery](classes/BooleanQuery.md) - [BoostQuery](classes/BoostQuery.md) +- [BranchContents](classes/BranchContents.md) +- [Branches](classes/Branches.md) - [Connection](classes/Connection.md) - [HeaderProvider](classes/HeaderProvider.md) - [Index](classes/Index.md) diff --git a/docs/src/js/interfaces/OpenTableOptions.md b/docs/src/js/interfaces/OpenTableOptions.md index 07190aa16..3c7f1e817 100644 --- a/docs/src/js/interfaces/OpenTableOptions.md +++ b/docs/src/js/interfaces/OpenTableOptions.md @@ -8,6 +8,18 @@ ## Properties +### branch? + +```ts +optional branch: string; +``` + +Open the table scoped to this branch instead of the default branch. + +Reads and writes on the returned table operate in the branch's context. + +*** + ### ~~indexCacheSize?~~ ```ts diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index d87fa34d4..3ca570aad 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -85,6 +85,64 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( await expect(table.countRows()).resolves.toBe(3); }); + it("should support branches", async () => { + await table.add([{ id: 1 }]); + expect(await table.countRows()).toBe(1); + + // fork an isolated, writable branch from main + const branch = await (await table.branches()).create("exp"); + expect(await branch.countRows()).toBe(1); + await branch.add([{ id: 2 }]); + expect(await branch.countRows()).toBe(2); + // main is untouched by branch writes + expect(await table.countRows()).toBe(1); + + // listed, with main (null) as the parent + const list = await (await table.branches()).list(); + expect(Object.keys(list)).toContain("exp"); + expect(list["exp"].parentBranch).toBeNull(); + + // fromRef="main" is equivalent to the default + await (await table.branches()).create("exp2", "main"); + const list2 = await (await table.branches()).list(); + expect(list2["exp2"].parentBranch).toBeNull(); + + // checkout returns a handle scoped to the branch's latest + const checkedOut = await (await table.branches()).checkout("exp"); + expect(await checkedOut.countRows()).toBe(2); + + // delete removes it + await (await table.branches()).delete("exp"); + await (await table.branches()).delete("exp2"); + const after = await (await table.branches()).list(); + expect(Object.keys(after)).not.toContain("exp"); + }); + + it("should open a branch via open_table", async () => { + const db = await connect(tmpDir.name); + await table.add([{ id: 1 }]); + const branch = await (await table.branches()).create("exp"); + await branch.add([{ id: 2 }]); + + // open_table(..., { branch }) returns a handle scoped to the branch + const opened = await db.openTable("some_table", undefined, { + branch: "exp", + }); + expect(await opened.countRows()).toBe(2); + // opening without branch still tracks main + expect(await (await db.openTable("some_table")).countRows()).toBe(1); + }); + + it("rejects invalid branch inputs", async () => { + const branches = await table.branches(); + await expect(branches.create("")).rejects.toThrow("non-empty"); + await expect(branches.checkout("")).rejects.toThrow("non-empty"); + await expect(branches.delete("")).rejects.toThrow("non-empty"); + await expect(branches.create("bad", "main", -1)).rejects.toThrow( + "non-negative", + ); + }); + it("should show table stats", async () => { await table.add([{ id: 1 }, { id: 2 }]); await table.add([{ id: 1 }]); diff --git a/nodejs/lancedb/connection.ts b/nodejs/lancedb/connection.ts index f6eacb20f..5ce155789 100644 --- a/nodejs/lancedb/connection.ts +++ b/nodejs/lancedb/connection.ts @@ -84,6 +84,12 @@ export interface CreateTableOptions { } export interface OpenTableOptions { + /** + * Open the table scoped to this branch instead of the default branch. + * + * Reads and writes on the returned table operate in the branch's context. + */ + branch?: string; /** * Configuration for object storage. * @@ -483,7 +489,11 @@ export class LocalConnection extends Connection { options?.indexCacheSize, ); - return new LocalTable(innerTable); + const table = new LocalTable(innerTable); + if (options?.branch != null) { + return (await table.branches()).checkout(options.branch); + } + return table; } async cloneTable( diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index f4f724e8a..c74cf1caa 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -38,6 +38,7 @@ export { FragmentSummaryStats, Tags, TagContents, + BranchContents, MergeResult, AddResult, AddColumnsResult, @@ -111,6 +112,7 @@ export { export { Table, + Branches, AddDataOptions, UpdateOptions, OptimizeOptions, diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index e3821bd81..fce768f74 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -25,10 +25,12 @@ import { AddColumnsSql, AddResult, AlterColumnsResult, + BranchContents, DeleteResult, DropColumnsResult, IndexConfig, IndexStatistics, + Branches as NativeBranches, OptimizeStats, TableStatistics, Tags, @@ -653,6 +655,14 @@ export abstract class Table { */ abstract tags(): Promise; + /** + * Get the branch manager for this table. + * + * Branches are isolated, writable lines of history forked from another + * branch (or version). Writes on a branch do not affect `main`. + */ + abstract branches(): Promise; + /** * Restore the table to the currently checked out version * @@ -1108,6 +1118,10 @@ export class LocalTable extends Table { return await this.inner.tags(); } + async branches(): Promise { + return new Branches(await this.inner.branches()); + } + async optimize(options?: Partial): Promise { let cleanupOlderThanMs; if ( @@ -1238,3 +1252,51 @@ export interface FieldMetadataUpdate { /** If true, replace the field's entire metadata map instead of merging. */ replace?: boolean; } + +/** + * Branch manager for a {@link Table}. + * + * Unlike tags, `create` and `checkout` return a new {@link Table} handle scoped + * to the branch; writes on it do not affect `main`. + */ +export class Branches { + #inner: NativeBranches; + + /** + * Construct a Branches manager. Internal use only. + * @hidden + */ + constructor(inner: NativeBranches) { + this.#inner = inner; + } + + /** List all branches, mapping name to branch metadata. */ + async list(): Promise> { + return await this.#inner.list(); + } + + /** + * Create a branch and return a handle scoped to it. + * + * @param name Name of the new branch. + * @param fromRef Source branch to fork from. Defaults to `main`. + * @param fromVersion A specific version on `fromRef`. Defaults to latest. + */ + async create( + name: string, + fromRef?: string, + fromVersion?: number, + ): Promise
{ + return new LocalTable(await this.#inner.create(name, fromRef, fromVersion)); + } + + /** Check out an existing branch and return a handle scoped to it. */ + async checkout(name: string): Promise
{ + return new LocalTable(await this.#inner.checkout(name)); + } + + /** Delete a branch. */ + async delete(name: string): Promise { + return await this.#inner.delete(name); + } +} diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 5f8be3244..5e2bbb545 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -7,7 +7,7 @@ use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema}; use lancedb::table::{ AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration, FieldMetadataUpdate as LanceFieldMetadataUpdate, NewColumnTransform, OptimizeAction, - OptimizeOptions, Table as LanceDbTable, + OptimizeOptions, Ref, Table as LanceDbTable, }; use napi::bindgen_prelude::*; use napi::threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode}; @@ -478,6 +478,13 @@ impl Table { }) } + #[napi(catch_unwind)] + pub async fn branches(&self) -> napi::Result { + Ok(Branches { + inner: self.inner_ref()?.clone(), + }) + } + #[napi(catch_unwind)] pub async fn optimize( &self, @@ -1056,6 +1063,13 @@ pub struct TagContents { pub manifest_size: i64, } +#[napi] +pub struct BranchContents { + pub parent_branch: Option, + pub parent_version: i64, + pub manifest_size: i64, +} + #[napi] pub struct Tags { inner: LanceDbTable, @@ -1124,3 +1138,65 @@ impl Tags { .default_error() } } + +#[napi] +pub struct Branches { + inner: LanceDbTable, +} + +#[napi] +impl Branches { + #[napi] + pub async fn list(&self) -> napi::Result> { + let branches = self.inner.list_branches().await.default_error()?; + let result = branches + .into_iter() + .map(|(k, v)| { + ( + k, + BranchContents { + parent_branch: v.parent_branch, + parent_version: v.parent_version as i64, + manifest_size: v.manifest_size as i64, + }, + ) + }) + .collect(); + Ok(result) + } + + #[napi] + pub async fn create( + &self, + name: String, + from_ref: Option, + from_version: Option, + ) -> napi::Result
{ + let from_ref = from_ref.filter(|b| b != "main"); + let from_version = from_version + .map(|v| { + u64::try_from(v).map_err(|_| { + napi::Error::from_reason("from_version must be a non-negative integer") + }) + }) + .transpose()?; + let from = Ref::Version(from_ref, from_version); + let table = self + .inner + .create_branch(&name, from) + .await + .default_error()?; + Ok(Table::new(table)) + } + + #[napi] + pub async fn checkout(&self, name: String) -> napi::Result
{ + let table = self.inner.checkout_branch(&name).await.default_error()?; + Ok(Table::new(table)) + } + + #[napi] + pub async fn delete(&self, name: String) -> napi::Result<()> { + self.inner.delete_branch(&name).await.default_error() + } +} diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index afbd62086..db9afda35 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -226,6 +226,9 @@ class Table: async def close_lsm_writers(self) -> None: ... @property def tags(self) -> Tags: ... + @property + def branches(self) -> Branches: ... + def current_branch(self) -> Optional[str]: ... def query(self) -> Query: ... def take_offsets(self, offsets: list[int]) -> TakeQuery: ... def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ... @@ -238,6 +241,17 @@ class Tags: async def delete(self, tag: str): ... async def update(self, tag: str, version: int): ... +class Branches: + async def list(self) -> Dict[str, Any]: ... + async def create( + self, + name: str, + from_ref: Optional[str] = None, + from_version: Optional[int] = None, + ) -> Table: ... + async def checkout(self, name: str) -> Table: ... + async def delete(self, name: str) -> None: ... + class IndexConfig: name: str index_type: str diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 76b3b4aeb..d7ec073ed 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -416,6 +416,7 @@ class DBConnection(EnforceOverrides): namespace_path: Optional[List[str]] = None, storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, + branch: Optional[str] = None, ) -> Table: """Open a Lance Table in the database. @@ -444,6 +445,9 @@ class DBConnection(EnforceOverrides): connection will be inherited by the table, but can be overridden here. See available options at + branch: str, optional + If provided, open a handle scoped to this branch instead of the + default branch. Reads and writes operate in the branch's context. Returns ------- @@ -958,6 +962,7 @@ class LanceDBConnection(DBConnection): namespace_path: Optional[List[str]] = None, storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, + branch: Optional[str] = None, ) -> LanceTable: """Open a table in the database. @@ -968,6 +973,9 @@ class LanceDBConnection(DBConnection): namespace_path: List[str], optional The namespace to open the table from. When non-empty, the table is resolved through the directory namespace client. + branch: str, optional + If provided, open a handle scoped to this branch instead of the + default branch. Reads and writes operate in the branch's context. Returns ------- @@ -987,20 +995,24 @@ class LanceDBConnection(DBConnection): ) if namespace_path: - return self._namespace_conn().open_table( + tbl = self._namespace_conn().open_table( + name, + namespace_path=namespace_path, + storage_options=storage_options, + index_cache_size=index_cache_size, + ) + else: + tbl = LanceTable.open( + self, name, namespace_path=namespace_path, storage_options=storage_options, index_cache_size=index_cache_size, ) - return LanceTable.open( - self, - name, - namespace_path=namespace_path, - storage_options=storage_options, - index_cache_size=index_cache_size, - ) + if branch is not None: + return tbl.branches.checkout(branch) + return tbl def clone_table( self, @@ -1641,6 +1653,7 @@ class AsyncConnection(object): location: Optional[str] = None, namespace_client: Optional[Any] = None, managed_versioning: Optional[bool] = None, + branch: Optional[str] = None, ) -> AsyncTable: """Open a Lance Table in the database. @@ -1676,6 +1689,9 @@ class AsyncConnection(object): managed_versioning: bool, optional Whether managed versioning is enabled for this table. If provided, avoids a redundant describe_table call when namespace_client is set. + branch: str, optional + If provided, open a handle scoped to this branch instead of the + default branch. Reads and writes operate in the branch's context. Returns ------- @@ -1692,7 +1708,10 @@ class AsyncConnection(object): namespace_client=namespace_client, managed_versioning=managed_versioning, ) - return AsyncTable(table) + tbl = AsyncTable(table) + if branch is not None: + return await tbl.branches.checkout(branch) + return tbl async def clone_table( self, diff --git a/python/python/lancedb/namespace.py b/python/python/lancedb/namespace.py index 8784bc19b..5cc1e5185 100644 --- a/python/python/lancedb/namespace.py +++ b/python/python/lancedb/namespace.py @@ -549,6 +549,7 @@ class LanceNamespaceDBConnection(DBConnection): namespace_path: Optional[List[str]] = None, storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, + branch: Optional[str] = None, ) -> Table: if namespace_path is None: namespace_path = [] @@ -567,7 +568,7 @@ class LanceNamespaceDBConnection(DBConnection): raise TableNotFoundError(f"Table not found: {'$'.join(table_id)}") raise - return LanceTable( + tbl = LanceTable( self, name, namespace_path=namespace_path, @@ -575,6 +576,9 @@ class LanceNamespaceDBConnection(DBConnection): pushdown_operations=self._namespace_client_pushdown_operations, _async=async_table, ) + if branch is not None: + return tbl.branches.checkout(branch) + return tbl @override def drop_table(self, name: str, namespace_path: Optional[List[str]] = None): @@ -984,6 +988,7 @@ class AsyncLanceNamespaceDBConnection: namespace_path: Optional[List[str]] = None, storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, + branch: Optional[str] = None, ) -> AsyncTable: """Open an existing table from the namespace.""" if namespace_path is None: @@ -1000,6 +1005,8 @@ class AsyncLanceNamespaceDBConnection: table_id = namespace_path + [name] raise TableNotFoundError(f"Table not found: {'$'.join(table_id)}") raise + if branch is not None: + table = await table.branches.checkout(branch) return table._set_namespace_context( namespace_path=namespace_path, namespace_client=self._namespace_client, diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 4421b057c..02fb5942f 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -383,6 +383,7 @@ class RemoteDBConnection(DBConnection): namespace_path: Optional[List[str]] = None, storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, + branch: Optional[str] = None, ) -> Table: """Open a Lance Table in the database. @@ -400,6 +401,9 @@ class RemoteDBConnection(DBConnection): """ from .table import RemoteTable + if branch is not None: + raise NotImplementedError("branching is not yet supported on remote tables") + if namespace_path is None: namespace_path = [] if index_cache_size is not None: diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index e57774ca2..468ac240a 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -784,6 +784,15 @@ class Table(ABC): """ raise NotImplementedError + @property + def branches(self) -> "Branches": + """Branch management for the table. + + Branches are isolated, writable lines of history forked from another + branch (or version). Writes on a branch do not affect ``main``. + """ + raise NotImplementedError + def __len__(self) -> int: """The number of rows in this Table""" return self.count_rows(None) @@ -2112,22 +2121,27 @@ class LanceTable(Table): "Please install with `pip install pylance`." ) + branch = self.current_branch() + version = None if branch is not None else self.version if self._namespace_client is not None: table_id = self._namespace_path + [self.name] - return lance.dataset( - version=self.version, + ds = lance.dataset( + version=version, storage_options=self._conn.storage_options, namespace_client=self._namespace_client, table_id=table_id, **kwargs, ) - - return lance.dataset( - self._dataset_path, - version=self.version, - storage_options=self._conn.storage_options, - **kwargs, - ) + else: + ds = lance.dataset( + self._dataset_path, + version=version, + storage_options=self._conn.storage_options, + **kwargs, + ) + if branch is not None: + ds = ds.checkout_version((branch, self.version)) + return ds @property def schema(self) -> pa.Schema: @@ -2193,6 +2207,19 @@ class LanceTable(Table): """ return Tags(self._table) + @property + def branches(self) -> "Branches": + """Branch management for the table. + + ``create``/``checkout`` return a new table handle scoped to the branch; + writes on it do not affect ``main``. + """ + return Branches(self) + + def current_branch(self) -> Optional[str]: + """The branch this table handle is scoped to, or ``None`` for ``main``.""" + return self._table.current_branch() + def checkout(self, version: Union[int, str]): """Checkout a version of the table. This is an in-place operation. @@ -3457,8 +3484,14 @@ class LanceTable(Table): batch_size: Optional[int] = None, timeout: Optional[timedelta] = None, ) -> pa.RecordBatchReader: - if _should_push_down_query_table( - self._namespace_client, self._pushdown_operations + # Branch queries run locally: the server-side query protocol can't + # carry a branch yet. + # TODO: push down server-side once it can (with remote table support). + if ( + _should_push_down_query_table( + self._namespace_client, self._pushdown_operations + ) + and self.current_branch() is None ): from lancedb.namespace import _execute_server_side_query @@ -4385,12 +4418,20 @@ class AsyncTable: "Please install with `pip install pylance`." ) - return lance.dataset( + # lance.dataset() can't open a branch directly, so open the base table + # and check out the branch ref (a None branch resolves to main). + branch = self.current_branch() + table_version = await self.version() + version = None if branch is not None else table_version + ds = lance.dataset( await self.uri(), - version=await self.version(), + version=version, storage_options=await self.latest_storage_options(), **kwargs, ) + if branch is not None: + ds = ds.checkout_version((branch, table_version)) + return ds async def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs) -> "pd.DataFrame": """Return the table as a pandas DataFrame. @@ -5521,6 +5562,19 @@ class AsyncTable: """ return AsyncTags(self._inner) + @property + def branches(self) -> AsyncBranches: + """Branch management for the table. + + Branches are isolated, writable lines of history forked from another + branch (or version). Writes on a branch do not affect ``main``. + """ + return AsyncBranches(self._inner) + + def current_branch(self) -> Optional[str]: + """The branch this table handle is scoped to, or ``None`` for ``main``.""" + return self._inner.current_branch() + async def optimize( self, *, @@ -5853,6 +5907,65 @@ class Tags: LOOP.run(self._table.tags.update(tag, version)) +class Branches: + """ + Table branch manager. + """ + + def __init__(self, parent: "LanceTable"): + self._parent = parent + self._table = parent._table + + def list(self) -> Dict[str, Any]: + """List all branches, mapping name to branch metadata.""" + return LOOP.run(self._table.branches.list()) + + def create( + self, + name: str, + from_ref: Optional[str] = None, + from_version: Optional[int] = None, + ) -> "LanceTable": + """Create a branch and return a handle scoped to it. + + Parameters + ---------- + name: str + Name of the new branch. + from_ref: str, optional + Source branch to fork from. Defaults to ``main``. + from_version: int, optional + A specific version on ``from_ref`` to fork from. Defaults to latest. + """ + async_table = LOOP.run( + self._table.branches.create(name, from_ref, from_version) + ) + return self._wrap(async_table) + + def checkout(self, name: str) -> "LanceTable": + """Check out an existing branch and return a handle scoped to it.""" + async_table = LOOP.run(self._table.branches.checkout(name)) + return self._wrap(async_table) + + def delete(self, name: str) -> None: + """Delete a branch.""" + LOOP.run(self._table.branches.delete(name)) + + def _wrap(self, async_table: "AsyncTable") -> "LanceTable": + # Reuse the parent's connection + namespace context; from_inner would drop + # it and break identity/query routing for namespace-backed tables. + parent = self._parent + return LanceTable( + parent._conn, + async_table.name, + namespace_path=parent._namespace_path, + namespace_client=parent._namespace_client, + pushdown_operations=parent._pushdown_operations, + location=parent._location, + _async=async_table, + ) + + class AsyncTags: """ Async table tag manager. @@ -5920,3 +6033,47 @@ class AsyncTags: The new table version to tag. """ await self._table.tags.update(tag, version) + + +class AsyncBranches: + """Async table branch manager.""" + + def __init__(self, table): + self._table = table + + async def list(self) -> Dict[str, Any]: + """List all branches, mapping name to branch metadata.""" + return await self._table.branches.list() + + async def create( + self, + name: str, + from_ref: Optional[str] = None, + from_version: Optional[int] = None, + ) -> "AsyncTable": + """Create a branch and return a handle scoped to it. + + Parameters + ---------- + name: str + Name of the new branch. + from_ref: str, optional + Source branch to fork from. Defaults to ``main``. + from_version: int, optional + A specific version on ``from_ref`` to fork from. Defaults to latest. + """ + # "main" and None are two spellings of the root branch in lance; normalize + # so from_ref="main" behaves identically to the default. + if from_ref == "main": + from_ref = None + inner = await self._table.branches.create(name, from_ref, from_version) + return AsyncTable(inner) + + async def checkout(self, name: str) -> "AsyncTable": + """Check out an existing branch and return a handle scoped to it.""" + inner = await self._table.branches.checkout(name) + return AsyncTable(inner) + + async def delete(self, name: str) -> None: + """Delete a branch.""" + await self._table.branches.delete(name) diff --git a/python/python/tests/test_namespace.py b/python/python/tests/test_namespace.py index 8417e7384..f51690365 100644 --- a/python/python/tests/test_namespace.py +++ b/python/python/tests/test_namespace.py @@ -28,6 +28,10 @@ def _ipc_file(table: pa.Table = PUSHDOWN_DATA) -> bytes: class _FailingSyncInner: name = "hist" + def current_branch(self): + # The pushdown gate only routes server-side when on the default branch. + return None + async def schema(self): return PUSHDOWN_DATA.schema diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 58c085f4e..880b77b04 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -928,6 +928,160 @@ async def test_async_tags(mem_db_async: AsyncConnection): ) +def test_branches(tmp_path): + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(0)) + table = db.create_table( + "test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], + ) + assert table.count_rows() == 2 + + # fork an isolated, writable branch from main + branch = table.branches.create("exp") + assert branch.count_rows() == 2 + branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}]) + + # writes on the branch do not touch main + assert branch.count_rows() == 3 + assert table.count_rows() == 2 + + # the branch is listed, with main (None) as its parent + branches = table.branches.list() + assert "exp" in branches + assert branches["exp"]["parent_branch"] is None + + # from_ref="main" is equivalent to the default + table.branches.create("exp2", from_ref="main") + assert table.branches.list()["exp2"]["parent_branch"] is None + + # checkout returns a handle scoped to the branch's latest + checked_out = table.branches.checkout("exp") + assert checked_out.count_rows() == 3 + + # delete removes it + table.branches.delete("exp") + table.branches.delete("exp2") + assert "exp" not in table.branches.list() + + +def test_branch_handle_tracks_concurrent_writes(tmp_path): + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(0)) + table = db.create_table("t", [{"id": 1}]) + + # two independent handles on the same branch + writer = table.branches.create("exp") + reader = db.open_table("t", branch="exp") + assert reader.count_rows() == 1 + + # a concurrent write on the branch is visible to the other handle + writer.add([{"id": 2}]) + assert reader.count_rows() == 2 + # main is unaffected + assert table.count_rows() == 1 + + +def test_branch_name_validation(tmp_path): + db = lancedb.connect(tmp_path) + table = db.create_table("t", [{"id": 1}]) + + with pytest.raises(ValueError, match="non-empty"): + table.branches.create("") + with pytest.raises(ValueError, match="non-empty"): + table.branches.checkout("") + with pytest.raises(ValueError, match="non-empty"): + table.branches.delete("") + + +def test_branches_preserve_namespace(tmp_path): + pytest.importorskip( + "lance" + ) # namespace_path routes through lance's DirectoryNamespace + db = lancedb.connect(tmp_path) + table = db.create_table("t", [{"id": 1}], namespace_path=["ns1"]) + assert table.namespace == ["ns1"] + + branch = table.branches.create("exp") + assert branch.namespace == ["ns1"] + assert branch.id == table.id + + # opening the branch directly also preserves namespace identity + opened = db.open_table("t", namespace_path=["ns1"], branch="exp") + assert opened.namespace == ["ns1"] + + +def test_open_table_with_branch(tmp_path): + db = lancedb.connect(tmp_path) + table = db.create_table("t", [{"i": 1}]) + table.branches.create("exp").add([{"i": 2}]) + + # open_table(branch=...) returns a handle scoped to the branch + assert db.open_table("t", branch="exp").count_rows() == 2 + # opening without branch still tracks main + assert db.open_table("t").count_rows() == 1 + + +@pytest.mark.asyncio +async def test_async_namespace_open_table_with_branch(tmp_path): + pytest.importorskip("lance") # "dir" impl is lance.namespace.DirectoryNamespace + db = lancedb.connect_namespace_async("dir", {"root": str(tmp_path)}) + await db.create_namespace(["ns1"]) + table = await db.create_table("t", [{"id": 1}], namespace_path=["ns1"]) + branch = await table.branches.create("exp") + await branch.add([{"id": 2}]) + + # open_table(branch=...) on the async namespace connection must work + opened = await db.open_table("t", namespace_path=["ns1"], branch="exp") + assert await opened.count_rows() == 2 + + +def test_branch_to_lance_targets_branch(tmp_path): + pytest.importorskip("lance") + db = lancedb.connect(tmp_path) + table = db.create_table("t", [{"i": 1}]) + branch = table.branches.create("exp") + branch.add([{"i": 2}]) # branch: 2 rows, main: 1 row + + assert branch.to_lance().count_rows() == 2 + assert table.to_lance().count_rows() == 1 + + +@pytest.mark.asyncio +async def test_async_branches(tmp_path): + db = await lancedb.connect_async(tmp_path) + table = await db.create_table( + "test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], + ) + assert await table.count_rows() == 2 + + branch = await table.branches.create("exp") + assert await branch.count_rows() == 2 + await branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}]) + + assert await branch.count_rows() == 3 + assert await table.count_rows() == 2 + + branches = await table.branches.list() + assert "exp" in branches + assert branches["exp"]["parent_branch"] is None + + await table.branches.create("exp2", from_ref="main") + assert (await table.branches.list())["exp2"]["parent_branch"] is None + + checked_out = await table.branches.checkout("exp") + assert await checked_out.count_rows() == 3 + + await table.branches.delete("exp") + await table.branches.delete("exp2") + assert "exp" not in await table.branches.list() + + @patch("lancedb.table.AsyncTable.create_index") def test_create_index_method(mock_create_index, mem_db: DBConnection): table = mem_db.create_table( diff --git a/python/src/table.rs b/python/src/table.rs index ea98d447e..4cc6864b1 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -17,7 +17,7 @@ use arrow::{ }; use lancedb::table::{ AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, NewColumnTransform, - OptimizeAction, OptimizeOptions, Table as LanceDbTable, + OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable, }; use pyo3::{ Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python, @@ -860,6 +860,15 @@ impl Table { Ok(Tags::new(self.inner_ref()?.clone())) } + pub fn current_branch(&self) -> PyResult> { + Ok(self.inner_ref()?.current_branch()) + } + + #[getter] + pub fn branches(&self) -> PyResult { + Ok(Branches::new(self.inner_ref()?.clone())) + } + #[pyo3(signature = (offsets))] pub fn take_offsets(self_: PyRef<'_, Self>, offsets: Vec) -> PyResult { Ok(TakeQuery::new( @@ -1261,3 +1270,66 @@ impl Tags { }) } } + +#[pyclass] +pub struct Branches { + inner: LanceDbTable, +} + +impl Branches { + pub fn new(table: LanceDbTable) -> Self { + Self { inner: table } + } +} + +#[pymethods] +impl Branches { + pub fn list(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let res = inner.list_branches().await.infer_error()?; + Python::attach(|py| { + let py_dict = PyDict::new(py); + for (name, contents) in res { + let value = PyDict::new(py); + value.set_item("parent_branch", contents.parent_branch)?; + value.set_item("parent_version", contents.parent_version)?; + value.set_item("manifest_size", contents.manifest_size)?; + py_dict.set_item(name, value)?; + } + Ok(py_dict.unbind()) + }) + }) + } + + #[pyo3(signature = (name, from_ref=None, from_version=None))] + pub fn create( + self_: PyRef<'_, Self>, + name: String, + from_ref: Option, + from_version: Option, + ) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let from = Ref::Version(from_ref, from_version); + let table = inner.create_branch(&name, from).await.infer_error()?; + Ok(Table::new(table)) + }) + } + + pub fn checkout(self_: PyRef<'_, Self>, name: String) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let table = inner.checkout_branch(&name).await.infer_error()?; + Ok(Table::new(table)) + }) + } + + pub fn delete(self_: PyRef<'_, Self>, name: String) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + inner.delete_branch(&name).await.infer_error()?; + Ok(()) + }) + } +} diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index c1d475c9c..a478abfea 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -119,6 +119,7 @@ pub struct OpenTableBuilder { parent: Arc, request: OpenTableRequest, embedding_registry: Arc, + branch: Option, } impl OpenTableBuilder { @@ -139,6 +140,7 @@ impl OpenTableBuilder { managed_versioning: None, }, embedding_registry, + branch: None, } } @@ -259,14 +261,22 @@ impl OpenTableBuilder { self } + /// Open the table scoped to the given branch instead of the default branch. + /// + /// Reads and writes on the returned table operate in the branch's context. + pub fn branch(mut self, branch: impl Into) -> Self { + self.branch = Some(branch.into()); + self + } + /// Open the table pub async fn execute(self) -> Result
{ let table = self.parent.open_table(self.request).await?; - Ok(Table::new_with_embedding_registry( - table, - self.parent, - self.embedding_registry, - )) + let table = Table::new_with_embedding_registry(table, self.parent, self.embedding_registry); + match self.branch { + Some(branch) => table.checkout_branch(&branch).await, + None => Ok(table), + } } } diff --git a/rust/lancedb/src/database/namespace.rs b/rust/lancedb/src/database/namespace.rs index 4a0b09e05..fda665a31 100644 --- a/rust/lancedb/src/database/namespace.rs +++ b/rust/lancedb/src/database/namespace.rs @@ -740,6 +740,64 @@ mod tests { assert!(table_names.contains(&"test_table".to_string())); } + #[tokio::test] + async fn test_namespace_branch_query_under_pushdown_stays_local() { + // With QueryTable pushdown enabled, a query on the main branch routes to + // the namespace server, but a branch handle must run locally: the + // server-side request carries no branch and would return main's rows. + let tmp_dir = tempdir().unwrap(); + let root_path = tmp_dir.path().to_str().unwrap().to_string(); + + let mut properties = HashMap::new(); + properties.insert("root".to_string(), root_path); + + let conn = connect_namespace("dir", properties) + .pushdown_operation(NamespaceClientPushdownOperation::QueryTable) + .execute() + .await + .expect("Failed to connect to namespace"); + + conn.create_namespace(CreateNamespaceRequest { + id: Some(vec!["test_ns".into()]), + ..Default::default() + }) + .await + .expect("Failed to create namespace"); + + // main has 5 rows + let table = conn + .create_table("ref_test", create_test_data()) + .namespace(vec!["test_ns".into()]) + .execute() + .await + .expect("Failed to create table"); + let main_version = table.version().await.unwrap(); + + // fork a branch off main, then add 5 more rows so it differs from main + let branch = table + .create_branch("exp", main_version) + .await + .expect("Failed to create branch"); + branch + .add(create_test_data()) + .execute() + .await + .expect("Failed to append to branch"); + + // the branch query must run locally and see the branch's 10 rows -- + // not get routed to the server (which carries no branch) and see main's 5 + let results = branch + .query() + .execute() + .await + .expect("Failed to query branch") + .try_collect::>() + .await + .expect("Failed to collect results"); + let count: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(count, 10); + } + #[tokio::test] async fn test_namespace_describe_table() { // Setup: Create a temporary directory for the namespace diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index fa1355254..1ccbcbb1d 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1384,6 +1384,38 @@ impl BaseTable for RemoteTable { .map_err(unwrap_shared_error) } + async fn create_branch( + &self, + _name: &str, + _from: lance::dataset::refs::Ref, + ) -> Result> { + Err(Error::NotSupported { + message: "branching is not yet supported on remote tables".into(), + }) + } + + async fn checkout_branch(&self, _name: &str) -> Result> { + Err(Error::NotSupported { + message: "branching is not yet supported on remote tables".into(), + }) + } + + async fn list_branches(&self) -> Result> { + Err(Error::NotSupported { + message: "branching is not yet supported on remote tables".into(), + }) + } + + async fn delete_branch(&self, _name: &str) -> Result<()> { + Err(Error::NotSupported { + message: "branching is not yet supported on remote tables".into(), + }) + } + + fn current_branch(&self) -> Option { + None + } + async fn count_rows(&self, filter: Option) -> Result { let mut request = self.post_read(&format!("/v1/table/{}/count_rows/", self.identifier)); diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index cdd8edfcf..9dc1dc1e9 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -86,7 +86,7 @@ pub use add_data::{AddDataBuilder, AddDataMode, AddResult, NaNVectorBehavior}; pub use chrono::Duration; pub use delete::DeleteResult; use futures::future::join_all; -pub use lance::dataset::refs::{TagContents, Tags as LanceTags}; +pub use lance::dataset::refs::{BranchContents, Ref, TagContents, Tags as LanceTags}; pub use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::statistics::DatasetStatisticsExt; pub use lance_index::optimize::OptimizeOptions; @@ -625,6 +625,20 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { async fn restore(&self) -> Result<()>; /// List the versions of the table. async fn list_versions(&self) -> Result>; + /// Create a new branch from `from` and return a handle scoped to it. + async fn create_branch( + &self, + name: &str, + from: lance::dataset::refs::Ref, + ) -> Result>; + /// Check out an existing branch and return a handle scoped to it. + async fn checkout_branch(&self, name: &str) -> Result>; + /// List the branches of the table. + async fn list_branches(&self) -> Result>; + /// Delete a branch. + async fn delete_branch(&self, name: &str) -> Result<()>; + /// The branch this handle is scoped to, or `None` for `main`. + fn current_branch(&self) -> Option; /// Get the table definition. async fn table_definition(&self) -> Result; /// Get the table URI (storage location) @@ -1625,6 +1639,45 @@ impl Table { self.inner.tags().await } + /// Create a new branch from `from` (a version, tag, or branch) + pub async fn create_branch( + &self, + name: &str, + from: impl Into, + ) -> Result { + let inner = self.inner.create_branch(name, from.into()).await?; + Ok(Self { + inner, + database: self.database.clone(), + embedding_registry: self.embedding_registry.clone(), + }) + } + + /// Check out an existing branch and return a handle scoped to it. + pub async fn checkout_branch(&self, name: &str) -> Result { + let inner = self.inner.checkout_branch(name).await?; + Ok(Self { + inner, + database: self.database.clone(), + embedding_registry: self.embedding_registry.clone(), + }) + } + + /// List the branches of the table. + pub async fn list_branches(&self) -> Result> { + self.inner.list_branches().await + } + + /// Delete a branch. + pub async fn delete_branch(&self, name: &str) -> Result<()> { + self.inner.delete_branch(name).await + } + + /// The branch this handle is scoped to, or `None` for `main`. + pub fn current_branch(&self) -> Option { + self.inner.current_branch() + } + /// Retrieve statistics on the table pub async fn stats(&self) -> Result { self.inner.stats().await @@ -1861,6 +1914,30 @@ impl NativeTable { self } + /// Build a sibling `NativeTable` with the same identity but a different + /// (independent) dataset wrapper — used to hand out branch-scoped handles. + fn with_dataset(&self, dataset: dataset::DatasetConsistencyWrapper) -> Self { + Self { + name: self.name.clone(), + namespace: self.namespace.clone(), + id: self.id.clone(), + uri: self.uri.clone(), + dataset, + read_consistency_interval: self.read_consistency_interval, + namespace_client: self.namespace_client.clone(), + pushdown_operations: self.pushdown_operations.clone(), + } + } + + fn validate_branch_name(name: &str, field: &str) -> Result<()> { + if name.is_empty() { + return Err(Error::InvalidInput { + message: format!("{field} must be a non-empty string"), + }); + } + Ok(()) + } + /// Opens an existing Table using a namespace client. /// /// This method uses `DatasetBuilder::from_namespace` to open the table, which @@ -2652,6 +2729,49 @@ impl BaseTable for NativeTable { self.dataset.reload().await } + async fn create_branch( + &self, + name: &str, + from: lance::dataset::refs::Ref, + ) -> Result> { + Self::validate_branch_name(name, "branch name")?; + if let lance::dataset::refs::Ref::Version(Some(from_branch), _) = &from { + Self::validate_branch_name(from_branch, "from_ref")?; + } + let mut ds = (*self.dataset.get().await?).clone(); + let branch_ds = ds.create_branch(name, from, None).await?; + let dataset = dataset::DatasetConsistencyWrapper::new_latest( + branch_ds, + self.read_consistency_interval, + ); + Ok(Arc::new(self.with_dataset(dataset))) + } + + async fn checkout_branch(&self, name: &str) -> Result> { + Self::validate_branch_name(name, "branch name")?; + let branch_ds = self.dataset.get().await?.checkout_branch(name).await?; + let dataset = dataset::DatasetConsistencyWrapper::new_latest( + branch_ds, + self.read_consistency_interval, + ); + Ok(Arc::new(self.with_dataset(dataset))) + } + + async fn list_branches(&self) -> Result> { + Ok(self.dataset.get().await?.list_branches().await?) + } + + async fn delete_branch(&self, name: &str) -> Result<()> { + Self::validate_branch_name(name, "branch name")?; + let mut ds = (*self.dataset.get().await?).clone(); + ds.delete_branch(name).await?; + Ok(()) + } + + fn current_branch(&self) -> Option { + self.dataset.current_branch() + } + async fn list_versions(&self) -> Result> { Ok(self.dataset.get().await?.versions().await?) } @@ -3370,6 +3490,171 @@ mod tests { assert_eq!(table.version().await.unwrap(), 4); } + #[tokio::test] + async fn test_branches() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + let conn = ConnectBuilder::new(uri) + .read_consistency_interval(Duration::from_secs(0)) + .execute() + .await + .unwrap(); + + // main: one row at v1 + let table = conn + .create_table("my_table", some_sample_data()) + .execute() + .await + .unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 1); + assert_eq!(table.current_branch(), None); + let main_version = table.version().await.unwrap(); + + // branch off main's current version; it starts with main's data + let branch = table.create_branch("exp", main_version).await.unwrap(); + assert_eq!(branch.current_branch().as_deref(), Some("exp")); + assert_eq!(branch.count_rows(None).await.unwrap(), 1); + + // writes on the branch are isolated from main + branch.add(some_sample_data()).execute().await.unwrap(); + assert_eq!(branch.count_rows(None).await.unwrap(), 2); + assert_eq!( + table.count_rows(None).await.unwrap(), + 1, + "main must be untouched by branch writes" + ); + + // the branch shows up in the listing + let branches = table.list_branches().await.unwrap(); + assert!(branches.contains_key("exp")); + + // checking out the branch from the main handle sees the branch's latest data + let checked_out = table.checkout_branch("exp").await.unwrap(); + assert_eq!(checked_out.current_branch().as_deref(), Some("exp")); + assert_eq!(checked_out.count_rows(None).await.unwrap(), 2); + + // open_table(...).branch(...) opens directly onto the branch + let opened = conn + .open_table("my_table") + .branch("exp") + .execute() + .await + .unwrap(); + assert_eq!(opened.current_branch().as_deref(), Some("exp")); + assert_eq!(opened.count_rows(None).await.unwrap(), 2); + + // delete removes it from the listing + table.delete_branch("exp").await.unwrap(); + let branches = table.list_branches().await.unwrap(); + assert!(!branches.contains_key("exp")); + } + + #[tokio::test] + async fn test_branch_name_validation() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let conn = ConnectBuilder::new(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", some_sample_data()) + .execute() + .await + .unwrap(); + + // every entry point rejects an empty name instead of passing it down + assert!(matches!( + table.create_branch("", 1u64).await, + Err(Error::InvalidInput { .. }) + )); + assert!(matches!( + table.checkout_branch("").await, + Err(Error::InvalidInput { .. }) + )); + assert!(matches!( + table.delete_branch("").await, + Err(Error::InvalidInput { .. }) + )); + // an empty source branch is rejected too + assert!(matches!( + table + .create_branch( + "ok", + lance::dataset::refs::Ref::Version(Some(String::new()), None) + ) + .await, + Err(Error::InvalidInput { .. }) + )); + } + + #[tokio::test] + async fn test_branch_handle_tracks_concurrent_writes() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + // interval = 0 so every read checks storage for new commits + let conn = ConnectBuilder::new(uri) + .read_consistency_interval(Duration::from_secs(0)) + .execute() + .await + .unwrap(); + let table = conn + .create_table("my_table", some_sample_data()) + .execute() + .await + .unwrap(); + let v1 = table.version().await.unwrap(); + + // two independent handles on the same branch + let writer = table.create_branch("exp", v1).await.unwrap(); + let reader = conn + .open_table("my_table") + .branch("exp") + .execute() + .await + .unwrap(); + assert_eq!(reader.count_rows(None).await.unwrap(), 1); + + // a concurrent write on the branch is visible to the other handle, which + // tracks the branch's HEAD (not main's) + writer.add(some_sample_data()).execute().await.unwrap(); + assert_eq!(reader.count_rows(None).await.unwrap(), 2); + // main is untouched + assert_eq!(table.count_rows(None).await.unwrap(), 1); + } + + #[tokio::test] + async fn test_branch_handle_without_consistency_interval_is_pinned() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + // default interval (None): handles do not auto-refresh + let conn = ConnectBuilder::new(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", some_sample_data()) + .execute() + .await + .unwrap(); + let v1 = table.version().await.unwrap(); + + let writer = table.create_branch("exp", v1).await.unwrap(); + let reader = conn + .open_table("my_table") + .branch("exp") + .execute() + .await + .unwrap(); + assert_eq!(reader.count_rows(None).await.unwrap(), 1); + + // without a consistency interval the reader stays on the version it + // opened, exactly like a main-branch handle... + writer.add(some_sample_data()).execute().await.unwrap(); + assert_eq!(reader.count_rows(None).await.unwrap(), 1); + + // ...until it explicitly refreshes + reader.checkout_latest().await.unwrap(); + assert_eq!(reader.count_rows(None).await.unwrap(), 2); + } + #[tokio::test] async fn test_create_index() { use arrow_array::RecordBatch; diff --git a/rust/lancedb/src/table/dataset.rs b/rust/lancedb/src/table/dataset.rs index b4673d876..1ff11198e 100644 --- a/rust/lancedb/src/table/dataset.rs +++ b/rust/lancedb/src/table/dataset.rs @@ -144,8 +144,19 @@ impl DatasetConsistencyWrapper { } /// Checkout a branch and track its HEAD for new versions. - pub async fn as_branch(&self, _branch: impl Into) -> Result<()> { - todo!("Branch support not yet implemented") + pub async fn as_branch(&self, branch: impl Into) -> Result<()> { + let branch = branch.into(); + let dataset = { self.state.lock()?.dataset.clone() }; + let new_dataset = dataset.checkout_branch(&branch).await?; + + let mut state = self.state.lock()?; + state.dataset = Arc::new(new_dataset); + state.pinned_version = None; + drop(state); + if let ConsistencyMode::Eventual(bg_cache) = &self.consistency { + bg_cache.invalidate(); + } + Ok(()) } /// Check that the dataset is in a mutable mode (Latest). @@ -161,6 +172,17 @@ impl DatasetConsistencyWrapper { } } + /// The branch this wrapper is currently tracking, or `None` for `main`. + pub fn current_branch(&self) -> Option { + self.state + .lock() + .unwrap_or_else(|e| e.into_inner()) + .dataset + .manifest() + .branch + .clone() + } + /// Returns the version, if in time travel mode, or None otherwise. pub fn time_travel_version(&self) -> Option { self.state @@ -737,4 +759,31 @@ mod tests { let result = wrapper.reload().await; assert!(result.is_err()); } + + #[tokio::test] + async fn test_as_branch_is_writable_and_tracked() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + + // v1 on main, then shallow-clone a branch off it + let mut ds = create_test_dataset(uri).await; + let v1 = ds.version().version; + ds.create_branch("exp", v1, None).await.unwrap(); + + // wrapper starts on main: latest, writable, no branch + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + assert_eq!(wrapper.current_branch(), None); + + // switch to the branch + wrapper.as_branch("exp").await.unwrap(); + assert_eq!(wrapper.current_branch().as_deref(), Some("exp")); + + // a branch is writable (unlike a pinned/time-travel checkout) + wrapper.ensure_mutable().unwrap(); + assert_eq!(wrapper.time_travel_version(), None); + + // get() returns the branch dataset + let on_branch = wrapper.get().await.unwrap(); + assert_eq!(on_branch.manifest().branch.as_deref(), Some("exp")); + } } diff --git a/rust/lancedb/src/table/query.rs b/rust/lancedb/src/table/query.rs index cc9312a0f..b136de2cd 100644 --- a/rust/lancedb/src/table/query.rs +++ b/rust/lancedb/src/table/query.rs @@ -41,11 +41,14 @@ pub async fn execute_query( query: &AnyQuery, options: QueryExecutionOptions, ) -> Result { - // If QueryTable pushdown is enabled and namespace client is configured, use server-side query execution + // QueryTable pushdown runs the query server-side, but only on the main + // branch: the namespace request carries no branch yet, so a branch handle + // must fall through to local execution. if table .pushdown_operations .contains(&NamespaceClientPushdownOperation::QueryTable) && let Some(ref namespace_client) = table.namespace_client + && table.dataset.current_branch().is_none() { return execute_namespace_query(table, namespace_client.clone(), query, options).await; }