diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 4e36bfbfb..db9afda35 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -228,6 +228,7 @@ class Table: def tags(self) -> Tags: ... @property def branches(self) -> Branches: ... + def current_branch(self) -> Optional[str]: ... def query(self) -> Query: ... def take_offsets(self, offsets: list[int]) -> TakeQuery: ... def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ... diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 5f93bb7dd..49c00e1ff 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -2095,22 +2095,27 @@ class LanceTable(Table): "Please install with `pip install pylance`." ) + branch = self.current_branch() + version = None if branch is not None else self.version if self._namespace_client is not None: table_id = self._namespace_path + [self.name] - return lance.dataset( - version=self.version, + ds = lance.dataset( + version=version, storage_options=self._conn.storage_options, namespace_client=self._namespace_client, table_id=table_id, **kwargs, ) - - return lance.dataset( - self._dataset_path, - version=self.version, - storage_options=self._conn.storage_options, - **kwargs, - ) + else: + ds = lance.dataset( + self._dataset_path, + version=version, + storage_options=self._conn.storage_options, + **kwargs, + ) + if branch is not None: + ds = ds.checkout_version((branch, self.version)) + return ds @property def schema(self) -> pa.Schema: @@ -2185,6 +2190,10 @@ class LanceTable(Table): """ return Branches(self) + def current_branch(self) -> Optional[str]: + """The branch this table handle is scoped to, or ``None`` for ``main``.""" + return self._table.current_branch() + def checkout(self, version: Union[int, str]): """Checkout a version of the table. This is an in-place operation. @@ -4346,12 +4355,20 @@ class AsyncTable: "Please install with `pip install pylance`." ) - return lance.dataset( + # lance.dataset() can't open a branch directly, so open the base table + # and check out the branch ref (a None branch resolves to main). + branch = self.current_branch() + table_version = await self.version() + version = None if branch is not None else table_version + ds = lance.dataset( await self.uri(), - version=await self.version(), + version=version, storage_options=await self.latest_storage_options(), **kwargs, ) + if branch is not None: + ds = ds.checkout_version((branch, table_version)) + return ds async def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs) -> "pd.DataFrame": """Return the table as a pandas DataFrame. @@ -5469,6 +5486,10 @@ class AsyncTable: """ return AsyncBranches(self._inner) + def current_branch(self) -> Optional[str]: + """The branch this table handle is scoped to, or ``None`` for ``main``.""" + return self._inner.current_branch() + async def optimize( self, *, diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 8548f2e27..084a077ef 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -969,6 +969,16 @@ def test_open_table_with_branch(tmp_path): assert opened.namespace == ["ns1"] +def test_branch_to_lance_targets_branch(tmp_path): + db = lancedb.connect(tmp_path) + table = db.create_table("t", [{"i": 1}]) + branch = table.branches.create("exp") + branch.add([{"i": 2}]) # branch: 2 rows, main: 1 row + + assert branch.to_lance().count_rows() == 2 + assert table.to_lance().count_rows() == 1 + + @pytest.mark.asyncio async def test_async_branches(tmp_path): db = await lancedb.connect_async(tmp_path) diff --git a/python/src/table.rs b/python/src/table.rs index 5993e035b..cd85d9ad6 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -864,6 +864,10 @@ impl Table { Ok(Tags::new(self.inner_ref()?.clone())) } + pub fn current_branch(&self) -> PyResult> { + Ok(self.inner_ref()?.current_branch()) + } + #[getter] pub fn branches(&self) -> PyResult { Ok(Branches::new(self.inner_ref()?.clone()))