From b534a18017a325858c14cf6be6999119972ada97 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Tue, 29 Jun 2021 17:23:47 +0200 Subject: [PATCH] Async pool implementation (#637) --- Cargo.toml | 6 +- src/executor.rs | 94 +++++++- src/lib.rs | 1 + src/transport/smtp/async_transport.rs | 37 +++- src/transport/smtp/mod.rs | 4 +- src/transport/smtp/pool/async_impl.rs | 299 ++++++++++++++++++++++++++ src/transport/smtp/pool/mod.rs | 2 + 7 files changed, 434 insertions(+), 9 deletions(-) create mode 100644 src/transport/smtp/pool/async_impl.rs diff --git a/Cargo.toml b/Cargo.toml index 015332e..78bc7ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ maintenance = { status = "actively-developed" } [dependencies] idna = "0.2" +once_cell = "1" tracing = { version = "0.1.16", default-features = false, features = ["std"], optional = true } # feature # builder @@ -27,7 +28,6 @@ mime = { version = "0.3.4", optional = true } fastrand = { version = "1.4", optional = true } quoted_printable = { version = "0.4", optional = true } base64 = { version = "0.13", optional = true } -once_cell = "1" regex = { version = "1", default-features = false, features = ["std", "unicode-case"] } # file transport @@ -76,7 +76,7 @@ harness = false name = "transport_smtp" [features] -default = ["smtp-transport", "native-tls", "hostname", "r2d2", "builder"] +default = ["smtp-transport", "pool", "native-tls", "hostname", "r2d2", "builder"] builder = ["httpdate", "mime", "base64", "fastrand", "quoted_printable"] # transports @@ -85,6 +85,8 @@ file-transport-envelope = ["serde", "serde_json", "file-transport"] sendmail-transport = [] smtp-transport = ["base64", "nom"] +pool = ["futures-util"] + rustls-tls = ["webpki", "webpki-roots", "rustls"] # async diff --git a/src/executor.rs b/src/executor.rs index 2dabfe1..be00192 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -1,6 +1,10 @@ use async_trait::async_trait; +#[cfg(all(feature = "smtp-transport", feature = "async-std1"))] +use futures_util::future::BoxFuture; use std::fmt::Debug; +#[cfg(feature = "smtp-transport")] +use std::future::Future; #[cfg(feature = "file-transport")] use std::io::Result as IoResult; #[cfg(feature = "file-transport")] @@ -39,7 +43,23 @@ use crate::transport::smtp::Error; /// [`AsyncFileTransport`]: crate::AsyncFileTransport #[cfg_attr(docsrs, doc(cfg(any(feature = "tokio1", feature = "async-std1"))))] #[async_trait] -pub trait Executor: Debug + Send + Sync + private::Sealed { +pub trait Executor: Debug + Send + Sync + 'static + private::Sealed { + #[cfg(feature = "smtp-transport")] + type Handle: SpawnHandle; + #[cfg(feature = "smtp-transport")] + type Sleep: Future + Send + 'static; + + #[doc(hidden)] + #[cfg(feature = "smtp-transport")] + fn spawn(fut: F) -> Self::Handle + where + F: Future + Send + 'static, + F::Output: Send + 'static; + + #[doc(hidden)] + #[cfg(feature = "smtp-transport")] + fn sleep(duration: Duration) -> Self::Sleep; + #[doc(hidden)] #[cfg(feature = "smtp-transport")] async fn connect( @@ -59,6 +79,13 @@ pub trait Executor: Debug + Send + Sync + private::Sealed { async fn fs_write(path: &Path, contents: &[u8]) -> IoResult<()>; } +#[doc(hidden)] +#[cfg(feature = "smtp-transport")] +#[async_trait] +pub trait SpawnHandle: Debug + Send + Sync + 'static + private::Sealed { + async fn shutdown(self); +} + /// Async [`Executor`] using `tokio` `1.x` /// /// Used by [`AsyncSmtpTransport`], [`AsyncSendmailTransport`] and [`AsyncFileTransport`] @@ -77,6 +104,27 @@ pub struct Tokio1Executor; #[async_trait] #[cfg(feature = "tokio1")] impl Executor for Tokio1Executor { + #[cfg(feature = "smtp-transport")] + type Handle = tokio1_crate::task::JoinHandle<()>; + #[cfg(feature = "smtp-transport")] + type Sleep = tokio1_crate::time::Sleep; + + #[doc(hidden)] + #[cfg(feature = "smtp-transport")] + fn spawn(fut: F) -> Self::Handle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + tokio1_crate::spawn(fut) + } + + #[doc(hidden)] + #[cfg(feature = "smtp-transport")] + fn sleep(duration: Duration) -> Self::Sleep { + tokio1_crate::time::sleep(duration) + } + #[doc(hidden)] #[cfg(feature = "smtp-transport")] async fn connect( @@ -130,6 +178,14 @@ impl Executor for Tokio1Executor { } } +#[cfg(all(feature = "smtp-transport", feature = "tokio1"))] +#[async_trait] +impl SpawnHandle for tokio1_crate::task::JoinHandle<()> { + async fn shutdown(self) { + self.abort(); + } +} + /// Async [`Executor`] using `async-std` `1.x` /// /// Used by [`AsyncSmtpTransport`], [`AsyncSendmailTransport`] and [`AsyncFileTransport`] @@ -148,6 +204,28 @@ pub struct AsyncStd1Executor; #[async_trait] #[cfg(feature = "async-std1")] impl Executor for AsyncStd1Executor { + #[cfg(feature = "smtp-transport")] + type Handle = async_std::task::JoinHandle<()>; + #[cfg(feature = "smtp-transport")] + type Sleep = BoxFuture<'static, ()>; + + #[doc(hidden)] + #[cfg(feature = "smtp-transport")] + fn spawn(fut: F) -> Self::Handle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + async_std::task::spawn(fut) + } + + #[doc(hidden)] + #[cfg(feature = "smtp-transport")] + fn sleep(duration: Duration) -> Self::Sleep { + let fut = async move { async_std::task::sleep(duration).await }; + Box::pin(fut) + } + #[doc(hidden)] #[cfg(feature = "smtp-transport")] async fn connect( @@ -201,6 +279,14 @@ impl Executor for AsyncStd1Executor { } } +#[cfg(all(feature = "smtp-transport", feature = "async-std1"))] +#[async_trait] +impl SpawnHandle for async_std::task::JoinHandle<()> { + async fn shutdown(self) { + self.cancel().await; + } +} + mod private { use super::*; @@ -211,4 +297,10 @@ mod private { #[cfg(feature = "async-std1")] impl Sealed for AsyncStd1Executor {} + + #[cfg(all(feature = "smtp-transport", feature = "tokio1"))] + impl Sealed for tokio1_crate::task::JoinHandle<()> {} + + #[cfg(all(feature = "smtp-transport", feature = "async-std1"))] + impl Sealed for async_std::task::JoinHandle<()> {} } diff --git a/src/lib.rs b/src/lib.rs index a1c0482..179bb6e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,7 @@ //! //! * **smtp-transport** 📫: Enable the SMTP transport //! * **r2d2** 📫: Connection pool for SMTP transport +//! * **pool** 📫: Async connection pool for SMTP transport //! * **hostname** 📫: Try to use the actual system hostname for the SMTP `CLIENTID` //! //! #### SMTP over TLS via the native-tls crate diff --git a/src/transport/smtp/async_transport.rs b/src/transport/smtp/async_transport.rs index 3fd6700..eeea48d 100644 --- a/src/transport/smtp/async_transport.rs +++ b/src/transport/smtp/async_transport.rs @@ -1,11 +1,16 @@ use std::{ fmt::{self, Debug}, marker::PhantomData, + sync::Arc, time::Duration, }; use async_trait::async_trait; +#[cfg(feature = "pool")] +use super::pool::async_impl::Pool; +#[cfg(feature = "pool")] +use super::PoolConfig; use super::{ client::AsyncSmtpConnection, ClientId, Credentials, Error, Mechanism, Response, SmtpInfo, }; @@ -19,8 +24,10 @@ use crate::{Envelope, Executor}; /// Asynchronously sends emails using the SMTP protocol #[cfg_attr(docsrs, doc(cfg(any(feature = "tokio1", feature = "async-std1"))))] -pub struct AsyncSmtpTransport { - // TODO: pool +pub struct AsyncSmtpTransport { + #[cfg(feature = "pool")] + inner: Arc>, + #[cfg(not(feature = "pool"))] inner: AsyncSmtpClient, } @@ -36,6 +43,7 @@ impl AsyncTransport for AsyncSmtpTransport { let result = conn.send(envelope, email).await?; + #[cfg(not(feature = "pool"))] conn.quit().await?; Ok(result) @@ -153,11 +161,15 @@ where server: server.into(), ..Default::default() }; - AsyncSmtpTransportBuilder { info } + AsyncSmtpTransportBuilder { + info, + #[cfg(feature = "pool")] + pool_config: PoolConfig::default(), + } } } -impl Debug for AsyncSmtpTransport { +impl Debug for AsyncSmtpTransport { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut builder = f.debug_struct("AsyncSmtpTransport"); builder.field("inner", &self.inner); @@ -182,6 +194,8 @@ where #[cfg_attr(docsrs, doc(cfg(any(feature = "tokio1", feature = "async-std1"))))] pub struct AsyncSmtpTransportBuilder { info: SmtpInfo, + #[cfg(feature = "pool")] + pool_config: PoolConfig, } /// Builder for the SMTP `AsyncSmtpTransport` @@ -236,6 +250,16 @@ impl AsyncSmtpTransportBuilder { self } + /// Use a custom configuration for the connection pool + /// + /// Defaults can be found at [`PoolConfig`] + #[cfg(feature = "pool")] + #[cfg_attr(docsrs, doc(cfg(feature = "pool")))] + pub fn pool_config(mut self, pool_config: PoolConfig) -> Self { + self.pool_config = pool_config; + self + } + /// Build the transport pub fn build(self) -> AsyncSmtpTransport where @@ -246,6 +270,9 @@ impl AsyncSmtpTransportBuilder { marker_: PhantomData, }; + #[cfg(feature = "pool")] + let client = Pool::new(self.pool_config, client); + AsyncSmtpTransport { inner: client } } } @@ -288,6 +315,8 @@ impl Debug for AsyncSmtpClient { } } +// `clone` is unused when the `pool` feature is on +#[allow(dead_code)] impl AsyncSmtpClient where E: Executor, diff --git a/src/transport/smtp/mod.rs b/src/transport/smtp/mod.rs index 7292f2e..f9ad2db 100644 --- a/src/transport/smtp/mod.rs +++ b/src/transport/smtp/mod.rs @@ -118,7 +118,7 @@ #[cfg(any(feature = "tokio1", feature = "async-std1"))] pub use self::async_transport::{AsyncSmtpTransport, AsyncSmtpTransportBuilder}; -#[cfg(feature = "r2d2")] +#[cfg(any(feature = "r2d2", feature = "pool"))] pub use self::pool::PoolConfig; #[cfg(feature = "r2d2")] pub(crate) use self::transport::SmtpClient; @@ -144,7 +144,7 @@ pub mod client; pub mod commands; mod error; pub mod extension; -#[cfg(feature = "r2d2")] +#[cfg(any(feature = "r2d2", feature = "pool"))] mod pool; pub mod response; mod transport; diff --git a/src/transport/smtp/pool/async_impl.rs b/src/transport/smtp/pool/async_impl.rs new file mode 100644 index 0000000..b947a25 --- /dev/null +++ b/src/transport/smtp/pool/async_impl.rs @@ -0,0 +1,299 @@ +use std::fmt::{self, Debug}; +use std::mem; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use futures_util::lock::Mutex; +use futures_util::stream::{self, StreamExt}; +use once_cell::sync::OnceCell; + +use crate::executor::SpawnHandle; +use crate::transport::smtp::async_transport::AsyncSmtpClient; +use crate::Executor; + +use super::super::client::AsyncSmtpConnection; +use super::super::Error; +use super::PoolConfig; + +pub struct Pool { + config: PoolConfig, + connections: Mutex>, + client: AsyncSmtpClient, + handle: OnceCell, +} + +struct ParkedConnection { + conn: AsyncSmtpConnection, + since: Instant, +} + +pub struct PooledConnection { + conn: Option, + pool: Arc>, +} + +impl Pool { + pub fn new(config: PoolConfig, client: AsyncSmtpClient) -> Arc { + let pool = Arc::new(Self { + config, + connections: Mutex::new(Vec::new()), + client, + handle: OnceCell::new(), + }); + + { + let pool_ = Arc::clone(&pool); + + let min_idle = pool_.config.min_idle; + let idle_timeout = pool_.config.idle_timeout; + let pool = Arc::downgrade(&pool_); + + let handle = E::spawn(async move { + loop { + #[cfg(feature = "tracing")] + tracing::trace!("running cleanup tasks"); + + match pool.upgrade() { + Some(pool) => { + #[allow(clippy::needless_collect)] + let (count, dropped) = { + let mut connections = pool.connections.lock().await; + + let to_drop = connections + .iter() + .enumerate() + .rev() + .filter(|(_, conn)| conn.idle_duration() > idle_timeout) + .map(|(i, _)| i) + .collect::>(); + let dropped = to_drop + .into_iter() + .map(|i| connections.remove(i)) + .collect::>(); + + (connections.len(), dropped) + }; + + #[cfg(feature = "tracing")] + let mut created = 0; + for _ in count..=(min_idle as usize) { + let conn = match pool.client.connection().await { + Ok(conn) => conn, + Err(err) => { + #[cfg(feature = "tracing")] + tracing::warn!("couldn't create idle connection {}", err); + #[cfg(not(feature = "tracing"))] + let _ = err; + + break; + } + }; + + let mut connections = pool.connections.lock().await; + connections.push(ParkedConnection::park(conn)); + + #[cfg(feature = "tracing")] + { + created += 1; + } + } + + #[cfg(feature = "tracing")] + if created > 0 { + tracing::debug!("created {} idle connections", created); + } + + if !dropped.is_empty() { + #[cfg(feature = "tracing")] + tracing::debug!("dropped {} idle connections", dropped.len()); + + abort_concurrent(dropped.into_iter().map(|conn| conn.unpark())) + .await; + } + } + None => { + #[cfg(feature = "tracing")] + tracing::warn!( + "breaking out of task - no more references to Pool are available" + ); + break; + } + } + + E::sleep(idle_timeout).await; + } + }); + pool_ + .handle + .set(handle) + .expect("handle hasn't been set yet"); + } + + pool + } + + pub async fn connection(self: &Arc) -> Result, Error> { + loop { + let conn = { + let mut connections = self.connections.lock().await; + connections.pop() + }; + + match conn { + Some(conn) => { + let mut conn = conn.unpark(); + + // TODO: handle the client try another connection if this one isn't good + if !conn.test_connected().await { + #[cfg(feature = "tracing")] + tracing::debug!("dropping a broken connection"); + + conn.abort().await; + continue; + } + + #[cfg(feature = "tracing")] + tracing::debug!("reusing a pooled connection"); + + return Ok(PooledConnection::wrap(conn, self.clone())); + } + None => { + #[cfg(feature = "tracing")] + tracing::debug!("creating a new connection"); + + let conn = self.client.connection().await?; + return Ok(PooledConnection::wrap(conn, self.clone())); + } + } + } + } + + async fn recycle(&self, mut conn: AsyncSmtpConnection) { + if conn.has_broken() { + #[cfg(feature = "tracing")] + tracing::debug!("dropping a broken connection instead of recycling it"); + + conn.abort().await; + drop(conn); + } else { + #[cfg(feature = "tracing")] + tracing::debug!("recycling connection"); + + let mut connections = self.connections.lock().await; + if connections.len() >= self.config.max_size as usize { + drop(connections); + conn.abort().await; + } else { + let conn = ParkedConnection::park(conn); + connections.push(conn); + } + } + } +} + +impl Debug for Pool { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Pool") + .field("config", &self.config) + .field( + "connections", + &match self.connections.try_lock() { + Some(connections) => format!("{} connections", connections.len()), + + None => "LOCKED".to_string(), + }, + ) + .field("client", &self.client) + .field( + "handle", + &match self.handle.get() { + Some(_) => "Some(JoinHandle)", + None => "None", + }, + ) + .finish() + } +} + +impl Drop for Pool { + fn drop(&mut self) { + #[cfg(feature = "tracing")] + tracing::debug!("dropping Pool"); + + let connections = mem::take(self.connections.get_mut()); + let handle = self.handle.take(); + E::spawn(async move { + if let Some(handle) = handle { + handle.shutdown().await; + } + + abort_concurrent(connections.into_iter().map(|conn| conn.unpark())).await; + }); + } +} + +impl ParkedConnection { + fn park(conn: AsyncSmtpConnection) -> Self { + Self { + conn, + since: Instant::now(), + } + } + + fn idle_duration(&self) -> Duration { + self.since.elapsed() + } + + fn unpark(self) -> AsyncSmtpConnection { + self.conn + } +} + +impl PooledConnection { + fn wrap(conn: AsyncSmtpConnection, pool: Arc>) -> Self { + Self { + conn: Some(conn), + pool, + } + } +} + +impl Deref for PooledConnection { + type Target = AsyncSmtpConnection; + + fn deref(&self) -> &Self::Target { + self.conn.as_ref().expect("conn hasn't been dropped yet") + } +} + +impl DerefMut for PooledConnection { + fn deref_mut(&mut self) -> &mut Self::Target { + self.conn.as_mut().expect("conn hasn't been dropped yet") + } +} + +impl Drop for PooledConnection { + fn drop(&mut self) { + let conn = self + .conn + .take() + .expect("AsyncSmtpConnection hasn't been taken yet"); + let pool = Arc::clone(&self.pool); + + E::spawn(async move { + pool.recycle(conn).await; + }); + } +} + +async fn abort_concurrent(iter: I) +where + I: Iterator, +{ + stream::iter(iter) + .for_each_concurrent(8, |mut conn| async move { + conn.abort().await; + }) + .await; +} diff --git a/src/transport/smtp/pool/mod.rs b/src/transport/smtp/pool/mod.rs index 474348f..20181ea 100644 --- a/src/transport/smtp/pool/mod.rs +++ b/src/transport/smtp/pool/mod.rs @@ -1,5 +1,7 @@ use std::time::Duration; +#[cfg(all(feature = "pool", any(feature = "tokio1", feature = "async-std1")))] +pub mod async_impl; #[cfg(feature = "r2d2")] pub mod sync_impl;