diff --git a/Cargo.lock b/Cargo.lock index 2c2c3fd3c..ce8f0bdb5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4698,6 +4698,7 @@ dependencies = [ "lance-namespace", "lance-namespace-impls", "lancedb", + "libc", "log", "pin-project", "pyo3", diff --git a/python/Cargo.toml b/python/Cargo.toml index 0811264e6..fce27e65a 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -35,7 +35,8 @@ futures.workspace = true serde = "1" serde_json = "1" snafu.workspace = true -tokio = { version = "1.40", features = ["sync"] } +tokio = { version = "1.40", features = ["sync", "rt-multi-thread"] } +libc = "0.2" [build-dependencies] pyo3-build-config = { version = "0.28", features = [ diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 9e8ee0dd8..efeed258f 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -7,7 +7,6 @@ import os from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from typing import Dict, Optional, Union, Any, List -import warnings __version__ = importlib.metadata.version("lancedb") @@ -438,13 +437,3 @@ __all__ = [ "Table", "__version__", ] - - -def __warn_on_fork(): - warnings.warn( - "lance is not fork-safe. If you are using multiprocessing, use spawn instead.", - ) - - -if hasattr(os, "register_at_fork"): - os.register_at_fork(before=__warn_on_fork) # type: ignore[attr-defined] diff --git a/python/python/lancedb/background_loop.py b/python/python/lancedb/background_loop.py index d132dd82d..b39da229d 100644 --- a/python/python/lancedb/background_loop.py +++ b/python/python/lancedb/background_loop.py @@ -2,7 +2,9 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors import asyncio +import os import threading +import warnings class BackgroundEventLoop: @@ -13,6 +15,9 @@ class BackgroundEventLoop: """ def __init__(self): + self._start() + + def _start(self): self.loop = asyncio.new_event_loop() self.thread = threading.Thread( target=self.loop.run_forever, @@ -31,3 +36,30 @@ class BackgroundEventLoop: LOOP = BackgroundEventLoop() + +_FORK_WARNED = False + + +def _reset_after_fork(): + # Threads do not survive fork(), so the asyncio loop in LOOP.thread is + # dead in the child. Re-initialize the singleton in place so existing + # `from .background_loop import LOOP` references in other modules see + # the new state. The Rust-side tokio runtime is reset analogously by a + # pthread_atfork hook installed in the _lancedb extension. + LOOP._start() + global _FORK_WARNED + if not _FORK_WARNED: + _FORK_WARNED = True + warnings.warn( + "lancedb fork support is experimental: the internal async " + "runtime has been reset in the forked child, but a small chance " + "of deadlock remains if other state was mid-operation at fork " + "time. The 'forkserver' or 'spawn' multiprocessing start method " + "is likely a safer alternative.", + RuntimeWarning, + stacklevel=2, + ) + + +if hasattr(os, "register_at_fork"): + os.register_at_fork(after_in_child=_reset_after_fork) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 0dd880cc0..a499275c5 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -6,6 +6,8 @@ import contextlib from datetime import timedelta import http.server import json +import multiprocessing as mp +import sys import threading import time from unittest.mock import MagicMock, patch @@ -1230,3 +1232,82 @@ def test_background_loop_cancellation(exception): with pytest.raises(exception): loop.run(None) mock_future.cancel.assert_called_once() + + +def _remote_fork_child(port: int, queue) -> None: + # Build a fresh Connection in the child so we exercise the at-fork-child + # tokio runtime reset rather than relying on an inherited reqwest client. + db = lancedb.connect( + "db://dev", + api_key="fake", + host_override=f"http://localhost:{port}", + client_config={ + "retry_config": {"retries": 0}, + "timeout_config": {"connect_timeout": 2, "read_timeout": 2}, + }, + ) + queue.put(db.table_names()) + + +@pytest.mark.skipif( + sys.platform != "linux", + reason=( + "fork() is unavailable on Windows and unsafe on macOS " + "(Apple frameworks/TLS are not fork-safe)" + ), +) +def test_remote_connection_after_fork(): + """A freshly-built remote Connection in a forked child should not hang. + + The pyo3-async-runtimes tokio runtime would otherwise be inherited from + the parent with dead worker threads; the at-fork-child handler in our + runtime module rebuilds it on first use in the child. + """ + + def handler(request): + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b'{"tables": []}') + + server = http.server.HTTPServer(("localhost", 0), make_mock_http_handler(handler)) + port = server.server_address[1] + server_thread = threading.Thread(target=server.serve_forever) + server_thread.start() + try: + # Hit the server in the parent first so the runtime + LOOP are warm + # before fork; a fresh child must still succeed. + parent_db = lancedb.connect( + "db://dev", + api_key="fake", + host_override=f"http://localhost:{port}", + client_config={ + "retry_config": {"retries": 0}, + "timeout_config": {"connect_timeout": 2, "read_timeout": 2}, + }, + ) + assert parent_db.table_names() == [] + + ctx = mp.get_context("fork") + queue = ctx.Queue() + proc = ctx.Process(target=_remote_fork_child, args=(port, queue)) + proc.start() + proc.join(timeout=15) + + if proc.is_alive(): + proc.terminate() + proc.join(timeout=5) + if proc.is_alive(): + proc.kill() + proc.join() + pytest.fail("Remote connection hung after fork") + + assert proc.exitcode == 0, f"child exited with code {proc.exitcode}" + assert not queue.empty(), "child produced no result" + assert queue.get() == [] + + # Parent connection must still be usable after the child returned. + assert parent_db.table_names() == [] + finally: + server.shutdown() + server_thread.join() diff --git a/python/python/tests/test_torch.py b/python/python/tests/test_torch.py index 0ca1de3e8..d17e60bbd 100644 --- a/python/python/tests/test_torch.py +++ b/python/python/tests/test_torch.py @@ -2,13 +2,15 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors import functools +import multiprocessing as mp import pickle +import sys import lancedb import pyarrow as pa import pytest -from lancedb.util import tbl_to_tensor from lancedb.permutation import Permutation, Permutations, permutation_builder +from lancedb.util import tbl_to_tensor torch = pytest.importorskip("torch") @@ -146,3 +148,63 @@ def test_permutation_with_builder_is_picklable(tmp_db): assert len(restored) == len(permutation) assert restored.__getitems__(indices) == expected + + +def _multiworker_dataloader_target(db_uri: str, result_queue): + import lancedb + from lancedb.permutation import Permutation + + db = lancedb.connect(db_uri) + table = db.open_table("test_table") + permutation = Permutation.identity(table) + + dataloader = torch.utils.data.DataLoader( + permutation, + batch_size=10, + num_workers=2, + multiprocessing_context="fork", + ) + count = 0 + for batch in dataloader: + assert batch["a"].size(0) == 10 + count += 1 + result_queue.put(count) + + +@pytest.mark.skipif( + sys.platform != "linux", + reason=( + "fork() is unavailable on Windows and unsafe on macOS " + "(Apple frameworks/TLS are not fork-safe)" + ), +) +def test_permutation_dataloader_fork_workers(tmp_path): + """A Permutation used by a fork-based DataLoader should not hang. + + PyTorch's DataLoader uses fork-based multiprocessing by default on Linux. + LanceDB drives async work through a background asyncio thread that does + not survive a fork, so any LOOP.run() in a worker blocks forever. + """ + import lancedb + + db_uri = str(tmp_path / "db") + db = lancedb.connect(db_uri) + db.create_table("test_table", pa.table({"a": list(range(1000))})) + + ctx = mp.get_context("spawn") + queue = ctx.Queue() + proc = ctx.Process(target=_multiworker_dataloader_target, args=(db_uri, queue)) + proc.start() + proc.join(timeout=30) + + if proc.is_alive(): + proc.terminate() + proc.join(timeout=5) + if proc.is_alive(): + proc.kill() + proc.join() + pytest.fail("Permutation hung when iterated in a fork-based DataLoader worker") + + assert proc.exitcode == 0, f"child exited with code {proc.exitcode}" + assert not queue.empty(), "child produced no batches" + assert queue.get() == 100 diff --git a/python/src/arrow.rs b/python/src/arrow.rs index fd3a05964..f0b4bceed 100644 --- a/python/src/arrow.rs +++ b/python/src/arrow.rs @@ -3,6 +3,8 @@ use std::sync::Arc; +use crate::error::PythonErrorExt; +use crate::runtime::future_into_py; use arrow::{ datatypes::SchemaRef, pyarrow::{IntoPyArrow, ToPyArrow}, @@ -12,9 +14,6 @@ use lancedb::arrow::SendableRecordBatchStream; use pyo3::{ Bound, Py, PyAny, PyRef, PyResult, Python, exceptions::PyStopAsyncIteration, pyclass, pymethods, }; -use pyo3_async_runtimes::tokio::future_into_py; - -use crate::error::PythonErrorExt; #[pyclass] pub struct RecordBatchStream { diff --git a/python/src/connection.rs b/python/src/connection.rs index 1b12c33ab..703b44424 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -7,6 +7,12 @@ use std::{ time::Duration, }; +use crate::{ + error::PythonErrorExt, + namespace::{create_namespace_storage_options_provider, extract_namespace_arc}, + runtime::future_into_py, + table::Table, +}; use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow}; use lancedb::{ connection::Connection as LanceConnection, @@ -20,13 +26,6 @@ use pyo3::{ pyclass, pyfunction, pymethods, types::{PyDict, PyDictMethods}, }; -use pyo3_async_runtimes::tokio::future_into_py; - -use crate::{ - error::PythonErrorExt, - namespace::{create_namespace_storage_options_provider, extract_namespace_arc}, - table::Table, -}; #[pyclass] pub struct Connection { diff --git a/python/src/lib.rs b/python/src/lib.rs index 7dd52bdc2..d0e933dba 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -28,6 +28,7 @@ pub mod index; pub mod namespace; pub mod permutation; pub mod query; +pub mod runtime; pub mod session; pub mod table; pub mod util; diff --git a/python/src/permutation.rs b/python/src/permutation.rs index 114825938..75e1fe1b7 100644 --- a/python/src/permutation.rs +++ b/python/src/permutation.rs @@ -3,7 +3,9 @@ use std::sync::{Arc, Mutex}; -use crate::{arrow::RecordBatchStream, error::PythonErrorExt, table::Table}; +use crate::{ + arrow::RecordBatchStream, error::PythonErrorExt, runtime::future_into_py, table::Table, +}; use arrow::pyarrow::{PyArrowType, ToPyArrow}; use lancedb::{ dataloader::permutation::{ @@ -19,7 +21,6 @@ use pyo3::{ pyclass, pymethods, types::{PyAnyMethods, PyDict, PyDictMethods, PyType}, }; -use pyo3_async_runtimes::tokio::future_into_py; fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult> { if table.hasattr("_inner")? { diff --git a/python/src/query.rs b/python/src/query.rs index 1b64b5eaa..1dc4f08db 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -4,6 +4,11 @@ use std::sync::Arc; use std::time::Duration; +use crate::expr::PyExpr; +use crate::runtime::future_into_py; +use crate::util::parse_distance_type; +use crate::{arrow::RecordBatchStream, util::PyLanceDB}; +use crate::{error::PythonErrorExt, index::class_name}; use arrow::array::Array; use arrow::array::ArrayData; use arrow::array::make_array; @@ -36,12 +41,6 @@ use pyo3::types::{PyDict, PyString}; use pyo3::{Borrowed, FromPyObject, exceptions::PyRuntimeError}; use pyo3::{PyErr, pyclass}; use pyo3::{exceptions::PyValueError, intern}; -use pyo3_async_runtimes::tokio::future_into_py; - -use crate::expr::PyExpr; -use crate::util::parse_distance_type; -use crate::{arrow::RecordBatchStream, util::PyLanceDB}; -use crate::{error::PythonErrorExt, index::class_name}; impl<'a, 'py> FromPyObject<'a, 'py> for PyLanceDB { type Error = PyErr; diff --git a/python/src/runtime.rs b/python/src/runtime.rs new file mode 100644 index 000000000..39ebfdaa8 --- /dev/null +++ b/python/src/runtime.rs @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Fork-safe wrapper around tokio + pyo3-async-runtimes. +//! +//! `pyo3_async_runtimes::tokio` keeps its multi-threaded runtime in a +//! `OnceLock` that can never be replaced. Tokio's worker threads do not +//! survive `fork()`, so once a child inherits a "frozen" runtime, every +//! `future_into_py` call hangs forever. +//! +//! We sidestep the global by routing every future through our own +//! [`LanceRuntime`] (a [`pyo3_async_runtimes::generic::Runtime`] impl) backed +//! by an [`AtomicPtr`] to a tokio runtime that we own. A `pthread_atfork` +//! child handler nulls the pointer; the next `spawn` rebuilds the runtime in +//! the child. This mirrors the pattern used in the Lance Python bindings. + +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; + +use pyo3::{Bound, PyAny, PyResult, Python, conversion::IntoPyObject}; +use pyo3_async_runtimes::{ + TaskLocals, + generic::{ContextExt, JoinError, Runtime}, +}; +use tokio::{runtime, task}; + +static RUNTIME: AtomicPtr = AtomicPtr::new(std::ptr::null_mut()); +static RUNTIME_INSTALLING: AtomicBool = AtomicBool::new(false); +static ATFORK_INSTALLED: AtomicBool = AtomicBool::new(false); + +fn create_runtime() -> runtime::Runtime { + runtime::Builder::new_multi_thread() + .enable_all() + .thread_name("lancedb-tokio-worker") + .build() + .expect("Failed to build tokio runtime") +} + +fn get_runtime() -> &'static runtime::Runtime { + loop { + let ptr = RUNTIME.load(Ordering::SeqCst); + if !ptr.is_null() { + return unsafe { &*ptr }; + } + if !RUNTIME_INSTALLING.fetch_or(true, Ordering::SeqCst) { + break; + } + std::thread::yield_now(); + } + if !ATFORK_INSTALLED.fetch_or(true, Ordering::SeqCst) { + install_atfork(); + } + let new_ptr = Box::into_raw(Box::new(create_runtime())); + RUNTIME.store(new_ptr, Ordering::SeqCst); + unsafe { &*new_ptr } +} + +/// Runs in async-signal context after `fork()` in the child. We can only +/// touch atomics here; we deliberately leak the previous runtime because +/// dropping a tokio `Runtime` would try to join its (now-dead) worker +/// threads and hang. +extern "C" fn atfork_child() { + RUNTIME.store(std::ptr::null_mut(), Ordering::SeqCst); + RUNTIME_INSTALLING.store(false, Ordering::SeqCst); +} + +#[cfg(not(windows))] +fn install_atfork() { + unsafe { libc::pthread_atfork(None, None, Some(atfork_child)) }; +} + +#[cfg(windows)] +fn install_atfork() {} + +/// Marker type implementing [`Runtime`] over our fork-safe runtime slot. +pub struct LanceRuntime; + +/// Newtype wrapper around `tokio::task::JoinError` so we can implement the +/// foreign [`JoinError`] trait without violating orphan rules. +pub struct LanceJoinError(task::JoinError); + +impl JoinError for LanceJoinError { + fn is_panic(&self) -> bool { + self.0.is_panic() + } + fn into_panic(self) -> Box { + self.0.into_panic() + } +} + +impl Runtime for LanceRuntime { + type JoinError = LanceJoinError; + type JoinHandle = Pin> + Send>>; + + fn spawn(fut: F) -> Self::JoinHandle + where + F: Future + Send + 'static, + { + let handle = get_runtime().spawn(fut); + Box::pin(async move { handle.await.map_err(LanceJoinError) }) + } + + fn spawn_blocking(f: F) -> Self::JoinHandle + where + F: FnOnce() + Send + 'static, + { + let handle = get_runtime().spawn_blocking(f); + Box::pin(async move { handle.await.map_err(LanceJoinError) }) + } +} + +tokio::task_local! { + static TASK_LOCALS: std::cell::OnceCell; +} + +impl ContextExt for LanceRuntime { + fn scope(locals: TaskLocals, fut: F) -> Pin + Send>> + where + F: Future + Send + 'static, + { + let cell = std::cell::OnceCell::new(); + cell.set(locals).unwrap(); + Box::pin(TASK_LOCALS.scope(cell, fut)) + } + + fn get_task_locals() -> Option { + TASK_LOCALS + .try_with(|c| c.get().cloned()) + .unwrap_or_default() + } +} + +/// Drop-in replacement for `pyo3_async_runtimes::tokio::future_into_py` that +/// uses our fork-safe runtime. +pub fn future_into_py(py: Python<'_>, fut: F) -> PyResult> +where + F: Future> + Send + 'static, + T: for<'py> IntoPyObject<'py> + Send + 'static, +{ + pyo3_async_runtimes::generic::future_into_py::(py, fut) +} diff --git a/python/src/table.rs b/python/src/table.rs index 715ac79cc..9ac5af807 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::{collections::HashMap, sync::Arc}; +use crate::runtime::future_into_py; use crate::{ connection::Connection, error::PythonErrorExt, @@ -24,7 +25,6 @@ use pyo3::{ pyclass, pymethods, types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods}, }; -use pyo3_async_runtimes::tokio::future_into_py; mod scannable;