From 57d7bf25ccdc93d2964a03c6793a0e743fc42dcd Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Sun, 10 Mar 2024 15:23:09 +0100 Subject: [PATCH] Replace try_smtp macro with more resilient method --- src/transport/smtp/client/async_connection.rs | 81 +++++++++---------- src/transport/smtp/client/async_net.rs | 28 ++++++- src/transport/smtp/client/connection.rs | 79 +++++++++--------- src/transport/smtp/client/mod.rs | 18 +++++ src/transport/smtp/client/net.rs | 25 +++++- 5 files changed, 135 insertions(+), 96 deletions(-) diff --git a/src/transport/smtp/client/async_connection.rs b/src/transport/smtp/client/async_connection.rs index 5d81466..59ff3ad 100644 --- a/src/transport/smtp/client/async_connection.rs +++ b/src/transport/smtp/client/async_connection.rs @@ -6,7 +6,7 @@ use futures_util::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use super::async_net::AsyncTokioStream; #[cfg(feature = "tracing")] use super::escape_crlf; -use super::{AsyncNetworkStream, ClientCodec, TlsParameters}; +use super::{AsyncNetworkStream, ClientCodec, ConnectionState, TlsParameters}; use crate::{ transport::smtp::{ authentication::{Credentials, Mechanism}, @@ -19,25 +19,11 @@ use crate::{ Envelope, }; -macro_rules! try_smtp ( - ($err: expr, $client: ident) => ({ - match $err { - Ok(val) => val, - Err(err) => { - $client.abort().await; - return Err(From::from(err)) - }, - } - }) -); - /// Structure that implements the SMTP client pub struct AsyncSmtpConnection { /// TCP stream between client and server /// Value is None before connection stream: BufReader, - /// Panic state - panic: bool, /// Information about the server server_info: ServerInfo, } @@ -126,7 +112,6 @@ impl AsyncSmtpConnection { let stream = BufReader::new(stream); let mut conn = AsyncSmtpConnection { stream, - panic: false, server_info: ServerInfo::default(), }; // TODO log @@ -170,30 +155,26 @@ impl AsyncSmtpConnection { mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime)); } - try_smtp!( - self.command(Mail::new(envelope.from().cloned(), mail_options)) - .await, - self - ); + self.command(Mail::new(envelope.from().cloned(), mail_options)) + .await?; // Recipient for to_address in envelope.to() { - try_smtp!( - self.command(Rcpt::new(to_address.clone(), vec![])).await, - self - ); + self.command(Rcpt::new(to_address.clone(), vec![])).await?; } // Data - try_smtp!(self.command(Data).await, self); + self.command(Data).await?; // Message content - let result = try_smtp!(self.message(email).await, self); - Ok(result) + self.message(email).await } pub fn has_broken(&self) -> bool { - self.panic + match self.stream.get_ref().state() { + ConnectionState::Ok => false, + ConnectionState::Broken | ConnectionState::Closed => true, + } } pub fn can_starttls(&self) -> bool { @@ -213,14 +194,14 @@ impl AsyncSmtpConnection { hello_name: &ClientId, ) -> Result { if self.server_info.supports_feature(Extension::StartTls) { - try_smtp!(self.command(Starttls).await, self); + self.command(Starttls).await?; let stream = self.stream.into_inner(); let stream = stream.upgrade_tls(tls_parameters).await?; self.stream = BufReader::new(stream); #[cfg(feature = "tracing")] tracing::debug!("connection encrypted"); // Send EHLO again - try_smtp!(self.ehlo(hello_name).await, self); + self.ehlo(hello_name).await?; Ok(self) } else { Err(error::client("STARTTLS is not supported on this server")) @@ -229,22 +210,24 @@ impl AsyncSmtpConnection { /// Send EHLO and update server info async fn ehlo(&mut self, hello_name: &ClientId) -> Result<(), Error> { - let ehlo_response = try_smtp!(self.command(Ehlo::new(hello_name.clone())).await, self); - self.server_info = try_smtp!(ServerInfo::from_response(&ehlo_response), self); + let ehlo_response = self.command(Ehlo::new(hello_name.clone())).await?; + self.server_info = ServerInfo::from_response(&ehlo_response)?; Ok(()) } pub async fn quit(&mut self) -> Result { - Ok(try_smtp!(self.command(Quit).await, self)) + self.command(Quit).await } pub async fn abort(&mut self) { - // Only try to quit if we are not already broken - if !self.panic { - self.panic = true; - let _ = self.command(Quit).await; + match self.stream.get_ref().state() { + ConnectionState::Ok | ConnectionState::Broken => { + let _ = self.command(Quit).await; + let _ = self.stream.close().await; + self.stream.get_mut().set_state(ConnectionState::Closed); + } + ConnectionState::Closed => {} } - let _ = self.stream.close().await; } /// Sets the underlying stream @@ -281,15 +264,13 @@ impl AsyncSmtpConnection { while challenges > 0 && response.has_code(334) { challenges -= 1; - response = try_smtp!( - self.command(Auth::new_from_response( + response = self + .command(Auth::new_from_response( mechanism, credentials.clone(), &response, )?) - .await, - self - ); + .await?; } if challenges == 0 { @@ -317,6 +298,9 @@ impl AsyncSmtpConnection { /// Writes a string to the server async fn write(&mut self, string: &[u8]) -> Result<(), Error> { + self.stream.get_ref().state().verify()?; + self.stream.get_mut().set_state(ConnectionState::Broken); + self.stream .get_mut() .write_all(string) @@ -328,6 +312,8 @@ impl AsyncSmtpConnection { .await .map_err(error::network)?; + self.stream.get_mut().set_state(ConnectionState::Ok); + #[cfg(feature = "tracing")] tracing::debug!("Wrote: {}", escape_crlf(&String::from_utf8_lossy(string))); Ok(()) @@ -335,6 +321,9 @@ impl AsyncSmtpConnection { /// Gets the SMTP response pub async fn read_response(&mut self) -> Result { + self.stream.get_ref().state().verify()?; + self.stream.get_mut().set_state(ConnectionState::Broken); + let mut buffer = String::with_capacity(100); while self @@ -348,6 +337,8 @@ impl AsyncSmtpConnection { tracing::debug!("<< {}", escape_crlf(&buffer)); match parse_response(&buffer) { Ok((_remaining, response)) => { + self.stream.get_mut().set_state(ConnectionState::Ok); + return if response.is_positive() { Ok(response) } else { @@ -355,7 +346,7 @@ impl AsyncSmtpConnection { response.code(), Some(response.message().collect()), )) - } + }; } Err(nom::Err::Failure(e)) => { return Err(error::response(e.to_string())); diff --git a/src/transport/smtp/client/async_net.rs b/src/transport/smtp/client/async_net.rs index 31eca13..6a113be 100644 --- a/src/transport/smtp/client/async_net.rs +++ b/src/transport/smtp/client/async_net.rs @@ -39,7 +39,7 @@ use tokio1_rustls::client::TlsStream as Tokio1RustlsTlsStream; feature = "async-std1-rustls-tls" ))] use super::InnerTlsParameters; -use super::TlsParameters; +use super::{ConnectionState, TlsParameters}; #[cfg(feature = "tokio1")] use crate::transport::smtp::client::net::resolved_address_filter; use crate::transport::smtp::{error, Error}; @@ -48,6 +48,7 @@ use crate::transport::smtp::{error, Error}; #[derive(Debug)] pub struct AsyncNetworkStream { inner: InnerAsyncNetworkStream, + state: ConnectionState, } #[cfg(feature = "tokio1")] @@ -94,7 +95,18 @@ enum InnerAsyncNetworkStream { impl AsyncNetworkStream { fn new(inner: InnerAsyncNetworkStream) -> Self { - AsyncNetworkStream { inner } + AsyncNetworkStream { + inner, + state: ConnectionState::Ok, + } + } + + pub(super) fn state(&self) -> ConnectionState { + self.state + } + + pub(super) fn set_state(&mut self, state: ConnectionState) { + self.state = state; } /// Returns peer's address @@ -265,7 +277,10 @@ impl AsyncNetworkStream { let inner = Self::upgrade_tokio1_tls(tcp_stream, tls_parameters) .await .map_err(error::connection)?; - Ok(Self { inner }) + Ok(Self { + inner, + state: ConnectionState::Ok, + }) } #[cfg(all( feature = "async-std1", @@ -281,7 +296,10 @@ impl AsyncNetworkStream { let inner = Self::upgrade_asyncstd1_tls(tcp_stream, tls_parameters) .await .map_err(error::connection)?; - Ok(Self { inner }) + Ok(Self { + inner, + state: ConnectionState::Ok, + }) } _ => Ok(self), } @@ -581,6 +599,8 @@ impl FuturesAsyncWrite for AsyncNetworkStream { } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.state = ConnectionState::Closed; + match &mut self.inner { #[cfg(feature = "tokio1")] InnerAsyncNetworkStream::Tokio1Tcp(s) => Pin::new(s).poll_shutdown(cx), diff --git a/src/transport/smtp/client/connection.rs b/src/transport/smtp/client/connection.rs index 12ff230..9a98385 100644 --- a/src/transport/smtp/client/connection.rs +++ b/src/transport/smtp/client/connection.rs @@ -7,7 +7,7 @@ use std::{ #[cfg(feature = "tracing")] use super::escape_crlf; -use super::{ClientCodec, NetworkStream, TlsParameters}; +use super::{ClientCodec, ConnectionState, NetworkStream, TlsParameters}; use crate::{ address::Envelope, transport::smtp::{ @@ -20,25 +20,11 @@ use crate::{ }, }; -macro_rules! try_smtp ( - ($err: expr, $client: ident) => ({ - match $err { - Ok(val) => val, - Err(err) => { - $client.abort(); - return Err(From::from(err)) - }, - } - }) -); - /// Structure that implements the SMTP client pub struct SmtpConnection { /// TCP stream between client and server /// Value is None before connection stream: BufReader, - /// Panic state - panic: bool, /// Information about the server server_info: ServerInfo, } @@ -65,7 +51,6 @@ impl SmtpConnection { let stream = BufReader::new(stream); let mut conn = SmtpConnection { stream, - panic: false, server_info: ServerInfo::default(), }; conn.set_timeout(timeout).map_err(error::network)?; @@ -110,26 +95,25 @@ impl SmtpConnection { mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime)); } - try_smtp!( - self.command(Mail::new(envelope.from().cloned(), mail_options)), - self - ); + self.command(Mail::new(envelope.from().cloned(), mail_options))?; // Recipient for to_address in envelope.to() { - try_smtp!(self.command(Rcpt::new(to_address.clone(), vec![])), self); + self.command(Rcpt::new(to_address.clone(), vec![]))?; } // Data - try_smtp!(self.command(Data), self); + self.command(Data)?; // Message content - let result = try_smtp!(self.message(email), self); - Ok(result) + self.message(email) } pub fn has_broken(&self) -> bool { - self.panic + match self.stream.get_ref().state() { + ConnectionState::Ok => false, + ConnectionState::Broken | ConnectionState::Closed => true, + } } pub fn can_starttls(&self) -> bool { @@ -145,14 +129,14 @@ impl SmtpConnection { if self.server_info.supports_feature(Extension::StartTls) { #[cfg(any(feature = "native-tls", feature = "rustls-tls", feature = "boring-tls"))] { - try_smtp!(self.command(Starttls), self); + self.command(Starttls)?; let stream = self.stream.into_inner(); let stream = stream.upgrade_tls(tls_parameters)?; self.stream = BufReader::new(stream); #[cfg(feature = "tracing")] tracing::debug!("connection encrypted"); // Send EHLO again - try_smtp!(self.ehlo(hello_name), self); + self.ehlo(hello_name)?; Ok(self) } #[cfg(not(any( @@ -170,22 +154,24 @@ impl SmtpConnection { /// Send EHLO and update server info fn ehlo(&mut self, hello_name: &ClientId) -> Result<(), Error> { - let ehlo_response = try_smtp!(self.command(Ehlo::new(hello_name.clone())), self); - self.server_info = try_smtp!(ServerInfo::from_response(&ehlo_response), self); + let ehlo_response = self.command(Ehlo::new(hello_name.clone()))?; + self.server_info = ServerInfo::from_response(&ehlo_response)?; Ok(()) } pub fn quit(&mut self) -> Result { - Ok(try_smtp!(self.command(Quit), self)) + self.command(Quit) } pub fn abort(&mut self) { - // Only try to quit if we are not already broken - if !self.panic { - self.panic = true; - let _ = self.command(Quit); + match self.stream.get_ref().state() { + ConnectionState::Ok | ConnectionState::Broken => { + let _ = self.command(Quit); + let _ = self.stream.get_mut().shutdown(std::net::Shutdown::Both); + self.stream.get_mut().set_state(ConnectionState::Closed); + } + ConnectionState::Closed => {} } - let _ = self.stream.get_mut().shutdown(std::net::Shutdown::Both); } /// Sets the underlying stream @@ -226,14 +212,11 @@ impl SmtpConnection { while challenges > 0 && response.has_code(334) { challenges -= 1; - response = try_smtp!( - self.command(Auth::new_from_response( - mechanism, - credentials.clone(), - &response, - )?), - self - ); + response = self.command(Auth::new_from_response( + mechanism, + credentials.clone(), + &response, + )?)?; } if challenges == 0 { @@ -262,12 +245,17 @@ impl SmtpConnection { /// Writes a string to the server fn write(&mut self, string: &[u8]) -> Result<(), Error> { + self.stream.get_ref().state().verify()?; + self.stream.get_mut().set_state(ConnectionState::Broken); + self.stream .get_mut() .write_all(string) .map_err(error::network)?; self.stream.get_mut().flush().map_err(error::network)?; + self.stream.get_mut().set_state(ConnectionState::Ok); + #[cfg(feature = "tracing")] tracing::debug!("Wrote: {}", escape_crlf(&String::from_utf8_lossy(string))); Ok(()) @@ -275,6 +263,9 @@ impl SmtpConnection { /// Gets the SMTP response pub fn read_response(&mut self) -> Result { + self.stream.get_ref().state().verify()?; + self.stream.get_mut().set_state(ConnectionState::Broken); + let mut buffer = String::with_capacity(100); while self.stream.read_line(&mut buffer).map_err(error::network)? > 0 { @@ -282,6 +273,8 @@ impl SmtpConnection { tracing::debug!("<< {}", escape_crlf(&buffer)); match parse_response(&buffer) { Ok((_remaining, response)) => { + self.stream.get_mut().set_state(ConnectionState::Ok); + return if response.is_positive() { Ok(response) } else { diff --git a/src/transport/smtp/client/mod.rs b/src/transport/smtp/client/mod.rs index be014c1..2e0ac9a 100644 --- a/src/transport/smtp/client/mod.rs +++ b/src/transport/smtp/client/mod.rs @@ -40,6 +40,7 @@ pub use self::{ connection::SmtpConnection, tls::{Certificate, CertificateStore, Tls, TlsParameters, TlsParametersBuilder}, }; +use super::{error, Error}; #[cfg(any(feature = "tokio1", feature = "async-std1"))] mod async_connection; @@ -49,6 +50,23 @@ mod connection; mod net; mod tls; +#[derive(Debug, Copy, Clone)] +enum ConnectionState { + Ok, + Broken, + Closed, +} + +impl ConnectionState { + fn verify(&mut self) -> Result<(), Error> { + match self { + Self::Ok => Ok(()), + Self::Broken => Err(error::connection("connection broken")), + Self::Closed => Err(error::connection("connection closed")), + } + } +} + /// The codec used for transparency #[derive(Debug)] struct ClientCodec { diff --git a/src/transport/smtp/client/net.rs b/src/transport/smtp/client/net.rs index e20b05b..a291a00 100644 --- a/src/transport/smtp/client/net.rs +++ b/src/transport/smtp/client/net.rs @@ -16,12 +16,13 @@ use socket2::{Domain, Protocol, Type}; #[cfg(any(feature = "native-tls", feature = "rustls-tls", feature = "boring-tls"))] use super::InnerTlsParameters; -use super::TlsParameters; +use super::{ConnectionState, TlsParameters}; use crate::transport::smtp::{error, Error}; /// A network stream pub struct NetworkStream { inner: InnerNetworkStream, + state: ConnectionState, } /// Represents the different types of underlying network streams @@ -43,7 +44,18 @@ enum InnerNetworkStream { impl NetworkStream { fn new(inner: InnerNetworkStream) -> Self { - NetworkStream { inner } + NetworkStream { + inner, + state: ConnectionState::Ok, + } + } + + pub(super) fn state(&self) -> ConnectionState { + self.state + } + + pub(super) fn set_state(&mut self, state: ConnectionState) { + self.state = state; } /// Returns peer's address @@ -60,7 +72,9 @@ impl NetworkStream { } /// Shutdowns the connection - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { + self.state = ConnectionState::Closed; + match &self.inner { InnerNetworkStream::Tcp(s) => s.shutdown(how), #[cfg(feature = "native-tls")] @@ -141,7 +155,10 @@ impl NetworkStream { #[cfg(any(feature = "native-tls", feature = "rustls-tls", feature = "boring-tls"))] InnerNetworkStream::Tcp(tcp_stream) => { let inner = Self::upgrade_tls_impl(tcp_stream, tls_parameters)?; - Ok(Self { inner }) + Ok(Self { + inner, + state: ConnectionState::Ok, + }) } _ => Ok(self), }