mirror of
https://github.com/lancedb/lancedb.git
synced 2026-04-16 12:50:39 +00:00
Compare commits
1 Commits
rust-neste
...
xuanwo/per
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
533c81ce03 |
@@ -446,3 +446,4 @@ def fts_query_to_json(query: Any) -> str: ...
|
||||
|
||||
class PermutationReader:
|
||||
def __init__(self, base_table: Table, permutation_table: Table): ...
|
||||
async def snapshot_indices(self) -> pa.RecordBatch: ...
|
||||
|
||||
@@ -779,6 +779,19 @@ class Permutation:
|
||||
batch = LOOP.run(do_getitems())
|
||||
return self.transform_fn(batch)
|
||||
|
||||
def _snapshot_indices(self) -> pa.RecordBatch:
|
||||
"""
|
||||
Materialize the current permutation view as ordered row ids.
|
||||
|
||||
This is an internal helper for dataset integrations and should not be
|
||||
considered stable public API.
|
||||
"""
|
||||
|
||||
async def do_snapshot():
|
||||
return await self.reader.snapshot_indices()
|
||||
|
||||
return LOOP.run(do_snapshot())
|
||||
|
||||
@deprecated(details="Use with_skip instead")
|
||||
def skip(self, skip: int) -> "Permutation":
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import math
|
||||
import pytest
|
||||
|
||||
@@ -643,6 +644,60 @@ def test_limit_offset(some_permutation: Permutation):
|
||||
some_permutation.with_skip(500).with_take(500).num_rows
|
||||
|
||||
|
||||
def test_snapshot_indices_identity(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"identity_snapshot_table",
|
||||
pa.table({"id": range(10), "value": range(10)}),
|
||||
)
|
||||
|
||||
snapshot = Permutation.identity(table)._snapshot_indices()
|
||||
|
||||
assert snapshot.schema == pa.schema(
|
||||
[pa.field("row_id", pa.uint64(), nullable=False)]
|
||||
)
|
||||
assert snapshot.column("row_id").to_pylist() == list(range(10))
|
||||
|
||||
|
||||
def test_snapshot_indices_split_respects_permutation_order(
|
||||
some_table: Table, some_perm_table: Table
|
||||
):
|
||||
permutation = Permutation.from_tables(some_table, some_perm_table, "test")
|
||||
snapshot = permutation._snapshot_indices()
|
||||
row_ids = snapshot.column("row_id").to_pylist()
|
||||
|
||||
assert snapshot.schema == pa.schema(
|
||||
[pa.field("row_id", pa.uint64(), nullable=False)]
|
||||
)
|
||||
assert len(row_ids) == permutation.num_rows == 50
|
||||
|
||||
permutation_rows = some_perm_table.to_arrow()
|
||||
expected = permutation_rows.filter(pc.equal(permutation_rows["split_id"], 1))[
|
||||
"row_id"
|
||||
].to_pylist()
|
||||
assert row_ids == expected
|
||||
|
||||
|
||||
def test_snapshot_indices_tracks_skip_take(some_permutation: Permutation):
|
||||
full_snapshot = some_permutation._snapshot_indices().column("row_id").to_pylist()
|
||||
sliced = some_permutation.with_skip(100).with_take(25)._snapshot_indices()
|
||||
|
||||
assert sliced.column("row_id").to_pylist() == full_snapshot[100:125]
|
||||
|
||||
|
||||
def test_snapshot_indices_ignores_selection_changes(some_permutation: Permutation):
|
||||
snapshot = some_permutation._snapshot_indices()
|
||||
selected = (
|
||||
some_permutation.select_columns(["id"])
|
||||
.rename_column("id", "row_id_alias")
|
||||
.with_batch_size(32)
|
||||
._snapshot_indices()
|
||||
)
|
||||
|
||||
assert (
|
||||
selected.column("row_id").to_pylist() == snapshot.column("row_id").to_pylist()
|
||||
)
|
||||
|
||||
|
||||
def test_remove_columns(some_permutation: Permutation):
|
||||
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
|
||||
[("id", pa.int64())]
|
||||
|
||||
@@ -303,6 +303,15 @@ impl PyPermutationReader {
|
||||
slf.reader.count_rows()
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn snapshot_indices<'py>(slf: PyRef<'py, Self>) -> PyResult<Bound<'py, PyAny>> {
|
||||
let reader = slf.reader.clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let batch = reader.snapshot_indices().await.infer_error()?;
|
||||
Ok(PyArrowType(batch))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (offset))]
|
||||
pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult<Bound<'py, PyAny>> {
|
||||
let reader = slf.reader.as_ref().clone();
|
||||
|
||||
@@ -20,7 +20,7 @@ use arrow::array::AsArray;
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow::datatypes::UInt64Type;
|
||||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
use arrow_schema::SchemaRef;
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::io::RecordBatchStream;
|
||||
@@ -409,6 +409,54 @@ impl PermutationReader {
|
||||
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
|
||||
}
|
||||
|
||||
pub async fn snapshot_indices(&self) -> Result<RecordBatch> {
|
||||
let row_ids = if let Some(permutation_table) = &self.permutation_table {
|
||||
permutation_table
|
||||
.query(
|
||||
&AnyQuery::Query(QueryRequest {
|
||||
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
|
||||
filter: Some(QueryFilter::Sql(format!(
|
||||
"{} = {}",
|
||||
SPLIT_ID_COLUMN, self.split
|
||||
))),
|
||||
offset: self.offset.map(|o| o as usize),
|
||||
limit: self.limit.map(|l| l as usize),
|
||||
..Default::default()
|
||||
}),
|
||||
QueryExecutionOptions::default(),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
self.base_table
|
||||
.query(
|
||||
&AnyQuery::Query(QueryRequest {
|
||||
select: Select::Columns(vec![ROW_ID.to_string()]),
|
||||
offset: self.offset.map(|o| o as usize),
|
||||
limit: self.limit.map(|l| l as usize),
|
||||
..Default::default()
|
||||
}),
|
||||
QueryExecutionOptions::default(),
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
let batches = row_ids.try_collect::<Vec<_>>().await?;
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
SRC_ROW_ID_COL,
|
||||
DataType::UInt64,
|
||||
false,
|
||||
)]));
|
||||
if batches.is_empty() {
|
||||
return Ok(RecordBatch::try_new(
|
||||
schema,
|
||||
vec![Arc::new(UInt64Array::from(Vec::<u64>::new()))],
|
||||
)?);
|
||||
}
|
||||
|
||||
let batch = concat_batches(&batches[0].schema(), &batches)?;
|
||||
Ok(RecordBatch::try_new(schema, vec![batch.column(0).clone()])?)
|
||||
}
|
||||
|
||||
/// If we are going to use `take` then we load the offset -> row id map once for the split and cache it
|
||||
///
|
||||
/// This method fetches the map with find-or-create semantics.
|
||||
|
||||
Reference in New Issue
Block a user