mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-01 03:10:43 +00:00
feat: add a getitems implementation for the permutation (#3013)
This commit is contained in:
@@ -746,15 +746,20 @@ class Permutation:
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
Return a single row from the permutation
|
||||
|
||||
The output will always be a python dictionary regardless of the format.
|
||||
|
||||
This method is mostly useful for debugging and exploration. For actual
|
||||
processing use [iter](#iter) or a torch data loader to perform batched
|
||||
processing.
|
||||
Returns a single row from the permutation by offset
|
||||
"""
|
||||
pass
|
||||
return self.__getitems__([index])
|
||||
|
||||
def __getitems__(self, indices: list[int]) -> Any:
|
||||
"""
|
||||
Returns rows from the permutation by offset
|
||||
"""
|
||||
|
||||
async def do_getitems():
|
||||
return await self.reader.take_offsets(indices, selection=self.selection)
|
||||
|
||||
batch = LOOP.run(do_getitems())
|
||||
return self.transform_fn(batch)
|
||||
|
||||
@deprecated(details="Use with_skip instead")
|
||||
def skip(self, skip: int) -> "Permutation":
|
||||
|
||||
@@ -945,3 +945,112 @@ def test_custom_transform(mem_db):
|
||||
batch = batches[0]
|
||||
|
||||
assert batch == pa.record_batch([range(10)], ["id"])
|
||||
|
||||
|
||||
def test_getitems_basic(some_permutation: Permutation):
|
||||
"""Test __getitems__ returns correct rows by offset."""
|
||||
result = some_permutation.__getitems__([0, 1, 2])
|
||||
assert isinstance(result, dict)
|
||||
assert "id" in result
|
||||
assert "value" in result
|
||||
assert len(result["id"]) == 3
|
||||
|
||||
|
||||
def test_getitems_single_index(some_permutation: Permutation):
|
||||
"""Test __getitems__ with a single index."""
|
||||
result = some_permutation.__getitems__([0])
|
||||
assert len(result["id"]) == 1
|
||||
assert len(result["value"]) == 1
|
||||
|
||||
|
||||
def test_getitems_preserves_order(some_permutation: Permutation):
|
||||
"""Test __getitems__ returns rows in the requested order."""
|
||||
# Get rows in forward order
|
||||
forward = some_permutation.__getitems__([0, 1, 2, 3, 4])
|
||||
# Get the same rows in reverse order
|
||||
reverse = some_permutation.__getitems__([4, 3, 2, 1, 0])
|
||||
|
||||
assert forward["id"] == list(reversed(reverse["id"]))
|
||||
assert forward["value"] == list(reversed(reverse["value"]))
|
||||
|
||||
|
||||
def test_getitems_non_contiguous(some_permutation: Permutation):
|
||||
"""Test __getitems__ with non-contiguous indices."""
|
||||
result = some_permutation.__getitems__([0, 10, 50, 100, 500])
|
||||
assert len(result["id"]) == 5
|
||||
|
||||
# Each id/value pair should match what we'd get individually
|
||||
for i, offset in enumerate([0, 10, 50, 100, 500]):
|
||||
single = some_permutation.__getitems__([offset])
|
||||
assert result["id"][i] == single["id"][0]
|
||||
assert result["value"][i] == single["value"][0]
|
||||
|
||||
|
||||
def test_getitems_with_column_selection(some_permutation: Permutation):
|
||||
"""Test __getitems__ respects column selection."""
|
||||
id_only = some_permutation.select_columns(["id"])
|
||||
result = id_only.__getitems__([0, 1, 2])
|
||||
assert "id" in result
|
||||
assert "value" not in result
|
||||
assert len(result["id"]) == 3
|
||||
|
||||
|
||||
def test_getitems_with_column_rename(some_permutation: Permutation):
|
||||
"""Test __getitems__ respects column renames."""
|
||||
renamed = some_permutation.rename_column("value", "data")
|
||||
result = renamed.__getitems__([0, 1])
|
||||
assert "data" in result
|
||||
assert "value" not in result
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
|
||||
def test_getitems_with_format(some_permutation: Permutation):
|
||||
"""Test __getitems__ applies the transform function."""
|
||||
arrow_perm = some_permutation.with_format("arrow")
|
||||
result = arrow_perm.__getitems__([0, 1, 2])
|
||||
assert isinstance(result, pa.RecordBatch)
|
||||
assert result.num_rows == 3
|
||||
|
||||
|
||||
def test_getitems_with_custom_transform(some_permutation: Permutation):
|
||||
"""Test __getitems__ with a custom transform."""
|
||||
|
||||
def transform(batch: pa.RecordBatch) -> list:
|
||||
return batch.column("id").to_pylist()
|
||||
|
||||
custom = some_permutation.with_transform(transform)
|
||||
result = custom.__getitems__([0, 1, 2])
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
def test_getitems_identity_permutation(mem_db):
|
||||
"""Test __getitems__ on an identity permutation."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||
)
|
||||
perm = Permutation.identity(tbl)
|
||||
|
||||
result = perm.__getitems__([0, 5, 9])
|
||||
assert result["id"] == [0, 5, 9]
|
||||
assert result["value"] == [0, 5, 9]
|
||||
|
||||
|
||||
def test_getitems_with_limit_offset(some_permutation: Permutation):
|
||||
"""Test __getitems__ on a permutation with skip/take applied."""
|
||||
limited = some_permutation.with_skip(100).with_take(200)
|
||||
|
||||
# Should be able to access offsets within the limited range
|
||||
result = limited.__getitems__([0, 1, 199])
|
||||
assert len(result["id"]) == 3
|
||||
|
||||
# The first item of the limited permutation should match offset 100 of original
|
||||
full_result = some_permutation.__getitems__([100])
|
||||
limited_result = limited.__getitems__([0])
|
||||
assert limited_result["id"][0] == full_result["id"][0]
|
||||
|
||||
|
||||
def test_getitems_invalid_offset(some_permutation: Permutation):
|
||||
"""Test __getitems__ with an out-of-range offset raises an error."""
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.__getitems__([999999])
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
|
||||
use crate::{
|
||||
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
|
||||
};
|
||||
use arrow::pyarrow::ToPyArrow;
|
||||
use arrow::pyarrow::{PyArrowType, ToPyArrow};
|
||||
use lancedb::{
|
||||
dataloader::permutation::{
|
||||
builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
@@ -328,4 +328,21 @@ impl PyPermutationReader {
|
||||
Ok(RecordBatchStream::new(stream))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (indices, *, selection=None))]
|
||||
pub fn take_offsets<'py>(
|
||||
slf: PyRef<'py, Self>,
|
||||
indices: Vec<u64>,
|
||||
selection: Option<Bound<'py, PyAny>>,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let selection = Self::parse_selection(selection)?;
|
||||
let reader = slf.reader.clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let batch = reader
|
||||
.take_offsets(&indices, selection)
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(PyArrowType(batch))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +39,9 @@ pub struct PermutationReader {
|
||||
limit: Option<u64>,
|
||||
available_rows: u64,
|
||||
split: u64,
|
||||
// Cached map of offset to row id for the split
|
||||
#[allow(clippy::type_complexity)]
|
||||
offset_map: Arc<tokio::sync::Mutex<Option<Arc<HashMap<u64, u64>>>>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PermutationReader {
|
||||
@@ -72,6 +75,7 @@ impl PermutationReader {
|
||||
limit: None,
|
||||
available_rows: 0,
|
||||
split,
|
||||
offset_map: Arc::new(tokio::sync::Mutex::new(None)),
|
||||
};
|
||||
slf.validate().await?;
|
||||
// Calculate the number of available rows
|
||||
@@ -157,6 +161,7 @@ impl PermutationReader {
|
||||
let available_rows = self.verify_limit_offset(self.limit, Some(offset)).await?;
|
||||
self.offset = Some(offset);
|
||||
self.available_rows = available_rows;
|
||||
self.offset_map = Arc::new(tokio::sync::Mutex::new(None));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -164,6 +169,7 @@ impl PermutationReader {
|
||||
let available_rows = self.verify_limit_offset(Some(limit), self.offset).await?;
|
||||
self.available_rows = available_rows;
|
||||
self.limit = Some(limit);
|
||||
self.offset_map = Arc::new(tokio::sync::Mutex::new(None));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -180,8 +186,9 @@ impl PermutationReader {
|
||||
base_table: &Arc<dyn BaseTable>,
|
||||
row_ids: RecordBatch,
|
||||
selection: Select,
|
||||
has_row_id: bool,
|
||||
) -> Result<RecordBatch> {
|
||||
let has_row_id = Self::has_row_id(&selection)?;
|
||||
|
||||
let num_rows = row_ids.num_rows();
|
||||
let row_ids = row_ids
|
||||
.column(0)
|
||||
@@ -282,14 +289,13 @@ impl PermutationReader {
|
||||
row_ids: DatasetRecordBatchStream,
|
||||
selection: Select,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let has_row_id = Self::has_row_id(&selection)?;
|
||||
let mut stream = row_ids
|
||||
.map_err(Error::from)
|
||||
.try_filter_map(move |batch| {
|
||||
let selection = selection.clone();
|
||||
let base_table = base_table.clone();
|
||||
async move {
|
||||
Self::load_batch(&base_table, batch, selection, has_row_id)
|
||||
Self::load_batch(&base_table, batch, selection)
|
||||
.await
|
||||
.map(Some)
|
||||
}
|
||||
@@ -397,6 +403,82 @@ impl PermutationReader {
|
||||
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
|
||||
}
|
||||
|
||||
/// If we are going to use `take` then we load the offset -> row id map once for the split and cache it
|
||||
///
|
||||
/// This method fetches the map with find-or-create semantics.
|
||||
async fn get_offset_map(
|
||||
&self,
|
||||
permutation_table: &Arc<dyn BaseTable>,
|
||||
) -> Result<Arc<HashMap<u64, u64>>> {
|
||||
let mut offset_map_ref = self.offset_map.lock().await;
|
||||
if let Some(offset_map) = &*offset_map_ref {
|
||||
return Ok(offset_map.clone());
|
||||
}
|
||||
let mut offset_map = HashMap::new();
|
||||
let mut row_ids_query = Table::from(permutation_table.clone())
|
||||
.query()
|
||||
.select(Select::Columns(vec![SRC_ROW_ID_COL.to_string()]))
|
||||
.only_if(format!("{} = {}", SPLIT_ID_COLUMN, self.split));
|
||||
if let Some(offset) = self.offset {
|
||||
row_ids_query = row_ids_query.offset(offset as usize);
|
||||
}
|
||||
if let Some(limit) = self.limit {
|
||||
row_ids_query = row_ids_query.limit(limit as usize);
|
||||
}
|
||||
let mut row_ids = row_ids_query.execute().await?;
|
||||
while let Some(batch) = row_ids.try_next().await? {
|
||||
let row_ids = batch
|
||||
.column(0)
|
||||
.as_primitive::<UInt64Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
for (i, row_id) in row_ids.iter().enumerate() {
|
||||
offset_map.insert(i as u64, *row_id);
|
||||
}
|
||||
}
|
||||
let offset_map = Arc::new(offset_map);
|
||||
*offset_map_ref = Some(offset_map.clone());
|
||||
Ok(offset_map)
|
||||
}
|
||||
|
||||
pub async fn take_offsets(&self, offsets: &[u64], selection: Select) -> Result<RecordBatch> {
|
||||
if let Some(permutation_table) = &self.permutation_table {
|
||||
let offset_map = self.get_offset_map(permutation_table).await?;
|
||||
let row_ids = offsets
|
||||
.iter()
|
||||
.map(|o| offset_map.get(o).copied().expect_ok().map_err(Error::from))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let row_ids = RecordBatch::try_new(
|
||||
Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
|
||||
"row_id",
|
||||
arrow_schema::DataType::UInt64,
|
||||
false,
|
||||
)])),
|
||||
vec![Arc::new(UInt64Array::from(row_ids))],
|
||||
)?;
|
||||
Self::load_batch(&self.base_table, row_ids, selection).await
|
||||
} else {
|
||||
let table = Table::from(self.base_table.clone());
|
||||
let batches = table
|
||||
.take_offsets(offsets.to_vec())
|
||||
.select(selection.clone())
|
||||
.execute()
|
||||
.await?
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
if let Some(first_batch) = batches.first() {
|
||||
let schema = first_batch.schema();
|
||||
let batch = arrow::compute::concat_batches(&schema, &batches)?;
|
||||
Ok(batch)
|
||||
} else {
|
||||
Ok(RecordBatch::try_new(
|
||||
self.output_schema(selection).await?,
|
||||
vec![],
|
||||
)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn output_schema(&self, selection: Select) -> Result<SchemaRef> {
|
||||
let table = Table::from(self.base_table.clone());
|
||||
table.query().select(selection).output_schema().await
|
||||
@@ -543,4 +625,224 @@ mod tests {
|
||||
check_batch(&mut stream, &row_ids[7..9]).await;
|
||||
assert!(stream.try_next().await.unwrap().is_none());
|
||||
}
|
||||
|
||||
/// Helper to create a base table and permutation table for take_offsets tests.
|
||||
/// Returns (base_table, row_ids_table, shuffled_row_ids).
|
||||
async fn setup_permutation_tables(num_rows: usize) -> (Table, Table, Vec<u64>) {
|
||||
let base_table = lance_datagen::gen_batch()
|
||||
.col("idx", lance_datagen::array::step::<Int32Type>())
|
||||
.col("other_col", lance_datagen::array::step::<UInt64Type>())
|
||||
.into_mem_table("tbl", RowCount::from(num_rows as u64), BatchCount::from(1))
|
||||
.await;
|
||||
|
||||
let mut row_ids = collect_column::<UInt64Type>(&base_table, "_rowid").await;
|
||||
row_ids.shuffle(&mut rand::rng());
|
||||
|
||||
let split_ids = UInt64Array::from_iter_values(std::iter::repeat_n(0u64, row_ids.len()));
|
||||
let permutation_batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("row_id", DataType::UInt64, false),
|
||||
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(row_ids.clone())),
|
||||
Arc::new(split_ids),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
|
||||
|
||||
(base_table, row_ids_table, row_ids)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_take_offsets_with_permutation_table() {
|
||||
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await;
|
||||
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Take specific offsets and verify the returned rows match the permutation order
|
||||
let offsets = vec![0, 2, 4];
|
||||
let batch = reader.take_offsets(&offsets, Select::All).await.unwrap();
|
||||
|
||||
assert_eq!(batch.num_rows(), 3);
|
||||
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
let expected: Vec<i32> = offsets
|
||||
.iter()
|
||||
.map(|&o| row_ids[o as usize] as i32)
|
||||
.collect();
|
||||
assert_eq!(idx_values, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_take_offsets_preserves_order() {
|
||||
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await;
|
||||
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Take offsets in reverse order and verify returned rows match that order
|
||||
let offsets = vec![5, 3, 1, 0];
|
||||
let batch = reader.take_offsets(&offsets, Select::All).await.unwrap();
|
||||
|
||||
assert_eq!(batch.num_rows(), 4);
|
||||
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
let expected: Vec<i32> = offsets
|
||||
.iter()
|
||||
.map(|&o| row_ids[o as usize] as i32)
|
||||
.collect();
|
||||
assert_eq!(idx_values, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_take_offsets_with_column_selection() {
|
||||
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await;
|
||||
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let offsets = vec![1, 3];
|
||||
let batch = reader
|
||||
.take_offsets(&offsets, Select::Columns(vec!["idx".to_string()]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(batch.num_rows(), 2);
|
||||
assert_eq!(batch.num_columns(), 1);
|
||||
assert_eq!(batch.schema().field(0).name(), "idx");
|
||||
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
let expected: Vec<i32> = offsets
|
||||
.iter()
|
||||
.map(|&o| row_ids[o as usize] as i32)
|
||||
.collect();
|
||||
assert_eq!(idx_values, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_take_offsets_invalid_offset() {
|
||||
let (base_table, row_ids_table, _) = setup_permutation_tables(5).await;
|
||||
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Offset 999 doesn't exist in the offset map
|
||||
let result = reader.take_offsets(&[0, 999], Select::All).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_take_offsets_identity_reader() {
|
||||
let base_table = lance_datagen::gen_batch()
|
||||
.col("idx", lance_datagen::array::step::<Int32Type>())
|
||||
.into_mem_table("tbl", RowCount::from(10), BatchCount::from(1))
|
||||
.await;
|
||||
|
||||
let reader = PermutationReader::identity(base_table.base_table().clone()).await;
|
||||
|
||||
// With no permutation table, take_offsets uses the base table directly
|
||||
let offsets = vec![0, 2, 4, 6];
|
||||
let batch = reader.take_offsets(&offsets, Select::All).await.unwrap();
|
||||
|
||||
assert_eq!(batch.num_rows(), 4);
|
||||
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, vec![0, 2, 4, 6]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_take_offsets_caches_offset_map() {
|
||||
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(10).await;
|
||||
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// First call populates the cache
|
||||
let batch1 = reader.take_offsets(&[0, 1], Select::All).await.unwrap();
|
||||
|
||||
// Second call should use the cached offset map and produce consistent results
|
||||
let batch2 = reader.take_offsets(&[0, 1], Select::All).await.unwrap();
|
||||
|
||||
let values1 = batch1
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
let values2 = batch2
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
assert_eq!(values1, values2);
|
||||
|
||||
let expected: Vec<i32> = vec![row_ids[0] as i32, row_ids[1] as i32];
|
||||
assert_eq!(values1, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_take_offsets_single_offset() {
|
||||
let (base_table, row_ids_table, row_ids) = setup_permutation_tables(5).await;
|
||||
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batch = reader.take_offsets(&[2], Select::All).await.unwrap();
|
||||
|
||||
assert_eq!(batch.num_rows(), 1);
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, vec![row_ids[2] as i32]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user