mirror of
https://github.com/lancedb/lancedb.git
synced 2026-07-02 18:40:40 +00:00
## Summary
PyTorch's `DataLoader` uses fork-based multiprocessing by default on
Linux, but threads do not survive `fork()`. LanceDB's Python bindings
drive async work through two threaded layers, both of which become inert
in a forked child:
- `BackgroundEventLoop` runs an asyncio loop on a Python
`threading.Thread`.
- `pyo3-async-runtimes::tokio` holds a global multi-threaded tokio
runtime whose worker threads also die on fork — and its runtime lives in
a `OnceLock` that cannot be replaced after first use.
As a result, any `Permutation` (or other async API) used inside a
fork-based `DataLoader` worker hangs indefinitely. This PR makes both
layers fork-safe so `Permutation` works as a `torch.utils.data.Dataset`
with `num_workers > 0`.
## Approach
### Rust — new `python/src/runtime.rs`
Mirrors the pattern used in [Lance's Python
bindings](456198cd6f/python/src/lib.rs (L139)),
adapted for the async-bridge use case.
- `LanceRuntime` implements `pyo3_async_runtimes::generic::Runtime +
ContextExt`, backed by an `AtomicPtr<tokio::runtime::Runtime>` we own
(sidestepping `pyo3-async-runtimes`'s frozen `OnceLock` global).
- A `pthread_atfork(after_in_child)` handler nulls the pointer; the next
`spawn` rebuilds the runtime in the child. The previous runtime is
intentionally **leaked** — calling `Drop` would try to join now-dead
worker threads and hang.
- `runtime::future_into_py` is a drop-in for
`pyo3_async_runtimes::tokio::future_into_py`. All ~80 call sites in
`arrow.rs` / `connection.rs` / `permutation.rs` / `query.rs` /
`table.rs` are updated to route through it.
- `python/Cargo.toml` adds `libc = "0.2"` and the tokio
`rt-multi-thread` feature.
### Python — `lancedb/background_loop.py`
- Refactors `BackgroundEventLoop.__init__` to a reusable `_start()`
method.
- An `os.register_at_fork(after_in_child=…)` hook calls `LOOP._start()`
to give the singleton a fresh asyncio loop and thread **in place**. This
matters because the rest of the codebase imports `LOOP` via `from
.background_loop import LOOP` — rebinding the module attribute would
leave those references holding the dead loop.
### Python — `lancedb/__init__.py`
Removes the `__warn_on_fork` pre-fork warning (and the now-unused
`import warnings`). Fork is supported.
## Test plan
- [x] New `test_permutation_dataloader_fork_workers` in
`python/tests/test_torch.py`: runs a `Permutation` through
`torch.utils.data.DataLoader(num_workers=2,
multiprocessing_context="fork")` inside a spawn-isolated child with a
30s hang detector. **Pre-fix**: timed out at 36s. **Post-fix**: passes
in ~3.6s.
- [x] New `test_remote_connection_after_fork` in
`python/tests/test_remote_db.py`: forks a child that creates a fresh
`lancedb.connect(...)` against a mock HTTP server and calls
`table_names()`; passes in <1s, validates the runtime reset is
sufficient for fresh remote clients.
- [x] All 62 tests in `test_torch.py` + `test_permutation.py` pass.
- [x] All 35 tests in `test_remote_db.py` pass.
- [x] `test_table.py` (87) + `test_db.py` + `test_query.py` (157, minus
one unrelated `sentence_transformers` import skip) — 244 passing.
- [x] `cargo clippy -p lancedb-python --tests` clean.
- [x] `cargo fmt`, `ruff check`, `ruff format` all clean.
## Known limitation (follow-up)
This PR makes a **freshly-built** `lancedb.connect(...)` work in a
forked child. An **inherited** `Connection` from the parent still
carries an inherited `reqwest::Client` whose hyper connection pool
references socket FDs and TCP/TLS state shared with the parent — using
it from the child after fork is unsafe (especially with HTTP/1.1
keep-alive). The recommended pattern for fork-based `DataLoader` workers
that hit a remote DB is to construct a new connection inside the worker.
Auto-clearing inherited HTTP client pools on fork would require tracking
live `Connection` instances in `lancedb` core and is left for a
follow-up PR.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
343 lines
11 KiB
Rust
343 lines
11 KiB
Rust
// SPDX-License-Identifier: Apache-2.0
|
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use crate::{
|
|
arrow::RecordBatchStream, error::PythonErrorExt, runtime::future_into_py, table::Table,
|
|
};
|
|
use arrow::pyarrow::{PyArrowType, ToPyArrow};
|
|
use lancedb::{
|
|
dataloader::permutation::{
|
|
builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
|
reader::PermutationReader,
|
|
split::{SplitSizes, SplitStrategy},
|
|
},
|
|
query::Select,
|
|
};
|
|
use pyo3::{
|
|
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
|
|
exceptions::PyRuntimeError,
|
|
pyclass, pymethods,
|
|
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
|
};
|
|
|
|
fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult<Bound<'a, Table>> {
|
|
if table.hasattr("_inner")? {
|
|
Ok(table.getattr("_inner")?.cast_into::<Table>()?)
|
|
} else if table.hasattr("_table")? {
|
|
Ok(table
|
|
.getattr("_table")?
|
|
.getattr("_inner")?
|
|
.cast_into::<Table>()?)
|
|
} else {
|
|
Err(PyRuntimeError::new_err(
|
|
"Provided table does not appear to be a Table or RemoteTable instance",
|
|
))
|
|
}
|
|
}
|
|
|
|
/// Create a permutation builder for the given table
|
|
#[pyo3::pyfunction]
|
|
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
|
|
let table = table_from_py(table)?;
|
|
let inner_table = table.borrow().inner_ref()?.clone();
|
|
let inner_builder = LancePermutationBuilder::new(inner_table);
|
|
|
|
Ok(PyAsyncPermutationBuilder {
|
|
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
|
|
builder: Some(inner_builder),
|
|
})),
|
|
})
|
|
}
|
|
|
|
struct PyAsyncPermutationBuilderState {
|
|
builder: Option<LancePermutationBuilder>,
|
|
}
|
|
|
|
#[pyclass(name = "AsyncPermutationBuilder")]
|
|
pub struct PyAsyncPermutationBuilder {
|
|
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
|
|
}
|
|
|
|
impl PyAsyncPermutationBuilder {
|
|
fn modify(
|
|
&self,
|
|
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
|
|
) -> PyResult<Self> {
|
|
let mut state = self.state.lock().unwrap();
|
|
let builder = state
|
|
.builder
|
|
.take()
|
|
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
|
state.builder = Some(func(builder));
|
|
Ok(Self {
|
|
state: self.state.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[pymethods]
|
|
impl PyAsyncPermutationBuilder {
|
|
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))]
|
|
pub fn split_random(
|
|
slf: PyRefMut<'_, Self>,
|
|
ratios: Option<Vec<f64>>,
|
|
counts: Option<Vec<u64>>,
|
|
fixed: Option<u64>,
|
|
seed: Option<u64>,
|
|
split_names: Option<Vec<String>>,
|
|
) -> PyResult<Self> {
|
|
// Check that exactly one split type is provided
|
|
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
|
.iter()
|
|
.filter(|&&x| x)
|
|
.count();
|
|
|
|
if split_args_count != 1 {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
|
));
|
|
}
|
|
|
|
let sizes = if let Some(ratios) = ratios {
|
|
SplitSizes::Percentages(ratios)
|
|
} else if let Some(counts) = counts {
|
|
SplitSizes::Counts(counts)
|
|
} else if let Some(fixed) = fixed {
|
|
SplitSizes::Fixed(fixed)
|
|
} else {
|
|
unreachable!("One of the split arguments must be provided");
|
|
};
|
|
|
|
slf.modify(|builder| {
|
|
builder.with_split_strategy(SplitStrategy::Random { seed, sizes }, split_names)
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (columns, split_weights, *, discard_weight=0, split_names=None))]
|
|
pub fn split_hash(
|
|
slf: PyRefMut<'_, Self>,
|
|
columns: Vec<String>,
|
|
split_weights: Vec<u64>,
|
|
discard_weight: u64,
|
|
split_names: Option<Vec<String>>,
|
|
) -> PyResult<Self> {
|
|
slf.modify(|builder| {
|
|
builder.with_split_strategy(
|
|
SplitStrategy::Hash {
|
|
columns,
|
|
split_weights,
|
|
discard_weight,
|
|
},
|
|
split_names,
|
|
)
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, split_names=None))]
|
|
pub fn split_sequential(
|
|
slf: PyRefMut<'_, Self>,
|
|
ratios: Option<Vec<f64>>,
|
|
counts: Option<Vec<u64>>,
|
|
fixed: Option<u64>,
|
|
split_names: Option<Vec<String>>,
|
|
) -> PyResult<Self> {
|
|
// Check that exactly one split type is provided
|
|
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
|
.iter()
|
|
.filter(|&&x| x)
|
|
.count();
|
|
|
|
if split_args_count != 1 {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
|
));
|
|
}
|
|
|
|
let sizes = if let Some(ratios) = ratios {
|
|
SplitSizes::Percentages(ratios)
|
|
} else if let Some(counts) = counts {
|
|
SplitSizes::Counts(counts)
|
|
} else if let Some(fixed) = fixed {
|
|
SplitSizes::Fixed(fixed)
|
|
} else {
|
|
unreachable!("One of the split arguments must be provided");
|
|
};
|
|
|
|
slf.modify(|builder| {
|
|
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, split_names)
|
|
})
|
|
}
|
|
|
|
pub fn split_calculated(
|
|
slf: PyRefMut<'_, Self>,
|
|
calculation: String,
|
|
split_names: Option<Vec<String>>,
|
|
) -> PyResult<Self> {
|
|
slf.modify(|builder| {
|
|
builder.with_split_strategy(SplitStrategy::Calculated { calculation }, split_names)
|
|
})
|
|
}
|
|
|
|
pub fn shuffle(
|
|
slf: PyRefMut<'_, Self>,
|
|
seed: Option<u64>,
|
|
clump_size: Option<u64>,
|
|
) -> PyResult<Self> {
|
|
slf.modify(|builder| {
|
|
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
|
|
})
|
|
}
|
|
|
|
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
|
|
slf.modify(|builder| builder.with_filter(filter))
|
|
}
|
|
|
|
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
|
let mut state = slf.state.lock().unwrap();
|
|
let builder = state
|
|
.builder
|
|
.take()
|
|
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
|
|
|
future_into_py(slf.py(), async move {
|
|
let table = builder.build().await.infer_error()?;
|
|
Ok(Table::new(table))
|
|
})
|
|
}
|
|
}
|
|
|
|
#[pyclass(name = "PermutationReader")]
|
|
pub struct PyPermutationReader {
|
|
reader: Arc<PermutationReader>,
|
|
}
|
|
|
|
impl PyPermutationReader {
|
|
fn from_reader(reader: PermutationReader) -> Self {
|
|
Self {
|
|
reader: Arc::new(reader),
|
|
}
|
|
}
|
|
|
|
fn parse_selection(selection: Option<Bound<'_, PyAny>>) -> PyResult<Select> {
|
|
let Some(selection) = selection else {
|
|
return Ok(Select::All);
|
|
};
|
|
let selection = selection.cast_into::<PyDict>()?;
|
|
let selection = selection
|
|
.iter()
|
|
.map(|(key, value)| {
|
|
let key = key.extract::<String>()?;
|
|
let value = value.extract::<String>()?;
|
|
Ok((key, value))
|
|
})
|
|
.collect::<PyResult<Vec<_>>>()?;
|
|
Ok(Select::dynamic(&selection))
|
|
}
|
|
}
|
|
|
|
#[pymethods]
|
|
impl PyPermutationReader {
|
|
#[classmethod]
|
|
pub fn from_tables<'py>(
|
|
cls: &Bound<'py, PyType>,
|
|
base_table: Bound<'py, PyAny>,
|
|
permutation_table: Option<Bound<'py, PyAny>>,
|
|
split: u64,
|
|
) -> PyResult<Bound<'py, PyAny>> {
|
|
let base_table = table_from_py(base_table)?;
|
|
let permutation_table = permutation_table.map(table_from_py).transpose()?;
|
|
|
|
let base_table = base_table.borrow().inner_ref()?.base_table().clone();
|
|
let permutation_table = permutation_table
|
|
.map(|p| PyResult::Ok(p.borrow().inner_ref()?.base_table().clone()))
|
|
.transpose()?;
|
|
|
|
future_into_py(cls.py(), async move {
|
|
let reader = if let Some(permutation_table) = permutation_table {
|
|
PermutationReader::try_from_tables(base_table, permutation_table, split)
|
|
.await
|
|
.infer_error()?
|
|
} else {
|
|
PermutationReader::identity(base_table).await
|
|
};
|
|
Ok(Self::from_reader(reader))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (selection=None))]
|
|
pub fn output_schema<'py>(
|
|
slf: PyRef<'py, Self>,
|
|
selection: Option<Bound<'py, PyAny>>,
|
|
) -> PyResult<Bound<'py, PyAny>> {
|
|
let selection = Self::parse_selection(selection)?;
|
|
let reader = slf.reader.clone();
|
|
future_into_py(slf.py(), async move {
|
|
let schema = reader.output_schema(selection).await.infer_error()?;
|
|
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = ())]
|
|
pub fn count_rows<'py>(slf: PyRef<'py, Self>) -> u64 {
|
|
slf.reader.count_rows()
|
|
}
|
|
|
|
#[pyo3(signature = (offset))]
|
|
pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult<Bound<'py, PyAny>> {
|
|
let reader = slf.reader.as_ref().clone();
|
|
future_into_py(slf.py(), async move {
|
|
let reader = reader.with_offset(offset).await.infer_error()?;
|
|
Ok(Self::from_reader(reader))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (limit))]
|
|
pub fn with_limit<'py>(slf: PyRef<'py, Self>, limit: u64) -> PyResult<Bound<'py, PyAny>> {
|
|
let reader = slf.reader.as_ref().clone();
|
|
future_into_py(slf.py(), async move {
|
|
let reader = reader.with_limit(limit).await.infer_error()?;
|
|
Ok(Self::from_reader(reader))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (selection=None, *, batch_size=None))]
|
|
pub fn read<'py>(
|
|
slf: PyRef<'py, Self>,
|
|
selection: Option<Bound<'py, PyAny>>,
|
|
batch_size: Option<u32>,
|
|
) -> PyResult<Bound<'py, PyAny>> {
|
|
let selection = Self::parse_selection(selection)?;
|
|
let reader = slf.reader.clone();
|
|
let batch_size = batch_size.unwrap_or(1024);
|
|
future_into_py(slf.py(), async move {
|
|
use lancedb::query::QueryExecutionOptions;
|
|
let mut execution_options = QueryExecutionOptions::default();
|
|
execution_options.max_batch_length = batch_size;
|
|
let stream = reader
|
|
.read(selection, execution_options)
|
|
.await
|
|
.infer_error()?;
|
|
Ok(RecordBatchStream::new(stream))
|
|
})
|
|
}
|
|
|
|
#[pyo3(signature = (indices, *, selection=None))]
|
|
pub fn take_offsets<'py>(
|
|
slf: PyRef<'py, Self>,
|
|
indices: Vec<u64>,
|
|
selection: Option<Bound<'py, PyAny>>,
|
|
) -> PyResult<Bound<'py, PyAny>> {
|
|
let selection = Self::parse_selection(selection)?;
|
|
let reader = slf.reader.clone();
|
|
future_into_py(slf.py(), async move {
|
|
let batch = reader
|
|
.take_offsets(&indices, selection)
|
|
.await
|
|
.infer_error()?;
|
|
Ok(PyArrowType(batch))
|
|
})
|
|
}
|
|
}
|