diff --git a/docs/src/js/classes/Table.md b/docs/src/js/classes/Table.md index a7a4dfafc..646fe48c4 100644 --- a/docs/src/js/classes/Table.md +++ b/docs/src/js/classes/Table.md @@ -295,6 +295,23 @@ await table.createIndex("my_float_col"); *** +### currentBranch() + +```ts +abstract currentBranch(): null | string +``` + +The branch this table handle is scoped to, or `null` for the main branch. + +A handle returned by [Branches.create](Branches.md#create) or [Branches.checkout](Branches.md#checkout) +reports the branch it targets; a handle opened normally reports `null`. + +#### Returns + +`null` \| `string` + +*** + ### delete() ```ts diff --git a/nodejs/__test__/remote.test.ts b/nodejs/__test__/remote.test.ts index a2a67014d..d665fe6a9 100644 --- a/nodejs/__test__/remote.test.ts +++ b/nodejs/__test__/remote.test.ts @@ -191,30 +191,36 @@ describe("remote connection", () => { ); }); - it("allows version on remote but rejects a non-main branch", async () => { + it("supports version time-travel and branches on remote", async () => { await withMockDatabase( - (_req, res) => { - // describe (table open + version validation) always succeeds - const body = JSON.stringify({ - name: "t", - version: 2, - schema: { fields: [] }, - }); + (req, res) => { + const body = req.url?.includes("/branches/list") + ? JSON.stringify({ + branches: { + exp: { parentVersion: 1, createAt: 1, manifestSize: 1 }, + }, + }) + : JSON.stringify({ name: "t", version: 2, schema: { fields: [] } }); res.writeHead(200, { "Content-Type": "application/json" }).end(body); }, async (db) => { - // version-only (and "main" + version) is allowed: remote supports - // version time-travel even though it has no branches - await db.openTable("t", undefined, { version: 2 }); - await db.openTable("t", undefined, { branch: "main", version: 2 }); + // version-only (and "main" + version) time-travel the main chain + const v2 = await db.openTable("t", undefined, { version: 2 }); + expect(v2.currentBranch()).toBeNull(); + const mainV2 = await db.openTable("t", undefined, { + branch: "main", + version: 2, + }); + expect(mainV2.currentBranch()).toBeNull(); - // a non-main branch is rejected, with or without a version - await expect( - db.openTable("t", undefined, { branch: "exp" }), - ).rejects.toThrow(/branching/); - await expect( - db.openTable("t", undefined, { branch: "exp", version: 2 }), - ).rejects.toThrow(/branching/); + // a non-main branch opens a handle scoped to that branch + const exp = await db.openTable("t", undefined, { branch: "exp" }); + expect(exp.currentBranch()).toBe("exp"); + const expV2 = await db.openTable("t", undefined, { + branch: "exp", + version: 2, + }); + expect(expV2.currentBranch()).toBe("exp"); }, ); }); diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index c808f1679..f4ba37361 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -89,8 +89,11 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( await table.add([{ id: 1 }]); expect(await table.countRows()).toBe(1); + expect(table.currentBranch()).toBeNull(); + // fork an isolated, writable branch from main const branch = await (await table.branches()).create("exp"); + expect(branch.currentBranch()).toBe("exp"); expect(await branch.countRows()).toBe(1); await branch.add([{ id: 2 }]); expect(await branch.countRows()).toBe(2); @@ -109,6 +112,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( // checkout returns a handle scoped to the branch's latest const checkedOut = await (await table.branches()).checkout("exp"); + expect(checkedOut.currentBranch()).toBe("exp"); expect(await checkedOut.countRows()).toBe(2); // delete removes it diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 497d809ea..e8d70e9f2 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -663,6 +663,14 @@ export abstract class Table { */ abstract branches(): Promise; + /** + * The branch this table handle is scoped to, or `null` for the main branch. + * + * A handle returned by {@link Branches.create} or {@link Branches.checkout} + * reports the branch it targets; a handle opened normally reports `null`. + */ + abstract currentBranch(): string | null; + /** * Restore the table to the currently checked out version * @@ -1122,6 +1130,10 @@ export class LocalTable extends Table { return new Branches(await this.inner.branches()); } + currentBranch(): string | null { + return this.inner.currentBranch() ?? null; + } + async optimize(options?: Partial): Promise { let cleanupOlderThanMs; if ( diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index b53be573a..7558bb4b6 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -487,6 +487,12 @@ impl Table { }) } + /// The branch this handle is scoped to, or `null` for the main branch. + #[napi] + pub fn current_branch(&self) -> napi::Result> { + Ok(self.inner_ref()?.current_branch()) + } + #[napi(catch_unwind)] pub async fn optimize( &self, diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 4e5655fae..e57958958 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -396,13 +396,13 @@ class RemoteDBConnection(DBConnection): The namespace to open the table from. None or empty list represents root namespace. branch: str, optional - Branching is not yet supported on remote tables, so only the - default branch is accepted (``None`` or ``"main"``); any other - value raises ``NotImplementedError``. + If provided, open a handle scoped to this branch instead of the + default branch. Reads and writes operate in the branch's context. version: int, optional If provided, open the table pinned to this version, producing a - read-only handle. Call ``checkout_latest`` to return to a writable - state. + read-only handle. Composes with ``branch``: when both are given, + opens that branch at the version; otherwise opens ``main`` at the + version. Call ``checkout_latest`` to return to a writable state. Returns ------- @@ -410,11 +410,6 @@ class RemoteDBConnection(DBConnection): """ from .table import RemoteTable - # Remote supports version time-travel but not branches: reject a non-main - # branch, but allow a version-only open (or "main"). - if branch is not None and branch != "main": - raise NotImplementedError("branching is not yet supported on remote tables") - if namespace_path is None: namespace_path = [] if index_cache_size is not None: @@ -430,7 +425,9 @@ class RemoteDBConnection(DBConnection): connection_state=self.serialize, namespace_path=namespace_path, ) - if version is not None: + if branch is not None: + tbl = tbl.branches.checkout(branch, version) + elif version is not None: tbl.checkout(version) return tbl diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index b578f829a..5a3b5eb19 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -56,7 +56,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry from lancedb.table import _normalize_progress from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder -from ..table import AsyncTable, BlobMode, IndexStatistics, Query, Table, Tags +from ..table import AsyncTable, BlobMode, Branches, IndexStatistics, Query, Table, Tags from ..types import BaseTokenizerType @@ -75,6 +75,9 @@ class RemoteTable(Table): self._connection_state = connection_state self._namespace_path = list(namespace_path or []) self._checkout_version: Optional[int] = None + # The branch this handle is scoped to (None == main). Persisted so a + # fork/pickle reopen restores the branch instead of reverting to main. + self._branch: Optional[str] = None self._pid = os.getpid() def _serialized_connection_state(self) -> str: @@ -109,9 +112,14 @@ class RemoteTable(Table): from lancedb import deserialize_conn db = deserialize_conn(self._serialized_connection_state(), for_worker=True) - table = db.open_table(self._name, namespace_path=self._namespace_path) - if self._checkout_version is not None: - table.checkout(self._checkout_version) + # Reopen on the same branch and pinned version (branch=None / version=None + # reproduce the plain main-latest open). + table = db.open_table( + self._name, + namespace_path=self._namespace_path, + branch=self._branch, + version=self._checkout_version, + ) self._table_handle = table._table self.db_name = table.db_name @@ -124,6 +132,7 @@ class RemoteTable(Table): "name": self.name, "namespace_path": self._namespace_path, "checkout_version": self._checkout_version, + "branch": self._branch, } def __setstate__(self, state: dict) -> None: @@ -133,6 +142,7 @@ class RemoteTable(Table): self._connection_state = state["connection_state"] self._namespace_path = state["namespace_path"] self._checkout_version = state["checkout_version"] + self._branch = state.get("branch") self._pid = None @property @@ -160,6 +170,34 @@ class RemoteTable(Table): def tags(self) -> Tags: return Tags(self._table) + @property + def branches(self) -> Branches: + """Branch management for the table. + + ``create``/``checkout`` return a new table handle scoped to the branch; + writes on it do not affect ``main``. + """ + return Branches(self) + + def current_branch(self) -> Optional[str]: + """The branch this table handle is scoped to, or ``None`` for ``main``.""" + return self._table.current_branch() + + def _wrap_branch_handle( + self, async_table: AsyncTable, version: Optional[int] = None + ) -> "RemoteTable": + # A branch handle stays a RemoteTable with the same connection context. + # Record the branch and version pin so a fork/pickle reopen restores both. + handle = RemoteTable( + async_table, + self.db_name, + connection_state=self._connection_state, + namespace_path=self._namespace_path, + ) + handle._branch = async_table.current_branch() + handle._checkout_version = version + return handle + @cached_property def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]: """ diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 64a05b3dd..c124309c3 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -796,6 +796,10 @@ class Table(ABC): """ raise NotImplementedError + def current_branch(self) -> Optional[str]: + """The branch this table handle is scoped to, or ``None`` for ``main``.""" + raise NotImplementedError + def __len__(self) -> int: """The number of rows in this Table""" return self.count_rows(None) @@ -2223,6 +2227,21 @@ class LanceTable(Table): """The branch this table handle is scoped to, or ``None`` for ``main``.""" return self._table.current_branch() + def _wrap_branch_handle( + self, async_table: "AsyncTable", version: Optional[int] = None + ) -> "LanceTable": + # version is unused locally: the pin already lives on async_table and a + # local handle is not reopened via a serialized connection. + return LanceTable( + self._conn, + async_table.name, + namespace_path=self._namespace_path, + namespace_client=self._namespace_client, + pushdown_operations=self._pushdown_operations, + location=self._location, + _async=async_table, + ) + def checkout(self, version: Union[int, str]): """Checkout a version of the table. This is an in-place operation. @@ -5934,7 +5953,7 @@ class Branches: name: str, from_ref: Optional[str] = None, from_version: Optional[int] = None, - ) -> "LanceTable": + ) -> "Table": """Create a branch and return a handle scoped to it. Parameters @@ -5951,7 +5970,7 @@ class Branches: ) return self._wrap(async_table) - def checkout(self, name: str, version: Optional[int] = None) -> "LanceTable": + def checkout(self, name: str, version: Optional[int] = None) -> "Table": """Check out an existing branch and return a handle scoped to it. Parameters @@ -5964,25 +5983,19 @@ class Branches: the branch's latest and stays writable. """ async_table = LOOP.run(self._table.branches.checkout(name, version)) - return self._wrap(async_table) + return self._wrap(async_table, version) def delete(self, name: str) -> None: """Delete a branch.""" LOOP.run(self._table.branches.delete(name)) - def _wrap(self, async_table: "AsyncTable") -> "LanceTable": - # Reuse the parent's connection + namespace context; from_inner would drop - # it and break identity/query routing for namespace-backed tables. - parent = self._parent - return LanceTable( - parent._conn, - async_table.name, - namespace_path=parent._namespace_path, - namespace_client=parent._namespace_client, - pushdown_operations=parent._pushdown_operations, - location=parent._location, - _async=async_table, - ) + def _wrap( + self, async_table: "AsyncTable", version: Optional[int] = None + ) -> "Table": + # Delegate to the parent so the branch handle keeps its concrete type + # (LanceTable / RemoteTable) and connection context; `version` is the + # explicit pin so a remote handle can restore branch+version on reopen. + return self._parent._wrap_branch_handle(async_table, version) class AsyncTags: diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index da2956428..1c1b26295 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -154,50 +154,116 @@ async def test_async_checkout(): assert await table.count_rows() == 300 +def _branch_open_handler(request): + if "/branches/list" in request.path: + body = json.dumps( + { + "branches": { + "exp": { + "parentBranch": None, + "parentVersion": 1, + "createAt": 1, + "manifestSize": 1, + } + } + } + ).encode() + else: + # describe (table open + version/branch validation) + body = json.dumps({"version": 2, "schema": {"fields": []}}).encode() + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(body) + + def test_remote_open_table_branch_and_version(): + with mock_lancedb_connection(_branch_open_handler) as db: + # version-only (and "main" + version) time-travels the main chain + assert db.open_table("test", version=2) is not None + assert db.open_table("test", branch="main", version=2).current_branch() is None + + # a non-main branch opens a handle scoped to that branch, with or + # without a version + assert db.open_table("test", branch="exp").current_branch() == "exp" + assert db.open_table("test", branch="exp", version=2).current_branch() == "exp" + + +def test_remote_table_branches_sync(): + # Branch CRUD + current_branch on the sync RemoteTable. The handle returned + # by create/checkout must stay a RemoteTable scoped to the branch. + from lancedb.remote.table import RemoteTable + def handler(request): - # describe (table open + version validation) always succeeds + if "/branches/list" in request.path: + body = json.dumps( + { + "branches": { + "exp": { + "parentBranch": None, + "parentVersion": 1, + "createAt": 1, + "manifestSize": 1, + } + } + } + ).encode() + elif "/branches/create" in request.path or "/branches/delete" in request.path: + body = b"{}" + else: + # describe (table open + checkout validation) + body = json.dumps({"version": 1, "schema": {"fields": []}}).encode() request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() - request.wfile.write( - json.dumps({"version": 2, "schema": {"fields": []}}).encode() - ) + request.wfile.write(body) with mock_lancedb_connection(handler) as db: - # version-only (and "main" + version) is allowed: remote supports - # version time-travel even though it has no branches - assert db.open_table("test", version=2) is not None - assert db.open_table("test", branch="main", version=2) is not None + table = db.open_table("test") + assert isinstance(table, RemoteTable) + assert table.current_branch() is None - # a non-main branch is rejected, with or without a version - with pytest.raises(NotImplementedError, match="branching"): - db.open_table("test", branch="exp") - with pytest.raises(NotImplementedError, match="branching"): - db.open_table("test", branch="exp", version=2) + branch = table.branches.create("exp") + assert isinstance(branch, RemoteTable) + assert branch.current_branch() == "exp" + + # list + checkout round trip; checkout also yields a branch-scoped handle + assert "exp" in table.branches.list() + checked = table.branches.checkout("exp") + assert isinstance(checked, RemoteTable) + assert checked.current_branch() == "exp" + + table.branches.delete("exp") @pytest.mark.asyncio async def test_async_remote_open_table_branch_and_version(): - def handler(request): - request.send_response(200) - request.send_header("Content-Type", "application/json") - request.end_headers() - request.wfile.write( - json.dumps({"version": 2, "schema": {"fields": []}}).encode() - ) - - async with mock_lancedb_connection_async(handler) as db: - # version-only (and "main" + version) is allowed: "main" is the default - # branch, so it must not hit the unsupported remote branch path + async with mock_lancedb_connection_async(_branch_open_handler) as db: + # version-only (and "main" + version) time-travels the main chain assert await db.open_table("test", version=2) is not None - assert await db.open_table("test", branch="main", version=2) is not None + main_v2 = await db.open_table("test", branch="main", version=2) + assert main_v2.current_branch() is None - # a non-main branch is rejected, with or without a version - with pytest.raises(NotImplementedError, match="branching"): - await db.open_table("test", branch="exp") - with pytest.raises(NotImplementedError, match="branching"): - await db.open_table("test", branch="exp", version=2) + # a non-main branch opens a handle scoped to that branch + exp = await db.open_table("test", branch="exp") + assert exp.current_branch() == "exp" + exp_v2 = await db.open_table("test", branch="exp", version=2) + assert exp_v2.current_branch() == "exp" + + +def test_remote_table_branch_survives_pickle(): + # Regression: a branch-scoped handle must keep its branch across a + # pickle/fork round-trip (it used to reopen on main). + with mock_lancedb_connection(_branch_open_handler) as db: + branch = db.open_table("test", branch="exp") + assert branch.current_branch() == "exp" + restored = pickle.loads(pickle.dumps(branch)) + assert restored.current_branch() == "exp" + + # the pinned version is carried through as well + branch_v2 = db.open_table("test", branch="exp", version=2) + restored_v2 = pickle.loads(pickle.dumps(branch_v2)) + assert restored_v2.current_branch() == "exp" def test_table_len_sync(): diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index e83014d5d..7c163e0bc 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -985,45 +985,42 @@ mod tests { #[tokio::test] async fn test_open_table_branch_and_version() { - // Remote supports version time-travel but not branches. A version-only - // open (or one on the default "main" branch) must succeed; a non-main - // branch must be rejected, with or without a version. let conn = Connection::new_with_handler(|request| { - assert_eq!(request.url().path(), "/v1/table/t/describe/"); - http::Response::builder() - .status(200) - .body( - r#"{"table": "t", "version": 2, "schema": {"fields": [ - {"name": "a", "type": { "type": "int32" }, "nullable": false} - ]}}"#, - ) - .unwrap() + let body = if request.url().path() == "/v1/table/t/branches/list/" { + // checkout_branch validates the branch exists via list_branches. + r#"{"branches":{"exp":{"parentVersion":1,"createAt":1,"manifestSize":1}}}"# + } else { + // describe (table open + version/branch validation) + r#"{"table": "t", "version": 2, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"# + }; + http::Response::builder().status(200).body(body).unwrap() }); - // version-only: allowed (open + checkout(version) both round-trip) - conn.open_table("t").version(2).execute().await.unwrap(); - - // "main" is the default branch, so it counts as no branch - conn.open_table("t") + // version-only (and "main" + version) time-travel the main chain + let v2 = conn.open_table("t").version(2).execute().await.unwrap(); + assert_eq!(v2.current_branch(), None); + let main_v2 = conn + .open_table("t") .branch("main") .version(2) .execute() .await .unwrap(); + assert_eq!(main_v2.current_branch(), None); - // a non-main branch is rejected, with or without a version - assert!(matches!( - conn.open_table("t").branch("exp").execute().await, - Err(Error::NotSupported { .. }) - )); - assert!(matches!( - conn.open_table("t") - .branch("exp") - .version(2) - .execute() - .await, - Err(Error::NotSupported { .. }) - )); + // a non-main branch opens a handle scoped to that branch + let exp = conn.open_table("t").branch("exp").execute().await.unwrap(); + assert_eq!(exp.current_branch(), Some("exp".to_string())); + let exp_v2 = conn + .open_table("t") + .branch("exp") + .version(2) + .execute() + .await + .unwrap(); + assert_eq!(exp_v2.current_branch(), Some("exp".to_string())); } #[tokio::test] diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 83f296a69..f11f13957 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -134,6 +134,14 @@ fn compute_min_timestamp( } } +/// Normalize a branch selector: trim whitespace and treat `""` or `"main"` as +/// the (absent) main branch, matching the server's convention. +fn normalize_branch(branch: Option) -> Option { + branch + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty() && value != "main") +} + pub struct RemoteTags<'a, S: HttpSend = Sender> { inner: &'a RemoteTable, } @@ -184,14 +192,16 @@ impl Tags for RemoteTags<'_, S> { } async fn create(&mut self, tag: &str, version: u64) -> Result<()> { + let mut body = serde_json::json!({ + "tag": tag, + "version": version + }); + self.inner.apply_branch_body(&mut body); let request = self .inner .client .post(&format!("/v1/table/{}/tags/create/", self.inner.identifier)) - .json(&serde_json::json!({ - "tag": tag, - "version": version - })); + .json(&body); let (request_id, response) = self.inner.send(request, true).await?; self.inner @@ -215,14 +225,16 @@ impl Tags for RemoteTags<'_, S> { } async fn update(&mut self, tag: &str, version: u64) -> Result<()> { + let mut body = serde_json::json!({ + "tag": tag, + "version": version + }); + self.inner.apply_branch_body(&mut body); let request = self .inner .client .post(&format!("/v1/table/{}/tags/update/", self.inner.identifier)) - .json(&serde_json::json!({ - "tag": tag, - "version": version - })); + .json(&body); let (request_id, response) = self.inner.send(request, true).await?; self.inner @@ -243,6 +255,10 @@ pub struct RemoteTable { location: RwLock>, schema_cache: BackgroundCache, freshness: Mutex, + /// The branch this handle is scoped to, or `None` for the main branch. + /// Stamped onto every branch-accepting request so reads and writes resolve + /// on the branch's own version chain rather than main's. + branch: Option, } impl std::fmt::Debug for RemoteTable { @@ -272,6 +288,43 @@ impl RemoteTable { location: RwLock::new(None), schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), freshness: Mutex::new(FreshnessState::default()), + branch: None, + } + } + + /// Return a new handle scoped to `branch`, sharing the client but with fresh + /// caches and version/freshness state (the branch tracks its own latest). + /// Mirrors `NativeTable`'s handle-per-branch model. + fn with_branch(&self, branch: Option) -> Self { + Self { + client: self.client.clone(), + name: self.name.clone(), + namespace: self.namespace.clone(), + identifier: self.identifier.clone(), + server_version: self.server_version.clone(), + version: RwLock::new(None), + location: RwLock::new(None), + schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), + freshness: Mutex::new(FreshnessState::default()), + branch, + } + } + + /// Stamp the branch onto a request as a `?branch=` query param (used for + /// Arrow-body / query-only ops). `None` (main) leaves the request unchanged, + /// keeping it byte-identical to the non-branch path. + fn apply_branch_query(&self, request: RequestBuilder) -> RequestBuilder { + match &self.branch { + Some(branch) => request.query(&[("branch", branch.as_str())]), + None => request, + } + } + + /// Stamp the branch into a JSON request body under `"branch"` (used for JSON + /// ops). `None` (main) leaves the body unchanged. + fn apply_branch_body(&self, body: &mut serde_json::Value) { + if let Some(branch) = &self.branch { + body["branch"] = serde_json::Value::String(branch.clone()); } } @@ -324,12 +377,43 @@ impl RemoteTable { } } + /// Resolve a tag to its `(branch, version)` coordinate via the `tags/version` + /// endpoint, since the `/branches/create` contract accepts no `from_tag`. + async fn resolve_tag_ref(&self, tag: &str) -> Result<(Option, u64)> { + let request = self + .client + .post(&format!("/v1/table/{}/tags/version/", self.identifier)) + .json(&serde_json::json!({ "tag": tag })); + let (request_id, response) = self.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + let value: serde_json::Value = serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse tag version: {}", e).into(), + request_id: request_id.clone(), + status_code: None, + })?; + let version = value + .get("version") + .and_then(|v| v.as_u64()) + .ok_or_else(|| Error::Http { + source: format!("Invalid tag version response: {}", body).into(), + request_id, + status_code: None, + })?; + let branch = value + .get("branch") + .and_then(|v| v.as_str()) + .map(String::from); + Ok((normalize_branch(branch), version)) + } + async fn describe_with_request( &self, request: RequestBuilder, version: Option, ) -> Result { - let body = serde_json::json!({ "version": version }); + let mut body = serde_json::json!({ "version": version }); + self.apply_branch_body(&mut body); let request = request.json(&body); let (request_id, response) = self.send(request, true).await?; @@ -707,10 +791,10 @@ impl RemoteTable { } async fn create_multipart_write(&self) -> Result { - let request = self.client.post(&format!( + let request = self.apply_branch_query(self.client.post(&format!( "/v1/table/{}/multipart_write/create", self.identifier - )); + ))); let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; let body = response.text().await.err_to_http(request_id.clone())?; @@ -730,13 +814,14 @@ impl RemoteTable { } async fn complete_multipart_write(&self, upload_id: &str) -> Result { - let request = self - .client - .post(&format!( - "/v1/table/{}/multipart_write/complete", - self.identifier - )) - .query(&[("upload_id", upload_id)]); + let request = self.apply_branch_query( + self.client + .post(&format!( + "/v1/table/{}/multipart_write/complete", + self.identifier + )) + .query(&[("upload_id", upload_id)]), + ); let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; let body = response.text().await.err_to_http(request_id.clone())?; @@ -754,13 +839,14 @@ impl RemoteTable { } async fn abort_multipart_write(&self, upload_id: &str) -> Result<()> { - let request = self - .client - .post(&format!( - "/v1/table/{}/multipart_write/abort", - self.identifier - )) - .query(&[("upload_id", upload_id)]); + let request = self.apply_branch_query( + self.client + .post(&format!( + "/v1/table/{}/multipart_write/abort", + self.identifier + )) + .query(&[("upload_id", upload_id)]), + ); let (request_id, response) = self.send(request, true).await?; self.check_table_response(&request_id, response).await?; Ok(()) @@ -865,7 +951,8 @@ impl RemoteTable { async fn prepare_query_bodies(&self, query: &AnyQuery) -> Result> { let version = self.current_version().await; - let base_body = serde_json::json!({ "version": version }); + let mut base_body = serde_json::json!({ "version": version }); + self.apply_branch_body(&mut base_body); match query { AnyQuery::Query(query) => { @@ -927,11 +1014,16 @@ async fn fetch_schema( identifier: &str, table_name: &str, version: Option, + branch: Option, freshness_headers: FreshnessHeaders, ) -> Result { + let mut body = serde_json::json!({ "version": version }); + if let Some(branch) = &branch { + body["branch"] = serde_json::Value::String(branch.clone()); + } let request = freshness_headers .apply(client.post(&format!("/v1/table/{}/describe/", identifier))) - .json(&serde_json::json!({ "version": version })); + .json(&body); let (request_id, response) = client.send_with_retry(request, None, true).await?; @@ -999,6 +1091,7 @@ mod test_utils { location: RwLock::new(None), schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), freshness: Mutex::new(FreshnessState::default()), + branch: None, } } @@ -1022,6 +1115,7 @@ mod test_utils { location: RwLock::new(None), schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), freshness: Mutex::new(FreshnessState::default()), + branch: None, } } @@ -1054,6 +1148,7 @@ mod test_utils { location: RwLock::new(None), schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), freshness: Mutex::new(FreshnessState::default()), + branch: None, } } } @@ -1097,6 +1192,7 @@ impl RemoteTable { output.plan, output.overwrite, output.tracker.clone(), + self.branch.clone(), )); let mut retry_counter = @@ -1219,6 +1315,7 @@ impl RemoteTable { output.overwrite, upload_id.to_string(), output.tracker.clone(), + self.branch.clone(), )); let task_ctx = Arc::new(datafusion_execution::TaskContext::default()); @@ -1433,7 +1530,8 @@ impl BaseTable for RemoteTable { .client .post(&format!("/v1/table/{}/restore/", self.identifier)); let version = self.current_version().await; - let body = serde_json::json!({ "version": version }); + let mut body = serde_json::json!({ "version": version }); + self.apply_branch_body(&mut body); request = request.json(&body); let (request_id, response) = self.send(request, true).await?; @@ -1443,7 +1541,9 @@ impl BaseTable for RemoteTable { } async fn list_versions(&self) -> Result> { - let request = self.post_read(&format!("/v1/table/{}/version/list/", self.identifier)); + let request = self.apply_branch_query( + self.post_read(&format!("/v1/table/{}/version/list/", self.identifier)), + ); let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; @@ -1476,6 +1576,7 @@ impl BaseTable for RemoteTable { let client = self.client.clone(); let identifier = self.identifier.clone(); let table_name = self.name.clone(); + let branch = self.branch.clone(); let freshness_headers = self.snapshot_freshness_headers(); self.schema_cache @@ -1485,6 +1586,7 @@ impl BaseTable for RemoteTable { &identifier, &table_name, version, + branch, freshness_headers, ) .await @@ -1495,34 +1597,151 @@ impl BaseTable for RemoteTable { async fn create_branch( &self, - _name: &str, - _from: lance::dataset::refs::Ref, + name: &str, + from: lance::dataset::refs::Ref, ) -> Result> { - Err(Error::NotSupported { - message: "branching is not yet supported on remote tables".into(), - }) + use lance::dataset::refs::Ref; + + if name.trim().is_empty() { + return Err(Error::InvalidInput { + message: "branch name must be a non-empty string".into(), + }); + } + + // Translate the source ref into the `from_branch` / `from_version` the + // `/branches/create` contract accepts (it has no `from_tag`). + let (from_branch, from_version) = match from { + Ref::Version(branch, version) => (normalize_branch(branch), version), + Ref::VersionNumber(version) => (normalize_branch(self.branch.clone()), Some(version)), + Ref::Tag(tag) => { + let (branch, version) = self.resolve_tag_ref(&tag).await?; + (branch, Some(version)) + } + }; + + let mut body = serde_json::json!({ "name": name }); + if let Some(from_branch) = &from_branch { + body["from_branch"] = serde_json::Value::String(from_branch.clone()); + } + if let Some(from_version) = from_version { + body["from_version"] = serde_json::json!(from_version); + } + + let request = self + .client + .post(&format!("/v1/table/{}/branches/create/", self.identifier)) + .json(&body); + + // Send without retry so the expected 409 (branch already exists) is + // surfaced as a response we can map, rather than being retried. + let (request_id, response) = self.send(request, false).await?; + match response.status() { + StatusCode::CONFLICT => { + return Err(Error::TableAlreadyExists { + name: format!("{} (branch: {})", self.name, name), + }); + } + StatusCode::BAD_REQUEST => { + let body = response.text().await.unwrap_or_default(); + return Err(Error::InvalidInput { + message: format!("invalid create_branch request: {}", body), + }); + } + StatusCode::NOT_FOUND => { + // 404 covers both a missing table and a missing source ref; name + // the source coordinate so the error isn't misattributed to the table. + let body = response.text().await.unwrap_or_default(); + let source_desc = match (&from_branch, from_version) { + (Some(b), Some(v)) => format!(" (source: branch '{b}' version {v})"), + (Some(b), None) => format!(" (source: branch '{b}')"), + (None, Some(v)) => format!(" (source: version {v})"), + (None, None) => String::new(), + }; + return Err(Error::TableNotFound { + name: format!("{}{}", self.name, source_desc), + source: Box::new(Error::Http { + source: body.into(), + request_id, + status_code: Some(StatusCode::NOT_FOUND), + }), + }); + } + _ => {} + } + self.check_table_response(&request_id, response).await?; + + Ok(Arc::new(self.with_branch(Some(name.to_string())))) } - async fn checkout_branch(&self, _name: &str) -> Result> { - Err(Error::NotSupported { - message: "branching is not yet supported on remote tables".into(), - }) + async fn checkout_branch(&self, name: &str) -> Result> { + // `main` / empty normalizes to the main-branch handle. + let Some(branch) = normalize_branch(Some(name.to_string())) else { + return Ok(Arc::new(self.with_branch(None))); + }; + + // Validate via listing -- the cheapest check that distinguishes a missing + // branch from a missing table. + let branches = self.list_branches().await?; + if !branches.contains_key(&branch) { + return Err(Error::TableNotFound { + name: format!("{} (branch: {})", self.name, branch), + source: format!("branch '{}' does not exist", branch).into(), + }); + } + + Ok(Arc::new(self.with_branch(Some(branch)))) } async fn list_branches(&self) -> Result> { - Err(Error::NotSupported { - message: "branching is not yet supported on remote tables".into(), - }) + use lance::dataset::refs::BranchContents; + + let request = self.post_read(&format!("/v1/table/{}/branches/list/", self.identifier)); + let (request_id, response) = self.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + let body = response.text().await.err_to_http(request_id.clone())?; + + #[derive(Deserialize)] + struct ListBranchesResponse { + branches: HashMap, + } + + let parsed: ListBranchesResponse = + serde_json::from_str(&body).map_err(|err| Error::Http { + source: format!( + "Failed to parse list_branches response: {}, body: {}", + err, body + ) + .into(), + request_id, + status_code: None, + })?; + + Ok(parsed.branches) } - async fn delete_branch(&self, _name: &str) -> Result<()> { - Err(Error::NotSupported { - message: "branching is not yet supported on remote tables".into(), - }) + async fn delete_branch(&self, name: &str) -> Result<()> { + if name.trim().is_empty() { + return Err(Error::InvalidInput { + message: "branch name must be a non-empty string".into(), + }); + } + let request = self + .client + .post(&format!("/v1/table/{}/branches/delete/", self.identifier)) + .json(&serde_json::json!({ "name": name })); + let (request_id, response) = self.send(request, true).await?; + if response.status() == StatusCode::NOT_FOUND { + return Err(Error::TableNotFound { + name: format!("{} (branch: {})", self.name, name), + source: format!("branch '{}' does not exist", name).into(), + }); + } + self.check_table_response(&request_id, response).await?; + Ok(()) } fn current_branch(&self) -> Option { - None + self.branch.clone() } async fn count_rows(&self, filter: Option) -> Result { @@ -1530,17 +1749,17 @@ impl BaseTable for RemoteTable { let version = self.current_version().await; - if let Some(filter) = filter { + let mut body = if let Some(filter) = filter { let filter_sql = match filter { Filter::Sql(sql) => sql.clone(), Filter::Datafusion(expr) => expr_to_sql_string(&expr)?, }; - request = - request.json(&serde_json::json!({ "predicate": filter_sql, "version": version })); + serde_json::json!({ "predicate": filter_sql, "version": version }) } else { - let body = serde_json::json!({ "version": version }); - request = request.json(&body); - } + serde_json::json!({ "version": version }) + }; + self.apply_branch_body(&mut body); + request = request.json(&body); let (request_id, response) = match self.send(request, true).await { Ok((id, resp)) => { @@ -1745,10 +1964,12 @@ impl BaseTable for RemoteTable { updates.push(vec![column, expression]); } - let request = request.json(&serde_json::json!({ + let mut body = serde_json::json!({ "updates": updates, "predicate": update.filter, - })); + }); + self.apply_branch_body(&mut body); + let request = request.json(&body); let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; @@ -1779,7 +2000,8 @@ impl BaseTable for RemoteTable { Predicate::String(s) => s.to_string(), Predicate::Expr(expr) => expr_to_sql_string(expr)?, }; - let body = serde_json::json!({ "predicate": predicate_sql }); + let mut body = serde_json::json!({ "predicate": predicate_sql }); + self.apply_branch_body(&mut body); let request = self .client .post(&format!("/v1/table/{}/delete/", self.identifier)) @@ -1891,6 +2113,7 @@ impl BaseTable for RemoteTable { body[key] = value.clone(); } } + self.apply_branch_body(&mut body); let request = request.json(&body); @@ -1927,6 +2150,7 @@ impl BaseTable for RemoteTable { .post(&format!("/v1/table/{}/merge_insert/", self.identifier)) .query(&query) .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); + request = self.apply_branch_query(request); if let Some(timeout) = timeout { // (If it doesn't fit into u64, it's not worth sending anyways.) @@ -2070,7 +2294,8 @@ impl BaseTable for RemoteTable { }) }) .collect::>(); - let body = serde_json::json!({ "new_columns": body }); + let mut body = serde_json::json!({ "new_columns": body }); + self.apply_branch_body(&mut body); let request = self .client .post(&format!("/v1/table/{}/add_columns/", self.identifier)) @@ -2126,7 +2351,8 @@ impl BaseTable for RemoteTable { value }) .collect::>(); - let body = serde_json::json!({ "alterations": body }); + let mut body = serde_json::json!({ "alterations": body }); + self.apply_branch_body(&mut body); let request = self .client .post(&format!("/v1/table/{}/alter_columns/", self.identifier)) @@ -2157,7 +2383,8 @@ impl BaseTable for RemoteTable { updates: &[FieldMetadataUpdate], ) -> Result { self.check_mutable().await?; - let body = serde_json::json!({ "updates": updates }); + let mut body = serde_json::json!({ "updates": updates }); + self.apply_branch_body(&mut body); let request = self .client .post(&format!( @@ -2183,7 +2410,8 @@ impl BaseTable for RemoteTable { async fn drop_columns(&self, columns: &[&str]) -> Result { self.check_mutable().await?; - let body = serde_json::json!({ "columns": columns }); + let mut body = serde_json::json!({ "columns": columns }); + self.apply_branch_body(&mut body); let request = self .client .post(&format!("/v1/table/{}/drop_columns/", self.identifier)) @@ -2212,7 +2440,9 @@ impl BaseTable for RemoteTable { async fn list_indices(&self) -> Result> { let mut request = self.post_read(&format!("/v1/table/{}/index/list/", self.identifier)); let version = self.current_version().await; - request = request.json(&serde_json::json!({ "version": version })); + let mut body = serde_json::json!({ "version": version }); + self.apply_branch_body(&mut body); + request = request.json(&body); let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; @@ -2229,7 +2459,8 @@ impl BaseTable for RemoteTable { self.identifier, index_name )); let version = self.current_version().await; - let body = serde_json::json!({ "version": version }); + let mut body = serde_json::json!({ "version": version }); + self.apply_branch_body(&mut body); request = request.json(&body); let (request_id, response) = self.send(request, true).await?; @@ -2252,10 +2483,10 @@ impl BaseTable for RemoteTable { } async fn drop_index(&self, index_name: &str) -> Result<()> { - let request = self.client.post(&format!( + let request = self.apply_branch_query(self.client.post(&format!( "/v1/table/{}/index/{}/drop/", self.identifier, index_name - )); + ))); let (request_id, response) = self.send(request, true).await?; if response.status() == StatusCode::NOT_FOUND { return Err(Error::IndexNotFound { @@ -2336,7 +2567,10 @@ impl BaseTable for RemoteTable { } async fn stats(&self) -> Result { - let request = self.post_read(&format!("/v1/table/{}/stats/", self.identifier)); + let mut request = self.post_read(&format!("/v1/table/{}/stats/", self.identifier)); + if let Some(branch) = &self.branch { + request = request.json(&serde_json::json!({ "branch": branch })); + } let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; let body = response.text().await.err_to_http(request_id.clone())?; @@ -2362,6 +2596,7 @@ impl BaseTable for RemoteTable { input, overwrite, None, + self.branch.clone(), ))) } } @@ -7067,4 +7302,927 @@ mod tests { .unwrap(); assert_eq!(result.version, 7); } + + // ----- Branch support ----- + + /// Parse a request's in-memory JSON body. Only valid for JSON-body ops + /// (not Arrow-stream inserts, whose body is a stream). + fn request_body_json(request: &reqwest::Request) -> serde_json::Value { + let bytes = request + .body() + .expect("request has a body") + .as_bytes() + .expect("body is in-memory"); + serde_json::from_slice(bytes).expect("body is valid JSON") + } + + #[tokio::test] + async fn test_create_branch_default_source() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/branches/create/"); + let body = request_body_json(&request); + assert_eq!(body["name"], "exp"); + assert!( + body.get("from_branch").is_none(), + "a main source omits from_branch" + ); + assert!( + body.get("from_version").is_none(), + "a latest source omits from_version" + ); + http::Response::builder().status(200).body("{}").unwrap() + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + assert_eq!(branch.current_branch(), Some("exp".to_string())); + assert_eq!(table.current_branch(), None); + } + + #[tokio::test] + async fn test_create_branch_from_branch_and_version() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |request| { + let body = request_body_json(&request); + assert_eq!(body["name"], "exp"); + assert_eq!(body["from_branch"], "base"); + assert_eq!(body["from_version"], 3); + http::Response::builder().status(200).body("{}").unwrap() + }); + table + .create_branch("exp", Ref::Version(Some("base".into()), Some(3))) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_create_branch_from_main_normalizes_to_none() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |request| { + let body = request_body_json(&request); + assert!( + body.get("from_branch").is_none(), + "\"main\" normalizes to an absent from_branch" + ); + assert_eq!(body["from_version"], 7); + http::Response::builder().status(200).body("{}").unwrap() + }); + table + .create_branch("exp", Ref::Version(Some("main".into()), Some(7))) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_create_branch_from_version_number_on_main() { + use lance::dataset::refs::Ref; + // A bare version number on a main handle resolves to (main, version). + let table = Table::new_with_handler("my_table", |request| { + let body = request_body_json(&request); + assert!(body.get("from_branch").is_none()); + assert_eq!(body["from_version"], 5); + http::Response::builder().status(200).body("{}").unwrap() + }); + table + .create_branch("exp", Ref::VersionNumber(5)) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_create_branch_from_tag_resolves_via_tags_endpoint() { + use lance::dataset::refs::Ref; + // A tag source has no from_tag in the create contract; it is resolved to + // its (branch, version) via the tags/version endpoint first. + let table = Table::new_with_handler("my_table", |request| match request.url().path() { + "/v1/table/my_table/tags/version/" => { + assert_eq!(request_body_json(&request)["tag"], "t"); + http::Response::builder() + .status(200) + .body(r#"{"version":3,"branch":"base"}"#.to_string()) + .unwrap() + } + "/v1/table/my_table/branches/create/" => { + let body = request_body_json(&request); + assert_eq!(body["name"], "exp"); + assert_eq!(body["from_branch"], "base"); + assert_eq!(body["from_version"], 3); + http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + table + .create_branch("exp", Ref::Tag("t".into())) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_create_branch_from_tag_on_main_normalizes() { + use lance::dataset::refs::Ref; + // A tag resolving to the main branch collapses from_branch to absent. + let table = Table::new_with_handler("my_table", |request| match request.url().path() { + "/v1/table/my_table/tags/version/" => http::Response::builder() + .status(200) + .body(r#"{"version":4,"branch":"main"}"#.to_string()) + .unwrap(), + "/v1/table/my_table/branches/create/" => { + let body = request_body_json(&request); + assert!( + body.get("from_branch").is_none(), + "a resolved \"main\" normalizes to an absent from_branch" + ); + assert_eq!(body["from_version"], 4); + http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + table + .create_branch("exp", Ref::Tag("t".into())) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_create_branch_invalid_request_maps_to_invalid_input() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |_| { + http::Response::builder() + .status(400) + .body("unsafe branch name") + .unwrap() + }); + let err = table + .create_branch("../evil", Ref::Version(None, None)) + .await + .unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn test_create_branch_conflict_maps_to_already_exists() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |_| { + http::Response::builder() + .status(409) + .body("branch already exists") + .unwrap() + }); + let err = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap_err(); + assert!( + matches!(err, Error::TableAlreadyExists { .. }), + "409 should map to AlreadyExists, got {err:?}" + ); + } + + #[tokio::test] + async fn test_create_branch_empty_name_rejected_client_side() { + use lance::dataset::refs::Ref; + // The empty name is rejected before any request is sent. + let table = Table::new_with_handler("my_table", |request| -> http::Response { + panic!("unexpected request: {}", request.url().path()) + }); + let err = table + .create_branch("", Ref::Version(None, None)) + .await + .unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn test_list_branches() { + use lance::dataset::refs::BranchIdentifier; + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/branches/list/"); + // A branch forked off main: the server omits `parentBranch` entirely + // (skip_serializing_if), not `null`, so this mirrors the real wire. + http::Response::builder() + .status(200) + .body( + r#"{"branches":{"exp":{"parentVersion":2,"createAt":1234,"manifestSize":4096}}}"#, + ) + .unwrap() + }); + let branches = table.list_branches().await.unwrap(); + let exp = branches.get("exp").expect("exp present"); + assert_eq!(exp.parent_version, 2); + assert_eq!(exp.create_at, 1234); + assert_eq!(exp.manifest_size, 4096); + assert_eq!(exp.parent_branch, None); + assert!(exp.metadata.is_empty()); + // The server omits the internal lineage token; it defaults to the sentinel. + assert_eq!( + exp.identifier, + BranchIdentifier::missing_identifier_sentinel() + ); + } + + #[tokio::test] + async fn test_delete_branch() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/branches/delete/"); + let body = request_body_json(&request); + assert_eq!(body["name"], "exp"); + http::Response::builder().status(200).body("{}").unwrap() + }); + table.delete_branch("exp").await.unwrap(); + } + + #[tokio::test] + async fn test_delete_branch_not_found() { + let table = Table::new_with_handler("my_table", |_| { + http::Response::builder() + .status(404) + .body("no such branch") + .unwrap() + }); + let err = table.delete_branch("ghost").await.unwrap_err(); + assert!(matches!(err, Error::TableNotFound { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn test_checkout_branch_validates_via_list() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.url().path(), "/v1/table/my_table/branches/list/"); + http::Response::builder() + .status(200) + .body( + r#"{"branches":{"exp":{"parentBranch":null,"parentVersion":1,"createAt":1,"manifestSize":1}}}"#, + ) + .unwrap() + }); + let branch = table.checkout_branch("exp", None).await.unwrap(); + assert_eq!(branch.current_branch(), Some("exp".to_string())); + } + + #[tokio::test] + async fn test_checkout_branch_missing() { + let table = Table::new_with_handler("my_table", |_| { + http::Response::builder() + .status(200) + .body(r#"{"branches":{}}"#) + .unwrap() + }); + let err = table.checkout_branch("ghost", None).await.unwrap_err(); + assert!(matches!(err, Error::TableNotFound { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn test_checkout_main_returns_main_handle() { + // "main" yields a main-scoped handle without any validation request. + let table = Table::new_with_handler("my_table", |request| -> http::Response { + panic!("unexpected request: {}", request.url().path()) + }); + let main = table.checkout_branch("main", None).await.unwrap(); + assert_eq!(main.current_branch(), None); + } + + #[tokio::test] + async fn test_branch_count_rows_carries_branch_in_body() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/count_rows/" => { + let body = request_body_json(&request); + assert_eq!(body["branch"], "exp"); + http::Response::builder() + .status(200) + .body("7".to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + assert_eq!(branch.count_rows(None).await.unwrap(), 7); + } + + #[tokio::test] + async fn test_main_handle_omits_branch_in_body() { + // A main handle must not send a branch field (byte-compatible with + // pre-branch servers and the existing wire format). + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.url().path(), "/v1/table/my_table/count_rows/"); + let body = request_body_json(&request); + assert!( + body.get("branch").is_none(), + "main handle must not send a branch field" + ); + http::Response::builder().status(200).body("0").unwrap() + }); + table.count_rows(None).await.unwrap(); + } + + #[tokio::test] + async fn test_branch_update_and_delete_carry_branch() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/update/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body(r#"{"version":5}"#.to_string()) + .unwrap() + } + "/v1/table/my_table/delete/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body(r#"{"version":6,"num_deleted_rows":1}"#.to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch + .update() + .column("a", "a + 1") + .execute() + .await + .unwrap(); + branch.delete("a > 1").await.unwrap(); + } + + #[tokio::test] + async fn test_branch_list_indices_carries_branch_in_body() { + use lance::dataset::refs::Ref; + // list_indices posts to index/list and then fetches the schema (describe) + // to resolve column names; both must carry the branch. + let describe_body = + describe_response(&Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/index/list/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body(r#"{"indexes":[]}"#.to_string()) + .unwrap() + } + "/v1/table/my_table/describe/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body(describe_body.clone()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + assert!(branch.list_indices().await.unwrap().is_empty()); + } + + #[tokio::test] + async fn test_branch_list_versions_carries_query_param() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/version/list/" => { + assert_eq!( + request + .url() + .query_pairs() + .find(|(k, _)| k == "branch") + .map(|(_, v)| v.into_owned()), + Some("exp".to_string()), + "version/list must carry ?branch=exp" + ); + http::Response::builder() + .status(200) + .body(r#"{"versions":[]}"#.to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + assert!(branch.list_versions().await.unwrap().is_empty()); + } + + #[tokio::test] + async fn test_branch_drop_index_carries_query_param() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/index/my_idx/drop/" => { + assert_eq!( + request + .url() + .query_pairs() + .find(|(k, _)| k == "branch") + .map(|(_, v)| v.into_owned()), + Some("exp".to_string()) + ); + http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch.drop_index("my_idx").await.unwrap(); + } + + #[tokio::test] + async fn test_branch_insert_carries_query_param() { + use lance::dataset::refs::Ref; + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let describe_body = describe_response(&data.schema()); + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/describe/" => { + // schema() fetch on the branch handle carries branch in the body. + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body(describe_body.clone()) + .unwrap() + } + "/v1/table/my_table/insert/" => { + assert_eq!( + request + .url() + .query_pairs() + .find(|(k, _)| k == "branch") + .map(|(_, v)| v.into_owned()), + Some("exp".to_string()), + "insert must carry ?branch=exp" + ); + http::Response::builder() + .status(200) + .body(r#"{"version":2}"#.to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch.add(data.clone()).execute().await.unwrap(); + } + + #[tokio::test] + async fn test_branch_tag_create_carries_branch_in_body() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/tags/create/" => { + let body = request_body_json(&request); + assert_eq!(body["branch"], "exp"); + assert_eq!(body["tag"], "v1"); + http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch.tags().await.unwrap().create("v1", 1).await.unwrap(); + } + + #[tokio::test] + async fn test_checkout_branch_version_forwards_branch() { + // Branch versions overlap main's, so the version must resolve on the + // branch's own chain -- the validating describe and reads carry both. + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let describe_body = describe_response(&schema); + let table = Table::new_with_handler("my_table", move |request| { + match request.url().path() { + "/v1/table/my_table/branches/list/" => http::Response::builder() + .status(200) + .body( + r#"{"branches":{"exp":{"parentBranch":null,"parentVersion":1,"createAt":1,"manifestSize":1}}}"# + .to_string(), + ) + .unwrap(), + "/v1/table/my_table/describe/" => { + let body = request_body_json(&request); + assert_eq!(body["branch"], "exp", "checkout validate carries branch"); + assert_eq!(body["version"], 2, "checkout validate carries version"); + http::Response::builder().status(200).body(describe_body.clone()).unwrap() + } + "/v1/table/my_table/count_rows/" => { + let body = request_body_json(&request); + assert_eq!(body["branch"], "exp"); + assert_eq!(body["version"], 2, "overlapping version resolves on the branch chain"); + http::Response::builder().status(200).body("3".to_string()).unwrap() + } + path => panic!("unexpected request path: {path}"), + } + }); + let branch = table.checkout_branch("exp", Some(2)).await.unwrap(); + assert_eq!(branch.current_branch(), Some("exp".to_string())); + assert_eq!(branch.count_rows(None).await.unwrap(), 3); + } + + fn branch_query_param(request: &reqwest::Request) -> Option { + request + .url() + .query_pairs() + .find(|(k, _)| k == "branch") + .map(|(_, v)| v.into_owned()) + } + + #[tokio::test] + async fn test_branch_query_carries_branch_in_body() { + use lance::dataset::refs::Ref; + // /query/ is the hot read path; the branch must ride in the JSON body. + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let data_ref = data.clone(); + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body(b"{}".to_vec()) + .unwrap(), + "/v1/table/my_table/query/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) + .body(write_ipc_file(&data_ref)) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + let rows: usize = branch + .query() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum(); + assert_eq!(rows, 3); + } + + #[tokio::test] + async fn test_branch_merge_insert_carries_query_param() { + use lance::dataset::refs::Ref; + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let data: Box = Box::new(RecordBatchIterator::new( + [Ok(batch.clone())], + batch.schema(), + )); + let table = Table::new_with_handler("my_table", move |request| { + match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/merge_insert/" => { + assert_eq!( + branch_query_param(&request).as_deref(), + Some("exp"), + "merge_insert must carry ?branch=exp" + ); + http::Response::builder() + .status(200) + .body( + r#"{"version":2,"num_deleted_rows":0,"num_inserted_rows":3,"num_updated_rows":0}"# + .to_string(), + ) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + } + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch.merge_insert(&["id"]).execute(data).await.unwrap(); + } + + #[tokio::test] + async fn test_branch_multipart_write_carries_query_param() { + use lance::dataset::refs::Ref; + // The multipart path (create -> insert parts -> complete) must carry + // ?branch= on every leg; an old server version forces it. + let table = Table::new_with_handler_version( + "my_table", + semver::Version::new(0, 4, 0), + move |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/describe/" => simple_describe_response(), + "/v1/table/my_table/multipart_write/create" => { + assert_eq!(branch_query_param(&request).as_deref(), Some("exp")); + http::Response::builder() + .status(200) + .body(r#"{"upload_id": "u1"}"#.to_string()) + .unwrap() + } + "/v1/table/my_table/insert/" => { + assert_eq!( + branch_query_param(&request).as_deref(), + Some("exp"), + "multipart insert must carry ?branch=exp" + ); + http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#.to_string()) + .unwrap() + } + "/v1/table/my_table/multipart_write/complete" => { + assert_eq!(branch_query_param(&request).as_deref(), Some("exp")); + http::Response::builder() + .status(200) + .body(r#"{"version": 5}"#.to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }, + ); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + branch + .add(vec![batch]) + .write_parallelism(2) + .execute() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_branch_restore_carries_branch_in_body() { + use lance::dataset::refs::Ref; + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/restore/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body(r#"{"version":1}"#.to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch.restore().await.unwrap(); + } + + #[tokio::test] + async fn test_branch_create_index_carries_branch_in_body() { + use lance::dataset::refs::Ref; + // create_index fetches the schema (describe) to resolve the column and + // then posts create_index; both must carry the branch in the body. + let describe_body = + describe_response(&Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/describe/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body(describe_body.clone()) + .unwrap() + } + "/v1/table/my_table/create_index/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch + .create_index(&["a"], Index::BTree(Default::default())) + .execute() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_branch_column_ops_carry_branch_in_body() { + use lance::dataset::refs::Ref; + // add_columns / alter_columns / drop_columns all stamp the branch into + // the JSON body via apply_branch_body. + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/add_columns/" + | "/v1/table/my_table/alter_columns/" + | "/v1/table/my_table/drop_columns/" => { + assert_eq!( + request_body_json(&request)["branch"], + "exp", + "{} must carry the branch", + request.url().path() + ); + http::Response::builder() + .status(200) + .body(r#"{"version":43}"#.to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch + .add_columns( + NewColumnTransform::SqlExpressions(vec![("b".into(), "a + 1".into())]), + None, + ) + .await + .unwrap(); + branch + .alter_columns(&[ColumnAlteration::new("a".into()).rename("b".into())]) + .await + .unwrap(); + branch.drop_columns(&["a"]).await.unwrap(); + } + + #[tokio::test] + async fn test_branch_update_field_metadata_carries_branch_in_body() { + use lance::dataset::refs::Ref; + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/update_field_metadata/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body(r#"{"version":7,"fields":{}}"#.to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch + .update_field_metadata(&[FieldMetadataUpdate::new("category").set("unit", "label")]) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_branch_index_stats_carries_branch_in_body() { + use lance::dataset::refs::Ref; + let table = Table::new_with_handler("my_table", move |request| { + match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body("{}".to_string()) + .unwrap(), + "/v1/table/my_table/index/my_index/stats/" => { + assert_eq!(request_body_json(&request)["branch"], "exp"); + http::Response::builder() + .status(200) + .body( + r#"{"num_indexed_rows":1,"num_unindexed_rows":0,"index_type":"IVF_PQ","distance_type":"l2"}"# + .to_string(), + ) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + } + }); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + assert!(branch.index_stats("my_index").await.unwrap().is_some()); + } + + #[tokio::test] + async fn test_branch_stats_attaches_body_while_main_omits_it() { + use lance::dataset::refs::Ref; + // stats has a bespoke conditional body: a main handle stays a bodyless + // POST, while a branch handle attaches {"branch": ...}. + let stats_body = r#"{"total_bytes":1,"num_rows":3,"num_indices":0,"fragment_stats":{"num_fragments":1,"num_small_fragments":0,"lengths":{"min":3,"max":3,"mean":3,"p25":3,"p50":3,"p75":3,"p99":3}}}"#; + let table = + Table::new_with_handler("my_table", move |request| match request.url().path() { + "/v1/table/my_table/branches/create/" => http::Response::builder() + .status(200) + .body(stats_body.to_string()) + .unwrap(), + "/v1/table/my_table/stats/" => { + match request.body() { + // main handle: byte-identical to the pre-branch wire format. + None => {} + // branch handle: branch travels in the body. + Some(_) => assert_eq!(request_body_json(&request)["branch"], "exp"), + } + http::Response::builder() + .status(200) + .body(stats_body.to_string()) + .unwrap() + } + path => panic!("unexpected request path: {path}"), + }); + table.stats().await.unwrap(); + let branch = table + .create_branch("exp", Ref::Version(None, None)) + .await + .unwrap(); + branch.stats().await.unwrap(); + } } diff --git a/rust/lancedb/src/remote/table/insert.rs b/rust/lancedb/src/remote/table/insert.rs index 49ebb2015..014119ff9 100644 --- a/rust/lancedb/src/remote/table/insert.rs +++ b/rust/lancedb/src/remote/table/insert.rs @@ -48,6 +48,8 @@ pub struct RemoteInsertExec { metrics: ExecutionPlanMetricsSet, upload_id: Option, tracker: Option>, + /// Branch to write to via `?branch=`. `None` targets the main branch. + branch: Option, } impl RemoteInsertExec { @@ -59,9 +61,10 @@ impl RemoteInsertExec { input: Arc, overwrite: bool, tracker: Option>, + branch: Option, ) -> Self { Self::new_inner( - table_name, identifier, client, input, overwrite, None, tracker, + table_name, identifier, client, input, overwrite, None, tracker, branch, ) } @@ -70,6 +73,7 @@ impl RemoteInsertExec { /// Each partition's insert is staged under the given `upload_id` without /// committing. The caller is responsible for calling the complete (or abort) /// endpoint after all partitions finish. + #[allow(clippy::too_many_arguments)] pub fn new_multipart( table_name: String, identifier: String, @@ -78,6 +82,7 @@ impl RemoteInsertExec { overwrite: bool, upload_id: String, tracker: Option>, + branch: Option, ) -> Self { Self::new_inner( table_name, @@ -87,9 +92,11 @@ impl RemoteInsertExec { overwrite, Some(upload_id), tracker, + branch, ) } + #[allow(clippy::too_many_arguments)] fn new_inner( table_name: String, identifier: String, @@ -98,6 +105,7 @@ impl RemoteInsertExec { overwrite: bool, upload_id: Option, tracker: Option>, + branch: Option, ) -> Self { let num_partitions = if upload_id.is_some() { input.output_partitioning().partition_count() @@ -123,6 +131,7 @@ impl RemoteInsertExec { metrics: ExecutionPlanMetricsSet::new(), upload_id, tracker, + branch, } } @@ -273,6 +282,7 @@ impl ExecutionPlan for RemoteInsertExec { self.overwrite, self.upload_id.clone(), self.tracker.clone(), + self.branch.clone(), ))) } @@ -304,6 +314,7 @@ impl ExecutionPlan for RemoteInsertExec { let table_name = self.table_name.clone(); let upload_id = self.upload_id.clone(); let tracker = self.tracker.clone(); + let branch = self.branch.clone(); let stream = futures::stream::once(async move { let mut request = client @@ -316,6 +327,9 @@ impl ExecutionPlan for RemoteInsertExec { if let Some(ref uid) = upload_id { request = request.query(&[("upload_id", uid.as_str())]); } + if let Some(ref b) = branch { + request = request.query(&[("branch", b.as_str())]); + } let (error_tx, mut error_rx) = tokio::sync::oneshot::channel(); let body = Self::stream_as_http_body(input_stream, error_tx, tracker)?;