feat: a utility for creating "permutation views" (#2552)

I'm working on a lancedb version of pytorch data loading (and hopefully
addressing https://github.com/lancedb/lance/issues/3727).

However, rather than rely on pytorch for everything I'm moving some of
the things that pytorch does into rust. This gives us more control over
data loading (e.g. using shards or a hash-based split) and it allows
permutations to be persistent. In particular I hope to be able to:

* Create a persistent permutation
* This permutation can handle splits, filtering, shuffling, and sharding
* Create a rust data loader that can read a permutation (one or more
splits), or a subset of a permutation (for DDP)
* Create a python data loader that delegates to the rust data loader

Eventually create integrations for other data loading libraries,
including rust & node
This commit is contained in:
Weston Pace
2025-10-09 18:07:31 -07:00
committed by GitHub
parent 3dcec724b7
commit 5a19cf15a6
38 changed files with 3786 additions and 58 deletions

View File

@@ -4,7 +4,10 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
use lancedb::{
connection::Connection as LanceConnection,
database::{CreateTableMode, ReadConsistency},
};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
@@ -23,7 +26,7 @@ impl Connection {
Self { inner: Some(inner) }
}
fn get_inner(&self) -> PyResult<&LanceConnection> {
pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
@@ -63,6 +66,18 @@ impl Connection {
self.get_inner().map(|inner| inner.uri().to_string())
}
#[pyo3(signature = ())]
pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
Ok(match inner.read_consistency().await.infer_error()? {
ReadConsistency::Manual => None,
ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
ReadConsistency::Strong => Some(0.0_f64),
})
})
}
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
pub fn table_names(
self_: PyRef<'_, Self>,

View File

@@ -5,6 +5,7 @@ use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::IndexConfig;
use permutation::PyAsyncPermutationBuilder;
use pyo3::{
pymodule,
types::{PyModule, PyModuleMethods},
@@ -22,6 +23,7 @@ pub mod connection;
pub mod error;
pub mod header;
pub mod index;
pub mod permutation;
pub mod query;
pub mod session;
pub mod table;
@@ -49,7 +51,9 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<DeleteResult>()?;
m.add_class::<DropColumnsResult>()?;
m.add_class::<UpdateResult>()?;
m.add_class::<PyAsyncPermutationBuilder>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

177
python/src/permutation.rs Normal file
View File

@@ -0,0 +1,177 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::PythonErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
};
use pyo3::{
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
PyResult,
};
use pyo3_async_runtimes::tokio::future_into_py;
/// Create a permutation builder for the given table
#[pyo3::pyfunction]
pub fn async_permutation_builder(
table: Bound<'_, PyAny>,
dest_table_name: String,
) -> PyResult<PyAsyncPermutationBuilder> {
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
let inner_table = table.borrow().inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PyAsyncPermutationBuilder {
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
builder: Some(inner_builder),
dest_table_name,
})),
})
}
struct PyAsyncPermutationBuilderState {
builder: Option<LancePermutationBuilder>,
dest_table_name: String,
}
#[pyclass(name = "AsyncPermutationBuilder")]
pub struct PyAsyncPermutationBuilder {
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
}
impl PyAsyncPermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> PyResult<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[pymethods]
impl PyAsyncPermutationBuilder {
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
pub fn split_random(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
seed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
pub fn split_hash(
slf: PyRefMut<'_, Self>,
columns: Vec<String>,
split_weights: Vec<u64>,
discard_weight: u64,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns,
split_weights,
discard_weight,
})
})
}
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
pub fn split_sequential(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
}
pub fn shuffle(
slf: PyRefMut<'_, Self>,
seed: Option<u64>,
clump_size: Option<u64>,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_filter(filter))
}
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let mut state = slf.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
future_into_py(slf.py(), async move {
let table = builder.build(&dest_table_name).await.infer_error()?;
Ok(Table::new(table))
})
}
}

View File

@@ -3,6 +3,7 @@
use std::{collections::HashMap, sync::Arc};
use crate::{
connection::Connection,
error::PythonErrorExt,
index::{extract_index_params, IndexConfig},
query::{Query, TakeQuery},
@@ -249,7 +250,7 @@ impl Table {
}
impl Table {
fn inner_ref(&self) -> PyResult<&LanceDbTable> {
pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
@@ -272,6 +273,13 @@ impl Table {
self.inner.take();
}
pub fn database(&self) -> PyResult<Connection> {
let inner = self.inner_ref()?.clone();
let inner_connection =
lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
Ok(Connection::new(inner_connection))
}
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {