feat(python): make Permutation fork-safe for PyTorch DataLoader workers (#3339)

## Summary

PyTorch's `DataLoader` uses fork-based multiprocessing by default on
Linux, but threads do not survive `fork()`. LanceDB's Python bindings
drive async work through two threaded layers, both of which become inert
in a forked child:

- `BackgroundEventLoop` runs an asyncio loop on a Python
`threading.Thread`.
- `pyo3-async-runtimes::tokio` holds a global multi-threaded tokio
runtime whose worker threads also die on fork — and its runtime lives in
a `OnceLock` that cannot be replaced after first use.

As a result, any `Permutation` (or other async API) used inside a
fork-based `DataLoader` worker hangs indefinitely. This PR makes both
layers fork-safe so `Permutation` works as a `torch.utils.data.Dataset`
with `num_workers > 0`.

## Approach

### Rust — new `python/src/runtime.rs`

Mirrors the pattern used in [Lance's Python
bindings](456198cd6f/python/src/lib.rs (L139)),
adapted for the async-bridge use case.

- `LanceRuntime` implements `pyo3_async_runtimes::generic::Runtime +
ContextExt`, backed by an `AtomicPtr<tokio::runtime::Runtime>` we own
(sidestepping `pyo3-async-runtimes`'s frozen `OnceLock` global).
- A `pthread_atfork(after_in_child)` handler nulls the pointer; the next
`spawn` rebuilds the runtime in the child. The previous runtime is
intentionally **leaked** — calling `Drop` would try to join now-dead
worker threads and hang.
- `runtime::future_into_py` is a drop-in for
`pyo3_async_runtimes::tokio::future_into_py`. All ~80 call sites in
`arrow.rs` / `connection.rs` / `permutation.rs` / `query.rs` /
`table.rs` are updated to route through it.
- `python/Cargo.toml` adds `libc = "0.2"` and the tokio
`rt-multi-thread` feature.

### Python — `lancedb/background_loop.py`

- Refactors `BackgroundEventLoop.__init__` to a reusable `_start()`
method.
- An `os.register_at_fork(after_in_child=…)` hook calls `LOOP._start()`
to give the singleton a fresh asyncio loop and thread **in place**. This
matters because the rest of the codebase imports `LOOP` via `from
.background_loop import LOOP` — rebinding the module attribute would
leave those references holding the dead loop.

### Python — `lancedb/__init__.py`

Removes the `__warn_on_fork` pre-fork warning (and the now-unused
`import warnings`). Fork is supported.

## Test plan

- [x] New `test_permutation_dataloader_fork_workers` in
`python/tests/test_torch.py`: runs a `Permutation` through
`torch.utils.data.DataLoader(num_workers=2,
multiprocessing_context="fork")` inside a spawn-isolated child with a
30s hang detector. **Pre-fix**: timed out at 36s. **Post-fix**: passes
in ~3.6s.
- [x] New `test_remote_connection_after_fork` in
`python/tests/test_remote_db.py`: forks a child that creates a fresh
`lancedb.connect(...)` against a mock HTTP server and calls
`table_names()`; passes in <1s, validates the runtime reset is
sufficient for fresh remote clients.
- [x] All 62 tests in `test_torch.py` + `test_permutation.py` pass.
- [x] All 35 tests in `test_remote_db.py` pass.
- [x] `test_table.py` (87) + `test_db.py` + `test_query.py` (157, minus
one unrelated `sentence_transformers` import skip) — 244 passing.
- [x] `cargo clippy -p lancedb-python --tests` clean.
- [x] `cargo fmt`, `ruff check`, `ruff format` all clean.

## Known limitation (follow-up)

This PR makes a **freshly-built** `lancedb.connect(...)` work in a
forked child. An **inherited** `Connection` from the parent still
carries an inherited `reqwest::Client` whose hyper connection pool
references socket FDs and TCP/TLS state shared with the parent — using
it from the child after fork is unsafe (especially with HTTP/1.1
keep-alive). The recommended pattern for fork-based `DataLoader` workers
that hit a remote DB is to construct a new connection inside the worker.
Auto-clearing inherited HTTP client pools on fork would require tracking
live `Connection` instances in `lancedb` core and is left for a
follow-up PR.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Weston Pace
2026-05-05 13:44:10 -07:00
committed by GitHub
parent 1fc23e5473
commit a17c241e86
13 changed files with 339 additions and 32 deletions

142
python/src/runtime.rs Normal file
View File

@@ -0,0 +1,142 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Fork-safe wrapper around tokio + pyo3-async-runtimes.
//!
//! `pyo3_async_runtimes::tokio` keeps its multi-threaded runtime in a
//! `OnceLock` that can never be replaced. Tokio's worker threads do not
//! survive `fork()`, so once a child inherits a "frozen" runtime, every
//! `future_into_py` call hangs forever.
//!
//! We sidestep the global by routing every future through our own
//! [`LanceRuntime`] (a [`pyo3_async_runtimes::generic::Runtime`] impl) backed
//! by an [`AtomicPtr`] to a tokio runtime that we own. A `pthread_atfork`
//! child handler nulls the pointer; the next `spawn` rebuilds the runtime in
//! the child. This mirrors the pattern used in the Lance Python bindings.
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use pyo3::{Bound, PyAny, PyResult, Python, conversion::IntoPyObject};
use pyo3_async_runtimes::{
TaskLocals,
generic::{ContextExt, JoinError, Runtime},
};
use tokio::{runtime, task};
static RUNTIME: AtomicPtr<runtime::Runtime> = AtomicPtr::new(std::ptr::null_mut());
static RUNTIME_INSTALLING: AtomicBool = AtomicBool::new(false);
static ATFORK_INSTALLED: AtomicBool = AtomicBool::new(false);
fn create_runtime() -> runtime::Runtime {
runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("lancedb-tokio-worker")
.build()
.expect("Failed to build tokio runtime")
}
fn get_runtime() -> &'static runtime::Runtime {
loop {
let ptr = RUNTIME.load(Ordering::SeqCst);
if !ptr.is_null() {
return unsafe { &*ptr };
}
if !RUNTIME_INSTALLING.fetch_or(true, Ordering::SeqCst) {
break;
}
std::thread::yield_now();
}
if !ATFORK_INSTALLED.fetch_or(true, Ordering::SeqCst) {
install_atfork();
}
let new_ptr = Box::into_raw(Box::new(create_runtime()));
RUNTIME.store(new_ptr, Ordering::SeqCst);
unsafe { &*new_ptr }
}
/// Runs in async-signal context after `fork()` in the child. We can only
/// touch atomics here; we deliberately leak the previous runtime because
/// dropping a tokio `Runtime` would try to join its (now-dead) worker
/// threads and hang.
extern "C" fn atfork_child() {
RUNTIME.store(std::ptr::null_mut(), Ordering::SeqCst);
RUNTIME_INSTALLING.store(false, Ordering::SeqCst);
}
#[cfg(not(windows))]
fn install_atfork() {
unsafe { libc::pthread_atfork(None, None, Some(atfork_child)) };
}
#[cfg(windows)]
fn install_atfork() {}
/// Marker type implementing [`Runtime`] over our fork-safe runtime slot.
pub struct LanceRuntime;
/// Newtype wrapper around `tokio::task::JoinError` so we can implement the
/// foreign [`JoinError`] trait without violating orphan rules.
pub struct LanceJoinError(task::JoinError);
impl JoinError for LanceJoinError {
fn is_panic(&self) -> bool {
self.0.is_panic()
}
fn into_panic(self) -> Box<dyn std::any::Any + Send + 'static> {
self.0.into_panic()
}
}
impl Runtime for LanceRuntime {
type JoinError = LanceJoinError;
type JoinHandle = Pin<Box<dyn Future<Output = Result<(), Self::JoinError>> + Send>>;
fn spawn<F>(fut: F) -> Self::JoinHandle
where
F: Future<Output = ()> + Send + 'static,
{
let handle = get_runtime().spawn(fut);
Box::pin(async move { handle.await.map_err(LanceJoinError) })
}
fn spawn_blocking<F>(f: F) -> Self::JoinHandle
where
F: FnOnce() + Send + 'static,
{
let handle = get_runtime().spawn_blocking(f);
Box::pin(async move { handle.await.map_err(LanceJoinError) })
}
}
tokio::task_local! {
static TASK_LOCALS: std::cell::OnceCell<TaskLocals>;
}
impl ContextExt for LanceRuntime {
fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
where
F: Future<Output = R> + Send + 'static,
{
let cell = std::cell::OnceCell::new();
cell.set(locals).unwrap();
Box::pin(TASK_LOCALS.scope(cell, fut))
}
fn get_task_locals() -> Option<TaskLocals> {
TASK_LOCALS
.try_with(|c| c.get().cloned())
.unwrap_or_default()
}
}
/// Drop-in replacement for `pyo3_async_runtimes::tokio::future_into_py` that
/// uses our fork-safe runtime.
pub fn future_into_py<F, T>(py: Python<'_>, fut: F) -> PyResult<Bound<'_, PyAny>>
where
F: Future<Output = PyResult<T>> + Send + 'static,
T: for<'py> IntoPyObject<'py> + Send + 'static,
{
pyo3_async_runtimes::generic::future_into_py::<LanceRuntime, _, T>(py, fut)
}