// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::sync::{Arc, Mutex}; use crate::{error::PythonErrorExt, table::Table}; use lancedb::dataloader::{ permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, permutation::split::{SplitSizes, SplitStrategy}, }; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut, PyResult, }; use pyo3_async_runtimes::tokio::future_into_py; /// Create a permutation builder for the given table #[pyo3::pyfunction] pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult { let table = table.getattr("_inner")?.downcast_into::()?; let inner_table = table.borrow().inner_ref()?.clone(); let inner_builder = LancePermutationBuilder::new(inner_table); Ok(PyAsyncPermutationBuilder { state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState { builder: Some(inner_builder), })), }) } struct PyAsyncPermutationBuilderState { builder: Option, } #[pyclass(name = "AsyncPermutationBuilder")] pub struct PyAsyncPermutationBuilder { state: Arc>, } impl PyAsyncPermutationBuilder { fn modify( &self, func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder, ) -> PyResult { let mut state = self.state.lock().unwrap(); let builder = state .builder .take() .ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?; state.builder = Some(func(builder)); Ok(Self { state: self.state.clone(), }) } } #[pymethods] impl PyAsyncPermutationBuilder { #[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))] pub fn split_random( slf: PyRefMut<'_, Self>, ratios: Option>, counts: Option>, fixed: Option, seed: Option, ) -> PyResult { // Check that exactly one split type is provided let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()] .iter() .filter(|&&x| x) .count(); if split_args_count != 1 { return Err(pyo3::exceptions::PyValueError::new_err( "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", )); } let sizes = if let Some(ratios) = ratios { SplitSizes::Percentages(ratios) } else if let Some(counts) = counts { SplitSizes::Counts(counts) } else if let Some(fixed) = fixed { SplitSizes::Fixed(fixed) } else { unreachable!("One of the split arguments must be provided"); }; slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes })) } #[pyo3(signature = (columns, split_weights, *, discard_weight=0))] pub fn split_hash( slf: PyRefMut<'_, Self>, columns: Vec, split_weights: Vec, discard_weight: u64, ) -> PyResult { slf.modify(|builder| { builder.with_split_strategy(SplitStrategy::Hash { columns, split_weights, discard_weight, }) }) } #[pyo3(signature = (*, ratios=None, counts=None, fixed=None))] pub fn split_sequential( slf: PyRefMut<'_, Self>, ratios: Option>, counts: Option>, fixed: Option, ) -> PyResult { // Check that exactly one split type is provided let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()] .iter() .filter(|&&x| x) .count(); if split_args_count != 1 { return Err(pyo3::exceptions::PyValueError::new_err( "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", )); } let sizes = if let Some(ratios) = ratios { SplitSizes::Percentages(ratios) } else if let Some(counts) = counts { SplitSizes::Counts(counts) } else if let Some(fixed) = fixed { SplitSizes::Fixed(fixed) } else { unreachable!("One of the split arguments must be provided"); }; slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes })) } pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult { slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation })) } pub fn shuffle( slf: PyRefMut<'_, Self>, seed: Option, clump_size: Option, ) -> PyResult { slf.modify(|builder| { builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size }) }) } pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult { slf.modify(|builder| builder.with_filter(filter)) } pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult> { let mut state = slf.state.lock().unwrap(); let builder = state .builder .take() .ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?; future_into_py(slf.py(), async move { let table = builder.build().await.infer_error()?; Ok(Table::new(table)) }) } }