mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-26 18:40:42 +00:00
This changes around the output format of `Permutation` in some breaking ways but I think the API is still new enough to be considered experimental. 1. In order to align with both huggingface's dataset and torch's expectations the default output format is now a list of dicts (row-major) instead of a dict of lists (column-major). I've added a python_col option which will return the dict of lists. 2. In order to align with pytorch's expectation the `torch` format is now a list of tensors (row-major) instead of a 2D tensor (column-major). I've added a torch_col option which will return the 2D tensor instead. Added tests for torch integration with Permutation ~~Leaving draft until https://github.com/lancedb/lancedb/pull/3013 merges as this is built on top of that~~
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
import pyarrow as pa
|
|
import pytest
|
|
from lancedb.util import tbl_to_tensor
|
|
from lancedb.permutation import Permutation
|
|
|
|
torch = pytest.importorskip("torch")
|
|
|
|
|
|
def test_table_dataloader(mem_db):
|
|
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
|
dataloader = torch.utils.data.DataLoader(
|
|
table, collate_fn=tbl_to_tensor, batch_size=10, shuffle=True
|
|
)
|
|
for batch in dataloader:
|
|
assert batch.size(0) == 1
|
|
assert batch.size(1) == 10
|
|
|
|
|
|
def test_permutation_dataloader(mem_db):
|
|
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
|
|
|
permutation = Permutation.identity(table)
|
|
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
|
|
for batch in dataloader:
|
|
assert batch["a"].size(0) == 10
|
|
|
|
permutation = permutation.with_format("torch")
|
|
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
|
|
for batch in dataloader:
|
|
assert batch.size(0) == 10
|
|
assert batch.size(1) == 1
|
|
|
|
permutation = permutation.with_format("torch_col")
|
|
dataloader = torch.utils.data.DataLoader(
|
|
permutation, collate_fn=lambda x: x, batch_size=10, shuffle=True
|
|
)
|
|
for batch in dataloader:
|
|
assert batch.size(0) == 1
|
|
assert batch.size(1) == 10
|