Track permutation reopen metadata

This commit is contained in:
Xuanwo
2026-04-08 17:34:05 +08:00
parent a898dc81c2
commit 2d380d1669
2 changed files with 181 additions and 7 deletions

View File

@@ -385,6 +385,12 @@ class Permutation:
selection: dict[str, str],
batch_size: int,
transform_fn: Callable[pa.RecordBatch, Any],
*,
base_table: LanceTable,
permutation_table: Optional[LanceTable],
split: int,
offset: Optional[int] = None,
limit: Optional[int] = None,
):
"""
Internal constructor. Use [from_tables](#from_tables) instead.
@@ -395,6 +401,21 @@ class Permutation:
self.selection = selection
self.transform_fn = transform_fn
self.batch_size = batch_size
# These fields are used to reconstruct the permutation in a new process.
self._base_table = base_table
self._permutation_table = permutation_table
self._split = split
self._offset = offset
self._limit = limit
def _reopen_metadata(self) -> dict[str, Any]:
return {
"base_table": self._base_table,
"permutation_table": self._permutation_table,
"split": self._split,
"offset": self._offset,
"limit": self._limit,
}
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
"""
@@ -403,7 +424,13 @@ class Permutation:
Does not validation of the selection and it replaces it entirely. This is not
intended for public use.
"""
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
return Permutation(
self.reader,
selection,
self.batch_size,
self.transform_fn,
**self._reopen_metadata(),
)
def _with_reader(self, reader: PermutationReader) -> "Permutation":
"""
@@ -411,13 +438,25 @@ class Permutation:
This is an internal method and should not be used directly.
"""
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
return Permutation(
reader,
self.selection,
self.batch_size,
self.transform_fn,
**self._reopen_metadata(),
)
def with_batch_size(self, batch_size: int) -> "Permutation":
"""
Creates a new permutation with the given batch size
"""
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
return Permutation(
self.reader,
self.selection,
batch_size,
self.transform_fn,
**self._reopen_metadata(),
)
@classmethod
def identity(cls, table: LanceTable) -> "Permutation":
@@ -491,7 +530,13 @@ class Permutation:
schema = await reader.output_schema(None)
initial_selection = {name: name for name in schema.names}
return cls(
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
reader,
initial_selection,
DEFAULT_BATCH_SIZE,
Transforms.arrow2python,
base_table=base_table,
permutation_table=permutation_table,
split=split,
)
return LOOP.run(do_from_tables())
@@ -762,7 +807,13 @@ class Permutation:
for expensive operations such as image decoding.
"""
assert transform is not None, "transform is required"
return Permutation(self.reader, self.selection, self.batch_size, transform)
return Permutation(
self.reader,
self.selection,
self.batch_size,
transform,
**self._reopen_metadata(),
)
def __getitem__(self, index: int) -> Any:
"""
@@ -800,7 +851,17 @@ class Permutation:
async def do_with_skip():
reader = await self.reader.with_offset(skip)
return self._with_reader(reader)
return Permutation(
reader,
self.selection,
self.batch_size,
self.transform_fn,
base_table=self._base_table,
permutation_table=self._permutation_table,
split=self._split,
offset=skip,
limit=self._limit,
)
return LOOP.run(do_with_skip())
@@ -823,7 +884,17 @@ class Permutation:
async def do_with_take():
reader = await self.reader.with_limit(limit)
return self._with_reader(reader)
return Permutation(
reader,
self.selection,
self.batch_size,
self.transform_fn,
base_table=self._base_table,
permutation_table=self._permutation_table,
split=self._split,
offset=self._offset,
limit=limit,
)
return LOOP.run(do_with_take())

View File

@@ -562,6 +562,22 @@ def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation:
return Permutation.from_tables(some_table, some_perm_table)
def assert_reopen_metadata(
permutation: Permutation,
*,
base_table: Table,
permutation_table: Table | None,
split: int,
offset: int | None = None,
limit: int | None = None,
):
assert permutation._base_table is base_table
assert permutation._permutation_table is permutation_table
assert permutation._split == split
assert permutation._offset == offset
assert permutation._limit == limit
def test_num_rows(some_permutation: Permutation):
assert some_permutation.num_rows == 950
@@ -599,6 +615,93 @@ def test_limit_offset(some_permutation: Permutation):
some_permutation.with_skip(500).with_take(500).num_rows
def test_reopen_metadata_identity(mem_db: DBConnection):
table = mem_db.create_table("identity_table", pa.table({"id": range(10)}))
permutation = Permutation.identity(table)
assert_reopen_metadata(
permutation,
base_table=table,
permutation_table=None,
split=0,
)
def test_reopen_metadata_split_name_resolution(
some_table: Table, some_perm_table: Table
):
permutation = Permutation.from_tables(some_table, some_perm_table, "test")
assert permutation.num_rows == 50
assert_reopen_metadata(
permutation,
base_table=some_table,
permutation_table=some_perm_table,
split=1,
)
def test_reopen_metadata_propagates_through_derived_permutations(
some_permutation: Permutation, some_table: Table, some_perm_table: Table
):
derived = (
some_permutation.select_columns(["id"])
.rename_column("id", "row_id")
.with_batch_size(32)
.with_format("arrow")
)
assert_reopen_metadata(
derived,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
)
def test_reopen_metadata_tracks_skip_and_take(
some_permutation: Permutation, some_table: Table, some_perm_table: Table
):
skipped = some_permutation.with_skip(100)
assert_reopen_metadata(
skipped,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
offset=100,
)
limited = skipped.with_take(200)
assert_reopen_metadata(
limited,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
offset=100,
limit=200,
)
reskipped = limited.with_skip(25)
assert_reopen_metadata(
reskipped,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
offset=25,
limit=200,
)
retaken = limited.with_take(50)
assert_reopen_metadata(
retaken,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
offset=100,
limit=50,
)
def test_remove_columns(some_permutation: Permutation):
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
[("id", pa.int64())]