mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-03 04:10:41 +00:00
feat(python): make Permutation fork-safe for PyTorch DataLoader workers (#3339)
## Summary
PyTorch's `DataLoader` uses fork-based multiprocessing by default on
Linux, but threads do not survive `fork()`. LanceDB's Python bindings
drive async work through two threaded layers, both of which become inert
in a forked child:
- `BackgroundEventLoop` runs an asyncio loop on a Python
`threading.Thread`.
- `pyo3-async-runtimes::tokio` holds a global multi-threaded tokio
runtime whose worker threads also die on fork — and its runtime lives in
a `OnceLock` that cannot be replaced after first use.
As a result, any `Permutation` (or other async API) used inside a
fork-based `DataLoader` worker hangs indefinitely. This PR makes both
layers fork-safe so `Permutation` works as a `torch.utils.data.Dataset`
with `num_workers > 0`.
## Approach
### Rust — new `python/src/runtime.rs`
Mirrors the pattern used in [Lance's Python
bindings](456198cd6f/python/src/lib.rs (L139)),
adapted for the async-bridge use case.
- `LanceRuntime` implements `pyo3_async_runtimes::generic::Runtime +
ContextExt`, backed by an `AtomicPtr<tokio::runtime::Runtime>` we own
(sidestepping `pyo3-async-runtimes`'s frozen `OnceLock` global).
- A `pthread_atfork(after_in_child)` handler nulls the pointer; the next
`spawn` rebuilds the runtime in the child. The previous runtime is
intentionally **leaked** — calling `Drop` would try to join now-dead
worker threads and hang.
- `runtime::future_into_py` is a drop-in for
`pyo3_async_runtimes::tokio::future_into_py`. All ~80 call sites in
`arrow.rs` / `connection.rs` / `permutation.rs` / `query.rs` /
`table.rs` are updated to route through it.
- `python/Cargo.toml` adds `libc = "0.2"` and the tokio
`rt-multi-thread` feature.
### Python — `lancedb/background_loop.py`
- Refactors `BackgroundEventLoop.__init__` to a reusable `_start()`
method.
- An `os.register_at_fork(after_in_child=…)` hook calls `LOOP._start()`
to give the singleton a fresh asyncio loop and thread **in place**. This
matters because the rest of the codebase imports `LOOP` via `from
.background_loop import LOOP` — rebinding the module attribute would
leave those references holding the dead loop.
### Python — `lancedb/__init__.py`
Removes the `__warn_on_fork` pre-fork warning (and the now-unused
`import warnings`). Fork is supported.
## Test plan
- [x] New `test_permutation_dataloader_fork_workers` in
`python/tests/test_torch.py`: runs a `Permutation` through
`torch.utils.data.DataLoader(num_workers=2,
multiprocessing_context="fork")` inside a spawn-isolated child with a
30s hang detector. **Pre-fix**: timed out at 36s. **Post-fix**: passes
in ~3.6s.
- [x] New `test_remote_connection_after_fork` in
`python/tests/test_remote_db.py`: forks a child that creates a fresh
`lancedb.connect(...)` against a mock HTTP server and calls
`table_names()`; passes in <1s, validates the runtime reset is
sufficient for fresh remote clients.
- [x] All 62 tests in `test_torch.py` + `test_permutation.py` pass.
- [x] All 35 tests in `test_remote_db.py` pass.
- [x] `test_table.py` (87) + `test_db.py` + `test_query.py` (157, minus
one unrelated `sentence_transformers` import skip) — 244 passing.
- [x] `cargo clippy -p lancedb-python --tests` clean.
- [x] `cargo fmt`, `ruff check`, `ruff format` all clean.
## Known limitation (follow-up)
This PR makes a **freshly-built** `lancedb.connect(...)` work in a
forked child. An **inherited** `Connection` from the parent still
carries an inherited `reqwest::Client` whose hyper connection pool
references socket FDs and TCP/TLS state shared with the parent — using
it from the child after fork is unsafe (especially with HTTP/1.1
keep-alive). The recommended pattern for fork-based `DataLoader` workers
that hit a remote DB is to construct a new connection inside the worker.
Auto-clearing inherited HTTP client pools on fork would require tracking
live `Connection` instances in `lancedb` core and is left for a
follow-up PR.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,8 @@ import contextlib
|
||||
from datetime import timedelta
|
||||
import http.server
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -1230,3 +1232,82 @@ def test_background_loop_cancellation(exception):
|
||||
with pytest.raises(exception):
|
||||
loop.run(None)
|
||||
mock_future.cancel.assert_called_once()
|
||||
|
||||
|
||||
def _remote_fork_child(port: int, queue) -> None:
|
||||
# Build a fresh Connection in the child so we exercise the at-fork-child
|
||||
# tokio runtime reset rather than relying on an inherited reqwest client.
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
queue.put(db.table_names())
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
"fork() is unavailable on Windows and unsafe on macOS "
|
||||
"(Apple frameworks/TLS are not fork-safe)"
|
||||
),
|
||||
)
|
||||
def test_remote_connection_after_fork():
|
||||
"""A freshly-built remote Connection in a forked child should not hang.
|
||||
|
||||
The pyo3-async-runtimes tokio runtime would otherwise be inherited from
|
||||
the parent with dead worker threads; the at-fork-child handler in our
|
||||
runtime module rebuilds it on first use in the child.
|
||||
"""
|
||||
|
||||
def handler(request):
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
server = http.server.HTTPServer(("localhost", 0), make_mock_http_handler(handler))
|
||||
port = server.server_address[1]
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.start()
|
||||
try:
|
||||
# Hit the server in the parent first so the runtime + LOOP are warm
|
||||
# before fork; a fresh child must still succeed.
|
||||
parent_db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
assert parent_db.table_names() == []
|
||||
|
||||
ctx = mp.get_context("fork")
|
||||
queue = ctx.Queue()
|
||||
proc = ctx.Process(target=_remote_fork_child, args=(port, queue))
|
||||
proc.start()
|
||||
proc.join(timeout=15)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join()
|
||||
pytest.fail("Remote connection hung after fork")
|
||||
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no result"
|
||||
assert queue.get() == []
|
||||
|
||||
# Parent connection must still be usable after the child returned.
|
||||
assert parent_db.table_names() == []
|
||||
finally:
|
||||
server.shutdown()
|
||||
server_thread.join()
|
||||
|
||||
Reference in New Issue
Block a user