diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 2ed0d0218..04adf38bc 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -746,15 +746,20 @@ class Permutation: def __getitem__(self, index: int) -> Any: """ - Return a single row from the permutation - - The output will always be a python dictionary regardless of the format. - - This method is mostly useful for debugging and exploration. For actual - processing use [iter](#iter) or a torch data loader to perform batched - processing. + Returns a single row from the permutation by offset """ - pass + return self.__getitems__([index]) + + def __getitems__(self, indices: list[int]) -> Any: + """ + Returns rows from the permutation by offset + """ + + async def do_getitems(): + return await self.reader.take_offsets(indices, selection=self.selection) + + batch = LOOP.run(do_getitems()) + return self.transform_fn(batch) @deprecated(details="Use with_skip instead") def skip(self, skip: int) -> "Permutation": diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index e75a02a53..1a4d4a5dd 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -945,3 +945,112 @@ def test_custom_transform(mem_db): batch = batches[0] assert batch == pa.record_batch([range(10)], ["id"]) + + +def test_getitems_basic(some_permutation: Permutation): + """Test __getitems__ returns correct rows by offset.""" + result = some_permutation.__getitems__([0, 1, 2]) + assert isinstance(result, dict) + assert "id" in result + assert "value" in result + assert len(result["id"]) == 3 + + +def test_getitems_single_index(some_permutation: Permutation): + """Test __getitems__ with a single index.""" + result = some_permutation.__getitems__([0]) + assert len(result["id"]) == 1 + assert len(result["value"]) == 1 + + +def test_getitems_preserves_order(some_permutation: Permutation): + """Test __getitems__ returns rows in the requested order.""" + # Get rows in forward order + forward = some_permutation.__getitems__([0, 1, 2, 3, 4]) + # Get the same rows in reverse order + reverse = some_permutation.__getitems__([4, 3, 2, 1, 0]) + + assert forward["id"] == list(reversed(reverse["id"])) + assert forward["value"] == list(reversed(reverse["value"])) + + +def test_getitems_non_contiguous(some_permutation: Permutation): + """Test __getitems__ with non-contiguous indices.""" + result = some_permutation.__getitems__([0, 10, 50, 100, 500]) + assert len(result["id"]) == 5 + + # Each id/value pair should match what we'd get individually + for i, offset in enumerate([0, 10, 50, 100, 500]): + single = some_permutation.__getitems__([offset]) + assert result["id"][i] == single["id"][0] + assert result["value"][i] == single["value"][0] + + +def test_getitems_with_column_selection(some_permutation: Permutation): + """Test __getitems__ respects column selection.""" + id_only = some_permutation.select_columns(["id"]) + result = id_only.__getitems__([0, 1, 2]) + assert "id" in result + assert "value" not in result + assert len(result["id"]) == 3 + + +def test_getitems_with_column_rename(some_permutation: Permutation): + """Test __getitems__ respects column renames.""" + renamed = some_permutation.rename_column("value", "data") + result = renamed.__getitems__([0, 1]) + assert "data" in result + assert "value" not in result + assert len(result["data"]) == 2 + + +def test_getitems_with_format(some_permutation: Permutation): + """Test __getitems__ applies the transform function.""" + arrow_perm = some_permutation.with_format("arrow") + result = arrow_perm.__getitems__([0, 1, 2]) + assert isinstance(result, pa.RecordBatch) + assert result.num_rows == 3 + + +def test_getitems_with_custom_transform(some_permutation: Permutation): + """Test __getitems__ with a custom transform.""" + + def transform(batch: pa.RecordBatch) -> list: + return batch.column("id").to_pylist() + + custom = some_permutation.with_transform(transform) + result = custom.__getitems__([0, 1, 2]) + assert isinstance(result, list) + assert len(result) == 3 + + +def test_getitems_identity_permutation(mem_db): + """Test __getitems__ on an identity permutation.""" + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(10), "value": range(10)}) + ) + perm = Permutation.identity(tbl) + + result = perm.__getitems__([0, 5, 9]) + assert result["id"] == [0, 5, 9] + assert result["value"] == [0, 5, 9] + + +def test_getitems_with_limit_offset(some_permutation: Permutation): + """Test __getitems__ on a permutation with skip/take applied.""" + limited = some_permutation.with_skip(100).with_take(200) + + # Should be able to access offsets within the limited range + result = limited.__getitems__([0, 1, 199]) + assert len(result["id"]) == 3 + + # The first item of the limited permutation should match offset 100 of original + full_result = some_permutation.__getitems__([100]) + limited_result = limited.__getitems__([0]) + assert limited_result["id"][0] == full_result["id"][0] + + +def test_getitems_invalid_offset(some_permutation: Permutation): + """Test __getitems__ with an out-of-range offset raises an error.""" + with pytest.raises(Exception): + some_permutation.__getitems__([999999]) diff --git a/python/src/permutation.rs b/python/src/permutation.rs index 2beb482f6..192cf70f5 100644 --- a/python/src/permutation.rs +++ b/python/src/permutation.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex}; use crate::{ arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table, }; -use arrow::pyarrow::ToPyArrow; +use arrow::pyarrow::{PyArrowType, ToPyArrow}; use lancedb::{ dataloader::permutation::{ builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy}, @@ -328,4 +328,21 @@ impl PyPermutationReader { Ok(RecordBatchStream::new(stream)) }) } + + #[pyo3(signature = (indices, *, selection=None))] + pub fn take_offsets<'py>( + slf: PyRef<'py, Self>, + indices: Vec, + selection: Option>, + ) -> PyResult> { + let selection = Self::parse_selection(selection)?; + let reader = slf.reader.clone(); + future_into_py(slf.py(), async move { + let batch = reader + .take_offsets(&indices, selection) + .await + .infer_error()?; + Ok(PyArrowType(batch)) + }) + } } diff --git a/rust/lancedb/src/dataloader/permutation/reader.rs b/rust/lancedb/src/dataloader/permutation/reader.rs index ce453434a..e3397b25b 100644 --- a/rust/lancedb/src/dataloader/permutation/reader.rs +++ b/rust/lancedb/src/dataloader/permutation/reader.rs @@ -39,6 +39,9 @@ pub struct PermutationReader { limit: Option, available_rows: u64, split: u64, + // Cached map of offset to row id for the split + #[allow(clippy::type_complexity)] + offset_map: Arc>>>>, } impl std::fmt::Debug for PermutationReader { @@ -72,6 +75,7 @@ impl PermutationReader { limit: None, available_rows: 0, split, + offset_map: Arc::new(tokio::sync::Mutex::new(None)), }; slf.validate().await?; // Calculate the number of available rows @@ -157,6 +161,7 @@ impl PermutationReader { let available_rows = self.verify_limit_offset(self.limit, Some(offset)).await?; self.offset = Some(offset); self.available_rows = available_rows; + self.offset_map = Arc::new(tokio::sync::Mutex::new(None)); Ok(self) } @@ -164,6 +169,7 @@ impl PermutationReader { let available_rows = self.verify_limit_offset(Some(limit), self.offset).await?; self.available_rows = available_rows; self.limit = Some(limit); + self.offset_map = Arc::new(tokio::sync::Mutex::new(None)); Ok(self) } @@ -180,8 +186,9 @@ impl PermutationReader { base_table: &Arc, row_ids: RecordBatch, selection: Select, - has_row_id: bool, ) -> Result { + let has_row_id = Self::has_row_id(&selection)?; + let num_rows = row_ids.num_rows(); let row_ids = row_ids .column(0) @@ -282,14 +289,13 @@ impl PermutationReader { row_ids: DatasetRecordBatchStream, selection: Select, ) -> Result { - let has_row_id = Self::has_row_id(&selection)?; let mut stream = row_ids .map_err(Error::from) .try_filter_map(move |batch| { let selection = selection.clone(); let base_table = base_table.clone(); async move { - Self::load_batch(&base_table, batch, selection, has_row_id) + Self::load_batch(&base_table, batch, selection) .await .map(Some) } @@ -397,6 +403,82 @@ impl PermutationReader { Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await } + /// If we are going to use `take` then we load the offset -> row id map once for the split and cache it + /// + /// This method fetches the map with find-or-create semantics. + async fn get_offset_map( + &self, + permutation_table: &Arc, + ) -> Result>> { + let mut offset_map_ref = self.offset_map.lock().await; + if let Some(offset_map) = &*offset_map_ref { + return Ok(offset_map.clone()); + } + let mut offset_map = HashMap::new(); + let mut row_ids_query = Table::from(permutation_table.clone()) + .query() + .select(Select::Columns(vec![SRC_ROW_ID_COL.to_string()])) + .only_if(format!("{} = {}", SPLIT_ID_COLUMN, self.split)); + if let Some(offset) = self.offset { + row_ids_query = row_ids_query.offset(offset as usize); + } + if let Some(limit) = self.limit { + row_ids_query = row_ids_query.limit(limit as usize); + } + let mut row_ids = row_ids_query.execute().await?; + while let Some(batch) = row_ids.try_next().await? { + let row_ids = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + for (i, row_id) in row_ids.iter().enumerate() { + offset_map.insert(i as u64, *row_id); + } + } + let offset_map = Arc::new(offset_map); + *offset_map_ref = Some(offset_map.clone()); + Ok(offset_map) + } + + pub async fn take_offsets(&self, offsets: &[u64], selection: Select) -> Result { + if let Some(permutation_table) = &self.permutation_table { + let offset_map = self.get_offset_map(permutation_table).await?; + let row_ids = offsets + .iter() + .map(|o| offset_map.get(o).copied().expect_ok().map_err(Error::from)) + .collect::>>()?; + let row_ids = RecordBatch::try_new( + Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "row_id", + arrow_schema::DataType::UInt64, + false, + )])), + vec![Arc::new(UInt64Array::from(row_ids))], + )?; + Self::load_batch(&self.base_table, row_ids, selection).await + } else { + let table = Table::from(self.base_table.clone()); + let batches = table + .take_offsets(offsets.to_vec()) + .select(selection.clone()) + .execute() + .await? + .try_collect::>() + .await?; + if let Some(first_batch) = batches.first() { + let schema = first_batch.schema(); + let batch = arrow::compute::concat_batches(&schema, &batches)?; + Ok(batch) + } else { + Ok(RecordBatch::try_new( + self.output_schema(selection).await?, + vec![], + )?) + } + } + } + pub async fn output_schema(&self, selection: Select) -> Result { let table = Table::from(self.base_table.clone()); table.query().select(selection).output_schema().await @@ -543,4 +625,224 @@ mod tests { check_batch(&mut stream, &row_ids[7..9]).await; assert!(stream.try_next().await.unwrap().is_none()); } + + /// Helper to create a base table and permutation table for take_offsets tests. + /// Returns (base_table, row_ids_table, shuffled_row_ids). + async fn setup_permutation_tables(num_rows: usize) -> (Table, Table, Vec) { + let base_table = lance_datagen::gen_batch() + .col("idx", lance_datagen::array::step::()) + .col("other_col", lance_datagen::array::step::()) + .into_mem_table("tbl", RowCount::from(num_rows as u64), BatchCount::from(1)) + .await; + + let mut row_ids = collect_column::(&base_table, "_rowid").await; + row_ids.shuffle(&mut rand::rng()); + + let split_ids = UInt64Array::from_iter_values(std::iter::repeat_n(0u64, row_ids.len())); + let permutation_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("row_id", DataType::UInt64, false), + Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false), + ])), + vec![ + Arc::new(UInt64Array::from(row_ids.clone())), + Arc::new(split_ids), + ], + ) + .unwrap(); + let row_ids_table = virtual_table("row_ids", &permutation_batch).await; + + (base_table, row_ids_table, row_ids) + } + + #[tokio::test] + async fn test_take_offsets_with_permutation_table() { + let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + // Take specific offsets and verify the returned rows match the permutation order + let offsets = vec![0, 2, 4]; + let batch = reader.take_offsets(&offsets, Select::All).await.unwrap(); + + assert_eq!(batch.num_rows(), 3); + + let idx_values = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + let expected: Vec = offsets + .iter() + .map(|&o| row_ids[o as usize] as i32) + .collect(); + assert_eq!(idx_values, expected); + } + + #[tokio::test] + async fn test_take_offsets_preserves_order() { + let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + // Take offsets in reverse order and verify returned rows match that order + let offsets = vec![5, 3, 1, 0]; + let batch = reader.take_offsets(&offsets, Select::All).await.unwrap(); + + assert_eq!(batch.num_rows(), 4); + + let idx_values = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + let expected: Vec = offsets + .iter() + .map(|&o| row_ids[o as usize] as i32) + .collect(); + assert_eq!(idx_values, expected); + } + + #[tokio::test] + async fn test_take_offsets_with_column_selection() { + let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + let offsets = vec![1, 3]; + let batch = reader + .take_offsets(&offsets, Select::Columns(vec!["idx".to_string()])) + .await + .unwrap(); + + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.schema().field(0).name(), "idx"); + + let idx_values = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + let expected: Vec = offsets + .iter() + .map(|&o| row_ids[o as usize] as i32) + .collect(); + assert_eq!(idx_values, expected); + } + + #[tokio::test] + async fn test_take_offsets_invalid_offset() { + let (base_table, row_ids_table, _) = setup_permutation_tables(5).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + // Offset 999 doesn't exist in the offset map + let result = reader.take_offsets(&[0, 999], Select::All).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_take_offsets_identity_reader() { + let base_table = lance_datagen::gen_batch() + .col("idx", lance_datagen::array::step::()) + .into_mem_table("tbl", RowCount::from(10), BatchCount::from(1)) + .await; + + let reader = PermutationReader::identity(base_table.base_table().clone()).await; + + // With no permutation table, take_offsets uses the base table directly + let offsets = vec![0, 2, 4, 6]; + let batch = reader.take_offsets(&offsets, Select::All).await.unwrap(); + + assert_eq!(batch.num_rows(), 4); + + let idx_values = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + assert_eq!(idx_values, vec![0, 2, 4, 6]); + } + + #[tokio::test] + async fn test_take_offsets_caches_offset_map() { + let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + // First call populates the cache + let batch1 = reader.take_offsets(&[0, 1], Select::All).await.unwrap(); + + // Second call should use the cached offset map and produce consistent results + let batch2 = reader.take_offsets(&[0, 1], Select::All).await.unwrap(); + + let values1 = batch1 + .column(0) + .as_primitive::() + .values() + .to_vec(); + let values2 = batch2 + .column(0) + .as_primitive::() + .values() + .to_vec(); + assert_eq!(values1, values2); + + let expected: Vec = vec![row_ids[0] as i32, row_ids[1] as i32]; + assert_eq!(values1, expected); + } + + #[tokio::test] + async fn test_take_offsets_single_offset() { + let (base_table, row_ids_table, row_ids) = setup_permutation_tables(5).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + let batch = reader.take_offsets(&[2], Select::All).await.unwrap(); + + assert_eq!(batch.num_rows(), 1); + let idx_values = batch + .column(0) + .as_primitive::() + .values() + .to_vec(); + assert_eq!(idx_values, vec![row_ids[2] as i32]); + } }