diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 04adf38bc..5d133c309 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -9,7 +9,7 @@ import json from ._lancedb import async_permutation_builder, PermutationReader from .table import LanceTable from .background_loop import LOOP -from .util import batch_to_tensor +from .util import batch_to_tensor, batch_to_tensor_rows from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union if TYPE_CHECKING: @@ -333,7 +333,11 @@ class Transforms: """ @staticmethod - def arrow2python(batch: pa.RecordBatch) -> dict[str, list[Any]]: + def arrow2python(batch: pa.RecordBatch) -> list[dict[str, Any]]: + return batch.to_pylist() + + @staticmethod + def arrow2pythoncol(batch: pa.RecordBatch) -> dict[str, list[Any]]: return batch.to_pydict() @staticmethod @@ -687,7 +691,17 @@ class Permutation: return def with_format( - self, format: Literal["numpy", "python", "pandas", "arrow", "torch", "polars"] + self, + format: Literal[ + "numpy", + "python", + "python_col", + "pandas", + "arrow", + "torch", + "torch_col", + "polars", + ], ) -> "Permutation": """ Set the format for batches @@ -696,16 +710,18 @@ class Permutation: The format can be one of: - "numpy" - the batch will be a dict of numpy arrays (one per column) - - "python" - the batch will be a dict of lists (one per column) + - "python" - the batch will be a list of dicts (one per row) + - "python_col" - the batch will be a dict of lists (one entry per column) - "pandas" - the batch will be a pandas DataFrame - "arrow" - the batch will be a pyarrow RecordBatch - - "torch" - the batch will be a two dimensional torch tensor + - "torch" - the batch will be a list of tensors, one per row + - "torch_col" - the batch will be a 2D torch tensor (first dim indexes columns) - "polars" - the batch will be a polars DataFrame Conversion may or may not involve a data copy. Lance uses Arrow internally - and so it is able to zero-copy to the arrow and polars. + and so it is able to zero-copy to the arrow and polars formats. - Conversion to torch will be zero-copy but will only support a subset of data + Conversion to torch_col will be zero-copy but will only support a subset of data types (numeric types). Conversion to numpy and/or pandas will typically be zero-copy for numeric @@ -718,6 +734,8 @@ class Permutation: assert format is not None, "format is required" if format == "python": return self.with_transform(Transforms.arrow2python) + if format == "python_col": + return self.with_transform(Transforms.arrow2pythoncol) elif format == "numpy": return self.with_transform(Transforms.arrow2numpy) elif format == "pandas": @@ -725,6 +743,8 @@ class Permutation: elif format == "arrow": return self.with_transform(Transforms.arrow2arrow) elif format == "torch": + return self.with_transform(batch_to_tensor_rows) + elif format == "torch_col": return self.with_transform(batch_to_tensor) elif format == "polars": return self.with_transform(Transforms.arrow2polars()) diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index 8084cbd1b..a3666c75c 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -419,3 +419,22 @@ def batch_to_tensor(batch: pa.RecordBatch): """ torch = attempt_import_or_raise("torch", "torch") return torch.stack([torch.from_dlpack(col) for col in batch.columns]) + + +def batch_to_tensor_rows(batch: pa.RecordBatch): + """ + Convert a PyArrow RecordBatch to a list of PyTorch Tensor, one per row + + Each column is converted to a tensor (using zero-copy via DLPack) + and the columns are then stacked into a single tensor. The 2D tensor + is then converted to a list of tensors, one per row + + Fails if torch or numpy is not installed. + Fails if a column's data type is not supported by PyTorch. + """ + torch = attempt_import_or_raise("torch", "torch") + numpy = attempt_import_or_raise("numpy", "numpy") + columns = [col.to_numpy(zero_copy_only=False) for col in batch.columns] + stacked = torch.tensor(numpy.column_stack(columns)) + rows = list(stacked.unbind(dim=0)) + return rows diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index 1a4d4a5dd..0223b829c 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -664,23 +664,20 @@ def test_iter_basic(some_permutation: Permutation): expected_batches = (950 + batch_size - 1) // batch_size # ceiling division assert len(batches) == expected_batches - # Check that all batches are dicts (default python format) - assert all(isinstance(batch, dict) for batch in batches) + # Check that all batches are lists of dicts (default python format) + assert all(isinstance(batch, list) for batch in batches) # Check that batches have the correct structure for batch in batches: - assert "id" in batch - assert "value" in batch - assert isinstance(batch["id"], list) - assert isinstance(batch["value"], list) + assert "id" in batch[0] + assert "value" in batch[0] # Check that all batches except the last have the correct size for batch in batches[:-1]: - assert len(batch["id"]) == batch_size - assert len(batch["value"]) == batch_size + assert len(batch) == batch_size # Last batch might be smaller - assert len(batches[-1]["id"]) <= batch_size + assert len(batches[-1]) <= batch_size def test_iter_skip_last_batch(some_permutation: Permutation): @@ -699,11 +696,11 @@ def test_iter_skip_last_batch(some_permutation: Permutation): if 950 % batch_size != 0: assert len(batches_without_skip) == num_full_batches + 1 # Last batch should be smaller - assert len(batches_without_skip[-1]["id"]) == 950 % batch_size + assert len(batches_without_skip[-1]) == 950 % batch_size # All batches with skip_last_batch should be full size for batch in batches_with_skip: - assert len(batch["id"]) == batch_size + assert len(batch) == batch_size def test_iter_different_batch_sizes(some_permutation: Permutation): @@ -720,12 +717,12 @@ def test_iter_different_batch_sizes(some_permutation: Permutation): # Test with batch size equal to total rows single_batch = list(some_permutation.iter(950, skip_last_batch=False)) assert len(single_batch) == 1 - assert len(single_batch[0]["id"]) == 950 + assert len(single_batch[0]) == 950 # Test with batch size larger than total rows oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False)) assert len(oversized_batch) == 1 - assert len(oversized_batch[0]["id"]) == 950 + assert len(oversized_batch[0]) == 950 def test_dunder_iter(some_permutation: Permutation): @@ -738,15 +735,13 @@ def test_dunder_iter(some_permutation: Permutation): # All batches should be full size for batch in batches: - assert len(batch["id"]) == 100 - assert len(batch["value"]) == 100 + assert len(batch) == 100 some_permutation = some_permutation.with_batch_size(400) batches = list(some_permutation) assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True for batch in batches: - assert len(batch["id"]) == 400 - assert len(batch["value"]) == 400 + assert len(batch) == 400 def test_iter_with_different_formats(some_permutation: Permutation): @@ -761,7 +756,7 @@ def test_iter_with_different_formats(some_permutation: Permutation): # Test with python format (default) python_perm = some_permutation.with_format("python") python_batches = list(python_perm.iter(batch_size, skip_last_batch=False)) - assert all(isinstance(batch, dict) for batch in python_batches) + assert all(isinstance(batch, list) for batch in python_batches) # Test with pandas format pandas_perm = some_permutation.with_format("pandas") @@ -780,8 +775,8 @@ def test_iter_with_column_selection(some_permutation: Permutation): # Check that batches only contain the id column for batch in batches: - assert "id" in batch - assert "value" not in batch + assert "id" in batch[0] + assert "value" not in batch[0] def test_iter_with_column_rename(some_permutation: Permutation): @@ -791,9 +786,9 @@ def test_iter_with_column_rename(some_permutation: Permutation): # Check that batches have the renamed column for batch in batches: - assert "id" in batch - assert "data" in batch - assert "value" not in batch + assert "id" in batch[0] + assert "data" in batch[0] + assert "value" not in batch[0] def test_iter_with_limit_offset(some_permutation: Permutation): @@ -812,14 +807,14 @@ def test_iter_with_limit_offset(some_permutation: Permutation): assert len(limit_batches) == 5 no_skip = some_permutation.iter(101, skip_last_batch=False) - row_100 = next(no_skip)["id"][100] + row_100 = next(no_skip)[100]["id"] # Test with both limit and offset limited_perm = some_permutation.with_skip(100).with_take(300) limited_batches = list(limited_perm.iter(100, skip_last_batch=False)) # Should have 3 batches (300 / 100) assert len(limited_batches) == 3 - assert limited_batches[0]["id"][0] == row_100 + assert limited_batches[0][0]["id"] == row_100 def test_iter_empty_permutation(mem_db): @@ -842,7 +837,7 @@ def test_iter_single_row(mem_db): # With skip_last_batch=False, should get one batch batches = list(perm.iter(10, skip_last_batch=False)) assert len(batches) == 1 - assert len(batches[0]["id"]) == 1 + assert len(batches[0]) == 1 # With skip_last_batch=True, should skip the single row (since it's < batch_size) batches_skip = list(perm.iter(10, skip_last_batch=True)) @@ -860,8 +855,7 @@ def test_identity_permutation(mem_db): batches = list(permutation.iter(10, skip_last_batch=False)) assert len(batches) == 1 - assert len(batches[0]["id"]) == 10 - assert len(batches[0]["value"]) == 10 + assert len(batches[0]) == 10 permutation = permutation.remove_columns(["value"]) assert permutation.num_columns == 1 @@ -904,10 +898,10 @@ def test_transform_fn(mem_db): py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[ 0 ] - assert len(py_result) == 2 - assert len(py_result["id"]) == 10 - assert len(py_result["value"]) == 10 - assert isinstance(py_result, dict) + assert len(py_result) == 10 + assert "id" in py_result[0] + assert "value" in py_result[0] + assert isinstance(py_result, list) try: import torch @@ -915,9 +909,11 @@ def test_transform_fn(mem_db): torch_result = list( permutation.with_format("torch").iter(10, skip_last_batch=False) )[0] - assert torch_result.shape == (2, 10) - assert torch_result.dtype == torch.int64 - assert isinstance(torch_result, torch.Tensor) + assert isinstance(torch_result, list) + assert len(torch_result) == 10 + assert isinstance(torch_result[0], torch.Tensor) + assert torch_result[0].shape == (2,) + assert torch_result[0].dtype == torch.int64 except ImportError: # Skip check if torch is not installed pass @@ -950,17 +946,16 @@ def test_custom_transform(mem_db): def test_getitems_basic(some_permutation: Permutation): """Test __getitems__ returns correct rows by offset.""" result = some_permutation.__getitems__([0, 1, 2]) - assert isinstance(result, dict) - assert "id" in result - assert "value" in result - assert len(result["id"]) == 3 + assert isinstance(result, list) + assert "id" in result[0] + assert "value" in result[0] + assert len(result) == 3 def test_getitems_single_index(some_permutation: Permutation): """Test __getitems__ with a single index.""" result = some_permutation.__getitems__([0]) - assert len(result["id"]) == 1 - assert len(result["value"]) == 1 + assert len(result) == 1 def test_getitems_preserves_order(some_permutation: Permutation): @@ -970,38 +965,40 @@ def test_getitems_preserves_order(some_permutation: Permutation): # Get the same rows in reverse order reverse = some_permutation.__getitems__([4, 3, 2, 1, 0]) - assert forward["id"] == list(reversed(reverse["id"])) - assert forward["value"] == list(reversed(reverse["value"])) + assert [r["id"] for r in forward] == list(reversed([r["id"] for r in reverse])) + assert [r["value"] for r in forward] == list( + reversed([r["value"] for r in reverse]) + ) def test_getitems_non_contiguous(some_permutation: Permutation): """Test __getitems__ with non-contiguous indices.""" result = some_permutation.__getitems__([0, 10, 50, 100, 500]) - assert len(result["id"]) == 5 + assert len(result) == 5 # Each id/value pair should match what we'd get individually for i, offset in enumerate([0, 10, 50, 100, 500]): single = some_permutation.__getitems__([offset]) - assert result["id"][i] == single["id"][0] - assert result["value"][i] == single["value"][0] + assert result[i]["id"] == single[0]["id"] + assert result[i]["value"] == single[0]["value"] def test_getitems_with_column_selection(some_permutation: Permutation): """Test __getitems__ respects column selection.""" id_only = some_permutation.select_columns(["id"]) result = id_only.__getitems__([0, 1, 2]) - assert "id" in result - assert "value" not in result - assert len(result["id"]) == 3 + assert "id" in result[0] + assert "value" not in result[0] + assert len(result) == 3 def test_getitems_with_column_rename(some_permutation: Permutation): """Test __getitems__ respects column renames.""" renamed = some_permutation.rename_column("value", "data") result = renamed.__getitems__([0, 1]) - assert "data" in result - assert "value" not in result - assert len(result["data"]) == 2 + assert "data" in result[0] + assert "value" not in result[0] + assert len(result) == 2 def test_getitems_with_format(some_permutation: Permutation): @@ -1032,8 +1029,8 @@ def test_getitems_identity_permutation(mem_db): perm = Permutation.identity(tbl) result = perm.__getitems__([0, 5, 9]) - assert result["id"] == [0, 5, 9] - assert result["value"] == [0, 5, 9] + assert [r["id"] for r in result] == [0, 5, 9] + assert [r["value"] for r in result] == [0, 5, 9] def test_getitems_with_limit_offset(some_permutation: Permutation): @@ -1042,12 +1039,12 @@ def test_getitems_with_limit_offset(some_permutation: Permutation): # Should be able to access offsets within the limited range result = limited.__getitems__([0, 1, 199]) - assert len(result["id"]) == 3 + assert len(result) == 3 # The first item of the limited permutation should match offset 100 of original full_result = some_permutation.__getitems__([100]) limited_result = limited.__getitems__([0]) - assert limited_result["id"][0] == full_result["id"][0] + assert limited_result[0]["id"] == full_result[0]["id"] def test_getitems_invalid_offset(some_permutation: Permutation): diff --git a/python/python/tests/test_torch.py b/python/python/tests/test_torch.py index b883fdf96..ef1c5e73b 100644 --- a/python/python/tests/test_torch.py +++ b/python/python/tests/test_torch.py @@ -4,6 +4,7 @@ import pyarrow as pa import pytest from lancedb.util import tbl_to_tensor +from lancedb.permutation import Permutation torch = pytest.importorskip("torch") @@ -16,3 +17,26 @@ def test_table_dataloader(mem_db): for batch in dataloader: assert batch.size(0) == 1 assert batch.size(1) == 10 + + +def test_permutation_dataloader(mem_db): + table = mem_db.create_table("test_table", pa.table({"a": range(1000)})) + + permutation = Permutation.identity(table) + dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True) + for batch in dataloader: + assert batch["a"].size(0) == 10 + + permutation = permutation.with_format("torch") + dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True) + for batch in dataloader: + assert batch.size(0) == 10 + assert batch.size(1) == 1 + + permutation = permutation.with_format("torch_col") + dataloader = torch.utils.data.DataLoader( + permutation, collate_fn=lambda x: x, batch_size=10, shuffle=True + ) + for batch in dataloader: + assert batch.size(0) == 1 + assert batch.size(1) == 10