// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::sync::{Arc, Mutex}; use crate::{error::NapiErrorExt, table::Table}; use lancedb::dataloader::{ permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, permutation::split::{SplitSizes, SplitStrategy}, }; use napi_derive::napi; #[napi(object)] pub struct SplitRandomOptions { pub ratios: Option>, pub counts: Option>, pub fixed: Option, pub seed: Option, pub split_names: Option>, } #[napi(object)] pub struct SplitHashOptions { pub columns: Vec, pub split_weights: Vec, pub discard_weight: Option, pub split_names: Option>, } #[napi(object)] pub struct SplitSequentialOptions { pub ratios: Option>, pub counts: Option>, pub fixed: Option, pub split_names: Option>, } #[napi(object)] pub struct SplitCalculatedOptions { pub calculation: String, pub split_names: Option>, } #[napi(object)] pub struct ShuffleOptions { pub seed: Option, pub clump_size: Option, } pub struct PermutationBuilderState { pub builder: Option, } #[napi] pub struct PermutationBuilder { state: Arc>, } impl PermutationBuilder { pub fn new(builder: LancePermutationBuilder) -> Self { Self { state: Arc::new(Mutex::new(PermutationBuilderState { builder: Some(builder), })), } } } impl PermutationBuilder { fn modify( &self, func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder, ) -> napi::Result { let mut state = self.state.lock().unwrap(); let builder = state .builder .take() .ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?; state.builder = Some(func(builder)); Ok(Self { state: self.state.clone(), }) } } #[napi] impl PermutationBuilder { #[napi] pub fn persist( &self, connection: &crate::connection::Connection, table_name: String, ) -> napi::Result { let database = connection.database()?; self.modify(|builder| builder.persist(database, table_name)) } /// Configure random splits #[napi] pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result { // Check that exactly one split type is provided let split_args_count = [ options.ratios.is_some(), options.counts.is_some(), options.fixed.is_some(), ] .iter() .filter(|&&x| x) .count(); if split_args_count != 1 { return Err(napi::Error::from_reason( "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", )); } let sizes = if let Some(ratios) = options.ratios { SplitSizes::Percentages(ratios) } else if let Some(counts) = options.counts { SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect()) } else if let Some(fixed) = options.fixed { SplitSizes::Fixed(fixed as u64) } else { unreachable!("One of the split arguments must be provided"); }; let seed = options.seed.map(|s| s as u64); self.modify(|builder| { builder.with_split_strategy( SplitStrategy::Random { seed, sizes }, options.split_names.clone(), ) }) } /// Configure hash-based splits #[napi] pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result { let split_weights = options .split_weights .into_iter() .map(|w| w as u64) .collect(); let discard_weight = options.discard_weight.unwrap_or(0) as u64; self.modify(move |builder| { builder.with_split_strategy( SplitStrategy::Hash { columns: options.columns, split_weights, discard_weight, }, options.split_names, ) }) } /// Configure sequential splits #[napi] pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result { // Check that exactly one split type is provided let split_args_count = [ options.ratios.is_some(), options.counts.is_some(), options.fixed.is_some(), ] .iter() .filter(|&&x| x) .count(); if split_args_count != 1 { return Err(napi::Error::from_reason( "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", )); } let sizes = if let Some(ratios) = options.ratios { SplitSizes::Percentages(ratios) } else if let Some(counts) = options.counts { SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect()) } else if let Some(fixed) = options.fixed { SplitSizes::Fixed(fixed as u64) } else { unreachable!("One of the split arguments must be provided"); }; self.modify(move |builder| { builder.with_split_strategy(SplitStrategy::Sequential { sizes }, options.split_names) }) } /// Configure calculated splits #[napi] pub fn split_calculated(&self, options: SplitCalculatedOptions) -> napi::Result { self.modify(move |builder| { builder.with_split_strategy( SplitStrategy::Calculated { calculation: options.calculation, }, options.split_names, ) }) } /// Configure shuffling #[napi] pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result { let seed = options.seed.map(|s| s as u64); let clump_size = options.clump_size.map(|c| c as u64); self.modify(|builder| { builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size }) }) } /// Configure filtering #[napi] pub fn filter(&self, filter: String) -> napi::Result { self.modify(|builder| builder.with_filter(filter)) } /// Execute the permutation builder and create the table #[napi] pub async fn execute(&self) -> napi::Result { let builder = { let mut state = self.state.lock().unwrap(); state .builder .take() .ok_or_else(|| napi::Error::from_reason("Builder already consumed"))? }; let table = builder.build().await.default_error()?; Ok(Table::new(table)) } } /// Create a permutation builder for the given table #[napi] pub fn permutation_builder(table: &crate::table::Table) -> napi::Result { use lancedb::dataloader::permutation::builder::PermutationBuilder as LancePermutationBuilder; let inner_table = table.inner_ref()?.clone(); let inner_builder = LancePermutationBuilder::new(inner_table); Ok(PermutationBuilder::new(inner_builder)) }