feat: a utility for creating "permutation views" (#2552)

I'm working on a lancedb version of pytorch data loading (and hopefully
addressing https://github.com/lancedb/lance/issues/3727).

However, rather than rely on pytorch for everything I'm moving some of
the things that pytorch does into rust. This gives us more control over
data loading (e.g. using shards or a hash-based split) and it allows
permutations to be persistent. In particular I hope to be able to:

* Create a persistent permutation
* This permutation can handle splits, filtering, shuffling, and sharding
* Create a rust data loader that can read a permutation (one or more
splits), or a subset of a permutation (for DDP)
* Create a python data loader that delegates to the rust data loader

Eventually create integrations for other data loading libraries,
including rust & node
This commit is contained in:
Weston Pace
2025-10-09 18:07:31 -07:00
committed by GitHub
parent 3dcec724b7
commit 5a19cf15a6
38 changed files with 3786 additions and 58 deletions

View File

@@ -12,6 +12,7 @@ mod header;
mod index;
mod iterator;
pub mod merge;
pub mod permutation;
mod query;
pub mod remote;
mod rerankers;

222
nodejs/src/permutation.rs Normal file
View File

@@ -0,0 +1,222 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::NapiErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
};
use napi_derive::napi;
#[napi(object)]
pub struct SplitRandomOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub seed: Option<i64>,
}
#[napi(object)]
pub struct SplitHashOptions {
pub columns: Vec<String>,
pub split_weights: Vec<i64>,
pub discard_weight: Option<i64>,
}
#[napi(object)]
pub struct SplitSequentialOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
}
#[napi(object)]
pub struct ShuffleOptions {
pub seed: Option<i64>,
pub clump_size: Option<i64>,
}
pub struct PermutationBuilderState {
pub builder: Option<LancePermutationBuilder>,
pub dest_table_name: String,
}
#[napi]
pub struct PermutationBuilder {
state: Arc<Mutex<PermutationBuilderState>>,
}
impl PermutationBuilder {
pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self {
Self {
state: Arc::new(Mutex::new(PermutationBuilderState {
builder: Some(builder),
dest_table_name,
})),
}
}
}
impl PermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> napi::Result<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[napi]
impl PermutationBuilder {
/// Configure random splits
#[napi]
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
let seed = options.seed.map(|s| s as u64);
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
/// Configure hash-based splits
#[napi]
pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result<Self> {
let split_weights = options
.split_weights
.into_iter()
.map(|w| w as u64)
.collect();
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
})
})
}
/// Configure sequential splits
#[napi]
pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
/// Configure calculated splits
#[napi]
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
})
}
/// Configure shuffling
#[napi]
pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result<Self> {
let seed = options.seed.map(|s| s as u64);
let clump_size = options.clump_size.map(|c| c as u64);
self.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
/// Configure filtering
#[napi]
pub fn filter(&self, filter: String) -> napi::Result<Self> {
self.modify(|builder| builder.with_filter(filter))
}
/// Execute the permutation builder and create the table
#[napi]
pub async fn execute(&self) -> napi::Result<Table> {
let (builder, dest_table_name) = {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
(builder, dest_table_name)
};
let table = builder.build(&dest_table_name).await.default_error()?;
Ok(Table::new(table))
}
}
/// Create a permutation builder for the given table
#[napi]
pub fn permutation_builder(
table: &crate::table::Table,
dest_table_name: String,
) -> napi::Result<PermutationBuilder> {
use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder;
let inner_table = table.inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PermutationBuilder::new(inner_builder, dest_table_name))
}

View File

@@ -26,7 +26,7 @@ pub struct Table {
}
impl Table {
fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
pub(crate) fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))