refactor: TlsParameters to not expose the inner tls library

Also made it compile with both TLS libraries enabled
This commit is contained in:
Paolo Barbolini
2020-08-06 21:53:21 +02:00
committed by Alexis Mousset
parent d75fb5956b
commit bcbdbecd95
5 changed files with 238 additions and 188 deletions

View File

@@ -49,11 +49,11 @@ pub use crate::transport::file::FileTransport;
#[cfg(feature = "sendmail-transport")]
pub use crate::transport::sendmail::SendmailTransport;
#[cfg(feature = "smtp-transport")]
pub use crate::transport::smtp::client::net::TlsParameters;
pub use crate::transport::smtp::client::TlsParameters;
#[cfg(all(feature = "smtp-transport", feature = "connection-pool"))]
pub use crate::transport::smtp::r2d2::SmtpConnectionManager;
#[cfg(feature = "smtp-transport")]
pub use crate::transport::smtp::{SmtpTransport, Tls};
pub use crate::transport::smtp::{client::Tls, SmtpTransport};
pub use crate::{address::Address, transport::stub::StubTransport};
#[cfg(any(feature = "async-std1", feature = "tokio02"))]
use async_trait::async_trait;

View File

@@ -1,18 +1,5 @@
//! SMTP client
use crate::{
transport::smtp::{
authentication::{Credentials, Mechanism},
client::net::{NetworkStream, TlsParameters},
commands::*,
error::Error,
extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo},
response::{parse_response, Response},
},
Envelope,
};
#[cfg(feature = "log")]
use log::debug;
#[cfg(feature = "serde")]
use std::fmt::Debug;
use std::{
@@ -23,8 +10,29 @@ use std::{
time::Duration,
};
pub mod mock;
pub mod net;
#[cfg(feature = "log")]
use log::debug;
use crate::{
transport::smtp::{
authentication::{Credentials, Mechanism},
client::net::NetworkStream,
commands::*,
error::Error,
extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo},
response::{parse_response, Response},
},
Envelope,
};
pub use self::mock::MockStream;
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
pub(super) use self::tls::InnerTlsParameters;
pub use self::tls::{Tls, TlsParameters};
mod mock;
mod net;
mod tls;
/// The codec used for transparency
#[derive(Default, Clone, Copy, Debug)]

View File

@@ -1,75 +1,57 @@
//! A trait to represent a stream
use crate::transport::smtp::{client::mock::MockStream, error::Error};
#[cfg(feature = "native-tls")]
use native_tls::{TlsConnector, TlsStream};
#[cfg(feature = "rustls-tls")]
use rustls::{ClientConfig, ClientSession};
#[cfg(feature = "native-tls")]
use std::io::ErrorKind;
use std::io::{self, Read, Write};
use std::net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream, ToSocketAddrs};
#[cfg(feature = "rustls-tls")]
use std::sync::Arc;
use std::{
io::{self, Read, Write},
net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream, ToSocketAddrs},
time::Duration,
};
use std::time::Duration;
/// Parameters to use for secure clients
#[derive(Clone)]
#[allow(missing_debug_implementations)]
pub struct TlsParameters {
/// A connector from `native-tls`
#[cfg(feature = "native-tls")]
connector: TlsConnector,
/// A client from `rustls`
#[cfg(feature = "rustls-tls")]
// TODO use the same in all transports of the client
connector: Box<ClientConfig>,
/// The domain name which is expected in the TLS certificate from the server
domain: String,
}
#[cfg(feature = "native-tls")]
use native_tls::TlsStream;
impl TlsParameters {
/// Creates a `TlsParameters`
#[cfg(feature = "native-tls")]
pub fn new(domain: String, connector: TlsConnector) -> Self {
Self { connector, domain }
}
#[cfg(feature = "rustls-tls")]
use rustls::{ClientSession, StreamOwned};
/// Creates a `TlsParameters`
#[cfg(feature = "rustls-tls")]
pub fn new(domain: String, connector: ClientConfig) -> Self {
Self {
connector: Box::new(connector),
domain,
}
}
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
use super::InnerTlsParameters;
use super::{MockStream, TlsParameters};
use crate::transport::smtp::Error;
/// A network stream
pub struct NetworkStream {
inner: InnerNetworkStream,
}
/// Represents the different types of underlying network streams
pub enum NetworkStream {
enum InnerNetworkStream {
/// Plain TCP stream
Tcp(TcpStream),
/// Encrypted TCP stream
#[cfg(feature = "native-tls")]
Tls(Box<TlsStream<TcpStream>>),
NativeTls(TlsStream<TcpStream>),
/// Encrypted TCP stream
#[cfg(feature = "rustls-tls")]
Tls(Box<rustls::StreamOwned<ClientSession, TcpStream>>),
RustlsTls(Box<StreamOwned<ClientSession, TcpStream>>),
/// Mock stream
Mock(MockStream),
}
impl NetworkStream {
pub(self) fn new(inner: InnerNetworkStream) -> Self {
NetworkStream { inner }
}
pub fn new_mock(mock: MockStream) -> Self {
Self::new(InnerNetworkStream::Mock(mock))
}
/// Returns peer's address
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
match *self {
NetworkStream::Tcp(ref s) => s.peer_addr(),
match self.inner {
InnerNetworkStream::Tcp(ref s) => s.peer_addr(),
#[cfg(feature = "native-tls")]
NetworkStream::Tls(ref s) => s.get_ref().peer_addr(),
InnerNetworkStream::NativeTls(ref s) => s.get_ref().peer_addr(),
#[cfg(feature = "rustls-tls")]
NetworkStream::Tls(ref s) => s.get_ref().peer_addr(),
NetworkStream::Mock(_) => Ok(SocketAddr::V4(SocketAddrV4::new(
InnerNetworkStream::RustlsTls(ref s) => s.get_ref().peer_addr(),
InnerNetworkStream::Mock(_) => Ok(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 0, 1),
80,
))),
@@ -78,13 +60,13 @@ impl NetworkStream {
/// Shutdowns the connection
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
match *self {
NetworkStream::Tcp(ref s) => s.shutdown(how),
match self.inner {
InnerNetworkStream::Tcp(ref s) => s.shutdown(how),
#[cfg(feature = "native-tls")]
NetworkStream::Tls(ref s) => s.get_ref().shutdown(how),
InnerNetworkStream::NativeTls(ref s) => s.get_ref().shutdown(how),
#[cfg(feature = "rustls-tls")]
NetworkStream::Tls(ref s) => s.get_ref().shutdown(how),
NetworkStream::Mock(_) => Ok(()),
InnerNetworkStream::RustlsTls(ref s) => s.get_ref().shutdown(how),
InnerNetworkStream::Mock(_) => Ok(()),
}
}
@@ -111,119 +93,127 @@ impl NetworkStream {
None => TcpStream::connect(server)?,
};
match tls_parameters {
#[cfg(feature = "native-tls")]
Some(context) => context
.connector
.connect(context.domain.as_ref(), tcp_stream)
.map(|tls| NetworkStream::Tls(Box::new(tls)))
.map_err(|e| Error::Io(io::Error::new(ErrorKind::Other, e))),
#[cfg(feature = "rustls-tls")]
Some(context) => {
let domain = webpki::DNSNameRef::try_from_ascii_str(&context.domain)?;
Ok(NetworkStream::Tls(Box::new(rustls::StreamOwned::new(
ClientSession::new(&Arc::new(*context.connector.clone()), domain),
tcp_stream,
))))
}
#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
Some(_) => panic!("TLS configuration without support"),
None => Ok(NetworkStream::Tcp(tcp_stream)),
let mut stream = NetworkStream::new(InnerNetworkStream::Tcp(tcp_stream));
if let Some(tls_parameters) = tls_parameters {
stream.upgrade_tls(tls_parameters)?;
}
Ok(stream)
}
#[allow(unused_variables, unreachable_code)]
#[allow(unused_variables)]
pub fn upgrade_tls(&mut self, tls_parameters: &TlsParameters) -> Result<(), Error> {
*self = match *self {
#[cfg(feature = "native-tls")]
NetworkStream::Tcp(ref mut stream) => match tls_parameters
.connector
.connect(tls_parameters.domain.as_ref(), stream.try_clone().unwrap())
{
Ok(tls_stream) => NetworkStream::Tls(Box::new(tls_stream)),
Err(err) => return Err(Error::Io(io::Error::new(ErrorKind::Other, err))),
},
#[cfg(feature = "rustls-tls")]
NetworkStream::Tcp(ref mut stream) => {
let domain = webpki::DNSNameRef::try_from_ascii_str(&tls_parameters.domain)?;
match self.inner {
InnerNetworkStream::Tcp(ref mut stream) => {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
match &tls_parameters.connector {
#[cfg(feature = "native-tls")]
InnerTlsParameters::NativeTls(connector) => {
let stream = connector
.connect(tls_parameters.domain(), stream.try_clone().unwrap())
.map_err(|err| Error::Io(io::Error::new(io::ErrorKind::Other, err)))?;
*self = NetworkStream::new(InnerNetworkStream::NativeTls(stream));
}
#[cfg(feature = "rustls-tls")]
InnerTlsParameters::RustlsTls(connector) => {
use webpki::DNSNameRef;
NetworkStream::Tls(Box::new(rustls::StreamOwned::new(
ClientSession::new(&Arc::new(*tls_parameters.connector.clone()), domain),
stream.try_clone().unwrap(),
)))
let domain = DNSNameRef::try_from_ascii_str(tls_parameters.domain())?;
let stream = StreamOwned::new(
ClientSession::new(&Arc::new(connector.clone()), domain),
stream.try_clone().unwrap(),
);
*self = NetworkStream::new(InnerNetworkStream::RustlsTls(Box::new(stream)));
}
};
}
#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
NetworkStream::Tcp(_) => panic!("STARTTLS without TLS support"),
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
NetworkStream::Tls(_) => return Ok(()),
NetworkStream::Mock(_) => return Ok(()),
#[cfg(feature = "native-tls")]
InnerNetworkStream::NativeTls(_) => (),
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(_) => (),
InnerNetworkStream::Mock(_) => (),
};
Ok(())
}
pub fn is_encrypted(&self) -> bool {
match *self {
NetworkStream::Tcp(_) | NetworkStream::Mock(_) => false,
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
NetworkStream::Tls(_) => true,
match self.inner {
InnerNetworkStream::Tcp(_) | InnerNetworkStream::Mock(_) => false,
#[cfg(feature = "native-tls")]
InnerNetworkStream::NativeTls(_) => true,
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(_) => true,
}
}
pub fn set_read_timeout(&mut self, duration: Option<Duration>) -> io::Result<()> {
match *self {
NetworkStream::Tcp(ref mut stream) => stream.set_read_timeout(duration),
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
NetworkStream::Tls(ref mut stream) => stream.get_ref().set_read_timeout(duration),
NetworkStream::Mock(_) => Ok(()),
match self.inner {
InnerNetworkStream::Tcp(ref mut stream) => stream.set_read_timeout(duration),
#[cfg(feature = "native-tls")]
InnerNetworkStream::NativeTls(ref mut stream) => {
stream.get_ref().set_read_timeout(duration)
}
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(ref mut stream) => {
stream.get_ref().set_read_timeout(duration)
}
InnerNetworkStream::Mock(_) => Ok(()),
}
}
/// Set write timeout for IO calls
pub fn set_write_timeout(&mut self, duration: Option<Duration>) -> io::Result<()> {
match *self {
NetworkStream::Tcp(ref mut stream) => stream.set_write_timeout(duration),
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
NetworkStream::Tls(ref mut stream) => stream.get_ref().set_write_timeout(duration),
NetworkStream::Mock(_) => Ok(()),
match self.inner {
InnerNetworkStream::Tcp(ref mut stream) => stream.set_write_timeout(duration),
#[cfg(feature = "native-tls")]
InnerNetworkStream::NativeTls(ref mut stream) => {
stream.get_ref().set_write_timeout(duration)
}
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(ref mut stream) => {
stream.get_ref().set_write_timeout(duration)
}
InnerNetworkStream::Mock(_) => Ok(()),
}
}
}
impl Read for NetworkStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match *self {
NetworkStream::Tcp(ref mut s) => s.read(buf),
match self.inner {
InnerNetworkStream::Tcp(ref mut s) => s.read(buf),
#[cfg(feature = "native-tls")]
NetworkStream::Tls(ref mut s) => s.read(buf),
InnerNetworkStream::NativeTls(ref mut s) => s.read(buf),
#[cfg(feature = "rustls-tls")]
NetworkStream::Tls(ref mut s) => s.read(buf),
NetworkStream::Mock(ref mut s) => s.read(buf),
InnerNetworkStream::RustlsTls(ref mut s) => s.read(buf),
InnerNetworkStream::Mock(ref mut s) => s.read(buf),
}
}
}
impl Write for NetworkStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match *self {
NetworkStream::Tcp(ref mut s) => s.write(buf),
match self.inner {
InnerNetworkStream::Tcp(ref mut s) => s.write(buf),
#[cfg(feature = "native-tls")]
NetworkStream::Tls(ref mut s) => s.write(buf),
InnerNetworkStream::NativeTls(ref mut s) => s.write(buf),
#[cfg(feature = "rustls-tls")]
NetworkStream::Tls(ref mut s) => s.write(buf),
NetworkStream::Mock(ref mut s) => s.write(buf),
InnerNetworkStream::RustlsTls(ref mut s) => s.write(buf),
InnerNetworkStream::Mock(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match *self {
NetworkStream::Tcp(ref mut s) => s.flush(),
match self.inner {
InnerNetworkStream::Tcp(ref mut s) => s.flush(),
#[cfg(feature = "native-tls")]
NetworkStream::Tls(ref mut s) => s.flush(),
InnerNetworkStream::NativeTls(ref mut s) => s.flush(),
#[cfg(feature = "rustls-tls")]
NetworkStream::Tls(ref mut s) => s.flush(),
NetworkStream::Mock(ref mut s) => s.flush(),
InnerNetworkStream::RustlsTls(ref mut s) => s.flush(),
InnerNetworkStream::Mock(ref mut s) => s.flush(),
}
}
}

View File

@@ -0,0 +1,89 @@
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
use crate::transport::smtp::error::Error;
#[cfg(feature = "native-tls")]
use native_tls::{Protocol, TlsConnector};
#[cfg(feature = "rustls-tls")]
use rustls::ClientConfig;
/// Accepted protocols by default.
/// This removes TLS 1.0 and 1.1 compared to tls-native defaults.
// This is also rustls' default behavior
#[cfg(feature = "native-tls")]
const DEFAULT_TLS_MIN_PROTOCOL: Protocol = Protocol::Tlsv12;
/// How to apply TLS to a client connection
#[derive(Clone)]
#[allow(missing_copy_implementations)]
pub enum Tls {
/// Insecure connection only (for testing purposes)
None,
/// Start with insecure connection and use `STARTTLS` when available
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Opportunistic(TlsParameters),
/// Start with insecure connection and require `STARTTLS`
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Required(TlsParameters),
/// Use TLS wrapped connection
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Wrapper(TlsParameters),
}
/// Parameters to use for secure clients
#[derive(Clone)]
#[allow(missing_debug_implementations)]
pub struct TlsParameters {
pub(crate) connector: InnerTlsParameters,
/// The domain name which is expected in the TLS certificate from the server
domain: String,
}
#[derive(Clone)]
pub enum InnerTlsParameters {
#[cfg(feature = "native-tls")]
NativeTls(TlsConnector),
#[cfg(feature = "rustls-tls")]
RustlsTls(ClientConfig),
}
impl TlsParameters {
/// Creates a new `TlsParameters` using native-tls or rustls
/// depending on which one is available
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
pub fn new(domain: String) -> Result<Self, Error> {
#[cfg(feature = "native-tls")]
return Self::new_native(domain);
#[cfg(not(feature = "native-tls"))]
return Self::new_rustls(domain);
}
/// Creates a new `TlsParameters` using native-tls
#[cfg(feature = "native-tls")]
pub fn new_native(domain: String) -> Result<Self, Error> {
let mut tls_builder = TlsConnector::builder();
tls_builder.min_protocol_version(Some(DEFAULT_TLS_MIN_PROTOCOL));
let connector = tls_builder.build()?;
Ok(Self {
connector: InnerTlsParameters::NativeTls(connector),
domain,
})
}
/// Creates a new `TlsParameters` using rustls
#[cfg(feature = "rustls-tls")]
pub fn new_rustls(domain: String) -> Result<Self, Error> {
use webpki_roots::TLS_SERVER_ROOTS;
let mut tls = ClientConfig::new();
tls.root_store.add_server_trust_anchors(&TLS_SERVER_ROOTS);
Ok(Self {
connector: InnerTlsParameters::RustlsTls(tls),
domain,
})
}
pub fn domain(&self) -> &str {
&self.domain
}
}

View File

@@ -176,8 +176,13 @@
//! # }
//! ```
use std::time::Duration;
#[cfg(feature = "r2d2")]
use r2d2::{Builder, Pool};
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
use crate::transport::smtp::client::net::TlsParameters;
use crate::transport::smtp::client::TlsParameters;
use crate::{
transport::smtp::{
authentication::{Credentials, Mechanism, DEFAULT_MECHANISMS},
@@ -188,15 +193,8 @@ use crate::{
},
Envelope, Transport,
};
#[cfg(feature = "native-tls")]
use native_tls::{Protocol, TlsConnector};
#[cfg(feature = "r2d2")]
use r2d2::{Builder, Pool};
#[cfg(feature = "rustls-tls")]
use rustls::ClientConfig;
use std::time::Duration;
#[cfg(feature = "rustls-tls")]
use webpki_roots::TLS_SERVER_ROOTS;
use client::Tls;
pub mod authentication;
pub mod client;
@@ -224,29 +222,6 @@ pub const SUBMISSIONS_PORT: u16 = 465;
/// Default timeout
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
/// Accepted protocols by default.
/// This removes TLS 1.0 and 1.1 compared to tls-native defaults.
// This is also rustls' default behavior
#[cfg(feature = "native-tls")]
const DEFAULT_TLS_MIN_PROTOCOL: Protocol = Protocol::Tlsv12;
/// How to apply TLS to a client connection
#[derive(Clone)]
#[allow(missing_debug_implementations, missing_copy_implementations)]
pub enum Tls {
/// Insecure connection only (for testing purposes)
None,
/// Start with insecure connection and use `STARTTLS` when available
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Opportunistic(TlsParameters),
/// Start with insecure connection and require `STARTTLS`
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Required(TlsParameters),
/// Use TLS wrapped connection
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Wrapper(TlsParameters),
}
#[allow(missing_debug_implementations)]
#[derive(Clone)]
pub struct SmtpTransport {
@@ -282,19 +257,7 @@ impl SmtpTransport {
/// to validate TLS certificates.
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
pub fn relay(relay: &str) -> Result<SmtpTransportBuilder, Error> {
#[cfg(feature = "native-tls")]
let mut tls_builder = TlsConnector::builder();
#[cfg(feature = "native-tls")]
tls_builder.min_protocol_version(Some(DEFAULT_TLS_MIN_PROTOCOL));
#[cfg(feature = "native-tls")]
let tls_parameters = TlsParameters::new(relay.to_string(), tls_builder.build()?);
#[cfg(feature = "rustls-tls")]
let mut tls = ClientConfig::new();
#[cfg(feature = "rustls-tls")]
tls.root_store.add_server_trust_anchors(&TLS_SERVER_ROOTS);
#[cfg(feature = "rustls-tls")]
let tls_parameters = TlsParameters::new(relay.to_string(), tls);
let tls_parameters = TlsParameters::new(relay.into())?;
Ok(Self::builder_dangerous(relay)
.port(SUBMISSIONS_PORT)