feat(python): align with_format("torch") with HuggingFace semantics

- New "torch" format yields list-of-per-row-dicts whose values are torch
  tensors. PyTorch DataLoader's default collate stacks these into a
  dict[str, Tensor] per batch — matching HuggingFace
  dataset.set_format("torch") and removing the need for a custom collate_fn.
- The previous "torch" behavior (list of 1-D row tensors without column
  names) moves to a new "torch_row" format.
- "torch_col" is unchanged (2-D tensor, requires custom collate).

Closes lancedb/lancedb#3245
This commit is contained in:
Ayush Chaurasia
2026-04-29 22:15:40 +05:30
parent 25dfe2cfd4
commit 2fa4d04de7
4 changed files with 67 additions and 6 deletions

View File

@@ -9,7 +9,7 @@ import json
from ._lancedb import async_permutation_builder, PermutationReader
from .table import LanceTable
from .background_loop import LOOP
from .util import batch_to_tensor, batch_to_tensor_rows
from .util import batch_to_tensor, batch_to_tensor_dict, batch_to_tensor_rows
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
@@ -697,6 +697,7 @@ class Permutation:
"pandas",
"arrow",
"torch",
"torch_row",
"torch_col",
"polars",
],
@@ -712,8 +713,17 @@ class Permutation:
- "python_col" - the batch will be a dict of lists (one entry per column)
- "pandas" - the batch will be a pandas DataFrame
- "arrow" - the batch will be a pyarrow RecordBatch
- "torch" - the batch will be a list of tensors, one per row
- "torch_col" - the batch will be a 2D torch tensor (first dim indexes columns)
- "torch" - a list of per-row dicts whose values are torch tensors. When
used with ``torch.utils.data.DataLoader`` (default collate), each
batch yielded by the loader is ``dict[str, Tensor]`` — one tensor per
column, with column names preserved. This matches HuggingFace
``dataset.set_format("torch")`` semantics.
- "torch_row" - a list of 1-D torch tensors, one per row. Each tensor
stacks all column values into a single row vector and column names
are not preserved. (This was the previous "torch" behavior.)
- "torch_col" - a 2-D torch tensor of shape ``(n_cols, n_rows)``. Column
names are not preserved. Requires ``collate_fn=lambda x: x`` if used
with ``DataLoader``.
- "polars" - the batch will be a polars DataFrame
Conversion may or may not involve a data copy. Lance uses Arrow internally
@@ -741,6 +751,8 @@ class Permutation:
elif format == "arrow":
return self.with_transform(Transforms.arrow2arrow)
elif format == "torch":
return self.with_transform(batch_to_tensor_dict)
elif format == "torch_row":
return self.with_transform(batch_to_tensor_rows)
elif format == "torch_col":
return self.with_transform(batch_to_tensor)

View File

@@ -448,3 +448,29 @@ def batch_to_tensor_rows(batch: pa.RecordBatch):
stacked = torch.tensor(numpy.column_stack(columns))
rows = list(stacked.unbind(dim=0))
return rows
def batch_to_tensor_dict(batch: pa.RecordBatch):
"""
Convert a PyArrow RecordBatch into a list of per-row dicts whose values
are PyTorch tensors.
Each column is converted to a tensor in one shot (zero-copy via DLPack
when supported), then sliced per row. The result is shaped to work with
PyTorch's default DataLoader collate, which stacks the per-row dicts
into a single ``dict[str, Tensor]`` per batch — matching the
HuggingFace ``dataset.set_format("torch")`` convention.
Fails if torch is not installed.
Fails if a column's data type is not supported by PyTorch.
"""
torch = attempt_import_or_raise("torch", "torch")
columns: dict[str, "torch.Tensor"] = {}
for i, name in enumerate(batch.schema.names):
col = batch.column(i)
try:
columns[name] = torch.from_dlpack(col)
except Exception:
columns[name] = torch.tensor(col.to_numpy(zero_copy_only=False))
n = batch.num_rows
return [{name: t[i] for name, t in columns.items()} for i in range(n)]

View File

@@ -950,14 +950,27 @@ def test_transform_fn(mem_db):
try:
import torch
# "torch" format: list of per-row dicts of tensors (HF-compatible).
torch_result = list(
permutation.with_format("torch").iter(10, skip_last_batch=False)
)[0]
assert isinstance(torch_result, list)
assert len(torch_result) == 10
assert isinstance(torch_result[0], torch.Tensor)
assert torch_result[0].shape == (2,)
assert torch_result[0].dtype == torch.int64
assert isinstance(torch_result[0], dict)
assert set(torch_result[0].keys()) == {"id", "value"}
assert isinstance(torch_result[0]["id"], torch.Tensor)
assert torch_result[0]["id"].shape == ()
assert torch_result[0]["id"].dtype == torch.int64
# "torch_row" format: list of 1-D row tensors (previous "torch" behavior).
torch_row_result = list(
permutation.with_format("torch_row").iter(10, skip_last_batch=False)
)[0]
assert isinstance(torch_row_result, list)
assert len(torch_row_result) == 10
assert isinstance(torch_row_result[0], torch.Tensor)
assert torch_row_result[0].shape == (2,)
assert torch_row_result[0].dtype == torch.int64
except ImportError:
# Skip check if torch is not installed
pass

View File

@@ -27,8 +27,18 @@ def test_permutation_dataloader(mem_db):
for batch in dataloader:
assert batch["a"].size(0) == 10
# New "torch" format: per-row dicts of tensors, default collate yields
# dict[str, Tensor] (HuggingFace style).
permutation = permutation.with_format("torch")
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert isinstance(batch, dict)
assert "a" in batch
assert batch["a"].size() == (10,)
# Previous "torch" semantics is preserved under the "torch_row" name.
permutation = permutation.with_format("torch_row")
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert batch.size(0) == 10
assert batch.size(1) == 1