mirror of
https://github.com/lancedb/lancedb.git
synced 2026-07-04 19:40:39 +00:00
feat: add table branch support to remote tables and Python/TS bindings
This commit is contained in:
@@ -295,6 +295,23 @@ await table.createIndex("my_float_col");
|
||||
|
||||
***
|
||||
|
||||
### currentBranch()
|
||||
|
||||
```ts
|
||||
abstract currentBranch(): null | string
|
||||
```
|
||||
|
||||
The branch this table handle is scoped to, or `null` for the main branch.
|
||||
|
||||
A handle returned by [Branches.create](Branches.md#create) or [Branches.checkout](Branches.md#checkout)
|
||||
reports the branch it targets; a handle opened normally reports `null`.
|
||||
|
||||
#### Returns
|
||||
|
||||
`null` \| `string`
|
||||
|
||||
***
|
||||
|
||||
### delete()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -191,30 +191,36 @@ describe("remote connection", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("allows version on remote but rejects a non-main branch", async () => {
|
||||
it("supports version time-travel and branches on remote", async () => {
|
||||
await withMockDatabase(
|
||||
(_req, res) => {
|
||||
// describe (table open + version validation) always succeeds
|
||||
const body = JSON.stringify({
|
||||
name: "t",
|
||||
version: 2,
|
||||
schema: { fields: [] },
|
||||
});
|
||||
(req, res) => {
|
||||
const body = req.url?.includes("/branches/list")
|
||||
? JSON.stringify({
|
||||
branches: {
|
||||
exp: { parentVersion: 1, createAt: 1, manifestSize: 1 },
|
||||
},
|
||||
})
|
||||
: JSON.stringify({ name: "t", version: 2, schema: { fields: [] } });
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||
},
|
||||
async (db) => {
|
||||
// version-only (and "main" + version) is allowed: remote supports
|
||||
// version time-travel even though it has no branches
|
||||
await db.openTable("t", undefined, { version: 2 });
|
||||
await db.openTable("t", undefined, { branch: "main", version: 2 });
|
||||
// version-only (and "main" + version) time-travel the main chain
|
||||
const v2 = await db.openTable("t", undefined, { version: 2 });
|
||||
expect(v2.currentBranch()).toBeNull();
|
||||
const mainV2 = await db.openTable("t", undefined, {
|
||||
branch: "main",
|
||||
version: 2,
|
||||
});
|
||||
expect(mainV2.currentBranch()).toBeNull();
|
||||
|
||||
// a non-main branch is rejected, with or without a version
|
||||
await expect(
|
||||
db.openTable("t", undefined, { branch: "exp" }),
|
||||
).rejects.toThrow(/branching/);
|
||||
await expect(
|
||||
db.openTable("t", undefined, { branch: "exp", version: 2 }),
|
||||
).rejects.toThrow(/branching/);
|
||||
// a non-main branch opens a handle scoped to that branch
|
||||
const exp = await db.openTable("t", undefined, { branch: "exp" });
|
||||
expect(exp.currentBranch()).toBe("exp");
|
||||
const expV2 = await db.openTable("t", undefined, {
|
||||
branch: "exp",
|
||||
version: 2,
|
||||
});
|
||||
expect(expV2.currentBranch()).toBe("exp");
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
@@ -89,8 +89,11 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
await table.add([{ id: 1 }]);
|
||||
expect(await table.countRows()).toBe(1);
|
||||
|
||||
expect(table.currentBranch()).toBeNull();
|
||||
|
||||
// fork an isolated, writable branch from main
|
||||
const branch = await (await table.branches()).create("exp");
|
||||
expect(branch.currentBranch()).toBe("exp");
|
||||
expect(await branch.countRows()).toBe(1);
|
||||
await branch.add([{ id: 2 }]);
|
||||
expect(await branch.countRows()).toBe(2);
|
||||
@@ -109,6 +112,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
|
||||
// checkout returns a handle scoped to the branch's latest
|
||||
const checkedOut = await (await table.branches()).checkout("exp");
|
||||
expect(checkedOut.currentBranch()).toBe("exp");
|
||||
expect(await checkedOut.countRows()).toBe(2);
|
||||
|
||||
// delete removes it
|
||||
|
||||
@@ -663,6 +663,14 @@ export abstract class Table {
|
||||
*/
|
||||
abstract branches(): Promise<Branches>;
|
||||
|
||||
/**
|
||||
* The branch this table handle is scoped to, or `null` for the main branch.
|
||||
*
|
||||
* A handle returned by {@link Branches.create} or {@link Branches.checkout}
|
||||
* reports the branch it targets; a handle opened normally reports `null`.
|
||||
*/
|
||||
abstract currentBranch(): string | null;
|
||||
|
||||
/**
|
||||
* Restore the table to the currently checked out version
|
||||
*
|
||||
@@ -1122,6 +1130,10 @@ export class LocalTable extends Table {
|
||||
return new Branches(await this.inner.branches());
|
||||
}
|
||||
|
||||
currentBranch(): string | null {
|
||||
return this.inner.currentBranch() ?? null;
|
||||
}
|
||||
|
||||
async optimize(options?: Partial<OptimizeOptions>): Promise<OptimizeStats> {
|
||||
let cleanupOlderThanMs;
|
||||
if (
|
||||
|
||||
@@ -487,6 +487,12 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
/// The branch this handle is scoped to, or `null` for the main branch.
|
||||
#[napi]
|
||||
pub fn current_branch(&self) -> napi::Result<Option<String>> {
|
||||
Ok(self.inner_ref()?.current_branch())
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn optimize(
|
||||
&self,
|
||||
|
||||
@@ -396,13 +396,13 @@ class RemoteDBConnection(DBConnection):
|
||||
The namespace to open the table from.
|
||||
None or empty list represents root namespace.
|
||||
branch: str, optional
|
||||
Branching is not yet supported on remote tables, so only the
|
||||
default branch is accepted (``None`` or ``"main"``); any other
|
||||
value raises ``NotImplementedError``.
|
||||
If provided, open a handle scoped to this branch instead of the
|
||||
default branch. Reads and writes operate in the branch's context.
|
||||
version: int, optional
|
||||
If provided, open the table pinned to this version, producing a
|
||||
read-only handle. Call ``checkout_latest`` to return to a writable
|
||||
state.
|
||||
read-only handle. Composes with ``branch``: when both are given,
|
||||
opens that branch at the version; otherwise opens ``main`` at the
|
||||
version. Call ``checkout_latest`` to return to a writable state.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -410,11 +410,6 @@ class RemoteDBConnection(DBConnection):
|
||||
"""
|
||||
from .table import RemoteTable
|
||||
|
||||
# Remote supports version time-travel but not branches: reject a non-main
|
||||
# branch, but allow a version-only open (or "main").
|
||||
if branch is not None and branch != "main":
|
||||
raise NotImplementedError("branching is not yet supported on remote tables")
|
||||
|
||||
if namespace_path is None:
|
||||
namespace_path = []
|
||||
if index_cache_size is not None:
|
||||
@@ -430,7 +425,9 @@ class RemoteDBConnection(DBConnection):
|
||||
connection_state=self.serialize,
|
||||
namespace_path=namespace_path,
|
||||
)
|
||||
if version is not None:
|
||||
if branch is not None:
|
||||
tbl = tbl.branches.checkout(branch, version)
|
||||
elif version is not None:
|
||||
tbl.checkout(version)
|
||||
return tbl
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.table import _normalize_progress
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
|
||||
from ..table import AsyncTable, BlobMode, IndexStatistics, Query, Table, Tags
|
||||
from ..table import AsyncTable, BlobMode, Branches, IndexStatistics, Query, Table, Tags
|
||||
from ..types import BaseTokenizerType
|
||||
|
||||
|
||||
@@ -75,6 +75,9 @@ class RemoteTable(Table):
|
||||
self._connection_state = connection_state
|
||||
self._namespace_path = list(namespace_path or [])
|
||||
self._checkout_version: Optional[int] = None
|
||||
# The branch this handle is scoped to (None == main). Persisted so a
|
||||
# fork/pickle reopen restores the branch instead of reverting to main.
|
||||
self._branch: Optional[str] = None
|
||||
self._pid = os.getpid()
|
||||
|
||||
def _serialized_connection_state(self) -> str:
|
||||
@@ -109,9 +112,14 @@ class RemoteTable(Table):
|
||||
from lancedb import deserialize_conn
|
||||
|
||||
db = deserialize_conn(self._serialized_connection_state(), for_worker=True)
|
||||
table = db.open_table(self._name, namespace_path=self._namespace_path)
|
||||
if self._checkout_version is not None:
|
||||
table.checkout(self._checkout_version)
|
||||
# Reopen on the same branch and pinned version (branch=None / version=None
|
||||
# reproduce the plain main-latest open).
|
||||
table = db.open_table(
|
||||
self._name,
|
||||
namespace_path=self._namespace_path,
|
||||
branch=self._branch,
|
||||
version=self._checkout_version,
|
||||
)
|
||||
|
||||
self._table_handle = table._table
|
||||
self.db_name = table.db_name
|
||||
@@ -124,6 +132,7 @@ class RemoteTable(Table):
|
||||
"name": self.name,
|
||||
"namespace_path": self._namespace_path,
|
||||
"checkout_version": self._checkout_version,
|
||||
"branch": self._branch,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
@@ -133,6 +142,7 @@ class RemoteTable(Table):
|
||||
self._connection_state = state["connection_state"]
|
||||
self._namespace_path = state["namespace_path"]
|
||||
self._checkout_version = state["checkout_version"]
|
||||
self._branch = state.get("branch")
|
||||
self._pid = None
|
||||
|
||||
@property
|
||||
@@ -160,6 +170,34 @@ class RemoteTable(Table):
|
||||
def tags(self) -> Tags:
|
||||
return Tags(self._table)
|
||||
|
||||
@property
|
||||
def branches(self) -> Branches:
|
||||
"""Branch management for the table.
|
||||
|
||||
``create``/``checkout`` return a new table handle scoped to the branch;
|
||||
writes on it do not affect ``main``.
|
||||
"""
|
||||
return Branches(self)
|
||||
|
||||
def current_branch(self) -> Optional[str]:
|
||||
"""The branch this table handle is scoped to, or ``None`` for ``main``."""
|
||||
return self._table.current_branch()
|
||||
|
||||
def _wrap_branch_handle(
|
||||
self, async_table: AsyncTable, version: Optional[int] = None
|
||||
) -> "RemoteTable":
|
||||
# A branch handle stays a RemoteTable with the same connection context.
|
||||
# Record the branch and version pin so a fork/pickle reopen restores both.
|
||||
handle = RemoteTable(
|
||||
async_table,
|
||||
self.db_name,
|
||||
connection_state=self._connection_state,
|
||||
namespace_path=self._namespace_path,
|
||||
)
|
||||
handle._branch = async_table.current_branch()
|
||||
handle._checkout_version = version
|
||||
return handle
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
|
||||
"""
|
||||
|
||||
@@ -796,6 +796,10 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def current_branch(self) -> Optional[str]:
|
||||
"""The branch this table handle is scoped to, or ``None`` for ``main``."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of rows in this Table"""
|
||||
return self.count_rows(None)
|
||||
@@ -2223,6 +2227,21 @@ class LanceTable(Table):
|
||||
"""The branch this table handle is scoped to, or ``None`` for ``main``."""
|
||||
return self._table.current_branch()
|
||||
|
||||
def _wrap_branch_handle(
|
||||
self, async_table: "AsyncTable", version: Optional[int] = None
|
||||
) -> "LanceTable":
|
||||
# version is unused locally: the pin already lives on async_table and a
|
||||
# local handle is not reopened via a serialized connection.
|
||||
return LanceTable(
|
||||
self._conn,
|
||||
async_table.name,
|
||||
namespace_path=self._namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._pushdown_operations,
|
||||
location=self._location,
|
||||
_async=async_table,
|
||||
)
|
||||
|
||||
def checkout(self, version: Union[int, str]):
|
||||
"""Checkout a version of the table. This is an in-place operation.
|
||||
|
||||
@@ -5934,7 +5953,7 @@ class Branches:
|
||||
name: str,
|
||||
from_ref: Optional[str] = None,
|
||||
from_version: Optional[int] = None,
|
||||
) -> "LanceTable":
|
||||
) -> "Table":
|
||||
"""Create a branch and return a handle scoped to it.
|
||||
|
||||
Parameters
|
||||
@@ -5951,7 +5970,7 @@ class Branches:
|
||||
)
|
||||
return self._wrap(async_table)
|
||||
|
||||
def checkout(self, name: str, version: Optional[int] = None) -> "LanceTable":
|
||||
def checkout(self, name: str, version: Optional[int] = None) -> "Table":
|
||||
"""Check out an existing branch and return a handle scoped to it.
|
||||
|
||||
Parameters
|
||||
@@ -5964,25 +5983,19 @@ class Branches:
|
||||
the branch's latest and stays writable.
|
||||
"""
|
||||
async_table = LOOP.run(self._table.branches.checkout(name, version))
|
||||
return self._wrap(async_table)
|
||||
return self._wrap(async_table, version)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
"""Delete a branch."""
|
||||
LOOP.run(self._table.branches.delete(name))
|
||||
|
||||
def _wrap(self, async_table: "AsyncTable") -> "LanceTable":
|
||||
# Reuse the parent's connection + namespace context; from_inner would drop
|
||||
# it and break identity/query routing for namespace-backed tables.
|
||||
parent = self._parent
|
||||
return LanceTable(
|
||||
parent._conn,
|
||||
async_table.name,
|
||||
namespace_path=parent._namespace_path,
|
||||
namespace_client=parent._namespace_client,
|
||||
pushdown_operations=parent._pushdown_operations,
|
||||
location=parent._location,
|
||||
_async=async_table,
|
||||
)
|
||||
def _wrap(
|
||||
self, async_table: "AsyncTable", version: Optional[int] = None
|
||||
) -> "Table":
|
||||
# Delegate to the parent so the branch handle keeps its concrete type
|
||||
# (LanceTable / RemoteTable) and connection context; `version` is the
|
||||
# explicit pin so a remote handle can restore branch+version on reopen.
|
||||
return self._parent._wrap_branch_handle(async_table, version)
|
||||
|
||||
|
||||
class AsyncTags:
|
||||
|
||||
@@ -154,50 +154,116 @@ async def test_async_checkout():
|
||||
assert await table.count_rows() == 300
|
||||
|
||||
|
||||
def _branch_open_handler(request):
|
||||
if "/branches/list" in request.path:
|
||||
body = json.dumps(
|
||||
{
|
||||
"branches": {
|
||||
"exp": {
|
||||
"parentBranch": None,
|
||||
"parentVersion": 1,
|
||||
"createAt": 1,
|
||||
"manifestSize": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
).encode()
|
||||
else:
|
||||
# describe (table open + version/branch validation)
|
||||
body = json.dumps({"version": 2, "schema": {"fields": []}}).encode()
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(body)
|
||||
|
||||
|
||||
def test_remote_open_table_branch_and_version():
|
||||
with mock_lancedb_connection(_branch_open_handler) as db:
|
||||
# version-only (and "main" + version) time-travels the main chain
|
||||
assert db.open_table("test", version=2) is not None
|
||||
assert db.open_table("test", branch="main", version=2).current_branch() is None
|
||||
|
||||
# a non-main branch opens a handle scoped to that branch, with or
|
||||
# without a version
|
||||
assert db.open_table("test", branch="exp").current_branch() == "exp"
|
||||
assert db.open_table("test", branch="exp", version=2).current_branch() == "exp"
|
||||
|
||||
|
||||
def test_remote_table_branches_sync():
|
||||
# Branch CRUD + current_branch on the sync RemoteTable. The handle returned
|
||||
# by create/checkout must stay a RemoteTable scoped to the branch.
|
||||
from lancedb.remote.table import RemoteTable
|
||||
|
||||
def handler(request):
|
||||
# describe (table open + version validation) always succeeds
|
||||
if "/branches/list" in request.path:
|
||||
body = json.dumps(
|
||||
{
|
||||
"branches": {
|
||||
"exp": {
|
||||
"parentBranch": None,
|
||||
"parentVersion": 1,
|
||||
"createAt": 1,
|
||||
"manifestSize": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
).encode()
|
||||
elif "/branches/create" in request.path or "/branches/delete" in request.path:
|
||||
body = b"{}"
|
||||
else:
|
||||
# describe (table open + checkout validation)
|
||||
body = json.dumps({"version": 1, "schema": {"fields": []}}).encode()
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(
|
||||
json.dumps({"version": 2, "schema": {"fields": []}}).encode()
|
||||
)
|
||||
request.wfile.write(body)
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
# version-only (and "main" + version) is allowed: remote supports
|
||||
# version time-travel even though it has no branches
|
||||
assert db.open_table("test", version=2) is not None
|
||||
assert db.open_table("test", branch="main", version=2) is not None
|
||||
table = db.open_table("test")
|
||||
assert isinstance(table, RemoteTable)
|
||||
assert table.current_branch() is None
|
||||
|
||||
# a non-main branch is rejected, with or without a version
|
||||
with pytest.raises(NotImplementedError, match="branching"):
|
||||
db.open_table("test", branch="exp")
|
||||
with pytest.raises(NotImplementedError, match="branching"):
|
||||
db.open_table("test", branch="exp", version=2)
|
||||
branch = table.branches.create("exp")
|
||||
assert isinstance(branch, RemoteTable)
|
||||
assert branch.current_branch() == "exp"
|
||||
|
||||
# list + checkout round trip; checkout also yields a branch-scoped handle
|
||||
assert "exp" in table.branches.list()
|
||||
checked = table.branches.checkout("exp")
|
||||
assert isinstance(checked, RemoteTable)
|
||||
assert checked.current_branch() == "exp"
|
||||
|
||||
table.branches.delete("exp")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_remote_open_table_branch_and_version():
|
||||
def handler(request):
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(
|
||||
json.dumps({"version": 2, "schema": {"fields": []}}).encode()
|
||||
)
|
||||
|
||||
async with mock_lancedb_connection_async(handler) as db:
|
||||
# version-only (and "main" + version) is allowed: "main" is the default
|
||||
# branch, so it must not hit the unsupported remote branch path
|
||||
async with mock_lancedb_connection_async(_branch_open_handler) as db:
|
||||
# version-only (and "main" + version) time-travels the main chain
|
||||
assert await db.open_table("test", version=2) is not None
|
||||
assert await db.open_table("test", branch="main", version=2) is not None
|
||||
main_v2 = await db.open_table("test", branch="main", version=2)
|
||||
assert main_v2.current_branch() is None
|
||||
|
||||
# a non-main branch is rejected, with or without a version
|
||||
with pytest.raises(NotImplementedError, match="branching"):
|
||||
await db.open_table("test", branch="exp")
|
||||
with pytest.raises(NotImplementedError, match="branching"):
|
||||
await db.open_table("test", branch="exp", version=2)
|
||||
# a non-main branch opens a handle scoped to that branch
|
||||
exp = await db.open_table("test", branch="exp")
|
||||
assert exp.current_branch() == "exp"
|
||||
exp_v2 = await db.open_table("test", branch="exp", version=2)
|
||||
assert exp_v2.current_branch() == "exp"
|
||||
|
||||
|
||||
def test_remote_table_branch_survives_pickle():
|
||||
# Regression: a branch-scoped handle must keep its branch across a
|
||||
# pickle/fork round-trip (it used to reopen on main).
|
||||
with mock_lancedb_connection(_branch_open_handler) as db:
|
||||
branch = db.open_table("test", branch="exp")
|
||||
assert branch.current_branch() == "exp"
|
||||
restored = pickle.loads(pickle.dumps(branch))
|
||||
assert restored.current_branch() == "exp"
|
||||
|
||||
# the pinned version is carried through as well
|
||||
branch_v2 = db.open_table("test", branch="exp", version=2)
|
||||
restored_v2 = pickle.loads(pickle.dumps(branch_v2))
|
||||
assert restored_v2.current_branch() == "exp"
|
||||
|
||||
|
||||
def test_table_len_sync():
|
||||
|
||||
@@ -985,45 +985,42 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_open_table_branch_and_version() {
|
||||
// Remote supports version time-travel but not branches. A version-only
|
||||
// open (or one on the default "main" branch) must succeed; a non-main
|
||||
// branch must be rejected, with or without a version.
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.url().path(), "/v1/table/t/describe/");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"table": "t", "version": 2, "schema": {"fields": [
|
||||
{"name": "a", "type": { "type": "int32" }, "nullable": false}
|
||||
]}}"#,
|
||||
)
|
||||
.unwrap()
|
||||
let body = if request.url().path() == "/v1/table/t/branches/list/" {
|
||||
// checkout_branch validates the branch exists via list_branches.
|
||||
r#"{"branches":{"exp":{"parentVersion":1,"createAt":1,"manifestSize":1}}}"#
|
||||
} else {
|
||||
// describe (table open + version/branch validation)
|
||||
r#"{"table": "t", "version": 2, "schema": {"fields": [
|
||||
{"name": "a", "type": { "type": "int32" }, "nullable": false}
|
||||
]}}"#
|
||||
};
|
||||
http::Response::builder().status(200).body(body).unwrap()
|
||||
});
|
||||
|
||||
// version-only: allowed (open + checkout(version) both round-trip)
|
||||
conn.open_table("t").version(2).execute().await.unwrap();
|
||||
|
||||
// "main" is the default branch, so it counts as no branch
|
||||
conn.open_table("t")
|
||||
// version-only (and "main" + version) time-travel the main chain
|
||||
let v2 = conn.open_table("t").version(2).execute().await.unwrap();
|
||||
assert_eq!(v2.current_branch(), None);
|
||||
let main_v2 = conn
|
||||
.open_table("t")
|
||||
.branch("main")
|
||||
.version(2)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(main_v2.current_branch(), None);
|
||||
|
||||
// a non-main branch is rejected, with or without a version
|
||||
assert!(matches!(
|
||||
conn.open_table("t").branch("exp").execute().await,
|
||||
Err(Error::NotSupported { .. })
|
||||
));
|
||||
assert!(matches!(
|
||||
conn.open_table("t")
|
||||
.branch("exp")
|
||||
.version(2)
|
||||
.execute()
|
||||
.await,
|
||||
Err(Error::NotSupported { .. })
|
||||
));
|
||||
// a non-main branch opens a handle scoped to that branch
|
||||
let exp = conn.open_table("t").branch("exp").execute().await.unwrap();
|
||||
assert_eq!(exp.current_branch(), Some("exp".to_string()));
|
||||
let exp_v2 = conn
|
||||
.open_table("t")
|
||||
.branch("exp")
|
||||
.version(2)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(exp_v2.current_branch(), Some("exp".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -48,6 +48,8 @@ pub struct RemoteInsertExec<S: HttpSend = Sender> {
|
||||
metrics: ExecutionPlanMetricsSet,
|
||||
upload_id: Option<String>,
|
||||
tracker: Option<Arc<WriteProgressTracker>>,
|
||||
/// Branch to write to via `?branch=`. `None` targets the main branch.
|
||||
branch: Option<String>,
|
||||
}
|
||||
|
||||
impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
@@ -59,9 +61,10 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
overwrite: bool,
|
||||
tracker: Option<Arc<WriteProgressTracker>>,
|
||||
branch: Option<String>,
|
||||
) -> Self {
|
||||
Self::new_inner(
|
||||
table_name, identifier, client, input, overwrite, None, tracker,
|
||||
table_name, identifier, client, input, overwrite, None, tracker, branch,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -70,6 +73,7 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
/// Each partition's insert is staged under the given `upload_id` without
|
||||
/// committing. The caller is responsible for calling the complete (or abort)
|
||||
/// endpoint after all partitions finish.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new_multipart(
|
||||
table_name: String,
|
||||
identifier: String,
|
||||
@@ -78,6 +82,7 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
overwrite: bool,
|
||||
upload_id: String,
|
||||
tracker: Option<Arc<WriteProgressTracker>>,
|
||||
branch: Option<String>,
|
||||
) -> Self {
|
||||
Self::new_inner(
|
||||
table_name,
|
||||
@@ -87,9 +92,11 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
overwrite,
|
||||
Some(upload_id),
|
||||
tracker,
|
||||
branch,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new_inner(
|
||||
table_name: String,
|
||||
identifier: String,
|
||||
@@ -98,6 +105,7 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
overwrite: bool,
|
||||
upload_id: Option<String>,
|
||||
tracker: Option<Arc<WriteProgressTracker>>,
|
||||
branch: Option<String>,
|
||||
) -> Self {
|
||||
let num_partitions = if upload_id.is_some() {
|
||||
input.output_partitioning().partition_count()
|
||||
@@ -123,6 +131,7 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
metrics: ExecutionPlanMetricsSet::new(),
|
||||
upload_id,
|
||||
tracker,
|
||||
branch,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,6 +282,7 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
||||
self.overwrite,
|
||||
self.upload_id.clone(),
|
||||
self.tracker.clone(),
|
||||
self.branch.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
@@ -304,6 +314,7 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
||||
let table_name = self.table_name.clone();
|
||||
let upload_id = self.upload_id.clone();
|
||||
let tracker = self.tracker.clone();
|
||||
let branch = self.branch.clone();
|
||||
|
||||
let stream = futures::stream::once(async move {
|
||||
let mut request = client
|
||||
@@ -316,6 +327,9 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
||||
if let Some(ref uid) = upload_id {
|
||||
request = request.query(&[("upload_id", uid.as_str())]);
|
||||
}
|
||||
if let Some(ref b) = branch {
|
||||
request = request.query(&[("branch", b.as_str())]);
|
||||
}
|
||||
|
||||
let (error_tx, mut error_rx) = tokio::sync::oneshot::channel();
|
||||
let body = Self::stream_as_http_body(input_stream, error_tx, tracker)?;
|
||||
|
||||
Reference in New Issue
Block a user