feat: a utility for creating "permutation views" (#2552)

I'm working on a lancedb version of pytorch data loading (and hopefully
addressing https://github.com/lancedb/lance/issues/3727).

However, rather than rely on pytorch for everything I'm moving some of
the things that pytorch does into rust. This gives us more control over
data loading (e.g. using shards or a hash-based split) and it allows
permutations to be persistent. In particular I hope to be able to:

* Create a persistent permutation
* This permutation can handle splits, filtering, shuffling, and sharding
* Create a rust data loader that can read a permutation (one or more
splits), or a subset of a permutation (for DDP)
* Create a python data loader that delegates to the rust data loader

Eventually create integrations for other data loading libraries,
including rust & node
This commit is contained in:
Weston Pace
2025-10-09 18:07:31 -07:00
committed by GitHub
parent 3dcec724b7
commit 5a19cf15a6
38 changed files with 3786 additions and 58 deletions

View File

@@ -0,0 +1,234 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import * as tmp from "tmp";
import { Table, connect, permutationBuilder } from "../lancedb";
import { makeArrowTable } from "../lancedb/arrow";
describe("PermutationBuilder", () => {
let tmpDir: tmp.DirResult;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const db = await connect(tmpDir.name);
// Create test data
const data = makeArrowTable(
[
{ id: 1, value: 10 },
{ id: 2, value: 20 },
{ id: 3, value: 30 },
{ id: 4, value: 40 },
{ id: 5, value: 50 },
{ id: 6, value: 60 },
{ id: 7, value: 70 },
{ id: 8, value: 80 },
{ id: 9, value: 90 },
{ id: 10, value: 100 },
],
{ vectorColumns: {} },
);
table = await db.createTable("test_table", data);
});
afterEach(() => {
tmpDir.removeCallback();
});
test("should create permutation builder", () => {
const builder = permutationBuilder(table, "permutation_table");
expect(builder).toBeDefined();
});
test("should execute basic permutation", async () => {
const builder = permutationBuilder(table, "permutation_table");
const permutationTable = await builder.execute();
expect(permutationTable).toBeDefined();
expect(permutationTable.name).toBe("permutation_table");
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with random splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
ratios: [1.0],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with percentage splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
ratios: [0.3, 0.7],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with count splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
counts: [3, 7],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(3);
expect(split1Count).toBe(7);
});
test("should create permutation with hash splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitHash({
columns: ["id"],
splitWeights: [50, 50],
discardWeight: 0,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check that splits exist
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with sequential splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitSequential({ ratios: [0.5, 0.5] });
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution - sequential should give exactly 5 and 5
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(5);
expect(split1Count).toBe(5);
});
test("should create permutation with calculated splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitCalculated("id % 2");
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with shuffle", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with shuffle and clump size", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
seed: 42,
clumpSize: 2,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with filter", async () => {
const builder = permutationBuilder(table, "permutation_table").filter(
"value > 50",
);
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(5); // Values 60, 70, 80, 90, 100
});
test("should chain multiple operations", async () => {
const builder = permutationBuilder(table, "permutation_table")
.filter("value <= 80")
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
.shuffle({ seed: 123 });
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(8); // Values 10, 20, 30, 40, 50, 60, 70, 80
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(8);
});
test("should throw error for invalid split arguments", () => {
const builder = permutationBuilder(table, "permutation_table");
// Test no arguments provided
expect(() => builder.splitRandom({})).toThrow(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
);
// Test multiple arguments provided
expect(() =>
builder.splitRandom({ ratios: [0.5, 0.5], counts: [3, 7], seed: 42 }),
).toThrow("Exactly one of 'ratios', 'counts', or 'fixed' must be provided");
});
test("should throw error when builder is consumed", async () => {
const builder = permutationBuilder(table, "permutation_table");
// Execute once
await builder.execute();
// Should throw error on second execution
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
});
});

View File

@@ -43,6 +43,10 @@ export {
DeleteResult,
DropColumnsResult,
UpdateResult,
SplitRandomOptions,
SplitHashOptions,
SplitSequentialOptions,
ShuffleOptions,
} from "./native.js";
export {
@@ -111,6 +115,7 @@ export {
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";
export { permutationBuilder, PermutationBuilder } from "./permutation";
export * as rerankers from "./rerankers";
export {
SchemaLike,

View File

@@ -0,0 +1,188 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import {
PermutationBuilder as NativePermutationBuilder,
Table as NativeTable,
ShuffleOptions,
SplitHashOptions,
SplitRandomOptions,
SplitSequentialOptions,
permutationBuilder as nativePermutationBuilder,
} from "./native.js";
import { LocalTable, Table } from "./table";
/**
* A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
*
* This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
* offering methods to configure data splits, shuffling, and filtering before executing
* the permutation to create a new table.
*/
export class PermutationBuilder {
private inner: NativePermutationBuilder;
/**
* @hidden
*/
constructor(inner: NativePermutationBuilder) {
this.inner = inner;
}
/**
* Configure random splits for the permutation.
*
* @param options - Configuration for random splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Split by ratios
* builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
*
* // Split by counts
* builder.splitRandom({ counts: [1000, 500], seed: 42 });
*
* // Split with fixed size
* builder.splitRandom({ fixed: 100, seed: 42 });
* ```
*/
splitRandom(options: SplitRandomOptions): PermutationBuilder {
const newInner = this.inner.splitRandom(options);
return new PermutationBuilder(newInner);
}
/**
* Configure hash-based splits for the permutation.
*
* @param options - Configuration for hash-based splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitHash({
* columns: ["user_id"],
* splitWeights: [70, 30],
* discardWeight: 0
* });
* ```
*/
splitHash(options: SplitHashOptions): PermutationBuilder {
const newInner = this.inner.splitHash(options);
return new PermutationBuilder(newInner);
}
/**
* Configure sequential splits for the permutation.
*
* @param options - Configuration for sequential splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Split by ratios
* builder.splitSequential({ ratios: [0.8, 0.2] });
*
* // Split by counts
* builder.splitSequential({ counts: [800, 200] });
*
* // Split with fixed size
* builder.splitSequential({ fixed: 1000 });
* ```
*/
splitSequential(options: SplitSequentialOptions): PermutationBuilder {
const newInner = this.inner.splitSequential(options);
return new PermutationBuilder(newInner);
}
/**
* Configure calculated splits for the permutation.
*
* @param calculation - SQL expression for calculating splits
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitCalculated("user_id % 3");
* ```
*/
splitCalculated(calculation: string): PermutationBuilder {
const newInner = this.inner.splitCalculated(calculation);
return new PermutationBuilder(newInner);
}
/**
* Configure shuffling for the permutation.
*
* @param options - Configuration for shuffling
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Basic shuffle
* builder.shuffle({ seed: 42 });
*
* // Shuffle with clump size
* builder.shuffle({ seed: 42, clumpSize: 10 });
* ```
*/
shuffle(options: ShuffleOptions): PermutationBuilder {
const newInner = this.inner.shuffle(options);
return new PermutationBuilder(newInner);
}
/**
* Configure filtering for the permutation.
*
* @param filter - SQL filter expression
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.filter("age > 18 AND status = 'active'");
* ```
*/
filter(filter: string): PermutationBuilder {
const newInner = this.inner.filter(filter);
return new PermutationBuilder(newInner);
}
/**
* Execute the permutation and create the destination table.
*
* @returns A Promise that resolves to the new Table instance
* @example
* ```ts
* const permutationTable = await builder.execute();
* console.log(`Created table: ${permutationTable.name}`);
* ```
*/
async execute(): Promise<Table> {
const nativeTable: NativeTable = await this.inner.execute();
return new LocalTable(nativeTable);
}
}
/**
* Create a permutation builder for the given table.
*
* @param table - The source table to create a permutation from
* @param destTableName - The name for the destination permutation table
* @returns A PermutationBuilder instance
* @example
* ```ts
* const builder = permutationBuilder(sourceTable, "training_data")
* .splitRandom({ ratios: [0.8, 0.2], seed: 42 })
* .shuffle({ seed: 123 });
*
* const trainingTable = await builder.execute();
* ```
*/
export function permutationBuilder(
table: Table,
destTableName: string,
): PermutationBuilder {
// Extract the inner native table from the TypeScript wrapper
const localTable = table as LocalTable;
// Access inner through type assertion since it's private
const nativeBuilder = nativePermutationBuilder(
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
(localTable as any).inner,
destTableName,
);
return new PermutationBuilder(nativeBuilder);
}

View File

@@ -12,6 +12,7 @@ mod header;
mod index;
mod iterator;
pub mod merge;
pub mod permutation;
mod query;
pub mod remote;
mod rerankers;

222
nodejs/src/permutation.rs Normal file
View File

@@ -0,0 +1,222 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::NapiErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
};
use napi_derive::napi;
#[napi(object)]
pub struct SplitRandomOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub seed: Option<i64>,
}
#[napi(object)]
pub struct SplitHashOptions {
pub columns: Vec<String>,
pub split_weights: Vec<i64>,
pub discard_weight: Option<i64>,
}
#[napi(object)]
pub struct SplitSequentialOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
}
#[napi(object)]
pub struct ShuffleOptions {
pub seed: Option<i64>,
pub clump_size: Option<i64>,
}
pub struct PermutationBuilderState {
pub builder: Option<LancePermutationBuilder>,
pub dest_table_name: String,
}
#[napi]
pub struct PermutationBuilder {
state: Arc<Mutex<PermutationBuilderState>>,
}
impl PermutationBuilder {
pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self {
Self {
state: Arc::new(Mutex::new(PermutationBuilderState {
builder: Some(builder),
dest_table_name,
})),
}
}
}
impl PermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> napi::Result<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[napi]
impl PermutationBuilder {
/// Configure random splits
#[napi]
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
let seed = options.seed.map(|s| s as u64);
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
/// Configure hash-based splits
#[napi]
pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result<Self> {
let split_weights = options
.split_weights
.into_iter()
.map(|w| w as u64)
.collect();
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
})
})
}
/// Configure sequential splits
#[napi]
pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
/// Configure calculated splits
#[napi]
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
})
}
/// Configure shuffling
#[napi]
pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result<Self> {
let seed = options.seed.map(|s| s as u64);
let clump_size = options.clump_size.map(|c| c as u64);
self.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
/// Configure filtering
#[napi]
pub fn filter(&self, filter: String) -> napi::Result<Self> {
self.modify(|builder| builder.with_filter(filter))
}
/// Execute the permutation builder and create the table
#[napi]
pub async fn execute(&self) -> napi::Result<Table> {
let (builder, dest_table_name) = {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
(builder, dest_table_name)
};
let table = builder.build(&dest_table_name).await.default_error()?;
Ok(Table::new(table))
}
}
/// Create a permutation builder for the given table
#[napi]
pub fn permutation_builder(
table: &crate::table::Table,
dest_table_name: String,
) -> napi::Result<PermutationBuilder> {
use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder;
let inner_table = table.inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PermutationBuilder::new(inner_builder, dest_table_name))
}

View File

@@ -26,7 +26,7 @@ pub struct Table {
}
impl Table {
fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
pub(crate) fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))