diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 91532f0a7..fdcebc69e 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -968,22 +968,32 @@ class Permutation: new.transform_fn = transform return new + def take_offsets(self, offsets: list[int]) -> Any: + """ + Take rows from the permutation by offset + + The returned value is passed through the permutation's current transform, + so `with_format` and `with_transform` affect this method in the same way + they affect iteration. + """ + + async def do_take_offsets(): + return await self.reader.take_offsets(offsets, selection=self.selection) + + batch = LOOP.run(do_take_offsets()) + return self.transform_fn(batch) + def __getitem__(self, index: int) -> Any: """ Returns a single row from the permutation by offset """ - return self.__getitems__([index]) + return self.take_offsets([index]) def __getitems__(self, indices: list[int]) -> Any: """ Returns rows from the permutation by offset """ - - async def do_getitems(): - return await self.reader.take_offsets(indices, selection=self.selection) - - batch = LOOP.run(do_getitems()) - return self.transform_fn(batch) + return self.take_offsets(indices) @deprecated(details="Use with_skip instead") def skip(self, skip: int) -> "Permutation": diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index 96d77f9d1..e58a49624 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -1080,3 +1080,29 @@ def test_getitems_invalid_offset(some_permutation: Permutation): """Test __getitems__ with an out-of-range offset raises an error.""" with pytest.raises(Exception): some_permutation.__getitems__([999999]) + + +def test_take_offsets(some_permutation: Permutation): + result = some_permutation.take_offsets([0, 1, 2]) + + assert isinstance(result, list) + assert "id" in result[0] + assert "value" in result[0] + assert len(result) == 3 + + +def test_take_offsets_empty_identity_permutation(mem_db): + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(10), "value": range(10)}) + ) + permutation = Permutation.identity(tbl) + + result = permutation.take_offsets([]) + + assert result == [] + + +def test_take_offsets_empty_permutation(some_permutation: Permutation): + result = some_permutation.take_offsets([]) + + assert result == [] diff --git a/rust/lancedb/src/dataloader/permutation/reader.rs b/rust/lancedb/src/dataloader/permutation/reader.rs index 43c229a0a..65d065db7 100644 --- a/rust/lancedb/src/dataloader/permutation/reader.rs +++ b/rust/lancedb/src/dataloader/permutation/reader.rs @@ -450,6 +450,10 @@ impl PermutationReader { } pub async fn take_offsets(&self, offsets: &[u64], selection: Select) -> Result { + if offsets.is_empty() { + return Ok(RecordBatch::new_empty(self.output_schema(selection).await?)); + } + if let Some(permutation_table) = &self.permutation_table { let offset_map = self.get_offset_map(permutation_table).await?; let row_ids = offsets @@ -955,4 +959,62 @@ mod tests { .to_vec(); assert_eq!(idx_values, &all_idx_values[4997..5000]); } + + #[tokio::test] + async fn test_take_offsets_empty_identity_reader() { + let base_table = lance_datagen::gen_batch() + .col("idx", lance_datagen::array::step::()) + .into_mem_table("tbl", RowCount::from(10), BatchCount::from(1)) + .await; + + let reader = PermutationReader::identity(base_table.base_table().clone()).await; + + let batch = reader.take_offsets(&[], Select::All).await.unwrap(); + + assert_eq!(batch.num_rows(), 0); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.schema().field(0).name(), "idx"); + } + + #[tokio::test] + async fn test_take_offsets_empty_with_permutation_table() { + let (base_table, row_ids_table, _) = setup_permutation_tables(5).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + let batch = reader.take_offsets(&[], Select::All).await.unwrap(); + + assert_eq!(batch.num_rows(), 0); + assert_eq!(batch.schema().fields().len(), 2); + assert_eq!(batch.schema().field(0).name(), "idx"); + assert_eq!(batch.schema().field(1).name(), "other_col"); + } + + #[tokio::test] + async fn test_take_offsets_empty_with_column_selection() { + let (base_table, row_ids_table, _) = setup_permutation_tables(5).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + let batch = reader + .take_offsets(&[], Select::Columns(vec!["idx".to_string()])) + .await + .unwrap(); + + assert_eq!(batch.num_rows(), 0); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.schema().field(0).name(), "idx"); + } }