mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-05 05:10:41 +00:00
Compare commits
1 Commits
ticket/324
...
ticket/324
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2fa4d04de7 |
@@ -9,7 +9,7 @@ import json
|
||||
from ._lancedb import async_permutation_builder, PermutationReader
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
from .util import batch_to_tensor, batch_to_tensor_rows
|
||||
from .util import batch_to_tensor, batch_to_tensor_dict, batch_to_tensor_rows
|
||||
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -697,6 +697,7 @@ class Permutation:
|
||||
"pandas",
|
||||
"arrow",
|
||||
"torch",
|
||||
"torch_row",
|
||||
"torch_col",
|
||||
"polars",
|
||||
],
|
||||
@@ -712,8 +713,17 @@ class Permutation:
|
||||
- "python_col" - the batch will be a dict of lists (one entry per column)
|
||||
- "pandas" - the batch will be a pandas DataFrame
|
||||
- "arrow" - the batch will be a pyarrow RecordBatch
|
||||
- "torch" - the batch will be a list of tensors, one per row
|
||||
- "torch_col" - the batch will be a 2D torch tensor (first dim indexes columns)
|
||||
- "torch" - a list of per-row dicts whose values are torch tensors. When
|
||||
used with ``torch.utils.data.DataLoader`` (default collate), each
|
||||
batch yielded by the loader is ``dict[str, Tensor]`` — one tensor per
|
||||
column, with column names preserved. This matches HuggingFace
|
||||
``dataset.set_format("torch")`` semantics.
|
||||
- "torch_row" - a list of 1-D torch tensors, one per row. Each tensor
|
||||
stacks all column values into a single row vector and column names
|
||||
are not preserved. (This was the previous "torch" behavior.)
|
||||
- "torch_col" - a 2-D torch tensor of shape ``(n_cols, n_rows)``. Column
|
||||
names are not preserved. Requires ``collate_fn=lambda x: x`` if used
|
||||
with ``DataLoader``.
|
||||
- "polars" - the batch will be a polars DataFrame
|
||||
|
||||
Conversion may or may not involve a data copy. Lance uses Arrow internally
|
||||
@@ -741,6 +751,8 @@ class Permutation:
|
||||
elif format == "arrow":
|
||||
return self.with_transform(Transforms.arrow2arrow)
|
||||
elif format == "torch":
|
||||
return self.with_transform(batch_to_tensor_dict)
|
||||
elif format == "torch_row":
|
||||
return self.with_transform(batch_to_tensor_rows)
|
||||
elif format == "torch_col":
|
||||
return self.with_transform(batch_to_tensor)
|
||||
@@ -762,35 +774,6 @@ class Permutation:
|
||||
assert transform is not None, "transform is required"
|
||||
return Permutation(self.reader, self.selection, self.batch_size, transform)
|
||||
|
||||
def map(self, fn: Callable[[dict], dict]) -> "Permutation":
|
||||
"""
|
||||
Apply a function to each row of the permutation, like HuggingFace
|
||||
``dataset.map``.
|
||||
|
||||
``fn`` receives a single row as a ``dict[str, Any]`` and must return a
|
||||
``dict[str, Any]``. The transformed batch is exposed as a list of dicts
|
||||
(matching the default "python" format), so it works directly with
|
||||
the PyTorch DataLoader's default collate.
|
||||
|
||||
For column-oriented or zero-copy transforms, use
|
||||
[with_transform](#with_transform) which receives a ``pa.RecordBatch``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("memory:///")
|
||||
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(5)])
|
||||
>>> perm = Permutation.identity(tbl).map(lambda row: {"x": row["x"] * 2})
|
||||
>>> perm.fetch([0, 1, 2])
|
||||
[{'x': 0}, {'x': 2}, {'x': 4}]
|
||||
"""
|
||||
assert fn is not None, "fn is required"
|
||||
|
||||
def batch_transform(batch: pa.RecordBatch) -> list[dict]:
|
||||
return [fn(row) for row in batch.to_pylist()]
|
||||
|
||||
return self.with_transform(batch_transform)
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
Returns a single row from the permutation by offset
|
||||
|
||||
@@ -448,3 +448,29 @@ def batch_to_tensor_rows(batch: pa.RecordBatch):
|
||||
stacked = torch.tensor(numpy.column_stack(columns))
|
||||
rows = list(stacked.unbind(dim=0))
|
||||
return rows
|
||||
|
||||
|
||||
def batch_to_tensor_dict(batch: pa.RecordBatch):
|
||||
"""
|
||||
Convert a PyArrow RecordBatch into a list of per-row dicts whose values
|
||||
are PyTorch tensors.
|
||||
|
||||
Each column is converted to a tensor in one shot (zero-copy via DLPack
|
||||
when supported), then sliced per row. The result is shaped to work with
|
||||
PyTorch's default DataLoader collate, which stacks the per-row dicts
|
||||
into a single ``dict[str, Tensor]`` per batch — matching the
|
||||
HuggingFace ``dataset.set_format("torch")`` convention.
|
||||
|
||||
Fails if torch is not installed.
|
||||
Fails if a column's data type is not supported by PyTorch.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
columns: dict[str, "torch.Tensor"] = {}
|
||||
for i, name in enumerate(batch.schema.names):
|
||||
col = batch.column(i)
|
||||
try:
|
||||
columns[name] = torch.from_dlpack(col)
|
||||
except Exception:
|
||||
columns[name] = torch.tensor(col.to_numpy(zero_copy_only=False))
|
||||
n = batch.num_rows
|
||||
return [{name: t[i] for name, t in columns.items()} for i in range(n)]
|
||||
|
||||
@@ -950,14 +950,27 @@ def test_transform_fn(mem_db):
|
||||
try:
|
||||
import torch
|
||||
|
||||
# "torch" format: list of per-row dicts of tensors (HF-compatible).
|
||||
torch_result = list(
|
||||
permutation.with_format("torch").iter(10, skip_last_batch=False)
|
||||
)[0]
|
||||
assert isinstance(torch_result, list)
|
||||
assert len(torch_result) == 10
|
||||
assert isinstance(torch_result[0], torch.Tensor)
|
||||
assert torch_result[0].shape == (2,)
|
||||
assert torch_result[0].dtype == torch.int64
|
||||
assert isinstance(torch_result[0], dict)
|
||||
assert set(torch_result[0].keys()) == {"id", "value"}
|
||||
assert isinstance(torch_result[0]["id"], torch.Tensor)
|
||||
assert torch_result[0]["id"].shape == ()
|
||||
assert torch_result[0]["id"].dtype == torch.int64
|
||||
|
||||
# "torch_row" format: list of 1-D row tensors (previous "torch" behavior).
|
||||
torch_row_result = list(
|
||||
permutation.with_format("torch_row").iter(10, skip_last_batch=False)
|
||||
)[0]
|
||||
assert isinstance(torch_row_result, list)
|
||||
assert len(torch_row_result) == 10
|
||||
assert isinstance(torch_row_result[0], torch.Tensor)
|
||||
assert torch_row_result[0].shape == (2,)
|
||||
assert torch_row_result[0].dtype == torch.int64
|
||||
except ImportError:
|
||||
# Skip check if torch is not installed
|
||||
pass
|
||||
@@ -1095,43 +1108,3 @@ def test_getitems_invalid_offset(some_permutation: Permutation):
|
||||
"""Test __getitems__ with an out-of-range offset raises an error."""
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.__getitems__([999999])
|
||||
|
||||
|
||||
def test_map_basic(mem_db):
|
||||
"""map() applies fn per row and yields list[dict]."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(10), "y": range(10, 20)})
|
||||
)
|
||||
perm = Permutation.identity(tbl).map(lambda row: {"sum": row["x"] + row["y"]})
|
||||
|
||||
rows = perm.__getitems__([0, 1, 2])
|
||||
assert isinstance(rows, list)
|
||||
assert rows == [{"sum": 10}, {"sum": 12}, {"sum": 14}]
|
||||
|
||||
|
||||
def test_map_in_iter(mem_db):
|
||||
"""map() integrates with iter() and produces list-of-dicts batches."""
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(10)}))
|
||||
perm = (
|
||||
Permutation.identity(tbl)
|
||||
.map(lambda row: {"y": row["x"] * 2})
|
||||
.with_batch_size(5)
|
||||
)
|
||||
|
||||
batches = list(perm.iter(5, skip_last_batch=False))
|
||||
assert len(batches) == 2
|
||||
assert batches[0] == [{"y": i * 2} for i in range(5)]
|
||||
assert batches[1] == [{"y": i * 2} for i in range(5, 10)]
|
||||
|
||||
|
||||
def test_map_can_add_columns(mem_db):
|
||||
"""map() can add or change keys in the row dict."""
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(3)}))
|
||||
perm = Permutation.identity(tbl).map(
|
||||
lambda row: {"x": row["x"], "doubled": row["x"] * 2}
|
||||
)
|
||||
assert perm.__getitems__([0, 1, 2]) == [
|
||||
{"x": 0, "doubled": 0},
|
||||
{"x": 1, "doubled": 2},
|
||||
{"x": 2, "doubled": 4},
|
||||
]
|
||||
|
||||
@@ -27,8 +27,18 @@ def test_permutation_dataloader(mem_db):
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
|
||||
# New "torch" format: per-row dicts of tensors, default collate yields
|
||||
# dict[str, Tensor] (HuggingFace style).
|
||||
permutation = permutation.with_format("torch")
|
||||
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, dict)
|
||||
assert "a" in batch
|
||||
assert batch["a"].size() == (10,)
|
||||
|
||||
# Previous "torch" semantics is preserved under the "torch_row" name.
|
||||
permutation = permutation.with_format("torch_row")
|
||||
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 10
|
||||
assert batch.size(1) == 1
|
||||
|
||||
Reference in New Issue
Block a user