mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-04 04:40:40 +00:00
fix(python): make Permutation picklable for PyTorch multiprocessing (#3335)
## Summary When pytorch is used with multiprocessing and the mp mode is spawn then the Permutation needs to be pickled. It could not be pickled because `Table` and `Connection` are not serializable. This PR adds pickle support to Permutation without adding general pickle support to `Table` or `Connection`. To add general support we probably need to start by adding serialization in the namespace client. In the meantime this PR enable pickling by adding special cases for: * In-memory tables (just serialize as Arrow IPC) * Native tables (serialize the URI) If a user is not using one of the above cases (e.g. using a remote connection) then they will need to provide a connection factory that can be pickled. ## Breaking change `PermutationBuilder.persist(...)` is removed from the Python bindings; the permutation table is now always in-memory. The underlying Rust `PermutationBuilder::persist` API is untouched and can be re-exposed later if needed. It probably won't make sense to do that until we have a way to serialize `Table` and `Connection`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,14 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import functools
|
||||
import pickle
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.util import tbl_to_tensor
|
||||
from lancedb.permutation import Permutation
|
||||
from lancedb.permutation import Permutation, Permutations, permutation_builder
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def _open_native_table(uri: str, table_name: str):
|
||||
"""Top-level connection factory used by the explicit-factory pickle test.
|
||||
|
||||
Defined at module scope so that pickle can resolve it by name in the
|
||||
worker / unpickling process.
|
||||
"""
|
||||
return lancedb.connect(uri).open_table(table_name)
|
||||
|
||||
|
||||
def test_table_dataloader(mem_db):
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
@@ -40,3 +53,96 @@ def test_permutation_dataloader(mem_db):
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
|
||||
|
||||
def test_permutation_is_picklable(tmp_db):
|
||||
"""A Permutation must be picklable so it can be used with PyTorch's
|
||||
DataLoader when num_workers > 0 (which uses multiprocessing and pickles
|
||||
the dataset to pass it to worker processes)."""
|
||||
table = tmp_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
pickled = pickle.dumps(permutation)
|
||||
restored = pickle.loads(pickled)
|
||||
|
||||
assert len(restored) == 1000
|
||||
rows = restored.__getitems__([0, 1, 2])
|
||||
assert rows == [{"a": 0}, {"a": 1}, {"a": 2}]
|
||||
|
||||
|
||||
def test_permutation_with_memory_base_is_picklable(mem_db):
|
||||
"""An in-memory base table is inlined into the pickle as Arrow IPC bytes
|
||||
and rebuilt on the other side as an in-memory LanceTable, so the
|
||||
Permutation round-trips even though the original database can't be
|
||||
reopened across processes."""
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(50)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert len(restored) == 50
|
||||
assert restored.__getitems__([0, 10, 49]) == [{"a": 0}, {"a": 10}, {"a": 49}]
|
||||
|
||||
|
||||
def test_permutation_dataloader_multiprocessing(tmp_db):
|
||||
"""Using a Permutation with a PyTorch DataLoader that has num_workers > 0
|
||||
must work end-to-end. Each worker process gets a pickled copy of the
|
||||
dataset and reads batches from it."""
|
||||
table = tmp_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation,
|
||||
batch_size=10,
|
||||
shuffle=True,
|
||||
num_workers=2,
|
||||
multiprocessing_context="spawn",
|
||||
)
|
||||
seen = 0
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
seen += batch["a"].size(0)
|
||||
assert seen == 1000
|
||||
|
||||
|
||||
def test_permutation_pickle_with_connection_factory(tmp_path):
|
||||
"""When the user provides a connection_factory, pickling should round-trip
|
||||
through that factory rather than introspecting the connection URI. Useful
|
||||
for remote / cloud connections where the URI alone isn't reopenable."""
|
||||
db = lancedb.connect(tmp_path)
|
||||
db.create_table("test_table", pa.table({"a": range(50)}))
|
||||
|
||||
factory = functools.partial(_open_native_table, str(tmp_path))
|
||||
permutation = Permutation.identity(factory("test_table")).with_connection_factory(
|
||||
factory
|
||||
)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert len(restored) == 50
|
||||
# The factory survives pickling and is what powered base-table reopen.
|
||||
assert restored.connection_factory is not None
|
||||
assert restored.connection_factory.func is _open_native_table
|
||||
assert restored.__getitems__([0, 1, 2]) == [{"a": 0}, {"a": 1}, {"a": 2}]
|
||||
|
||||
|
||||
def test_permutation_with_builder_is_picklable(tmp_db):
|
||||
"""A Permutation built from a non-identity permutation table must round-trip
|
||||
through pickle while preserving the row order defined by the permutation."""
|
||||
table = tmp_db.create_table("test_table", pa.table({"a": range(100)}))
|
||||
perm_tbl = (
|
||||
permutation_builder(table)
|
||||
.split_random(ratios=[0.8, 0.2], seed=42, split_names=["train", "test"])
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
permutations = Permutations(table, perm_tbl)
|
||||
permutation = permutations["train"]
|
||||
|
||||
indices = list(range(len(permutation)))
|
||||
expected = permutation.__getitems__(indices)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert len(restored) == len(permutation)
|
||||
assert restored.__getitems__(indices) == expected
|
||||
|
||||
Reference in New Issue
Block a user