diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index 0d4c46c72..7b3091747 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -1,21 +1,98 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors +import json +import pickle +from datetime import timedelta +from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union + +import pyarrow as pa from deprecation import deprecated from lancedb import AsyncConnection, DBConnection -import pyarrow as pa -import json from ._lancedb import async_permutation_builder, PermutationReader -from .table import LanceTable from .background_loop import LOOP +from .table import LanceTable from .util import batch_to_tensor, batch_to_tensor_rows -from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union if TYPE_CHECKING: from lancedb.dependencies import pandas as pd, numpy as np, polars as pl +def _builtin_transform(format: str) -> Callable[[pa.RecordBatch], Any]: + if format == "python": + return Transforms.arrow2python + if format == "python_col": + return Transforms.arrow2pythoncol + if format == "numpy": + return Transforms.arrow2numpy + if format == "pandas": + return Transforms.arrow2pandas + if format == "arrow": + return Transforms.arrow2arrow + if format == "torch": + return batch_to_tensor_rows + if format == "torch_col": + return batch_to_tensor + if format == "polars": + return Transforms.arrow2polars() + raise ValueError(f"Invalid format: {format}") + + +def _table_to_state( + table: Union[LanceTable, dict[str, Any]], +) -> dict[str, Any]: + if isinstance(table, dict): + return table + if not isinstance(table, LanceTable): + raise pickle.PicklingError( + "Permutation pickling only supports LanceTable-backed permutations" + ) + if table._namespace_client is not None: + raise pickle.PicklingError( + "Permutation pickling does not yet support namespace-backed tables" + ) + if table._conn.uri.startswith("memory://"): + raise pickle.PicklingError( + "Permutation pickling does not support in-memory databases" + ) + + try: + read_consistency_interval = table._conn.read_consistency_interval + except Exception: + read_consistency_interval = None + return { + "uri": table._conn.uri, + "name": table.name, + "version": table.version, + "storage_options": table.initial_storage_options(), + "read_consistency_interval_secs": ( + read_consistency_interval.total_seconds() + if read_consistency_interval is not None + else None + ), + "namespace_path": list(table.namespace), + } + + +def _table_from_state(state: dict[str, Any]) -> LanceTable: + from . import connect + + read_consistency_interval = ( + timedelta(seconds=state["read_consistency_interval_secs"]) + if state["read_consistency_interval_secs"] is not None + else None + ) + db = connect( + state["uri"], + read_consistency_interval=read_consistency_interval, + storage_options=state["storage_options"], + ) + table = db.open_table(state["name"], namespace_path=state["namespace_path"]) + table.checkout(state["version"]) + return table + + class PermutationBuilder: """ A utility for creating a "permutation table" which is a table that defines an @@ -386,11 +463,12 @@ class Permutation: batch_size: int, transform_fn: Callable[pa.RecordBatch, Any], *, - base_table: LanceTable, - permutation_table: Optional[LanceTable], + base_table: Union[LanceTable, dict[str, Any]], + permutation_table: Optional[Union[LanceTable, dict[str, Any]]], split: int, offset: Optional[int] = None, limit: Optional[int] = None, + transform_spec: Optional[str] = None, ): """ Internal constructor. Use [from_tables](#from_tables) instead. @@ -401,6 +479,7 @@ class Permutation: self.selection = selection self.transform_fn = transform_fn self.batch_size = batch_size + self._transform_spec = transform_spec # These fields are used to reconstruct the permutation in a new process. self._base_table = base_table self._permutation_table = permutation_table @@ -415,8 +494,79 @@ class Permutation: "split": self._split, "offset": self._offset, "limit": self._limit, + "transform_spec": self._transform_spec, } + def __getstate__(self) -> dict[str, Any]: + if self._transform_spec is not None: + transform_state = { + "kind": "builtin", + "format": self._transform_spec, + } + else: + transform_state = { + "kind": "callable", + "transform_fn": self.transform_fn, + } + + return { + "selection": self.selection, + "batch_size": self.batch_size, + "transform": transform_state, + "reopen": { + **self._reopen_metadata(), + # Store reopen state instead of live LanceTable handles. + "base_table": _table_to_state(self._base_table), + "permutation_table": ( + _table_to_state(self._permutation_table) + if self._permutation_table is not None + else None + ), + }, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + reopen = state["reopen"] + base_table = _table_from_state(reopen["base_table"]) + permutation_table_state = reopen["permutation_table"] + permutation_table = ( + _table_from_state(permutation_table_state) + if permutation_table_state is not None + else None + ) + split = reopen["split"] + offset = reopen["offset"] + limit = reopen["limit"] + + async def do_reopen(): + reader = await PermutationReader.from_tables( + base_table, permutation_table, split + ) + if offset is not None: + reader = await reader.with_offset(offset) + if limit is not None: + reader = await reader.with_limit(limit) + return reader + + transform = state["transform"] + if transform["kind"] == "builtin": + transform_spec = transform["format"] + transform_fn = _builtin_transform(transform_spec) + else: + transform_spec = None + transform_fn = transform["transform_fn"] + + self.reader = LOOP.run(do_reopen()) + self.selection = state["selection"] + self.batch_size = state["batch_size"] + self.transform_fn = transform_fn + self._transform_spec = transform_spec + self._base_table = reopen["base_table"] + self._permutation_table = permutation_table_state + self._split = split + self._offset = offset + self._limit = limit + def _with_selection(self, selection: dict[str, str]) -> "Permutation": """ Creates a new permutation with the given selection @@ -537,6 +687,7 @@ class Permutation: base_table=base_table, permutation_table=permutation_table, split=split, + transform_spec="python", ) return LOOP.run(do_from_tables()) @@ -777,24 +928,16 @@ class Permutation: this method. """ assert format is not None, "format is required" - if format == "python": - return self.with_transform(Transforms.arrow2python) - if format == "python_col": - return self.with_transform(Transforms.arrow2pythoncol) - elif format == "numpy": - return self.with_transform(Transforms.arrow2numpy) - elif format == "pandas": - return self.with_transform(Transforms.arrow2pandas) - elif format == "arrow": - return self.with_transform(Transforms.arrow2arrow) - elif format == "torch": - return self.with_transform(batch_to_tensor_rows) - elif format == "torch_col": - return self.with_transform(batch_to_tensor) - elif format == "polars": - return self.with_transform(Transforms.arrow2polars()) - else: - raise ValueError(f"Invalid format: {format}") + return Permutation( + self.reader, + self.selection, + self.batch_size, + _builtin_transform(format), + **{ + **self._reopen_metadata(), + "transform_spec": format, + }, + ) def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation": """ @@ -812,7 +955,10 @@ class Permutation: self.selection, self.batch_size, transform, - **self._reopen_metadata(), + **{ + **self._reopen_metadata(), + "transform_spec": None, + }, ) def __getitem__(self, index: int) -> Any: @@ -856,11 +1002,10 @@ class Permutation: self.selection, self.batch_size, self.transform_fn, - base_table=self._base_table, - permutation_table=self._permutation_table, - split=self._split, - offset=skip, - limit=self._limit, + **{ + **self._reopen_metadata(), + "offset": skip, + }, ) return LOOP.run(do_with_skip()) @@ -889,11 +1034,10 @@ class Permutation: self.selection, self.batch_size, self.transform_fn, - base_table=self._base_table, - permutation_table=self._permutation_table, - split=self._split, - offset=self._offset, - limit=limit, + **{ + **self._reopen_metadata(), + "limit": limit, + }, ) return LOOP.run(do_with_take()) diff --git a/python/python/tests/test_permutation.py b/python/python/tests/test_permutation.py index 628542fde..dc44e08e7 100644 --- a/python/python/tests/test_permutation.py +++ b/python/python/tests/test_permutation.py @@ -3,6 +3,7 @@ import pyarrow as pa import math +import pickle import pytest from lancedb import DBConnection, Table, connect @@ -562,22 +563,6 @@ def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation: return Permutation.from_tables(some_table, some_perm_table) -def assert_reopen_metadata( - permutation: Permutation, - *, - base_table: Table, - permutation_table: Table | None, - split: int, - offset: int | None = None, - limit: int | None = None, -): - assert permutation._base_table is base_table - assert permutation._permutation_table is permutation_table - assert permutation._split == split - assert permutation._offset == offset - assert permutation._limit == limit - - def test_num_rows(some_permutation: Permutation): assert some_permutation.num_rows == 950 @@ -615,91 +600,88 @@ def test_limit_offset(some_permutation: Permutation): some_permutation.with_skip(500).with_take(500).num_rows -def test_reopen_metadata_identity(mem_db: DBConnection): + +def test_permutation_pickle_rejects_in_memory_tables(mem_db: DBConnection): table = mem_db.create_table("identity_table", pa.table({"id": range(10)})) permutation = Permutation.identity(table) - assert_reopen_metadata( - permutation, - base_table=table, - permutation_table=None, - split=0, + with pytest.raises( + pickle.PicklingError, + match="in-memory databases", + ): + pickle.dumps(permutation) + + +def test_identity_permutation_pickle_roundtrip_preserves_table_version(tmp_path): + db = connect(tmp_path) + table = db.create_table( + "identity_table", + pa.table({"id": range(10), "value": range(10)}), + ) + permutation = ( + Permutation.identity(table) + .with_skip(2) + .with_take(3) + .with_format("python_col") ) + payload = pickle.dumps(permutation) + table.add(pa.table({"id": [10], "value": [10]})) -def test_reopen_metadata_split_name_resolution( - some_table: Table, some_perm_table: Table -): - permutation = Permutation.from_tables(some_table, some_perm_table, "test") + restored = pickle.loads(payload) + assert restored.num_rows == 3 + batches = list(restored.iter(10, skip_last_batch=False)) + assert batches == [{"id": [2, 3, 4], "value": [2, 3, 4]}] - assert permutation.num_rows == 50 - assert_reopen_metadata( - permutation, - base_table=some_table, - permutation_table=some_perm_table, - split=1, + +def test_permutation_pickle_roundtrip_with_persisted_permutation_table(tmp_path): + db = connect(tmp_path) + table = db.create_table( + "base_table", + pa.table({"id": range(1000), "value": range(1000)}), ) - - -def test_reopen_metadata_propagates_through_derived_permutations( - some_permutation: Permutation, some_table: Table, some_perm_table: Table -): - derived = ( - some_permutation.select_columns(["id"]) + permutation_table = ( + permutation_builder(table) + .split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"]) + .shuffle(seed=42) + .persist(db, "persisted_permutation") + .execute() + ) + permutation = ( + Permutation.from_tables(table, permutation_table, "test") + .select_columns(["id"]) .rename_column("id", "row_id") .with_batch_size(32) + .with_skip(5) + .with_take(10) .with_format("arrow") ) - assert_reopen_metadata( - derived, - base_table=some_table, - permutation_table=some_perm_table, - split=0, - ) + restored = pickle.loads(pickle.dumps(permutation)) + + assert restored.batch_size == 32 + assert restored.column_names == ["row_id"] + assert restored.num_rows == 10 + assert restored.__getitems__([0, 1, 2]).to_pylist() == permutation.__getitems__( + [0, 1, 2] + ).to_pylist() -def test_reopen_metadata_tracks_skip_and_take( - some_permutation: Permutation, some_table: Table, some_perm_table: Table -): - skipped = some_permutation.with_skip(100) - assert_reopen_metadata( - skipped, - base_table=some_table, - permutation_table=some_perm_table, - split=0, - offset=100, - ) +def test_permutation_pickle_roundtrip_preserves_builtin_polars_format(tmp_path): + pl = pytest.importorskip("polars") - limited = skipped.with_take(200) - assert_reopen_metadata( - limited, - base_table=some_table, - permutation_table=some_perm_table, - split=0, - offset=100, - limit=200, + db = connect(tmp_path) + table = db.create_table( + "polars_table", + pa.table({"id": range(5), "value": range(5)}), ) + permutation = Permutation.identity(table).with_take(2).with_format("polars") - reskipped = limited.with_skip(25) - assert_reopen_metadata( - reskipped, - base_table=some_table, - permutation_table=some_perm_table, - split=0, - offset=25, - limit=200, - ) + restored = pickle.loads(pickle.dumps(permutation)) + batch = restored.__getitems__([0, 1]) - retaken = limited.with_take(50) - assert_reopen_metadata( - retaken, - base_table=some_table, - permutation_table=some_perm_table, - split=0, - offset=100, - limit=50, - ) + assert isinstance(batch, pl.DataFrame) + assert batch.to_dict(as_series=False) == {"id": [0, 1], "value": [0, 1]} def test_remove_columns(some_permutation: Permutation):