diff --git a/src/lib.rs b/src/lib.rs index 35c947b..e1de1c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,11 +49,11 @@ pub use crate::transport::file::FileTransport; #[cfg(feature = "sendmail-transport")] pub use crate::transport::sendmail::SendmailTransport; #[cfg(feature = "smtp-transport")] -pub use crate::transport::smtp::client::net::TlsParameters; +pub use crate::transport::smtp::client::TlsParameters; #[cfg(all(feature = "smtp-transport", feature = "connection-pool"))] pub use crate::transport::smtp::r2d2::SmtpConnectionManager; #[cfg(feature = "smtp-transport")] -pub use crate::transport::smtp::{SmtpTransport, Tls}; +pub use crate::transport::smtp::{client::Tls, SmtpTransport}; pub use crate::{address::Address, transport::stub::StubTransport}; #[cfg(any(feature = "async-std1", feature = "tokio02"))] use async_trait::async_trait; diff --git a/src/transport/smtp/client/mod.rs b/src/transport/smtp/client/mod.rs index 96fc5b3..baf2e18 100644 --- a/src/transport/smtp/client/mod.rs +++ b/src/transport/smtp/client/mod.rs @@ -1,18 +1,5 @@ //! SMTP client -use crate::{ - transport::smtp::{ - authentication::{Credentials, Mechanism}, - client::net::{NetworkStream, TlsParameters}, - commands::*, - error::Error, - extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo}, - response::{parse_response, Response}, - }, - Envelope, -}; -#[cfg(feature = "log")] -use log::debug; #[cfg(feature = "serde")] use std::fmt::Debug; use std::{ @@ -23,8 +10,29 @@ use std::{ time::Duration, }; -pub mod mock; -pub mod net; +#[cfg(feature = "log")] +use log::debug; + +use crate::{ + transport::smtp::{ + authentication::{Credentials, Mechanism}, + client::net::NetworkStream, + commands::*, + error::Error, + extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo}, + response::{parse_response, Response}, + }, + Envelope, +}; + +pub use self::mock::MockStream; +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] +pub(super) use self::tls::InnerTlsParameters; +pub use self::tls::{Tls, TlsParameters}; + +mod mock; +mod net; +mod tls; /// The codec used for transparency #[derive(Default, Clone, Copy, Debug)] diff --git a/src/transport/smtp/client/net.rs b/src/transport/smtp/client/net.rs index c112d29..8058e8d 100644 --- a/src/transport/smtp/client/net.rs +++ b/src/transport/smtp/client/net.rs @@ -1,75 +1,57 @@ -//! A trait to represent a stream - -use crate::transport::smtp::{client::mock::MockStream, error::Error}; -#[cfg(feature = "native-tls")] -use native_tls::{TlsConnector, TlsStream}; -#[cfg(feature = "rustls-tls")] -use rustls::{ClientConfig, ClientSession}; -#[cfg(feature = "native-tls")] -use std::io::ErrorKind; +use std::io::{self, Read, Write}; +use std::net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream, ToSocketAddrs}; #[cfg(feature = "rustls-tls")] use std::sync::Arc; -use std::{ - io::{self, Read, Write}, - net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream, ToSocketAddrs}, - time::Duration, -}; +use std::time::Duration; -/// Parameters to use for secure clients -#[derive(Clone)] -#[allow(missing_debug_implementations)] -pub struct TlsParameters { - /// A connector from `native-tls` - #[cfg(feature = "native-tls")] - connector: TlsConnector, - /// A client from `rustls` - #[cfg(feature = "rustls-tls")] - // TODO use the same in all transports of the client - connector: Box, - /// The domain name which is expected in the TLS certificate from the server - domain: String, -} +#[cfg(feature = "native-tls")] +use native_tls::TlsStream; -impl TlsParameters { - /// Creates a `TlsParameters` - #[cfg(feature = "native-tls")] - pub fn new(domain: String, connector: TlsConnector) -> Self { - Self { connector, domain } - } +#[cfg(feature = "rustls-tls")] +use rustls::{ClientSession, StreamOwned}; - /// Creates a `TlsParameters` - #[cfg(feature = "rustls-tls")] - pub fn new(domain: String, connector: ClientConfig) -> Self { - Self { - connector: Box::new(connector), - domain, - } - } +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] +use super::InnerTlsParameters; +use super::{MockStream, TlsParameters}; +use crate::transport::smtp::Error; + +/// A network stream +pub struct NetworkStream { + inner: InnerNetworkStream, } /// Represents the different types of underlying network streams -pub enum NetworkStream { +enum InnerNetworkStream { /// Plain TCP stream Tcp(TcpStream), /// Encrypted TCP stream #[cfg(feature = "native-tls")] - Tls(Box>), + NativeTls(TlsStream), + /// Encrypted TCP stream #[cfg(feature = "rustls-tls")] - Tls(Box>), + RustlsTls(Box>), /// Mock stream Mock(MockStream), } impl NetworkStream { + pub(self) fn new(inner: InnerNetworkStream) -> Self { + NetworkStream { inner } + } + + pub fn new_mock(mock: MockStream) -> Self { + Self::new(InnerNetworkStream::Mock(mock)) + } + /// Returns peer's address pub fn peer_addr(&self) -> io::Result { - match *self { - NetworkStream::Tcp(ref s) => s.peer_addr(), + match self.inner { + InnerNetworkStream::Tcp(ref s) => s.peer_addr(), #[cfg(feature = "native-tls")] - NetworkStream::Tls(ref s) => s.get_ref().peer_addr(), + InnerNetworkStream::NativeTls(ref s) => s.get_ref().peer_addr(), #[cfg(feature = "rustls-tls")] - NetworkStream::Tls(ref s) => s.get_ref().peer_addr(), - NetworkStream::Mock(_) => Ok(SocketAddr::V4(SocketAddrV4::new( + InnerNetworkStream::RustlsTls(ref s) => s.get_ref().peer_addr(), + InnerNetworkStream::Mock(_) => Ok(SocketAddr::V4(SocketAddrV4::new( Ipv4Addr::new(127, 0, 0, 1), 80, ))), @@ -78,13 +60,13 @@ impl NetworkStream { /// Shutdowns the connection pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { - match *self { - NetworkStream::Tcp(ref s) => s.shutdown(how), + match self.inner { + InnerNetworkStream::Tcp(ref s) => s.shutdown(how), #[cfg(feature = "native-tls")] - NetworkStream::Tls(ref s) => s.get_ref().shutdown(how), + InnerNetworkStream::NativeTls(ref s) => s.get_ref().shutdown(how), #[cfg(feature = "rustls-tls")] - NetworkStream::Tls(ref s) => s.get_ref().shutdown(how), - NetworkStream::Mock(_) => Ok(()), + InnerNetworkStream::RustlsTls(ref s) => s.get_ref().shutdown(how), + InnerNetworkStream::Mock(_) => Ok(()), } } @@ -111,119 +93,127 @@ impl NetworkStream { None => TcpStream::connect(server)?, }; - match tls_parameters { - #[cfg(feature = "native-tls")] - Some(context) => context - .connector - .connect(context.domain.as_ref(), tcp_stream) - .map(|tls| NetworkStream::Tls(Box::new(tls))) - .map_err(|e| Error::Io(io::Error::new(ErrorKind::Other, e))), - #[cfg(feature = "rustls-tls")] - Some(context) => { - let domain = webpki::DNSNameRef::try_from_ascii_str(&context.domain)?; - - Ok(NetworkStream::Tls(Box::new(rustls::StreamOwned::new( - ClientSession::new(&Arc::new(*context.connector.clone()), domain), - tcp_stream, - )))) - } - #[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))] - Some(_) => panic!("TLS configuration without support"), - None => Ok(NetworkStream::Tcp(tcp_stream)), + let mut stream = NetworkStream::new(InnerNetworkStream::Tcp(tcp_stream)); + if let Some(tls_parameters) = tls_parameters { + stream.upgrade_tls(tls_parameters)?; } + Ok(stream) } - #[allow(unused_variables, unreachable_code)] + #[allow(unused_variables)] pub fn upgrade_tls(&mut self, tls_parameters: &TlsParameters) -> Result<(), Error> { - *self = match *self { - #[cfg(feature = "native-tls")] - NetworkStream::Tcp(ref mut stream) => match tls_parameters - .connector - .connect(tls_parameters.domain.as_ref(), stream.try_clone().unwrap()) - { - Ok(tls_stream) => NetworkStream::Tls(Box::new(tls_stream)), - Err(err) => return Err(Error::Io(io::Error::new(ErrorKind::Other, err))), - }, - #[cfg(feature = "rustls-tls")] - NetworkStream::Tcp(ref mut stream) => { - let domain = webpki::DNSNameRef::try_from_ascii_str(&tls_parameters.domain)?; + match self.inner { + InnerNetworkStream::Tcp(ref mut stream) => { + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + match &tls_parameters.connector { + #[cfg(feature = "native-tls")] + InnerTlsParameters::NativeTls(connector) => { + let stream = connector + .connect(tls_parameters.domain(), stream.try_clone().unwrap()) + .map_err(|err| Error::Io(io::Error::new(io::ErrorKind::Other, err)))?; + *self = NetworkStream::new(InnerNetworkStream::NativeTls(stream)); + } + #[cfg(feature = "rustls-tls")] + InnerTlsParameters::RustlsTls(connector) => { + use webpki::DNSNameRef; - NetworkStream::Tls(Box::new(rustls::StreamOwned::new( - ClientSession::new(&Arc::new(*tls_parameters.connector.clone()), domain), - stream.try_clone().unwrap(), - ))) + let domain = DNSNameRef::try_from_ascii_str(tls_parameters.domain())?; + let stream = StreamOwned::new( + ClientSession::new(&Arc::new(connector.clone()), domain), + stream.try_clone().unwrap(), + ); + + *self = NetworkStream::new(InnerNetworkStream::RustlsTls(Box::new(stream))); + } + }; } - #[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))] - NetworkStream::Tcp(_) => panic!("STARTTLS without TLS support"), - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - NetworkStream::Tls(_) => return Ok(()), - NetworkStream::Mock(_) => return Ok(()), + #[cfg(feature = "native-tls")] + InnerNetworkStream::NativeTls(_) => (), + #[cfg(feature = "rustls-tls")] + InnerNetworkStream::RustlsTls(_) => (), + InnerNetworkStream::Mock(_) => (), }; Ok(()) } pub fn is_encrypted(&self) -> bool { - match *self { - NetworkStream::Tcp(_) | NetworkStream::Mock(_) => false, - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - NetworkStream::Tls(_) => true, + match self.inner { + InnerNetworkStream::Tcp(_) | InnerNetworkStream::Mock(_) => false, + #[cfg(feature = "native-tls")] + InnerNetworkStream::NativeTls(_) => true, + #[cfg(feature = "rustls-tls")] + InnerNetworkStream::RustlsTls(_) => true, } } pub fn set_read_timeout(&mut self, duration: Option) -> io::Result<()> { - match *self { - NetworkStream::Tcp(ref mut stream) => stream.set_read_timeout(duration), - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - NetworkStream::Tls(ref mut stream) => stream.get_ref().set_read_timeout(duration), - NetworkStream::Mock(_) => Ok(()), + match self.inner { + InnerNetworkStream::Tcp(ref mut stream) => stream.set_read_timeout(duration), + #[cfg(feature = "native-tls")] + InnerNetworkStream::NativeTls(ref mut stream) => { + stream.get_ref().set_read_timeout(duration) + } + #[cfg(feature = "rustls-tls")] + InnerNetworkStream::RustlsTls(ref mut stream) => { + stream.get_ref().set_read_timeout(duration) + } + InnerNetworkStream::Mock(_) => Ok(()), } } /// Set write timeout for IO calls pub fn set_write_timeout(&mut self, duration: Option) -> io::Result<()> { - match *self { - NetworkStream::Tcp(ref mut stream) => stream.set_write_timeout(duration), - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - NetworkStream::Tls(ref mut stream) => stream.get_ref().set_write_timeout(duration), - NetworkStream::Mock(_) => Ok(()), + match self.inner { + InnerNetworkStream::Tcp(ref mut stream) => stream.set_write_timeout(duration), + + #[cfg(feature = "native-tls")] + InnerNetworkStream::NativeTls(ref mut stream) => { + stream.get_ref().set_write_timeout(duration) + } + #[cfg(feature = "rustls-tls")] + InnerNetworkStream::RustlsTls(ref mut stream) => { + stream.get_ref().set_write_timeout(duration) + } + + InnerNetworkStream::Mock(_) => Ok(()), } } } impl Read for NetworkStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - NetworkStream::Tcp(ref mut s) => s.read(buf), + match self.inner { + InnerNetworkStream::Tcp(ref mut s) => s.read(buf), #[cfg(feature = "native-tls")] - NetworkStream::Tls(ref mut s) => s.read(buf), + InnerNetworkStream::NativeTls(ref mut s) => s.read(buf), #[cfg(feature = "rustls-tls")] - NetworkStream::Tls(ref mut s) => s.read(buf), - NetworkStream::Mock(ref mut s) => s.read(buf), + InnerNetworkStream::RustlsTls(ref mut s) => s.read(buf), + InnerNetworkStream::Mock(ref mut s) => s.read(buf), } } } impl Write for NetworkStream { fn write(&mut self, buf: &[u8]) -> io::Result { - match *self { - NetworkStream::Tcp(ref mut s) => s.write(buf), + match self.inner { + InnerNetworkStream::Tcp(ref mut s) => s.write(buf), #[cfg(feature = "native-tls")] - NetworkStream::Tls(ref mut s) => s.write(buf), + InnerNetworkStream::NativeTls(ref mut s) => s.write(buf), #[cfg(feature = "rustls-tls")] - NetworkStream::Tls(ref mut s) => s.write(buf), - NetworkStream::Mock(ref mut s) => s.write(buf), + InnerNetworkStream::RustlsTls(ref mut s) => s.write(buf), + InnerNetworkStream::Mock(ref mut s) => s.write(buf), } } fn flush(&mut self) -> io::Result<()> { - match *self { - NetworkStream::Tcp(ref mut s) => s.flush(), + match self.inner { + InnerNetworkStream::Tcp(ref mut s) => s.flush(), #[cfg(feature = "native-tls")] - NetworkStream::Tls(ref mut s) => s.flush(), + InnerNetworkStream::NativeTls(ref mut s) => s.flush(), #[cfg(feature = "rustls-tls")] - NetworkStream::Tls(ref mut s) => s.flush(), - NetworkStream::Mock(ref mut s) => s.flush(), + InnerNetworkStream::RustlsTls(ref mut s) => s.flush(), + InnerNetworkStream::Mock(ref mut s) => s.flush(), } } } diff --git a/src/transport/smtp/client/tls.rs b/src/transport/smtp/client/tls.rs new file mode 100644 index 0000000..cd8a3c0 --- /dev/null +++ b/src/transport/smtp/client/tls.rs @@ -0,0 +1,89 @@ +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] +use crate::transport::smtp::error::Error; + +#[cfg(feature = "native-tls")] +use native_tls::{Protocol, TlsConnector}; +#[cfg(feature = "rustls-tls")] +use rustls::ClientConfig; + +/// Accepted protocols by default. +/// This removes TLS 1.0 and 1.1 compared to tls-native defaults. +// This is also rustls' default behavior +#[cfg(feature = "native-tls")] +const DEFAULT_TLS_MIN_PROTOCOL: Protocol = Protocol::Tlsv12; + +/// How to apply TLS to a client connection +#[derive(Clone)] +#[allow(missing_copy_implementations)] +pub enum Tls { + /// Insecure connection only (for testing purposes) + None, + /// Start with insecure connection and use `STARTTLS` when available + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + Opportunistic(TlsParameters), + /// Start with insecure connection and require `STARTTLS` + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + Required(TlsParameters), + /// Use TLS wrapped connection + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + Wrapper(TlsParameters), +} + +/// Parameters to use for secure clients +#[derive(Clone)] +#[allow(missing_debug_implementations)] +pub struct TlsParameters { + pub(crate) connector: InnerTlsParameters, + /// The domain name which is expected in the TLS certificate from the server + domain: String, +} + +#[derive(Clone)] +pub enum InnerTlsParameters { + #[cfg(feature = "native-tls")] + NativeTls(TlsConnector), + #[cfg(feature = "rustls-tls")] + RustlsTls(ClientConfig), +} + +impl TlsParameters { + /// Creates a new `TlsParameters` using native-tls or rustls + /// depending on which one is available + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + pub fn new(domain: String) -> Result { + #[cfg(feature = "native-tls")] + return Self::new_native(domain); + + #[cfg(not(feature = "native-tls"))] + return Self::new_rustls(domain); + } + + /// Creates a new `TlsParameters` using native-tls + #[cfg(feature = "native-tls")] + pub fn new_native(domain: String) -> Result { + let mut tls_builder = TlsConnector::builder(); + tls_builder.min_protocol_version(Some(DEFAULT_TLS_MIN_PROTOCOL)); + let connector = tls_builder.build()?; + Ok(Self { + connector: InnerTlsParameters::NativeTls(connector), + domain, + }) + } + + /// Creates a new `TlsParameters` using rustls + #[cfg(feature = "rustls-tls")] + pub fn new_rustls(domain: String) -> Result { + use webpki_roots::TLS_SERVER_ROOTS; + + let mut tls = ClientConfig::new(); + tls.root_store.add_server_trust_anchors(&TLS_SERVER_ROOTS); + Ok(Self { + connector: InnerTlsParameters::RustlsTls(tls), + domain, + }) + } + + pub fn domain(&self) -> &str { + &self.domain + } +} diff --git a/src/transport/smtp/mod.rs b/src/transport/smtp/mod.rs index e684eef..75a80e7 100644 --- a/src/transport/smtp/mod.rs +++ b/src/transport/smtp/mod.rs @@ -176,8 +176,13 @@ //! # } //! ``` +use std::time::Duration; + +#[cfg(feature = "r2d2")] +use r2d2::{Builder, Pool}; + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] -use crate::transport::smtp::client::net::TlsParameters; +use crate::transport::smtp::client::TlsParameters; use crate::{ transport::smtp::{ authentication::{Credentials, Mechanism, DEFAULT_MECHANISMS}, @@ -188,15 +193,8 @@ use crate::{ }, Envelope, Transport, }; -#[cfg(feature = "native-tls")] -use native_tls::{Protocol, TlsConnector}; -#[cfg(feature = "r2d2")] -use r2d2::{Builder, Pool}; -#[cfg(feature = "rustls-tls")] -use rustls::ClientConfig; -use std::time::Duration; -#[cfg(feature = "rustls-tls")] -use webpki_roots::TLS_SERVER_ROOTS; + +use client::Tls; pub mod authentication; pub mod client; @@ -224,29 +222,6 @@ pub const SUBMISSIONS_PORT: u16 = 465; /// Default timeout pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60); -/// Accepted protocols by default. -/// This removes TLS 1.0 and 1.1 compared to tls-native defaults. -// This is also rustls' default behavior -#[cfg(feature = "native-tls")] -const DEFAULT_TLS_MIN_PROTOCOL: Protocol = Protocol::Tlsv12; - -/// How to apply TLS to a client connection -#[derive(Clone)] -#[allow(missing_debug_implementations, missing_copy_implementations)] -pub enum Tls { - /// Insecure connection only (for testing purposes) - None, - /// Start with insecure connection and use `STARTTLS` when available - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - Opportunistic(TlsParameters), - /// Start with insecure connection and require `STARTTLS` - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - Required(TlsParameters), - /// Use TLS wrapped connection - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - Wrapper(TlsParameters), -} - #[allow(missing_debug_implementations)] #[derive(Clone)] pub struct SmtpTransport { @@ -282,19 +257,7 @@ impl SmtpTransport { /// to validate TLS certificates. #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] pub fn relay(relay: &str) -> Result { - #[cfg(feature = "native-tls")] - let mut tls_builder = TlsConnector::builder(); - #[cfg(feature = "native-tls")] - tls_builder.min_protocol_version(Some(DEFAULT_TLS_MIN_PROTOCOL)); - #[cfg(feature = "native-tls")] - let tls_parameters = TlsParameters::new(relay.to_string(), tls_builder.build()?); - - #[cfg(feature = "rustls-tls")] - let mut tls = ClientConfig::new(); - #[cfg(feature = "rustls-tls")] - tls.root_store.add_server_trust_anchors(&TLS_SERVER_ROOTS); - #[cfg(feature = "rustls-tls")] - let tls_parameters = TlsParameters::new(relay.to_string(), tls); + let tls_parameters = TlsParameters::new(relay.into())?; Ok(Self::builder_dangerous(relay) .port(SUBMISSIONS_PORT)