mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-26 02:20:40 +00:00
feat: add a getitems implementation for the permutation (#3013)
This commit is contained in:
@@ -746,15 +746,20 @@ class Permutation:
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
Return a single row from the permutation
|
||||
|
||||
The output will always be a python dictionary regardless of the format.
|
||||
|
||||
This method is mostly useful for debugging and exploration. For actual
|
||||
processing use [iter](#iter) or a torch data loader to perform batched
|
||||
processing.
|
||||
Returns a single row from the permutation by offset
|
||||
"""
|
||||
pass
|
||||
return self.__getitems__([index])
|
||||
|
||||
def __getitems__(self, indices: list[int]) -> Any:
|
||||
"""
|
||||
Returns rows from the permutation by offset
|
||||
"""
|
||||
|
||||
async def do_getitems():
|
||||
return await self.reader.take_offsets(indices, selection=self.selection)
|
||||
|
||||
batch = LOOP.run(do_getitems())
|
||||
return self.transform_fn(batch)
|
||||
|
||||
@deprecated(details="Use with_skip instead")
|
||||
def skip(self, skip: int) -> "Permutation":
|
||||
|
||||
@@ -945,3 +945,112 @@ def test_custom_transform(mem_db):
|
||||
batch = batches[0]
|
||||
|
||||
assert batch == pa.record_batch([range(10)], ["id"])
|
||||
|
||||
|
||||
def test_getitems_basic(some_permutation: Permutation):
|
||||
"""Test __getitems__ returns correct rows by offset."""
|
||||
result = some_permutation.__getitems__([0, 1, 2])
|
||||
assert isinstance(result, dict)
|
||||
assert "id" in result
|
||||
assert "value" in result
|
||||
assert len(result["id"]) == 3
|
||||
|
||||
|
||||
def test_getitems_single_index(some_permutation: Permutation):
|
||||
"""Test __getitems__ with a single index."""
|
||||
result = some_permutation.__getitems__([0])
|
||||
assert len(result["id"]) == 1
|
||||
assert len(result["value"]) == 1
|
||||
|
||||
|
||||
def test_getitems_preserves_order(some_permutation: Permutation):
|
||||
"""Test __getitems__ returns rows in the requested order."""
|
||||
# Get rows in forward order
|
||||
forward = some_permutation.__getitems__([0, 1, 2, 3, 4])
|
||||
# Get the same rows in reverse order
|
||||
reverse = some_permutation.__getitems__([4, 3, 2, 1, 0])
|
||||
|
||||
assert forward["id"] == list(reversed(reverse["id"]))
|
||||
assert forward["value"] == list(reversed(reverse["value"]))
|
||||
|
||||
|
||||
def test_getitems_non_contiguous(some_permutation: Permutation):
|
||||
"""Test __getitems__ with non-contiguous indices."""
|
||||
result = some_permutation.__getitems__([0, 10, 50, 100, 500])
|
||||
assert len(result["id"]) == 5
|
||||
|
||||
# Each id/value pair should match what we'd get individually
|
||||
for i, offset in enumerate([0, 10, 50, 100, 500]):
|
||||
single = some_permutation.__getitems__([offset])
|
||||
assert result["id"][i] == single["id"][0]
|
||||
assert result["value"][i] == single["value"][0]
|
||||
|
||||
|
||||
def test_getitems_with_column_selection(some_permutation: Permutation):
|
||||
"""Test __getitems__ respects column selection."""
|
||||
id_only = some_permutation.select_columns(["id"])
|
||||
result = id_only.__getitems__([0, 1, 2])
|
||||
assert "id" in result
|
||||
assert "value" not in result
|
||||
assert len(result["id"]) == 3
|
||||
|
||||
|
||||
def test_getitems_with_column_rename(some_permutation: Permutation):
|
||||
"""Test __getitems__ respects column renames."""
|
||||
renamed = some_permutation.rename_column("value", "data")
|
||||
result = renamed.__getitems__([0, 1])
|
||||
assert "data" in result
|
||||
assert "value" not in result
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
|
||||
def test_getitems_with_format(some_permutation: Permutation):
|
||||
"""Test __getitems__ applies the transform function."""
|
||||
arrow_perm = some_permutation.with_format("arrow")
|
||||
result = arrow_perm.__getitems__([0, 1, 2])
|
||||
assert isinstance(result, pa.RecordBatch)
|
||||
assert result.num_rows == 3
|
||||
|
||||
|
||||
def test_getitems_with_custom_transform(some_permutation: Permutation):
|
||||
"""Test __getitems__ with a custom transform."""
|
||||
|
||||
def transform(batch: pa.RecordBatch) -> list:
|
||||
return batch.column("id").to_pylist()
|
||||
|
||||
custom = some_permutation.with_transform(transform)
|
||||
result = custom.__getitems__([0, 1, 2])
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
def test_getitems_identity_permutation(mem_db):
|
||||
"""Test __getitems__ on an identity permutation."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||
)
|
||||
perm = Permutation.identity(tbl)
|
||||
|
||||
result = perm.__getitems__([0, 5, 9])
|
||||
assert result["id"] == [0, 5, 9]
|
||||
assert result["value"] == [0, 5, 9]
|
||||
|
||||
|
||||
def test_getitems_with_limit_offset(some_permutation: Permutation):
|
||||
"""Test __getitems__ on a permutation with skip/take applied."""
|
||||
limited = some_permutation.with_skip(100).with_take(200)
|
||||
|
||||
# Should be able to access offsets within the limited range
|
||||
result = limited.__getitems__([0, 1, 199])
|
||||
assert len(result["id"]) == 3
|
||||
|
||||
# The first item of the limited permutation should match offset 100 of original
|
||||
full_result = some_permutation.__getitems__([100])
|
||||
limited_result = limited.__getitems__([0])
|
||||
assert limited_result["id"][0] == full_result["id"][0]
|
||||
|
||||
|
||||
def test_getitems_invalid_offset(some_permutation: Permutation):
|
||||
"""Test __getitems__ with an out-of-range offset raises an error."""
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.__getitems__([999999])
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
|
||||
use crate::{
|
||||
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
|
||||
};
|
||||
use arrow::pyarrow::ToPyArrow;
|
||||
use arrow::pyarrow::{PyArrowType, ToPyArrow};
|
||||
use lancedb::{
|
||||
dataloader::permutation::{
|
||||
builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
@@ -328,4 +328,21 @@ impl PyPermutationReader {
|
||||
Ok(RecordBatchStream::new(stream))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (indices, *, selection=None))]
|
||||
pub fn take_offsets<'py>(
|
||||
slf: PyRef<'py, Self>,
|
||||
indices: Vec<u64>,
|
||||
selection: Option<Bound<'py, PyAny>>,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let selection = Self::parse_selection(selection)?;
|
||||
let reader = slf.reader.clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let batch = reader
|
||||
.take_offsets(&indices, selection)
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(PyArrowType(batch))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user