fix(transport): Fix connection pool

This commit is contained in:
Alexis Mousset
2020-05-02 01:07:31 +02:00
parent 0604030b91
commit b3414bd1ff
7 changed files with 98 additions and 115 deletions

View File

@@ -37,6 +37,7 @@ serde = { version = "1", optional = true, features = ["derive"] }
serde_json = { version = "1", optional = true }
textnonce = { version = "0.7", optional = true }
webpki = { version = "0.21", optional = true }
webpki-roots = { version = "0.19", optional = true }
[dev-dependencies]
criterion = "0.3"
@@ -50,10 +51,9 @@ name = "transport_smtp"
[features]
builder = ["mime", "base64", "hyperx", "textnonce", "quoted_printable"]
connection-pool = ["r2d2"]
default = ["file-transport", "smtp-transport", "hostname", "sendmail-transport", "native-tls", "builder"]
default = ["file-transport", "smtp-transport", "hostname", "sendmail-transport", "rustls-tls", "builder", "r2d2"]
file-transport = ["serde", "serde_json"]
rustls-tls = ["webpki", "rustls"]
rustls-tls = ["webpki", "webpki-roots", "rustls"]
sendmail-transport = []
smtp-transport = ["bufstream", "base64", "nom"]
unstable = []

View File

@@ -5,9 +5,10 @@ use crate::transport::smtp::{
client::net::{NetworkStream, TlsParameters},
commands::*,
error::{Error, SmtpResult},
extension::{ClientId, Extension, ServerInfo},
extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo},
response::Response,
};
use crate::Envelope;
use bufstream::BufStream;
use log::debug;
#[cfg(feature = "serde")]
@@ -151,6 +152,31 @@ impl SmtpConnection {
Ok(conn)
}
pub fn send(&mut self, envelope: &Envelope, email: &[u8]) -> SmtpResult {
// Mail
let mut mail_options = vec![];
if self.server_info().supports_feature(Extension::EightBitMime) {
mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime));
}
try_smtp!(
self.command(Mail::new(envelope.from().cloned(), mail_options,)),
self
);
// Recipient
for to_address in envelope.to() {
try_smtp!(self.command(Rcpt::new(to_address.clone(), vec![])), self);
}
// Data
try_smtp!(self.command(Data), self);
// Message content
let result = try_smtp!(self.message(email), self);
Ok(result)
}
pub fn has_broken(&self) -> bool {
self.panic
}
@@ -172,7 +198,8 @@ impl SmtpConnection {
try_smtp!(self.stream.get_mut().upgrade_tls(tls_parameters), self);
debug!("connection encrypted");
// Send EHLO again
self.ehlo(hello_name)
try_smtp!(self.ehlo(hello_name), self);
Ok(())
}
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
// This should never happen as `Tls` can only be created
@@ -193,6 +220,10 @@ impl SmtpConnection {
Ok(())
}
pub fn quit(&mut self) -> SmtpResult {
Ok(try_smtp!(self.command(Quit), self))
}
pub fn abort(&mut self) {
// Only try to quit if we are not already broken
if !self.panic {
@@ -239,11 +270,14 @@ impl SmtpConnection {
while challenges > 0 && response.has_code(334) {
challenges -= 1;
response = self.command(Auth::new_from_response(
mechanism,
credentials.clone(),
&response,
)?)?;
response = try_smtp!(
self.command(Auth::new_from_response(
mechanism,
credentials.clone(),
&response,
)?),
self
);
}
if challenges == 0 {

View File

@@ -24,7 +24,6 @@ pub struct Ehlo {
impl Display for Ehlo {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
#[allow(clippy::write_with_newline)]
write!(f, "EHLO {}\r\n", self.client_id)
}
}

View File

@@ -3,11 +3,9 @@
use self::Error::*;
use crate::transport::smtp::response::{Response, Severity};
use base64::DecodeError;
#[cfg(feature = "native-tls")]
use std::{
error::Error as StdError,
fmt,
fmt::{Display, Formatter},
fmt::{self, Display, Formatter},
io,
string::FromUtf8Error,
};
@@ -43,10 +41,11 @@ pub enum Error {
/// Invalid hostname
#[cfg(feature = "rustls-tls")]
InvalidDNSName(webpki::InvalidDNSNameError),
#[cfg(feature = "r2d2")]
Pool(r2d2::Error),
}
impl Display for Error {
#[cfg_attr(feature = "cargo-clippy", allow(clippy::match_same_arms))]
fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> {
match *self {
// Try to display the first line of the server's response that usually
@@ -70,6 +69,7 @@ impl Display for Error {
Parsing(ref err) => fmt.write_str(err.description()),
#[cfg(feature = "rustls-tls")]
InvalidDNSName(ref err) => err.fmt(fmt),
Pool(ref err) => err.fmt(fmt),
}
}
}
@@ -129,6 +129,13 @@ impl From<webpki::InvalidDNSNameError> for Error {
}
}
#[cfg(feature = "r2d2")]
impl From<r2d2::Error> for Error {
fn from(err: r2d2::Error) -> Error {
Pool(err)
}
}
impl From<Response> for Error {
fn from(response: Response) -> Error {
match response.code.severity {

View File

@@ -12,18 +12,14 @@
//! * STARTTLS ([RFC 2487](http://tools.ietf.org/html/rfc2487))
//!
#[cfg(feature = "r2d2")]
use crate::transport::smtp::r2d2::SmtpConnectionManager;
use crate::Envelope;
use crate::{
transport::smtp::{
authentication::{Credentials, Mechanism, DEFAULT_MECHANISMS},
client::{net::TlsParameters, SmtpConnection},
commands::*,
error::{Error, SmtpResult},
extension::{ClientId, Extension, MailBodyParameter, MailParameter},
extension::ClientId,
},
Transport,
Envelope, Transport,
};
#[cfg(feature = "native-tls")]
use native_tls::{Protocol, TlsConnector};
@@ -31,8 +27,8 @@ use native_tls::{Protocol, TlsConnector};
use r2d2::Pool;
#[cfg(feature = "rustls")]
use rustls::ClientConfig;
use std::ops::DerefMut;
use std::time::Duration;
#[cfg(feature = "rustls")]
use webpki_roots::TLS_SERVER_ROOTS;
@@ -41,8 +37,8 @@ pub mod client;
pub mod commands;
pub mod error;
pub mod extension;
#[cfg(feature = "connection-pool")]
pub mod r2d2;
#[cfg(feature = "r2d2")]
pub mod pool;
pub mod response;
pub mod util;
@@ -101,21 +97,9 @@ pub struct SmtpTransport {
timeout: Option<Duration>,
/// Connection pool
#[cfg(feature = "r2d2")]
pool: Option<Pool>,
pool: Option<Pool<SmtpTransport>>,
}
macro_rules! try_smtp (
($err: expr, $client: ident) => ({
match $err {
Ok(val) => val,
Err(err) => {
$client.abort();
return Err(From::from(err))
},
}
})
);
/// Builder for the SMTP `SmtpTransport`
impl SmtpTransport {
/// Creates a new SMTP client
@@ -155,9 +139,7 @@ impl SmtpTransport {
#[cfg(feature = "rustls")]
let mut tls = ClientConfig::new();
#[cfg(feature = "rustls")]
tls.config
.root_store
.add_server_trust_anchors(&TLS_SERVER_ROOTS);
tls.root_store.add_server_trust_anchors(&TLS_SERVER_ROOTS);
#[cfg(feature = "rustls")]
let tls_parameters = TlsParameters::new(relay.to_string(), tls);
@@ -167,8 +149,9 @@ impl SmtpTransport {
#[cfg(feature = "r2d2")]
// Pool with default configuration
let new = new.pool(Pool::new(SmtpConnectionManager))?;
// FIXME avoid clone
let tpool = new.clone();
let new = new.pool(Pool::new(tpool)?);
Ok(new)
}
@@ -217,8 +200,8 @@ impl SmtpTransport {
/// Set the TLS settings to use
#[cfg(feature = "r2d2")]
pub fn pool(mut self, pool: Pool) -> Self {
self.pool = pool;
pub fn pool(mut self, pool: Pool<SmtpTransport>) -> Self {
self.pool = Some(pool);
self
}
@@ -240,22 +223,19 @@ impl SmtpTransport {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
Tls::Opportunistic(ref tls_parameters) => {
if conn.can_starttls() {
try_smtp!(conn.starttls(tls_parameters, &self.hello_name), conn);
conn.starttls(tls_parameters, &self.hello_name)?;
}
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
Tls::Required(ref tls_parameters) => {
try_smtp!(conn.starttls(tls_parameters, &self.hello_name), conn);
conn.starttls(tls_parameters, &self.hello_name)?;
}
_ => (),
}
match &self.credentials {
Some(credentials) => {
try_smtp!(
conn.auth(self.authentication.as_slice(), &credentials),
conn
);
conn.auth(self.authentication.as_slice(), &credentials)?;
}
None => (),
}
@@ -270,43 +250,23 @@ impl<'a> Transport<'a> for SmtpTransport {
/// Sends an email
fn send_raw(&self, envelope: &Envelope, email: &[u8]) -> Self::Result {
#[cfg(feature = "r2d2")]
let mut conn = match self.pool {
Some(p) => p.get()?,
None => self.connection()?,
let mut conn: Box<dyn DerefMut<Target = SmtpConnection>> = match self.pool {
Some(ref p) => Box::new(p.get()?),
None => Box::new(Box::new(self.connection()?)),
};
#[cfg(not(feature = "r2d2"))]
let mut conn = self.connection()?;
// Mail
let mut mail_options = vec![];
if conn.server_info().supports_feature(Extension::EightBitMime) {
mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime));
}
try_smtp!(
conn.command(Mail::new(envelope.from().cloned(), mail_options,)),
conn
);
// Recipient
for to_address in envelope.to() {
try_smtp!(conn.command(Rcpt::new(to_address.clone(), vec![])), conn);
}
// Data
try_smtp!(conn.command(Data), conn);
// Message content
let result = try_smtp!(conn.message(email), conn);
let result = conn.send(envelope, email)?;
#[cfg(feature = "r2d2")]
{
if self.pool.is_none() {
try_smtp!(conn.command(Quit), conn);
conn.quit()?;
}
}
#[cfg(not(feature = "r2d2"))]
try_smtp!(conn.command(Quit), conn);
conn.quit()?;
Ok(result)
}

View File

@@ -0,0 +1,22 @@
use crate::transport::smtp::{client::SmtpConnection, error::Error, SmtpTransport};
use r2d2::ManageConnection;
impl ManageConnection for SmtpTransport {
type Connection = SmtpConnection;
type Error = Error;
fn connect(&self) -> Result<Self::Connection, Error> {
self.connection()
}
fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Error> {
if conn.test_connected() {
return Ok(());
}
Err(Error::Client("is not connected anymore"))
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.has_broken()
}
}

View File

@@ -1,39 +0,0 @@
use crate::transport::smtp::{
error::Error, ConnectionReuseParameters, SmtpTransport, SmtpTransport,
};
use r2d2::ManageConnection;
pub struct SmtpConnectionManager {
transport_builder: SmtpTransport,
}
impl SmtpConnectionManager {
pub fn new(transport_builder: SmtpTransport) -> Result<SmtpConnectionManager, Error> {
Ok(SmtpConnectionManager {
transport_builder: transport_builder
.connection_reuse(ConnectionReuseParameters::ReuseUnlimited),
})
}
}
impl ManageConnection for SmtpConnectionManager {
type Connection = SmtpTransport;
type Error = Error;
fn connect(&self) -> Result<Self::Connection, Error> {
let mut transport = SmtpTransport::new(self.transport_builder.clone());
transport.connect()?;
Ok(transport)
}
fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Error> {
if conn.client.test_connected() {
return Ok(());
}
Err(Error::Client("is not connected anymore"))
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.state.panic
}
}