mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-28 17:30:42 +00:00
feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use datafusion::prelude::{SessionConfig, SessionContext};
|
||||
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
|
||||
@@ -25,6 +25,8 @@ use crate::{
|
||||
|
||||
pub const SRC_ROW_ID_COL: &str = "row_id";
|
||||
|
||||
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
|
||||
|
||||
/// Where to store the permutation table
|
||||
#[derive(Debug, Clone, Default)]
|
||||
enum PermutationDestination {
|
||||
@@ -40,6 +42,8 @@ enum PermutationDestination {
|
||||
pub struct PermutationConfig {
|
||||
/// Splitting configuration
|
||||
split_strategy: SplitStrategy,
|
||||
/// Optional names for the splits
|
||||
split_names: Option<Vec<String>>,
|
||||
/// Shuffle strategy
|
||||
shuffle_strategy: ShuffleStrategy,
|
||||
/// Optional filter to apply to the base table
|
||||
@@ -112,8 +116,16 @@ impl PermutationBuilder {
|
||||
/// multiple processes and multiple nodes.
|
||||
///
|
||||
/// The default is a single split that contains all rows.
|
||||
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
|
||||
///
|
||||
/// An optional list of names can be provided for the splits. This is for convenience and the names
|
||||
/// will be stored in the permutation table's config metadata.
|
||||
pub fn with_split_strategy(
|
||||
mut self,
|
||||
split_strategy: SplitStrategy,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> Self {
|
||||
self.config.split_strategy = split_strategy;
|
||||
self.config.split_names = split_names;
|
||||
self
|
||||
}
|
||||
|
||||
@@ -193,6 +205,30 @@ impl PermutationBuilder {
|
||||
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
|
||||
}
|
||||
|
||||
fn add_split_names(
|
||||
data: SendableRecordBatchStream,
|
||||
split_names: &[String],
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let schema = data
|
||||
.schema()
|
||||
.as_ref()
|
||||
.clone()
|
||||
.with_metadata(HashMap::from([(
|
||||
SPLIT_NAMES_CONFIG_KEY.to_string(),
|
||||
serde_json::to_string(split_names).map_err(|e| Error::Other {
|
||||
message: format!("Failed to serialize split names: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?,
|
||||
)]));
|
||||
let schema = Arc::new(schema);
|
||||
let schema_clone = schema.clone();
|
||||
let stream = data.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
|
||||
Ok(Box::pin(SimpleRecordBatchStream {
|
||||
schema: schema_clone,
|
||||
stream,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Builds the permutation table and stores it in the given database.
|
||||
pub async fn build(self) -> Result<Table> {
|
||||
// First pass, apply filter and load row ids
|
||||
@@ -249,6 +285,12 @@ impl PermutationBuilder {
|
||||
// Rename _rowid to row_id
|
||||
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
|
||||
|
||||
let streaming_data = if let Some(split_names) = &self.config.split_names {
|
||||
Self::add_split_names(renamed, split_names)?
|
||||
} else {
|
||||
renamed
|
||||
};
|
||||
|
||||
let (name, database) = match &self.config.destination {
|
||||
PermutationDestination::Permanent(database, table_name) => {
|
||||
(table_name.as_str(), database.clone())
|
||||
@@ -259,10 +301,13 @@ impl PermutationBuilder {
|
||||
}
|
||||
};
|
||||
|
||||
let create_table_request =
|
||||
CreateTableRequest::new(name.to_string(), CreateTableData::StreamingData(renamed));
|
||||
let create_table_request = CreateTableRequest::new(
|
||||
name.to_string(),
|
||||
CreateTableData::StreamingData(streaming_data),
|
||||
);
|
||||
|
||||
let table = database.create_table(create_table_request).await?;
|
||||
|
||||
Ok(Table::new(table, database))
|
||||
}
|
||||
}
|
||||
@@ -296,10 +341,13 @@ mod tests {
|
||||
|
||||
let permutation_table = PermutationBuilder::new(data_table.clone())
|
||||
.with_filter("some_value > 57".to_string())
|
||||
.with_split_strategy(SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
|
||||
})
|
||||
.with_split_strategy(
|
||||
SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
|
||||
},
|
||||
None,
|
||||
)
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -11,14 +11,19 @@ use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
|
||||
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
|
||||
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
|
||||
use crate::error::Error;
|
||||
use crate::query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select};
|
||||
use crate::table::{AnyQuery, BaseTable};
|
||||
use crate::Result;
|
||||
use crate::query::{
|
||||
ExecutableQuery, QueryBase, QueryExecutionOptions, QueryFilter, QueryRequest, Select,
|
||||
};
|
||||
use crate::table::{AnyQuery, BaseTable, Filter};
|
||||
use crate::{Result, Table};
|
||||
use arrow::array::AsArray;
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow::datatypes::UInt64Type;
|
||||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
use arrow_schema::SchemaRef;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::io::RecordBatchStream;
|
||||
use lance_arrow::RecordBatchExt;
|
||||
use lance_core::error::LanceOptionExt;
|
||||
use lance_core::ROW_ID;
|
||||
@@ -26,43 +31,140 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Reads a permutation of a source table based on row IDs stored in a separate table
|
||||
#[derive(Clone)]
|
||||
pub struct PermutationReader {
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Option<Arc<dyn BaseTable>>,
|
||||
offset: Option<u64>,
|
||||
limit: Option<u64>,
|
||||
available_rows: u64,
|
||||
split: u64,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PermutationReader {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"PermutationReader(base={}, permutation={})",
|
||||
"PermutationReader(base={}, permutation={}, split={}, offset={:?}, limit={:?})",
|
||||
self.base_table.name(),
|
||||
self.permutation_table.name(),
|
||||
self.permutation_table
|
||||
.as_ref()
|
||||
.map(|t| t.name())
|
||||
.unwrap_or("--"),
|
||||
self.split,
|
||||
self.offset,
|
||||
self.limit,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PermutationReader {
|
||||
/// Create a new PermutationReader
|
||||
pub async fn try_new(
|
||||
pub async fn inner_new(
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Option<Arc<dyn BaseTable>>,
|
||||
split: u64,
|
||||
) -> Result<Self> {
|
||||
let schema = permutation_table.schema().await?;
|
||||
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named row_id".to_string(),
|
||||
});
|
||||
}
|
||||
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named split_id".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(Self {
|
||||
let mut slf = Self {
|
||||
base_table,
|
||||
permutation_table,
|
||||
})
|
||||
offset: None,
|
||||
limit: None,
|
||||
available_rows: 0,
|
||||
split,
|
||||
};
|
||||
slf.validate().await?;
|
||||
// Calculate the number of available rows
|
||||
slf.available_rows = slf.verify_limit_offset(None, None).await?;
|
||||
if slf.available_rows == 0 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "No rows found in the permutation table for the given split".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(slf)
|
||||
}
|
||||
|
||||
pub async fn try_from_tables(
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Arc<dyn BaseTable>,
|
||||
split: u64,
|
||||
) -> Result<Self> {
|
||||
Self::inner_new(base_table, Some(permutation_table), split).await
|
||||
}
|
||||
|
||||
pub async fn identity(base_table: Arc<dyn BaseTable>) -> Self {
|
||||
Self::inner_new(base_table, None, 0).await.unwrap()
|
||||
}
|
||||
|
||||
/// Validates the limit and offset and returns the number of rows that will be read
|
||||
fn validate_limit_offset(
|
||||
limit: Option<u64>,
|
||||
offset: Option<u64>,
|
||||
available_rows: u64,
|
||||
) -> Result<u64> {
|
||||
match (limit, offset) {
|
||||
(Some(limit), Some(offset)) => {
|
||||
if offset + limit > available_rows {
|
||||
Err(Error::InvalidInput {
|
||||
message: "Offset + limit is greater than the number of rows in the permutation table"
|
||||
.to_string(),
|
||||
})
|
||||
} else {
|
||||
Ok(limit)
|
||||
}
|
||||
}
|
||||
(None, Some(offset)) => {
|
||||
if offset > available_rows {
|
||||
Err(Error::InvalidInput {
|
||||
message:
|
||||
"Offset is greater than the number of rows in the permutation table"
|
||||
.to_string(),
|
||||
})
|
||||
} else {
|
||||
Ok(available_rows - offset)
|
||||
}
|
||||
}
|
||||
(Some(limit), None) => {
|
||||
if limit > available_rows {
|
||||
Err(Error::InvalidInput {
|
||||
message:
|
||||
"Limit is greater than the number of rows in the permutation table"
|
||||
.to_string(),
|
||||
})
|
||||
} else {
|
||||
Ok(limit)
|
||||
}
|
||||
}
|
||||
(None, None) => Ok(available_rows),
|
||||
}
|
||||
}
|
||||
|
||||
async fn verify_limit_offset(&self, limit: Option<u64>, offset: Option<u64>) -> Result<u64> {
|
||||
let available_rows = if let Some(permutation_table) = &self.permutation_table {
|
||||
permutation_table
|
||||
.count_rows(Some(Filter::Sql(format!(
|
||||
"{} = {}",
|
||||
SPLIT_ID_COLUMN, self.split
|
||||
))))
|
||||
.await? as u64
|
||||
} else {
|
||||
self.base_table.count_rows(None).await? as u64
|
||||
};
|
||||
Self::validate_limit_offset(limit, offset, available_rows)
|
||||
}
|
||||
|
||||
pub async fn with_offset(mut self, offset: u64) -> Result<Self> {
|
||||
let available_rows = self.verify_limit_offset(self.limit, Some(offset)).await?;
|
||||
self.offset = Some(offset);
|
||||
self.available_rows = available_rows;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub async fn with_limit(mut self, limit: u64) -> Result<Self> {
|
||||
let available_rows = self.verify_limit_offset(Some(limit), self.offset).await?;
|
||||
self.available_rows = available_rows;
|
||||
self.limit = Some(limit);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
|
||||
@@ -103,7 +205,7 @@ impl PermutationReader {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut data = base_table
|
||||
let data = base_table
|
||||
.query(
|
||||
&AnyQuery::Query(base_query),
|
||||
QueryExecutionOptions {
|
||||
@@ -112,25 +214,29 @@ impl PermutationReader {
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let schema = data.schema();
|
||||
|
||||
let Some(batch) = data.try_next().await? else {
|
||||
let batches = data.try_collect::<Vec<_>>().await?;
|
||||
|
||||
if batches.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned no batches".to_string(),
|
||||
});
|
||||
};
|
||||
if data.try_next().await?.is_some() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned more than one batch".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if batch.num_rows() != num_rows {
|
||||
if batches.iter().map(|b| b.num_rows()).sum::<usize>() != num_rows {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned different number of rows than the number of row IDs"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let batch = if batches.len() == 1 {
|
||||
batches.into_iter().next().unwrap()
|
||||
} else {
|
||||
concat_batches(&schema, &batches)?
|
||||
};
|
||||
|
||||
// There is no guarantee the result order will match the order provided
|
||||
// so may need to restore order
|
||||
let actual_row_ids = batch
|
||||
@@ -230,26 +336,75 @@ impl PermutationReader {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_split(
|
||||
async fn validate(&self) -> Result<()> {
|
||||
if let Some(permutation_table) = &self.permutation_table {
|
||||
let schema = permutation_table.schema().await?;
|
||||
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named row_id".to_string(),
|
||||
});
|
||||
}
|
||||
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named split_id".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
let avail_rows = if let Some(permutation_table) = &self.permutation_table {
|
||||
permutation_table.count_rows(None).await? as u64
|
||||
} else {
|
||||
self.base_table.count_rows(None).await? as u64
|
||||
};
|
||||
Self::validate_limit_offset(self.limit, self.offset, avail_rows)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn read(
|
||||
&self,
|
||||
split: u64,
|
||||
selection: Select,
|
||||
execution_options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let row_ids = self
|
||||
.permutation_table
|
||||
.query(
|
||||
&AnyQuery::Query(QueryRequest {
|
||||
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
|
||||
filter: Some(QueryFilter::Sql(format!("{} = {}", SPLIT_ID_COLUMN, split))),
|
||||
..Default::default()
|
||||
}),
|
||||
execution_options,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Note: this relies on the row ids query here being returned in consistent order
|
||||
let row_ids = if let Some(permutation_table) = &self.permutation_table {
|
||||
permutation_table
|
||||
.query(
|
||||
&AnyQuery::Query(QueryRequest {
|
||||
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
|
||||
filter: Some(QueryFilter::Sql(format!(
|
||||
"{} = {}",
|
||||
SPLIT_ID_COLUMN, self.split
|
||||
))),
|
||||
offset: self.offset.map(|o| o as usize),
|
||||
limit: self.limit.map(|l| l as usize),
|
||||
..Default::default()
|
||||
}),
|
||||
execution_options,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
self.base_table
|
||||
.query(
|
||||
&AnyQuery::Query(QueryRequest {
|
||||
select: Select::Columns(vec![ROW_ID.to_string()]),
|
||||
offset: self.offset.map(|o| o as usize),
|
||||
limit: self.limit.map(|l| l as usize),
|
||||
..Default::default()
|
||||
}),
|
||||
execution_options,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
|
||||
}
|
||||
|
||||
pub async fn output_schema(&self, selection: Select) -> Result<SchemaRef> {
|
||||
let table = Table::from(self.base_table.clone());
|
||||
table.query().select(selection).output_schema().await
|
||||
}
|
||||
|
||||
pub fn count_rows(&self) -> u64 {
|
||||
self.available_rows
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -321,17 +476,17 @@ mod tests {
|
||||
.unwrap();
|
||||
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
|
||||
|
||||
let reader = PermutationReader::try_new(
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Read split 0
|
||||
let mut stream = reader
|
||||
.read_split(
|
||||
0,
|
||||
.read(
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: 3,
|
||||
@@ -366,9 +521,16 @@ mod tests {
|
||||
assert!(stream.try_next().await.unwrap().is_none());
|
||||
|
||||
// Read split 1
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
1,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut stream = reader
|
||||
.read_split(
|
||||
1,
|
||||
.read(
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: 3,
|
||||
|
||||
@@ -34,7 +34,7 @@ pub(crate) const DEFAULT_TOP_K: usize = 10;
|
||||
/// Which columns should be retrieved from the database
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Select {
|
||||
/// Select all columns
|
||||
/// Select all non-system columns
|
||||
///
|
||||
/// Warning: This will always be slower than selecting only the columns you need.
|
||||
All,
|
||||
|
||||
@@ -620,7 +620,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Table {
|
||||
inner: Arc<dyn BaseTable>,
|
||||
database: Arc<dyn Database>,
|
||||
database: Option<Arc<dyn Database>>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
@@ -644,7 +644,7 @@ mod test_utils {
|
||||
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
database: Some(database),
|
||||
// Registry is unused.
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
@@ -666,7 +666,7 @@ mod test_utils {
|
||||
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
database: Some(database),
|
||||
// Registry is unused.
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
@@ -680,11 +680,21 @@ impl std::fmt::Display for Table {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Arc<dyn BaseTable>> for Table {
|
||||
fn from(inner: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
database: None,
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Table {
|
||||
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
database: Some(database),
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
@@ -694,7 +704,7 @@ impl Table {
|
||||
}
|
||||
|
||||
pub fn database(&self) -> &Arc<dyn Database> {
|
||||
&self.database
|
||||
self.database.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
|
||||
@@ -708,7 +718,7 @@ impl Table {
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
database: Some(database),
|
||||
embedding_registry,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user