From aac6c624591d2ea19492273e6382e87cf60a9a42 Mon Sep 17 00:00:00 2001 From: Drew Gallardo Date: Mon, 18 May 2026 09:35:56 -0700 Subject: [PATCH] feat(python): add public take_offsets method on Permutation (#3375) Closes #3243. This PR exposes a new public api `Permutation.take_offsets(offsets: list[int])`, since users initially had to call __getitems__ directly to batch-fetch rows by position. Currently, the name matches the existing `Table.take_offsets` pattern, and now the dunder `__getitem__` and `__getitems__` now delegate to it. Also, fixes a parse error when `PermutationReader::take_offsets` gets an empty list. Now returns an empty `RecordBatch` with the correct schema instead. Bundled this because without the fix the new public API blows up on a perfectly reasonable input. `__getitems__` is preserved since PyTorch's batched DataLoader requires it. ### Testing - Added 3 new Rust tests for empty offsets including permutation table with Select::All, Select::Columns, and identity path - Added 3 new Python tests for the public API including a happy case, and empty input on both identity and permutation clippy, format, check all clean! cc: @westonpace --- python/python/lancedb/permutation.py | 24 ++++--- python/python/tests/test_permutation.py | 26 ++++++++ .../src/dataloader/permutation/reader.rs | 62 +++++++++++++++++++ 3 files changed, 105 insertions(+), 7 deletions(-) diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 91532f0a7..fdcebc69e 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -968,22 +968,32 @@ class Permutation: new.transform_fn = transform return new + def take_offsets(self, offsets: list[int]) -> Any: + """ + Take rows from the permutation by offset + + The returned value is passed through the permutation's current transform, + so `with_format` and `with_transform` affect this method in the same way + they affect iteration. + """ + + async def do_take_offsets(): + return await self.reader.take_offsets(offsets, selection=self.selection) + + batch = LOOP.run(do_take_offsets()) + return self.transform_fn(batch) + def __getitem__(self, index: int) -> Any: """ Returns a single row from the permutation by offset """ - return self.__getitems__([index]) + return self.take_offsets([index]) def __getitems__(self, indices: list[int]) -> Any: """ Returns rows from the permutation by offset """ - - async def do_getitems(): - return await self.reader.take_offsets(indices, selection=self.selection) - - batch = LOOP.run(do_getitems()) - return self.transform_fn(batch) + return self.take_offsets(indices) @deprecated(details="Use with_skip instead") def skip(self, skip: int) -> "Permutation": diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index 96d77f9d1..e58a49624 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -1080,3 +1080,29 @@ def test_getitems_invalid_offset(some_permutation: Permutation): """Test __getitems__ with an out-of-range offset raises an error.""" with pytest.raises(Exception): some_permutation.__getitems__([999999]) + + +def test_take_offsets(some_permutation: Permutation): + result = some_permutation.take_offsets([0, 1, 2]) + + assert isinstance(result, list) + assert "id" in result[0] + assert "value" in result[0] + assert len(result) == 3 + + +def test_take_offsets_empty_identity_permutation(mem_db): + tbl = mem_db.create_table( + "test_table", pa.table({"id": range(10), "value": range(10)}) + ) + permutation = Permutation.identity(tbl) + + result = permutation.take_offsets([]) + + assert result == [] + + +def test_take_offsets_empty_permutation(some_permutation: Permutation): + result = some_permutation.take_offsets([]) + + assert result == [] diff --git a/rust/lancedb/src/dataloader/permutation/reader.rs b/rust/lancedb/src/dataloader/permutation/reader.rs index 43c229a0a..65d065db7 100644 --- a/rust/lancedb/src/dataloader/permutation/reader.rs +++ b/rust/lancedb/src/dataloader/permutation/reader.rs @@ -450,6 +450,10 @@ impl PermutationReader { } pub async fn take_offsets(&self, offsets: &[u64], selection: Select) -> Result { + if offsets.is_empty() { + return Ok(RecordBatch::new_empty(self.output_schema(selection).await?)); + } + if let Some(permutation_table) = &self.permutation_table { let offset_map = self.get_offset_map(permutation_table).await?; let row_ids = offsets @@ -955,4 +959,62 @@ mod tests { .to_vec(); assert_eq!(idx_values, &all_idx_values[4997..5000]); } + + #[tokio::test] + async fn test_take_offsets_empty_identity_reader() { + let base_table = lance_datagen::gen_batch() + .col("idx", lance_datagen::array::step::()) + .into_mem_table("tbl", RowCount::from(10), BatchCount::from(1)) + .await; + + let reader = PermutationReader::identity(base_table.base_table().clone()).await; + + let batch = reader.take_offsets(&[], Select::All).await.unwrap(); + + assert_eq!(batch.num_rows(), 0); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.schema().field(0).name(), "idx"); + } + + #[tokio::test] + async fn test_take_offsets_empty_with_permutation_table() { + let (base_table, row_ids_table, _) = setup_permutation_tables(5).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + let batch = reader.take_offsets(&[], Select::All).await.unwrap(); + + assert_eq!(batch.num_rows(), 0); + assert_eq!(batch.schema().fields().len(), 2); + assert_eq!(batch.schema().field(0).name(), "idx"); + assert_eq!(batch.schema().field(1).name(), "other_col"); + } + + #[tokio::test] + async fn test_take_offsets_empty_with_column_selection() { + let (base_table, row_ids_table, _) = setup_permutation_tables(5).await; + + let reader = PermutationReader::try_from_tables( + base_table.base_table().clone(), + row_ids_table.base_table().clone(), + 0, + ) + .await + .unwrap(); + + let batch = reader + .take_offsets(&[], Select::Columns(vec!["idx".to_string()])) + .await + .unwrap(); + + assert_eq!(batch.num_rows(), 0); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.schema().field(0).name(), "idx"); + } }