From f08a9c685c7fa70a9d1944881406853fe5c9b855 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 29 Apr 2026 22:14:19 +0530 Subject: [PATCH] feat(python): add Permutation.map() for row-level transforms Adds a HuggingFace-style `map(fn)` method that applies fn to each row dict. This complements `with_transform` (which operates on `pa.RecordBatch`) by offering the more familiar per-row API for AI engineers. Closes lancedb/lancedb#3246 --- python/python/lancedb/permutation.py | 29 ++++++++++++++++++ python/python/tests/test_permutation.py | 40 +++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 724a0fd25..cf9d64778 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -762,6 +762,35 @@ class Permutation: assert transform is not None, "transform is required" return Permutation(self.reader, self.selection, self.batch_size, transform) + def map(self, fn: Callable[[dict], dict]) -> "Permutation": + """ + Apply a function to each row of the permutation, like HuggingFace + ``dataset.map``. + + ``fn`` receives a single row as a ``dict[str, Any]`` and must return a + ``dict[str, Any]``. The transformed batch is exposed as a list of dicts + (matching the default "python" format), so it works directly with + the PyTorch DataLoader's default collate. + + For column-oriented or zero-copy transforms, use + [with_transform](#with_transform) which receives a ``pa.RecordBatch``. + + Examples + -------- + >>> import lancedb + >>> db = lancedb.connect("memory:///") + >>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(5)]) + >>> perm = Permutation.identity(tbl).map(lambda row: {"x": row["x"] * 2}) + >>> perm.fetch([0, 1, 2]) + [{'x': 0}, {'x': 2}, {'x': 4}] + """ + assert fn is not None, "fn is required" + + def batch_transform(batch: pa.RecordBatch) -> list[dict]: + return [fn(row) for row in batch.to_pylist()] + + return self.with_transform(batch_transform) + def __getitem__(self, index: int) -> Any: """ Returns a single row from the permutation by offset diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index bb92ba0ba..273367a50 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -1095,3 +1095,43 @@ def test_getitems_invalid_offset(some_permutation: Permutation): """Test __getitems__ with an out-of-range offset raises an error.""" with pytest.raises(Exception): some_permutation.__getitems__([999999]) + + +def test_map_basic(mem_db): + """map() applies fn per row and yields list[dict].""" + tbl = mem_db.create_table( + "test_table", pa.table({"x": range(10), "y": range(10, 20)}) + ) + perm = Permutation.identity(tbl).map(lambda row: {"sum": row["x"] + row["y"]}) + + rows = perm.__getitems__([0, 1, 2]) + assert isinstance(rows, list) + assert rows == [{"sum": 10}, {"sum": 12}, {"sum": 14}] + + +def test_map_in_iter(mem_db): + """map() integrates with iter() and produces list-of-dicts batches.""" + tbl = mem_db.create_table("test_table", pa.table({"x": range(10)})) + perm = ( + Permutation.identity(tbl) + .map(lambda row: {"y": row["x"] * 2}) + .with_batch_size(5) + ) + + batches = list(perm.iter(5, skip_last_batch=False)) + assert len(batches) == 2 + assert batches[0] == [{"y": i * 2} for i in range(5)] + assert batches[1] == [{"y": i * 2} for i in range(5, 10)] + + +def test_map_can_add_columns(mem_db): + """map() can add or change keys in the row dict.""" + tbl = mem_db.create_table("test_table", pa.table({"x": range(3)})) + perm = Permutation.identity(tbl).map( + lambda row: {"x": row["x"], "doubled": row["x"] * 2} + ) + assert perm.__getitems__([0, 1, 2]) == [ + {"x": 0, "doubled": 0}, + {"x": 1, "doubled": 2}, + {"x": 2, "doubled": 4}, + ]