mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-15 11:00:41 +00:00
feat: add a permutation reader that can read a permutation view (#2712)
This adds a rust permutation builder. In the next PR I will have python bindings and integration with pytorch.
This commit is contained in:
@@ -8,8 +8,8 @@ from typing import Optional
|
||||
|
||||
|
||||
class PermutationBuilder:
|
||||
def __init__(self, table: LanceTable, dest_table_name: str):
|
||||
self._async = async_permutation_builder(table, dest_table_name)
|
||||
def __init__(self, table: LanceTable):
|
||||
self._async = async_permutation_builder(table)
|
||||
|
||||
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
|
||||
self._async.select(projections)
|
||||
@@ -68,5 +68,5 @@ class PermutationBuilder:
|
||||
return LOOP.run(do_execute())
|
||||
|
||||
|
||||
def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
|
||||
return PermutationBuilder(table, dest_table_name)
|
||||
def permutation_builder(table: LanceTable) -> PermutationBuilder:
|
||||
return PermutationBuilder(table)
|
||||
|
||||
@@ -12,11 +12,7 @@ def test_split_random_ratios(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_random(ratios=[0.3, 0.7])
|
||||
.execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute()
|
||||
|
||||
# Check that the table was created and has data
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
@@ -38,11 +34,7 @@ def test_split_random_counts(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_random(counts=[20, 30])
|
||||
.execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute()
|
||||
|
||||
# Check that we have exactly the requested counts
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
@@ -58,9 +50,7 @@ def test_split_random_fixed(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute()
|
||||
|
||||
# Check that we have 4 splits with 25 rows each
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
@@ -78,17 +68,9 @@ def test_split_random_with_seed(mem_db):
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
|
||||
|
||||
# Create two identical permutations with same seed
|
||||
perm1 = (
|
||||
permutation_builder(tbl, "perm1")
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.execute()
|
||||
)
|
||||
perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "perm2")
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.execute()
|
||||
)
|
||||
perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||
|
||||
# Results should be identical
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
@@ -112,7 +94,7 @@ def test_split_hash(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
@@ -133,7 +115,7 @@ def test_split_hash(mem_db):
|
||||
# Hash splits should be deterministic - same category should go to same split
|
||||
# Let's verify by creating another permutation and checking consistency
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "test_permutation2")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
@@ -150,7 +132,7 @@ def test_split_hash_with_discard(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
|
||||
.execute()
|
||||
)
|
||||
@@ -168,9 +150,7 @@ def test_split_sequential(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_sequential(counts=[30, 40])
|
||||
.execute()
|
||||
permutation_builder(tbl).split_sequential(counts=[30, 40]).execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 70
|
||||
@@ -194,7 +174,7 @@ def test_split_calculated(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_calculated("id % 3") # Split based on id modulo 3
|
||||
.execute()
|
||||
)
|
||||
@@ -216,23 +196,21 @@ def test_split_error_cases(mem_db):
|
||||
|
||||
# Test split_random with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error1").split_random().execute()
|
||||
permutation_builder(tbl).split_random().execute()
|
||||
|
||||
# Test split_random with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error2").split_random(
|
||||
permutation_builder(tbl).split_random(
|
||||
ratios=[0.5, 0.5], counts=[5, 5]
|
||||
).execute()
|
||||
|
||||
# Test split_sequential with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error3").split_sequential().execute()
|
||||
permutation_builder(tbl).split_sequential().execute()
|
||||
|
||||
# Test split_sequential with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error4").split_sequential(
|
||||
ratios=[0.5, 0.5], fixed=2
|
||||
).execute()
|
||||
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
|
||||
|
||||
|
||||
def test_shuffle_no_seed(mem_db):
|
||||
@@ -242,7 +220,7 @@ def test_shuffle_no_seed(mem_db):
|
||||
)
|
||||
|
||||
# Create a permutation with shuffling (no seed)
|
||||
permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
|
||||
permutation_tbl = permutation_builder(tbl).shuffle().execute()
|
||||
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
@@ -262,9 +240,9 @@ def test_shuffle_with_seed(mem_db):
|
||||
)
|
||||
|
||||
# Create two identical permutations with same shuffle seed
|
||||
perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
|
||||
perm1 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||
|
||||
perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
|
||||
perm2 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||
|
||||
# Results should be identical due to same seed
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
@@ -282,7 +260,7 @@ def test_shuffle_with_clump_size(mem_db):
|
||||
|
||||
# Create a permutation with shuffling using clumps
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.shuffle(clump_size=10) # 10-row clumps
|
||||
.execute()
|
||||
)
|
||||
@@ -304,19 +282,9 @@ def test_shuffle_different_seeds(mem_db):
|
||||
)
|
||||
|
||||
# Create two permutations with different shuffle seeds
|
||||
perm1 = (
|
||||
permutation_builder(tbl, "perm1")
|
||||
.split_random(fixed=2)
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute()
|
||||
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "perm2")
|
||||
.split_random(fixed=2)
|
||||
.shuffle(seed=123)
|
||||
.execute()
|
||||
)
|
||||
perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute()
|
||||
|
||||
# Results should be different due to different seeds
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
@@ -341,7 +309,7 @@ def test_shuffle_combined_with_splits(mem_db):
|
||||
|
||||
# Test shuffle with random splits
|
||||
perm_random = (
|
||||
permutation_builder(tbl, "perm_random")
|
||||
permutation_builder(tbl)
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.shuffle(seed=123, clump_size=None)
|
||||
.execute()
|
||||
@@ -349,7 +317,7 @@ def test_shuffle_combined_with_splits(mem_db):
|
||||
|
||||
# Test shuffle with hash splits
|
||||
perm_hash = (
|
||||
permutation_builder(tbl, "perm_hash")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.shuffle(seed=456, clump_size=5)
|
||||
.execute()
|
||||
@@ -357,7 +325,7 @@ def test_shuffle_combined_with_splits(mem_db):
|
||||
|
||||
# Test shuffle with sequential splits
|
||||
perm_sequential = (
|
||||
permutation_builder(tbl, "perm_sequential")
|
||||
permutation_builder(tbl)
|
||||
.split_sequential(counts=[40, 35])
|
||||
.shuffle(seed=789, clump_size=None)
|
||||
.execute()
|
||||
@@ -384,7 +352,7 @@ def test_no_shuffle_maintains_order(mem_db):
|
||||
|
||||
# Create permutation without shuffle (should maintain some order)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_sequential(counts=[25, 25]) # Sequential maintains order
|
||||
.execute()
|
||||
)
|
||||
@@ -405,9 +373,7 @@ def test_filter_basic(mem_db):
|
||||
)
|
||||
|
||||
# Filter to only include rows where id < 50
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).filter("id < 50").execute()
|
||||
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
|
||||
@@ -433,7 +399,7 @@ def test_filter_with_splits(mem_db):
|
||||
|
||||
# Filter to only category A and B, then split
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.filter("category IN ('A', 'B')")
|
||||
.split_random(ratios=[0.5, 0.5])
|
||||
.execute()
|
||||
@@ -465,7 +431,7 @@ def test_filter_with_shuffle(mem_db):
|
||||
|
||||
# Filter and shuffle
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.filter("category IN ('A', 'C')")
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
@@ -488,7 +454,7 @@ def test_filter_empty_result(mem_db):
|
||||
|
||||
# Filter that matches nothing
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.filter("value > 100") # No values > 100 in our data
|
||||
.execute()
|
||||
)
|
||||
|
||||
@@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::PythonErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
split::{SplitSizes, SplitStrategy},
|
||||
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
permutation::split::{SplitSizes, SplitStrategy},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
|
||||
@@ -16,10 +16,7 @@ use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[pyo3::pyfunction]
|
||||
pub fn async_permutation_builder(
|
||||
table: Bound<'_, PyAny>,
|
||||
dest_table_name: String,
|
||||
) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let inner_table = table.borrow().inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
@@ -27,14 +24,12 @@ pub fn async_permutation_builder(
|
||||
Ok(PyAsyncPermutationBuilder {
|
||||
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
|
||||
builder: Some(inner_builder),
|
||||
dest_table_name,
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
struct PyAsyncPermutationBuilderState {
|
||||
builder: Option<LancePermutationBuilder>,
|
||||
dest_table_name: String,
|
||||
}
|
||||
|
||||
#[pyclass(name = "AsyncPermutationBuilder")]
|
||||
@@ -167,10 +162,8 @@ impl PyAsyncPermutationBuilder {
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||
|
||||
let dest_table_name = std::mem::take(&mut state.dest_table_name);
|
||||
|
||||
future_into_py(slf.py(), async move {
|
||||
let table = builder.build(&dest_table_name).await.infer_error()?;
|
||||
let table = builder.build().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user