feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)

This commit is contained in:
Weston Pace
2025-11-06 16:15:33 -08:00
committed by GitHub
parent 6ddd271627
commit aeac9c7644
24 changed files with 2071 additions and 126 deletions

View File

@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
@@ -25,6 +25,8 @@ use crate::{
pub const SRC_ROW_ID_COL: &str = "row_id";
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
/// Where to store the permutation table
#[derive(Debug, Clone, Default)]
enum PermutationDestination {
@@ -40,6 +42,8 @@ enum PermutationDestination {
pub struct PermutationConfig {
/// Splitting configuration
split_strategy: SplitStrategy,
/// Optional names for the splits
split_names: Option<Vec<String>>,
/// Shuffle strategy
shuffle_strategy: ShuffleStrategy,
/// Optional filter to apply to the base table
@@ -112,8 +116,16 @@ impl PermutationBuilder {
/// multiple processes and multiple nodes.
///
/// The default is a single split that contains all rows.
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
///
/// An optional list of names can be provided for the splits. This is for convenience and the names
/// will be stored in the permutation table's config metadata.
pub fn with_split_strategy(
mut self,
split_strategy: SplitStrategy,
split_names: Option<Vec<String>>,
) -> Self {
self.config.split_strategy = split_strategy;
self.config.split_names = split_names;
self
}
@@ -193,6 +205,30 @@ impl PermutationBuilder {
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
}
fn add_split_names(
data: SendableRecordBatchStream,
split_names: &[String],
) -> Result<SendableRecordBatchStream> {
let schema = data
.schema()
.as_ref()
.clone()
.with_metadata(HashMap::from([(
SPLIT_NAMES_CONFIG_KEY.to_string(),
serde_json::to_string(split_names).map_err(|e| Error::Other {
message: format!("Failed to serialize split names: {}", e),
source: Some(e.into()),
})?,
)]));
let schema = Arc::new(schema);
let schema_clone = schema.clone();
let stream = data.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
Ok(Box::pin(SimpleRecordBatchStream {
schema: schema_clone,
stream,
}))
}
/// Builds the permutation table and stores it in the given database.
pub async fn build(self) -> Result<Table> {
// First pass, apply filter and load row ids
@@ -249,6 +285,12 @@ impl PermutationBuilder {
// Rename _rowid to row_id
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
let streaming_data = if let Some(split_names) = &self.config.split_names {
Self::add_split_names(renamed, split_names)?
} else {
renamed
};
let (name, database) = match &self.config.destination {
PermutationDestination::Permanent(database, table_name) => {
(table_name.as_str(), database.clone())
@@ -259,10 +301,13 @@ impl PermutationBuilder {
}
};
let create_table_request =
CreateTableRequest::new(name.to_string(), CreateTableData::StreamingData(renamed));
let create_table_request = CreateTableRequest::new(
name.to_string(),
CreateTableData::StreamingData(streaming_data),
);
let table = database.create_table(create_table_request).await?;
Ok(Table::new(table, database))
}
}
@@ -296,10 +341,13 @@ mod tests {
let permutation_table = PermutationBuilder::new(data_table.clone())
.with_filter("some_value > 57".to_string())
.with_split_strategy(SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
})
.with_split_strategy(
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
},
None,
)
.build()
.await
.unwrap();

View File

@@ -11,14 +11,19 @@ use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
use crate::error::Error;
use crate::query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select};
use crate::table::{AnyQuery, BaseTable};
use crate::Result;
use crate::query::{
ExecutableQuery, QueryBase, QueryExecutionOptions, QueryFilter, QueryRequest, Select,
};
use crate::table::{AnyQuery, BaseTable, Filter};
use crate::{Result, Table};
use arrow::array::AsArray;
use arrow::compute::concat_batches;
use arrow::datatypes::UInt64Type;
use arrow_array::{RecordBatch, UInt64Array};
use arrow_schema::SchemaRef;
use futures::{StreamExt, TryStreamExt};
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::io::RecordBatchStream;
use lance_arrow::RecordBatchExt;
use lance_core::error::LanceOptionExt;
use lance_core::ROW_ID;
@@ -26,43 +31,140 @@ use std::collections::HashMap;
use std::sync::Arc;
/// Reads a permutation of a source table based on row IDs stored in a separate table
#[derive(Clone)]
pub struct PermutationReader {
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
permutation_table: Option<Arc<dyn BaseTable>>,
offset: Option<u64>,
limit: Option<u64>,
available_rows: u64,
split: u64,
}
impl std::fmt::Debug for PermutationReader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PermutationReader(base={}, permutation={})",
"PermutationReader(base={}, permutation={}, split={}, offset={:?}, limit={:?})",
self.base_table.name(),
self.permutation_table.name(),
self.permutation_table
.as_ref()
.map(|t| t.name())
.unwrap_or("--"),
self.split,
self.offset,
self.limit,
)
}
}
impl PermutationReader {
/// Create a new PermutationReader
pub async fn try_new(
pub async fn inner_new(
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
permutation_table: Option<Arc<dyn BaseTable>>,
split: u64,
) -> Result<Self> {
let schema = permutation_table.schema().await?;
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named row_id".to_string(),
});
}
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named split_id".to_string(),
});
}
Ok(Self {
let mut slf = Self {
base_table,
permutation_table,
})
offset: None,
limit: None,
available_rows: 0,
split,
};
slf.validate().await?;
// Calculate the number of available rows
slf.available_rows = slf.verify_limit_offset(None, None).await?;
if slf.available_rows == 0 {
return Err(Error::InvalidInput {
message: "No rows found in the permutation table for the given split".to_string(),
});
}
Ok(slf)
}
pub async fn try_from_tables(
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
split: u64,
) -> Result<Self> {
Self::inner_new(base_table, Some(permutation_table), split).await
}
pub async fn identity(base_table: Arc<dyn BaseTable>) -> Self {
Self::inner_new(base_table, None, 0).await.unwrap()
}
/// Validates the limit and offset and returns the number of rows that will be read
fn validate_limit_offset(
limit: Option<u64>,
offset: Option<u64>,
available_rows: u64,
) -> Result<u64> {
match (limit, offset) {
(Some(limit), Some(offset)) => {
if offset + limit > available_rows {
Err(Error::InvalidInput {
message: "Offset + limit is greater than the number of rows in the permutation table"
.to_string(),
})
} else {
Ok(limit)
}
}
(None, Some(offset)) => {
if offset > available_rows {
Err(Error::InvalidInput {
message:
"Offset is greater than the number of rows in the permutation table"
.to_string(),
})
} else {
Ok(available_rows - offset)
}
}
(Some(limit), None) => {
if limit > available_rows {
Err(Error::InvalidInput {
message:
"Limit is greater than the number of rows in the permutation table"
.to_string(),
})
} else {
Ok(limit)
}
}
(None, None) => Ok(available_rows),
}
}
async fn verify_limit_offset(&self, limit: Option<u64>, offset: Option<u64>) -> Result<u64> {
let available_rows = if let Some(permutation_table) = &self.permutation_table {
permutation_table
.count_rows(Some(Filter::Sql(format!(
"{} = {}",
SPLIT_ID_COLUMN, self.split
))))
.await? as u64
} else {
self.base_table.count_rows(None).await? as u64
};
Self::validate_limit_offset(limit, offset, available_rows)
}
pub async fn with_offset(mut self, offset: u64) -> Result<Self> {
let available_rows = self.verify_limit_offset(self.limit, Some(offset)).await?;
self.offset = Some(offset);
self.available_rows = available_rows;
Ok(self)
}
pub async fn with_limit(mut self, limit: u64) -> Result<Self> {
let available_rows = self.verify_limit_offset(Some(limit), self.offset).await?;
self.available_rows = available_rows;
self.limit = Some(limit);
Ok(self)
}
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
@@ -103,7 +205,7 @@ impl PermutationReader {
..Default::default()
};
let mut data = base_table
let data = base_table
.query(
&AnyQuery::Query(base_query),
QueryExecutionOptions {
@@ -112,25 +214,29 @@ impl PermutationReader {
},
)
.await?;
let schema = data.schema();
let Some(batch) = data.try_next().await? else {
let batches = data.try_collect::<Vec<_>>().await?;
if batches.is_empty() {
return Err(Error::InvalidInput {
message: "Base table returned no batches".to_string(),
});
};
if data.try_next().await?.is_some() {
return Err(Error::InvalidInput {
message: "Base table returned more than one batch".to_string(),
});
}
if batch.num_rows() != num_rows {
if batches.iter().map(|b| b.num_rows()).sum::<usize>() != num_rows {
return Err(Error::InvalidInput {
message: "Base table returned different number of rows than the number of row IDs"
.to_string(),
});
}
let batch = if batches.len() == 1 {
batches.into_iter().next().unwrap()
} else {
concat_batches(&schema, &batches)?
};
// There is no guarantee the result order will match the order provided
// so may need to restore order
let actual_row_ids = batch
@@ -230,26 +336,75 @@ impl PermutationReader {
}
}
pub async fn read_split(
async fn validate(&self) -> Result<()> {
if let Some(permutation_table) = &self.permutation_table {
let schema = permutation_table.schema().await?;
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named row_id".to_string(),
});
}
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named split_id".to_string(),
});
}
}
let avail_rows = if let Some(permutation_table) = &self.permutation_table {
permutation_table.count_rows(None).await? as u64
} else {
self.base_table.count_rows(None).await? as u64
};
Self::validate_limit_offset(self.limit, self.offset, avail_rows)?;
Ok(())
}
pub async fn read(
&self,
split: u64,
selection: Select,
execution_options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let row_ids = self
.permutation_table
.query(
&AnyQuery::Query(QueryRequest {
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
filter: Some(QueryFilter::Sql(format!("{} = {}", SPLIT_ID_COLUMN, split))),
..Default::default()
}),
execution_options,
)
.await?;
// Note: this relies on the row ids query here being returned in consistent order
let row_ids = if let Some(permutation_table) = &self.permutation_table {
permutation_table
.query(
&AnyQuery::Query(QueryRequest {
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
filter: Some(QueryFilter::Sql(format!(
"{} = {}",
SPLIT_ID_COLUMN, self.split
))),
offset: self.offset.map(|o| o as usize),
limit: self.limit.map(|l| l as usize),
..Default::default()
}),
execution_options,
)
.await?
} else {
self.base_table
.query(
&AnyQuery::Query(QueryRequest {
select: Select::Columns(vec![ROW_ID.to_string()]),
offset: self.offset.map(|o| o as usize),
limit: self.limit.map(|l| l as usize),
..Default::default()
}),
execution_options,
)
.await?
};
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
}
pub async fn output_schema(&self, selection: Select) -> Result<SchemaRef> {
let table = Table::from(self.base_table.clone());
table.query().select(selection).output_schema().await
}
pub fn count_rows(&self) -> u64 {
self.available_rows
}
}
#[cfg(test)]
@@ -321,17 +476,17 @@ mod tests {
.unwrap();
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
let reader = PermutationReader::try_new(
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
// Read split 0
let mut stream = reader
.read_split(
0,
.read(
Select::All,
QueryExecutionOptions {
max_batch_length: 3,
@@ -366,9 +521,16 @@ mod tests {
assert!(stream.try_next().await.unwrap().is_none());
// Read split 1
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
1,
)
.await
.unwrap();
let mut stream = reader
.read_split(
1,
.read(
Select::All,
QueryExecutionOptions {
max_batch_length: 3,

View File

@@ -34,7 +34,7 @@ pub(crate) const DEFAULT_TOP_K: usize = 10;
/// Which columns should be retrieved from the database
#[derive(Debug, Clone)]
pub enum Select {
/// Select all columns
/// Select all non-system columns
///
/// Warning: This will always be slower than selecting only the columns you need.
All,

View File

@@ -620,7 +620,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
#[derive(Clone, Debug)]
pub struct Table {
inner: Arc<dyn BaseTable>,
database: Arc<dyn Database>,
database: Option<Arc<dyn Database>>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
@@ -644,7 +644,7 @@ mod test_utils {
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
database: Some(database),
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -666,7 +666,7 @@ mod test_utils {
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
database: Some(database),
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -680,11 +680,21 @@ impl std::fmt::Display for Table {
}
}
impl From<Arc<dyn BaseTable>> for Table {
fn from(inner: Arc<dyn BaseTable>) -> Self {
Self {
inner,
database: None,
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
}
impl Table {
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
Self {
inner,
database,
database: Some(database),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -694,7 +704,7 @@ impl Table {
}
pub fn database(&self) -> &Arc<dyn Database> {
&self.database
self.database.as_ref().unwrap()
}
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
@@ -708,7 +718,7 @@ impl Table {
) -> Self {
Self {
inner,
database,
database: Some(database),
embedding_registry,
}
}