fix(rust): handle Mutex lock poisoning gracefully across codebase (#3196)

Replace ~30 production `lock().unwrap()` calls that would cascade-panic
on a poisoned Mutex. Functions returning `Result` now propagate the
poison as an error via `?` (leveraging the existing `From<PoisonError>`
impl). Functions without a `Result` return recover via
`unwrap_or_else(|e| e.into_inner())`, which is safe because the guarded
data (counters, caches, RNG state) remains logically valid after a
panic.
This commit is contained in:
lennylxx
2026-03-30 09:25:18 -07:00
committed by GitHub
parent cca6a7c989
commit 74f457a0f2
7 changed files with 309 additions and 29 deletions

View File

@@ -240,7 +240,7 @@ impl Shuffler {
.await?;
// Need to read the entire file in a single batch for in-memory shuffling
let batch = reader.read_record_batch(0, reader.num_rows()).await?;
let mut rng = rng.lock().unwrap();
let mut rng = rng.lock().unwrap_or_else(|e| e.into_inner());
Self::shuffle_batch(&batch, &mut rng, clump_size)
}
})

View File

@@ -66,13 +66,13 @@ impl IoTrackingStore {
}
fn record_read(&self, num_bytes: u64) {
let mut stats = self.stats.lock().unwrap();
let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
stats.read_iops += 1;
stats.read_bytes += num_bytes;
}
fn record_write(&self, num_bytes: u64) {
let mut stats = self.stats.lock().unwrap();
let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
stats.write_iops += 1;
stats.write_bytes += num_bytes;
}
@@ -229,10 +229,63 @@ impl MultipartUpload for IoTrackingMultipartUpload {
fn put_part(&mut self, payload: PutPayload) -> UploadPart {
{
let mut stats = self.stats.lock().unwrap();
let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
stats.write_iops += 1;
stats.write_bytes += payload.content_length() as u64;
}
self.target.put_part(payload)
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Helper: poison a Mutex<IoStats> by panicking while holding the lock.
fn poison_stats(stats: &Arc<Mutex<IoStats>>) {
let stats_clone = stats.clone();
let handle = std::thread::spawn(move || {
let _guard = stats_clone.lock().unwrap();
panic!("intentional panic to poison stats mutex");
});
let _ = handle.join();
assert!(stats.lock().is_err(), "mutex should be poisoned");
}
#[test]
fn test_record_read_recovers_from_poisoned_lock() {
let stats = Arc::new(Mutex::new(IoStats::default()));
let store = IoTrackingStore {
target: Arc::new(object_store::memory::InMemory::new()),
stats: stats.clone(),
};
poison_stats(&stats);
// record_read should not panic
store.record_read(1024);
// Verify the stats were updated despite poisoning
let s = stats.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(s.read_iops, 1);
assert_eq!(s.read_bytes, 1024);
}
#[test]
fn test_record_write_recovers_from_poisoned_lock() {
let stats = Arc::new(Mutex::new(IoStats::default()));
let store = IoTrackingStore {
target: Arc::new(object_store::memory::InMemory::new()),
stats: stats.clone(),
};
poison_stats(&stats);
// record_write should not panic
store.record_write(2048);
let s = stats.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(s.write_iops, 1);
assert_eq!(s.write_bytes, 2048);
}
}

View File

@@ -130,7 +130,10 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
// TODO: this will be used when we wire this up to Table::add().
#[allow(dead_code)]
pub fn add_result(&self) -> Option<AddResult> {
self.add_result.lock().unwrap().clone()
self.add_result
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
/// Stream the input into an HTTP body as an Arrow IPC stream, capturing any

View File

@@ -204,7 +204,9 @@ impl ExecutionPlan for InsertExec {
let to_commit = {
// Don't hold the lock over an await point.
let mut txns = partial_transactions.lock().unwrap();
let mut txns = partial_transactions
.lock()
.unwrap_or_else(|e| e.into_inner());
txns.push(transaction);
if txns.len() == total_partitions {
Some(std::mem::take(&mut *txns))

View File

@@ -82,7 +82,7 @@ impl DatasetConsistencyWrapper {
/// pinned dataset regardless of consistency mode.
pub async fn get(&self) -> Result<Arc<Dataset>> {
{
let state = self.state.lock().unwrap();
let state = self.state.lock()?;
if state.pinned_version.is_some() {
return Ok(state.dataset.clone());
}
@@ -101,7 +101,7 @@ impl DatasetConsistencyWrapper {
}
ConsistencyMode::Strong => refresh_latest(self.state.clone()).await,
ConsistencyMode::Lazy => {
let state = self.state.lock().unwrap();
let state = self.state.lock()?;
Ok(state.dataset.clone())
}
}
@@ -116,7 +116,7 @@ impl DatasetConsistencyWrapper {
/// concurrent [`as_time_travel`](Self::as_time_travel) call), the update
/// is silently ignored — the write already committed to storage.
pub fn update(&self, dataset: Dataset) {
let mut state = self.state.lock().unwrap();
let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
if state.pinned_version.is_some() {
// A concurrent as_time_travel() beat us here. The write succeeded
// in storage, but since we're now pinned we don't advance the
@@ -139,7 +139,7 @@ impl DatasetConsistencyWrapper {
/// Check that the dataset is in a mutable mode (Latest).
pub fn ensure_mutable(&self) -> Result<()> {
let state = self.state.lock().unwrap();
let state = self.state.lock()?;
if state.pinned_version.is_some() {
Err(crate::Error::InvalidInput {
message: "table cannot be modified when a specific version is checked out"
@@ -152,13 +152,16 @@ impl DatasetConsistencyWrapper {
/// Returns the version, if in time travel mode, or None otherwise.
pub fn time_travel_version(&self) -> Option<u64> {
self.state.lock().unwrap().pinned_version
self.state
.lock()
.unwrap_or_else(|e| e.into_inner())
.pinned_version
}
/// Convert into a wrapper in latest version mode.
pub async fn as_latest(&self) -> Result<()> {
let dataset = {
let state = self.state.lock().unwrap();
let state = self.state.lock()?;
if state.pinned_version.is_none() {
return Ok(());
}
@@ -168,7 +171,7 @@ impl DatasetConsistencyWrapper {
let latest_version = dataset.latest_version_id().await?;
let new_dataset = dataset.checkout_version(latest_version).await?;
let mut state = self.state.lock().unwrap();
let mut state = self.state.lock()?;
if state.pinned_version.is_some() {
state.dataset = Arc::new(new_dataset);
state.pinned_version = None;
@@ -184,7 +187,7 @@ impl DatasetConsistencyWrapper {
let target_ref = target_version.into();
let (should_checkout, dataset) = {
let state = self.state.lock().unwrap();
let state = self.state.lock()?;
let should = match state.pinned_version {
None => true,
Some(version) => match &target_ref {
@@ -204,7 +207,7 @@ impl DatasetConsistencyWrapper {
let new_dataset = dataset.checkout_version(target_ref).await?;
let version_value = new_dataset.version().version;
let mut state = self.state.lock().unwrap();
let mut state = self.state.lock()?;
state.dataset = Arc::new(new_dataset);
state.pinned_version = Some(version_value);
Ok(())
@@ -212,7 +215,7 @@ impl DatasetConsistencyWrapper {
pub async fn reload(&self) -> Result<()> {
let (dataset, pinned_version) = {
let state = self.state.lock().unwrap();
let state = self.state.lock()?;
(state.dataset.clone(), state.pinned_version)
};
@@ -230,7 +233,7 @@ impl DatasetConsistencyWrapper {
let new_dataset = dataset.checkout_version(version).await?;
let mut state = self.state.lock().unwrap();
let mut state = self.state.lock()?;
if state.pinned_version == Some(version) {
state.dataset = Arc::new(new_dataset);
}
@@ -242,14 +245,14 @@ impl DatasetConsistencyWrapper {
}
async fn refresh_latest(state: Arc<Mutex<DatasetState>>) -> Result<Arc<Dataset>> {
let dataset = { state.lock().unwrap().dataset.clone() };
let dataset = { state.lock()?.dataset.clone() };
let mut ds = (*dataset).clone();
ds.checkout_latest().await?;
let new_arc = Arc::new(ds);
{
let mut state = state.lock().unwrap();
let mut state = state.lock()?;
if state.pinned_version.is_none()
&& new_arc.manifest().version >= state.dataset.manifest().version
{
@@ -612,4 +615,108 @@ mod tests {
let s = io_stats.incremental_stats();
assert_eq!(s.read_iops, 0, "step 5, elapsed={:?}", start.elapsed());
}
/// Helper: poison the mutex inside a DatasetConsistencyWrapper.
fn poison_state(wrapper: &DatasetConsistencyWrapper) {
let state = wrapper.state.clone();
let handle = std::thread::spawn(move || {
let _guard = state.lock().unwrap();
panic!("intentional panic to poison mutex");
});
let _ = handle.join(); // join collects the panic
assert!(wrapper.state.lock().is_err(), "mutex should be poisoned");
}
#[tokio::test]
async fn test_get_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
// get() should return Err, not panic
let result = wrapper.get().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_ensure_mutable_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
let result = wrapper.ensure_mutable();
assert!(result.is_err());
}
#[tokio::test]
async fn test_update_recovers_from_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let ds_v2 = append_to_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
// update() returns (), should not panic
wrapper.update(ds_v2);
}
#[tokio::test]
async fn test_time_travel_version_recovers_from_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
// Should not panic, returns whatever was in the mutex
let _version = wrapper.time_travel_version();
}
#[tokio::test]
async fn test_as_latest_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
let result = wrapper.as_latest().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_as_time_travel_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
let result = wrapper.as_time_travel(1u64).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_reload_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
let result = wrapper.reload().await;
assert!(result.is_err());
}
}

View File

@@ -130,8 +130,11 @@ impl WriteProgressTracker {
pub fn record_batch(&self, rows: usize, bytes: usize) {
// Lock order: callback first, then rows_and_bytes. This is the only
// order used anywhere, so deadlocks cannot occur.
let mut cb = self.callback.lock().unwrap();
let mut guard = self.rows_and_bytes.lock().unwrap();
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
let mut guard = self
.rows_and_bytes
.lock()
.unwrap_or_else(|e| e.into_inner());
guard.0 += rows;
guard.1 += bytes;
let progress = self.snapshot(guard.0, guard.1, false);
@@ -151,8 +154,11 @@ impl WriteProgressTracker {
/// `total_rows` is always `Some` on the final callback: it uses the known
/// total if available, or falls back to the number of rows actually written.
pub fn finish(&self) {
let mut cb = self.callback.lock().unwrap();
let guard = self.rows_and_bytes.lock().unwrap();
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
let guard = self
.rows_and_bytes
.lock()
.unwrap_or_else(|e| e.into_inner());
let mut snap = self.snapshot(guard.0, guard.1, true);
snap.total_rows = Some(self.total_rows.unwrap_or(guard.0));
drop(guard);
@@ -376,4 +382,50 @@ mod tests {
}
}
}
#[test]
fn test_record_batch_recovers_from_poisoned_callback_lock() {
use super::{ProgressCallback, WriteProgressTracker};
use std::sync::Mutex;
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
// Poison the callback mutex
let cb_clone = callback.clone();
let handle = std::thread::spawn(move || {
let _guard = cb_clone.lock().unwrap();
panic!("intentional panic to poison callback mutex");
});
let _ = handle.join();
assert!(
callback.lock().is_err(),
"callback mutex should be poisoned"
);
let tracker = WriteProgressTracker::new(callback, Some(100));
// record_batch should not panic
tracker.record_batch(10, 1024);
}
#[test]
fn test_finish_recovers_from_poisoned_callback_lock() {
use super::{ProgressCallback, WriteProgressTracker};
use std::sync::Mutex;
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
// Poison the callback mutex
let cb_clone = callback.clone();
let handle = std::thread::spawn(move || {
let _guard = cb_clone.lock().unwrap();
panic!("intentional panic to poison callback mutex");
});
let _ = handle.join();
let tracker = WriteProgressTracker::new(callback, Some(100));
// finish should not panic
tracker.finish();
}
}

View File

@@ -122,7 +122,7 @@ where
/// This is a cheap synchronous check useful as a fast path before
/// constructing a fetch closure for [`get()`](Self::get).
pub fn try_get(&self) -> Option<V> {
let cache = self.inner.lock().unwrap();
let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
cache.state.fresh_value(self.ttl, self.refresh_window)
}
@@ -138,7 +138,7 @@ where
{
// Fast path: check if cache is fresh
{
let cache = self.inner.lock().unwrap();
let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) {
return Ok(value);
}
@@ -147,7 +147,7 @@ where
// Slow path
let mut fetch = Some(fetch);
let action = {
let mut cache = self.inner.lock().unwrap();
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
self.determine_action(&mut cache, &mut fetch)
};
@@ -161,7 +161,7 @@ where
///
/// This avoids a blocking fetch on the first [`get()`](Self::get) call.
pub fn seed(&self, value: V) {
let mut cache = self.inner.lock().unwrap();
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
cache.state = State::Current(value, clock::now());
}
@@ -170,7 +170,7 @@ where
/// Any in-flight background fetch from before this call will not update the
/// cache (the generation counter prevents stale writes).
pub fn invalidate(&self) {
let mut cache = self.inner.lock().unwrap();
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
cache.state = State::Empty;
cache.generation += 1;
}
@@ -267,7 +267,7 @@ where
let fut_for_spawn = shared.clone();
tokio::spawn(async move {
let result = fut_for_spawn.await;
let mut cache = inner.lock().unwrap();
let mut cache = inner.lock().unwrap_or_else(|e| e.into_inner());
// Only update if no invalidation has happened since we started
if cache.generation != generation {
return;
@@ -590,4 +590,67 @@ mod tests {
let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap();
assert_eq!(v, "fresh");
}
/// Helper: poison the inner mutex of a BackgroundCache.
fn poison_cache(cache: &BackgroundCache<String, TestError>) {
let inner = cache.inner.clone();
let handle = std::thread::spawn(move || {
let _guard = inner.lock().unwrap();
panic!("intentional panic to poison mutex");
});
let _ = handle.join();
assert!(cache.inner.lock().is_err(), "mutex should be poisoned");
}
#[tokio::test]
async fn test_try_get_recovers_from_poisoned_lock() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
// Seed a value first
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); // peek
poison_cache(&cache);
// try_get() should not panic — it recovers via unwrap_or_else
let result = cache.try_get();
// The value may or may not be fresh depending on timing, but it must not panic
let _ = result;
}
#[tokio::test]
async fn test_get_recovers_from_poisoned_lock() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
poison_cache(&cache);
// get() should not panic — it recovers and can still fetch
let result = cache.get(ok_fetcher(count.clone(), "recovered")).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "recovered");
}
#[tokio::test]
async fn test_seed_recovers_from_poisoned_lock() {
let cache = new_cache();
poison_cache(&cache);
// seed() should not panic
cache.seed("seeded".to_string());
}
#[tokio::test]
async fn test_invalidate_recovers_from_poisoned_lock() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
poison_cache(&cache);
// invalidate() should not panic
cache.invalidate();
}
}