// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::sync::{Arc, Mutex}; use crate::{ arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table, }; use arrow::pyarrow::{PyArrowType, ToPyArrow}; use lancedb::{ dataloader::permutation::{ builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, reader::PermutationReader, split::{SplitSizes, SplitStrategy}, }, query::Select, }; use pyo3::{ Bound, PyAny, PyRef, PyRefMut, PyResult, Python, exceptions::PyRuntimeError, pyclass, pymethods, types::{PyAnyMethods, PyDict, PyDictMethods, PyType}, }; use pyo3_async_runtimes::tokio::future_into_py; fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult> { if table.hasattr("_inner")? { Ok(table.getattr("_inner")?.downcast_into::()?) } else if table.hasattr("_table")? { Ok(table .getattr("_table")? .getattr("_inner")? .downcast_into::
()?) } else { Err(PyRuntimeError::new_err( "Provided table does not appear to be a Table or RemoteTable instance", )) } } /// Create a permutation builder for the given table #[pyo3::pyfunction] pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult { let table = table_from_py(table)?; let inner_table = table.borrow().inner_ref()?.clone(); let inner_builder = LancePermutationBuilder::new(inner_table); Ok(PyAsyncPermutationBuilder { state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState { builder: Some(inner_builder), })), }) } struct PyAsyncPermutationBuilderState { builder: Option, } #[pyclass(name = "AsyncPermutationBuilder")] pub struct PyAsyncPermutationBuilder { state: Arc>, } impl PyAsyncPermutationBuilder { fn modify( &self, func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder, ) -> PyResult { let mut state = self.state.lock().unwrap(); let builder = state .builder .take() .ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?; state.builder = Some(func(builder)); Ok(Self { state: self.state.clone(), }) } } #[pymethods] impl PyAsyncPermutationBuilder { #[pyo3(signature = (database, table_name))] pub fn persist( slf: PyRefMut<'_, Self>, database: Bound<'_, PyAny>, table_name: String, ) -> PyResult { let conn = if database.hasattr("_conn")? { database .getattr("_conn")? .getattr("_inner")? .downcast_into::()? } else { database.getattr("_inner")?.downcast_into::()? }; let database = conn.borrow().database()?; slf.modify(|builder| builder.persist(database, table_name)) } #[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))] pub fn split_random( slf: PyRefMut<'_, Self>, ratios: Option>, counts: Option>, fixed: Option, seed: Option, split_names: Option>, ) -> PyResult { // Check that exactly one split type is provided let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()] .iter() .filter(|&&x| x) .count(); if split_args_count != 1 { return Err(pyo3::exceptions::PyValueError::new_err( "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", )); } let sizes = if let Some(ratios) = ratios { SplitSizes::Percentages(ratios) } else if let Some(counts) = counts { SplitSizes::Counts(counts) } else if let Some(fixed) = fixed { SplitSizes::Fixed(fixed) } else { unreachable!("One of the split arguments must be provided"); }; slf.modify(|builder| { builder.with_split_strategy(SplitStrategy::Random { seed, sizes }, split_names) }) } #[pyo3(signature = (columns, split_weights, *, discard_weight=0, split_names=None))] pub fn split_hash( slf: PyRefMut<'_, Self>, columns: Vec, split_weights: Vec, discard_weight: u64, split_names: Option>, ) -> PyResult { slf.modify(|builder| { builder.with_split_strategy( SplitStrategy::Hash { columns, split_weights, discard_weight, }, split_names, ) }) } #[pyo3(signature = (*, ratios=None, counts=None, fixed=None, split_names=None))] pub fn split_sequential( slf: PyRefMut<'_, Self>, ratios: Option>, counts: Option>, fixed: Option, split_names: Option>, ) -> PyResult { // Check that exactly one split type is provided let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()] .iter() .filter(|&&x| x) .count(); if split_args_count != 1 { return Err(pyo3::exceptions::PyValueError::new_err( "Exactly one of 'ratios', 'counts', or 'fixed' must be provided", )); } let sizes = if let Some(ratios) = ratios { SplitSizes::Percentages(ratios) } else if let Some(counts) = counts { SplitSizes::Counts(counts) } else if let Some(fixed) = fixed { SplitSizes::Fixed(fixed) } else { unreachable!("One of the split arguments must be provided"); }; slf.modify(|builder| { builder.with_split_strategy(SplitStrategy::Sequential { sizes }, split_names) }) } pub fn split_calculated( slf: PyRefMut<'_, Self>, calculation: String, split_names: Option>, ) -> PyResult { slf.modify(|builder| { builder.with_split_strategy(SplitStrategy::Calculated { calculation }, split_names) }) } pub fn shuffle( slf: PyRefMut<'_, Self>, seed: Option, clump_size: Option, ) -> PyResult { slf.modify(|builder| { builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size }) }) } pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult { slf.modify(|builder| builder.with_filter(filter)) } pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult> { let mut state = slf.state.lock().unwrap(); let builder = state .builder .take() .ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?; future_into_py(slf.py(), async move { let table = builder.build().await.infer_error()?; Ok(Table::new(table)) }) } } #[pyclass(name = "PermutationReader")] pub struct PyPermutationReader { reader: Arc, } impl PyPermutationReader { fn from_reader(reader: PermutationReader) -> Self { Self { reader: Arc::new(reader), } } fn parse_selection(selection: Option>) -> PyResult