feat(python): add Permutation.from_table() with chainable HF/torch-style API

Hides the two-step `permutation_builder(t).shuffle().execute()` /
`Permutation.from_tables(t, perm_tbl)` dance behind a single chainable
entry point that ML engineers already expect from HuggingFace and
PyTorch:

  perm = Permutation.from_table(table).shuffle(seed=42)

Builder operations (shuffle, filter, split_*) are accumulated lazily and
the underlying permutation table is built only once on first read.
Subsequent read / transform / format calls are forwarded transparently
to the materialized Permutation.

Closes lancedb/lancedb#3244
This commit is contained in:
Ayush Chaurasia
2026-04-29 22:17:01 +05:30
parent 25dfe2cfd4
commit cb57b7655e
2 changed files with 168 additions and 0 deletions

View File

@@ -425,6 +425,35 @@ class Permutation:
"""
return Permutation.from_tables(table, None, None)
@classmethod
def from_table(cls, table: LanceTable) -> "_PermutationFromTable":
"""
Create a permutation directly from a base table, with HuggingFace /
PyTorch-style chaining for ``shuffle``, ``filter``, and ``split_*``.
This is a convenience wrapper that hides the two-step
``permutation_builder(table).shuffle().execute()`` /
``Permutation.from_tables(table, perm_tbl)`` dance. The returned object
accumulates builder operations and only materializes the underlying
permutation table on first read (any access of an attribute that is
not a builder operation), so chained calls do not pay an extra
``execute()`` for each step.
After the first read all ``Permutation`` methods (``select_columns``,
``with_format``, ``map``, ``__iter__``, ``fetch``, ``num_rows``, ...)
are forwarded transparently.
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("memory:///")
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(100)])
>>> perm = Permutation.from_table(tbl).shuffle(seed=42)
>>> perm.num_rows
100
"""
return _PermutationFromTable(table)
@classmethod
def from_tables(
cls,
@@ -842,3 +871,85 @@ class Permutation:
Repeat the permutation `times` times
"""
raise Exception("with_repeat is not yet implemented")
class _PermutationFromTable:
"""
Result of [Permutation.from_table](#from_table).
Records pending builder operations (``shuffle``, ``filter``, ``split_*``)
and lazily executes them on first read. After materialization all
Permutation reads / transforms (``select_columns``, ``with_format``,
``map``, ``__iter__``, ``fetch``, ``num_rows``, ...) are forwarded to the
underlying [Permutation].
"""
__slots__ = ("_base_table", "_pending_ops", "_materialized")
def __init__(
self,
base_table: LanceTable,
_pending_ops: Optional[list[tuple[str, tuple, dict]]] = None,
):
self._base_table = base_table
self._pending_ops: list[tuple[str, tuple, dict]] = (
list(_pending_ops) if _pending_ops is not None else []
)
self._materialized: Optional[Permutation] = None
def _with_op(
self, name: str, args: tuple, kwargs: dict
) -> "_PermutationFromTable":
return _PermutationFromTable(
self._base_table, _pending_ops=self._pending_ops + [(name, args, kwargs)]
)
def shuffle(self, *args, **kwargs) -> "_PermutationFromTable":
return self._with_op("shuffle", args, kwargs)
def filter(self, *args, **kwargs) -> "_PermutationFromTable":
return self._with_op("filter", args, kwargs)
def split_random(self, *args, **kwargs) -> "_PermutationFromTable":
return self._with_op("split_random", args, kwargs)
def split_sequential(self, *args, **kwargs) -> "_PermutationFromTable":
return self._with_op("split_sequential", args, kwargs)
def split_hash(self, *args, **kwargs) -> "_PermutationFromTable":
return self._with_op("split_hash", args, kwargs)
def split_calculated(self, *args, **kwargs) -> "_PermutationFromTable":
return self._with_op("split_calculated", args, kwargs)
def _materialize(self) -> Permutation:
if self._materialized is None:
if self._pending_ops:
builder = permutation_builder(self._base_table)
for name, args, kwargs in self._pending_ops:
builder = getattr(builder, name)(*args, **kwargs)
perm_tbl = builder.execute()
self._materialized = Permutation.from_tables(
self._base_table, perm_tbl
)
else:
self._materialized = Permutation.identity(self._base_table)
return self._materialized
def __getattr__(self, name: str) -> Any:
# Avoid recursion on dunder/private state.
if name.startswith("_"):
raise AttributeError(name)
return getattr(self._materialize(), name)
def __iter__(self):
return iter(self._materialize())
def __len__(self) -> int:
return len(self._materialize())
def __getitem__(self, index):
return self._materialize()[index]
def __getitems__(self, indices):
return self._materialize().__getitems__(indices)

View File

@@ -1095,3 +1095,60 @@ def test_getitems_invalid_offset(some_permutation: Permutation):
"""Test __getitems__ with an out-of-range offset raises an error."""
with pytest.raises(Exception):
some_permutation.__getitems__([999999])
def test_from_table_identity(mem_db):
"""Permutation.from_table without ops behaves like identity."""
tbl = mem_db.create_table("tbl", pa.table({"x": range(10)}))
perm = Permutation.from_table(tbl)
assert perm.num_rows == 10
assert perm.column_names == ["x"]
def test_from_table_shuffle_seeded(mem_db):
"""from_table().shuffle(seed=...) is reproducible and reorders rows."""
tbl = mem_db.create_table("tbl", pa.table({"x": range(100)}))
perm = Permutation.from_table(tbl).shuffle(seed=42)
rows = [r["x"] for r in perm.__getitems__(list(range(100)))]
assert sorted(rows) == list(range(100))
assert rows != list(range(100))
# Same seed → same order
rows2 = [
r["x"]
for r in Permutation.from_table(tbl)
.shuffle(seed=42)
.__getitems__(list(range(100)))
]
assert rows == rows2
def test_from_table_filter(mem_db):
"""from_table().filter(...) limits the rows."""
tbl = mem_db.create_table("tbl", pa.table({"x": range(100)}))
perm = Permutation.from_table(tbl).filter("x < 25")
assert perm.num_rows == 25
def test_from_table_chained_ops(mem_db):
"""Chained shuffle + filter materializes once."""
tbl = mem_db.create_table("tbl", pa.table({"x": range(100)}))
perm = Permutation.from_table(tbl).filter("x >= 50").shuffle(seed=7)
assert perm.num_rows == 50
rows = [r["x"] for r in perm.__getitems__(list(range(50)))]
assert sorted(rows) == list(range(50, 100))
def test_from_table_forwards_read_methods(mem_db):
"""from_table() result transparently forwards Permutation read methods."""
tbl = mem_db.create_table("tbl", pa.table({"x": range(10), "y": range(10)}))
perm = Permutation.from_table(tbl).select_columns(["x"])
assert perm.column_names == ["x"]
def test_from_table_split_random(mem_db):
"""from_table().split_random(...) returns rows from the first split."""
tbl = mem_db.create_table("tbl", pa.table({"x": range(100)}))
perm = Permutation.from_table(tbl).split_random(ratios=[0.3, 0.7], seed=1)
# Default split is 0 — ratio 0.3 → ~30 rows
assert 25 <= perm.num_rows <= 35