From 2fa4d04de73477e7206641dba6c855419c90edb0 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 29 Apr 2026 22:15:40 +0530 Subject: [PATCH] feat(python): align with_format("torch") with HuggingFace semantics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - New "torch" format yields list-of-per-row-dicts whose values are torch tensors. PyTorch DataLoader's default collate stacks these into a dict[str, Tensor] per batch — matching HuggingFace dataset.set_format("torch") and removing the need for a custom collate_fn. - The previous "torch" behavior (list of 1-D row tensors without column names) moves to a new "torch_row" format. - "torch_col" is unchanged (2-D tensor, requires custom collate). Closes lancedb/lancedb#3245 --- python/python/lancedb/permutation.py | 18 ++++++++++++++--- python/python/lancedb/util.py | 26 +++++++++++++++++++++++++ python/python/tests/test_permutation.py | 19 +++++++++++++++--- python/python/tests/test_torch.py | 10 ++++++++++ 4 files changed, 67 insertions(+), 6 deletions(-) diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 724a0fd25..02c051fd5 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -9,7 +9,7 @@ import json from ._lancedb import async_permutation_builder, PermutationReader from .table import LanceTable from .background_loop import LOOP -from .util import batch_to_tensor, batch_to_tensor_rows +from .util import batch_to_tensor, batch_to_tensor_dict, batch_to_tensor_rows from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union if TYPE_CHECKING: @@ -697,6 +697,7 @@ class Permutation: "pandas", "arrow", "torch", + "torch_row", "torch_col", "polars", ], @@ -712,8 +713,17 @@ class Permutation: - "python_col" - the batch will be a dict of lists (one entry per column) - "pandas" - the batch will be a pandas DataFrame - "arrow" - the batch will be a pyarrow RecordBatch - - "torch" - the batch will be a list of tensors, one per row - - "torch_col" - the batch will be a 2D torch tensor (first dim indexes columns) + - "torch" - a list of per-row dicts whose values are torch tensors. When + used with ``torch.utils.data.DataLoader`` (default collate), each + batch yielded by the loader is ``dict[str, Tensor]`` — one tensor per + column, with column names preserved. This matches HuggingFace + ``dataset.set_format("torch")`` semantics. + - "torch_row" - a list of 1-D torch tensors, one per row. Each tensor + stacks all column values into a single row vector and column names + are not preserved. (This was the previous "torch" behavior.) + - "torch_col" - a 2-D torch tensor of shape ``(n_cols, n_rows)``. Column + names are not preserved. Requires ``collate_fn=lambda x: x`` if used + with ``DataLoader``. - "polars" - the batch will be a polars DataFrame Conversion may or may not involve a data copy. Lance uses Arrow internally @@ -741,6 +751,8 @@ class Permutation: elif format == "arrow": return self.with_transform(Transforms.arrow2arrow) elif format == "torch": + return self.with_transform(batch_to_tensor_dict) + elif format == "torch_row": return self.with_transform(batch_to_tensor_rows) elif format == "torch_col": return self.with_transform(batch_to_tensor) diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index d5b66707f..81f32b192 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -448,3 +448,29 @@ def batch_to_tensor_rows(batch: pa.RecordBatch): stacked = torch.tensor(numpy.column_stack(columns)) rows = list(stacked.unbind(dim=0)) return rows + + +def batch_to_tensor_dict(batch: pa.RecordBatch): + """ + Convert a PyArrow RecordBatch into a list of per-row dicts whose values + are PyTorch tensors. + + Each column is converted to a tensor in one shot (zero-copy via DLPack + when supported), then sliced per row. The result is shaped to work with + PyTorch's default DataLoader collate, which stacks the per-row dicts + into a single ``dict[str, Tensor]`` per batch — matching the + HuggingFace ``dataset.set_format("torch")`` convention. + + Fails if torch is not installed. + Fails if a column's data type is not supported by PyTorch. + """ + torch = attempt_import_or_raise("torch", "torch") + columns: dict[str, "torch.Tensor"] = {} + for i, name in enumerate(batch.schema.names): + col = batch.column(i) + try: + columns[name] = torch.from_dlpack(col) + except Exception: + columns[name] = torch.tensor(col.to_numpy(zero_copy_only=False)) + n = batch.num_rows + return [{name: t[i] for name, t in columns.items()} for i in range(n)] diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index bb92ba0ba..32c28adab 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -950,14 +950,27 @@ def test_transform_fn(mem_db): try: import torch + # "torch" format: list of per-row dicts of tensors (HF-compatible). torch_result = list( permutation.with_format("torch").iter(10, skip_last_batch=False) )[0] assert isinstance(torch_result, list) assert len(torch_result) == 10 - assert isinstance(torch_result[0], torch.Tensor) - assert torch_result[0].shape == (2,) - assert torch_result[0].dtype == torch.int64 + assert isinstance(torch_result[0], dict) + assert set(torch_result[0].keys()) == {"id", "value"} + assert isinstance(torch_result[0]["id"], torch.Tensor) + assert torch_result[0]["id"].shape == () + assert torch_result[0]["id"].dtype == torch.int64 + + # "torch_row" format: list of 1-D row tensors (previous "torch" behavior). + torch_row_result = list( + permutation.with_format("torch_row").iter(10, skip_last_batch=False) + )[0] + assert isinstance(torch_row_result, list) + assert len(torch_row_result) == 10 + assert isinstance(torch_row_result[0], torch.Tensor) + assert torch_row_result[0].shape == (2,) + assert torch_row_result[0].dtype == torch.int64 except ImportError: # Skip check if torch is not installed pass diff --git a/python/python/tests/test_torch.py b/python/python/tests/test_torch.py index ef1c5e73b..74423da27 100644 --- a/python/python/tests/test_torch.py +++ b/python/python/tests/test_torch.py @@ -27,8 +27,18 @@ def test_permutation_dataloader(mem_db): for batch in dataloader: assert batch["a"].size(0) == 10 + # New "torch" format: per-row dicts of tensors, default collate yields + # dict[str, Tensor] (HuggingFace style). permutation = permutation.with_format("torch") dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True) + for batch in dataloader: + assert isinstance(batch, dict) + assert "a" in batch + assert batch["a"].size() == (10,) + + # Previous "torch" semantics is preserved under the "torch_row" name. + permutation = permutation.with_format("torch_row") + dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True) for batch in dataloader: assert batch.size(0) == 10 assert batch.size(1) == 1