From 60e3a0b7cbb33ded86468930b2ec1be817c53f9e Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Thu, 13 Aug 2020 23:23:12 +0200 Subject: [PATCH] refactor: backport improvements from Tokio02 support --- src/transport/smtp/client/connection.rs | 35 +++++------ src/transport/smtp/client/net.rs | 82 +++++++++++++++---------- src/transport/smtp/transport.rs | 33 +++++----- 3 files changed, 77 insertions(+), 73 deletions(-) diff --git a/src/transport/smtp/client/connection.rs b/src/transport/smtp/client/connection.rs index 5782bbd..681e88e 100644 --- a/src/transport/smtp/client/connection.rs +++ b/src/transport/smtp/client/connection.rs @@ -1,10 +1,7 @@ -use std::{ - fmt::Display, - io::{self, BufRead, BufReader, Write}, - net::ToSocketAddrs, - string::String, - time::Duration, -}; +use std::fmt::Display; +use std::io::{self, BufRead, BufReader, Write}; +use std::net::ToSocketAddrs; +use std::time::Duration; #[cfg(feature = "log")] use log::debug; @@ -61,7 +58,8 @@ impl SmtpConnection { hello_name: &ClientId, tls_parameters: Option<&TlsParameters>, ) -> Result { - let stream = BufReader::new(NetworkStream::connect(server, timeout, tls_parameters)?); + let stream = NetworkStream::connect(server, timeout, tls_parameters)?; + let stream = BufReader::new(stream); let mut conn = SmtpConnection { stream, panic: false, @@ -87,7 +85,7 @@ impl SmtpConnection { mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime)); } try_smtp!( - self.command(Mail::new(envelope.from().cloned(), mail_options,)), + self.command(Mail::new(envelope.from().cloned(), mail_options)), self ); @@ -187,14 +185,12 @@ impl SmtpConnection { mechanisms: &[Mechanism], credentials: &Credentials, ) -> Result { - let mechanism = match self.server_info.get_auth_mechanism(mechanisms) { - Some(m) => m, - None => { - return Err(Error::Client( - "No compatible authentication mechanism was found", - )) - } - }; + let mechanism = self + .server_info + .get_auth_mechanism(mechanisms) + .ok_or(Error::Client( + "No compatible authentication mechanism was found", + ))?; // Limit challenges to avoid blocking let mut challenges = 10; @@ -241,10 +237,7 @@ impl SmtpConnection { self.stream.get_mut().flush()?; #[cfg(feature = "log")] - debug!( - "Wrote: {}", - escape_crlf(String::from_utf8_lossy(string).as_ref()) - ); + debug!("Wrote: {}", escape_crlf(&String::from_utf8_lossy(string))); Ok(()) } diff --git a/src/transport/smtp/client/net.rs b/src/transport/smtp/client/net.rs index 8058e8d..8dccc65 100644 --- a/src/transport/smtp/client/net.rs +++ b/src/transport/smtp/client/net.rs @@ -35,7 +35,7 @@ enum InnerNetworkStream { } impl NetworkStream { - pub(self) fn new(inner: InnerNetworkStream) -> Self { + fn new(inner: InnerNetworkStream) -> Self { NetworkStream { inner } } @@ -100,41 +100,57 @@ impl NetworkStream { Ok(stream) } - #[allow(unused_variables)] pub fn upgrade_tls(&mut self, tls_parameters: &TlsParameters) -> Result<(), Error> { - match self.inner { - InnerNetworkStream::Tcp(ref mut stream) => { - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - match &tls_parameters.connector { - #[cfg(feature = "native-tls")] - InnerTlsParameters::NativeTls(connector) => { - let stream = connector - .connect(tls_parameters.domain(), stream.try_clone().unwrap()) - .map_err(|err| Error::Io(io::Error::new(io::ErrorKind::Other, err)))?; - *self = NetworkStream::new(InnerNetworkStream::NativeTls(stream)); - } - #[cfg(feature = "rustls-tls")] - InnerTlsParameters::RustlsTls(connector) => { - use webpki::DNSNameRef; - - let domain = DNSNameRef::try_from_ascii_str(tls_parameters.domain())?; - let stream = StreamOwned::new( - ClientSession::new(&Arc::new(connector.clone()), domain), - stream.try_clone().unwrap(), - ); - - *self = NetworkStream::new(InnerNetworkStream::RustlsTls(Box::new(stream))); - } - }; + match &self.inner { + #[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))] + InnerNetworkStream::Tcp(_) => { + let _ = tls_parameters; + panic!("Trying to upgrade an NetworkStream without having enabled either the native-tls or the rustls-tls feature"); } - #[cfg(feature = "native-tls")] - InnerNetworkStream::NativeTls(_) => (), - #[cfg(feature = "rustls-tls")] - InnerNetworkStream::RustlsTls(_) => (), - InnerNetworkStream::Mock(_) => (), - }; - Ok(()) + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + InnerNetworkStream::Tcp(_) => { + // get owned TcpStream + let tcp_stream = + std::mem::replace(&mut self.inner, InnerNetworkStream::Mock(MockStream::new())); + let tcp_stream = match tcp_stream { + InnerNetworkStream::Tcp(tcp_stream) => tcp_stream, + _ => unreachable!(), + }; + + self.inner = Self::upgrade_tls_impl(tcp_stream, tls_parameters)?; + Ok(()) + } + _ => Ok(()), + } + } + + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + fn upgrade_tls_impl( + tcp_stream: TcpStream, + tls_parameters: &TlsParameters, + ) -> Result { + Ok(match &tls_parameters.connector { + #[cfg(feature = "native-tls")] + InnerTlsParameters::NativeTls(connector) => { + let stream = connector + .connect(tls_parameters.domain(), tcp_stream) + .map_err(|err| Error::Io(io::Error::new(io::ErrorKind::Other, err)))?; + InnerNetworkStream::NativeTls(stream) + } + #[cfg(feature = "rustls-tls")] + InnerTlsParameters::RustlsTls(connector) => { + use webpki::DNSNameRef; + + let domain = DNSNameRef::try_from_ascii_str(tls_parameters.domain())?; + let stream = StreamOwned::new( + ClientSession::new(&Arc::new(connector.clone()), domain), + tcp_stream, + ); + + InnerNetworkStream::RustlsTls(Box::new(stream)) + } + }) } pub fn is_encrypted(&self) -> bool { diff --git a/src/transport/smtp/transport.rs b/src/transport/smtp/transport.rs index e2da184..7d5fe00 100644 --- a/src/transport/smtp/transport.rs +++ b/src/transport/smtp/transport.rs @@ -3,10 +3,9 @@ use std::time::Duration; #[cfg(feature = "r2d2")] use r2d2::{Builder, Pool}; -use super::{ - ClientId, Credentials, Error, Mechanism, Response, SmtpConnection, SmtpInfo, Tls, - TlsParameters, SUBMISSIONS_PORT, -}; +use super::{ClientId, Credentials, Error, Mechanism, Response, SmtpConnection, SmtpInfo}; +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] +use super::{Tls, TlsParameters, SUBMISSIONS_PORT}; use crate::{Envelope, Transport}; #[allow(missing_debug_implementations)] @@ -156,40 +155,36 @@ impl SmtpClient { /// /// Handles encryption and authentication pub fn connection(&self) -> Result { + #[allow(clippy::match_single_binding)] + let tls_parameters = match self.info.tls { + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + Tls::Wrapper(ref tls_parameters) => Some(tls_parameters), + _ => None, + }; + let mut conn = SmtpConnection::connect::<(&str, u16)>( (self.info.server.as_ref(), self.info.port), self.info.timeout, &self.info.hello_name, - #[allow(clippy::match_single_binding)] - match self.info.tls { - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - Tls::Wrapper(ref tls_parameters) => Some(tls_parameters), - _ => None, - }, + tls_parameters, )?; - #[allow(clippy::match_single_binding)] + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] match self.info.tls { - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] Tls::Opportunistic(ref tls_parameters) => { if conn.can_starttls() { conn.starttls(tls_parameters, &self.info.hello_name)?; } } - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] Tls::Required(ref tls_parameters) => { conn.starttls(tls_parameters, &self.info.hello_name)?; } _ => (), } - match &self.info.credentials { - Some(credentials) => { - conn.auth(self.info.authentication.as_slice(), &credentials)?; - } - None => (), + if let Some(credentials) = &self.info.credentials { + conn.auth(&self.info.authentication, &credentials)?; } - Ok(conn) } }