Files
lancedb/python/python/tests/test_torch.py
Weston Pace 70cbee6293 feat: improve Permutation pytorch integration (#3016)
This changes around the output format of `Permutation` in some breaking
ways but I think the API is still new enough to be considered
experimental.

1. In order to align with both huggingface's dataset and torch's
expectations the default output format is now a list of dicts
(row-major) instead of a dict of lists (column-major). I've added a
python_col option which will return the dict of lists.
2. In order to align with pytorch's expectation the `torch` format is
now a list of tensors (row-major) instead of a 2D tensor (column-major).
I've added a torch_col option which will return the 2D tensor instead.

Added tests for torch integration with Permutation

~~Leaving draft until https://github.com/lancedb/lancedb/pull/3013
merges as this is built on top of that~~
2026-02-12 13:41:14 -08:00

43 lines
1.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pyarrow as pa
import pytest
from lancedb.util import tbl_to_tensor
from lancedb.permutation import Permutation
torch = pytest.importorskip("torch")
def test_table_dataloader(mem_db):
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
dataloader = torch.utils.data.DataLoader(
table, collate_fn=tbl_to_tensor, batch_size=10, shuffle=True
)
for batch in dataloader:
assert batch.size(0) == 1
assert batch.size(1) == 10
def test_permutation_dataloader(mem_db):
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert batch["a"].size(0) == 10
permutation = permutation.with_format("torch")
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert batch.size(0) == 10
assert batch.size(1) == 1
permutation = permutation.with_format("torch_col")
dataloader = torch.utils.data.DataLoader(
permutation, collate_fn=lambda x: x, batch_size=10, shuffle=True
)
for batch in dataloader:
assert batch.size(0) == 1
assert batch.size(1) == 10