From 74f457a0f2732e848b2102734865d563ba4c1b61 Mon Sep 17 00:00:00 2001 From: lennylxx Date: Mon, 30 Mar 2026 09:25:18 -0700 Subject: [PATCH] fix(rust): handle Mutex lock poisoning gracefully across codebase (#3196) Replace ~30 production `lock().unwrap()` calls that would cascade-panic on a poisoned Mutex. Functions returning `Result` now propagate the poison as an error via `?` (leveraging the existing `From` impl). Functions without a `Result` return recover via `unwrap_or_else(|e| e.into_inner())`, which is safe because the guarded data (counters, caches, RNG state) remains logically valid after a panic. --- .../src/dataloader/permutation/shuffle.rs | 2 +- .../src/io/object_store/io_tracking.rs | 59 +++++++- rust/lancedb/src/remote/table/insert.rs | 5 +- rust/lancedb/src/table/datafusion/insert.rs | 4 +- rust/lancedb/src/table/dataset.rs | 133 ++++++++++++++++-- rust/lancedb/src/table/write_progress.rs | 60 +++++++- rust/lancedb/src/utils/background_cache.rs | 75 +++++++++- 7 files changed, 309 insertions(+), 29 deletions(-) diff --git a/rust/lancedb/src/dataloader/permutation/shuffle.rs b/rust/lancedb/src/dataloader/permutation/shuffle.rs index 05f98eb5a..7cd27e342 100644 --- a/rust/lancedb/src/dataloader/permutation/shuffle.rs +++ b/rust/lancedb/src/dataloader/permutation/shuffle.rs @@ -240,7 +240,7 @@ impl Shuffler { .await?; // Need to read the entire file in a single batch for in-memory shuffling let batch = reader.read_record_batch(0, reader.num_rows()).await?; - let mut rng = rng.lock().unwrap(); + let mut rng = rng.lock().unwrap_or_else(|e| e.into_inner()); Self::shuffle_batch(&batch, &mut rng, clump_size) } }) diff --git a/rust/lancedb/src/io/object_store/io_tracking.rs b/rust/lancedb/src/io/object_store/io_tracking.rs index 882ef51fa..20f0a020a 100644 --- a/rust/lancedb/src/io/object_store/io_tracking.rs +++ b/rust/lancedb/src/io/object_store/io_tracking.rs @@ -66,13 +66,13 @@ impl IoTrackingStore { } fn record_read(&self, num_bytes: u64) { - let mut stats = self.stats.lock().unwrap(); + let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner()); stats.read_iops += 1; stats.read_bytes += num_bytes; } fn record_write(&self, num_bytes: u64) { - let mut stats = self.stats.lock().unwrap(); + let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner()); stats.write_iops += 1; stats.write_bytes += num_bytes; } @@ -229,10 +229,63 @@ impl MultipartUpload for IoTrackingMultipartUpload { fn put_part(&mut self, payload: PutPayload) -> UploadPart { { - let mut stats = self.stats.lock().unwrap(); + let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner()); stats.write_iops += 1; stats.write_bytes += payload.content_length() as u64; } self.target.put_part(payload) } } + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: poison a Mutex by panicking while holding the lock. + fn poison_stats(stats: &Arc>) { + let stats_clone = stats.clone(); + let handle = std::thread::spawn(move || { + let _guard = stats_clone.lock().unwrap(); + panic!("intentional panic to poison stats mutex"); + }); + let _ = handle.join(); + assert!(stats.lock().is_err(), "mutex should be poisoned"); + } + + #[test] + fn test_record_read_recovers_from_poisoned_lock() { + let stats = Arc::new(Mutex::new(IoStats::default())); + let store = IoTrackingStore { + target: Arc::new(object_store::memory::InMemory::new()), + stats: stats.clone(), + }; + + poison_stats(&stats); + + // record_read should not panic + store.record_read(1024); + + // Verify the stats were updated despite poisoning + let s = stats.lock().unwrap_or_else(|e| e.into_inner()); + assert_eq!(s.read_iops, 1); + assert_eq!(s.read_bytes, 1024); + } + + #[test] + fn test_record_write_recovers_from_poisoned_lock() { + let stats = Arc::new(Mutex::new(IoStats::default())); + let store = IoTrackingStore { + target: Arc::new(object_store::memory::InMemory::new()), + stats: stats.clone(), + }; + + poison_stats(&stats); + + // record_write should not panic + store.record_write(2048); + + let s = stats.lock().unwrap_or_else(|e| e.into_inner()); + assert_eq!(s.write_iops, 1); + assert_eq!(s.write_bytes, 2048); + } +} diff --git a/rust/lancedb/src/remote/table/insert.rs b/rust/lancedb/src/remote/table/insert.rs index d7d30a680..8aec28609 100644 --- a/rust/lancedb/src/remote/table/insert.rs +++ b/rust/lancedb/src/remote/table/insert.rs @@ -130,7 +130,10 @@ impl RemoteInsertExec { // TODO: this will be used when we wire this up to Table::add(). #[allow(dead_code)] pub fn add_result(&self) -> Option { - self.add_result.lock().unwrap().clone() + self.add_result + .lock() + .unwrap_or_else(|e| e.into_inner()) + .clone() } /// Stream the input into an HTTP body as an Arrow IPC stream, capturing any diff --git a/rust/lancedb/src/table/datafusion/insert.rs b/rust/lancedb/src/table/datafusion/insert.rs index 4dce78788..51be4abb8 100644 --- a/rust/lancedb/src/table/datafusion/insert.rs +++ b/rust/lancedb/src/table/datafusion/insert.rs @@ -204,7 +204,9 @@ impl ExecutionPlan for InsertExec { let to_commit = { // Don't hold the lock over an await point. - let mut txns = partial_transactions.lock().unwrap(); + let mut txns = partial_transactions + .lock() + .unwrap_or_else(|e| e.into_inner()); txns.push(transaction); if txns.len() == total_partitions { Some(std::mem::take(&mut *txns)) diff --git a/rust/lancedb/src/table/dataset.rs b/rust/lancedb/src/table/dataset.rs index 89fcf55dd..54c4ba691 100644 --- a/rust/lancedb/src/table/dataset.rs +++ b/rust/lancedb/src/table/dataset.rs @@ -82,7 +82,7 @@ impl DatasetConsistencyWrapper { /// pinned dataset regardless of consistency mode. pub async fn get(&self) -> Result> { { - let state = self.state.lock().unwrap(); + let state = self.state.lock()?; if state.pinned_version.is_some() { return Ok(state.dataset.clone()); } @@ -101,7 +101,7 @@ impl DatasetConsistencyWrapper { } ConsistencyMode::Strong => refresh_latest(self.state.clone()).await, ConsistencyMode::Lazy => { - let state = self.state.lock().unwrap(); + let state = self.state.lock()?; Ok(state.dataset.clone()) } } @@ -116,7 +116,7 @@ impl DatasetConsistencyWrapper { /// concurrent [`as_time_travel`](Self::as_time_travel) call), the update /// is silently ignored — the write already committed to storage. pub fn update(&self, dataset: Dataset) { - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner()); if state.pinned_version.is_some() { // A concurrent as_time_travel() beat us here. The write succeeded // in storage, but since we're now pinned we don't advance the @@ -139,7 +139,7 @@ impl DatasetConsistencyWrapper { /// Check that the dataset is in a mutable mode (Latest). pub fn ensure_mutable(&self) -> Result<()> { - let state = self.state.lock().unwrap(); + let state = self.state.lock()?; if state.pinned_version.is_some() { Err(crate::Error::InvalidInput { message: "table cannot be modified when a specific version is checked out" @@ -152,13 +152,16 @@ impl DatasetConsistencyWrapper { /// Returns the version, if in time travel mode, or None otherwise. pub fn time_travel_version(&self) -> Option { - self.state.lock().unwrap().pinned_version + self.state + .lock() + .unwrap_or_else(|e| e.into_inner()) + .pinned_version } /// Convert into a wrapper in latest version mode. pub async fn as_latest(&self) -> Result<()> { let dataset = { - let state = self.state.lock().unwrap(); + let state = self.state.lock()?; if state.pinned_version.is_none() { return Ok(()); } @@ -168,7 +171,7 @@ impl DatasetConsistencyWrapper { let latest_version = dataset.latest_version_id().await?; let new_dataset = dataset.checkout_version(latest_version).await?; - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock()?; if state.pinned_version.is_some() { state.dataset = Arc::new(new_dataset); state.pinned_version = None; @@ -184,7 +187,7 @@ impl DatasetConsistencyWrapper { let target_ref = target_version.into(); let (should_checkout, dataset) = { - let state = self.state.lock().unwrap(); + let state = self.state.lock()?; let should = match state.pinned_version { None => true, Some(version) => match &target_ref { @@ -204,7 +207,7 @@ impl DatasetConsistencyWrapper { let new_dataset = dataset.checkout_version(target_ref).await?; let version_value = new_dataset.version().version; - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock()?; state.dataset = Arc::new(new_dataset); state.pinned_version = Some(version_value); Ok(()) @@ -212,7 +215,7 @@ impl DatasetConsistencyWrapper { pub async fn reload(&self) -> Result<()> { let (dataset, pinned_version) = { - let state = self.state.lock().unwrap(); + let state = self.state.lock()?; (state.dataset.clone(), state.pinned_version) }; @@ -230,7 +233,7 @@ impl DatasetConsistencyWrapper { let new_dataset = dataset.checkout_version(version).await?; - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock()?; if state.pinned_version == Some(version) { state.dataset = Arc::new(new_dataset); } @@ -242,14 +245,14 @@ impl DatasetConsistencyWrapper { } async fn refresh_latest(state: Arc>) -> Result> { - let dataset = { state.lock().unwrap().dataset.clone() }; + let dataset = { state.lock()?.dataset.clone() }; let mut ds = (*dataset).clone(); ds.checkout_latest().await?; let new_arc = Arc::new(ds); { - let mut state = state.lock().unwrap(); + let mut state = state.lock()?; if state.pinned_version.is_none() && new_arc.manifest().version >= state.dataset.manifest().version { @@ -612,4 +615,108 @@ mod tests { let s = io_stats.incremental_stats(); assert_eq!(s.read_iops, 0, "step 5, elapsed={:?}", start.elapsed()); } + + /// Helper: poison the mutex inside a DatasetConsistencyWrapper. + fn poison_state(wrapper: &DatasetConsistencyWrapper) { + let state = wrapper.state.clone(); + let handle = std::thread::spawn(move || { + let _guard = state.lock().unwrap(); + panic!("intentional panic to poison mutex"); + }); + let _ = handle.join(); // join collects the panic + assert!(wrapper.state.lock().is_err(), "mutex should be poisoned"); + } + + #[tokio::test] + async fn test_get_returns_error_on_poisoned_lock() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + poison_state(&wrapper); + + // get() should return Err, not panic + let result = wrapper.get().await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_ensure_mutable_returns_error_on_poisoned_lock() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + poison_state(&wrapper); + + let result = wrapper.ensure_mutable(); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_update_recovers_from_poisoned_lock() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + let ds_v2 = append_to_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + poison_state(&wrapper); + + // update() returns (), should not panic + wrapper.update(ds_v2); + } + + #[tokio::test] + async fn test_time_travel_version_recovers_from_poisoned_lock() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + poison_state(&wrapper); + + // Should not panic, returns whatever was in the mutex + let _version = wrapper.time_travel_version(); + } + + #[tokio::test] + async fn test_as_latest_returns_error_on_poisoned_lock() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + poison_state(&wrapper); + + let result = wrapper.as_latest().await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_as_time_travel_returns_error_on_poisoned_lock() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + poison_state(&wrapper); + + let result = wrapper.as_time_travel(1u64).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_reload_returns_error_on_poisoned_lock() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + poison_state(&wrapper); + + let result = wrapper.reload().await; + assert!(result.is_err()); + } } diff --git a/rust/lancedb/src/table/write_progress.rs b/rust/lancedb/src/table/write_progress.rs index bf1b513ae..7a5c30008 100644 --- a/rust/lancedb/src/table/write_progress.rs +++ b/rust/lancedb/src/table/write_progress.rs @@ -130,8 +130,11 @@ impl WriteProgressTracker { pub fn record_batch(&self, rows: usize, bytes: usize) { // Lock order: callback first, then rows_and_bytes. This is the only // order used anywhere, so deadlocks cannot occur. - let mut cb = self.callback.lock().unwrap(); - let mut guard = self.rows_and_bytes.lock().unwrap(); + let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner()); + let mut guard = self + .rows_and_bytes + .lock() + .unwrap_or_else(|e| e.into_inner()); guard.0 += rows; guard.1 += bytes; let progress = self.snapshot(guard.0, guard.1, false); @@ -151,8 +154,11 @@ impl WriteProgressTracker { /// `total_rows` is always `Some` on the final callback: it uses the known /// total if available, or falls back to the number of rows actually written. pub fn finish(&self) { - let mut cb = self.callback.lock().unwrap(); - let guard = self.rows_and_bytes.lock().unwrap(); + let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner()); + let guard = self + .rows_and_bytes + .lock() + .unwrap_or_else(|e| e.into_inner()); let mut snap = self.snapshot(guard.0, guard.1, true); snap.total_rows = Some(self.total_rows.unwrap_or(guard.0)); drop(guard); @@ -376,4 +382,50 @@ mod tests { } } } + + #[test] + fn test_record_batch_recovers_from_poisoned_callback_lock() { + use super::{ProgressCallback, WriteProgressTracker}; + use std::sync::Mutex; + + let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {})); + + // Poison the callback mutex + let cb_clone = callback.clone(); + let handle = std::thread::spawn(move || { + let _guard = cb_clone.lock().unwrap(); + panic!("intentional panic to poison callback mutex"); + }); + let _ = handle.join(); + assert!( + callback.lock().is_err(), + "callback mutex should be poisoned" + ); + + let tracker = WriteProgressTracker::new(callback, Some(100)); + + // record_batch should not panic + tracker.record_batch(10, 1024); + } + + #[test] + fn test_finish_recovers_from_poisoned_callback_lock() { + use super::{ProgressCallback, WriteProgressTracker}; + use std::sync::Mutex; + + let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {})); + + // Poison the callback mutex + let cb_clone = callback.clone(); + let handle = std::thread::spawn(move || { + let _guard = cb_clone.lock().unwrap(); + panic!("intentional panic to poison callback mutex"); + }); + let _ = handle.join(); + + let tracker = WriteProgressTracker::new(callback, Some(100)); + + // finish should not panic + tracker.finish(); + } } diff --git a/rust/lancedb/src/utils/background_cache.rs b/rust/lancedb/src/utils/background_cache.rs index 211630556..851f495f4 100644 --- a/rust/lancedb/src/utils/background_cache.rs +++ b/rust/lancedb/src/utils/background_cache.rs @@ -122,7 +122,7 @@ where /// This is a cheap synchronous check useful as a fast path before /// constructing a fetch closure for [`get()`](Self::get). pub fn try_get(&self) -> Option { - let cache = self.inner.lock().unwrap(); + let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner()); cache.state.fresh_value(self.ttl, self.refresh_window) } @@ -138,7 +138,7 @@ where { // Fast path: check if cache is fresh { - let cache = self.inner.lock().unwrap(); + let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner()); if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) { return Ok(value); } @@ -147,7 +147,7 @@ where // Slow path let mut fetch = Some(fetch); let action = { - let mut cache = self.inner.lock().unwrap(); + let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner()); self.determine_action(&mut cache, &mut fetch) }; @@ -161,7 +161,7 @@ where /// /// This avoids a blocking fetch on the first [`get()`](Self::get) call. pub fn seed(&self, value: V) { - let mut cache = self.inner.lock().unwrap(); + let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner()); cache.state = State::Current(value, clock::now()); } @@ -170,7 +170,7 @@ where /// Any in-flight background fetch from before this call will not update the /// cache (the generation counter prevents stale writes). pub fn invalidate(&self) { - let mut cache = self.inner.lock().unwrap(); + let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner()); cache.state = State::Empty; cache.generation += 1; } @@ -267,7 +267,7 @@ where let fut_for_spawn = shared.clone(); tokio::spawn(async move { let result = fut_for_spawn.await; - let mut cache = inner.lock().unwrap(); + let mut cache = inner.lock().unwrap_or_else(|e| e.into_inner()); // Only update if no invalidation has happened since we started if cache.generation != generation { return; @@ -590,4 +590,67 @@ mod tests { let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap(); assert_eq!(v, "fresh"); } + + /// Helper: poison the inner mutex of a BackgroundCache. + fn poison_cache(cache: &BackgroundCache) { + let inner = cache.inner.clone(); + let handle = std::thread::spawn(move || { + let _guard = inner.lock().unwrap(); + panic!("intentional panic to poison mutex"); + }); + let _ = handle.join(); + assert!(cache.inner.lock().is_err(), "mutex should be poisoned"); + } + + #[tokio::test] + async fn test_try_get_recovers_from_poisoned_lock() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + // Seed a value first + cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); + cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); // peek + + poison_cache(&cache); + + // try_get() should not panic — it recovers via unwrap_or_else + let result = cache.try_get(); + // The value may or may not be fresh depending on timing, but it must not panic + let _ = result; + } + + #[tokio::test] + async fn test_get_recovers_from_poisoned_lock() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + poison_cache(&cache); + + // get() should not panic — it recovers and can still fetch + let result = cache.get(ok_fetcher(count.clone(), "recovered")).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "recovered"); + } + + #[tokio::test] + async fn test_seed_recovers_from_poisoned_lock() { + let cache = new_cache(); + poison_cache(&cache); + + // seed() should not panic + cache.seed("seeded".to_string()); + } + + #[tokio::test] + async fn test_invalidate_recovers_from_poisoned_lock() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); + + poison_cache(&cache); + + // invalidate() should not panic + cache.invalidate(); + } }