mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 02:20:40 +00:00
feat(python): make Permutation fork-safe for PyTorch DataLoader workers (#3339)
## Summary
PyTorch's `DataLoader` uses fork-based multiprocessing by default on
Linux, but threads do not survive `fork()`. LanceDB's Python bindings
drive async work through two threaded layers, both of which become inert
in a forked child:
- `BackgroundEventLoop` runs an asyncio loop on a Python
`threading.Thread`.
- `pyo3-async-runtimes::tokio` holds a global multi-threaded tokio
runtime whose worker threads also die on fork — and its runtime lives in
a `OnceLock` that cannot be replaced after first use.
As a result, any `Permutation` (or other async API) used inside a
fork-based `DataLoader` worker hangs indefinitely. This PR makes both
layers fork-safe so `Permutation` works as a `torch.utils.data.Dataset`
with `num_workers > 0`.
## Approach
### Rust — new `python/src/runtime.rs`
Mirrors the pattern used in [Lance's Python
bindings](456198cd6f/python/src/lib.rs (L139)),
adapted for the async-bridge use case.
- `LanceRuntime` implements `pyo3_async_runtimes::generic::Runtime +
ContextExt`, backed by an `AtomicPtr<tokio::runtime::Runtime>` we own
(sidestepping `pyo3-async-runtimes`'s frozen `OnceLock` global).
- A `pthread_atfork(after_in_child)` handler nulls the pointer; the next
`spawn` rebuilds the runtime in the child. The previous runtime is
intentionally **leaked** — calling `Drop` would try to join now-dead
worker threads and hang.
- `runtime::future_into_py` is a drop-in for
`pyo3_async_runtimes::tokio::future_into_py`. All ~80 call sites in
`arrow.rs` / `connection.rs` / `permutation.rs` / `query.rs` /
`table.rs` are updated to route through it.
- `python/Cargo.toml` adds `libc = "0.2"` and the tokio
`rt-multi-thread` feature.
### Python — `lancedb/background_loop.py`
- Refactors `BackgroundEventLoop.__init__` to a reusable `_start()`
method.
- An `os.register_at_fork(after_in_child=…)` hook calls `LOOP._start()`
to give the singleton a fresh asyncio loop and thread **in place**. This
matters because the rest of the codebase imports `LOOP` via `from
.background_loop import LOOP` — rebinding the module attribute would
leave those references holding the dead loop.
### Python — `lancedb/__init__.py`
Removes the `__warn_on_fork` pre-fork warning (and the now-unused
`import warnings`). Fork is supported.
## Test plan
- [x] New `test_permutation_dataloader_fork_workers` in
`python/tests/test_torch.py`: runs a `Permutation` through
`torch.utils.data.DataLoader(num_workers=2,
multiprocessing_context="fork")` inside a spawn-isolated child with a
30s hang detector. **Pre-fix**: timed out at 36s. **Post-fix**: passes
in ~3.6s.
- [x] New `test_remote_connection_after_fork` in
`python/tests/test_remote_db.py`: forks a child that creates a fresh
`lancedb.connect(...)` against a mock HTTP server and calls
`table_names()`; passes in <1s, validates the runtime reset is
sufficient for fresh remote clients.
- [x] All 62 tests in `test_torch.py` + `test_permutation.py` pass.
- [x] All 35 tests in `test_remote_db.py` pass.
- [x] `test_table.py` (87) + `test_db.py` + `test_query.py` (157, minus
one unrelated `sentence_transformers` import skip) — 244 passing.
- [x] `cargo clippy -p lancedb-python --tests` clean.
- [x] `cargo fmt`, `ruff check`, `ruff format` all clean.
## Known limitation (follow-up)
This PR makes a **freshly-built** `lancedb.connect(...)` work in a
forked child. An **inherited** `Connection` from the parent still
carries an inherited `reqwest::Client` whose hyper connection pool
references socket FDs and TCP/TLS state shared with the parent — using
it from the child after fork is unsafe (especially with HTTP/1.1
keep-alive). The recommended pattern for fork-based `DataLoader` workers
that hit a remote DB is to construct a new connection inside the worker.
Auto-clearing inherited HTTP client pools on fork would require tracking
live `Connection` instances in `lancedb` core and is left for a
follow-up PR.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -35,7 +35,8 @@ futures.workspace = true
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
snafu.workspace = true
|
||||
tokio = { version = "1.40", features = ["sync"] }
|
||||
tokio = { version = "1.40", features = ["sync", "rt-multi-thread"] }
|
||||
libc = "0.2"
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = { version = "0.28", features = [
|
||||
|
||||
@@ -7,7 +7,6 @@ import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Optional, Union, Any, List
|
||||
import warnings
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
@@ -438,13 +437,3 @@ __all__ = [
|
||||
"Table",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
|
||||
def __warn_on_fork():
|
||||
warnings.warn(
|
||||
"lance is not fork-safe. If you are using multiprocessing, use spawn instead.",
|
||||
)
|
||||
|
||||
|
||||
if hasattr(os, "register_at_fork"):
|
||||
os.register_at_fork(before=__warn_on_fork) # type: ignore[attr-defined]
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
|
||||
class BackgroundEventLoop:
|
||||
@@ -13,6 +15,9 @@ class BackgroundEventLoop:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._start()
|
||||
|
||||
def _start(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.thread = threading.Thread(
|
||||
target=self.loop.run_forever,
|
||||
@@ -31,3 +36,30 @@ class BackgroundEventLoop:
|
||||
|
||||
|
||||
LOOP = BackgroundEventLoop()
|
||||
|
||||
_FORK_WARNED = False
|
||||
|
||||
|
||||
def _reset_after_fork():
|
||||
# Threads do not survive fork(), so the asyncio loop in LOOP.thread is
|
||||
# dead in the child. Re-initialize the singleton in place so existing
|
||||
# `from .background_loop import LOOP` references in other modules see
|
||||
# the new state. The Rust-side tokio runtime is reset analogously by a
|
||||
# pthread_atfork hook installed in the _lancedb extension.
|
||||
LOOP._start()
|
||||
global _FORK_WARNED
|
||||
if not _FORK_WARNED:
|
||||
_FORK_WARNED = True
|
||||
warnings.warn(
|
||||
"lancedb fork support is experimental: the internal async "
|
||||
"runtime has been reset in the forked child, but a small chance "
|
||||
"of deadlock remains if other state was mid-operation at fork "
|
||||
"time. The 'forkserver' or 'spawn' multiprocessing start method "
|
||||
"is likely a safer alternative.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
if hasattr(os, "register_at_fork"):
|
||||
os.register_at_fork(after_in_child=_reset_after_fork)
|
||||
|
||||
@@ -6,6 +6,8 @@ import contextlib
|
||||
from datetime import timedelta
|
||||
import http.server
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -1230,3 +1232,82 @@ def test_background_loop_cancellation(exception):
|
||||
with pytest.raises(exception):
|
||||
loop.run(None)
|
||||
mock_future.cancel.assert_called_once()
|
||||
|
||||
|
||||
def _remote_fork_child(port: int, queue) -> None:
|
||||
# Build a fresh Connection in the child so we exercise the at-fork-child
|
||||
# tokio runtime reset rather than relying on an inherited reqwest client.
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
queue.put(db.table_names())
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
"fork() is unavailable on Windows and unsafe on macOS "
|
||||
"(Apple frameworks/TLS are not fork-safe)"
|
||||
),
|
||||
)
|
||||
def test_remote_connection_after_fork():
|
||||
"""A freshly-built remote Connection in a forked child should not hang.
|
||||
|
||||
The pyo3-async-runtimes tokio runtime would otherwise be inherited from
|
||||
the parent with dead worker threads; the at-fork-child handler in our
|
||||
runtime module rebuilds it on first use in the child.
|
||||
"""
|
||||
|
||||
def handler(request):
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
server = http.server.HTTPServer(("localhost", 0), make_mock_http_handler(handler))
|
||||
port = server.server_address[1]
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.start()
|
||||
try:
|
||||
# Hit the server in the parent first so the runtime + LOOP are warm
|
||||
# before fork; a fresh child must still succeed.
|
||||
parent_db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
assert parent_db.table_names() == []
|
||||
|
||||
ctx = mp.get_context("fork")
|
||||
queue = ctx.Queue()
|
||||
proc = ctx.Process(target=_remote_fork_child, args=(port, queue))
|
||||
proc.start()
|
||||
proc.join(timeout=15)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join()
|
||||
pytest.fail("Remote connection hung after fork")
|
||||
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no result"
|
||||
assert queue.get() == []
|
||||
|
||||
# Parent connection must still be usable after the child returned.
|
||||
assert parent_db.table_names() == []
|
||||
finally:
|
||||
server.shutdown()
|
||||
server_thread.join()
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import functools
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.util import tbl_to_tensor
|
||||
from lancedb.permutation import Permutation, Permutations, permutation_builder
|
||||
from lancedb.util import tbl_to_tensor
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
@@ -146,3 +148,63 @@ def test_permutation_with_builder_is_picklable(tmp_db):
|
||||
|
||||
assert len(restored) == len(permutation)
|
||||
assert restored.__getitems__(indices) == expected
|
||||
|
||||
|
||||
def _multiworker_dataloader_target(db_uri: str, result_queue):
|
||||
import lancedb
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
db = lancedb.connect(db_uri)
|
||||
table = db.open_table("test_table")
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation,
|
||||
batch_size=10,
|
||||
num_workers=2,
|
||||
multiprocessing_context="fork",
|
||||
)
|
||||
count = 0
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
count += 1
|
||||
result_queue.put(count)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
"fork() is unavailable on Windows and unsafe on macOS "
|
||||
"(Apple frameworks/TLS are not fork-safe)"
|
||||
),
|
||||
)
|
||||
def test_permutation_dataloader_fork_workers(tmp_path):
|
||||
"""A Permutation used by a fork-based DataLoader should not hang.
|
||||
|
||||
PyTorch's DataLoader uses fork-based multiprocessing by default on Linux.
|
||||
LanceDB drives async work through a background asyncio thread that does
|
||||
not survive a fork, so any LOOP.run() in a worker blocks forever.
|
||||
"""
|
||||
import lancedb
|
||||
|
||||
db_uri = str(tmp_path / "db")
|
||||
db = lancedb.connect(db_uri)
|
||||
db.create_table("test_table", pa.table({"a": list(range(1000))}))
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
queue = ctx.Queue()
|
||||
proc = ctx.Process(target=_multiworker_dataloader_target, args=(db_uri, queue))
|
||||
proc.start()
|
||||
proc.join(timeout=30)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join()
|
||||
pytest.fail("Permutation hung when iterated in a fork-based DataLoader worker")
|
||||
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no batches"
|
||||
assert queue.get() == 100
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
use crate::runtime::future_into_py;
|
||||
use arrow::{
|
||||
datatypes::SchemaRef,
|
||||
pyarrow::{IntoPyArrow, ToPyArrow},
|
||||
@@ -12,9 +14,6 @@ use lancedb::arrow::SendableRecordBatchStream;
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyRef, PyResult, Python, exceptions::PyStopAsyncIteration, pyclass, pymethods,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
|
||||
#[pyclass]
|
||||
pub struct RecordBatchStream {
|
||||
|
||||
@@ -7,6 +7,12 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
namespace::{create_namespace_storage_options_provider, extract_namespace_arc},
|
||||
runtime::future_into_py,
|
||||
table::Table,
|
||||
};
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
@@ -20,13 +26,6 @@ use pyo3::{
|
||||
pyclass, pyfunction, pymethods,
|
||||
types::{PyDict, PyDictMethods},
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
namespace::{create_namespace_storage_options_provider, extract_namespace_arc},
|
||||
table::Table,
|
||||
};
|
||||
|
||||
#[pyclass]
|
||||
pub struct Connection {
|
||||
|
||||
@@ -28,6 +28,7 @@ pub mod index;
|
||||
pub mod namespace;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod runtime;
|
||||
pub mod session;
|
||||
pub mod table;
|
||||
pub mod util;
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{arrow::RecordBatchStream, error::PythonErrorExt, table::Table};
|
||||
use crate::{
|
||||
arrow::RecordBatchStream, error::PythonErrorExt, runtime::future_into_py, table::Table,
|
||||
};
|
||||
use arrow::pyarrow::{PyArrowType, ToPyArrow};
|
||||
use lancedb::{
|
||||
dataloader::permutation::{
|
||||
@@ -19,7 +21,6 @@ use pyo3::{
|
||||
pyclass, pymethods,
|
||||
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult<Bound<'a, Table>> {
|
||||
if table.hasattr("_inner")? {
|
||||
|
||||
@@ -4,6 +4,11 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::expr::PyExpr;
|
||||
use crate::runtime::future_into_py;
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||
use crate::{error::PythonErrorExt, index::class_name};
|
||||
use arrow::array::Array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::array::make_array;
|
||||
@@ -36,12 +41,6 @@ use pyo3::types::{PyDict, PyString};
|
||||
use pyo3::{Borrowed, FromPyObject, exceptions::PyRuntimeError};
|
||||
use pyo3::{PyErr, pyclass};
|
||||
use pyo3::{exceptions::PyValueError, intern};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::expr::PyExpr;
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||
use crate::{error::PythonErrorExt, index::class_name};
|
||||
|
||||
impl<'a, 'py> FromPyObject<'a, 'py> for PyLanceDB<FtsQuery> {
|
||||
type Error = PyErr;
|
||||
|
||||
142
python/src/runtime.rs
Normal file
142
python/src/runtime.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Fork-safe wrapper around tokio + pyo3-async-runtimes.
|
||||
//!
|
||||
//! `pyo3_async_runtimes::tokio` keeps its multi-threaded runtime in a
|
||||
//! `OnceLock` that can never be replaced. Tokio's worker threads do not
|
||||
//! survive `fork()`, so once a child inherits a "frozen" runtime, every
|
||||
//! `future_into_py` call hangs forever.
|
||||
//!
|
||||
//! We sidestep the global by routing every future through our own
|
||||
//! [`LanceRuntime`] (a [`pyo3_async_runtimes::generic::Runtime`] impl) backed
|
||||
//! by an [`AtomicPtr`] to a tokio runtime that we own. A `pthread_atfork`
|
||||
//! child handler nulls the pointer; the next `spawn` rebuilds the runtime in
|
||||
//! the child. This mirrors the pattern used in the Lance Python bindings.
|
||||
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
|
||||
|
||||
use pyo3::{Bound, PyAny, PyResult, Python, conversion::IntoPyObject};
|
||||
use pyo3_async_runtimes::{
|
||||
TaskLocals,
|
||||
generic::{ContextExt, JoinError, Runtime},
|
||||
};
|
||||
use tokio::{runtime, task};
|
||||
|
||||
static RUNTIME: AtomicPtr<runtime::Runtime> = AtomicPtr::new(std::ptr::null_mut());
|
||||
static RUNTIME_INSTALLING: AtomicBool = AtomicBool::new(false);
|
||||
static ATFORK_INSTALLED: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
fn create_runtime() -> runtime::Runtime {
|
||||
runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.thread_name("lancedb-tokio-worker")
|
||||
.build()
|
||||
.expect("Failed to build tokio runtime")
|
||||
}
|
||||
|
||||
fn get_runtime() -> &'static runtime::Runtime {
|
||||
loop {
|
||||
let ptr = RUNTIME.load(Ordering::SeqCst);
|
||||
if !ptr.is_null() {
|
||||
return unsafe { &*ptr };
|
||||
}
|
||||
if !RUNTIME_INSTALLING.fetch_or(true, Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
std::thread::yield_now();
|
||||
}
|
||||
if !ATFORK_INSTALLED.fetch_or(true, Ordering::SeqCst) {
|
||||
install_atfork();
|
||||
}
|
||||
let new_ptr = Box::into_raw(Box::new(create_runtime()));
|
||||
RUNTIME.store(new_ptr, Ordering::SeqCst);
|
||||
unsafe { &*new_ptr }
|
||||
}
|
||||
|
||||
/// Runs in async-signal context after `fork()` in the child. We can only
|
||||
/// touch atomics here; we deliberately leak the previous runtime because
|
||||
/// dropping a tokio `Runtime` would try to join its (now-dead) worker
|
||||
/// threads and hang.
|
||||
extern "C" fn atfork_child() {
|
||||
RUNTIME.store(std::ptr::null_mut(), Ordering::SeqCst);
|
||||
RUNTIME_INSTALLING.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
fn install_atfork() {
|
||||
unsafe { libc::pthread_atfork(None, None, Some(atfork_child)) };
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn install_atfork() {}
|
||||
|
||||
/// Marker type implementing [`Runtime`] over our fork-safe runtime slot.
|
||||
pub struct LanceRuntime;
|
||||
|
||||
/// Newtype wrapper around `tokio::task::JoinError` so we can implement the
|
||||
/// foreign [`JoinError`] trait without violating orphan rules.
|
||||
pub struct LanceJoinError(task::JoinError);
|
||||
|
||||
impl JoinError for LanceJoinError {
|
||||
fn is_panic(&self) -> bool {
|
||||
self.0.is_panic()
|
||||
}
|
||||
fn into_panic(self) -> Box<dyn std::any::Any + Send + 'static> {
|
||||
self.0.into_panic()
|
||||
}
|
||||
}
|
||||
|
||||
impl Runtime for LanceRuntime {
|
||||
type JoinError = LanceJoinError;
|
||||
type JoinHandle = Pin<Box<dyn Future<Output = Result<(), Self::JoinError>> + Send>>;
|
||||
|
||||
fn spawn<F>(fut: F) -> Self::JoinHandle
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let handle = get_runtime().spawn(fut);
|
||||
Box::pin(async move { handle.await.map_err(LanceJoinError) })
|
||||
}
|
||||
|
||||
fn spawn_blocking<F>(f: F) -> Self::JoinHandle
|
||||
where
|
||||
F: FnOnce() + Send + 'static,
|
||||
{
|
||||
let handle = get_runtime().spawn_blocking(f);
|
||||
Box::pin(async move { handle.await.map_err(LanceJoinError) })
|
||||
}
|
||||
}
|
||||
|
||||
tokio::task_local! {
|
||||
static TASK_LOCALS: std::cell::OnceCell<TaskLocals>;
|
||||
}
|
||||
|
||||
impl ContextExt for LanceRuntime {
|
||||
fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
|
||||
where
|
||||
F: Future<Output = R> + Send + 'static,
|
||||
{
|
||||
let cell = std::cell::OnceCell::new();
|
||||
cell.set(locals).unwrap();
|
||||
Box::pin(TASK_LOCALS.scope(cell, fut))
|
||||
}
|
||||
|
||||
fn get_task_locals() -> Option<TaskLocals> {
|
||||
TASK_LOCALS
|
||||
.try_with(|c| c.get().cloned())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop-in replacement for `pyo3_async_runtimes::tokio::future_into_py` that
|
||||
/// uses our fork-safe runtime.
|
||||
pub fn future_into_py<F, T>(py: Python<'_>, fut: F) -> PyResult<Bound<'_, PyAny>>
|
||||
where
|
||||
F: Future<Output = PyResult<T>> + Send + 'static,
|
||||
T: for<'py> IntoPyObject<'py> + Send + 'static,
|
||||
{
|
||||
pyo3_async_runtimes::generic::future_into_py::<LanceRuntime, _, T>(py, fut)
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::runtime::future_into_py;
|
||||
use crate::{
|
||||
connection::Connection,
|
||||
error::PythonErrorExt,
|
||||
@@ -24,7 +25,6 @@ use pyo3::{
|
||||
pyclass, pymethods,
|
||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
mod scannable;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user