feat: add a getitems implementation for the permutation (#3013)

This commit is contained in:
Weston Pace
2026-02-12 05:36:11 -08:00
committed by GitHub
parent 4323ca0147
commit 02783bf440
4 changed files with 445 additions and 12 deletions

View File

@@ -39,6 +39,9 @@ pub struct PermutationReader {
limit: Option<u64>,
available_rows: u64,
split: u64,
// Cached map of offset to row id for the split
#[allow(clippy::type_complexity)]
offset_map: Arc<tokio::sync::Mutex<Option<Arc<HashMap<u64, u64>>>>>,
}
impl std::fmt::Debug for PermutationReader {
@@ -72,6 +75,7 @@ impl PermutationReader {
limit: None,
available_rows: 0,
split,
offset_map: Arc::new(tokio::sync::Mutex::new(None)),
};
slf.validate().await?;
// Calculate the number of available rows
@@ -157,6 +161,7 @@ impl PermutationReader {
let available_rows = self.verify_limit_offset(self.limit, Some(offset)).await?;
self.offset = Some(offset);
self.available_rows = available_rows;
self.offset_map = Arc::new(tokio::sync::Mutex::new(None));
Ok(self)
}
@@ -164,6 +169,7 @@ impl PermutationReader {
let available_rows = self.verify_limit_offset(Some(limit), self.offset).await?;
self.available_rows = available_rows;
self.limit = Some(limit);
self.offset_map = Arc::new(tokio::sync::Mutex::new(None));
Ok(self)
}
@@ -180,8 +186,9 @@ impl PermutationReader {
base_table: &Arc<dyn BaseTable>,
row_ids: RecordBatch,
selection: Select,
has_row_id: bool,
) -> Result<RecordBatch> {
let has_row_id = Self::has_row_id(&selection)?;
let num_rows = row_ids.num_rows();
let row_ids = row_ids
.column(0)
@@ -282,14 +289,13 @@ impl PermutationReader {
row_ids: DatasetRecordBatchStream,
selection: Select,
) -> Result<SendableRecordBatchStream> {
let has_row_id = Self::has_row_id(&selection)?;
let mut stream = row_ids
.map_err(Error::from)
.try_filter_map(move |batch| {
let selection = selection.clone();
let base_table = base_table.clone();
async move {
Self::load_batch(&base_table, batch, selection, has_row_id)
Self::load_batch(&base_table, batch, selection)
.await
.map(Some)
}
@@ -397,6 +403,82 @@ impl PermutationReader {
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
}
/// If we are going to use `take` then we load the offset -> row id map once for the split and cache it
///
/// This method fetches the map with find-or-create semantics.
async fn get_offset_map(
&self,
permutation_table: &Arc<dyn BaseTable>,
) -> Result<Arc<HashMap<u64, u64>>> {
let mut offset_map_ref = self.offset_map.lock().await;
if let Some(offset_map) = &*offset_map_ref {
return Ok(offset_map.clone());
}
let mut offset_map = HashMap::new();
let mut row_ids_query = Table::from(permutation_table.clone())
.query()
.select(Select::Columns(vec![SRC_ROW_ID_COL.to_string()]))
.only_if(format!("{} = {}", SPLIT_ID_COLUMN, self.split));
if let Some(offset) = self.offset {
row_ids_query = row_ids_query.offset(offset as usize);
}
if let Some(limit) = self.limit {
row_ids_query = row_ids_query.limit(limit as usize);
}
let mut row_ids = row_ids_query.execute().await?;
while let Some(batch) = row_ids.try_next().await? {
let row_ids = batch
.column(0)
.as_primitive::<UInt64Type>()
.values()
.to_vec();
for (i, row_id) in row_ids.iter().enumerate() {
offset_map.insert(i as u64, *row_id);
}
}
let offset_map = Arc::new(offset_map);
*offset_map_ref = Some(offset_map.clone());
Ok(offset_map)
}
pub async fn take_offsets(&self, offsets: &[u64], selection: Select) -> Result<RecordBatch> {
if let Some(permutation_table) = &self.permutation_table {
let offset_map = self.get_offset_map(permutation_table).await?;
let row_ids = offsets
.iter()
.map(|o| offset_map.get(o).copied().expect_ok().map_err(Error::from))
.collect::<Result<Vec<_>>>()?;
let row_ids = RecordBatch::try_new(
Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
"row_id",
arrow_schema::DataType::UInt64,
false,
)])),
vec![Arc::new(UInt64Array::from(row_ids))],
)?;
Self::load_batch(&self.base_table, row_ids, selection).await
} else {
let table = Table::from(self.base_table.clone());
let batches = table
.take_offsets(offsets.to_vec())
.select(selection.clone())
.execute()
.await?
.try_collect::<Vec<_>>()
.await?;
if let Some(first_batch) = batches.first() {
let schema = first_batch.schema();
let batch = arrow::compute::concat_batches(&schema, &batches)?;
Ok(batch)
} else {
Ok(RecordBatch::try_new(
self.output_schema(selection).await?,
vec![],
)?)
}
}
}
pub async fn output_schema(&self, selection: Select) -> Result<SchemaRef> {
let table = Table::from(self.base_table.clone());
table.query().select(selection).output_schema().await
@@ -543,4 +625,224 @@ mod tests {
check_batch(&mut stream, &row_ids[7..9]).await;
assert!(stream.try_next().await.unwrap().is_none());
}
/// Helper to create a base table and permutation table for take_offsets tests.
/// Returns (base_table, row_ids_table, shuffled_row_ids).
async fn setup_permutation_tables(num_rows: usize) -> (Table, Table, Vec<u64>) {
let base_table = lance_datagen::gen_batch()
.col("idx", lance_datagen::array::step::<Int32Type>())
.col("other_col", lance_datagen::array::step::<UInt64Type>())
.into_mem_table("tbl", RowCount::from(num_rows as u64), BatchCount::from(1))
.await;
let mut row_ids = collect_column::<UInt64Type>(&base_table, "_rowid").await;
row_ids.shuffle(&mut rand::rng());
let split_ids = UInt64Array::from_iter_values(std::iter::repeat_n(0u64, row_ids.len()));
let permutation_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("row_id", DataType::UInt64, false),
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
])),
vec![
Arc::new(UInt64Array::from(row_ids.clone())),
Arc::new(split_ids),
],
)
.unwrap();
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
(base_table, row_ids_table, row_ids)
}
#[tokio::test]
async fn test_take_offsets_with_permutation_table() {
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await;
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
// Take specific offsets and verify the returned rows match the permutation order
let offsets = vec![0, 2, 4];
let batch = reader.take_offsets(&offsets, Select::All).await.unwrap();
assert_eq!(batch.num_rows(), 3);
let idx_values = batch
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
let expected: Vec<i32> = offsets
.iter()
.map(|&o| row_ids[o as usize] as i32)
.collect();
assert_eq!(idx_values, expected);
}
#[tokio::test]
async fn test_take_offsets_preserves_order() {
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await;
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
// Take offsets in reverse order and verify returned rows match that order
let offsets = vec![5, 3, 1, 0];
let batch = reader.take_offsets(&offsets, Select::All).await.unwrap();
assert_eq!(batch.num_rows(), 4);
let idx_values = batch
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
let expected: Vec<i32> = offsets
.iter()
.map(|&o| row_ids[o as usize] as i32)
.collect();
assert_eq!(idx_values, expected);
}
#[tokio::test]
async fn test_take_offsets_with_column_selection() {
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await;
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
let offsets = vec![1, 3];
let batch = reader
.take_offsets(&offsets, Select::Columns(vec!["idx".to_string()]))
.await
.unwrap();
assert_eq!(batch.num_rows(), 2);
assert_eq!(batch.num_columns(), 1);
assert_eq!(batch.schema().field(0).name(), "idx");
let idx_values = batch
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
let expected: Vec<i32> = offsets
.iter()
.map(|&o| row_ids[o as usize] as i32)
.collect();
assert_eq!(idx_values, expected);
}
#[tokio::test]
async fn test_take_offsets_invalid_offset() {
let (base_table, row_ids_table, _) = setup_permutation_tables(5).await;
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
// Offset 999 doesn't exist in the offset map
let result = reader.take_offsets(&[0, 999], Select::All).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_take_offsets_identity_reader() {
let base_table = lance_datagen::gen_batch()
.col("idx", lance_datagen::array::step::<Int32Type>())
.into_mem_table("tbl", RowCount::from(10), BatchCount::from(1))
.await;
let reader = PermutationReader::identity(base_table.base_table().clone()).await;
// With no permutation table, take_offsets uses the base table directly
let offsets = vec![0, 2, 4, 6];
let batch = reader.take_offsets(&offsets, Select::All).await.unwrap();
assert_eq!(batch.num_rows(), 4);
let idx_values = batch
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
assert_eq!(idx_values, vec![0, 2, 4, 6]);
}
#[tokio::test]
async fn test_take_offsets_caches_offset_map() {
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await;
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
// First call populates the cache
let batch1 = reader.take_offsets(&[0, 1], Select::All).await.unwrap();
// Second call should use the cached offset map and produce consistent results
let batch2 = reader.take_offsets(&[0, 1], Select::All).await.unwrap();
let values1 = batch1
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
let values2 = batch2
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
assert_eq!(values1, values2);
let expected: Vec<i32> = vec![row_ids[0] as i32, row_ids[1] as i32];
assert_eq!(values1, expected);
}
#[tokio::test]
async fn test_take_offsets_single_offset() {
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(5).await;
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
let batch = reader.take_offsets(&[2], Select::All).await.unwrap();
assert_eq!(batch.num_rows(), 1);
let idx_values = batch
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
assert_eq!(idx_values, vec![row_ids[2] as i32]);
}
}