feat: improve Permutation pytorch integration (#3016)

This changes around the output format of `Permutation` in some breaking
ways but I think the API is still new enough to be considered
experimental.

1. In order to align with both huggingface's dataset and torch's
expectations the default output format is now a list of dicts
(row-major) instead of a dict of lists (column-major). I've added a
python_col option which will return the dict of lists.
2. In order to align with pytorch's expectation the `torch` format is
now a list of tensors (row-major) instead of a 2D tensor (column-major).
I've added a torch_col option which will return the 2D tensor instead.

Added tests for torch integration with Permutation

~~Leaving draft until https://github.com/lancedb/lancedb/pull/3013
merges as this is built on top of that~~
This commit is contained in:
Weston Pace
2026-02-12 13:41:14 -08:00
committed by GitHub
parent 02783bf440
commit 70cbee6293
4 changed files with 123 additions and 63 deletions

View File

@@ -9,7 +9,7 @@ import json
from ._lancedb import async_permutation_builder, PermutationReader
from .table import LanceTable
from .background_loop import LOOP
from .util import batch_to_tensor
from .util import batch_to_tensor, batch_to_tensor_rows
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
@@ -333,7 +333,11 @@ class Transforms:
"""
@staticmethod
def arrow2python(batch: pa.RecordBatch) -> dict[str, list[Any]]:
def arrow2python(batch: pa.RecordBatch) -> list[dict[str, Any]]:
return batch.to_pylist()
@staticmethod
def arrow2pythoncol(batch: pa.RecordBatch) -> dict[str, list[Any]]:
return batch.to_pydict()
@staticmethod
@@ -687,7 +691,17 @@ class Permutation:
return
def with_format(
self, format: Literal["numpy", "python", "pandas", "arrow", "torch", "polars"]
self,
format: Literal[
"numpy",
"python",
"python_col",
"pandas",
"arrow",
"torch",
"torch_col",
"polars",
],
) -> "Permutation":
"""
Set the format for batches
@@ -696,16 +710,18 @@ class Permutation:
The format can be one of:
- "numpy" - the batch will be a dict of numpy arrays (one per column)
- "python" - the batch will be a dict of lists (one per column)
- "python" - the batch will be a list of dicts (one per row)
- "python_col" - the batch will be a dict of lists (one entry per column)
- "pandas" - the batch will be a pandas DataFrame
- "arrow" - the batch will be a pyarrow RecordBatch
- "torch" - the batch will be a two dimensional torch tensor
- "torch" - the batch will be a list of tensors, one per row
- "torch_col" - the batch will be a 2D torch tensor (first dim indexes columns)
- "polars" - the batch will be a polars DataFrame
Conversion may or may not involve a data copy. Lance uses Arrow internally
and so it is able to zero-copy to the arrow and polars.
and so it is able to zero-copy to the arrow and polars formats.
Conversion to torch will be zero-copy but will only support a subset of data
Conversion to torch_col will be zero-copy but will only support a subset of data
types (numeric types).
Conversion to numpy and/or pandas will typically be zero-copy for numeric
@@ -718,6 +734,8 @@ class Permutation:
assert format is not None, "format is required"
if format == "python":
return self.with_transform(Transforms.arrow2python)
if format == "python_col":
return self.with_transform(Transforms.arrow2pythoncol)
elif format == "numpy":
return self.with_transform(Transforms.arrow2numpy)
elif format == "pandas":
@@ -725,6 +743,8 @@ class Permutation:
elif format == "arrow":
return self.with_transform(Transforms.arrow2arrow)
elif format == "torch":
return self.with_transform(batch_to_tensor_rows)
elif format == "torch_col":
return self.with_transform(batch_to_tensor)
elif format == "polars":
return self.with_transform(Transforms.arrow2polars())

View File

@@ -419,3 +419,22 @@ def batch_to_tensor(batch: pa.RecordBatch):
"""
torch = attempt_import_or_raise("torch", "torch")
return torch.stack([torch.from_dlpack(col) for col in batch.columns])
def batch_to_tensor_rows(batch: pa.RecordBatch):
"""
Convert a PyArrow RecordBatch to a list of PyTorch Tensor, one per row
Each column is converted to a tensor (using zero-copy via DLPack)
and the columns are then stacked into a single tensor. The 2D tensor
is then converted to a list of tensors, one per row
Fails if torch or numpy is not installed.
Fails if a column's data type is not supported by PyTorch.
"""
torch = attempt_import_or_raise("torch", "torch")
numpy = attempt_import_or_raise("numpy", "numpy")
columns = [col.to_numpy(zero_copy_only=False) for col in batch.columns]
stacked = torch.tensor(numpy.column_stack(columns))
rows = list(stacked.unbind(dim=0))
return rows

View File

@@ -664,23 +664,20 @@ def test_iter_basic(some_permutation: Permutation):
expected_batches = (950 + batch_size - 1) // batch_size # ceiling division
assert len(batches) == expected_batches
# Check that all batches are dicts (default python format)
assert all(isinstance(batch, dict) for batch in batches)
# Check that all batches are lists of dicts (default python format)
assert all(isinstance(batch, list) for batch in batches)
# Check that batches have the correct structure
for batch in batches:
assert "id" in batch
assert "value" in batch
assert isinstance(batch["id"], list)
assert isinstance(batch["value"], list)
assert "id" in batch[0]
assert "value" in batch[0]
# Check that all batches except the last have the correct size
for batch in batches[:-1]:
assert len(batch["id"]) == batch_size
assert len(batch["value"]) == batch_size
assert len(batch) == batch_size
# Last batch might be smaller
assert len(batches[-1]["id"]) <= batch_size
assert len(batches[-1]) <= batch_size
def test_iter_skip_last_batch(some_permutation: Permutation):
@@ -699,11 +696,11 @@ def test_iter_skip_last_batch(some_permutation: Permutation):
if 950 % batch_size != 0:
assert len(batches_without_skip) == num_full_batches + 1
# Last batch should be smaller
assert len(batches_without_skip[-1]["id"]) == 950 % batch_size
assert len(batches_without_skip[-1]) == 950 % batch_size
# All batches with skip_last_batch should be full size
for batch in batches_with_skip:
assert len(batch["id"]) == batch_size
assert len(batch) == batch_size
def test_iter_different_batch_sizes(some_permutation: Permutation):
@@ -720,12 +717,12 @@ def test_iter_different_batch_sizes(some_permutation: Permutation):
# Test with batch size equal to total rows
single_batch = list(some_permutation.iter(950, skip_last_batch=False))
assert len(single_batch) == 1
assert len(single_batch[0]["id"]) == 950
assert len(single_batch[0]) == 950
# Test with batch size larger than total rows
oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False))
assert len(oversized_batch) == 1
assert len(oversized_batch[0]["id"]) == 950
assert len(oversized_batch[0]) == 950
def test_dunder_iter(some_permutation: Permutation):
@@ -738,15 +735,13 @@ def test_dunder_iter(some_permutation: Permutation):
# All batches should be full size
for batch in batches:
assert len(batch["id"]) == 100
assert len(batch["value"]) == 100
assert len(batch) == 100
some_permutation = some_permutation.with_batch_size(400)
batches = list(some_permutation)
assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True
for batch in batches:
assert len(batch["id"]) == 400
assert len(batch["value"]) == 400
assert len(batch) == 400
def test_iter_with_different_formats(some_permutation: Permutation):
@@ -761,7 +756,7 @@ def test_iter_with_different_formats(some_permutation: Permutation):
# Test with python format (default)
python_perm = some_permutation.with_format("python")
python_batches = list(python_perm.iter(batch_size, skip_last_batch=False))
assert all(isinstance(batch, dict) for batch in python_batches)
assert all(isinstance(batch, list) for batch in python_batches)
# Test with pandas format
pandas_perm = some_permutation.with_format("pandas")
@@ -780,8 +775,8 @@ def test_iter_with_column_selection(some_permutation: Permutation):
# Check that batches only contain the id column
for batch in batches:
assert "id" in batch
assert "value" not in batch
assert "id" in batch[0]
assert "value" not in batch[0]
def test_iter_with_column_rename(some_permutation: Permutation):
@@ -791,9 +786,9 @@ def test_iter_with_column_rename(some_permutation: Permutation):
# Check that batches have the renamed column
for batch in batches:
assert "id" in batch
assert "data" in batch
assert "value" not in batch
assert "id" in batch[0]
assert "data" in batch[0]
assert "value" not in batch[0]
def test_iter_with_limit_offset(some_permutation: Permutation):
@@ -812,14 +807,14 @@ def test_iter_with_limit_offset(some_permutation: Permutation):
assert len(limit_batches) == 5
no_skip = some_permutation.iter(101, skip_last_batch=False)
row_100 = next(no_skip)["id"][100]
row_100 = next(no_skip)[100]["id"]
# Test with both limit and offset
limited_perm = some_permutation.with_skip(100).with_take(300)
limited_batches = list(limited_perm.iter(100, skip_last_batch=False))
# Should have 3 batches (300 / 100)
assert len(limited_batches) == 3
assert limited_batches[0]["id"][0] == row_100
assert limited_batches[0][0]["id"] == row_100
def test_iter_empty_permutation(mem_db):
@@ -842,7 +837,7 @@ def test_iter_single_row(mem_db):
# With skip_last_batch=False, should get one batch
batches = list(perm.iter(10, skip_last_batch=False))
assert len(batches) == 1
assert len(batches[0]["id"]) == 1
assert len(batches[0]) == 1
# With skip_last_batch=True, should skip the single row (since it's < batch_size)
batches_skip = list(perm.iter(10, skip_last_batch=True))
@@ -860,8 +855,7 @@ def test_identity_permutation(mem_db):
batches = list(permutation.iter(10, skip_last_batch=False))
assert len(batches) == 1
assert len(batches[0]["id"]) == 10
assert len(batches[0]["value"]) == 10
assert len(batches[0]) == 10
permutation = permutation.remove_columns(["value"])
assert permutation.num_columns == 1
@@ -904,10 +898,10 @@ def test_transform_fn(mem_db):
py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[
0
]
assert len(py_result) == 2
assert len(py_result["id"]) == 10
assert len(py_result["value"]) == 10
assert isinstance(py_result, dict)
assert len(py_result) == 10
assert "id" in py_result[0]
assert "value" in py_result[0]
assert isinstance(py_result, list)
try:
import torch
@@ -915,9 +909,11 @@ def test_transform_fn(mem_db):
torch_result = list(
permutation.with_format("torch").iter(10, skip_last_batch=False)
)[0]
assert torch_result.shape == (2, 10)
assert torch_result.dtype == torch.int64
assert isinstance(torch_result, torch.Tensor)
assert isinstance(torch_result, list)
assert len(torch_result) == 10
assert isinstance(torch_result[0], torch.Tensor)
assert torch_result[0].shape == (2,)
assert torch_result[0].dtype == torch.int64
except ImportError:
# Skip check if torch is not installed
pass
@@ -950,17 +946,16 @@ def test_custom_transform(mem_db):
def test_getitems_basic(some_permutation: Permutation):
"""Test __getitems__ returns correct rows by offset."""
result = some_permutation.__getitems__([0, 1, 2])
assert isinstance(result, dict)
assert "id" in result
assert "value" in result
assert len(result["id"]) == 3
assert isinstance(result, list)
assert "id" in result[0]
assert "value" in result[0]
assert len(result) == 3
def test_getitems_single_index(some_permutation: Permutation):
"""Test __getitems__ with a single index."""
result = some_permutation.__getitems__([0])
assert len(result["id"]) == 1
assert len(result["value"]) == 1
assert len(result) == 1
def test_getitems_preserves_order(some_permutation: Permutation):
@@ -970,38 +965,40 @@ def test_getitems_preserves_order(some_permutation: Permutation):
# Get the same rows in reverse order
reverse = some_permutation.__getitems__([4, 3, 2, 1, 0])
assert forward["id"] == list(reversed(reverse["id"]))
assert forward["value"] == list(reversed(reverse["value"]))
assert [r["id"] for r in forward] == list(reversed([r["id"] for r in reverse]))
assert [r["value"] for r in forward] == list(
reversed([r["value"] for r in reverse])
)
def test_getitems_non_contiguous(some_permutation: Permutation):
"""Test __getitems__ with non-contiguous indices."""
result = some_permutation.__getitems__([0, 10, 50, 100, 500])
assert len(result["id"]) == 5
assert len(result) == 5
# Each id/value pair should match what we'd get individually
for i, offset in enumerate([0, 10, 50, 100, 500]):
single = some_permutation.__getitems__([offset])
assert result["id"][i] == single["id"][0]
assert result["value"][i] == single["value"][0]
assert result[i]["id"] == single[0]["id"]
assert result[i]["value"] == single[0]["value"]
def test_getitems_with_column_selection(some_permutation: Permutation):
"""Test __getitems__ respects column selection."""
id_only = some_permutation.select_columns(["id"])
result = id_only.__getitems__([0, 1, 2])
assert "id" in result
assert "value" not in result
assert len(result["id"]) == 3
assert "id" in result[0]
assert "value" not in result[0]
assert len(result) == 3
def test_getitems_with_column_rename(some_permutation: Permutation):
"""Test __getitems__ respects column renames."""
renamed = some_permutation.rename_column("value", "data")
result = renamed.__getitems__([0, 1])
assert "data" in result
assert "value" not in result
assert len(result["data"]) == 2
assert "data" in result[0]
assert "value" not in result[0]
assert len(result) == 2
def test_getitems_with_format(some_permutation: Permutation):
@@ -1032,8 +1029,8 @@ def test_getitems_identity_permutation(mem_db):
perm = Permutation.identity(tbl)
result = perm.__getitems__([0, 5, 9])
assert result["id"] == [0, 5, 9]
assert result["value"] == [0, 5, 9]
assert [r["id"] for r in result] == [0, 5, 9]
assert [r["value"] for r in result] == [0, 5, 9]
def test_getitems_with_limit_offset(some_permutation: Permutation):
@@ -1042,12 +1039,12 @@ def test_getitems_with_limit_offset(some_permutation: Permutation):
# Should be able to access offsets within the limited range
result = limited.__getitems__([0, 1, 199])
assert len(result["id"]) == 3
assert len(result) == 3
# The first item of the limited permutation should match offset 100 of original
full_result = some_permutation.__getitems__([100])
limited_result = limited.__getitems__([0])
assert limited_result["id"][0] == full_result["id"][0]
assert limited_result[0]["id"] == full_result[0]["id"]
def test_getitems_invalid_offset(some_permutation: Permutation):

View File

@@ -4,6 +4,7 @@
import pyarrow as pa
import pytest
from lancedb.util import tbl_to_tensor
from lancedb.permutation import Permutation
torch = pytest.importorskip("torch")
@@ -16,3 +17,26 @@ def test_table_dataloader(mem_db):
for batch in dataloader:
assert batch.size(0) == 1
assert batch.size(1) == 10
def test_permutation_dataloader(mem_db):
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert batch["a"].size(0) == 10
permutation = permutation.with_format("torch")
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert batch.size(0) == 10
assert batch.size(1) == 1
permutation = permutation.with_format("torch_col")
dataloader = torch.utils.data.DataLoader(
permutation, collate_fn=lambda x: x, batch_size=10, shuffle=True
)
for batch in dataloader:
assert batch.size(0) == 1
assert batch.size(1) == 10