mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-28 19:40:39 +00:00
This PR migrates all Rust crates in the workspace to Rust 2024 edition and addresses the resulting compatibility updates. It also fixes all clippy warnings surfaced by the workspace checks so the codebase remains warning-free under the current lint configuration. Context: - Scope: workspace edition bump (`2021` -> `2024`) plus follow-up refactors required by new edition and clippy rules. - Validation: `cargo fmt --all` and `cargo clippy --quiet --features remote --tests --examples -- -D warnings` both pass.
362 lines
12 KiB
Rust
362 lines
12 KiB
Rust
// SPDX-License-Identifier: Apache-2.0
|
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use crate::{
|
|
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
|
|
};
|
|
use arrow::pyarrow::{PyArrowType, ToPyArrow};
|
|
use lancedb::{
|
|
dataloader::permutation::{
|
|
builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
|
reader::PermutationReader,
|
|
split::{SplitSizes, SplitStrategy},
|
|
},
|
|
query::Select,
|
|
};
|
|
use pyo3::{
|
|
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
|
|
exceptions::PyRuntimeError,
|
|
pyclass, pymethods,
|
|
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
|
};
|
|
use pyo3_async_runtimes::tokio::future_into_py;
|
|
|
|
fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult<Bound<'a, Table>> {
|
|
if table.hasattr("_inner")? {
|
|
Ok(table.getattr("_inner")?.downcast_into::<Table>()?)
|
|
} else if table.hasattr("_table")? {
|
|
Ok(table
|
|
.getattr("_table")?
|
|
.getattr("_inner")?
|
|
.downcast_into::<Table>()?)
|
|
} else {
|
|
Err(PyRuntimeError::new_err(
|
|
"Provided table does not appear to be a Table or RemoteTable instance",
|
|
))
|
|
}
|
|
}
|
|
|
|
/// Create a permutation builder for the given table
|
|
#[pyo3::pyfunction]
|
|
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
|
|
let table = table_from_py(table)?;
|
|
let inner_table = table.borrow().inner_ref()?.clone();
|
|
let inner_builder = LancePermutationBuilder::new(inner_table);
|
|
|
|
Ok(PyAsyncPermutationBuilder {
|
|
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
|
|
builder: Some(inner_builder),
|
|
})),
|
|
})
|
|
}
|
|
|
|
struct PyAsyncPermutationBuilderState {
|
|
builder: Option<LancePermutationBuilder>,
|
|
}
|
|
|
|
#[pyclass(name = "AsyncPermutationBuilder")]
|
|
pub struct PyAsyncPermutationBuilder {
|
|
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
|
|
}
|
|
|
|
impl PyAsyncPermutationBuilder {
|
|
fn modify(
|
|
&self,
|
|
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
|
|
) -> PyResult<Self> {
|
|
let mut state = self.state.lock().unwrap();
|
|
let builder = state
|
|
.builder
|
|
.take()
|
|
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
|
state.builder = Some(func(builder));
|
|
Ok(Self {
|
|
state: self.state.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[pymethods]
|
|
impl PyAsyncPermutationBuilder {
|
|
#[pyo3(signature = (database, table_name))]
|
|
pub fn persist(
|
|
slf: PyRefMut<'_, Self>,
|
|
database: Bound<'_, PyAny>,
|
|
table_name: String,
|
|
) -> PyResult<Self> {
|
|
let conn = if database.hasattr("_conn")? {
|
|
database
|
|
.getattr("_conn")?
|
|
.getattr("_inner")?
|
|
.downcast_into::<Connection>()?
|
|
} else {
|
|
database.getattr("_inner")?.downcast_into::<Connection>()?
|
|
};
|
|
let database = conn.borrow().database()?;
|
|
slf.modify(|builder| builder.persist(database, table_name))
|
|
}
|
|
|
|
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))]
|
|
pub fn split_random(
|
|
slf: PyRefMut<'_, Self>,
|
|
ratios: Option<Vec<f64>>,
|
|
counts: Option<Vec<u64>>,
|
|
fixed: Option<u64>,
|
|
seed: Option<u64>,
|
|
split_names: Option<Vec<String>>,
|
|
) -> PyResult<Self> {
|
|
// Check that exactly one split type is provided
|
|
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
|
.iter()
|
|
.filter(|&&x| x)
|
|
.count();
|
|
|
|
if split_args_count != 1 {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
|
));
|
|
}
|
|
|
|
let sizes = if let Some(ratios) = ratios {
|
|
SplitSizes::Percentages(ratios)
|
|
} else if let Some(counts) = counts {
|
|
SplitSizes::Counts(counts)
|
|
} else if let Some(fixed) = fixed {
|
|
SplitSizes::Fixed(fixed)
|
|
} else {
|
|
unreachable!("One of the split arguments must be provided");
|
|
};
|
|
|
|
slf.modify(|builder| {
|
|
builder.with_split_strategy(SplitStrategy::Random { seed, sizes }, split_names)
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (columns, split_weights, *, discard_weight=0, split_names=None))]
|
|
pub fn split_hash(
|
|
slf: PyRefMut<'_, Self>,
|
|
columns: Vec<String>,
|
|
split_weights: Vec<u64>,
|
|
discard_weight: u64,
|
|
split_names: Option<Vec<String>>,
|
|
) -> PyResult<Self> {
|
|
slf.modify(|builder| {
|
|
builder.with_split_strategy(
|
|
SplitStrategy::Hash {
|
|
columns,
|
|
split_weights,
|
|
discard_weight,
|
|
},
|
|
split_names,
|
|
)
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, split_names=None))]
|
|
pub fn split_sequential(
|
|
slf: PyRefMut<'_, Self>,
|
|
ratios: Option<Vec<f64>>,
|
|
counts: Option<Vec<u64>>,
|
|
fixed: Option<u64>,
|
|
split_names: Option<Vec<String>>,
|
|
) -> PyResult<Self> {
|
|
// Check that exactly one split type is provided
|
|
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
|
.iter()
|
|
.filter(|&&x| x)
|
|
.count();
|
|
|
|
if split_args_count != 1 {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
|
));
|
|
}
|
|
|
|
let sizes = if let Some(ratios) = ratios {
|
|
SplitSizes::Percentages(ratios)
|
|
} else if let Some(counts) = counts {
|
|
SplitSizes::Counts(counts)
|
|
} else if let Some(fixed) = fixed {
|
|
SplitSizes::Fixed(fixed)
|
|
} else {
|
|
unreachable!("One of the split arguments must be provided");
|
|
};
|
|
|
|
slf.modify(|builder| {
|
|
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, split_names)
|
|
})
|
|
}
|
|
|
|
pub fn split_calculated(
|
|
slf: PyRefMut<'_, Self>,
|
|
calculation: String,
|
|
split_names: Option<Vec<String>>,
|
|
) -> PyResult<Self> {
|
|
slf.modify(|builder| {
|
|
builder.with_split_strategy(SplitStrategy::Calculated { calculation }, split_names)
|
|
})
|
|
}
|
|
|
|
pub fn shuffle(
|
|
slf: PyRefMut<'_, Self>,
|
|
seed: Option<u64>,
|
|
clump_size: Option<u64>,
|
|
) -> PyResult<Self> {
|
|
slf.modify(|builder| {
|
|
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
|
|
})
|
|
}
|
|
|
|
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
|
|
slf.modify(|builder| builder.with_filter(filter))
|
|
}
|
|
|
|
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
|
let mut state = slf.state.lock().unwrap();
|
|
let builder = state
|
|
.builder
|
|
.take()
|
|
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
|
|
|
future_into_py(slf.py(), async move {
|
|
let table = builder.build().await.infer_error()?;
|
|
Ok(Table::new(table))
|
|
})
|
|
}
|
|
}
|
|
|
|
#[pyclass(name = "PermutationReader")]
|
|
pub struct PyPermutationReader {
|
|
reader: Arc<PermutationReader>,
|
|
}
|
|
|
|
impl PyPermutationReader {
|
|
fn from_reader(reader: PermutationReader) -> Self {
|
|
Self {
|
|
reader: Arc::new(reader),
|
|
}
|
|
}
|
|
|
|
fn parse_selection(selection: Option<Bound<'_, PyAny>>) -> PyResult<Select> {
|
|
let Some(selection) = selection else {
|
|
return Ok(Select::All);
|
|
};
|
|
let selection = selection.downcast_into::<PyDict>()?;
|
|
let selection = selection
|
|
.iter()
|
|
.map(|(key, value)| {
|
|
let key = key.extract::<String>()?;
|
|
let value = value.extract::<String>()?;
|
|
Ok((key, value))
|
|
})
|
|
.collect::<PyResult<Vec<_>>>()?;
|
|
Ok(Select::dynamic(&selection))
|
|
}
|
|
}
|
|
|
|
#[pymethods]
|
|
impl PyPermutationReader {
|
|
#[classmethod]
|
|
pub fn from_tables<'py>(
|
|
cls: &Bound<'py, PyType>,
|
|
base_table: Bound<'py, PyAny>,
|
|
permutation_table: Option<Bound<'py, PyAny>>,
|
|
split: u64,
|
|
) -> PyResult<Bound<'py, PyAny>> {
|
|
let base_table = table_from_py(base_table)?;
|
|
let permutation_table = permutation_table.map(table_from_py).transpose()?;
|
|
|
|
let base_table = base_table.borrow().inner_ref()?.base_table().clone();
|
|
let permutation_table = permutation_table
|
|
.map(|p| PyResult::Ok(p.borrow().inner_ref()?.base_table().clone()))
|
|
.transpose()?;
|
|
|
|
future_into_py(cls.py(), async move {
|
|
let reader = if let Some(permutation_table) = permutation_table {
|
|
PermutationReader::try_from_tables(base_table, permutation_table, split)
|
|
.await
|
|
.infer_error()?
|
|
} else {
|
|
PermutationReader::identity(base_table).await
|
|
};
|
|
Ok(Self::from_reader(reader))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (selection=None))]
|
|
pub fn output_schema<'py>(
|
|
slf: PyRef<'py, Self>,
|
|
selection: Option<Bound<'py, PyAny>>,
|
|
) -> PyResult<Bound<'py, PyAny>> {
|
|
let selection = Self::parse_selection(selection)?;
|
|
let reader = slf.reader.clone();
|
|
future_into_py(slf.py(), async move {
|
|
let schema = reader.output_schema(selection).await.infer_error()?;
|
|
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = ())]
|
|
pub fn count_rows<'py>(slf: PyRef<'py, Self>) -> u64 {
|
|
slf.reader.count_rows()
|
|
}
|
|
|
|
#[pyo3(signature = (offset))]
|
|
pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult<Bound<'py, PyAny>> {
|
|
let reader = slf.reader.as_ref().clone();
|
|
future_into_py(slf.py(), async move {
|
|
let reader = reader.with_offset(offset).await.infer_error()?;
|
|
Ok(Self::from_reader(reader))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (limit))]
|
|
pub fn with_limit<'py>(slf: PyRef<'py, Self>, limit: u64) -> PyResult<Bound<'py, PyAny>> {
|
|
let reader = slf.reader.as_ref().clone();
|
|
future_into_py(slf.py(), async move {
|
|
let reader = reader.with_limit(limit).await.infer_error()?;
|
|
Ok(Self::from_reader(reader))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (selection=None, *, batch_size=None))]
|
|
pub fn read<'py>(
|
|
slf: PyRef<'py, Self>,
|
|
selection: Option<Bound<'py, PyAny>>,
|
|
batch_size: Option<u32>,
|
|
) -> PyResult<Bound<'py, PyAny>> {
|
|
let selection = Self::parse_selection(selection)?;
|
|
let reader = slf.reader.clone();
|
|
let batch_size = batch_size.unwrap_or(1024);
|
|
future_into_py(slf.py(), async move {
|
|
use lancedb::query::QueryExecutionOptions;
|
|
let mut execution_options = QueryExecutionOptions::default();
|
|
execution_options.max_batch_length = batch_size;
|
|
let stream = reader
|
|
.read(selection, execution_options)
|
|
.await
|
|
.infer_error()?;
|
|
Ok(RecordBatchStream::new(stream))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (indices, *, selection=None))]
|
|
pub fn take_offsets<'py>(
|
|
slf: PyRef<'py, Self>,
|
|
indices: Vec<u64>,
|
|
selection: Option<Bound<'py, PyAny>>,
|
|
) -> PyResult<Bound<'py, PyAny>> {
|
|
let selection = Self::parse_selection(selection)?;
|
|
let reader = slf.reader.clone();
|
|
future_into_py(slf.py(), async move {
|
|
let batch = reader
|
|
.take_offsets(&indices, selection)
|
|
.await
|
|
.infer_error()?;
|
|
Ok(PyArrowType(batch))
|
|
})
|
|
}
|
|
}
|