feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)

This commit is contained in:
Weston Pace
2025-11-06 16:15:33 -08:00
committed by GitHub
parent 6ddd271627
commit aeac9c7644
24 changed files with 2071 additions and 126 deletions

View File

@@ -138,7 +138,9 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with calculated splits", async () => {
const builder = permutationBuilder(table).splitCalculated("id % 2");
const builder = permutationBuilder(table).splitCalculated({
calculation: "id % 2",
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
@@ -224,4 +226,146 @@ describe("PermutationBuilder", () => {
// Should throw error on second execution
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
});
test("should accept custom split names with random splits", async () => {
const builder = permutationBuilder(table).splitRandom({
ratios: [0.3, 0.7],
seed: 42,
splitNames: ["train", "test"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric (0, 1, etc.)
// The names are metadata that can be used by higher-level APIs
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should accept custom split names with hash splits", async () => {
const builder = permutationBuilder(table).splitHash({
columns: ["id"],
splitWeights: [50, 50],
discardWeight: 0,
splitNames: ["set_a", "set_b"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should accept custom split names with sequential splits", async () => {
const builder = permutationBuilder(table).splitSequential({
ratios: [0.5, 0.5],
splitNames: ["first", "second"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(5);
expect(split1Count).toBe(5);
});
test("should accept custom split names with calculated splits", async () => {
const builder = permutationBuilder(table).splitCalculated({
calculation: "id % 2",
splitNames: ["even", "odd"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should persist permutation to a new table", async () => {
const db = await connect(tmpDir.name);
const builder = permutationBuilder(table)
.splitRandom({
ratios: [0.7, 0.3],
seed: 42,
splitNames: ["train", "validation"],
})
.persist(db, "my_permutation");
// Execute the builder which will persist the table
const permutationTable = await builder.execute();
// Verify the persisted table exists and can be opened
const persistedTable = await db.openTable("my_permutation");
expect(persistedTable).toBeDefined();
// Verify the persisted table has the correct number of rows
const rowCount = await persistedTable.countRows();
expect(rowCount).toBe(10);
// Verify splits exist (numeric split_id values)
const split0Count = await persistedTable.countRows("split_id = 0");
const split1Count = await persistedTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
// Verify the table returned by execute is the same as the persisted one
const executedRowCount = await permutationTable.countRows();
expect(executedRowCount).toBe(10);
});
test("should persist permutation with multiple operations", async () => {
const db = await connect(tmpDir.name);
const builder = permutationBuilder(table)
.filter("value > 30")
.splitRandom({ ratios: [0.5, 0.5], seed: 123, splitNames: ["a", "b"] })
.shuffle({ seed: 456 })
.persist(db, "filtered_permutation");
// Execute the builder
const permutationTable = await builder.execute();
// Verify the persisted table
const persistedTable = await db.openTable("filtered_permutation");
const rowCount = await persistedTable.countRows();
expect(rowCount).toBe(7); // Values 40, 50, 60, 70, 80, 90, 100
// Verify splits exist (numeric split_id values)
const split0Count = await persistedTable.countRows("split_id = 0");
const split1Count = await persistedTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(7);
// Verify the executed table matches
const executedRowCount = await permutationTable.countRows();
expect(executedRowCount).toBe(7);
});
});

View File

@@ -43,6 +43,7 @@ export {
DeleteResult,
DropColumnsResult,
UpdateResult,
SplitCalculatedOptions,
SplitRandomOptions,
SplitHashOptions,
SplitSequentialOptions,

View File

@@ -1,10 +1,12 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { Connection, LocalConnection } from "./connection.js";
import {
PermutationBuilder as NativePermutationBuilder,
Table as NativeTable,
ShuffleOptions,
SplitCalculatedOptions,
SplitHashOptions,
SplitRandomOptions,
SplitSequentialOptions,
@@ -29,6 +31,23 @@ export class PermutationBuilder {
this.inner = inner;
}
/**
* Configure the permutation to be persisted.
*
* @param connection - The connection to persist the permutation to
* @param tableName - The name of the table to create
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.persist(connection, "permutation_table");
* ```
*/
persist(connection: Connection, tableName: string): PermutationBuilder {
const localConnection = connection as LocalConnection;
const newInner = this.inner.persist(localConnection.inner, tableName);
return new PermutationBuilder(newInner);
}
/**
* Configure random splits for the permutation.
*
@@ -95,15 +114,15 @@ export class PermutationBuilder {
/**
* Configure calculated splits for the permutation.
*
* @param calculation - SQL expression for calculating splits
* @param options - Configuration for calculated splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitCalculated("user_id % 3");
* ```
*/
splitCalculated(calculation: string): PermutationBuilder {
const newInner = this.inner.splitCalculated(calculation);
splitCalculated(options: SplitCalculatedOptions): PermutationBuilder {
const newInner = this.inner.splitCalculated(options);
return new PermutationBuilder(newInner);
}

View File

@@ -4,7 +4,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use lancedb::database::CreateTableMode;
use lancedb::database::{CreateTableMode, Database};
use napi::bindgen_prelude::*;
use napi_derive::*;
@@ -41,6 +41,10 @@ impl Connection {
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
}
}
pub fn database(&self) -> napi::Result<Arc<dyn Database>> {
Ok(self.get_inner()?.database().clone())
}
}
#[napi]

View File

@@ -16,6 +16,7 @@ pub struct SplitRandomOptions {
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub seed: Option<i64>,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
@@ -23,6 +24,7 @@ pub struct SplitHashOptions {
pub columns: Vec<String>,
pub split_weights: Vec<i64>,
pub discard_weight: Option<i64>,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
@@ -30,6 +32,13 @@ pub struct SplitSequentialOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
pub struct SplitCalculatedOptions {
pub calculation: String,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
@@ -76,6 +85,16 @@ impl PermutationBuilder {
#[napi]
impl PermutationBuilder {
#[napi]
pub fn persist(
&self,
connection: &crate::connection::Connection,
table_name: String,
) -> napi::Result<Self> {
let database = connection.database()?;
self.modify(|builder| builder.persist(database, table_name))
}
/// Configure random splits
#[napi]
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
@@ -107,7 +126,12 @@ impl PermutationBuilder {
let seed = options.seed.map(|s| s as u64);
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
self.modify(|builder| {
builder.with_split_strategy(
SplitStrategy::Random { seed, sizes },
options.split_names.clone(),
)
})
}
/// Configure hash-based splits
@@ -120,12 +144,15 @@ impl PermutationBuilder {
.collect();
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
})
self.modify(move |builder| {
builder.with_split_strategy(
SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
},
options.split_names,
)
})
}
@@ -158,14 +185,21 @@ impl PermutationBuilder {
unreachable!("One of the split arguments must be provided");
};
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
self.modify(move |builder| {
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, options.split_names)
})
}
/// Configure calculated splits
#[napi]
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
pub fn split_calculated(&self, options: SplitCalculatedOptions) -> napi::Result<Self> {
self.modify(move |builder| {
builder.with_split_strategy(
SplitStrategy::Calculated {
calculation: options.calculation,
},
options.split_names,
)
})
}