feat: add controlled shutdown methods (#1045)

This commit is contained in:
Popax21
2025-03-08 12:43:05 +01:00
committed by GitHub
parent 76bf68268f
commit 0b3a1ed278
7 changed files with 169 additions and 38 deletions

View File

@@ -84,7 +84,7 @@ pub trait Executor: Debug + Send + Sync + 'static + private::Sealed {
#[cfg(feature = "smtp-transport")]
#[async_trait]
pub(crate) trait SpawnHandle: Debug + Send + Sync + 'static + private::Sealed {
async fn shutdown(self);
async fn shutdown(&self);
}
/// Async [`Executor`] using `tokio` `1.x`
@@ -178,7 +178,7 @@ impl Executor for Tokio1Executor {
#[cfg(all(feature = "smtp-transport", feature = "tokio1"))]
#[async_trait]
impl SpawnHandle for tokio1_crate::task::JoinHandle<()> {
async fn shutdown(self) {
async fn shutdown(&self) {
self.abort();
}
}
@@ -202,7 +202,7 @@ pub struct AsyncStd1Executor;
#[cfg(feature = "async-std1")]
impl Executor for AsyncStd1Executor {
#[cfg(feature = "smtp-transport")]
type Handle = async_std::task::JoinHandle<()>;
type Handle = futures_util::future::AbortHandle;
#[cfg(feature = "smtp-transport")]
type Sleep = BoxFuture<'static, ()>;
@@ -212,7 +212,9 @@ impl Executor for AsyncStd1Executor {
F: Future<Output = ()> + Send + 'static,
F::Output: Send + 'static,
{
async_std::task::spawn(fut)
let (handle, registration) = futures_util::future::AbortHandle::new_pair();
async_std::task::spawn(futures_util::future::Abortable::new(fut, registration));
handle
}
#[cfg(feature = "smtp-transport")]
@@ -273,9 +275,9 @@ impl Executor for AsyncStd1Executor {
#[cfg(all(feature = "smtp-transport", feature = "async-std1"))]
#[async_trait]
impl SpawnHandle for async_std::task::JoinHandle<()> {
async fn shutdown(self) {
self.cancel().await;
impl SpawnHandle for futures_util::future::AbortHandle {
async fn shutdown(&self) {
self.abort();
}
}
@@ -292,5 +294,5 @@ mod private {
impl Sealed for tokio1_crate::task::JoinHandle<()> {}
#[cfg(all(feature = "smtp-transport", feature = "async-std1"))]
impl Sealed for async_std::task::JoinHandle<()> {}
impl Sealed for futures_util::future::AbortHandle {}
}

View File

@@ -140,6 +140,10 @@ pub trait Transport {
}
fn send_raw(&self, envelope: &Envelope, email: &[u8]) -> Result<Self::Ok, Self::Error>;
/// Shuts down the transport. Future calls to [`send`] and [`send_raw`] might
/// fail.
fn shutdown(&self) {}
}
/// Async Transport method for emails
@@ -166,4 +170,8 @@ pub trait AsyncTransport {
}
async fn send_raw(&self, envelope: &Envelope, email: &[u8]) -> Result<Self::Ok, Self::Error>;
/// Shuts down the transport. Future calls to [`send`] and [`send_raw`] might
/// fail.
async fn shutdown(&self) {}
}

View File

@@ -79,6 +79,11 @@ impl AsyncTransport for AsyncSmtpTransport<Tokio1Executor> {
Ok(result)
}
async fn shutdown(&self) {
#[cfg(feature = "pool")]
self.inner.shutdown().await;
}
}
#[cfg(feature = "async-std1")]
@@ -97,6 +102,11 @@ impl AsyncTransport for AsyncSmtpTransport<AsyncStd1Executor> {
Ok(result)
}
async fn shutdown(&self) {
#[cfg(feature = "pool")]
self.inner.shutdown().await;
}
}
impl<E> AsyncSmtpTransport<E>

View File

@@ -77,6 +77,11 @@ impl Error {
matches!(self.inner.kind, Kind::Tls)
}
/// Returns true if the error is because the transport was shut down
pub fn is_transport_shutdown(&self) -> bool {
matches!(self.inner.kind, Kind::TransportShutdown)
}
/// Returns the status code, if the error was generated from a response.
pub fn status(&self) -> Option<Code> {
match self.inner.kind {
@@ -111,6 +116,8 @@ pub(crate) enum Kind {
)]
#[cfg(any(feature = "native-tls", feature = "rustls", feature = "boring-tls"))]
Tls,
/// Transport shutdown error
TransportShutdown,
}
impl fmt::Debug for Error {
@@ -136,6 +143,7 @@ impl fmt::Display for Error {
Kind::Connection => f.write_str("Connection error")?,
#[cfg(any(feature = "native-tls", feature = "rustls", feature = "boring-tls"))]
Kind::Tls => f.write_str("tls error")?,
Kind::TransportShutdown => f.write_str("transport has been shut down")?,
Kind::Transient(code) => {
write!(f, "transient error ({code})")?;
}
@@ -189,3 +197,7 @@ pub(crate) fn connection<E: Into<BoxError>>(e: E) -> Error {
pub(crate) fn tls<E: Into<BoxError>>(e: E) -> Error {
Error::new(Kind::Tls, Some(e))
}
pub(crate) fn transport_shutdown() -> Error {
Error::new::<BoxError>(Kind::TransportShutdown, None)
}

View File

@@ -1,6 +1,5 @@
use std::{
fmt::{self, Debug},
mem,
ops::{Deref, DerefMut},
sync::{Arc, OnceLock},
time::{Duration, Instant},
@@ -15,11 +14,15 @@ use super::{
super::{client::AsyncSmtpConnection, Error},
PoolConfig,
};
use crate::{executor::SpawnHandle, transport::smtp::async_transport::AsyncSmtpClient, Executor};
use crate::{
executor::SpawnHandle,
transport::smtp::{async_transport::AsyncSmtpClient, error},
Executor,
};
pub(crate) struct Pool<E: Executor> {
config: PoolConfig,
connections: Mutex<Vec<ParkedConnection>>,
connections: Mutex<Option<Vec<ParkedConnection>>>,
client: AsyncSmtpClient<E>,
handle: OnceLock<E::Handle>,
}
@@ -38,7 +41,7 @@ impl<E: Executor> Pool<E> {
pub(crate) fn new(config: PoolConfig, client: AsyncSmtpClient<E>) -> Arc<Self> {
let pool = Arc::new(Self {
config,
connections: Mutex::new(Vec::new()),
connections: Mutex::new(Some(Vec::new())),
client,
handle: OnceLock::new(),
});
@@ -60,6 +63,10 @@ impl<E: Executor> Pool<E> {
#[allow(clippy::needless_collect)]
let (count, dropped) = {
let mut connections = pool.connections.lock().await;
let Some(connections) = connections.as_mut() else {
// The transport was shut down
return;
};
let to_drop = connections
.iter()
@@ -92,6 +99,11 @@ impl<E: Executor> Pool<E> {
};
let mut connections = pool.connections.lock().await;
let Some(connections) = connections.as_mut() else {
// The transport was shut down
return;
};
connections.push(ParkedConnection::park(conn));
#[cfg(feature = "tracing")]
@@ -134,10 +146,29 @@ impl<E: Executor> Pool<E> {
pool
}
pub(crate) async fn shutdown(&self) {
let connections = { self.connections.lock().await.take() };
if let Some(connections) = connections {
stream::iter(connections)
.for_each_concurrent(8, |conn| async move {
conn.unpark().abort().await;
})
.await;
}
if let Some(handle) = self.handle.get() {
handle.shutdown().await;
}
}
pub(crate) async fn connection(self: &Arc<Self>) -> Result<PooledConnection<E>, Error> {
loop {
let conn = {
let mut connections = self.connections.lock().await;
let Some(connections) = connections.as_mut() else {
// The transport was shut down
return Err(error::transport_shutdown());
};
connections.pop()
};
@@ -181,13 +212,20 @@ impl<E: Executor> Pool<E> {
#[cfg(feature = "tracing")]
tracing::debug!("recycling connection");
let mut connections = self.connections.lock().await;
if connections.len() >= self.config.max_size as usize {
drop(connections);
conn.abort().await;
let mut connections_guard = self.connections.lock().await;
if let Some(connections) = connections_guard.as_mut() {
if connections.len() >= self.config.max_size as usize {
drop(connections_guard);
conn.abort().await;
} else {
let conn = ParkedConnection::park(conn);
connections.push(conn);
}
} else {
let conn = ParkedConnection::park(conn);
connections.push(conn);
// The pool has already been shut down
drop(connections_guard);
conn.abort().await;
}
}
}
@@ -200,7 +238,13 @@ impl<E: Executor> Debug for Pool<E> {
.field(
"connections",
&match self.connections.try_lock() {
Some(connections) => format!("{} connections", connections.len()),
Some(connections) => {
if let Some(connections) = connections.as_ref() {
format!("{} connections", connections.len())
} else {
"SHUT DOWN".to_owned()
}
}
None => "LOCKED".to_owned(),
},
@@ -222,14 +266,16 @@ impl<E: Executor> Drop for Pool<E> {
#[cfg(feature = "tracing")]
tracing::debug!("dropping Pool");
let connections = mem::take(self.connections.get_mut());
let connections = self.connections.get_mut().take();
let handle = self.handle.take();
E::spawn(async move {
if let Some(handle) = handle {
handle.shutdown().await;
}
abort_concurrent(connections.into_iter().map(ParkedConnection::unpark)).await;
if let Some(connections) = connections {
abort_concurrent(connections.into_iter().map(ParkedConnection::unpark)).await;
}
});
}
}

View File

@@ -1,8 +1,7 @@
use std::{
fmt::{self, Debug},
mem,
ops::{Deref, DerefMut},
sync::{Arc, Mutex, TryLockError},
sync::{mpsc, Arc, Mutex, TryLockError},
thread,
time::{Duration, Instant},
};
@@ -11,11 +10,12 @@ use super::{
super::{client::SmtpConnection, Error},
PoolConfig,
};
use crate::transport::smtp::transport::SmtpClient;
use crate::transport::smtp::{error, transport::SmtpClient};
pub(crate) struct Pool {
config: PoolConfig,
connections: Mutex<Vec<ParkedConnection>>,
connections: Mutex<Option<Vec<ParkedConnection>>>,
thread_terminator: mpsc::SyncSender<()>,
client: SmtpClient,
}
@@ -31,9 +31,12 @@ pub(crate) struct PooledConnection {
impl Pool {
pub(crate) fn new(config: PoolConfig, client: SmtpClient) -> Arc<Self> {
let (thread_tx, thread_rx) = mpsc::sync_channel(1);
let pool = Arc::new(Self {
config,
connections: Mutex::new(Vec::new()),
connections: Mutex::new(Some(Vec::new())),
thread_terminator: thread_tx,
client,
});
@@ -54,6 +57,10 @@ impl Pool {
#[allow(clippy::needless_collect)]
let (count, dropped) = {
let mut connections = pool.connections.lock().unwrap();
let Some(connections) = connections.as_mut() else {
// The transport was shut down
return;
};
let to_drop = connections
.iter()
@@ -86,6 +93,11 @@ impl Pool {
};
let mut connections = pool.connections.lock().unwrap();
let Some(connections) = connections.as_mut() else {
// The transport was shut down
return;
};
connections.push(ParkedConnection::park(conn));
#[cfg(feature = "tracing")]
@@ -110,7 +122,14 @@ impl Pool {
}
drop(pool);
thread::sleep(idle_timeout);
match thread_rx.recv_timeout(idle_timeout) {
Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => {
// The transport was shut down
return;
}
Err(mpsc::RecvTimeoutError::Timeout) => {}
}
}
})
.expect("couldn't spawn the Pool thread");
@@ -119,10 +138,25 @@ impl Pool {
pool
}
pub(crate) fn shutdown(&self) {
let connections = { self.connections.lock().unwrap().take() };
if let Some(connections) = connections {
for conn in connections {
conn.unpark().abort();
}
}
_ = self.thread_terminator.try_send(());
}
pub(crate) fn connection(self: &Arc<Self>) -> Result<PooledConnection, Error> {
loop {
let conn = {
let mut connections = self.connections.lock().unwrap();
let Some(connections) = connections.as_mut() else {
// The transport was shut down
return Err(error::transport_shutdown());
};
connections.pop()
};
@@ -166,13 +200,20 @@ impl Pool {
#[cfg(feature = "tracing")]
tracing::debug!("recycling connection");
let mut connections = self.connections.lock().unwrap();
if connections.len() >= self.config.max_size as usize {
drop(connections);
conn.abort();
let mut connections_guard = self.connections.lock().unwrap();
if let Some(connections) = connections_guard.as_mut() {
if connections.len() >= self.config.max_size as usize {
drop(connections_guard);
conn.abort();
} else {
let conn = ParkedConnection::park(conn);
connections.push(conn);
}
} else {
let conn = ParkedConnection::park(conn);
connections.push(conn);
// The pool has already been shut down
drop(connections_guard);
conn.abort();
}
}
}
@@ -185,7 +226,13 @@ impl Debug for Pool {
.field(
"connections",
&match self.connections.try_lock() {
Ok(connections) => format!("{} connections", connections.len()),
Ok(connections) => {
if let Some(connections) = connections.as_ref() {
format!("{} connections", connections.len())
} else {
"SHUT DOWN".to_owned()
}
}
Err(TryLockError::WouldBlock) => "LOCKED".to_owned(),
Err(TryLockError::Poisoned(_)) => "POISONED".to_owned(),
@@ -201,10 +248,11 @@ impl Drop for Pool {
#[cfg(feature = "tracing")]
tracing::debug!("dropping Pool");
let connections = mem::take(&mut *self.connections.get_mut().unwrap());
for conn in connections {
let mut conn = conn.unpark();
conn.abort();
if let Some(connections) = self.connections.get_mut().unwrap().take() {
for conn in connections {
let mut conn = conn.unpark();
conn.abort();
}
}
}
}

View File

@@ -60,6 +60,11 @@ impl Transport for SmtpTransport {
Ok(result)
}
fn shutdown(&self) {
#[cfg(feature = "pool")]
self.inner.shutdown();
}
}
impl Debug for SmtpTransport {