diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 5d133c309..0d4c46c72 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -385,6 +385,12 @@ class Permutation: selection: dict[str, str], batch_size: int, transform_fn: Callable[pa.RecordBatch, Any], + *, + base_table: LanceTable, + permutation_table: Optional[LanceTable], + split: int, + offset: Optional[int] = None, + limit: Optional[int] = None, ): """ Internal constructor. Use [from_tables](#from_tables) instead. @@ -395,6 +401,21 @@ class Permutation: self.selection = selection self.transform_fn = transform_fn self.batch_size = batch_size + # These fields are used to reconstruct the permutation in a new process. + self._base_table = base_table + self._permutation_table = permutation_table + self._split = split + self._offset = offset + self._limit = limit + + def _reopen_metadata(self) -> dict[str, Any]: + return { + "base_table": self._base_table, + "permutation_table": self._permutation_table, + "split": self._split, + "offset": self._offset, + "limit": self._limit, + } def _with_selection(self, selection: dict[str, str]) -> "Permutation": """ @@ -403,7 +424,13 @@ class Permutation: Does not validation of the selection and it replaces it entirely. This is not intended for public use. """ - return Permutation(self.reader, selection, self.batch_size, self.transform_fn) + return Permutation( + self.reader, + selection, + self.batch_size, + self.transform_fn, + **self._reopen_metadata(), + ) def _with_reader(self, reader: PermutationReader) -> "Permutation": """ @@ -411,13 +438,25 @@ class Permutation: This is an internal method and should not be used directly. """ - return Permutation(reader, self.selection, self.batch_size, self.transform_fn) + return Permutation( + reader, + self.selection, + self.batch_size, + self.transform_fn, + **self._reopen_metadata(), + ) def with_batch_size(self, batch_size: int) -> "Permutation": """ Creates a new permutation with the given batch size """ - return Permutation(self.reader, self.selection, batch_size, self.transform_fn) + return Permutation( + self.reader, + self.selection, + batch_size, + self.transform_fn, + **self._reopen_metadata(), + ) @classmethod def identity(cls, table: LanceTable) -> "Permutation": @@ -491,7 +530,13 @@ class Permutation: schema = await reader.output_schema(None) initial_selection = {name: name for name in schema.names} return cls( - reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python + reader, + initial_selection, + DEFAULT_BATCH_SIZE, + Transforms.arrow2python, + base_table=base_table, + permutation_table=permutation_table, + split=split, ) return LOOP.run(do_from_tables()) @@ -762,7 +807,13 @@ class Permutation: for expensive operations such as image decoding. """ assert transform is not None, "transform is required" - return Permutation(self.reader, self.selection, self.batch_size, transform) + return Permutation( + self.reader, + self.selection, + self.batch_size, + transform, + **self._reopen_metadata(), + ) def __getitem__(self, index: int) -> Any: """ @@ -800,7 +851,17 @@ class Permutation: async def do_with_skip(): reader = await self.reader.with_offset(skip) - return self._with_reader(reader) + return Permutation( + reader, + self.selection, + self.batch_size, + self.transform_fn, + base_table=self._base_table, + permutation_table=self._permutation_table, + split=self._split, + offset=skip, + limit=self._limit, + ) return LOOP.run(do_with_skip()) @@ -823,7 +884,17 @@ class Permutation: async def do_with_take(): reader = await self.reader.with_limit(limit) - return self._with_reader(reader) + return Permutation( + reader, + self.selection, + self.batch_size, + self.transform_fn, + base_table=self._base_table, + permutation_table=self._permutation_table, + split=self._split, + offset=self._offset, + limit=limit, + ) return LOOP.run(do_with_take()) diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index 0223b829c..628542fde 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -562,6 +562,22 @@ def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation: return Permutation.from_tables(some_table, some_perm_table) +def assert_reopen_metadata( + permutation: Permutation, + *, + base_table: Table, + permutation_table: Table | None, + split: int, + offset: int | None = None, + limit: int | None = None, +): + assert permutation._base_table is base_table + assert permutation._permutation_table is permutation_table + assert permutation._split == split + assert permutation._offset == offset + assert permutation._limit == limit + + def test_num_rows(some_permutation: Permutation): assert some_permutation.num_rows == 950 @@ -599,6 +615,93 @@ def test_limit_offset(some_permutation: Permutation): some_permutation.with_skip(500).with_take(500).num_rows +def test_reopen_metadata_identity(mem_db: DBConnection): + table = mem_db.create_table("identity_table", pa.table({"id": range(10)})) + permutation = Permutation.identity(table) + + assert_reopen_metadata( + permutation, + base_table=table, + permutation_table=None, + split=0, + ) + + +def test_reopen_metadata_split_name_resolution( + some_table: Table, some_perm_table: Table +): + permutation = Permutation.from_tables(some_table, some_perm_table, "test") + + assert permutation.num_rows == 50 + assert_reopen_metadata( + permutation, + base_table=some_table, + permutation_table=some_perm_table, + split=1, + ) + + +def test_reopen_metadata_propagates_through_derived_permutations( + some_permutation: Permutation, some_table: Table, some_perm_table: Table +): + derived = ( + some_permutation.select_columns(["id"]) + .rename_column("id", "row_id") + .with_batch_size(32) + .with_format("arrow") + ) + + assert_reopen_metadata( + derived, + base_table=some_table, + permutation_table=some_perm_table, + split=0, + ) + + +def test_reopen_metadata_tracks_skip_and_take( + some_permutation: Permutation, some_table: Table, some_perm_table: Table +): + skipped = some_permutation.with_skip(100) + assert_reopen_metadata( + skipped, + base_table=some_table, + permutation_table=some_perm_table, + split=0, + offset=100, + ) + + limited = skipped.with_take(200) + assert_reopen_metadata( + limited, + base_table=some_table, + permutation_table=some_perm_table, + split=0, + offset=100, + limit=200, + ) + + reskipped = limited.with_skip(25) + assert_reopen_metadata( + reskipped, + base_table=some_table, + permutation_table=some_perm_table, + split=0, + offset=25, + limit=200, + ) + + retaken = limited.with_take(50) + assert_reopen_metadata( + retaken, + base_table=some_table, + permutation_table=some_perm_table, + split=0, + offset=100, + limit=50, + ) + + def test_remove_columns(some_permutation: Permutation): assert some_permutation.remove_columns(["value"]).schema == pa.schema( [("id", pa.int64())]