diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 6f43476d67..bbc12ae876 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -377,35 +377,15 @@ impl Instance { ) -> server_error::Result { let db_string = ctx.get_db_string(); // fast cache check - { - let cache = self - .otlp_metrics_table_legacy_cache - .entry(db_string.clone()) - .or_default(); - - let hit_cache = names - .iter() - .filter_map(|name| cache.get(*name)) - .collect::>(); - if !hit_cache.is_empty() { - let hit_legacy = hit_cache.iter().any(|en| *en.value()); - let hit_prom = hit_cache.iter().any(|en| !*en.value()); - - // hit but have true and false, means both legacy and new mode are used - // we cannot handle this case, so return error - // add doc links in err msg later - ensure!(!(hit_legacy && hit_prom), OtlpMetricModeIncompatibleSnafu); - - let flag = hit_legacy; - // set cache for all names - names.iter().for_each(|name| { - if !cache.contains_key(*name) { - cache.insert(name.to_string(), flag); - } - }); - return Ok(flag); - } + let cache = self + .otlp_metrics_table_legacy_cache + .entry(db_string.clone()) + .or_default(); + if let Some(flag) = fast_legacy_check(&cache, names)? { + return Ok(flag); } + // release cache reference to avoid lock contention + drop(cache); let catalog = ctx.current_catalog(); let schema = ctx.current_schema(); @@ -486,6 +466,39 @@ impl Instance { } } +fn fast_legacy_check( + cache: &DashMap, + names: &[&String], +) -> server_error::Result> { + let hit_cache = names + .iter() + .filter_map(|name| cache.get(*name)) + .collect::>(); + if !hit_cache.is_empty() { + let hit_legacy = hit_cache.iter().any(|en| *en.value()); + let hit_prom = hit_cache.iter().any(|en| !*en.value()); + + // hit but have true and false, means both legacy and new mode are used + // we cannot handle this case, so return error + // add doc links in err msg later + ensure!(!(hit_legacy && hit_prom), OtlpMetricModeIncompatibleSnafu); + + let flag = hit_legacy; + // drop hit_cache to release references before inserting to avoid deadlock + drop(hit_cache); + + // set cache for all names + names.iter().for_each(|name| { + if !cache.contains_key(*name) { + cache.insert(name.to_string(), flag); + } + }); + Ok(Some(flag)) + } else { + Ok(None) + } +} + /// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements. /// For MySQL, it applies only to read-only statements. fn derive_timeout(stmt: &Statement, query_ctx: &QueryContextRef) -> Option { @@ -1050,6 +1063,10 @@ fn should_capture_statement(stmt: Option<&Statement>) -> bool { #[cfg(test)] mod tests { use std::collections::HashMap; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Arc, Barrier}; + use std::thread; + use std::time::{Duration, Instant}; use common_base::Plugins; use query::query_engine::options::QueryOptions; @@ -1059,6 +1076,122 @@ mod tests { use super::*; + #[test] + fn test_fast_legacy_check_deadlock_prevention() { + // Create a DashMap to simulate the cache + let cache = DashMap::new(); + + // Pre-populate cache with some entries + cache.insert("metric1".to_string(), true); // legacy mode + cache.insert("metric2".to_string(), false); // prom mode + cache.insert("metric3".to_string(), true); // legacy mode + + // Test case 1: Normal operation with cache hits + let metric1 = "metric1".to_string(); + let metric4 = "metric4".to_string(); + let names1 = vec![&metric1, &metric4]; + let result = fast_legacy_check(&cache, &names1); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(true)); // should return legacy mode + + // Verify that metric4 was added to cache + assert!(cache.contains_key("metric4")); + assert!(*cache.get("metric4").unwrap().value()); + + // Test case 2: No cache hits + let metric5 = "metric5".to_string(); + let metric6 = "metric6".to_string(); + let names2 = vec![&metric5, &metric6]; + let result = fast_legacy_check(&cache, &names2); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), None); // should return None as no cache hits + + // Test case 3: Incompatible modes should return error + let cache_incompatible = DashMap::new(); + cache_incompatible.insert("metric1".to_string(), true); // legacy + cache_incompatible.insert("metric2".to_string(), false); // prom + let metric1_test = "metric1".to_string(); + let metric2_test = "metric2".to_string(); + let names3 = vec![&metric1_test, &metric2_test]; + let result = fast_legacy_check(&cache_incompatible, &names3); + assert!(result.is_err()); // should error due to incompatible modes + + // Test case 4: Intensive concurrent access to test deadlock prevention + // This test specifically targets the scenario where multiple threads + // access the same cache entries simultaneously + let cache_concurrent = Arc::new(DashMap::new()); + cache_concurrent.insert("shared_metric".to_string(), true); + + let num_threads = 8; + let operations_per_thread = 100; + let barrier = Arc::new(Barrier::new(num_threads)); + let success_flag = Arc::new(AtomicBool::new(true)); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let cache_clone = Arc::clone(&cache_concurrent); + let barrier_clone = Arc::clone(&barrier); + let success_flag_clone = Arc::clone(&success_flag); + + thread::spawn(move || { + // Wait for all threads to be ready + barrier_clone.wait(); + + let start_time = Instant::now(); + for i in 0..operations_per_thread { + // Each operation references existing cache entry and adds new ones + let shared_metric = "shared_metric".to_string(); + let new_metric = format!("thread_{}_metric_{}", thread_id, i); + let names = vec![&shared_metric, &new_metric]; + + match fast_legacy_check(&cache_clone, &names) { + Ok(_) => {} + Err(_) => { + success_flag_clone.store(false, Ordering::Relaxed); + return; + } + } + + // If the test takes too long, it likely means deadlock + if start_time.elapsed() > Duration::from_secs(10) { + success_flag_clone.store(false, Ordering::Relaxed); + return; + } + } + }) + }) + .collect(); + + // Join all threads with timeout + let start_time = Instant::now(); + for (i, handle) in handles.into_iter().enumerate() { + let join_result = handle.join(); + + // Check if we're taking too long (potential deadlock) + if start_time.elapsed() > Duration::from_secs(30) { + panic!("Test timed out - possible deadlock detected!"); + } + + if join_result.is_err() { + panic!("Thread {} panicked during execution", i); + } + } + + // Verify all operations completed successfully + assert!( + success_flag.load(Ordering::Relaxed), + "Some operations failed" + ); + + // Verify that many new entries were added (proving operations completed) + let final_count = cache_concurrent.len(); + assert!( + final_count > 1 + num_threads * operations_per_thread / 2, + "Expected more cache entries, got {}", + final_count + ); + } + #[test] fn test_exec_validation() { let query_ctx = QueryContext::arc();