diff --git a/src/executor.rs b/src/executor.rs index 80774e5..92830a8 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -84,7 +84,7 @@ pub trait Executor: Debug + Send + Sync + 'static + private::Sealed { #[cfg(feature = "smtp-transport")] #[async_trait] pub(crate) trait SpawnHandle: Debug + Send + Sync + 'static + private::Sealed { - async fn shutdown(self); + async fn shutdown(&self); } /// Async [`Executor`] using `tokio` `1.x` @@ -178,7 +178,7 @@ impl Executor for Tokio1Executor { #[cfg(all(feature = "smtp-transport", feature = "tokio1"))] #[async_trait] impl SpawnHandle for tokio1_crate::task::JoinHandle<()> { - async fn shutdown(self) { + async fn shutdown(&self) { self.abort(); } } @@ -202,7 +202,7 @@ pub struct AsyncStd1Executor; #[cfg(feature = "async-std1")] impl Executor for AsyncStd1Executor { #[cfg(feature = "smtp-transport")] - type Handle = async_std::task::JoinHandle<()>; + type Handle = futures_util::future::AbortHandle; #[cfg(feature = "smtp-transport")] type Sleep = BoxFuture<'static, ()>; @@ -212,7 +212,9 @@ impl Executor for AsyncStd1Executor { F: Future + Send + 'static, F::Output: Send + 'static, { - async_std::task::spawn(fut) + let (handle, registration) = futures_util::future::AbortHandle::new_pair(); + async_std::task::spawn(futures_util::future::Abortable::new(fut, registration)); + handle } #[cfg(feature = "smtp-transport")] @@ -273,9 +275,9 @@ impl Executor for AsyncStd1Executor { #[cfg(all(feature = "smtp-transport", feature = "async-std1"))] #[async_trait] -impl SpawnHandle for async_std::task::JoinHandle<()> { - async fn shutdown(self) { - self.cancel().await; +impl SpawnHandle for futures_util::future::AbortHandle { + async fn shutdown(&self) { + self.abort(); } } @@ -292,5 +294,5 @@ mod private { impl Sealed for tokio1_crate::task::JoinHandle<()> {} #[cfg(all(feature = "smtp-transport", feature = "async-std1"))] - impl Sealed for async_std::task::JoinHandle<()> {} + impl Sealed for futures_util::future::AbortHandle {} } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 13b0267..1bc4f5d 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -140,6 +140,10 @@ pub trait Transport { } fn send_raw(&self, envelope: &Envelope, email: &[u8]) -> Result; + + /// Shuts down the transport. Future calls to [`send`] and [`send_raw`] might + /// fail. + fn shutdown(&self) {} } /// Async Transport method for emails @@ -166,4 +170,8 @@ pub trait AsyncTransport { } async fn send_raw(&self, envelope: &Envelope, email: &[u8]) -> Result; + + /// Shuts down the transport. Future calls to [`send`] and [`send_raw`] might + /// fail. + async fn shutdown(&self) {} } diff --git a/src/transport/smtp/async_transport.rs b/src/transport/smtp/async_transport.rs index b0f3be5..e6cfc39 100644 --- a/src/transport/smtp/async_transport.rs +++ b/src/transport/smtp/async_transport.rs @@ -79,6 +79,11 @@ impl AsyncTransport for AsyncSmtpTransport { Ok(result) } + + async fn shutdown(&self) { + #[cfg(feature = "pool")] + self.inner.shutdown().await; + } } #[cfg(feature = "async-std1")] @@ -97,6 +102,11 @@ impl AsyncTransport for AsyncSmtpTransport { Ok(result) } + + async fn shutdown(&self) { + #[cfg(feature = "pool")] + self.inner.shutdown().await; + } } impl AsyncSmtpTransport diff --git a/src/transport/smtp/error.rs b/src/transport/smtp/error.rs index 4131fb0..5f8f352 100644 --- a/src/transport/smtp/error.rs +++ b/src/transport/smtp/error.rs @@ -77,6 +77,11 @@ impl Error { matches!(self.inner.kind, Kind::Tls) } + /// Returns true if the error is because the transport was shut down + pub fn is_transport_shutdown(&self) -> bool { + matches!(self.inner.kind, Kind::TransportShutdown) + } + /// Returns the status code, if the error was generated from a response. pub fn status(&self) -> Option { match self.inner.kind { @@ -111,6 +116,8 @@ pub(crate) enum Kind { )] #[cfg(any(feature = "native-tls", feature = "rustls", feature = "boring-tls"))] Tls, + /// Transport shutdown error + TransportShutdown, } impl fmt::Debug for Error { @@ -136,6 +143,7 @@ impl fmt::Display for Error { Kind::Connection => f.write_str("Connection error")?, #[cfg(any(feature = "native-tls", feature = "rustls", feature = "boring-tls"))] Kind::Tls => f.write_str("tls error")?, + Kind::TransportShutdown => f.write_str("transport has been shut down")?, Kind::Transient(code) => { write!(f, "transient error ({code})")?; } @@ -189,3 +197,7 @@ pub(crate) fn connection>(e: E) -> Error { pub(crate) fn tls>(e: E) -> Error { Error::new(Kind::Tls, Some(e)) } + +pub(crate) fn transport_shutdown() -> Error { + Error::new::(Kind::TransportShutdown, None) +} diff --git a/src/transport/smtp/pool/async_impl.rs b/src/transport/smtp/pool/async_impl.rs index d6d41c2..7f17be8 100644 --- a/src/transport/smtp/pool/async_impl.rs +++ b/src/transport/smtp/pool/async_impl.rs @@ -1,6 +1,5 @@ use std::{ fmt::{self, Debug}, - mem, ops::{Deref, DerefMut}, sync::{Arc, OnceLock}, time::{Duration, Instant}, @@ -15,11 +14,15 @@ use super::{ super::{client::AsyncSmtpConnection, Error}, PoolConfig, }; -use crate::{executor::SpawnHandle, transport::smtp::async_transport::AsyncSmtpClient, Executor}; +use crate::{ + executor::SpawnHandle, + transport::smtp::{async_transport::AsyncSmtpClient, error}, + Executor, +}; pub(crate) struct Pool { config: PoolConfig, - connections: Mutex>, + connections: Mutex>>, client: AsyncSmtpClient, handle: OnceLock, } @@ -38,7 +41,7 @@ impl Pool { pub(crate) fn new(config: PoolConfig, client: AsyncSmtpClient) -> Arc { let pool = Arc::new(Self { config, - connections: Mutex::new(Vec::new()), + connections: Mutex::new(Some(Vec::new())), client, handle: OnceLock::new(), }); @@ -60,6 +63,10 @@ impl Pool { #[allow(clippy::needless_collect)] let (count, dropped) = { let mut connections = pool.connections.lock().await; + let Some(connections) = connections.as_mut() else { + // The transport was shut down + return; + }; let to_drop = connections .iter() @@ -92,6 +99,11 @@ impl Pool { }; let mut connections = pool.connections.lock().await; + let Some(connections) = connections.as_mut() else { + // The transport was shut down + return; + }; + connections.push(ParkedConnection::park(conn)); #[cfg(feature = "tracing")] @@ -134,10 +146,29 @@ impl Pool { pool } + pub(crate) async fn shutdown(&self) { + let connections = { self.connections.lock().await.take() }; + if let Some(connections) = connections { + stream::iter(connections) + .for_each_concurrent(8, |conn| async move { + conn.unpark().abort().await; + }) + .await; + } + + if let Some(handle) = self.handle.get() { + handle.shutdown().await; + } + } + pub(crate) async fn connection(self: &Arc) -> Result, Error> { loop { let conn = { let mut connections = self.connections.lock().await; + let Some(connections) = connections.as_mut() else { + // The transport was shut down + return Err(error::transport_shutdown()); + }; connections.pop() }; @@ -181,13 +212,20 @@ impl Pool { #[cfg(feature = "tracing")] tracing::debug!("recycling connection"); - let mut connections = self.connections.lock().await; - if connections.len() >= self.config.max_size as usize { - drop(connections); - conn.abort().await; + let mut connections_guard = self.connections.lock().await; + + if let Some(connections) = connections_guard.as_mut() { + if connections.len() >= self.config.max_size as usize { + drop(connections_guard); + conn.abort().await; + } else { + let conn = ParkedConnection::park(conn); + connections.push(conn); + } } else { - let conn = ParkedConnection::park(conn); - connections.push(conn); + // The pool has already been shut down + drop(connections_guard); + conn.abort().await; } } } @@ -200,7 +238,13 @@ impl Debug for Pool { .field( "connections", &match self.connections.try_lock() { - Some(connections) => format!("{} connections", connections.len()), + Some(connections) => { + if let Some(connections) = connections.as_ref() { + format!("{} connections", connections.len()) + } else { + "SHUT DOWN".to_owned() + } + } None => "LOCKED".to_owned(), }, @@ -222,14 +266,16 @@ impl Drop for Pool { #[cfg(feature = "tracing")] tracing::debug!("dropping Pool"); - let connections = mem::take(self.connections.get_mut()); + let connections = self.connections.get_mut().take(); let handle = self.handle.take(); E::spawn(async move { if let Some(handle) = handle { handle.shutdown().await; } - abort_concurrent(connections.into_iter().map(ParkedConnection::unpark)).await; + if let Some(connections) = connections { + abort_concurrent(connections.into_iter().map(ParkedConnection::unpark)).await; + } }); } } diff --git a/src/transport/smtp/pool/sync_impl.rs b/src/transport/smtp/pool/sync_impl.rs index 6c24322..c9e5690 100644 --- a/src/transport/smtp/pool/sync_impl.rs +++ b/src/transport/smtp/pool/sync_impl.rs @@ -1,8 +1,7 @@ use std::{ fmt::{self, Debug}, - mem, ops::{Deref, DerefMut}, - sync::{Arc, Mutex, TryLockError}, + sync::{mpsc, Arc, Mutex, TryLockError}, thread, time::{Duration, Instant}, }; @@ -11,11 +10,12 @@ use super::{ super::{client::SmtpConnection, Error}, PoolConfig, }; -use crate::transport::smtp::transport::SmtpClient; +use crate::transport::smtp::{error, transport::SmtpClient}; pub(crate) struct Pool { config: PoolConfig, - connections: Mutex>, + connections: Mutex>>, + thread_terminator: mpsc::SyncSender<()>, client: SmtpClient, } @@ -31,9 +31,12 @@ pub(crate) struct PooledConnection { impl Pool { pub(crate) fn new(config: PoolConfig, client: SmtpClient) -> Arc { + let (thread_tx, thread_rx) = mpsc::sync_channel(1); + let pool = Arc::new(Self { config, - connections: Mutex::new(Vec::new()), + connections: Mutex::new(Some(Vec::new())), + thread_terminator: thread_tx, client, }); @@ -54,6 +57,10 @@ impl Pool { #[allow(clippy::needless_collect)] let (count, dropped) = { let mut connections = pool.connections.lock().unwrap(); + let Some(connections) = connections.as_mut() else { + // The transport was shut down + return; + }; let to_drop = connections .iter() @@ -86,6 +93,11 @@ impl Pool { }; let mut connections = pool.connections.lock().unwrap(); + let Some(connections) = connections.as_mut() else { + // The transport was shut down + return; + }; + connections.push(ParkedConnection::park(conn)); #[cfg(feature = "tracing")] @@ -110,7 +122,14 @@ impl Pool { } drop(pool); - thread::sleep(idle_timeout); + + match thread_rx.recv_timeout(idle_timeout) { + Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => { + // The transport was shut down + return; + } + Err(mpsc::RecvTimeoutError::Timeout) => {} + } } }) .expect("couldn't spawn the Pool thread"); @@ -119,10 +138,25 @@ impl Pool { pool } + pub(crate) fn shutdown(&self) { + let connections = { self.connections.lock().unwrap().take() }; + if let Some(connections) = connections { + for conn in connections { + conn.unpark().abort(); + } + } + + _ = self.thread_terminator.try_send(()); + } + pub(crate) fn connection(self: &Arc) -> Result { loop { let conn = { let mut connections = self.connections.lock().unwrap(); + let Some(connections) = connections.as_mut() else { + // The transport was shut down + return Err(error::transport_shutdown()); + }; connections.pop() }; @@ -166,13 +200,20 @@ impl Pool { #[cfg(feature = "tracing")] tracing::debug!("recycling connection"); - let mut connections = self.connections.lock().unwrap(); - if connections.len() >= self.config.max_size as usize { - drop(connections); - conn.abort(); + let mut connections_guard = self.connections.lock().unwrap(); + + if let Some(connections) = connections_guard.as_mut() { + if connections.len() >= self.config.max_size as usize { + drop(connections_guard); + conn.abort(); + } else { + let conn = ParkedConnection::park(conn); + connections.push(conn); + } } else { - let conn = ParkedConnection::park(conn); - connections.push(conn); + // The pool has already been shut down + drop(connections_guard); + conn.abort(); } } } @@ -185,7 +226,13 @@ impl Debug for Pool { .field( "connections", &match self.connections.try_lock() { - Ok(connections) => format!("{} connections", connections.len()), + Ok(connections) => { + if let Some(connections) = connections.as_ref() { + format!("{} connections", connections.len()) + } else { + "SHUT DOWN".to_owned() + } + } Err(TryLockError::WouldBlock) => "LOCKED".to_owned(), Err(TryLockError::Poisoned(_)) => "POISONED".to_owned(), @@ -201,10 +248,11 @@ impl Drop for Pool { #[cfg(feature = "tracing")] tracing::debug!("dropping Pool"); - let connections = mem::take(&mut *self.connections.get_mut().unwrap()); - for conn in connections { - let mut conn = conn.unpark(); - conn.abort(); + if let Some(connections) = self.connections.get_mut().unwrap().take() { + for conn in connections { + let mut conn = conn.unpark(); + conn.abort(); + } } } } diff --git a/src/transport/smtp/transport.rs b/src/transport/smtp/transport.rs index be158dc..e787e32 100644 --- a/src/transport/smtp/transport.rs +++ b/src/transport/smtp/transport.rs @@ -60,6 +60,11 @@ impl Transport for SmtpTransport { Ok(result) } + + fn shutdown(&self) { + #[cfg(feature = "pool")] + self.inner.shutdown(); + } } impl Debug for SmtpTransport {