From b3414bd1ff729684308d4935aae86ee7ee405577 Mon Sep 17 00:00:00 2001 From: Alexis Mousset Date: Sat, 2 May 2020 01:07:31 +0200 Subject: [PATCH] fix(transport): Fix connection pool --- Cargo.toml | 6 +-- src/transport/smtp/client/mod.rs | 48 ++++++++++++++++--- src/transport/smtp/commands.rs | 1 - src/transport/smtp/error.rs | 15 ++++-- src/transport/smtp/mod.rs | 82 ++++++++------------------------ src/transport/smtp/pool.rs | 22 +++++++++ src/transport/smtp/r2d2.rs | 39 --------------- 7 files changed, 98 insertions(+), 115 deletions(-) create mode 100644 src/transport/smtp/pool.rs delete mode 100644 src/transport/smtp/r2d2.rs diff --git a/Cargo.toml b/Cargo.toml index a0acd43..a98df95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ serde = { version = "1", optional = true, features = ["derive"] } serde_json = { version = "1", optional = true } textnonce = { version = "0.7", optional = true } webpki = { version = "0.21", optional = true } +webpki-roots = { version = "0.19", optional = true } [dev-dependencies] criterion = "0.3" @@ -50,10 +51,9 @@ name = "transport_smtp" [features] builder = ["mime", "base64", "hyperx", "textnonce", "quoted_printable"] -connection-pool = ["r2d2"] -default = ["file-transport", "smtp-transport", "hostname", "sendmail-transport", "native-tls", "builder"] +default = ["file-transport", "smtp-transport", "hostname", "sendmail-transport", "rustls-tls", "builder", "r2d2"] file-transport = ["serde", "serde_json"] -rustls-tls = ["webpki", "rustls"] +rustls-tls = ["webpki", "webpki-roots", "rustls"] sendmail-transport = [] smtp-transport = ["bufstream", "base64", "nom"] unstable = [] diff --git a/src/transport/smtp/client/mod.rs b/src/transport/smtp/client/mod.rs index 96d71bb..7cde3f7 100644 --- a/src/transport/smtp/client/mod.rs +++ b/src/transport/smtp/client/mod.rs @@ -5,9 +5,10 @@ use crate::transport::smtp::{ client::net::{NetworkStream, TlsParameters}, commands::*, error::{Error, SmtpResult}, - extension::{ClientId, Extension, ServerInfo}, + extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo}, response::Response, }; +use crate::Envelope; use bufstream::BufStream; use log::debug; #[cfg(feature = "serde")] @@ -151,6 +152,31 @@ impl SmtpConnection { Ok(conn) } + pub fn send(&mut self, envelope: &Envelope, email: &[u8]) -> SmtpResult { + // Mail + let mut mail_options = vec![]; + + if self.server_info().supports_feature(Extension::EightBitMime) { + mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime)); + } + try_smtp!( + self.command(Mail::new(envelope.from().cloned(), mail_options,)), + self + ); + + // Recipient + for to_address in envelope.to() { + try_smtp!(self.command(Rcpt::new(to_address.clone(), vec![])), self); + } + + // Data + try_smtp!(self.command(Data), self); + + // Message content + let result = try_smtp!(self.message(email), self); + Ok(result) + } + pub fn has_broken(&self) -> bool { self.panic } @@ -172,7 +198,8 @@ impl SmtpConnection { try_smtp!(self.stream.get_mut().upgrade_tls(tls_parameters), self); debug!("connection encrypted"); // Send EHLO again - self.ehlo(hello_name) + try_smtp!(self.ehlo(hello_name), self); + Ok(()) } #[cfg(not(any(feature = "native-tls", feature = "rustls")))] // This should never happen as `Tls` can only be created @@ -193,6 +220,10 @@ impl SmtpConnection { Ok(()) } + pub fn quit(&mut self) -> SmtpResult { + Ok(try_smtp!(self.command(Quit), self)) + } + pub fn abort(&mut self) { // Only try to quit if we are not already broken if !self.panic { @@ -239,11 +270,14 @@ impl SmtpConnection { while challenges > 0 && response.has_code(334) { challenges -= 1; - response = self.command(Auth::new_from_response( - mechanism, - credentials.clone(), - &response, - )?)?; + response = try_smtp!( + self.command(Auth::new_from_response( + mechanism, + credentials.clone(), + &response, + )?), + self + ); } if challenges == 0 { diff --git a/src/transport/smtp/commands.rs b/src/transport/smtp/commands.rs index 98467fe..a4eb498 100644 --- a/src/transport/smtp/commands.rs +++ b/src/transport/smtp/commands.rs @@ -24,7 +24,6 @@ pub struct Ehlo { impl Display for Ehlo { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - #[allow(clippy::write_with_newline)] write!(f, "EHLO {}\r\n", self.client_id) } } diff --git a/src/transport/smtp/error.rs b/src/transport/smtp/error.rs index f025873..4fad19a 100644 --- a/src/transport/smtp/error.rs +++ b/src/transport/smtp/error.rs @@ -3,11 +3,9 @@ use self::Error::*; use crate::transport::smtp::response::{Response, Severity}; use base64::DecodeError; -#[cfg(feature = "native-tls")] use std::{ error::Error as StdError, - fmt, - fmt::{Display, Formatter}, + fmt::{self, Display, Formatter}, io, string::FromUtf8Error, }; @@ -43,10 +41,11 @@ pub enum Error { /// Invalid hostname #[cfg(feature = "rustls-tls")] InvalidDNSName(webpki::InvalidDNSNameError), + #[cfg(feature = "r2d2")] + Pool(r2d2::Error), } impl Display for Error { - #[cfg_attr(feature = "cargo-clippy", allow(clippy::match_same_arms))] fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { match *self { // Try to display the first line of the server's response that usually @@ -70,6 +69,7 @@ impl Display for Error { Parsing(ref err) => fmt.write_str(err.description()), #[cfg(feature = "rustls-tls")] InvalidDNSName(ref err) => err.fmt(fmt), + Pool(ref err) => err.fmt(fmt), } } } @@ -129,6 +129,13 @@ impl From for Error { } } +#[cfg(feature = "r2d2")] +impl From for Error { + fn from(err: r2d2::Error) -> Error { + Pool(err) + } +} + impl From for Error { fn from(response: Response) -> Error { match response.code.severity { diff --git a/src/transport/smtp/mod.rs b/src/transport/smtp/mod.rs index 54f0fa9..70d7c98 100644 --- a/src/transport/smtp/mod.rs +++ b/src/transport/smtp/mod.rs @@ -12,18 +12,14 @@ //! * STARTTLS ([RFC 2487](http://tools.ietf.org/html/rfc2487)) //! -#[cfg(feature = "r2d2")] -use crate::transport::smtp::r2d2::SmtpConnectionManager; -use crate::Envelope; use crate::{ transport::smtp::{ authentication::{Credentials, Mechanism, DEFAULT_MECHANISMS}, client::{net::TlsParameters, SmtpConnection}, - commands::*, error::{Error, SmtpResult}, - extension::{ClientId, Extension, MailBodyParameter, MailParameter}, + extension::ClientId, }, - Transport, + Envelope, Transport, }; #[cfg(feature = "native-tls")] use native_tls::{Protocol, TlsConnector}; @@ -31,8 +27,8 @@ use native_tls::{Protocol, TlsConnector}; use r2d2::Pool; #[cfg(feature = "rustls")] use rustls::ClientConfig; +use std::ops::DerefMut; use std::time::Duration; - #[cfg(feature = "rustls")] use webpki_roots::TLS_SERVER_ROOTS; @@ -41,8 +37,8 @@ pub mod client; pub mod commands; pub mod error; pub mod extension; -#[cfg(feature = "connection-pool")] -pub mod r2d2; +#[cfg(feature = "r2d2")] +pub mod pool; pub mod response; pub mod util; @@ -101,21 +97,9 @@ pub struct SmtpTransport { timeout: Option, /// Connection pool #[cfg(feature = "r2d2")] - pool: Option, + pool: Option>, } -macro_rules! try_smtp ( - ($err: expr, $client: ident) => ({ - match $err { - Ok(val) => val, - Err(err) => { - $client.abort(); - return Err(From::from(err)) - }, - } - }) -); - /// Builder for the SMTP `SmtpTransport` impl SmtpTransport { /// Creates a new SMTP client @@ -155,9 +139,7 @@ impl SmtpTransport { #[cfg(feature = "rustls")] let mut tls = ClientConfig::new(); #[cfg(feature = "rustls")] - tls.config - .root_store - .add_server_trust_anchors(&TLS_SERVER_ROOTS); + tls.root_store.add_server_trust_anchors(&TLS_SERVER_ROOTS); #[cfg(feature = "rustls")] let tls_parameters = TlsParameters::new(relay.to_string(), tls); @@ -167,8 +149,9 @@ impl SmtpTransport { #[cfg(feature = "r2d2")] // Pool with default configuration - let new = new.pool(Pool::new(SmtpConnectionManager))?; - + // FIXME avoid clone + let tpool = new.clone(); + let new = new.pool(Pool::new(tpool)?); Ok(new) } @@ -217,8 +200,8 @@ impl SmtpTransport { /// Set the TLS settings to use #[cfg(feature = "r2d2")] - pub fn pool(mut self, pool: Pool) -> Self { - self.pool = pool; + pub fn pool(mut self, pool: Pool) -> Self { + self.pool = Some(pool); self } @@ -240,22 +223,19 @@ impl SmtpTransport { #[cfg(any(feature = "native-tls", feature = "rustls"))] Tls::Opportunistic(ref tls_parameters) => { if conn.can_starttls() { - try_smtp!(conn.starttls(tls_parameters, &self.hello_name), conn); + conn.starttls(tls_parameters, &self.hello_name)?; } } #[cfg(any(feature = "native-tls", feature = "rustls"))] Tls::Required(ref tls_parameters) => { - try_smtp!(conn.starttls(tls_parameters, &self.hello_name), conn); + conn.starttls(tls_parameters, &self.hello_name)?; } _ => (), } match &self.credentials { Some(credentials) => { - try_smtp!( - conn.auth(self.authentication.as_slice(), &credentials), - conn - ); + conn.auth(self.authentication.as_slice(), &credentials)?; } None => (), } @@ -270,43 +250,23 @@ impl<'a> Transport<'a> for SmtpTransport { /// Sends an email fn send_raw(&self, envelope: &Envelope, email: &[u8]) -> Self::Result { #[cfg(feature = "r2d2")] - let mut conn = match self.pool { - Some(p) => p.get()?, - None => self.connection()?, + let mut conn: Box> = match self.pool { + Some(ref p) => Box::new(p.get()?), + None => Box::new(Box::new(self.connection()?)), }; #[cfg(not(feature = "r2d2"))] let mut conn = self.connection()?; - // Mail - let mut mail_options = vec![]; - - if conn.server_info().supports_feature(Extension::EightBitMime) { - mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime)); - } - try_smtp!( - conn.command(Mail::new(envelope.from().cloned(), mail_options,)), - conn - ); - - // Recipient - for to_address in envelope.to() { - try_smtp!(conn.command(Rcpt::new(to_address.clone(), vec![])), conn); - } - - // Data - try_smtp!(conn.command(Data), conn); - - // Message content - let result = try_smtp!(conn.message(email), conn); + let result = conn.send(envelope, email)?; #[cfg(feature = "r2d2")] { if self.pool.is_none() { - try_smtp!(conn.command(Quit), conn); + conn.quit()?; } } #[cfg(not(feature = "r2d2"))] - try_smtp!(conn.command(Quit), conn); + conn.quit()?; Ok(result) } diff --git a/src/transport/smtp/pool.rs b/src/transport/smtp/pool.rs new file mode 100644 index 0000000..b373bd6 --- /dev/null +++ b/src/transport/smtp/pool.rs @@ -0,0 +1,22 @@ +use crate::transport::smtp::{client::SmtpConnection, error::Error, SmtpTransport}; +use r2d2::ManageConnection; + +impl ManageConnection for SmtpTransport { + type Connection = SmtpConnection; + type Error = Error; + + fn connect(&self) -> Result { + self.connection() + } + + fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Error> { + if conn.test_connected() { + return Ok(()); + } + Err(Error::Client("is not connected anymore")) + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + conn.has_broken() + } +} diff --git a/src/transport/smtp/r2d2.rs b/src/transport/smtp/r2d2.rs deleted file mode 100644 index 980cdb2..0000000 --- a/src/transport/smtp/r2d2.rs +++ /dev/null @@ -1,39 +0,0 @@ -use crate::transport::smtp::{ - error::Error, ConnectionReuseParameters, SmtpTransport, SmtpTransport, -}; -use r2d2::ManageConnection; - -pub struct SmtpConnectionManager { - transport_builder: SmtpTransport, -} - -impl SmtpConnectionManager { - pub fn new(transport_builder: SmtpTransport) -> Result { - Ok(SmtpConnectionManager { - transport_builder: transport_builder - .connection_reuse(ConnectionReuseParameters::ReuseUnlimited), - }) - } -} - -impl ManageConnection for SmtpConnectionManager { - type Connection = SmtpTransport; - type Error = Error; - - fn connect(&self) -> Result { - let mut transport = SmtpTransport::new(self.transport_builder.clone()); - transport.connect()?; - Ok(transport) - } - - fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Error> { - if conn.client.test_connected() { - return Ok(()); - } - Err(Error::Client("is not connected anymore")) - } - - fn has_broken(&self, conn: &mut Self::Connection) -> bool { - conn.state.panic - } -}