feat(python): add public take_offsets method on Permutation (#3375)

Closes #3243.

This PR exposes a new public api `Permutation.take_offsets(offsets:
list[int])`, since users initially had to call __getitems__ directly to
batch-fetch rows by position.

Currently, the name matches the existing `Table.take_offsets` pattern,
and now the dunder `__getitem__` and `__getitems__` now delegate to it.

Also, fixes a parse error when `PermutationReader::take_offsets` gets an
empty list. Now returns an empty `RecordBatch` with the correct schema
instead. Bundled this because without the fix the new public API blows
up on a perfectly reasonable input.

`__getitems__` is preserved since PyTorch's batched DataLoader requires
it.

### Testing

- Added 3 new Rust tests for empty offsets including permutation table
with Select::All, Select::Columns, and identity path
- Added 3 new Python tests for the public API including a happy case,
and empty input on both identity and permutation

clippy, format, check all clean!

cc: @westonpace
This commit is contained in:
Drew Gallardo
2026-05-18 09:35:56 -07:00
committed by GitHub
parent 8df2fff75f
commit aac6c62459
3 changed files with 105 additions and 7 deletions

View File

@@ -968,22 +968,32 @@ class Permutation:
new.transform_fn = transform
return new
def take_offsets(self, offsets: list[int]) -> Any:
"""
Take rows from the permutation by offset
The returned value is passed through the permutation's current transform,
so `with_format` and `with_transform` affect this method in the same way
they affect iteration.
"""
async def do_take_offsets():
return await self.reader.take_offsets(offsets, selection=self.selection)
batch = LOOP.run(do_take_offsets())
return self.transform_fn(batch)
def __getitem__(self, index: int) -> Any:
"""
Returns a single row from the permutation by offset
"""
return self.__getitems__([index])
return self.take_offsets([index])
def __getitems__(self, indices: list[int]) -> Any:
"""
Returns rows from the permutation by offset
"""
async def do_getitems():
return await self.reader.take_offsets(indices, selection=self.selection)
batch = LOOP.run(do_getitems())
return self.transform_fn(batch)
return self.take_offsets(indices)
@deprecated(details="Use with_skip instead")
def skip(self, skip: int) -> "Permutation":

View File

@@ -1080,3 +1080,29 @@ def test_getitems_invalid_offset(some_permutation: Permutation):
"""Test __getitems__ with an out-of-range offset raises an error."""
with pytest.raises(Exception):
some_permutation.__getitems__([999999])
def test_take_offsets(some_permutation: Permutation):
result = some_permutation.take_offsets([0, 1, 2])
assert isinstance(result, list)
assert "id" in result[0]
assert "value" in result[0]
assert len(result) == 3
def test_take_offsets_empty_identity_permutation(mem_db):
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
permutation = Permutation.identity(tbl)
result = permutation.take_offsets([])
assert result == []
def test_take_offsets_empty_permutation(some_permutation: Permutation):
result = some_permutation.take_offsets([])
assert result == []

View File

@@ -450,6 +450,10 @@ impl PermutationReader {
}
pub async fn take_offsets(&self, offsets: &[u64], selection: Select) -> Result<RecordBatch> {
if offsets.is_empty() {
return Ok(RecordBatch::new_empty(self.output_schema(selection).await?));
}
if let Some(permutation_table) = &self.permutation_table {
let offset_map = self.get_offset_map(permutation_table).await?;
let row_ids = offsets
@@ -955,4 +959,62 @@ mod tests {
.to_vec();
assert_eq!(idx_values, &all_idx_values[4997..5000]);
}
#[tokio::test]
async fn test_take_offsets_empty_identity_reader() {
let base_table = lance_datagen::gen_batch()
.col("idx", lance_datagen::array::step::<Int32Type>())
.into_mem_table("tbl", RowCount::from(10), BatchCount::from(1))
.await;
let reader = PermutationReader::identity(base_table.base_table().clone()).await;
let batch = reader.take_offsets(&[], Select::All).await.unwrap();
assert_eq!(batch.num_rows(), 0);
assert_eq!(batch.num_columns(), 1);
assert_eq!(batch.schema().field(0).name(), "idx");
}
#[tokio::test]
async fn test_take_offsets_empty_with_permutation_table() {
let (base_table, row_ids_table, _) = setup_permutation_tables(5).await;
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
let batch = reader.take_offsets(&[], Select::All).await.unwrap();
assert_eq!(batch.num_rows(), 0);
assert_eq!(batch.schema().fields().len(), 2);
assert_eq!(batch.schema().field(0).name(), "idx");
assert_eq!(batch.schema().field(1).name(), "other_col");
}
#[tokio::test]
async fn test_take_offsets_empty_with_column_selection() {
let (base_table, row_ids_table, _) = setup_permutation_tables(5).await;
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
0,
)
.await
.unwrap();
let batch = reader
.take_offsets(&[], Select::Columns(vec!["idx".to_string()]))
.await
.unwrap();
assert_eq!(batch.num_rows(), 0);
assert_eq!(batch.num_columns(), 1);
assert_eq!(batch.schema().field(0).name(), "idx");
}
}