#[cfg(feature = "rustls-tls")] use std::convert::TryFrom; use std::{ io::{self, Read, Write}, mem, net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream, ToSocketAddrs}, time::Duration, }; #[cfg(feature = "native-tls")] use native_tls::TlsStream; #[cfg(feature = "rustls-tls")] use rustls::{ClientConnection, ServerName, StreamOwned}; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] use super::InnerTlsParameters; use super::TlsParameters; use crate::transport::smtp::{error, Error}; /// A network stream pub struct NetworkStream { inner: InnerNetworkStream, } /// Represents the different types of underlying network streams // usually only one TLS backend at a time is going to be enabled, // so clippy::large_enum_variant doesn't make sense here #[allow(clippy::large_enum_variant)] enum InnerNetworkStream { /// Plain TCP stream Tcp(TcpStream), /// Encrypted TCP stream #[cfg(feature = "native-tls")] NativeTls(TlsStream), /// Encrypted TCP stream #[cfg(feature = "rustls-tls")] RustlsTls(StreamOwned), /// Can't be built None, } impl NetworkStream { fn new(inner: InnerNetworkStream) -> Self { if let InnerNetworkStream::None = inner { debug_assert!(false, "InnerNetworkStream::None must never be built"); } NetworkStream { inner } } /// Returns peer's address pub fn peer_addr(&self) -> io::Result { match self.inner { InnerNetworkStream::Tcp(ref s) => s.peer_addr(), #[cfg(feature = "native-tls")] InnerNetworkStream::NativeTls(ref s) => s.get_ref().peer_addr(), #[cfg(feature = "rustls-tls")] InnerNetworkStream::RustlsTls(ref s) => s.get_ref().peer_addr(), InnerNetworkStream::None => { debug_assert!(false, "InnerNetworkStream::None must never be built"); Ok(SocketAddr::V4(SocketAddrV4::new( Ipv4Addr::new(127, 0, 0, 1), 80, ))) } } } /// Shutdowns the connection pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { match self.inner { InnerNetworkStream::Tcp(ref s) => s.shutdown(how), #[cfg(feature = "native-tls")] InnerNetworkStream::NativeTls(ref s) => s.get_ref().shutdown(how), #[cfg(feature = "rustls-tls")] InnerNetworkStream::RustlsTls(ref s) => s.get_ref().shutdown(how), InnerNetworkStream::None => { debug_assert!(false, "InnerNetworkStream::None must never be built"); Ok(()) } } } pub fn connect( server: T, timeout: Option, tls_parameters: Option<&TlsParameters>, ) -> Result { fn try_connect_timeout( server: T, timeout: Duration, ) -> Result { let addrs = server.to_socket_addrs().map_err(error::connection)?; let mut last_err = None; for addr in addrs { match TcpStream::connect_timeout(&addr, timeout) { Ok(stream) => return Ok(stream), Err(err) => last_err = Some(err), } } Err(match last_err { Some(last_err) => error::connection(last_err), None => error::connection("could not resolve to any address"), }) } let tcp_stream = match timeout { Some(t) => try_connect_timeout(server, t)?, None => TcpStream::connect(server).map_err(error::connection)?, }; let mut stream = NetworkStream::new(InnerNetworkStream::Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { stream.upgrade_tls(tls_parameters)?; } Ok(stream) } pub fn upgrade_tls(&mut self, tls_parameters: &TlsParameters) -> Result<(), Error> { match &self.inner { #[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))] InnerNetworkStream::Tcp(_) => { let _ = tls_parameters; panic!("Trying to upgrade an NetworkStream without having enabled either the native-tls or the rustls-tls feature"); } #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] InnerNetworkStream::Tcp(_) => { // get owned TcpStream let tcp_stream = mem::replace(&mut self.inner, InnerNetworkStream::None); let tcp_stream = match tcp_stream { InnerNetworkStream::Tcp(tcp_stream) => tcp_stream, _ => unreachable!(), }; self.inner = Self::upgrade_tls_impl(tcp_stream, tls_parameters)?; Ok(()) } _ => Ok(()), } } #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] fn upgrade_tls_impl( tcp_stream: TcpStream, tls_parameters: &TlsParameters, ) -> Result { Ok(match &tls_parameters.connector { #[cfg(feature = "native-tls")] InnerTlsParameters::NativeTls(connector) => { let stream = connector .connect(tls_parameters.domain(), tcp_stream) .map_err(error::connection)?; InnerNetworkStream::NativeTls(stream) } #[cfg(feature = "rustls-tls")] InnerTlsParameters::RustlsTls(connector) => { let domain = ServerName::try_from(tls_parameters.domain()) .map_err(|_| error::connection("domain isn't a valid DNS name"))?; let connection = ClientConnection::new(connector.clone(), domain).map_err(error::connection)?; let stream = StreamOwned::new(connection, tcp_stream); InnerNetworkStream::RustlsTls(stream) } }) } pub fn is_encrypted(&self) -> bool { match self.inner { InnerNetworkStream::Tcp(_) => false, #[cfg(feature = "native-tls")] InnerNetworkStream::NativeTls(_) => true, #[cfg(feature = "rustls-tls")] InnerNetworkStream::RustlsTls(_) => true, InnerNetworkStream::None => { debug_assert!(false, "InnerNetworkStream::None must never be built"); false } } } #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] pub fn peer_certificate(&self) -> Result, Error> { match &self.inner { InnerNetworkStream::Tcp(_) => Err(error::client("Connection is not encrypted")), #[cfg(feature = "native-tls")] InnerNetworkStream::NativeTls(stream) => Ok(stream .peer_certificate() .map_err(error::tls)? .unwrap() .to_der() .map_err(error::tls)?), #[cfg(feature = "rustls-tls")] InnerNetworkStream::RustlsTls(stream) => Ok(stream .conn .peer_certificates() .unwrap() .first() .unwrap() .clone() .0), InnerNetworkStream::None => panic!("InnerNetworkStream::None must never be built"), } } pub fn set_read_timeout(&mut self, duration: Option) -> io::Result<()> { match self.inner { InnerNetworkStream::Tcp(ref mut stream) => stream.set_read_timeout(duration), #[cfg(feature = "native-tls")] InnerNetworkStream::NativeTls(ref mut stream) => { stream.get_ref().set_read_timeout(duration) } #[cfg(feature = "rustls-tls")] InnerNetworkStream::RustlsTls(ref mut stream) => { stream.get_ref().set_read_timeout(duration) } InnerNetworkStream::None => { debug_assert!(false, "InnerNetworkStream::None must never be built"); Ok(()) } } } /// Set write timeout for IO calls pub fn set_write_timeout(&mut self, duration: Option) -> io::Result<()> { match self.inner { InnerNetworkStream::Tcp(ref mut stream) => stream.set_write_timeout(duration), #[cfg(feature = "native-tls")] InnerNetworkStream::NativeTls(ref mut stream) => { stream.get_ref().set_write_timeout(duration) } #[cfg(feature = "rustls-tls")] InnerNetworkStream::RustlsTls(ref mut stream) => { stream.get_ref().set_write_timeout(duration) } InnerNetworkStream::None => { debug_assert!(false, "InnerNetworkStream::None must never be built"); Ok(()) } } } } impl Read for NetworkStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self.inner { InnerNetworkStream::Tcp(ref mut s) => s.read(buf), #[cfg(feature = "native-tls")] InnerNetworkStream::NativeTls(ref mut s) => s.read(buf), #[cfg(feature = "rustls-tls")] InnerNetworkStream::RustlsTls(ref mut s) => s.read(buf), InnerNetworkStream::None => { debug_assert!(false, "InnerNetworkStream::None must never be built"); Ok(0) } } } } impl Write for NetworkStream { fn write(&mut self, buf: &[u8]) -> io::Result { match self.inner { InnerNetworkStream::Tcp(ref mut s) => s.write(buf), #[cfg(feature = "native-tls")] InnerNetworkStream::NativeTls(ref mut s) => s.write(buf), #[cfg(feature = "rustls-tls")] InnerNetworkStream::RustlsTls(ref mut s) => s.write(buf), InnerNetworkStream::None => { debug_assert!(false, "InnerNetworkStream::None must never be built"); Ok(0) } } } fn flush(&mut self) -> io::Result<()> { match self.inner { InnerNetworkStream::Tcp(ref mut s) => s.flush(), #[cfg(feature = "native-tls")] InnerNetworkStream::NativeTls(ref mut s) => s.flush(), #[cfg(feature = "rustls-tls")] InnerNetworkStream::RustlsTls(ref mut s) => s.flush(), InnerNetworkStream::None => { debug_assert!(false, "InnerNetworkStream::None must never be built"); Ok(()) } } } }