feat(python): support pickling permutations

This commit is contained in:
Xuanwo
2026-04-09 00:32:48 +08:00
parent 2d380d1669
commit 768d84845c
2 changed files with 243 additions and 117 deletions

View File

@@ -1,21 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import json
import pickle
from datetime import timedelta
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
import pyarrow as pa
from deprecation import deprecated
from lancedb import AsyncConnection, DBConnection
import pyarrow as pa
import json
from ._lancedb import async_permutation_builder, PermutationReader
from .table import LanceTable
from .background_loop import LOOP
from .table import LanceTable
from .util import batch_to_tensor, batch_to_tensor_rows
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
from lancedb.dependencies import pandas as pd, numpy as np, polars as pl
def _builtin_transform(format: str) -> Callable[[pa.RecordBatch], Any]:
if format == "python":
return Transforms.arrow2python
if format == "python_col":
return Transforms.arrow2pythoncol
if format == "numpy":
return Transforms.arrow2numpy
if format == "pandas":
return Transforms.arrow2pandas
if format == "arrow":
return Transforms.arrow2arrow
if format == "torch":
return batch_to_tensor_rows
if format == "torch_col":
return batch_to_tensor
if format == "polars":
return Transforms.arrow2polars()
raise ValueError(f"Invalid format: {format}")
def _table_to_state(
table: Union[LanceTable, dict[str, Any]],
) -> dict[str, Any]:
if isinstance(table, dict):
return table
if not isinstance(table, LanceTable):
raise pickle.PicklingError(
"Permutation pickling only supports LanceTable-backed permutations"
)
if table._namespace_client is not None:
raise pickle.PicklingError(
"Permutation pickling does not yet support namespace-backed tables"
)
if table._conn.uri.startswith("memory://"):
raise pickle.PicklingError(
"Permutation pickling does not support in-memory databases"
)
try:
read_consistency_interval = table._conn.read_consistency_interval
except Exception:
read_consistency_interval = None
return {
"uri": table._conn.uri,
"name": table.name,
"version": table.version,
"storage_options": table.initial_storage_options(),
"read_consistency_interval_secs": (
read_consistency_interval.total_seconds()
if read_consistency_interval is not None
else None
),
"namespace_path": list(table.namespace),
}
def _table_from_state(state: dict[str, Any]) -> LanceTable:
from . import connect
read_consistency_interval = (
timedelta(seconds=state["read_consistency_interval_secs"])
if state["read_consistency_interval_secs"] is not None
else None
)
db = connect(
state["uri"],
read_consistency_interval=read_consistency_interval,
storage_options=state["storage_options"],
)
table = db.open_table(state["name"], namespace_path=state["namespace_path"])
table.checkout(state["version"])
return table
class PermutationBuilder:
"""
A utility for creating a "permutation table" which is a table that defines an
@@ -386,11 +463,12 @@ class Permutation:
batch_size: int,
transform_fn: Callable[pa.RecordBatch, Any],
*,
base_table: LanceTable,
permutation_table: Optional[LanceTable],
base_table: Union[LanceTable, dict[str, Any]],
permutation_table: Optional[Union[LanceTable, dict[str, Any]]],
split: int,
offset: Optional[int] = None,
limit: Optional[int] = None,
transform_spec: Optional[str] = None,
):
"""
Internal constructor. Use [from_tables](#from_tables) instead.
@@ -401,6 +479,7 @@ class Permutation:
self.selection = selection
self.transform_fn = transform_fn
self.batch_size = batch_size
self._transform_spec = transform_spec
# These fields are used to reconstruct the permutation in a new process.
self._base_table = base_table
self._permutation_table = permutation_table
@@ -415,8 +494,79 @@ class Permutation:
"split": self._split,
"offset": self._offset,
"limit": self._limit,
"transform_spec": self._transform_spec,
}
def __getstate__(self) -> dict[str, Any]:
if self._transform_spec is not None:
transform_state = {
"kind": "builtin",
"format": self._transform_spec,
}
else:
transform_state = {
"kind": "callable",
"transform_fn": self.transform_fn,
}
return {
"selection": self.selection,
"batch_size": self.batch_size,
"transform": transform_state,
"reopen": {
**self._reopen_metadata(),
# Store reopen state instead of live LanceTable handles.
"base_table": _table_to_state(self._base_table),
"permutation_table": (
_table_to_state(self._permutation_table)
if self._permutation_table is not None
else None
),
},
}
def __setstate__(self, state: dict[str, Any]) -> None:
reopen = state["reopen"]
base_table = _table_from_state(reopen["base_table"])
permutation_table_state = reopen["permutation_table"]
permutation_table = (
_table_from_state(permutation_table_state)
if permutation_table_state is not None
else None
)
split = reopen["split"]
offset = reopen["offset"]
limit = reopen["limit"]
async def do_reopen():
reader = await PermutationReader.from_tables(
base_table, permutation_table, split
)
if offset is not None:
reader = await reader.with_offset(offset)
if limit is not None:
reader = await reader.with_limit(limit)
return reader
transform = state["transform"]
if transform["kind"] == "builtin":
transform_spec = transform["format"]
transform_fn = _builtin_transform(transform_spec)
else:
transform_spec = None
transform_fn = transform["transform_fn"]
self.reader = LOOP.run(do_reopen())
self.selection = state["selection"]
self.batch_size = state["batch_size"]
self.transform_fn = transform_fn
self._transform_spec = transform_spec
self._base_table = reopen["base_table"]
self._permutation_table = permutation_table_state
self._split = split
self._offset = offset
self._limit = limit
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
"""
Creates a new permutation with the given selection
@@ -537,6 +687,7 @@ class Permutation:
base_table=base_table,
permutation_table=permutation_table,
split=split,
transform_spec="python",
)
return LOOP.run(do_from_tables())
@@ -777,24 +928,16 @@ class Permutation:
this method.
"""
assert format is not None, "format is required"
if format == "python":
return self.with_transform(Transforms.arrow2python)
if format == "python_col":
return self.with_transform(Transforms.arrow2pythoncol)
elif format == "numpy":
return self.with_transform(Transforms.arrow2numpy)
elif format == "pandas":
return self.with_transform(Transforms.arrow2pandas)
elif format == "arrow":
return self.with_transform(Transforms.arrow2arrow)
elif format == "torch":
return self.with_transform(batch_to_tensor_rows)
elif format == "torch_col":
return self.with_transform(batch_to_tensor)
elif format == "polars":
return self.with_transform(Transforms.arrow2polars())
else:
raise ValueError(f"Invalid format: {format}")
return Permutation(
self.reader,
self.selection,
self.batch_size,
_builtin_transform(format),
**{
**self._reopen_metadata(),
"transform_spec": format,
},
)
def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation":
"""
@@ -812,7 +955,10 @@ class Permutation:
self.selection,
self.batch_size,
transform,
**self._reopen_metadata(),
**{
**self._reopen_metadata(),
"transform_spec": None,
},
)
def __getitem__(self, index: int) -> Any:
@@ -856,11 +1002,10 @@ class Permutation:
self.selection,
self.batch_size,
self.transform_fn,
base_table=self._base_table,
permutation_table=self._permutation_table,
split=self._split,
offset=skip,
limit=self._limit,
**{
**self._reopen_metadata(),
"offset": skip,
},
)
return LOOP.run(do_with_skip())
@@ -889,11 +1034,10 @@ class Permutation:
self.selection,
self.batch_size,
self.transform_fn,
base_table=self._base_table,
permutation_table=self._permutation_table,
split=self._split,
offset=self._offset,
limit=limit,
**{
**self._reopen_metadata(),
"limit": limit,
},
)
return LOOP.run(do_with_take())

View File

@@ -3,6 +3,7 @@
import pyarrow as pa
import math
import pickle
import pytest
from lancedb import DBConnection, Table, connect
@@ -562,22 +563,6 @@ def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation:
return Permutation.from_tables(some_table, some_perm_table)
def assert_reopen_metadata(
permutation: Permutation,
*,
base_table: Table,
permutation_table: Table | None,
split: int,
offset: int | None = None,
limit: int | None = None,
):
assert permutation._base_table is base_table
assert permutation._permutation_table is permutation_table
assert permutation._split == split
assert permutation._offset == offset
assert permutation._limit == limit
def test_num_rows(some_permutation: Permutation):
assert some_permutation.num_rows == 950
@@ -615,91 +600,88 @@ def test_limit_offset(some_permutation: Permutation):
some_permutation.with_skip(500).with_take(500).num_rows
def test_reopen_metadata_identity(mem_db: DBConnection):
def test_permutation_pickle_rejects_in_memory_tables(mem_db: DBConnection):
table = mem_db.create_table("identity_table", pa.table({"id": range(10)}))
permutation = Permutation.identity(table)
assert_reopen_metadata(
permutation,
base_table=table,
permutation_table=None,
split=0,
with pytest.raises(
pickle.PicklingError,
match="in-memory databases",
):
pickle.dumps(permutation)
def test_identity_permutation_pickle_roundtrip_preserves_table_version(tmp_path):
db = connect(tmp_path)
table = db.create_table(
"identity_table",
pa.table({"id": range(10), "value": range(10)}),
)
permutation = (
Permutation.identity(table)
.with_skip(2)
.with_take(3)
.with_format("python_col")
)
payload = pickle.dumps(permutation)
table.add(pa.table({"id": [10], "value": [10]}))
def test_reopen_metadata_split_name_resolution(
some_table: Table, some_perm_table: Table
):
permutation = Permutation.from_tables(some_table, some_perm_table, "test")
restored = pickle.loads(payload)
assert restored.num_rows == 3
batches = list(restored.iter(10, skip_last_batch=False))
assert batches == [{"id": [2, 3, 4], "value": [2, 3, 4]}]
assert permutation.num_rows == 50
assert_reopen_metadata(
permutation,
base_table=some_table,
permutation_table=some_perm_table,
split=1,
def test_permutation_pickle_roundtrip_with_persisted_permutation_table(tmp_path):
db = connect(tmp_path)
table = db.create_table(
"base_table",
pa.table({"id": range(1000), "value": range(1000)}),
)
def test_reopen_metadata_propagates_through_derived_permutations(
some_permutation: Permutation, some_table: Table, some_perm_table: Table
):
derived = (
some_permutation.select_columns(["id"])
permutation_table = (
permutation_builder(table)
.split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"])
.shuffle(seed=42)
.persist(db, "persisted_permutation")
.execute()
)
permutation = (
Permutation.from_tables(table, permutation_table, "test")
.select_columns(["id"])
.rename_column("id", "row_id")
.with_batch_size(32)
.with_skip(5)
.with_take(10)
.with_format("arrow")
)
assert_reopen_metadata(
derived,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
)
restored = pickle.loads(pickle.dumps(permutation))
assert restored.batch_size == 32
assert restored.column_names == ["row_id"]
assert restored.num_rows == 10
assert restored.__getitems__([0, 1, 2]).to_pylist() == permutation.__getitems__(
[0, 1, 2]
).to_pylist()
def test_reopen_metadata_tracks_skip_and_take(
some_permutation: Permutation, some_table: Table, some_perm_table: Table
):
skipped = some_permutation.with_skip(100)
assert_reopen_metadata(
skipped,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
offset=100,
)
def test_permutation_pickle_roundtrip_preserves_builtin_polars_format(tmp_path):
pl = pytest.importorskip("polars")
limited = skipped.with_take(200)
assert_reopen_metadata(
limited,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
offset=100,
limit=200,
db = connect(tmp_path)
table = db.create_table(
"polars_table",
pa.table({"id": range(5), "value": range(5)}),
)
permutation = Permutation.identity(table).with_take(2).with_format("polars")
reskipped = limited.with_skip(25)
assert_reopen_metadata(
reskipped,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
offset=25,
limit=200,
)
restored = pickle.loads(pickle.dumps(permutation))
batch = restored.__getitems__([0, 1])
retaken = limited.with_take(50)
assert_reopen_metadata(
retaken,
base_table=some_table,
permutation_table=some_perm_table,
split=0,
offset=100,
limit=50,
)
assert isinstance(batch, pl.DataFrame)
assert batch.to_dict(as_series=False) == {"id": [0, 1], "value": [0, 1]}
def test_remove_columns(some_permutation: Permutation):