mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 18:40:39 +00:00
feat(python): add Permutation.map() for row-level transforms
Adds a HuggingFace-style `map(fn)` method that applies fn to each row dict. This complements `with_transform` (which operates on `pa.RecordBatch`) by offering the more familiar per-row API for AI engineers. Closes lancedb/lancedb#3246
This commit is contained in:
@@ -762,6 +762,35 @@ class Permutation:
|
||||
assert transform is not None, "transform is required"
|
||||
return Permutation(self.reader, self.selection, self.batch_size, transform)
|
||||
|
||||
def map(self, fn: Callable[[dict], dict]) -> "Permutation":
|
||||
"""
|
||||
Apply a function to each row of the permutation, like HuggingFace
|
||||
``dataset.map``.
|
||||
|
||||
``fn`` receives a single row as a ``dict[str, Any]`` and must return a
|
||||
``dict[str, Any]``. The transformed batch is exposed as a list of dicts
|
||||
(matching the default "python" format), so it works directly with
|
||||
the PyTorch DataLoader's default collate.
|
||||
|
||||
For column-oriented or zero-copy transforms, use
|
||||
[with_transform](#with_transform) which receives a ``pa.RecordBatch``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("memory:///")
|
||||
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(5)])
|
||||
>>> perm = Permutation.identity(tbl).map(lambda row: {"x": row["x"] * 2})
|
||||
>>> perm.fetch([0, 1, 2])
|
||||
[{'x': 0}, {'x': 2}, {'x': 4}]
|
||||
"""
|
||||
assert fn is not None, "fn is required"
|
||||
|
||||
def batch_transform(batch: pa.RecordBatch) -> list[dict]:
|
||||
return [fn(row) for row in batch.to_pylist()]
|
||||
|
||||
return self.with_transform(batch_transform)
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
Returns a single row from the permutation by offset
|
||||
|
||||
@@ -1095,3 +1095,43 @@ def test_getitems_invalid_offset(some_permutation: Permutation):
|
||||
"""Test __getitems__ with an out-of-range offset raises an error."""
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.__getitems__([999999])
|
||||
|
||||
|
||||
def test_map_basic(mem_db):
|
||||
"""map() applies fn per row and yields list[dict]."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(10), "y": range(10, 20)})
|
||||
)
|
||||
perm = Permutation.identity(tbl).map(lambda row: {"sum": row["x"] + row["y"]})
|
||||
|
||||
rows = perm.__getitems__([0, 1, 2])
|
||||
assert isinstance(rows, list)
|
||||
assert rows == [{"sum": 10}, {"sum": 12}, {"sum": 14}]
|
||||
|
||||
|
||||
def test_map_in_iter(mem_db):
|
||||
"""map() integrates with iter() and produces list-of-dicts batches."""
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(10)}))
|
||||
perm = (
|
||||
Permutation.identity(tbl)
|
||||
.map(lambda row: {"y": row["x"] * 2})
|
||||
.with_batch_size(5)
|
||||
)
|
||||
|
||||
batches = list(perm.iter(5, skip_last_batch=False))
|
||||
assert len(batches) == 2
|
||||
assert batches[0] == [{"y": i * 2} for i in range(5)]
|
||||
assert batches[1] == [{"y": i * 2} for i in range(5, 10)]
|
||||
|
||||
|
||||
def test_map_can_add_columns(mem_db):
|
||||
"""map() can add or change keys in the row dict."""
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(3)}))
|
||||
perm = Permutation.identity(tbl).map(
|
||||
lambda row: {"x": row["x"], "doubled": row["x"] * 2}
|
||||
)
|
||||
assert perm.__getitems__([0, 1, 2]) == [
|
||||
{"x": 0, "doubled": 0},
|
||||
{"x": 1, "doubled": 2},
|
||||
{"x": 2, "doubled": 4},
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user