mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use lancedb::database::CreateTableMode;
|
||||
use lancedb::database::{CreateTableMode, Database};
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::*;
|
||||
|
||||
@@ -41,6 +41,10 @@ impl Connection {
|
||||
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn database(&self) -> napi::Result<Arc<dyn Database>> {
|
||||
Ok(self.get_inner()?.database().clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
|
||||
@@ -16,6 +16,7 @@ pub struct SplitRandomOptions {
|
||||
pub counts: Option<Vec<i64>>,
|
||||
pub fixed: Option<i64>,
|
||||
pub seed: Option<i64>,
|
||||
pub split_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
@@ -23,6 +24,7 @@ pub struct SplitHashOptions {
|
||||
pub columns: Vec<String>,
|
||||
pub split_weights: Vec<i64>,
|
||||
pub discard_weight: Option<i64>,
|
||||
pub split_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
@@ -30,6 +32,13 @@ pub struct SplitSequentialOptions {
|
||||
pub ratios: Option<Vec<f64>>,
|
||||
pub counts: Option<Vec<i64>>,
|
||||
pub fixed: Option<i64>,
|
||||
pub split_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct SplitCalculatedOptions {
|
||||
pub calculation: String,
|
||||
pub split_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
@@ -76,6 +85,16 @@ impl PermutationBuilder {
|
||||
|
||||
#[napi]
|
||||
impl PermutationBuilder {
|
||||
#[napi]
|
||||
pub fn persist(
|
||||
&self,
|
||||
connection: &crate::connection::Connection,
|
||||
table_name: String,
|
||||
) -> napi::Result<Self> {
|
||||
let database = connection.database()?;
|
||||
self.modify(|builder| builder.persist(database, table_name))
|
||||
}
|
||||
|
||||
/// Configure random splits
|
||||
#[napi]
|
||||
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
|
||||
@@ -107,7 +126,12 @@ impl PermutationBuilder {
|
||||
|
||||
let seed = options.seed.map(|s| s as u64);
|
||||
|
||||
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
||||
self.modify(|builder| {
|
||||
builder.with_split_strategy(
|
||||
SplitStrategy::Random { seed, sizes },
|
||||
options.split_names.clone(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Configure hash-based splits
|
||||
@@ -120,12 +144,15 @@ impl PermutationBuilder {
|
||||
.collect();
|
||||
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
|
||||
|
||||
self.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Hash {
|
||||
columns: options.columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
})
|
||||
self.modify(move |builder| {
|
||||
builder.with_split_strategy(
|
||||
SplitStrategy::Hash {
|
||||
columns: options.columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
},
|
||||
options.split_names,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -158,14 +185,21 @@ impl PermutationBuilder {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
||||
self.modify(move |builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, options.split_names)
|
||||
})
|
||||
}
|
||||
|
||||
/// Configure calculated splits
|
||||
#[napi]
|
||||
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
|
||||
self.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
|
||||
pub fn split_calculated(&self, options: SplitCalculatedOptions) -> napi::Result<Self> {
|
||||
self.modify(move |builder| {
|
||||
builder.with_split_strategy(
|
||||
SplitStrategy::Calculated {
|
||||
calculation: options.calculation,
|
||||
},
|
||||
options.split_names,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user