mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 04:12:59 +00:00
feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)
This commit is contained in:
@@ -6,7 +6,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
database::{CreateTableMode, ReadConsistency},
|
||||
database::{CreateTableMode, Database, ReadConsistency},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
@@ -42,6 +42,10 @@ impl Connection {
|
||||
_ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn database(&self) -> PyResult<Arc<dyn Database>> {
|
||||
Ok(self.get_inner()?.database().clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
||||
@@ -5,7 +5,7 @@ use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::IndexConfig;
|
||||
use permutation::PyAsyncPermutationBuilder;
|
||||
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
|
||||
use pyo3::{
|
||||
pymodule,
|
||||
types::{PyModule, PyModuleMethods},
|
||||
@@ -52,6 +52,7 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<DropColumnsResult>()?;
|
||||
m.add_class::<UpdateResult>()?;
|
||||
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||
m.add_class::<PyPermutationReader>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
|
||||
@@ -3,14 +3,23 @@
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::PythonErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
permutation::split::{SplitSizes, SplitStrategy},
|
||||
use crate::{
|
||||
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
|
||||
};
|
||||
use arrow::pyarrow::ToPyArrow;
|
||||
use lancedb::{
|
||||
dataloader::permutation::{
|
||||
builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
reader::PermutationReader,
|
||||
split::{SplitSizes, SplitStrategy},
|
||||
},
|
||||
query::Select,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
|
||||
PyResult,
|
||||
exceptions::PyRuntimeError,
|
||||
pyclass, pymethods,
|
||||
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
||||
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
@@ -56,13 +65,32 @@ impl PyAsyncPermutationBuilder {
|
||||
|
||||
#[pymethods]
|
||||
impl PyAsyncPermutationBuilder {
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
|
||||
#[pyo3(signature = (database, table_name))]
|
||||
pub fn persist(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
database: Bound<'_, PyAny>,
|
||||
table_name: String,
|
||||
) -> PyResult<Self> {
|
||||
let conn = if database.hasattr("_conn")? {
|
||||
database
|
||||
.getattr("_conn")?
|
||||
.getattr("_inner")?
|
||||
.downcast_into::<Connection>()?
|
||||
} else {
|
||||
database.getattr("_inner")?.downcast_into::<Connection>()?
|
||||
};
|
||||
let database = conn.borrow().database()?;
|
||||
slf.modify(|builder| builder.persist(database, table_name))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))]
|
||||
pub fn split_random(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
seed: Option<u64>,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
@@ -86,31 +114,38 @@ impl PyAsyncPermutationBuilder {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Random { seed, sizes }, split_names)
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
|
||||
#[pyo3(signature = (columns, split_weights, *, discard_weight=0, split_names=None))]
|
||||
pub fn split_hash(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
columns: Vec<String>,
|
||||
split_weights: Vec<u64>,
|
||||
discard_weight: u64,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Hash {
|
||||
columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
})
|
||||
builder.with_split_strategy(
|
||||
SplitStrategy::Hash {
|
||||
columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
},
|
||||
split_names,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, split_names=None))]
|
||||
pub fn split_sequential(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
@@ -134,11 +169,19 @@ impl PyAsyncPermutationBuilder {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, split_names)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
|
||||
pub fn split_calculated(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
calculation: String,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Calculated { calculation }, split_names)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn shuffle(
|
||||
@@ -168,3 +211,121 @@ impl PyAsyncPermutationBuilder {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(name = "PermutationReader")]
|
||||
pub struct PyPermutationReader {
|
||||
reader: Arc<PermutationReader>,
|
||||
}
|
||||
|
||||
impl PyPermutationReader {
|
||||
fn from_reader(reader: PermutationReader) -> Self {
|
||||
Self {
|
||||
reader: Arc::new(reader),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_selection(selection: Option<Bound<'_, PyAny>>) -> PyResult<Select> {
|
||||
let Some(selection) = selection else {
|
||||
return Ok(Select::All);
|
||||
};
|
||||
let selection = selection.downcast_into::<PyDict>()?;
|
||||
let selection = selection
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let key = key.extract::<String>()?;
|
||||
let value = value.extract::<String>()?;
|
||||
Ok((key, value))
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
Ok(Select::dynamic(&selection))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyPermutationReader {
|
||||
#[classmethod]
|
||||
pub fn from_tables<'py>(
|
||||
cls: &Bound<'py, PyType>,
|
||||
base_table: Bound<'py, PyAny>,
|
||||
permutation_table: Option<Bound<'py, PyAny>>,
|
||||
split: u64,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let base_table = base_table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let permutation_table = permutation_table
|
||||
.map(|p| PyResult::Ok(p.getattr("_inner")?.downcast_into::<Table>()?))
|
||||
.transpose()?;
|
||||
|
||||
let base_table = base_table.borrow().inner_ref()?.base_table().clone();
|
||||
let permutation_table = permutation_table
|
||||
.map(|p| PyResult::Ok(p.borrow().inner_ref()?.base_table().clone()))
|
||||
.transpose()?;
|
||||
|
||||
future_into_py(cls.py(), async move {
|
||||
let reader = if let Some(permutation_table) = permutation_table {
|
||||
PermutationReader::try_from_tables(base_table, permutation_table, split)
|
||||
.await
|
||||
.infer_error()?
|
||||
} else {
|
||||
PermutationReader::identity(base_table).await
|
||||
};
|
||||
Ok(Self::from_reader(reader))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (selection=None))]
|
||||
pub fn output_schema<'py>(
|
||||
slf: PyRef<'py, Self>,
|
||||
selection: Option<Bound<'py, PyAny>>,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let selection = Self::parse_selection(selection)?;
|
||||
let reader = slf.reader.clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let schema = reader.output_schema(selection).await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn count_rows<'py>(slf: PyRef<'py, Self>) -> u64 {
|
||||
slf.reader.count_rows()
|
||||
}
|
||||
|
||||
#[pyo3(signature = (offset))]
|
||||
pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult<Bound<'py, PyAny>> {
|
||||
let reader = slf.reader.as_ref().clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let reader = reader.with_offset(offset).await.infer_error()?;
|
||||
Ok(Self::from_reader(reader))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (limit))]
|
||||
pub fn with_limit<'py>(slf: PyRef<'py, Self>, limit: u64) -> PyResult<Bound<'py, PyAny>> {
|
||||
let reader = slf.reader.as_ref().clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let reader = reader.with_limit(limit).await.infer_error()?;
|
||||
Ok(Self::from_reader(reader))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (selection=None, *, batch_size=None))]
|
||||
pub fn read<'py>(
|
||||
slf: PyRef<'py, Self>,
|
||||
selection: Option<Bound<'py, PyAny>>,
|
||||
batch_size: Option<u32>,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let selection = Self::parse_selection(selection)?;
|
||||
let reader = slf.reader.clone();
|
||||
let batch_size = batch_size.unwrap_or(1024);
|
||||
future_into_py(slf.py(), async move {
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
let mut execution_options = QueryExecutionOptions::default();
|
||||
execution_options.max_batch_length = batch_size;
|
||||
let stream = reader
|
||||
.read(selection, execution_options)
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(RecordBatchStream::new(stream))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user