Compare commits

...

1 Commits

Author SHA1 Message Date
Xuanwo
533c81ce03 feat(python): add permutation index snapshots 2026-04-15 14:11:36 +08:00
5 changed files with 127 additions and 1 deletions

View File

@@ -446,3 +446,4 @@ def fts_query_to_json(query: Any) -> str: ...
class PermutationReader:
def __init__(self, base_table: Table, permutation_table: Table): ...
async def snapshot_indices(self) -> pa.RecordBatch: ...

View File

@@ -779,6 +779,19 @@ class Permutation:
batch = LOOP.run(do_getitems())
return self.transform_fn(batch)
def _snapshot_indices(self) -> pa.RecordBatch:
"""
Materialize the current permutation view as ordered row ids.
This is an internal helper for dataset integrations and should not be
considered stable public API.
"""
async def do_snapshot():
return await self.reader.snapshot_indices()
return LOOP.run(do_snapshot())
@deprecated(details="Use with_skip instead")
def skip(self, skip: int) -> "Permutation":
"""

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pyarrow as pa
import pyarrow.compute as pc
import math
import pytest
@@ -643,6 +644,60 @@ def test_limit_offset(some_permutation: Permutation):
some_permutation.with_skip(500).with_take(500).num_rows
def test_snapshot_indices_identity(mem_db: DBConnection):
table = mem_db.create_table(
"identity_snapshot_table",
pa.table({"id": range(10), "value": range(10)}),
)
snapshot = Permutation.identity(table)._snapshot_indices()
assert snapshot.schema == pa.schema(
[pa.field("row_id", pa.uint64(), nullable=False)]
)
assert snapshot.column("row_id").to_pylist() == list(range(10))
def test_snapshot_indices_split_respects_permutation_order(
some_table: Table, some_perm_table: Table
):
permutation = Permutation.from_tables(some_table, some_perm_table, "test")
snapshot = permutation._snapshot_indices()
row_ids = snapshot.column("row_id").to_pylist()
assert snapshot.schema == pa.schema(
[pa.field("row_id", pa.uint64(), nullable=False)]
)
assert len(row_ids) == permutation.num_rows == 50
permutation_rows = some_perm_table.to_arrow()
expected = permutation_rows.filter(pc.equal(permutation_rows["split_id"], 1))[
"row_id"
].to_pylist()
assert row_ids == expected
def test_snapshot_indices_tracks_skip_take(some_permutation: Permutation):
full_snapshot = some_permutation._snapshot_indices().column("row_id").to_pylist()
sliced = some_permutation.with_skip(100).with_take(25)._snapshot_indices()
assert sliced.column("row_id").to_pylist() == full_snapshot[100:125]
def test_snapshot_indices_ignores_selection_changes(some_permutation: Permutation):
snapshot = some_permutation._snapshot_indices()
selected = (
some_permutation.select_columns(["id"])
.rename_column("id", "row_id_alias")
.with_batch_size(32)
._snapshot_indices()
)
assert (
selected.column("row_id").to_pylist() == snapshot.column("row_id").to_pylist()
)
def test_remove_columns(some_permutation: Permutation):
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
[("id", pa.int64())]

View File

@@ -303,6 +303,15 @@ impl PyPermutationReader {
slf.reader.count_rows()
}
#[pyo3(signature = ())]
pub fn snapshot_indices<'py>(slf: PyRef<'py, Self>) -> PyResult<Bound<'py, PyAny>> {
let reader = slf.reader.clone();
future_into_py(slf.py(), async move {
let batch = reader.snapshot_indices().await.infer_error()?;
Ok(PyArrowType(batch))
})
}
#[pyo3(signature = (offset))]
pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult<Bound<'py, PyAny>> {
let reader = slf.reader.as_ref().clone();

View File

@@ -20,7 +20,7 @@ use arrow::array::AsArray;
use arrow::compute::concat_batches;
use arrow::datatypes::UInt64Type;
use arrow_array::{RecordBatch, UInt64Array};
use arrow_schema::SchemaRef;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use futures::{StreamExt, TryStreamExt};
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::io::RecordBatchStream;
@@ -409,6 +409,54 @@ impl PermutationReader {
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
}
pub async fn snapshot_indices(&self) -> Result<RecordBatch> {
let row_ids = if let Some(permutation_table) = &self.permutation_table {
permutation_table
.query(
&AnyQuery::Query(QueryRequest {
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
filter: Some(QueryFilter::Sql(format!(
"{} = {}",
SPLIT_ID_COLUMN, self.split
))),
offset: self.offset.map(|o| o as usize),
limit: self.limit.map(|l| l as usize),
..Default::default()
}),
QueryExecutionOptions::default(),
)
.await?
} else {
self.base_table
.query(
&AnyQuery::Query(QueryRequest {
select: Select::Columns(vec![ROW_ID.to_string()]),
offset: self.offset.map(|o| o as usize),
limit: self.limit.map(|l| l as usize),
..Default::default()
}),
QueryExecutionOptions::default(),
)
.await?
};
let batches = row_ids.try_collect::<Vec<_>>().await?;
let schema = Arc::new(Schema::new(vec![Field::new(
SRC_ROW_ID_COL,
DataType::UInt64,
false,
)]));
if batches.is_empty() {
return Ok(RecordBatch::try_new(
schema,
vec![Arc::new(UInt64Array::from(Vec::<u64>::new()))],
)?);
}
let batch = concat_batches(&batches[0].schema(), &batches)?;
Ok(RecordBatch::try_new(schema, vec![batch.column(0).clone()])?)
}
/// If we are going to use `take` then we load the offset -> row id map once for the split and cache it
///
/// This method fetches the map with find-or-create semantics.