feat: add a permutation reader that can read a permutation view (#2712)

This adds a rust permutation builder. In the next PR I will have python
bindings and integration with pytorch.
This commit is contained in:
Weston Pace
2025-10-17 05:00:23 -07:00
committed by GitHub
parent a70ff04bc9
commit 4cfcd95320
24 changed files with 974 additions and 546 deletions

View File

@@ -8,8 +8,8 @@ from typing import Optional
class PermutationBuilder:
def __init__(self, table: LanceTable, dest_table_name: str):
self._async = async_permutation_builder(table, dest_table_name)
def __init__(self, table: LanceTable):
self._async = async_permutation_builder(table)
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
self._async.select(projections)
@@ -68,5 +68,5 @@ class PermutationBuilder:
return LOOP.run(do_execute())
def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
return PermutationBuilder(table, dest_table_name)
def permutation_builder(table: LanceTable) -> PermutationBuilder:
return PermutationBuilder(table)

View File

@@ -12,11 +12,7 @@ def test_split_random_ratios(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(ratios=[0.3, 0.7])
.execute()
)
permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute()
# Check that the table was created and has data
assert permutation_tbl.count_rows() == 100
@@ -38,11 +34,7 @@ def test_split_random_counts(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(counts=[20, 30])
.execute()
)
permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute()
# Check that we have exactly the requested counts
assert permutation_tbl.count_rows() == 50
@@ -58,9 +50,7 @@ def test_split_random_fixed(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
)
permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute()
# Check that we have 4 splits with 25 rows each
assert permutation_tbl.count_rows() == 100
@@ -78,17 +68,9 @@ def test_split_random_with_seed(mem_db):
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
# Create two identical permutations with same seed
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
# Results should be identical
data1 = perm1.search(None).to_arrow().to_pydict()
@@ -112,7 +94,7 @@ def test_split_hash(mem_db):
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
@@ -133,7 +115,7 @@ def test_split_hash(mem_db):
# Hash splits should be deterministic - same category should go to same split
# Let's verify by creating another permutation and checking consistency
perm2 = (
permutation_builder(tbl, "test_permutation2")
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
@@ -150,7 +132,7 @@ def test_split_hash_with_discard(mem_db):
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
.execute()
)
@@ -168,9 +150,7 @@ def test_split_sequential(mem_db):
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_sequential(counts=[30, 40])
.execute()
permutation_builder(tbl).split_sequential(counts=[30, 40]).execute()
)
assert permutation_tbl.count_rows() == 70
@@ -194,7 +174,7 @@ def test_split_calculated(mem_db):
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.split_calculated("id % 3") # Split based on id modulo 3
.execute()
)
@@ -216,23 +196,21 @@ def test_split_error_cases(mem_db):
# Test split_random with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error1").split_random().execute()
permutation_builder(tbl).split_random().execute()
# Test split_random with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error2").split_random(
permutation_builder(tbl).split_random(
ratios=[0.5, 0.5], counts=[5, 5]
).execute()
# Test split_sequential with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error3").split_sequential().execute()
permutation_builder(tbl).split_sequential().execute()
# Test split_sequential with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error4").split_sequential(
ratios=[0.5, 0.5], fixed=2
).execute()
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
def test_shuffle_no_seed(mem_db):
@@ -242,7 +220,7 @@ def test_shuffle_no_seed(mem_db):
)
# Create a permutation with shuffling (no seed)
permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
permutation_tbl = permutation_builder(tbl).shuffle().execute()
assert permutation_tbl.count_rows() == 100
@@ -262,9 +240,9 @@ def test_shuffle_with_seed(mem_db):
)
# Create two identical permutations with same shuffle seed
perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
perm1 = permutation_builder(tbl).shuffle(seed=42).execute()
perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
perm2 = permutation_builder(tbl).shuffle(seed=42).execute()
# Results should be identical due to same seed
data1 = perm1.search(None).to_arrow().to_pydict()
@@ -282,7 +260,7 @@ def test_shuffle_with_clump_size(mem_db):
# Create a permutation with shuffling using clumps
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.shuffle(clump_size=10) # 10-row clumps
.execute()
)
@@ -304,19 +282,9 @@ def test_shuffle_different_seeds(mem_db):
)
# Create two permutations with different shuffle seeds
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(fixed=2)
.shuffle(seed=42)
.execute()
)
perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute()
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(fixed=2)
.shuffle(seed=123)
.execute()
)
perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute()
# Results should be different due to different seeds
data1 = perm1.search(None).to_arrow().to_pydict()
@@ -341,7 +309,7 @@ def test_shuffle_combined_with_splits(mem_db):
# Test shuffle with random splits
perm_random = (
permutation_builder(tbl, "perm_random")
permutation_builder(tbl)
.split_random(ratios=[0.6, 0.4], seed=42)
.shuffle(seed=123, clump_size=None)
.execute()
@@ -349,7 +317,7 @@ def test_shuffle_combined_with_splits(mem_db):
# Test shuffle with hash splits
perm_hash = (
permutation_builder(tbl, "perm_hash")
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.shuffle(seed=456, clump_size=5)
.execute()
@@ -357,7 +325,7 @@ def test_shuffle_combined_with_splits(mem_db):
# Test shuffle with sequential splits
perm_sequential = (
permutation_builder(tbl, "perm_sequential")
permutation_builder(tbl)
.split_sequential(counts=[40, 35])
.shuffle(seed=789, clump_size=None)
.execute()
@@ -384,7 +352,7 @@ def test_no_shuffle_maintains_order(mem_db):
# Create permutation without shuffle (should maintain some order)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.split_sequential(counts=[25, 25]) # Sequential maintains order
.execute()
)
@@ -405,9 +373,7 @@ def test_filter_basic(mem_db):
)
# Filter to only include rows where id < 50
permutation_tbl = (
permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
)
permutation_tbl = permutation_builder(tbl).filter("id < 50").execute()
assert permutation_tbl.count_rows() == 50
@@ -433,7 +399,7 @@ def test_filter_with_splits(mem_db):
# Filter to only category A and B, then split
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.filter("category IN ('A', 'B')")
.split_random(ratios=[0.5, 0.5])
.execute()
@@ -465,7 +431,7 @@ def test_filter_with_shuffle(mem_db):
# Filter and shuffle
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.filter("category IN ('A', 'C')")
.shuffle(seed=42)
.execute()
@@ -488,7 +454,7 @@ def test_filter_empty_result(mem_db):
# Filter that matches nothing
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
permutation_builder(tbl)
.filter("value > 100") # No values > 100 in our data
.execute()
)