mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 10:30:40 +00:00
feat(python): add Permutation.from_table() with chainable HF/torch-style API
Hides the two-step `permutation_builder(t).shuffle().execute()` / `Permutation.from_tables(t, perm_tbl)` dance behind a single chainable entry point that ML engineers already expect from HuggingFace and PyTorch: perm = Permutation.from_table(table).shuffle(seed=42) Builder operations (shuffle, filter, split_*) are accumulated lazily and the underlying permutation table is built only once on first read. Subsequent read / transform / format calls are forwarded transparently to the materialized Permutation. Closes lancedb/lancedb#3244
This commit is contained in:
@@ -425,6 +425,35 @@ class Permutation:
|
||||
"""
|
||||
return Permutation.from_tables(table, None, None)
|
||||
|
||||
@classmethod
|
||||
def from_table(cls, table: LanceTable) -> "_PermutationFromTable":
|
||||
"""
|
||||
Create a permutation directly from a base table, with HuggingFace /
|
||||
PyTorch-style chaining for ``shuffle``, ``filter``, and ``split_*``.
|
||||
|
||||
This is a convenience wrapper that hides the two-step
|
||||
``permutation_builder(table).shuffle().execute()`` /
|
||||
``Permutation.from_tables(table, perm_tbl)`` dance. The returned object
|
||||
accumulates builder operations and only materializes the underlying
|
||||
permutation table on first read (any access of an attribute that is
|
||||
not a builder operation), so chained calls do not pay an extra
|
||||
``execute()`` for each step.
|
||||
|
||||
After the first read all ``Permutation`` methods (``select_columns``,
|
||||
``with_format``, ``map``, ``__iter__``, ``fetch``, ``num_rows``, ...)
|
||||
are forwarded transparently.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("memory:///")
|
||||
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(100)])
|
||||
>>> perm = Permutation.from_table(tbl).shuffle(seed=42)
|
||||
>>> perm.num_rows
|
||||
100
|
||||
"""
|
||||
return _PermutationFromTable(table)
|
||||
|
||||
@classmethod
|
||||
def from_tables(
|
||||
cls,
|
||||
@@ -842,3 +871,85 @@ class Permutation:
|
||||
Repeat the permutation `times` times
|
||||
"""
|
||||
raise Exception("with_repeat is not yet implemented")
|
||||
|
||||
|
||||
class _PermutationFromTable:
|
||||
"""
|
||||
Result of [Permutation.from_table](#from_table).
|
||||
|
||||
Records pending builder operations (``shuffle``, ``filter``, ``split_*``)
|
||||
and lazily executes them on first read. After materialization all
|
||||
Permutation reads / transforms (``select_columns``, ``with_format``,
|
||||
``map``, ``__iter__``, ``fetch``, ``num_rows``, ...) are forwarded to the
|
||||
underlying [Permutation].
|
||||
"""
|
||||
|
||||
__slots__ = ("_base_table", "_pending_ops", "_materialized")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_table: LanceTable,
|
||||
_pending_ops: Optional[list[tuple[str, tuple, dict]]] = None,
|
||||
):
|
||||
self._base_table = base_table
|
||||
self._pending_ops: list[tuple[str, tuple, dict]] = (
|
||||
list(_pending_ops) if _pending_ops is not None else []
|
||||
)
|
||||
self._materialized: Optional[Permutation] = None
|
||||
|
||||
def _with_op(
|
||||
self, name: str, args: tuple, kwargs: dict
|
||||
) -> "_PermutationFromTable":
|
||||
return _PermutationFromTable(
|
||||
self._base_table, _pending_ops=self._pending_ops + [(name, args, kwargs)]
|
||||
)
|
||||
|
||||
def shuffle(self, *args, **kwargs) -> "_PermutationFromTable":
|
||||
return self._with_op("shuffle", args, kwargs)
|
||||
|
||||
def filter(self, *args, **kwargs) -> "_PermutationFromTable":
|
||||
return self._with_op("filter", args, kwargs)
|
||||
|
||||
def split_random(self, *args, **kwargs) -> "_PermutationFromTable":
|
||||
return self._with_op("split_random", args, kwargs)
|
||||
|
||||
def split_sequential(self, *args, **kwargs) -> "_PermutationFromTable":
|
||||
return self._with_op("split_sequential", args, kwargs)
|
||||
|
||||
def split_hash(self, *args, **kwargs) -> "_PermutationFromTable":
|
||||
return self._with_op("split_hash", args, kwargs)
|
||||
|
||||
def split_calculated(self, *args, **kwargs) -> "_PermutationFromTable":
|
||||
return self._with_op("split_calculated", args, kwargs)
|
||||
|
||||
def _materialize(self) -> Permutation:
|
||||
if self._materialized is None:
|
||||
if self._pending_ops:
|
||||
builder = permutation_builder(self._base_table)
|
||||
for name, args, kwargs in self._pending_ops:
|
||||
builder = getattr(builder, name)(*args, **kwargs)
|
||||
perm_tbl = builder.execute()
|
||||
self._materialized = Permutation.from_tables(
|
||||
self._base_table, perm_tbl
|
||||
)
|
||||
else:
|
||||
self._materialized = Permutation.identity(self._base_table)
|
||||
return self._materialized
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
# Avoid recursion on dunder/private state.
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(name)
|
||||
return getattr(self._materialize(), name)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._materialize())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._materialize())
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._materialize()[index]
|
||||
|
||||
def __getitems__(self, indices):
|
||||
return self._materialize().__getitems__(indices)
|
||||
|
||||
@@ -1095,3 +1095,60 @@ def test_getitems_invalid_offset(some_permutation: Permutation):
|
||||
"""Test __getitems__ with an out-of-range offset raises an error."""
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.__getitems__([999999])
|
||||
|
||||
|
||||
def test_from_table_identity(mem_db):
|
||||
"""Permutation.from_table without ops behaves like identity."""
|
||||
tbl = mem_db.create_table("tbl", pa.table({"x": range(10)}))
|
||||
perm = Permutation.from_table(tbl)
|
||||
assert perm.num_rows == 10
|
||||
assert perm.column_names == ["x"]
|
||||
|
||||
|
||||
def test_from_table_shuffle_seeded(mem_db):
|
||||
"""from_table().shuffle(seed=...) is reproducible and reorders rows."""
|
||||
tbl = mem_db.create_table("tbl", pa.table({"x": range(100)}))
|
||||
perm = Permutation.from_table(tbl).shuffle(seed=42)
|
||||
rows = [r["x"] for r in perm.__getitems__(list(range(100)))]
|
||||
assert sorted(rows) == list(range(100))
|
||||
assert rows != list(range(100))
|
||||
|
||||
# Same seed → same order
|
||||
rows2 = [
|
||||
r["x"]
|
||||
for r in Permutation.from_table(tbl)
|
||||
.shuffle(seed=42)
|
||||
.__getitems__(list(range(100)))
|
||||
]
|
||||
assert rows == rows2
|
||||
|
||||
|
||||
def test_from_table_filter(mem_db):
|
||||
"""from_table().filter(...) limits the rows."""
|
||||
tbl = mem_db.create_table("tbl", pa.table({"x": range(100)}))
|
||||
perm = Permutation.from_table(tbl).filter("x < 25")
|
||||
assert perm.num_rows == 25
|
||||
|
||||
|
||||
def test_from_table_chained_ops(mem_db):
|
||||
"""Chained shuffle + filter materializes once."""
|
||||
tbl = mem_db.create_table("tbl", pa.table({"x": range(100)}))
|
||||
perm = Permutation.from_table(tbl).filter("x >= 50").shuffle(seed=7)
|
||||
assert perm.num_rows == 50
|
||||
rows = [r["x"] for r in perm.__getitems__(list(range(50)))]
|
||||
assert sorted(rows) == list(range(50, 100))
|
||||
|
||||
|
||||
def test_from_table_forwards_read_methods(mem_db):
|
||||
"""from_table() result transparently forwards Permutation read methods."""
|
||||
tbl = mem_db.create_table("tbl", pa.table({"x": range(10), "y": range(10)}))
|
||||
perm = Permutation.from_table(tbl).select_columns(["x"])
|
||||
assert perm.column_names == ["x"]
|
||||
|
||||
|
||||
def test_from_table_split_random(mem_db):
|
||||
"""from_table().split_random(...) returns rows from the first split."""
|
||||
tbl = mem_db.create_table("tbl", pa.table({"x": range(100)}))
|
||||
perm = Permutation.from_table(tbl).split_random(ratios=[0.3, 0.7], seed=1)
|
||||
# Default split is 0 — ratio 0.3 → ~30 rows
|
||||
assert 25 <= perm.num_rows <= 35
|
||||
|
||||
Reference in New Issue
Block a user