From 049a689a1c059dba0bb496d00a63f9841cc0b5a8 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 29 Apr 2026 22:13:42 +0530 Subject: [PATCH] feat(python): add public Permutation.fetch(indices) API Adds a public method that mirrors __getitems__ for batch index access, so users do not have to call a dunder directly when implementing custom torch datasets. Closes lancedb/lancedb#3243 --- python/python/lancedb/permutation.py | 19 +++++++++++++++++++ python/python/tests/test_permutation.py | 20 ++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 724a0fd25..9e04a0c6d 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -779,6 +779,25 @@ class Permutation: batch = LOOP.run(do_getitems()) return self.transform_fn(batch) + def fetch(self, indices: list[int]) -> Any: + """ + Fetch rows from the permutation by offset. + + This is the public batch-access API. It returns the rows for the given + offsets in the same shape as configured by + [with_format](#with_format) / [with_transform](#with_transform). + + Examples + -------- + >>> import lancedb + >>> db = lancedb.connect("memory:///") + >>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(10)]) + >>> perm = Permutation.identity(tbl) + >>> perm.fetch([0, 5, 9]) + [{'x': 0}, {'x': 5}, {'x': 9}] + """ + return self.__getitems__(indices) + @deprecated(details="Use with_skip instead") def skip(self, skip: int) -> "Permutation": """ diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index bb92ba0ba..f9b20a228 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -1095,3 +1095,23 @@ def test_getitems_invalid_offset(some_permutation: Permutation): """Test __getitems__ with an out-of-range offset raises an error.""" with pytest.raises(Exception): some_permutation.__getitems__([999999]) + + +def test_fetch_matches_getitems(some_permutation: Permutation): + """Public fetch() should be equivalent to __getitems__.""" + indices = [0, 1, 2, 10, 100] + assert some_permutation.fetch(indices) == some_permutation.__getitems__(indices) + + +def test_fetch_respects_format(some_permutation: Permutation): + """fetch() applies the configured format/transform.""" + arrow_perm = some_permutation.with_format("arrow") + result = arrow_perm.fetch([0, 1, 2]) + assert isinstance(result, pa.RecordBatch) + assert result.num_rows == 3 + + +def test_fetch_invalid_offset(some_permutation: Permutation): + """fetch() with an out-of-range offset raises an error.""" + with pytest.raises(Exception): + some_permutation.fetch([999999])