mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 10:30:40 +00:00
feat: improve Permutation pytorch integration (#3016)
This changes around the output format of `Permutation` in some breaking ways but I think the API is still new enough to be considered experimental. 1. In order to align with both huggingface's dataset and torch's expectations the default output format is now a list of dicts (row-major) instead of a dict of lists (column-major). I've added a python_col option which will return the dict of lists. 2. In order to align with pytorch's expectation the `torch` format is now a list of tensors (row-major) instead of a 2D tensor (column-major). I've added a torch_col option which will return the 2D tensor instead. Added tests for torch integration with Permutation ~~Leaving draft until https://github.com/lancedb/lancedb/pull/3013 merges as this is built on top of that~~
This commit is contained in:
@@ -9,7 +9,7 @@ import json
|
||||
from ._lancedb import async_permutation_builder, PermutationReader
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
from .util import batch_to_tensor
|
||||
from .util import batch_to_tensor, batch_to_tensor_rows
|
||||
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -333,7 +333,11 @@ class Transforms:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def arrow2python(batch: pa.RecordBatch) -> dict[str, list[Any]]:
|
||||
def arrow2python(batch: pa.RecordBatch) -> list[dict[str, Any]]:
|
||||
return batch.to_pylist()
|
||||
|
||||
@staticmethod
|
||||
def arrow2pythoncol(batch: pa.RecordBatch) -> dict[str, list[Any]]:
|
||||
return batch.to_pydict()
|
||||
|
||||
@staticmethod
|
||||
@@ -687,7 +691,17 @@ class Permutation:
|
||||
return
|
||||
|
||||
def with_format(
|
||||
self, format: Literal["numpy", "python", "pandas", "arrow", "torch", "polars"]
|
||||
self,
|
||||
format: Literal[
|
||||
"numpy",
|
||||
"python",
|
||||
"python_col",
|
||||
"pandas",
|
||||
"arrow",
|
||||
"torch",
|
||||
"torch_col",
|
||||
"polars",
|
||||
],
|
||||
) -> "Permutation":
|
||||
"""
|
||||
Set the format for batches
|
||||
@@ -696,16 +710,18 @@ class Permutation:
|
||||
|
||||
The format can be one of:
|
||||
- "numpy" - the batch will be a dict of numpy arrays (one per column)
|
||||
- "python" - the batch will be a dict of lists (one per column)
|
||||
- "python" - the batch will be a list of dicts (one per row)
|
||||
- "python_col" - the batch will be a dict of lists (one entry per column)
|
||||
- "pandas" - the batch will be a pandas DataFrame
|
||||
- "arrow" - the batch will be a pyarrow RecordBatch
|
||||
- "torch" - the batch will be a two dimensional torch tensor
|
||||
- "torch" - the batch will be a list of tensors, one per row
|
||||
- "torch_col" - the batch will be a 2D torch tensor (first dim indexes columns)
|
||||
- "polars" - the batch will be a polars DataFrame
|
||||
|
||||
Conversion may or may not involve a data copy. Lance uses Arrow internally
|
||||
and so it is able to zero-copy to the arrow and polars.
|
||||
and so it is able to zero-copy to the arrow and polars formats.
|
||||
|
||||
Conversion to torch will be zero-copy but will only support a subset of data
|
||||
Conversion to torch_col will be zero-copy but will only support a subset of data
|
||||
types (numeric types).
|
||||
|
||||
Conversion to numpy and/or pandas will typically be zero-copy for numeric
|
||||
@@ -718,6 +734,8 @@ class Permutation:
|
||||
assert format is not None, "format is required"
|
||||
if format == "python":
|
||||
return self.with_transform(Transforms.arrow2python)
|
||||
if format == "python_col":
|
||||
return self.with_transform(Transforms.arrow2pythoncol)
|
||||
elif format == "numpy":
|
||||
return self.with_transform(Transforms.arrow2numpy)
|
||||
elif format == "pandas":
|
||||
@@ -725,6 +743,8 @@ class Permutation:
|
||||
elif format == "arrow":
|
||||
return self.with_transform(Transforms.arrow2arrow)
|
||||
elif format == "torch":
|
||||
return self.with_transform(batch_to_tensor_rows)
|
||||
elif format == "torch_col":
|
||||
return self.with_transform(batch_to_tensor)
|
||||
elif format == "polars":
|
||||
return self.with_transform(Transforms.arrow2polars())
|
||||
|
||||
@@ -419,3 +419,22 @@ def batch_to_tensor(batch: pa.RecordBatch):
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
return torch.stack([torch.from_dlpack(col) for col in batch.columns])
|
||||
|
||||
|
||||
def batch_to_tensor_rows(batch: pa.RecordBatch):
|
||||
"""
|
||||
Convert a PyArrow RecordBatch to a list of PyTorch Tensor, one per row
|
||||
|
||||
Each column is converted to a tensor (using zero-copy via DLPack)
|
||||
and the columns are then stacked into a single tensor. The 2D tensor
|
||||
is then converted to a list of tensors, one per row
|
||||
|
||||
Fails if torch or numpy is not installed.
|
||||
Fails if a column's data type is not supported by PyTorch.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
numpy = attempt_import_or_raise("numpy", "numpy")
|
||||
columns = [col.to_numpy(zero_copy_only=False) for col in batch.columns]
|
||||
stacked = torch.tensor(numpy.column_stack(columns))
|
||||
rows = list(stacked.unbind(dim=0))
|
||||
return rows
|
||||
|
||||
@@ -664,23 +664,20 @@ def test_iter_basic(some_permutation: Permutation):
|
||||
expected_batches = (950 + batch_size - 1) // batch_size # ceiling division
|
||||
assert len(batches) == expected_batches
|
||||
|
||||
# Check that all batches are dicts (default python format)
|
||||
assert all(isinstance(batch, dict) for batch in batches)
|
||||
# Check that all batches are lists of dicts (default python format)
|
||||
assert all(isinstance(batch, list) for batch in batches)
|
||||
|
||||
# Check that batches have the correct structure
|
||||
for batch in batches:
|
||||
assert "id" in batch
|
||||
assert "value" in batch
|
||||
assert isinstance(batch["id"], list)
|
||||
assert isinstance(batch["value"], list)
|
||||
assert "id" in batch[0]
|
||||
assert "value" in batch[0]
|
||||
|
||||
# Check that all batches except the last have the correct size
|
||||
for batch in batches[:-1]:
|
||||
assert len(batch["id"]) == batch_size
|
||||
assert len(batch["value"]) == batch_size
|
||||
assert len(batch) == batch_size
|
||||
|
||||
# Last batch might be smaller
|
||||
assert len(batches[-1]["id"]) <= batch_size
|
||||
assert len(batches[-1]) <= batch_size
|
||||
|
||||
|
||||
def test_iter_skip_last_batch(some_permutation: Permutation):
|
||||
@@ -699,11 +696,11 @@ def test_iter_skip_last_batch(some_permutation: Permutation):
|
||||
if 950 % batch_size != 0:
|
||||
assert len(batches_without_skip) == num_full_batches + 1
|
||||
# Last batch should be smaller
|
||||
assert len(batches_without_skip[-1]["id"]) == 950 % batch_size
|
||||
assert len(batches_without_skip[-1]) == 950 % batch_size
|
||||
|
||||
# All batches with skip_last_batch should be full size
|
||||
for batch in batches_with_skip:
|
||||
assert len(batch["id"]) == batch_size
|
||||
assert len(batch) == batch_size
|
||||
|
||||
|
||||
def test_iter_different_batch_sizes(some_permutation: Permutation):
|
||||
@@ -720,12 +717,12 @@ def test_iter_different_batch_sizes(some_permutation: Permutation):
|
||||
# Test with batch size equal to total rows
|
||||
single_batch = list(some_permutation.iter(950, skip_last_batch=False))
|
||||
assert len(single_batch) == 1
|
||||
assert len(single_batch[0]["id"]) == 950
|
||||
assert len(single_batch[0]) == 950
|
||||
|
||||
# Test with batch size larger than total rows
|
||||
oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False))
|
||||
assert len(oversized_batch) == 1
|
||||
assert len(oversized_batch[0]["id"]) == 950
|
||||
assert len(oversized_batch[0]) == 950
|
||||
|
||||
|
||||
def test_dunder_iter(some_permutation: Permutation):
|
||||
@@ -738,15 +735,13 @@ def test_dunder_iter(some_permutation: Permutation):
|
||||
|
||||
# All batches should be full size
|
||||
for batch in batches:
|
||||
assert len(batch["id"]) == 100
|
||||
assert len(batch["value"]) == 100
|
||||
assert len(batch) == 100
|
||||
|
||||
some_permutation = some_permutation.with_batch_size(400)
|
||||
batches = list(some_permutation)
|
||||
assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True
|
||||
for batch in batches:
|
||||
assert len(batch["id"]) == 400
|
||||
assert len(batch["value"]) == 400
|
||||
assert len(batch) == 400
|
||||
|
||||
|
||||
def test_iter_with_different_formats(some_permutation: Permutation):
|
||||
@@ -761,7 +756,7 @@ def test_iter_with_different_formats(some_permutation: Permutation):
|
||||
# Test with python format (default)
|
||||
python_perm = some_permutation.with_format("python")
|
||||
python_batches = list(python_perm.iter(batch_size, skip_last_batch=False))
|
||||
assert all(isinstance(batch, dict) for batch in python_batches)
|
||||
assert all(isinstance(batch, list) for batch in python_batches)
|
||||
|
||||
# Test with pandas format
|
||||
pandas_perm = some_permutation.with_format("pandas")
|
||||
@@ -780,8 +775,8 @@ def test_iter_with_column_selection(some_permutation: Permutation):
|
||||
|
||||
# Check that batches only contain the id column
|
||||
for batch in batches:
|
||||
assert "id" in batch
|
||||
assert "value" not in batch
|
||||
assert "id" in batch[0]
|
||||
assert "value" not in batch[0]
|
||||
|
||||
|
||||
def test_iter_with_column_rename(some_permutation: Permutation):
|
||||
@@ -791,9 +786,9 @@ def test_iter_with_column_rename(some_permutation: Permutation):
|
||||
|
||||
# Check that batches have the renamed column
|
||||
for batch in batches:
|
||||
assert "id" in batch
|
||||
assert "data" in batch
|
||||
assert "value" not in batch
|
||||
assert "id" in batch[0]
|
||||
assert "data" in batch[0]
|
||||
assert "value" not in batch[0]
|
||||
|
||||
|
||||
def test_iter_with_limit_offset(some_permutation: Permutation):
|
||||
@@ -812,14 +807,14 @@ def test_iter_with_limit_offset(some_permutation: Permutation):
|
||||
assert len(limit_batches) == 5
|
||||
|
||||
no_skip = some_permutation.iter(101, skip_last_batch=False)
|
||||
row_100 = next(no_skip)["id"][100]
|
||||
row_100 = next(no_skip)[100]["id"]
|
||||
|
||||
# Test with both limit and offset
|
||||
limited_perm = some_permutation.with_skip(100).with_take(300)
|
||||
limited_batches = list(limited_perm.iter(100, skip_last_batch=False))
|
||||
# Should have 3 batches (300 / 100)
|
||||
assert len(limited_batches) == 3
|
||||
assert limited_batches[0]["id"][0] == row_100
|
||||
assert limited_batches[0][0]["id"] == row_100
|
||||
|
||||
|
||||
def test_iter_empty_permutation(mem_db):
|
||||
@@ -842,7 +837,7 @@ def test_iter_single_row(mem_db):
|
||||
# With skip_last_batch=False, should get one batch
|
||||
batches = list(perm.iter(10, skip_last_batch=False))
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]["id"]) == 1
|
||||
assert len(batches[0]) == 1
|
||||
|
||||
# With skip_last_batch=True, should skip the single row (since it's < batch_size)
|
||||
batches_skip = list(perm.iter(10, skip_last_batch=True))
|
||||
@@ -860,8 +855,7 @@ def test_identity_permutation(mem_db):
|
||||
|
||||
batches = list(permutation.iter(10, skip_last_batch=False))
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]["id"]) == 10
|
||||
assert len(batches[0]["value"]) == 10
|
||||
assert len(batches[0]) == 10
|
||||
|
||||
permutation = permutation.remove_columns(["value"])
|
||||
assert permutation.num_columns == 1
|
||||
@@ -904,10 +898,10 @@ def test_transform_fn(mem_db):
|
||||
py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[
|
||||
0
|
||||
]
|
||||
assert len(py_result) == 2
|
||||
assert len(py_result["id"]) == 10
|
||||
assert len(py_result["value"]) == 10
|
||||
assert isinstance(py_result, dict)
|
||||
assert len(py_result) == 10
|
||||
assert "id" in py_result[0]
|
||||
assert "value" in py_result[0]
|
||||
assert isinstance(py_result, list)
|
||||
|
||||
try:
|
||||
import torch
|
||||
@@ -915,9 +909,11 @@ def test_transform_fn(mem_db):
|
||||
torch_result = list(
|
||||
permutation.with_format("torch").iter(10, skip_last_batch=False)
|
||||
)[0]
|
||||
assert torch_result.shape == (2, 10)
|
||||
assert torch_result.dtype == torch.int64
|
||||
assert isinstance(torch_result, torch.Tensor)
|
||||
assert isinstance(torch_result, list)
|
||||
assert len(torch_result) == 10
|
||||
assert isinstance(torch_result[0], torch.Tensor)
|
||||
assert torch_result[0].shape == (2,)
|
||||
assert torch_result[0].dtype == torch.int64
|
||||
except ImportError:
|
||||
# Skip check if torch is not installed
|
||||
pass
|
||||
@@ -950,17 +946,16 @@ def test_custom_transform(mem_db):
|
||||
def test_getitems_basic(some_permutation: Permutation):
|
||||
"""Test __getitems__ returns correct rows by offset."""
|
||||
result = some_permutation.__getitems__([0, 1, 2])
|
||||
assert isinstance(result, dict)
|
||||
assert "id" in result
|
||||
assert "value" in result
|
||||
assert len(result["id"]) == 3
|
||||
assert isinstance(result, list)
|
||||
assert "id" in result[0]
|
||||
assert "value" in result[0]
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
def test_getitems_single_index(some_permutation: Permutation):
|
||||
"""Test __getitems__ with a single index."""
|
||||
result = some_permutation.__getitems__([0])
|
||||
assert len(result["id"]) == 1
|
||||
assert len(result["value"]) == 1
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_getitems_preserves_order(some_permutation: Permutation):
|
||||
@@ -970,38 +965,40 @@ def test_getitems_preserves_order(some_permutation: Permutation):
|
||||
# Get the same rows in reverse order
|
||||
reverse = some_permutation.__getitems__([4, 3, 2, 1, 0])
|
||||
|
||||
assert forward["id"] == list(reversed(reverse["id"]))
|
||||
assert forward["value"] == list(reversed(reverse["value"]))
|
||||
assert [r["id"] for r in forward] == list(reversed([r["id"] for r in reverse]))
|
||||
assert [r["value"] for r in forward] == list(
|
||||
reversed([r["value"] for r in reverse])
|
||||
)
|
||||
|
||||
|
||||
def test_getitems_non_contiguous(some_permutation: Permutation):
|
||||
"""Test __getitems__ with non-contiguous indices."""
|
||||
result = some_permutation.__getitems__([0, 10, 50, 100, 500])
|
||||
assert len(result["id"]) == 5
|
||||
assert len(result) == 5
|
||||
|
||||
# Each id/value pair should match what we'd get individually
|
||||
for i, offset in enumerate([0, 10, 50, 100, 500]):
|
||||
single = some_permutation.__getitems__([offset])
|
||||
assert result["id"][i] == single["id"][0]
|
||||
assert result["value"][i] == single["value"][0]
|
||||
assert result[i]["id"] == single[0]["id"]
|
||||
assert result[i]["value"] == single[0]["value"]
|
||||
|
||||
|
||||
def test_getitems_with_column_selection(some_permutation: Permutation):
|
||||
"""Test __getitems__ respects column selection."""
|
||||
id_only = some_permutation.select_columns(["id"])
|
||||
result = id_only.__getitems__([0, 1, 2])
|
||||
assert "id" in result
|
||||
assert "value" not in result
|
||||
assert len(result["id"]) == 3
|
||||
assert "id" in result[0]
|
||||
assert "value" not in result[0]
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
def test_getitems_with_column_rename(some_permutation: Permutation):
|
||||
"""Test __getitems__ respects column renames."""
|
||||
renamed = some_permutation.rename_column("value", "data")
|
||||
result = renamed.__getitems__([0, 1])
|
||||
assert "data" in result
|
||||
assert "value" not in result
|
||||
assert len(result["data"]) == 2
|
||||
assert "data" in result[0]
|
||||
assert "value" not in result[0]
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_getitems_with_format(some_permutation: Permutation):
|
||||
@@ -1032,8 +1029,8 @@ def test_getitems_identity_permutation(mem_db):
|
||||
perm = Permutation.identity(tbl)
|
||||
|
||||
result = perm.__getitems__([0, 5, 9])
|
||||
assert result["id"] == [0, 5, 9]
|
||||
assert result["value"] == [0, 5, 9]
|
||||
assert [r["id"] for r in result] == [0, 5, 9]
|
||||
assert [r["value"] for r in result] == [0, 5, 9]
|
||||
|
||||
|
||||
def test_getitems_with_limit_offset(some_permutation: Permutation):
|
||||
@@ -1042,12 +1039,12 @@ def test_getitems_with_limit_offset(some_permutation: Permutation):
|
||||
|
||||
# Should be able to access offsets within the limited range
|
||||
result = limited.__getitems__([0, 1, 199])
|
||||
assert len(result["id"]) == 3
|
||||
assert len(result) == 3
|
||||
|
||||
# The first item of the limited permutation should match offset 100 of original
|
||||
full_result = some_permutation.__getitems__([100])
|
||||
limited_result = limited.__getitems__([0])
|
||||
assert limited_result["id"][0] == full_result["id"][0]
|
||||
assert limited_result[0]["id"] == full_result[0]["id"]
|
||||
|
||||
|
||||
def test_getitems_invalid_offset(some_permutation: Permutation):
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.util import tbl_to_tensor
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
@@ -16,3 +17,26 @@ def test_table_dataloader(mem_db):
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
|
||||
|
||||
def test_permutation_dataloader(mem_db):
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
|
||||
permutation = Permutation.identity(table)
|
||||
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
|
||||
permutation = permutation.with_format("torch")
|
||||
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 10
|
||||
assert batch.size(1) == 1
|
||||
|
||||
permutation = permutation.with_format("torch_col")
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation, collate_fn=lambda x: x, batch_size=10, shuffle=True
|
||||
)
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
|
||||
Reference in New Issue
Block a user