diff --git a/src/executor.rs b/src/executor.rs index 199a378..2bdbb50 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -151,11 +151,11 @@ impl Executor for Tokio1Executor { match tls { Tls::Opportunistic(ref tls_parameters) => { if conn.can_starttls() { - conn.starttls(tls_parameters.clone(), hello_name).await?; + conn = conn.starttls(tls_parameters.clone(), hello_name).await?; } } Tls::Required(ref tls_parameters) => { - conn.starttls(tls_parameters.clone(), hello_name).await?; + conn = conn.starttls(tls_parameters.clone(), hello_name).await?; } _ => (), } @@ -247,11 +247,11 @@ impl Executor for AsyncStd1Executor { match tls { Tls::Opportunistic(ref tls_parameters) => { if conn.can_starttls() { - conn.starttls(tls_parameters.clone(), hello_name).await?; + conn = conn.starttls(tls_parameters.clone(), hello_name).await?; } } Tls::Required(ref tls_parameters) => { - conn.starttls(tls_parameters.clone(), hello_name).await?; + conn = conn.starttls(tls_parameters.clone(), hello_name).await?; } _ => (), } diff --git a/src/transport/smtp/client/async_connection.rs b/src/transport/smtp/client/async_connection.rs index f265b95..5d81466 100644 --- a/src/transport/smtp/client/async_connection.rs +++ b/src/transport/smtp/client/async_connection.rs @@ -208,18 +208,20 @@ impl AsyncSmtpConnection { /// [rfc8314]: https://www.rfc-editor.org/rfc/rfc8314 #[allow(unused_variables)] pub async fn starttls( - &mut self, + mut self, tls_parameters: TlsParameters, hello_name: &ClientId, - ) -> Result<(), Error> { + ) -> Result { if self.server_info.supports_feature(Extension::StartTls) { try_smtp!(self.command(Starttls).await, self); - self.stream.get_mut().upgrade_tls(tls_parameters).await?; + let stream = self.stream.into_inner(); + let stream = stream.upgrade_tls(tls_parameters).await?; + self.stream = BufReader::new(stream); #[cfg(feature = "tracing")] tracing::debug!("connection encrypted"); // Send EHLO again try_smtp!(self.ehlo(hello_name).await, self); - Ok(()) + Ok(self) } else { Err(error::client("STARTTLS is not supported on this server")) } diff --git a/src/transport/smtp/client/async_net.rs b/src/transport/smtp/client/async_net.rs index d9f1b46..b0c3245 100644 --- a/src/transport/smtp/client/async_net.rs +++ b/src/transport/smtp/client/async_net.rs @@ -11,8 +11,7 @@ use async_native_tls::TlsStream as AsyncStd1TlsStream; #[cfg(feature = "async-std1")] use async_std::net::{TcpStream as AsyncStd1TcpStream, ToSocketAddrs as AsyncStd1ToSocketAddrs}; use futures_io::{ - AsyncRead as FuturesAsyncRead, AsyncWrite as FuturesAsyncWrite, Error as IoError, ErrorKind, - Result as IoResult, + AsyncRead as FuturesAsyncRead, AsyncWrite as FuturesAsyncWrite, Result as IoResult, }; #[cfg(feature = "async-std1-rustls-tls")] use futures_rustls::client::TlsStream as AsyncStd1RustlsTlsStream; @@ -91,16 +90,10 @@ enum InnerAsyncNetworkStream { /// Encrypted Tokio 1.x TCP stream #[cfg(feature = "async-std1-rustls-tls")] AsyncStd1RustlsTls(AsyncStd1RustlsTlsStream), - /// Can't be built - None, } impl AsyncNetworkStream { fn new(inner: InnerAsyncNetworkStream) -> Self { - if let InnerAsyncNetworkStream::None = inner { - debug_assert!(false, "InnerAsyncNetworkStream::None must never be built"); - } - AsyncNetworkStream { inner } } @@ -123,13 +116,6 @@ impl AsyncNetworkStream { InnerAsyncNetworkStream::AsyncStd1NativeTls(ref s) => s.get_ref().peer_addr(), #[cfg(feature = "async-std1-rustls-tls")] InnerAsyncNetworkStream::AsyncStd1RustlsTls(ref s) => s.get_ref().0.peer_addr(), - InnerAsyncNetworkStream::None => { - debug_assert!(false, "InnerAsyncNetworkStream::None must never be built"); - Err(IoError::new( - ErrorKind::Other, - "InnerAsyncNetworkStream::None must never be built", - )) - } } } @@ -199,7 +185,7 @@ impl AsyncNetworkStream { let mut stream = AsyncNetworkStream::new(InnerAsyncNetworkStream::Tokio1Tcp(Box::new(tcp_stream))); if let Some(tls_parameters) = tls_parameters { - stream.upgrade_tls(tls_parameters).await?; + stream = stream.upgrade_tls(tls_parameters).await?; } Ok(stream) } @@ -250,13 +236,13 @@ impl AsyncNetworkStream { let mut stream = AsyncNetworkStream::new(InnerAsyncNetworkStream::AsyncStd1Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { - stream.upgrade_tls(tls_parameters).await?; + stream = stream.upgrade_tls(tls_parameters).await?; } Ok(stream) } - pub async fn upgrade_tls(&mut self, tls_parameters: TlsParameters) -> Result<(), Error> { - match &self.inner { + pub async fn upgrade_tls(self, tls_parameters: TlsParameters) -> Result { + match self.inner { #[cfg(all( feature = "tokio1", not(any( @@ -275,18 +261,11 @@ impl AsyncNetworkStream { feature = "tokio1-rustls-tls", feature = "tokio1-boring-tls" ))] - InnerAsyncNetworkStream::Tokio1Tcp(_) => { - // get owned TcpStream - let tcp_stream = mem::replace(&mut self.inner, InnerAsyncNetworkStream::None); - let tcp_stream = match tcp_stream { - InnerAsyncNetworkStream::Tokio1Tcp(tcp_stream) => tcp_stream, - _ => unreachable!(), - }; - - self.inner = Self::upgrade_tokio1_tls(tcp_stream, tls_parameters) + InnerAsyncNetworkStream::Tokio1Tcp(tcp_stream) => { + let inner = Self::upgrade_tokio1_tls(tcp_stream, tls_parameters) .await .map_err(error::connection)?; - Ok(()) + Ok(Self { inner }) } #[cfg(all( feature = "async-std1", @@ -298,20 +277,13 @@ impl AsyncNetworkStream { } #[cfg(any(feature = "async-std1-native-tls", feature = "async-std1-rustls-tls"))] - InnerAsyncNetworkStream::AsyncStd1Tcp(_) => { - // get owned TcpStream - let tcp_stream = mem::replace(&mut self.inner, InnerAsyncNetworkStream::None); - let tcp_stream = match tcp_stream { - InnerAsyncNetworkStream::AsyncStd1Tcp(tcp_stream) => tcp_stream, - _ => unreachable!(), - }; - - self.inner = Self::upgrade_asyncstd1_tls(tcp_stream, tls_parameters) + InnerAsyncNetworkStream::AsyncStd1Tcp(tcp_stream) => { + let inner = Self::upgrade_asyncstd1_tls(tcp_stream, tls_parameters) .await .map_err(error::connection)?; - Ok(()) + Ok(Self { inner }) } - _ => Ok(()), + _ => Ok(self), } } @@ -460,7 +432,6 @@ impl AsyncNetworkStream { InnerAsyncNetworkStream::AsyncStd1NativeTls(_) => true, #[cfg(feature = "async-std1-rustls-tls")] InnerAsyncNetworkStream::AsyncStd1RustlsTls(_) => true, - InnerAsyncNetworkStream::None => false, } } @@ -509,7 +480,6 @@ impl AsyncNetworkStream { .first() .unwrap() .to_vec()), - InnerAsyncNetworkStream::None => panic!("InnerNetworkStream::None must never be built"), } } } @@ -567,10 +537,6 @@ impl FuturesAsyncRead for AsyncNetworkStream { InnerAsyncNetworkStream::AsyncStd1RustlsTls(ref mut s) => { Pin::new(s).poll_read(cx, buf) } - InnerAsyncNetworkStream::None => { - debug_assert!(false, "InnerAsyncNetworkStream::None must never be built"); - Poll::Ready(Ok(0)) - } } } } @@ -600,10 +566,6 @@ impl FuturesAsyncWrite for AsyncNetworkStream { InnerAsyncNetworkStream::AsyncStd1RustlsTls(ref mut s) => { Pin::new(s).poll_write(cx, buf) } - InnerAsyncNetworkStream::None => { - debug_assert!(false, "InnerAsyncNetworkStream::None must never be built"); - Poll::Ready(Ok(0)) - } } } @@ -623,10 +585,6 @@ impl FuturesAsyncWrite for AsyncNetworkStream { InnerAsyncNetworkStream::AsyncStd1NativeTls(ref mut s) => Pin::new(s).poll_flush(cx), #[cfg(feature = "async-std1-rustls-tls")] InnerAsyncNetworkStream::AsyncStd1RustlsTls(ref mut s) => Pin::new(s).poll_flush(cx), - InnerAsyncNetworkStream::None => { - debug_assert!(false, "InnerAsyncNetworkStream::None must never be built"); - Poll::Ready(Ok(())) - } } } @@ -646,10 +604,6 @@ impl FuturesAsyncWrite for AsyncNetworkStream { InnerAsyncNetworkStream::AsyncStd1NativeTls(ref mut s) => Pin::new(s).poll_close(cx), #[cfg(feature = "async-std1-rustls-tls")] InnerAsyncNetworkStream::AsyncStd1RustlsTls(ref mut s) => Pin::new(s).poll_close(cx), - InnerAsyncNetworkStream::None => { - debug_assert!(false, "InnerAsyncNetworkStream::None must never be built"); - Poll::Ready(Ok(())) - } } } } diff --git a/src/transport/smtp/client/connection.rs b/src/transport/smtp/client/connection.rs index b3dc62f..12ff230 100644 --- a/src/transport/smtp/client/connection.rs +++ b/src/transport/smtp/client/connection.rs @@ -138,20 +138,22 @@ impl SmtpConnection { #[allow(unused_variables)] pub fn starttls( - &mut self, + mut self, tls_parameters: &TlsParameters, hello_name: &ClientId, - ) -> Result<(), Error> { + ) -> Result { if self.server_info.supports_feature(Extension::StartTls) { #[cfg(any(feature = "native-tls", feature = "rustls-tls", feature = "boring-tls"))] { try_smtp!(self.command(Starttls), self); - self.stream.get_mut().upgrade_tls(tls_parameters)?; + let stream = self.stream.into_inner(); + let stream = stream.upgrade_tls(tls_parameters)?; + self.stream = BufReader::new(stream); #[cfg(feature = "tracing")] tracing::debug!("connection encrypted"); // Send EHLO again try_smtp!(self.ehlo(hello_name), self); - Ok(()) + Ok(self) } #[cfg(not(any( feature = "native-tls", diff --git a/src/transport/smtp/client/net.rs b/src/transport/smtp/client/net.rs index c2ed429..efede5d 100644 --- a/src/transport/smtp/client/net.rs +++ b/src/transport/smtp/client/net.rs @@ -2,8 +2,7 @@ use std::sync::Arc; use std::{ io::{self, Read, Write}, - mem, - net::{IpAddr, Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream, ToSocketAddrs}, + net::{IpAddr, Shutdown, SocketAddr, TcpStream, ToSocketAddrs}, time::Duration, }; @@ -40,16 +39,10 @@ enum InnerNetworkStream { RustlsTls(StreamOwned), #[cfg(feature = "boring-tls")] BoringTls(SslStream), - /// Can't be built - None, } impl NetworkStream { fn new(inner: InnerNetworkStream) -> Self { - if let InnerNetworkStream::None = inner { - debug_assert!(false, "InnerNetworkStream::None must never be built"); - } - NetworkStream { inner } } @@ -63,13 +56,6 @@ impl NetworkStream { InnerNetworkStream::RustlsTls(ref s) => s.get_ref().peer_addr(), #[cfg(feature = "boring-tls")] InnerNetworkStream::BoringTls(ref s) => s.get_ref().peer_addr(), - InnerNetworkStream::None => { - debug_assert!(false, "InnerNetworkStream::None must never be built"); - Ok(SocketAddr::V4(SocketAddrV4::new( - Ipv4Addr::new(127, 0, 0, 1), - 80, - ))) - } } } @@ -83,10 +69,6 @@ impl NetworkStream { InnerNetworkStream::RustlsTls(ref s) => s.get_ref().shutdown(how), #[cfg(feature = "boring-tls")] InnerNetworkStream::BoringTls(ref s) => s.get_ref().shutdown(how), - InnerNetworkStream::None => { - debug_assert!(false, "InnerNetworkStream::None must never be built"); - Ok(()) - } } } @@ -139,13 +121,13 @@ impl NetworkStream { let tcp_stream = try_connect(server, timeout, local_addr)?; let mut stream = NetworkStream::new(InnerNetworkStream::Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { - stream.upgrade_tls(tls_parameters)?; + stream = stream.upgrade_tls(tls_parameters)?; } Ok(stream) } - pub fn upgrade_tls(&mut self, tls_parameters: &TlsParameters) -> Result<(), Error> { - match &self.inner { + pub fn upgrade_tls(self, tls_parameters: &TlsParameters) -> Result { + match self.inner { #[cfg(not(any( feature = "native-tls", feature = "rustls-tls", @@ -157,18 +139,11 @@ impl NetworkStream { } #[cfg(any(feature = "native-tls", feature = "rustls-tls", feature = "boring-tls"))] - InnerNetworkStream::Tcp(_) => { - // get owned TcpStream - let tcp_stream = mem::replace(&mut self.inner, InnerNetworkStream::None); - let tcp_stream = match tcp_stream { - InnerNetworkStream::Tcp(tcp_stream) => tcp_stream, - _ => unreachable!(), - }; - - self.inner = Self::upgrade_tls_impl(tcp_stream, tls_parameters)?; - Ok(()) + InnerNetworkStream::Tcp(tcp_stream) => { + let inner = Self::upgrade_tls_impl(tcp_stream, tls_parameters)?; + Ok(Self { inner }) } - _ => Ok(()), + _ => Ok(self), } } @@ -216,10 +191,6 @@ impl NetworkStream { InnerNetworkStream::RustlsTls(_) => true, #[cfg(feature = "boring-tls")] InnerNetworkStream::BoringTls(_) => true, - InnerNetworkStream::None => { - debug_assert!(false, "InnerNetworkStream::None must never be built"); - false - } } } @@ -249,7 +220,6 @@ impl NetworkStream { .unwrap() .to_der() .map_err(error::tls)?), - InnerNetworkStream::None => panic!("InnerNetworkStream::None must never be built"), } } @@ -268,10 +238,6 @@ impl NetworkStream { InnerNetworkStream::BoringTls(ref mut stream) => { stream.get_ref().set_read_timeout(duration) } - InnerNetworkStream::None => { - debug_assert!(false, "InnerNetworkStream::None must never be built"); - Ok(()) - } } } @@ -292,10 +258,6 @@ impl NetworkStream { InnerNetworkStream::BoringTls(ref mut stream) => { stream.get_ref().set_write_timeout(duration) } - InnerNetworkStream::None => { - debug_assert!(false, "InnerNetworkStream::None must never be built"); - Ok(()) - } } } } @@ -310,10 +272,6 @@ impl Read for NetworkStream { InnerNetworkStream::RustlsTls(ref mut s) => s.read(buf), #[cfg(feature = "boring-tls")] InnerNetworkStream::BoringTls(ref mut s) => s.read(buf), - InnerNetworkStream::None => { - debug_assert!(false, "InnerNetworkStream::None must never be built"); - Ok(0) - } } } } @@ -328,10 +286,6 @@ impl Write for NetworkStream { InnerNetworkStream::RustlsTls(ref mut s) => s.write(buf), #[cfg(feature = "boring-tls")] InnerNetworkStream::BoringTls(ref mut s) => s.write(buf), - InnerNetworkStream::None => { - debug_assert!(false, "InnerNetworkStream::None must never be built"); - Ok(0) - } } } @@ -344,10 +298,6 @@ impl Write for NetworkStream { InnerNetworkStream::RustlsTls(ref mut s) => s.flush(), #[cfg(feature = "boring-tls")] InnerNetworkStream::BoringTls(ref mut s) => s.flush(), - InnerNetworkStream::None => { - debug_assert!(false, "InnerNetworkStream::None must never be built"); - Ok(()) - } } } } diff --git a/src/transport/smtp/extension.rs b/src/transport/smtp/extension.rs index 89fc299..afdadf5 100644 --- a/src/transport/smtp/extension.rs +++ b/src/transport/smtp/extension.rs @@ -4,7 +4,6 @@ use std::{ collections::HashSet, fmt::{self, Display, Formatter}, net::{Ipv4Addr, Ipv6Addr}, - result::Result, }; use crate::transport::smtp::{ diff --git a/src/transport/smtp/transport.rs b/src/transport/smtp/transport.rs index 3763aab..4baa99c 100644 --- a/src/transport/smtp/transport.rs +++ b/src/transport/smtp/transport.rs @@ -336,11 +336,11 @@ impl SmtpClient { match self.info.tls { Tls::Opportunistic(ref tls_parameters) => { if conn.can_starttls() { - conn.starttls(tls_parameters, &self.info.hello_name)?; + conn = conn.starttls(tls_parameters, &self.info.hello_name)?; } } Tls::Required(ref tls_parameters) => { - conn.starttls(tls_parameters, &self.info.hello_name)?; + conn = conn.starttls(tls_parameters, &self.info.hello_name)?; } _ => (), }