feat: add a permutation reader that can read a permutation view (#2712)

This adds a rust permutation builder. In the next PR I will have python
bindings and integration with pytorch.
This commit is contained in:
Weston Pace
2025-10-17 05:00:23 -07:00
committed by GitHub
parent a70ff04bc9
commit 4cfcd95320
24 changed files with 974 additions and 546 deletions

View File

@@ -38,23 +38,22 @@ describe("PermutationBuilder", () => {
});
test("should create permutation builder", () => {
const builder = permutationBuilder(table, "permutation_table");
const builder = permutationBuilder(table);
expect(builder).toBeDefined();
});
test("should execute basic permutation", async () => {
const builder = permutationBuilder(table, "permutation_table");
const builder = permutationBuilder(table);
const permutationTable = await builder.execute();
expect(permutationTable).toBeDefined();
expect(permutationTable.name).toBe("permutation_table");
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with random splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
const builder = permutationBuilder(table).splitRandom({
ratios: [1.0],
seed: 42,
});
@@ -65,7 +64,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with percentage splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
const builder = permutationBuilder(table).splitRandom({
ratios: [0.3, 0.7],
seed: 42,
});
@@ -84,7 +83,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with count splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
const builder = permutationBuilder(table).splitRandom({
counts: [3, 7],
seed: 42,
});
@@ -102,7 +101,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with hash splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitHash({
const builder = permutationBuilder(table).splitHash({
columns: ["id"],
splitWeights: [50, 50],
discardWeight: 0,
@@ -122,10 +121,9 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with sequential splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitSequential({ ratios: [0.5, 0.5] });
const builder = permutationBuilder(table).splitSequential({
ratios: [0.5, 0.5],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
@@ -140,10 +138,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with calculated splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitCalculated("id % 2");
const builder = permutationBuilder(table).splitCalculated("id % 2");
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
@@ -159,7 +154,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with shuffle", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
const builder = permutationBuilder(table).shuffle({
seed: 42,
});
@@ -169,7 +164,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with shuffle and clump size", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
const builder = permutationBuilder(table).shuffle({
seed: 42,
clumpSize: 2,
});
@@ -180,9 +175,7 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with filter", async () => {
const builder = permutationBuilder(table, "permutation_table").filter(
"value > 50",
);
const builder = permutationBuilder(table).filter("value > 50");
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
@@ -190,7 +183,7 @@ describe("PermutationBuilder", () => {
});
test("should chain multiple operations", async () => {
const builder = permutationBuilder(table, "permutation_table")
const builder = permutationBuilder(table)
.filter("value <= 80")
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
.shuffle({ seed: 123 });
@@ -209,7 +202,7 @@ describe("PermutationBuilder", () => {
});
test("should throw error for invalid split arguments", () => {
const builder = permutationBuilder(table, "permutation_table");
const builder = permutationBuilder(table);
// Test no arguments provided
expect(() => builder.splitRandom({})).toThrow(
@@ -223,7 +216,7 @@ describe("PermutationBuilder", () => {
});
test("should throw error when builder is consumed", async () => {
const builder = permutationBuilder(table, "permutation_table");
const builder = permutationBuilder(table);
// Execute once
await builder.execute();

View File

@@ -161,7 +161,6 @@ export class PermutationBuilder {
* Create a permutation builder for the given table.
*
* @param table - The source table to create a permutation from
* @param destTableName - The name for the destination permutation table
* @returns A PermutationBuilder instance
* @example
* ```ts
@@ -172,17 +171,13 @@ export class PermutationBuilder {
* const trainingTable = await builder.execute();
* ```
*/
export function permutationBuilder(
table: Table,
destTableName: string,
): PermutationBuilder {
export function permutationBuilder(table: Table): PermutationBuilder {
// Extract the inner native table from the TypeScript wrapper
const localTable = table as LocalTable;
// Access inner through type assertion since it's private
const nativeBuilder = nativePermutationBuilder(
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
(localTable as any).inner,
destTableName,
);
return new PermutationBuilder(nativeBuilder);
}

View File

@@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex};
use crate::{error::NapiErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
permutation::split::{SplitSizes, SplitStrategy},
};
use napi_derive::napi;
@@ -40,7 +40,6 @@ pub struct ShuffleOptions {
pub struct PermutationBuilderState {
pub builder: Option<LancePermutationBuilder>,
pub dest_table_name: String,
}
#[napi]
@@ -49,11 +48,10 @@ pub struct PermutationBuilder {
}
impl PermutationBuilder {
pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self {
pub fn new(builder: LancePermutationBuilder) -> Self {
Self {
state: Arc::new(Mutex::new(PermutationBuilderState {
builder: Some(builder),
dest_table_name,
})),
}
}
@@ -191,32 +189,26 @@ impl PermutationBuilder {
/// Execute the permutation builder and create the table
#[napi]
pub async fn execute(&self) -> napi::Result<Table> {
let (builder, dest_table_name) = {
let builder = {
let mut state = self.state.lock().unwrap();
let builder = state
state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
(builder, dest_table_name)
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?
};
let table = builder.build(&dest_table_name).await.default_error()?;
let table = builder.build().await.default_error()?;
Ok(Table::new(table))
}
}
/// Create a permutation builder for the given table
#[napi]
pub fn permutation_builder(
table: &crate::table::Table,
dest_table_name: String,
) -> napi::Result<PermutationBuilder> {
use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder;
pub fn permutation_builder(table: &crate::table::Table) -> napi::Result<PermutationBuilder> {
use lancedb::dataloader::permutation::builder::PermutationBuilder as LancePermutationBuilder;
let inner_table = table.inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PermutationBuilder::new(inner_builder, dest_table_name))
Ok(PermutationBuilder::new(inner_builder))
}