From 18d2d051ed59705bd734ecfe9d44970940262e4b Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Wed, 9 Oct 2024 21:38:52 +0200 Subject: [PATCH] Improve SMTP error handling --- src/transport/smtp/client/async_connection.rs | 121 +++++++++--------- src/transport/smtp/client/connection.rs | 116 ++++++++--------- src/transport/smtp/client/mod.rs | 96 ++++++++++++++ 3 files changed, 213 insertions(+), 120 deletions(-) diff --git a/src/transport/smtp/client/async_connection.rs b/src/transport/smtp/client/async_connection.rs index cdbdec5..93683f1 100644 --- a/src/transport/smtp/client/async_connection.rs +++ b/src/transport/smtp/client/async_connection.rs @@ -10,34 +10,21 @@ use super::{AsyncNetworkStream, ClientCodec, TlsParameters}; use crate::{ transport::smtp::{ authentication::{Credentials, Mechanism}, + client::{ConnectionState, ConnectionWrapper}, commands::{Auth, Data, Ehlo, Mail, Noop, Quit, Rcpt, Starttls}, - error, - error::Error, + error::{self, Error}, extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo}, response::{parse_response, Response}, }, Envelope, }; -macro_rules! try_smtp ( - ($err: expr, $client: ident) => ({ - match $err { - Ok(val) => val, - Err(err) => { - $client.abort().await; - return Err(From::from(err)) - }, - } - }) -); - /// Structure that implements the SMTP client pub struct AsyncSmtpConnection { /// TCP stream between client and server - /// Value is None before connection - stream: BufReader, - /// Panic state - panic: bool, + stream: ConnectionWrapper>, + /// Whether QUIT has been sent + sent_quit: bool, /// Information about the server server_info: ServerInfo, } @@ -125,8 +112,8 @@ impl AsyncSmtpConnection { ) -> Result { let stream = BufReader::new(stream); let mut conn = AsyncSmtpConnection { - stream, - panic: false, + stream: ConnectionWrapper::new(stream), + sent_quit: false, server_info: ServerInfo::default(), }; // TODO log @@ -170,30 +157,28 @@ impl AsyncSmtpConnection { mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime)); } - try_smtp!( - self.command(Mail::new(envelope.from().cloned(), mail_options)) - .await, - self - ); + self.command(Mail::new(envelope.from().cloned(), mail_options)) + .await?; // Recipient for to_address in envelope.to() { - try_smtp!( - self.command(Rcpt::new(to_address.clone(), vec![])).await, - self - ); + self.command(Rcpt::new(to_address.clone(), vec![])).await?; } // Data - try_smtp!(self.command(Data).await, self); + self.command(Data).await?; // Message content - let result = try_smtp!(self.message(email).await, self); + let result = self.message(email).await?; Ok(result) } pub fn has_broken(&self) -> bool { - self.panic + self.sent_quit + || matches!( + self.stream.state(), + ConnectionState::BrokenConnection | ConnectionState::BrokenResponse + ) } pub fn can_starttls(&self) -> bool { @@ -213,12 +198,14 @@ impl AsyncSmtpConnection { hello_name: &ClientId, ) -> Result<(), Error> { if self.server_info.supports_feature(Extension::StartTls) { - try_smtp!(self.command(Starttls).await, self); - self.stream.get_mut().upgrade_tls(tls_parameters).await?; + self.command(Starttls).await?; + self.stream + .async_op(|stream| stream.get_mut().upgrade_tls(tls_parameters)) + .await?; #[cfg(feature = "tracing")] tracing::debug!("connection encrypted"); // Send EHLO again - try_smtp!(self.ehlo(hello_name).await, self); + self.ehlo(hello_name).await?; Ok(()) } else { Err(error::client("STARTTLS is not supported on this server")) @@ -227,32 +214,39 @@ impl AsyncSmtpConnection { /// Send EHLO and update server info async fn ehlo(&mut self, hello_name: &ClientId) -> Result<(), Error> { - let ehlo_response = try_smtp!(self.command(Ehlo::new(hello_name.clone())).await, self); - self.server_info = try_smtp!(ServerInfo::from_response(&ehlo_response), self); + let ehlo_response = self.command(Ehlo::new(hello_name.clone())).await?; + self.server_info = ServerInfo::from_response(&ehlo_response)?; Ok(()) } pub async fn quit(&mut self) -> Result { - Ok(try_smtp!(self.command(Quit).await, self)) + self.sent_quit = true; + self.command(Quit).await } pub async fn abort(&mut self) { // Only try to quit if we are not already broken - if !self.panic { - self.panic = true; - let _ = self.command(Quit).await; + // `write` already rejects writes if the connection state if bad + if !self.sent_quit { + let _ = self.quit().await; + } + + if !matches!(self.stream.state(), ConnectionState::BrokenConnection) { + let _ = self + .stream + .async_op(|stream| async { stream.close().await.map_err(error::network) }) + .await; } - let _ = self.stream.close().await; } /// Sets the underlying stream pub fn set_stream(&mut self, stream: AsyncNetworkStream) { - self.stream = BufReader::new(stream); + self.stream = ConnectionWrapper::new(BufReader::new(stream)); } /// Tells if the underlying stream is currently encrypted pub fn is_encrypted(&self) -> bool { - self.stream.get_ref().is_encrypted() + self.stream.get_ref().get_ref().is_encrypted() } /// Checks if the server is connected using the NOOP SMTP command @@ -279,15 +273,13 @@ impl AsyncSmtpConnection { while challenges > 0 && response.has_code(334) { challenges -= 1; - response = try_smtp!( - self.command(Auth::new_from_response( + response = self + .command(Auth::new_from_response( mechanism, credentials.clone(), &response, )?) - .await, - self - ); + .await?; } if challenges == 0 { @@ -316,15 +308,17 @@ impl AsyncSmtpConnection { /// Writes a string to the server async fn write(&mut self, string: &[u8]) -> Result<(), Error> { self.stream - .get_mut() - .write_all(string) - .await - .map_err(error::network)?; + .async_op(|stream| async { + stream + .get_mut() + .write_all(string) + .await + .map_err(error::network) + }) + .await?; self.stream - .get_mut() - .flush() - .await - .map_err(error::network)?; + .async_op(|stream| async { stream.get_mut().flush().await.map_err(error::network) }) + .await?; #[cfg(feature = "tracing")] tracing::debug!("Wrote: {}", escape_crlf(&String::from_utf8_lossy(string))); @@ -337,9 +331,10 @@ impl AsyncSmtpConnection { while self .stream - .read_line(&mut buffer) - .await - .map_err(error::network)? + .async_op(|stream| async { + stream.read_line(&mut buffer).await.map_err(error::network) + }) + .await? > 0 { #[cfg(feature = "tracing")] @@ -356,10 +351,12 @@ impl AsyncSmtpConnection { } } Err(nom::Err::Failure(e)) => { + self.stream.set_state(ConnectionState::BrokenResponse); return Err(error::response(e.to_string())); } Err(nom::Err::Incomplete(_)) => { /* read more */ } Err(nom::Err::Error(e)) => { + self.stream.set_state(ConnectionState::BrokenResponse); return Err(error::response(e.to_string())); } } @@ -371,12 +368,12 @@ impl AsyncSmtpConnection { /// The X509 certificate of the server (DER encoded) #[cfg(any(feature = "native-tls", feature = "rustls-tls", feature = "boring-tls"))] pub fn peer_certificate(&self) -> Result, Error> { - self.stream.get_ref().peer_certificate() + self.stream.get_ref().get_ref().peer_certificate() } /// All the X509 certificates of the chain (DER encoded) #[cfg(any(feature = "rustls-tls", feature = "boring-tls"))] pub fn certificate_chain(&self) -> Result>, Error> { - self.stream.get_ref().certificate_chain() + self.stream.get_ref().get_ref().certificate_chain() } } diff --git a/src/transport/smtp/client/connection.rs b/src/transport/smtp/client/connection.rs index 0e6ebdb..8f16060 100644 --- a/src/transport/smtp/client/connection.rs +++ b/src/transport/smtp/client/connection.rs @@ -7,38 +7,25 @@ use std::{ #[cfg(feature = "tracing")] use super::escape_crlf; -use super::{ClientCodec, NetworkStream, TlsParameters}; +use super::{ClientCodec, ConnectionWrapper, NetworkStream, TlsParameters}; use crate::{ address::Envelope, transport::smtp::{ authentication::{Credentials, Mechanism}, + client::ConnectionState, commands::{Auth, Data, Ehlo, Mail, Noop, Quit, Rcpt, Starttls}, - error, - error::Error, + error::{self, Error}, extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo}, response::{parse_response, Response}, }, }; -macro_rules! try_smtp ( - ($err: expr, $client: ident) => ({ - match $err { - Ok(val) => val, - Err(err) => { - $client.abort(); - return Err(From::from(err)) - }, - } - }) -); - /// Structure that implements the SMTP client pub struct SmtpConnection { /// TCP stream between client and server - /// Value is None before connection - stream: BufReader, - /// Panic state - panic: bool, + stream: ConnectionWrapper>, + /// Whether QUIT has been sent + sent_quit: bool, /// Information about the server server_info: ServerInfo, } @@ -64,8 +51,8 @@ impl SmtpConnection { let stream = NetworkStream::connect(server, timeout, tls_parameters, local_address)?; let stream = BufReader::new(stream); let mut conn = SmtpConnection { - stream, - panic: false, + stream: ConnectionWrapper::new(stream), + sent_quit: false, server_info: ServerInfo::default(), }; conn.set_timeout(timeout).map_err(error::network)?; @@ -110,26 +97,27 @@ impl SmtpConnection { mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime)); } - try_smtp!( - self.command(Mail::new(envelope.from().cloned(), mail_options)), - self - ); + self.command(Mail::new(envelope.from().cloned(), mail_options))?; // Recipient for to_address in envelope.to() { - try_smtp!(self.command(Rcpt::new(to_address.clone(), vec![])), self); + self.command(Rcpt::new(to_address.clone(), vec![]))?; } // Data - try_smtp!(self.command(Data), self); + self.command(Data)?; // Message content - let result = try_smtp!(self.message(email), self); + let result = self.message(email)?; Ok(result) } pub fn has_broken(&self) -> bool { - self.panic + self.sent_quit + || matches!( + self.stream.state(), + ConnectionState::BrokenConnection | ConnectionState::BrokenResponse + ) } pub fn can_starttls(&self) -> bool { @@ -145,12 +133,13 @@ impl SmtpConnection { if self.server_info.supports_feature(Extension::StartTls) { #[cfg(any(feature = "native-tls", feature = "rustls-tls", feature = "boring-tls"))] { - try_smtp!(self.command(Starttls), self); - self.stream.get_mut().upgrade_tls(tls_parameters)?; + self.command(Starttls)?; + self.stream + .sync_op(|stream| stream.get_mut().upgrade_tls(tls_parameters))?; #[cfg(feature = "tracing")] tracing::debug!("connection encrypted"); // Send EHLO again - try_smtp!(self.ehlo(hello_name), self); + self.ehlo(hello_name)?; Ok(()) } #[cfg(not(any( @@ -168,38 +157,47 @@ impl SmtpConnection { /// Send EHLO and update server info fn ehlo(&mut self, hello_name: &ClientId) -> Result<(), Error> { - let ehlo_response = try_smtp!(self.command(Ehlo::new(hello_name.clone())), self); - self.server_info = try_smtp!(ServerInfo::from_response(&ehlo_response), self); + let ehlo_response = self.command(Ehlo::new(hello_name.clone()))?; + self.server_info = ServerInfo::from_response(&ehlo_response)?; Ok(()) } pub fn quit(&mut self) -> Result { - Ok(try_smtp!(self.command(Quit), self)) + self.sent_quit = true; + self.command(Quit) } pub fn abort(&mut self) { // Only try to quit if we are not already broken - if !self.panic { - self.panic = true; - let _ = self.command(Quit); + // `write` already rejects writes if the connection state if bad + if !self.sent_quit { + let _ = self.quit(); + } + + if !matches!(self.stream.state(), ConnectionState::BrokenConnection) { + let _ = self.stream.sync_op(|stream| { + stream + .get_mut() + .shutdown(std::net::Shutdown::Both) + .map_err(error::network) + }); } - let _ = self.stream.get_mut().shutdown(std::net::Shutdown::Both); } /// Sets the underlying stream pub fn set_stream(&mut self, stream: NetworkStream) { - self.stream = BufReader::new(stream); + self.stream = ConnectionWrapper::new(BufReader::new(stream)); } /// Tells if the underlying stream is currently encrypted pub fn is_encrypted(&self) -> bool { - self.stream.get_ref().is_encrypted() + self.stream.get_ref().get_ref().is_encrypted() } /// Set timeout pub fn set_timeout(&mut self, duration: Option) -> io::Result<()> { - self.stream.get_mut().set_read_timeout(duration)?; - self.stream.get_mut().set_write_timeout(duration) + self.stream.get_mut().get_mut().set_read_timeout(duration)?; + self.stream.get_mut().get_mut().set_write_timeout(duration) } /// Checks if the server is connected using the NOOP SMTP command @@ -224,14 +222,11 @@ impl SmtpConnection { while challenges > 0 && response.has_code(334) { challenges -= 1; - response = try_smtp!( - self.command(Auth::new_from_response( - mechanism, - credentials.clone(), - &response, - )?), - self - ); + response = self.command(Auth::new_from_response( + mechanism, + credentials.clone(), + &response, + )?)?; } if challenges == 0 { @@ -261,10 +256,9 @@ impl SmtpConnection { /// Writes a string to the server fn write(&mut self, string: &[u8]) -> Result<(), Error> { self.stream - .get_mut() - .write_all(string) - .map_err(error::network)?; - self.stream.get_mut().flush().map_err(error::network)?; + .sync_op(|stream| stream.get_mut().write_all(string).map_err(error::network))?; + self.stream + .sync_op(|stream| stream.get_mut().flush().map_err(error::network))?; #[cfg(feature = "tracing")] tracing::debug!("Wrote: {}", escape_crlf(&String::from_utf8_lossy(string))); @@ -275,7 +269,11 @@ impl SmtpConnection { pub fn read_response(&mut self) -> Result { let mut buffer = String::with_capacity(100); - while self.stream.read_line(&mut buffer).map_err(error::network)? > 0 { + while self + .stream + .sync_op(|stream| stream.read_line(&mut buffer).map_err(error::network))? + > 0 + { #[cfg(feature = "tracing")] tracing::debug!("<< {}", escape_crlf(&buffer)); match parse_response(&buffer) { @@ -290,10 +288,12 @@ impl SmtpConnection { }; } Err(nom::Err::Failure(e)) => { + self.stream.set_state(ConnectionState::BrokenResponse); return Err(error::response(e.to_string())); } Err(nom::Err::Incomplete(_)) => { /* read more */ } Err(nom::Err::Error(e)) => { + self.stream.set_state(ConnectionState::BrokenResponse); return Err(error::response(e.to_string())); } } @@ -305,12 +305,12 @@ impl SmtpConnection { /// The X509 certificate of the server (DER encoded) #[cfg(any(feature = "native-tls", feature = "rustls-tls", feature = "boring-tls"))] pub fn peer_certificate(&self) -> Result, Error> { - self.stream.get_ref().peer_certificate() + self.stream.get_ref().get_ref().peer_certificate() } /// All the X509 certificates of the chain (DER encoded) #[cfg(any(feature = "rustls-tls", feature = "boring-tls"))] pub fn certificate_chain(&self) -> Result>, Error> { - self.stream.get_ref().certificate_chain() + self.stream.get_ref().get_ref().certificate_chain() } } diff --git a/src/transport/smtp/client/mod.rs b/src/transport/smtp/client/mod.rs index bb225b4..d73eaf8 100644 --- a/src/transport/smtp/client/mod.rs +++ b/src/transport/smtp/client/mod.rs @@ -24,6 +24,8 @@ #[cfg(feature = "serde")] use std::fmt::Debug; +#[cfg(any(feature = "tokio1", feature = "async-std1"))] +use std::future::Future; #[cfg(any(feature = "tokio1", feature = "async-std1"))] pub use self::async_connection::AsyncSmtpConnection; @@ -40,6 +42,7 @@ pub use self::{ connection::SmtpConnection, tls::{Certificate, CertificateStore, Identity, Tls, TlsParameters, TlsParametersBuilder}, }; +use super::{error, Error}; #[cfg(any(feature = "tokio1", feature = "async-std1"))] mod async_connection; @@ -49,6 +52,99 @@ mod connection; mod net; mod tls; +#[derive(Debug)] +pub(super) struct ConnectionWrapper { + conn: C, + state: ConnectionState, +} + +impl ConnectionWrapper { + pub(super) fn new(conn: C) -> Self { + Self { + conn, + state: ConnectionState::ProbablyConnected, + } + } + + pub(super) fn get_ref(&self) -> &C { + &self.conn + } + + pub(super) fn get_mut(&mut self) -> &mut C { + &mut self.conn + } + + pub(super) fn state(&self) -> ConnectionState { + self.state + } + + pub(super) fn set_state(&mut self, state: ConnectionState) { + self.state = state; + } + + pub(super) fn sync_op(&mut self, f: F) -> Result + where + F: FnOnce(&mut C) -> Result, + { + if !matches!( + self.state, + ConnectionState::ProbablyConnected | ConnectionState::BrokenResponse + ) { + return Err(error::client( + "attempted to send operation to broken connection", + )); + } + + self.state = ConnectionState::Writing; + match f(&mut self.conn) { + Ok(t) => { + self.state = ConnectionState::ProbablyConnected; + Ok(t) + } + Err(err) => { + self.state = ConnectionState::BrokenConnection; + Err(err) + } + } + } + + #[cfg(any(feature = "tokio1", feature = "async-std1"))] + pub(super) async fn async_op<'a, F, Fut, T>(&'a mut self, f: F) -> Result + where + F: FnOnce(&'a mut C) -> Fut, + Fut: Future>, + { + if !matches!( + self.state, + ConnectionState::ProbablyConnected | ConnectionState::BrokenResponse + ) { + return Err(error::client( + "attempted to send operation to broken connection", + )); + } + + self.state = ConnectionState::Writing; + match f(&mut self.conn).await { + Ok(t) => { + self.state = ConnectionState::ProbablyConnected; + Ok(t) + } + Err(err) => { + self.state = ConnectionState::BrokenConnection; + Err(err) + } + } + } +} + +#[derive(Debug, Copy, Clone)] +pub(super) enum ConnectionState { + ProbablyConnected, + Writing, + BrokenResponse, + BrokenConnection, +} + /// The codec used for transparency #[derive(Debug)] struct ClientCodec {