mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 21:02:58 +00:00
feat: a utility for creating "permutation views" (#2552)
I'm working on a lancedb version of pytorch data loading (and hopefully addressing https://github.com/lancedb/lance/issues/3727). However, rather than rely on pytorch for everything I'm moving some of the things that pytorch does into rust. This gives us more control over data loading (e.g. using shards or a hash-based split) and it allows permutations to be persistent. In particular I hope to be able to: * Create a persistent permutation * This permutation can handle splits, filtering, shuffling, and sharding * Create a rust data loader that can read a permutation (one or more splits), or a subset of a permutation (for DDP) * Create a python data loader that delegates to the rust data loader Eventually create integrations for other data loading libraries, including rust & node
This commit is contained in:
@@ -4,7 +4,10 @@
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
|
||||
use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
database::{CreateTableMode, ReadConsistency},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
@@ -23,7 +26,7 @@ impl Connection {
|
||||
Self { inner: Some(inner) }
|
||||
}
|
||||
|
||||
fn get_inner(&self) -> PyResult<&LanceConnection> {
|
||||
pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
|
||||
self.inner
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
|
||||
@@ -63,6 +66,18 @@ impl Connection {
|
||||
self.get_inner().map(|inner| inner.uri().to_string())
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
Ok(match inner.read_consistency().await.infer_error()? {
|
||||
ReadConsistency::Manual => None,
|
||||
ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
|
||||
ReadConsistency::Strong => Some(0.0_f64),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
|
||||
pub fn table_names(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
@@ -5,6 +5,7 @@ use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::IndexConfig;
|
||||
use permutation::PyAsyncPermutationBuilder;
|
||||
use pyo3::{
|
||||
pymodule,
|
||||
types::{PyModule, PyModuleMethods},
|
||||
@@ -22,6 +23,7 @@ pub mod connection;
|
||||
pub mod error;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod session;
|
||||
pub mod table;
|
||||
@@ -49,7 +51,9 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<DeleteResult>()?;
|
||||
m.add_class::<DropColumnsResult>()?;
|
||||
m.add_class::<UpdateResult>()?;
|
||||
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
|
||||
177
python/src/permutation.rs
Normal file
177
python/src/permutation.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::PythonErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
split::{SplitSizes, SplitStrategy},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
|
||||
PyResult,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[pyo3::pyfunction]
|
||||
pub fn async_permutation_builder(
|
||||
table: Bound<'_, PyAny>,
|
||||
dest_table_name: String,
|
||||
) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let inner_table = table.borrow().inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
|
||||
Ok(PyAsyncPermutationBuilder {
|
||||
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
|
||||
builder: Some(inner_builder),
|
||||
dest_table_name,
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
struct PyAsyncPermutationBuilderState {
|
||||
builder: Option<LancePermutationBuilder>,
|
||||
dest_table_name: String,
|
||||
}
|
||||
|
||||
#[pyclass(name = "AsyncPermutationBuilder")]
|
||||
pub struct PyAsyncPermutationBuilder {
|
||||
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
|
||||
}
|
||||
|
||||
impl PyAsyncPermutationBuilder {
|
||||
fn modify(
|
||||
&self,
|
||||
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
|
||||
) -> PyResult<Self> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let builder = state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||
state.builder = Some(func(builder));
|
||||
Ok(Self {
|
||||
state: self.state.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyAsyncPermutationBuilder {
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
|
||||
pub fn split_random(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
seed: Option<u64>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
.iter()
|
||||
.filter(|&&x| x)
|
||||
.count();
|
||||
|
||||
if split_args_count != 1 {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
));
|
||||
}
|
||||
|
||||
let sizes = if let Some(ratios) = ratios {
|
||||
SplitSizes::Percentages(ratios)
|
||||
} else if let Some(counts) = counts {
|
||||
SplitSizes::Counts(counts)
|
||||
} else if let Some(fixed) = fixed {
|
||||
SplitSizes::Fixed(fixed)
|
||||
} else {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
|
||||
pub fn split_hash(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
columns: Vec<String>,
|
||||
split_weights: Vec<u64>,
|
||||
discard_weight: u64,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Hash {
|
||||
columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
|
||||
pub fn split_sequential(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
.iter()
|
||||
.filter(|&&x| x)
|
||||
.count();
|
||||
|
||||
if split_args_count != 1 {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
));
|
||||
}
|
||||
|
||||
let sizes = if let Some(ratios) = ratios {
|
||||
SplitSizes::Percentages(ratios)
|
||||
} else if let Some(counts) = counts {
|
||||
SplitSizes::Counts(counts)
|
||||
} else if let Some(fixed) = fixed {
|
||||
SplitSizes::Fixed(fixed)
|
||||
} else {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
||||
}
|
||||
|
||||
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
|
||||
}
|
||||
|
||||
pub fn shuffle(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
seed: Option<u64>,
|
||||
clump_size: Option<u64>,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
|
||||
})
|
||||
}
|
||||
|
||||
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
|
||||
slf.modify(|builder| builder.with_filter(filter))
|
||||
}
|
||||
|
||||
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let mut state = slf.state.lock().unwrap();
|
||||
let builder = state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||
|
||||
let dest_table_name = std::mem::take(&mut state.dest_table_name);
|
||||
|
||||
future_into_py(slf.py(), async move {
|
||||
let table = builder.build(&dest_table_name).await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
connection::Connection,
|
||||
error::PythonErrorExt,
|
||||
index::{extract_index_params, IndexConfig},
|
||||
query::{Query, TakeQuery},
|
||||
@@ -249,7 +250,7 @@ impl Table {
|
||||
}
|
||||
|
||||
impl Table {
|
||||
fn inner_ref(&self) -> PyResult<&LanceDbTable> {
|
||||
pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
|
||||
self.inner
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
|
||||
@@ -272,6 +273,13 @@ impl Table {
|
||||
self.inner.take();
|
||||
}
|
||||
|
||||
pub fn database(&self) -> PyResult<Connection> {
|
||||
let inner = self.inner_ref()?.clone();
|
||||
let inner_connection =
|
||||
lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
|
||||
Ok(Connection::new(inner_connection))
|
||||
}
|
||||
|
||||
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
|
||||
Reference in New Issue
Block a user