feat: add table branch support to remote tables and Python/TS bindings

This commit is contained in:
Brendan Clement
2026-06-11 23:39:34 -07:00
parent dfbe5becaa
commit 4150e0b1c1
9 changed files with 1471 additions and 126 deletions

View File

@@ -396,13 +396,13 @@ class RemoteDBConnection(DBConnection):
The namespace to open the table from.
None or empty list represents root namespace.
branch: str, optional
Branching is not yet supported on remote tables, so only the
default branch is accepted (``None`` or ``"main"``); any other
value raises ``NotImplementedError``.
If provided, open a handle scoped to this branch instead of the
default branch. Reads and writes operate in the branch's context.
version: int, optional
If provided, open the table pinned to this version, producing a
read-only handle. Call ``checkout_latest`` to return to a writable
state.
read-only handle. Composes with ``branch``: when both are given,
opens that branch at the version; otherwise opens ``main`` at the
version. Call ``checkout_latest`` to return to a writable state.
Returns
-------
@@ -410,11 +410,6 @@ class RemoteDBConnection(DBConnection):
"""
from .table import RemoteTable
# Remote supports version time-travel but not branches: reject a non-main
# branch, but allow a version-only open (or "main").
if branch is not None and branch != "main":
raise NotImplementedError("branching is not yet supported on remote tables")
if namespace_path is None:
namespace_path = []
if index_cache_size is not None:
@@ -430,7 +425,9 @@ class RemoteDBConnection(DBConnection):
connection_state=self.serialize,
namespace_path=namespace_path,
)
if version is not None:
if branch is not None:
tbl = tbl.branches.checkout(branch, version)
elif version is not None:
tbl.checkout(version)
return tbl

View File

@@ -56,7 +56,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.table import _normalize_progress
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
from ..table import AsyncTable, BlobMode, IndexStatistics, Query, Table, Tags
from ..table import AsyncTable, BlobMode, Branches, IndexStatistics, Query, Table, Tags
from ..types import BaseTokenizerType
@@ -75,6 +75,10 @@ class RemoteTable(Table):
self._connection_state = connection_state
self._namespace_path = list(namespace_path or [])
self._checkout_version: Optional[int] = None
# The branch this handle is scoped to (None == main). Persisted so the
# branch (and any pinned version) survives a fork/pickle reopen instead
# of silently reverting to main.
self._branch: Optional[str] = None
self._pid = os.getpid()
def _serialized_connection_state(self) -> str:
@@ -109,9 +113,15 @@ class RemoteTable(Table):
from lancedb import deserialize_conn
db = deserialize_conn(self._serialized_connection_state(), for_worker=True)
table = db.open_table(self._name, namespace_path=self._namespace_path)
if self._checkout_version is not None:
table.checkout(self._checkout_version)
# Reopen on the same branch and pinned version. open_table composes both
# (branch=None / version=None reproduce the plain main-latest open), so a
# branch handle no longer silently reopens on main.
table = db.open_table(
self._name,
namespace_path=self._namespace_path,
branch=self._branch,
version=self._checkout_version,
)
self._table_handle = table._table
self.db_name = table.db_name
@@ -124,6 +134,7 @@ class RemoteTable(Table):
"name": self.name,
"namespace_path": self._namespace_path,
"checkout_version": self._checkout_version,
"branch": self._branch,
}
def __setstate__(self, state: dict) -> None:
@@ -133,6 +144,7 @@ class RemoteTable(Table):
self._connection_state = state["connection_state"]
self._namespace_path = state["namespace_path"]
self._checkout_version = state["checkout_version"]
self._branch = state.get("branch")
self._pid = None
@property
@@ -160,6 +172,36 @@ class RemoteTable(Table):
def tags(self) -> Tags:
return Tags(self._table)
@property
def branches(self) -> Branches:
"""Branch management for the table.
``create``/``checkout`` return a new table handle scoped to the branch;
writes on it do not affect ``main``.
"""
return Branches(self)
def current_branch(self) -> Optional[str]:
"""The branch this table handle is scoped to, or ``None`` for ``main``."""
return self._table.current_branch()
def _wrap_branch_handle(
self, async_table: AsyncTable, version: Optional[int] = None
) -> "RemoteTable":
# A branch handle stays a RemoteTable, carrying the same connection
# context so it can be reopened after a fork/pickle. Record the branch
# and any explicit version pin so the reopen targets the branch (and
# version), not main -- the in-memory handle alone does not survive.
handle = RemoteTable(
async_table,
self.db_name,
connection_state=self._connection_state,
namespace_path=self._namespace_path,
)
handle._branch = async_table.current_branch()
handle._checkout_version = version
return handle
@cached_property
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
"""

View File

@@ -796,6 +796,10 @@ class Table(ABC):
"""
raise NotImplementedError
def current_branch(self) -> Optional[str]:
"""The branch this table handle is scoped to, or ``None`` for ``main``."""
raise NotImplementedError
def __len__(self) -> int:
"""The number of rows in this Table"""
return self.count_rows(None)
@@ -2223,6 +2227,21 @@ class LanceTable(Table):
"""The branch this table handle is scoped to, or ``None`` for ``main``."""
return self._table.current_branch()
def _wrap_branch_handle(
self, async_table: "AsyncTable", version: Optional[int] = None
) -> "LanceTable":
# version is unused locally: the pin already lives on async_table and a
# local handle is not reopened via a serialized connection.
return LanceTable(
self._conn,
async_table.name,
namespace_path=self._namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._pushdown_operations,
location=self._location,
_async=async_table,
)
def checkout(self, version: Union[int, str]):
"""Checkout a version of the table. This is an in-place operation.
@@ -5934,7 +5953,7 @@ class Branches:
name: str,
from_ref: Optional[str] = None,
from_version: Optional[int] = None,
) -> "LanceTable":
) -> "Table":
"""Create a branch and return a handle scoped to it.
Parameters
@@ -5951,7 +5970,7 @@ class Branches:
)
return self._wrap(async_table)
def checkout(self, name: str, version: Optional[int] = None) -> "LanceTable":
def checkout(self, name: str, version: Optional[int] = None) -> "Table":
"""Check out an existing branch and return a handle scoped to it.
Parameters
@@ -5964,25 +5983,20 @@ class Branches:
the branch's latest and stays writable.
"""
async_table = LOOP.run(self._table.branches.checkout(name, version))
return self._wrap(async_table)
return self._wrap(async_table, version)
def delete(self, name: str) -> None:
"""Delete a branch."""
LOOP.run(self._table.branches.delete(name))
def _wrap(self, async_table: "AsyncTable") -> "LanceTable":
# Reuse the parent's connection + namespace context; from_inner would drop
# it and break identity/query routing for namespace-backed tables.
parent = self._parent
return LanceTable(
parent._conn,
async_table.name,
namespace_path=parent._namespace_path,
namespace_client=parent._namespace_client,
pushdown_operations=parent._pushdown_operations,
location=parent._location,
_async=async_table,
)
def _wrap(
self, async_table: "AsyncTable", version: Optional[int] = None
) -> "Table":
# Delegate to the parent so a branch handle keeps the parent's concrete
# type (LanceTable for local, RemoteTable for remote) and its connection
# context. `version` is the explicit pin (if any) so a remote handle can
# restore branch+version after a fork/pickle.
return self._parent._wrap_branch_handle(async_table, version)
class AsyncTags:

View File

@@ -154,50 +154,120 @@ async def test_async_checkout():
assert await table.count_rows() == 300
def _branch_open_handler(request):
if "/branches/list" in request.path:
body = json.dumps(
{
"branches": {
"exp": {
"parentBranch": None,
"parentVersion": 1,
"createAt": 1,
"manifestSize": 1,
}
}
}
).encode()
else:
# describe (table open + version/branch validation)
body = json.dumps({"version": 2, "schema": {"fields": []}}).encode()
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(body)
def test_remote_open_table_branch_and_version():
with mock_lancedb_connection(_branch_open_handler) as db:
# version-only (and "main" + version) time-travels the main chain
assert db.open_table("test", version=2) is not None
assert db.open_table("test", branch="main", version=2).current_branch() is None
# a non-main branch opens a handle scoped to that branch, with or
# without a version
assert db.open_table("test", branch="exp").current_branch() == "exp"
assert db.open_table("test", branch="exp", version=2).current_branch() == "exp"
def test_remote_table_branches_sync():
# Branch CRUD + current_branch on the sync RemoteTable. The handle returned
# by create/checkout must stay a RemoteTable scoped to the branch.
from lancedb.remote.table import RemoteTable
def handler(request):
# describe (table open + version validation) always succeeds
if "/branches/list" in request.path:
body = json.dumps(
{
"branches": {
"exp": {
"parentBranch": None,
"parentVersion": 1,
"createAt": 1,
"manifestSize": 1,
}
}
}
).encode()
elif "/branches/create" in request.path or "/branches/delete" in request.path:
body = b"{}"
else:
# describe (table open + checkout validation)
body = json.dumps({"version": 1, "schema": {"fields": []}}).encode()
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(
json.dumps({"version": 2, "schema": {"fields": []}}).encode()
)
request.wfile.write(body)
with mock_lancedb_connection(handler) as db:
# version-only (and "main" + version) is allowed: remote supports
# version time-travel even though it has no branches
assert db.open_table("test", version=2) is not None
assert db.open_table("test", branch="main", version=2) is not None
table = db.open_table("test")
assert isinstance(table, RemoteTable)
# a handle opened normally tracks main
assert table.current_branch() is None
# a non-main branch is rejected, with or without a version
with pytest.raises(NotImplementedError, match="branching"):
db.open_table("test", branch="exp")
with pytest.raises(NotImplementedError, match="branching"):
db.open_table("test", branch="exp", version=2)
# create returns a RemoteTable handle scoped to the new branch
branch = table.branches.create("exp")
assert isinstance(branch, RemoteTable)
assert branch.current_branch() == "exp"
# list + checkout round trip; checkout also yields a branch-scoped handle
assert "exp" in table.branches.list()
checked = table.branches.checkout("exp")
assert isinstance(checked, RemoteTable)
assert checked.current_branch() == "exp"
table.branches.delete("exp")
@pytest.mark.asyncio
async def test_async_remote_open_table_branch_and_version():
def handler(request):
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(
json.dumps({"version": 2, "schema": {"fields": []}}).encode()
)
async with mock_lancedb_connection_async(handler) as db:
# version-only (and "main" + version) is allowed: "main" is the default
# branch, so it must not hit the unsupported remote branch path
async with mock_lancedb_connection_async(_branch_open_handler) as db:
# version-only (and "main" + version) time-travels the main chain
assert await db.open_table("test", version=2) is not None
assert await db.open_table("test", branch="main", version=2) is not None
main_v2 = await db.open_table("test", branch="main", version=2)
assert main_v2.current_branch() is None
# a non-main branch is rejected, with or without a version
with pytest.raises(NotImplementedError, match="branching"):
await db.open_table("test", branch="exp")
with pytest.raises(NotImplementedError, match="branching"):
await db.open_table("test", branch="exp", version=2)
# a non-main branch opens a handle scoped to that branch
exp = await db.open_table("test", branch="exp")
assert exp.current_branch() == "exp"
exp_v2 = await db.open_table("test", branch="exp", version=2)
assert exp_v2.current_branch() == "exp"
def test_remote_table_branch_survives_pickle():
# Regression: a branch-scoped handle (and a branch+version handle) must keep
# its branch across a pickle/fork round-trip. Before the fix, __getstate__
# dropped the branch and _ensure_open reopened on main, silently routing
# subsequent reads/writes to the wrong branch.
with mock_lancedb_connection(_branch_open_handler) as db:
branch = db.open_table("test", branch="exp")
assert branch.current_branch() == "exp"
restored = pickle.loads(pickle.dumps(branch))
assert restored.current_branch() == "exp"
# the pinned version is carried through as well
branch_v2 = db.open_table("test", branch="exp", version=2)
restored_v2 = pickle.loads(pickle.dumps(branch_v2))
assert restored_v2.current_branch() == "exp"
def test_table_len_sync():