From 533c81ce03c65c0396c4467f7ddc202eae021390 Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Wed, 15 Apr 2026 14:11:36 +0800 Subject: [PATCH] feat(python): add permutation index snapshots --- python/python/lancedb/_lancedb.pyi | 1 + python/python/lancedb/permutation.py | 13 +++++ python/python/tests/test_permutation.py | 55 +++++++++++++++++++ python/src/permutation.rs | 9 +++ .../src/dataloader/permutation/reader.rs | 50 ++++++++++++++++- 5 files changed, 127 insertions(+), 1 deletion(-) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 76c08041b..c9613f440 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -446,3 +446,4 @@ def fts_query_to_json(query: Any) -> str: ... class PermutationReader: def __init__(self, base_table: Table, permutation_table: Table): ... + async def snapshot_indices(self) -> pa.RecordBatch: ... diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 724a0fd25..e2e8fadeb 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -779,6 +779,19 @@ class Permutation: batch = LOOP.run(do_getitems()) return self.transform_fn(batch) + def _snapshot_indices(self) -> pa.RecordBatch: + """ + Materialize the current permutation view as ordered row ids. + + This is an internal helper for dataset integrations and should not be + considered stable public API. + """ + + async def do_snapshot(): + return await self.reader.snapshot_indices() + + return LOOP.run(do_snapshot()) + @deprecated(details="Use with_skip instead") def skip(self, skip: int) -> "Permutation": """ diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index bb92ba0ba..9beda3152 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors import pyarrow as pa +import pyarrow.compute as pc import math import pytest @@ -643,6 +644,60 @@ def test_limit_offset(some_permutation: Permutation): some_permutation.with_skip(500).with_take(500).num_rows +def test_snapshot_indices_identity(mem_db: DBConnection): + table = mem_db.create_table( + "identity_snapshot_table", + pa.table({"id": range(10), "value": range(10)}), + ) + + snapshot = Permutation.identity(table)._snapshot_indices() + + assert snapshot.schema == pa.schema( + [pa.field("row_id", pa.uint64(), nullable=False)] + ) + assert snapshot.column("row_id").to_pylist() == list(range(10)) + + +def test_snapshot_indices_split_respects_permutation_order( + some_table: Table, some_perm_table: Table +): + permutation = Permutation.from_tables(some_table, some_perm_table, "test") + snapshot = permutation._snapshot_indices() + row_ids = snapshot.column("row_id").to_pylist() + + assert snapshot.schema == pa.schema( + [pa.field("row_id", pa.uint64(), nullable=False)] + ) + assert len(row_ids) == permutation.num_rows == 50 + + permutation_rows = some_perm_table.to_arrow() + expected = permutation_rows.filter(pc.equal(permutation_rows["split_id"], 1))[ + "row_id" + ].to_pylist() + assert row_ids == expected + + +def test_snapshot_indices_tracks_skip_take(some_permutation: Permutation): + full_snapshot = some_permutation._snapshot_indices().column("row_id").to_pylist() + sliced = some_permutation.with_skip(100).with_take(25)._snapshot_indices() + + assert sliced.column("row_id").to_pylist() == full_snapshot[100:125] + + +def test_snapshot_indices_ignores_selection_changes(some_permutation: Permutation): + snapshot = some_permutation._snapshot_indices() + selected = ( + some_permutation.select_columns(["id"]) + .rename_column("id", "row_id_alias") + .with_batch_size(32) + ._snapshot_indices() + ) + + assert ( + selected.column("row_id").to_pylist() == snapshot.column("row_id").to_pylist() + ) + + def test_remove_columns(some_permutation: Permutation): assert some_permutation.remove_columns(["value"]).schema == pa.schema( [("id", pa.int64())] diff --git a/python/src/permutation.rs b/python/src/permutation.rs index 21b8c9c47..244f8db73 100644 --- a/python/src/permutation.rs +++ b/python/src/permutation.rs @@ -303,6 +303,15 @@ impl PyPermutationReader { slf.reader.count_rows() } + #[pyo3(signature = ())] + pub fn snapshot_indices<'py>(slf: PyRef<'py, Self>) -> PyResult> { + let reader = slf.reader.clone(); + future_into_py(slf.py(), async move { + let batch = reader.snapshot_indices().await.infer_error()?; + Ok(PyArrowType(batch)) + }) + } + #[pyo3(signature = (offset))] pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult> { let reader = slf.reader.as_ref().clone(); diff --git a/rust/lancedb/src/dataloader/permutation/reader.rs b/rust/lancedb/src/dataloader/permutation/reader.rs index 43c229a0a..484c5179f 100644 --- a/rust/lancedb/src/dataloader/permutation/reader.rs +++ b/rust/lancedb/src/dataloader/permutation/reader.rs @@ -20,7 +20,7 @@ use arrow::array::AsArray; use arrow::compute::concat_batches; use arrow::datatypes::UInt64Type; use arrow_array::{RecordBatch, UInt64Array}; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use futures::{StreamExt, TryStreamExt}; use lance::dataset::scanner::DatasetRecordBatchStream; use lance::io::RecordBatchStream; @@ -409,6 +409,54 @@ impl PermutationReader { Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await } + pub async fn snapshot_indices(&self) -> Result { + let row_ids = if let Some(permutation_table) = &self.permutation_table { + permutation_table + .query( + &AnyQuery::Query(QueryRequest { + select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]), + filter: Some(QueryFilter::Sql(format!( + "{} = {}", + SPLIT_ID_COLUMN, self.split + ))), + offset: self.offset.map(|o| o as usize), + limit: self.limit.map(|l| l as usize), + ..Default::default() + }), + QueryExecutionOptions::default(), + ) + .await? + } else { + self.base_table + .query( + &AnyQuery::Query(QueryRequest { + select: Select::Columns(vec![ROW_ID.to_string()]), + offset: self.offset.map(|o| o as usize), + limit: self.limit.map(|l| l as usize), + ..Default::default() + }), + QueryExecutionOptions::default(), + ) + .await? + }; + + let batches = row_ids.try_collect::>().await?; + let schema = Arc::new(Schema::new(vec![Field::new( + SRC_ROW_ID_COL, + DataType::UInt64, + false, + )])); + if batches.is_empty() { + return Ok(RecordBatch::try_new( + schema, + vec![Arc::new(UInt64Array::from(Vec::::new()))], + )?); + } + + let batch = concat_batches(&batches[0].schema(), &batches)?; + Ok(RecordBatch::try_new(schema, vec![batch.column(0).clone()])?) + } + /// If we are going to use `take` then we load the offset -> row id map once for the split and cache it /// /// This method fetches the map with find-or-create semantics.