mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 02:20:40 +00:00
feat(python): add public Permutation.fetch(indices) API
Adds a public method that mirrors __getitems__ for batch index access, so users do not have to call a dunder directly when implementing custom torch datasets. Closes lancedb/lancedb#3243
This commit is contained in:
@@ -779,6 +779,25 @@ class Permutation:
|
||||
batch = LOOP.run(do_getitems())
|
||||
return self.transform_fn(batch)
|
||||
|
||||
def fetch(self, indices: list[int]) -> Any:
|
||||
"""
|
||||
Fetch rows from the permutation by offset.
|
||||
|
||||
This is the public batch-access API. It returns the rows for the given
|
||||
offsets in the same shape as configured by
|
||||
[with_format](#with_format) / [with_transform](#with_transform).
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("memory:///")
|
||||
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(10)])
|
||||
>>> perm = Permutation.identity(tbl)
|
||||
>>> perm.fetch([0, 5, 9])
|
||||
[{'x': 0}, {'x': 5}, {'x': 9}]
|
||||
"""
|
||||
return self.__getitems__(indices)
|
||||
|
||||
@deprecated(details="Use with_skip instead")
|
||||
def skip(self, skip: int) -> "Permutation":
|
||||
"""
|
||||
|
||||
@@ -1095,3 +1095,23 @@ def test_getitems_invalid_offset(some_permutation: Permutation):
|
||||
"""Test __getitems__ with an out-of-range offset raises an error."""
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.__getitems__([999999])
|
||||
|
||||
|
||||
def test_fetch_matches_getitems(some_permutation: Permutation):
|
||||
"""Public fetch() should be equivalent to __getitems__."""
|
||||
indices = [0, 1, 2, 10, 100]
|
||||
assert some_permutation.fetch(indices) == some_permutation.__getitems__(indices)
|
||||
|
||||
|
||||
def test_fetch_respects_format(some_permutation: Permutation):
|
||||
"""fetch() applies the configured format/transform."""
|
||||
arrow_perm = some_permutation.with_format("arrow")
|
||||
result = arrow_perm.fetch([0, 1, 2])
|
||||
assert isinstance(result, pa.RecordBatch)
|
||||
assert result.num_rows == 3
|
||||
|
||||
|
||||
def test_fetch_invalid_offset(some_permutation: Permutation):
|
||||
"""fetch() with an out-of-range offset raises an error."""
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.fetch([999999])
|
||||
|
||||
Reference in New Issue
Block a user