From 78ad89b4d5bac25f709fdde4091667536e33b484 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Fri, 11 Jul 2025 16:51:40 +0000 Subject: [PATCH] some abstraction for notifiers --- libs/neon_failpoint/src/lib.rs | 202 +++++++++++++++++++++------------ 1 file changed, 128 insertions(+), 74 deletions(-) diff --git a/libs/neon_failpoint/src/lib.rs b/libs/neon_failpoint/src/lib.rs index ac5538c448..ee793129cc 100644 --- a/libs/neon_failpoint/src/lib.rs +++ b/libs/neon_failpoint/src/lib.rs @@ -30,14 +30,14 @@ static FAILPOINTS: Lazy>> Lazy::new(|| Default::default()); /// Configuration for a single failpoint -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FailpointConfig { /// The action specification including probability pub action_spec: FailpointActionSpec, /// Optional context matching rules pub context_matchers: Option>, /// Notify objects for tasks waiting on this failpoint - pub notifiers: Vec>, + pub notifiers: FailpointNotifiers, /// Counter for probability-based actions pub trigger_count: u32, } @@ -84,6 +84,86 @@ pub enum FailpointResult { Cancelled, } +/// Collection of notifiers for a failpoint +/// +/// This abstraction manages the lifecycle of notification objects +/// and provides a clean interface for creating notifiers and broadcasting notifications. +#[derive(Debug, Default)] +pub struct FailpointNotifiers { + notifiers: Vec>, +} + +impl FailpointNotifiers { + /// Create a new empty collection of notifiers + pub fn new() -> Self { + Self { + notifiers: Vec::new(), + } + } + + /// Create a new notifier and add it to the collection + /// + /// Returns a `FailpointNotifier` that automatically removes itself + /// from the collection when dropped. + pub fn create_notifier(&mut self, cleanup_callback: F) -> FailpointNotifier + where + F: FnOnce(&Arc) + Send + 'static, + { + let notifier = Arc::new(Notify::new()); + self.notifiers.push(notifier.clone()); + + FailpointNotifier::new(notifier, cleanup_callback) + } + + /// Notify all waiting tasks + pub fn notify_all(&self) { + for notifier in &self.notifiers { + notifier.notify_waiters(); + } + } + + /// Remove a specific notifier from the collection + pub fn remove_notifier(&mut self, notifier: &Arc) { + self.notifiers.retain(|n| !Arc::ptr_eq(n, notifier)); + } +} + +/// Abstraction for managing failpoint notifications +/// +/// This handles the lifecycle of a notifier for a failpoint: +/// - Provides a future that can be awaited to receive notifications +/// - Automatically cleans up when dropped using a provided callback +pub struct FailpointNotifier { + notifier: Arc, + cleanup: Option) + Send>>, +} + +impl FailpointNotifier { + /// Create a new notifier with a cleanup callback + pub fn new(notifier: Arc, cleanup_callback: F) -> Self + where + F: FnOnce(&Arc) + Send + 'static, + { + Self { + notifier, + cleanup: Some(Box::new(cleanup_callback)), + } + } + + /// Get a future that will be notified when the failpoint configuration changes + pub fn notified(&self) -> impl Future + '_ { + self.notifier.notified() + } +} + +impl Drop for FailpointNotifier { + fn drop(&mut self) { + if let Some(cleanup) = self.cleanup.take() { + cleanup(&self.notifier); + } + } +} + /// Initialize failpoints from environment variable pub fn init() -> Result<()> { if let Ok(env_failpoints) = std::env::var("FAILPOINTS") { @@ -102,7 +182,7 @@ pub fn configure_failpoint(name: &str, actions: &str) -> Result<()> { let config = FailpointConfig { action_spec, context_matchers: None, - notifiers: Vec::new(), + notifiers: FailpointNotifiers::new(), trigger_count: 0, }; @@ -111,9 +191,7 @@ pub fn configure_failpoint(name: &str, actions: &str) -> Result<()> { // If this failpoint already exists, notify all waiting tasks if let Some(existing_config) = failpoints.get(name) { // Notify all waiting tasks about the configuration change - for notifier in &existing_config.notifiers { - notifier.notify_waiters(); - } + existing_config.notifiers.notify_all(); } failpoints.insert(name.to_string(), config); @@ -132,7 +210,7 @@ pub fn configure_failpoint_with_context( let config = FailpointConfig { action_spec, context_matchers: Some(context_matchers), - notifiers: Vec::new(), + notifiers: FailpointNotifiers::new(), trigger_count: 0, }; @@ -141,9 +219,7 @@ pub fn configure_failpoint_with_context( // If this failpoint already exists, notify all waiting tasks if let Some(existing_config) = failpoints.get(name) { // Notify all waiting tasks about the configuration change - for notifier in &existing_config.notifiers { - notifier.notify_waiters(); - } + existing_config.notifiers.notify_all(); } failpoints.insert(name.to_string(), config); @@ -158,9 +234,7 @@ pub fn remove_failpoint(name: &str) { // Notify all waiting tasks before removing if let Some(existing_config) = failpoints.get(name) { - for notifier in &existing_config.notifiers { - notifier.notify_waiters(); - } + existing_config.notifiers.notify_all(); } failpoints.remove(name); @@ -201,27 +275,36 @@ pub fn failpoint_with_cancellation( return Either::Left(FailpointResult::Continue); } - let config = { + // Check if the failpoint exists and get the necessary info + let (action_spec, context_matchers) = { let failpoints = FAILPOINTS.read().unwrap(); - failpoints.get(name).cloned() - }; - - let Some(config) = config else { - return Either::Left(FailpointResult::Continue); + let Some(config) = failpoints.get(name) else { + return Either::Left(FailpointResult::Continue); + }; + (config.action_spec.clone(), config.context_matchers.clone()) }; // Check context matchers if provided - if let (Some(matchers), Some(ctx)) = (&config.context_matchers, context) { + if let (Some(matchers), Some(ctx)) = (&context_matchers, context) { if !matches_context(matchers, ctx) { return Either::Left(FailpointResult::Continue); } } // Check probability and max_count - if let Some(probability) = config.action_spec.probability { + if let Some(probability) = action_spec.probability { // Check if we've hit the max count - if let Some(max_count) = config.action_spec.max_count { - if config.trigger_count >= max_count { + if let Some(max_count) = action_spec.max_count { + // Get the current trigger count + let trigger_count = { + let failpoints = FAILPOINTS.read().unwrap(); + failpoints + .get(name) + .map(|config| config.trigger_count) + .unwrap_or(0) + }; + + if trigger_count >= max_count { return Either::Left(FailpointResult::Continue); } } @@ -235,8 +318,7 @@ pub fn failpoint_with_cancellation( // Increment trigger count { - let mut failpoints: std::sync::RwLockWriteGuard<'_, HashMap> = - FAILPOINTS.write().unwrap(); + let mut failpoints = FAILPOINTS.write().unwrap(); if let Some(fp_config) = failpoints.get_mut(name) { fp_config.trigger_count += 1; } @@ -245,7 +327,25 @@ pub fn failpoint_with_cancellation( tracing::info!("Hit failpoint: {}", name); - execute_action(name, &config.action_spec, context, cancel_token) + execute_action(name, &action_spec, context, cancel_token) +} + +/// Create a notifier for a failpoint +fn create_failpoint_notifier(name: &str) -> FailpointNotifier { + let mut failpoints = FAILPOINTS.write().unwrap(); + if let Some(fp_config) = failpoints.get_mut(name) { + let cleanup_name = name.to_string(); + fp_config.notifiers.create_notifier(move |notifier| { + let mut failpoints = FAILPOINTS.write().unwrap(); + if let Some(fp_config) = failpoints.get_mut(&cleanup_name) { + fp_config.notifiers.remove_notifier(notifier); + } + }) + } else { + // Failpoint doesn't exist, create a dummy notifier + let notifier = Arc::new(Notify::new()); + FailpointNotifier::new(notifier, |_| {}) + } } /// Execute a specific action (used for recursive execution in probability-based actions) @@ -278,21 +378,7 @@ fn execute_action( tracing::info!("Failpoint {} sleeping for {}ms", name, millis); // Create a notifier for this task - let notifier = Arc::new(Notify::new()); - - // Add the notifier to the failpoint configuration - { - let mut failpoints = FAILPOINTS.write().unwrap(); - if let Some(fp_config) = failpoints.get_mut(&name) { - fp_config.notifiers.push(notifier.clone()); - } - } - - // Create cleanup guard to remove notifier when done - let _guard = NotifierCleanupGuard { - failpoint_name: name.clone(), - notifier: notifier.clone(), - }; + let notifier = create_failpoint_notifier(&name); // Sleep with cancellation support tokio::select! { @@ -319,21 +405,7 @@ fn execute_action( tracing::info!("Failpoint {} pausing", name); // Create a notifier for this task - let notifier = Arc::new(Notify::new()); - - // Add the notifier to the failpoint configuration - { - let mut failpoints = FAILPOINTS.write().unwrap(); - if let Some(fp_config) = failpoints.get_mut(&name) { - fp_config.notifiers.push(notifier.clone()); - } - } - - // Create cleanup guard to remove notifier when done - let _guard = NotifierCleanupGuard { - failpoint_name: name.clone(), - notifier: notifier.clone(), - }; + let notifier = create_failpoint_notifier(&name); // Wait until cancelled or notified tokio::select! { @@ -442,24 +514,6 @@ fn matches_context(matchers: &HashMap, context: &FailpointContex true } -/// RAII guard that removes a notifier from a failpoint when dropped -struct NotifierCleanupGuard { - failpoint_name: String, - notifier: Arc, -} - -impl Drop for NotifierCleanupGuard { - fn drop(&mut self) { - let mut failpoints = FAILPOINTS.write().unwrap(); - if let Some(fp_config) = failpoints.get_mut(&self.failpoint_name) { - // Remove this specific notifier from the list - fp_config - .notifiers - .retain(|n| !Arc::ptr_eq(n, &self.notifier)); - } - } -} - #[cfg(test)] mod tests { use super::*;