Files
lancedb/python/python/tests/test_torch.py
Weston Pace a17c241e86 feat(python): make Permutation fork-safe for PyTorch DataLoader workers (#3339)
## Summary

PyTorch's `DataLoader` uses fork-based multiprocessing by default on
Linux, but threads do not survive `fork()`. LanceDB's Python bindings
drive async work through two threaded layers, both of which become inert
in a forked child:

- `BackgroundEventLoop` runs an asyncio loop on a Python
`threading.Thread`.
- `pyo3-async-runtimes::tokio` holds a global multi-threaded tokio
runtime whose worker threads also die on fork — and its runtime lives in
a `OnceLock` that cannot be replaced after first use.

As a result, any `Permutation` (or other async API) used inside a
fork-based `DataLoader` worker hangs indefinitely. This PR makes both
layers fork-safe so `Permutation` works as a `torch.utils.data.Dataset`
with `num_workers > 0`.

## Approach

### Rust — new `python/src/runtime.rs`

Mirrors the pattern used in [Lance's Python
bindings](456198cd6f/python/src/lib.rs (L139)),
adapted for the async-bridge use case.

- `LanceRuntime` implements `pyo3_async_runtimes::generic::Runtime +
ContextExt`, backed by an `AtomicPtr<tokio::runtime::Runtime>` we own
(sidestepping `pyo3-async-runtimes`'s frozen `OnceLock` global).
- A `pthread_atfork(after_in_child)` handler nulls the pointer; the next
`spawn` rebuilds the runtime in the child. The previous runtime is
intentionally **leaked** — calling `Drop` would try to join now-dead
worker threads and hang.
- `runtime::future_into_py` is a drop-in for
`pyo3_async_runtimes::tokio::future_into_py`. All ~80 call sites in
`arrow.rs` / `connection.rs` / `permutation.rs` / `query.rs` /
`table.rs` are updated to route through it.
- `python/Cargo.toml` adds `libc = "0.2"` and the tokio
`rt-multi-thread` feature.

### Python — `lancedb/background_loop.py`

- Refactors `BackgroundEventLoop.__init__` to a reusable `_start()`
method.
- An `os.register_at_fork(after_in_child=…)` hook calls `LOOP._start()`
to give the singleton a fresh asyncio loop and thread **in place**. This
matters because the rest of the codebase imports `LOOP` via `from
.background_loop import LOOP` — rebinding the module attribute would
leave those references holding the dead loop.

### Python — `lancedb/__init__.py`

Removes the `__warn_on_fork` pre-fork warning (and the now-unused
`import warnings`). Fork is supported.

## Test plan

- [x] New `test_permutation_dataloader_fork_workers` in
`python/tests/test_torch.py`: runs a `Permutation` through
`torch.utils.data.DataLoader(num_workers=2,
multiprocessing_context="fork")` inside a spawn-isolated child with a
30s hang detector. **Pre-fix**: timed out at 36s. **Post-fix**: passes
in ~3.6s.
- [x] New `test_remote_connection_after_fork` in
`python/tests/test_remote_db.py`: forks a child that creates a fresh
`lancedb.connect(...)` against a mock HTTP server and calls
`table_names()`; passes in <1s, validates the runtime reset is
sufficient for fresh remote clients.
- [x] All 62 tests in `test_torch.py` + `test_permutation.py` pass.
- [x] All 35 tests in `test_remote_db.py` pass.
- [x] `test_table.py` (87) + `test_db.py` + `test_query.py` (157, minus
one unrelated `sentence_transformers` import skip) — 244 passing.
- [x] `cargo clippy -p lancedb-python --tests` clean.
- [x] `cargo fmt`, `ruff check`, `ruff format` all clean.

## Known limitation (follow-up)

This PR makes a **freshly-built** `lancedb.connect(...)` work in a
forked child. An **inherited** `Connection` from the parent still
carries an inherited `reqwest::Client` whose hyper connection pool
references socket FDs and TCP/TLS state shared with the parent — using
it from the child after fork is unsafe (especially with HTTP/1.1
keep-alive). The recommended pattern for fork-based `DataLoader` workers
that hit a remote DB is to construct a new connection inside the worker.
Auto-clearing inherited HTTP client pools on fork would require tracking
live `Connection` instances in `lancedb` core and is left for a
follow-up PR.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 13:44:10 -07:00

211 lines
7.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import functools
import multiprocessing as mp
import pickle
import sys
import lancedb
import pyarrow as pa
import pytest
from lancedb.permutation import Permutation, Permutations, permutation_builder
from lancedb.util import tbl_to_tensor
torch = pytest.importorskip("torch")
def _open_native_table(uri: str, table_name: str):
"""Top-level connection factory used by the explicit-factory pickle test.
Defined at module scope so that pickle can resolve it by name in the
worker / unpickling process.
"""
return lancedb.connect(uri).open_table(table_name)
def test_table_dataloader(mem_db):
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
dataloader = torch.utils.data.DataLoader(
table, collate_fn=tbl_to_tensor, batch_size=10, shuffle=True
)
for batch in dataloader:
assert batch.size(0) == 1
assert batch.size(1) == 10
def test_permutation_dataloader(mem_db):
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert batch["a"].size(0) == 10
permutation = permutation.with_format("torch")
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert batch.size(0) == 10
assert batch.size(1) == 1
permutation = permutation.with_format("torch_col")
dataloader = torch.utils.data.DataLoader(
permutation, collate_fn=lambda x: x, batch_size=10, shuffle=True
)
for batch in dataloader:
assert batch.size(0) == 1
assert batch.size(1) == 10
def test_permutation_is_picklable(tmp_db):
"""A Permutation must be picklable so it can be used with PyTorch's
DataLoader when num_workers > 0 (which uses multiprocessing and pickles
the dataset to pass it to worker processes)."""
table = tmp_db.create_table("test_table", pa.table({"a": range(1000)}))
permutation = Permutation.identity(table)
pickled = pickle.dumps(permutation)
restored = pickle.loads(pickled)
assert len(restored) == 1000
rows = restored.__getitems__([0, 1, 2])
assert rows == [{"a": 0}, {"a": 1}, {"a": 2}]
def test_permutation_with_memory_base_is_picklable(mem_db):
"""An in-memory base table is inlined into the pickle as Arrow IPC bytes
and rebuilt on the other side as an in-memory LanceTable, so the
Permutation round-trips even though the original database can't be
reopened across processes."""
table = mem_db.create_table("test_table", pa.table({"a": range(50)}))
permutation = Permutation.identity(table)
restored = pickle.loads(pickle.dumps(permutation))
assert len(restored) == 50
assert restored.__getitems__([0, 10, 49]) == [{"a": 0}, {"a": 10}, {"a": 49}]
def test_permutation_dataloader_multiprocessing(tmp_db):
"""Using a Permutation with a PyTorch DataLoader that has num_workers > 0
must work end-to-end. Each worker process gets a pickled copy of the
dataset and reads batches from it."""
table = tmp_db.create_table("test_table", pa.table({"a": range(1000)}))
permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(
permutation,
batch_size=10,
shuffle=True,
num_workers=2,
multiprocessing_context="spawn",
)
seen = 0
for batch in dataloader:
assert batch["a"].size(0) == 10
seen += batch["a"].size(0)
assert seen == 1000
def test_permutation_pickle_with_connection_factory(tmp_path):
"""When the user provides a connection_factory, pickling should round-trip
through that factory rather than introspecting the connection URI. Useful
for remote / cloud connections where the URI alone isn't reopenable."""
db = lancedb.connect(tmp_path)
db.create_table("test_table", pa.table({"a": range(50)}))
factory = functools.partial(_open_native_table, str(tmp_path))
permutation = Permutation.identity(factory("test_table")).with_connection_factory(
factory
)
restored = pickle.loads(pickle.dumps(permutation))
assert len(restored) == 50
# The factory survives pickling and is what powered base-table reopen.
assert restored.connection_factory is not None
assert restored.connection_factory.func is _open_native_table
assert restored.__getitems__([0, 1, 2]) == [{"a": 0}, {"a": 1}, {"a": 2}]
def test_permutation_with_builder_is_picklable(tmp_db):
"""A Permutation built from a non-identity permutation table must round-trip
through pickle while preserving the row order defined by the permutation."""
table = tmp_db.create_table("test_table", pa.table({"a": range(100)}))
perm_tbl = (
permutation_builder(table)
.split_random(ratios=[0.8, 0.2], seed=42, split_names=["train", "test"])
.shuffle(seed=42)
.execute()
)
permutations = Permutations(table, perm_tbl)
permutation = permutations["train"]
indices = list(range(len(permutation)))
expected = permutation.__getitems__(indices)
restored = pickle.loads(pickle.dumps(permutation))
assert len(restored) == len(permutation)
assert restored.__getitems__(indices) == expected
def _multiworker_dataloader_target(db_uri: str, result_queue):
import lancedb
from lancedb.permutation import Permutation
db = lancedb.connect(db_uri)
table = db.open_table("test_table")
permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(
permutation,
batch_size=10,
num_workers=2,
multiprocessing_context="fork",
)
count = 0
for batch in dataloader:
assert batch["a"].size(0) == 10
count += 1
result_queue.put(count)
@pytest.mark.skipif(
sys.platform != "linux",
reason=(
"fork() is unavailable on Windows and unsafe on macOS "
"(Apple frameworks/TLS are not fork-safe)"
),
)
def test_permutation_dataloader_fork_workers(tmp_path):
"""A Permutation used by a fork-based DataLoader should not hang.
PyTorch's DataLoader uses fork-based multiprocessing by default on Linux.
LanceDB drives async work through a background asyncio thread that does
not survive a fork, so any LOOP.run() in a worker blocks forever.
"""
import lancedb
db_uri = str(tmp_path / "db")
db = lancedb.connect(db_uri)
db.create_table("test_table", pa.table({"a": list(range(1000))}))
ctx = mp.get_context("spawn")
queue = ctx.Queue()
proc = ctx.Process(target=_multiworker_dataloader_target, args=(db_uri, queue))
proc.start()
proc.join(timeout=30)
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
if proc.is_alive():
proc.kill()
proc.join()
pytest.fail("Permutation hung when iterated in a fork-based DataLoader worker")
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
assert not queue.empty(), "child produced no batches"
assert queue.get() == 100