feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)

This commit is contained in:
Weston Pace
2025-11-06 16:15:33 -08:00
committed by GitHub
parent 6ddd271627
commit aeac9c7644
24 changed files with 2071 additions and 126 deletions

View File

@@ -4,7 +4,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use lancedb::database::CreateTableMode;
use lancedb::database::{CreateTableMode, Database};
use napi::bindgen_prelude::*;
use napi_derive::*;
@@ -41,6 +41,10 @@ impl Connection {
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
}
}
pub fn database(&self) -> napi::Result<Arc<dyn Database>> {
Ok(self.get_inner()?.database().clone())
}
}
#[napi]

View File

@@ -16,6 +16,7 @@ pub struct SplitRandomOptions {
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub seed: Option<i64>,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
@@ -23,6 +24,7 @@ pub struct SplitHashOptions {
pub columns: Vec<String>,
pub split_weights: Vec<i64>,
pub discard_weight: Option<i64>,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
@@ -30,6 +32,13 @@ pub struct SplitSequentialOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
pub struct SplitCalculatedOptions {
pub calculation: String,
pub split_names: Option<Vec<String>>,
}
#[napi(object)]
@@ -76,6 +85,16 @@ impl PermutationBuilder {
#[napi]
impl PermutationBuilder {
#[napi]
pub fn persist(
&self,
connection: &crate::connection::Connection,
table_name: String,
) -> napi::Result<Self> {
let database = connection.database()?;
self.modify(|builder| builder.persist(database, table_name))
}
/// Configure random splits
#[napi]
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
@@ -107,7 +126,12 @@ impl PermutationBuilder {
let seed = options.seed.map(|s| s as u64);
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
self.modify(|builder| {
builder.with_split_strategy(
SplitStrategy::Random { seed, sizes },
options.split_names.clone(),
)
})
}
/// Configure hash-based splits
@@ -120,12 +144,15 @@ impl PermutationBuilder {
.collect();
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
})
self.modify(move |builder| {
builder.with_split_strategy(
SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
},
options.split_names,
)
})
}
@@ -158,14 +185,21 @@ impl PermutationBuilder {
unreachable!("One of the split arguments must be provided");
};
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
self.modify(move |builder| {
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, options.split_names)
})
}
/// Configure calculated splits
#[napi]
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
pub fn split_calculated(&self, options: SplitCalculatedOptions) -> napi::Result<Self> {
self.modify(move |builder| {
builder.with_split_strategy(
SplitStrategy::Calculated {
calculation: options.calculation,
},
options.split_names,
)
})
}