refactor: backport improvements from Tokio02 support

This commit is contained in:
Paolo Barbolini
2020-08-13 23:23:12 +02:00
committed by Alexis Mousset
parent c8ec8984b8
commit 60e3a0b7cb
3 changed files with 77 additions and 73 deletions

View File

@@ -1,10 +1,7 @@
use std::{
fmt::Display,
io::{self, BufRead, BufReader, Write},
net::ToSocketAddrs,
string::String,
time::Duration,
};
use std::fmt::Display;
use std::io::{self, BufRead, BufReader, Write};
use std::net::ToSocketAddrs;
use std::time::Duration;
#[cfg(feature = "log")]
use log::debug;
@@ -61,7 +58,8 @@ impl SmtpConnection {
hello_name: &ClientId,
tls_parameters: Option<&TlsParameters>,
) -> Result<SmtpConnection, Error> {
let stream = BufReader::new(NetworkStream::connect(server, timeout, tls_parameters)?);
let stream = NetworkStream::connect(server, timeout, tls_parameters)?;
let stream = BufReader::new(stream);
let mut conn = SmtpConnection {
stream,
panic: false,
@@ -87,7 +85,7 @@ impl SmtpConnection {
mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime));
}
try_smtp!(
self.command(Mail::new(envelope.from().cloned(), mail_options,)),
self.command(Mail::new(envelope.from().cloned(), mail_options)),
self
);
@@ -187,14 +185,12 @@ impl SmtpConnection {
mechanisms: &[Mechanism],
credentials: &Credentials,
) -> Result<Response, Error> {
let mechanism = match self.server_info.get_auth_mechanism(mechanisms) {
Some(m) => m,
None => {
return Err(Error::Client(
"No compatible authentication mechanism was found",
))
}
};
let mechanism = self
.server_info
.get_auth_mechanism(mechanisms)
.ok_or(Error::Client(
"No compatible authentication mechanism was found",
))?;
// Limit challenges to avoid blocking
let mut challenges = 10;
@@ -241,10 +237,7 @@ impl SmtpConnection {
self.stream.get_mut().flush()?;
#[cfg(feature = "log")]
debug!(
"Wrote: {}",
escape_crlf(String::from_utf8_lossy(string).as_ref())
);
debug!("Wrote: {}", escape_crlf(&String::from_utf8_lossy(string)));
Ok(())
}

View File

@@ -35,7 +35,7 @@ enum InnerNetworkStream {
}
impl NetworkStream {
pub(self) fn new(inner: InnerNetworkStream) -> Self {
fn new(inner: InnerNetworkStream) -> Self {
NetworkStream { inner }
}
@@ -100,41 +100,57 @@ impl NetworkStream {
Ok(stream)
}
#[allow(unused_variables)]
pub fn upgrade_tls(&mut self, tls_parameters: &TlsParameters) -> Result<(), Error> {
match self.inner {
InnerNetworkStream::Tcp(ref mut stream) => {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
match &tls_parameters.connector {
#[cfg(feature = "native-tls")]
InnerTlsParameters::NativeTls(connector) => {
let stream = connector
.connect(tls_parameters.domain(), stream.try_clone().unwrap())
.map_err(|err| Error::Io(io::Error::new(io::ErrorKind::Other, err)))?;
*self = NetworkStream::new(InnerNetworkStream::NativeTls(stream));
}
#[cfg(feature = "rustls-tls")]
InnerTlsParameters::RustlsTls(connector) => {
use webpki::DNSNameRef;
let domain = DNSNameRef::try_from_ascii_str(tls_parameters.domain())?;
let stream = StreamOwned::new(
ClientSession::new(&Arc::new(connector.clone()), domain),
stream.try_clone().unwrap(),
);
*self = NetworkStream::new(InnerNetworkStream::RustlsTls(Box::new(stream)));
}
};
match &self.inner {
#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
InnerNetworkStream::Tcp(_) => {
let _ = tls_parameters;
panic!("Trying to upgrade an NetworkStream without having enabled either the native-tls or the rustls-tls feature");
}
#[cfg(feature = "native-tls")]
InnerNetworkStream::NativeTls(_) => (),
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(_) => (),
InnerNetworkStream::Mock(_) => (),
};
Ok(())
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
InnerNetworkStream::Tcp(_) => {
// get owned TcpStream
let tcp_stream =
std::mem::replace(&mut self.inner, InnerNetworkStream::Mock(MockStream::new()));
let tcp_stream = match tcp_stream {
InnerNetworkStream::Tcp(tcp_stream) => tcp_stream,
_ => unreachable!(),
};
self.inner = Self::upgrade_tls_impl(tcp_stream, tls_parameters)?;
Ok(())
}
_ => Ok(()),
}
}
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
fn upgrade_tls_impl(
tcp_stream: TcpStream,
tls_parameters: &TlsParameters,
) -> Result<InnerNetworkStream, Error> {
Ok(match &tls_parameters.connector {
#[cfg(feature = "native-tls")]
InnerTlsParameters::NativeTls(connector) => {
let stream = connector
.connect(tls_parameters.domain(), tcp_stream)
.map_err(|err| Error::Io(io::Error::new(io::ErrorKind::Other, err)))?;
InnerNetworkStream::NativeTls(stream)
}
#[cfg(feature = "rustls-tls")]
InnerTlsParameters::RustlsTls(connector) => {
use webpki::DNSNameRef;
let domain = DNSNameRef::try_from_ascii_str(tls_parameters.domain())?;
let stream = StreamOwned::new(
ClientSession::new(&Arc::new(connector.clone()), domain),
tcp_stream,
);
InnerNetworkStream::RustlsTls(Box::new(stream))
}
})
}
pub fn is_encrypted(&self) -> bool {

View File

@@ -3,10 +3,9 @@ use std::time::Duration;
#[cfg(feature = "r2d2")]
use r2d2::{Builder, Pool};
use super::{
ClientId, Credentials, Error, Mechanism, Response, SmtpConnection, SmtpInfo, Tls,
TlsParameters, SUBMISSIONS_PORT,
};
use super::{ClientId, Credentials, Error, Mechanism, Response, SmtpConnection, SmtpInfo};
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
use super::{Tls, TlsParameters, SUBMISSIONS_PORT};
use crate::{Envelope, Transport};
#[allow(missing_debug_implementations)]
@@ -156,40 +155,36 @@ impl SmtpClient {
///
/// Handles encryption and authentication
pub fn connection(&self) -> Result<SmtpConnection, Error> {
#[allow(clippy::match_single_binding)]
let tls_parameters = match self.info.tls {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Tls::Wrapper(ref tls_parameters) => Some(tls_parameters),
_ => None,
};
let mut conn = SmtpConnection::connect::<(&str, u16)>(
(self.info.server.as_ref(), self.info.port),
self.info.timeout,
&self.info.hello_name,
#[allow(clippy::match_single_binding)]
match self.info.tls {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Tls::Wrapper(ref tls_parameters) => Some(tls_parameters),
_ => None,
},
tls_parameters,
)?;
#[allow(clippy::match_single_binding)]
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
match self.info.tls {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Tls::Opportunistic(ref tls_parameters) => {
if conn.can_starttls() {
conn.starttls(tls_parameters, &self.info.hello_name)?;
}
}
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Tls::Required(ref tls_parameters) => {
conn.starttls(tls_parameters, &self.info.hello_name)?;
}
_ => (),
}
match &self.info.credentials {
Some(credentials) => {
conn.auth(self.info.authentication.as_slice(), &credentials)?;
}
None => (),
if let Some(credentials) = &self.info.credentials {
conn.auth(&self.info.authentication, &credentials)?;
}
Ok(conn)
}
}