From 22efe341fea8482a55a52ad9732569c90157b0b1 Mon Sep 17 00:00:00 2001 From: Alexis Mousset Date: Sat, 13 Mar 2021 17:15:21 +0000 Subject: [PATCH] feat(builder): Seal SMTP error type (#564) * feat(builder): Seal SMTP error type * More precise error types --- Cargo.toml | 2 +- src/address/envelope.rs | 1 + src/address/types.rs | 1 + src/transport/file/mod.rs | 4 +- src/transport/sendmail/mod.rs | 2 +- src/transport/smtp/authentication.rs | 12 +- src/transport/smtp/client/async_connection.rs | 55 ++-- src/transport/smtp/client/async_net.rs | 60 +++- src/transport/smtp/client/connection.rs | 40 +-- src/transport/smtp/client/net.rs | 13 +- src/transport/smtp/client/tls.rs | 24 +- src/transport/smtp/commands.rs | 11 +- src/transport/smtp/error.rs | 275 ++++++++++-------- src/transport/smtp/extension.rs | 7 +- src/transport/smtp/pool.rs | 4 +- src/transport/smtp/response.rs | 19 +- src/transport/smtp/transport.rs | 6 +- 17 files changed, 308 insertions(+), 228 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f6a2e73..e9aadd1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ serde = { version = "1", optional = true, features = ["derive"] } serde_json = { version = "1", optional = true } # smtp -nom = { version = "6", default-features = false, features = ["alloc"], optional = true } +nom = { version = "6", default-features = false, features = ["alloc", "std"], optional = true } r2d2 = { version = "0.8", optional = true } # feature hostname = { version = "0.3", optional = true } # feature diff --git a/src/address/envelope.rs b/src/address/envelope.rs index 27ed35c..082e665 100644 --- a/src/address/envelope.rs +++ b/src/address/envelope.rs @@ -104,6 +104,7 @@ impl Envelope { self.reverse_path.as_ref() } + #[cfg(feature = "smtp-transport")] /// Check if any of the addresses in the envelope contains non-ascii chars pub(crate) fn has_non_ascii_addresses(&self) -> bool { self.reverse_path diff --git a/src/address/types.rs b/src/address/types.rs index 99e3e2e..fe6817b 100644 --- a/src/address/types.rs +++ b/src/address/types.rs @@ -174,6 +174,7 @@ impl Address { Err(AddressError::InvalidDomain) } + #[cfg(feature = "smtp-transport")] /// Check if the address contains non-ascii chars pub(super) fn is_ascii(&self) -> bool { self.serialized.is_ascii() diff --git a/src/transport/file/mod.rs b/src/transport/file/mod.rs index 139e9f9..8ba19c4 100644 --- a/src/transport/file/mod.rs +++ b/src/transport/file/mod.rs @@ -9,8 +9,8 @@ //! # //! # #[cfg(all(feature = "file-transport", feature = "builder"))] //! # fn main() -> Result<(), Box> { -//! use std::env::temp_dir; //! use lettre::{FileTransport, Message, Transport}; +//! use std::env::temp_dir; //! //! // Write to the local temp directory //! let sender = FileTransport::new(temp_dir()); @@ -41,8 +41,8 @@ //! # //! # #[cfg(all(feature = "file-transport-envelope", feature = "builder"))] //! # fn main() -> Result<(), Box> { -//! use std::env::temp_dir; //! use lettre::{FileTransport, Message, Transport}; +//! use std::env::temp_dir; //! //! // Write to the local temp directory //! let sender = FileTransport::with_envelope(temp_dir()); diff --git a/src/transport/sendmail/mod.rs b/src/transport/sendmail/mod.rs index 3d9e85c..06fc53c 100644 --- a/src/transport/sendmail/mod.rs +++ b/src/transport/sendmail/mod.rs @@ -7,7 +7,7 @@ //! # //! # #[cfg(all(feature = "sendmail-transport", feature = "builder"))] //! # fn main() -> Result<(), Box> { -//! use lettre::{Message, Transport, SendmailTransport}; +//! use lettre::{Message, SendmailTransport, Transport}; //! //! let email = Message::builder() //! .from("NoBody ".parse()?) diff --git a/src/transport/smtp/authentication.rs b/src/transport/smtp/authentication.rs index 37f8f7c..71d181c 100644 --- a/src/transport/smtp/authentication.rs +++ b/src/transport/smtp/authentication.rs @@ -1,6 +1,6 @@ //! Provides limited SASL authentication mechanisms -use crate::transport::smtp::error::Error; +use crate::transport::smtp::error::{self, Error}; use std::fmt::{self, Display, Formatter}; /// Accepted authentication mechanisms @@ -80,15 +80,15 @@ impl Mechanism { ) -> Result { match self { Mechanism::Plain => match challenge { - Some(_) => Err(Error::Client("This mechanism does not expect a challenge")), + Some(_) => Err(error::client("This mechanism does not expect a challenge")), None => Ok(format!( "\u{0}{}\u{0}{}", credentials.authentication_identity, credentials.secret )), }, Mechanism::Login => { - let decoded_challenge = - challenge.ok_or(Error::Client("This mechanism does expect a challenge"))?; + let decoded_challenge = challenge + .ok_or_else(|| error::client("This mechanism does expect a challenge"))?; if vec!["User Name", "Username:", "Username"].contains(&decoded_challenge) { return Ok(credentials.authentication_identity.to_string()); @@ -98,10 +98,10 @@ impl Mechanism { return Ok(credentials.secret.to_string()); } - Err(Error::Client("Unrecognized challenge")) + Err(error::client("Unrecognized challenge")) } Mechanism::Xoauth2 => match challenge { - Some(_) => Err(Error::Client("This mechanism does not expect a challenge")), + Some(_) => Err(error::client("This mechanism does not expect a challenge")), None => Ok(format!( "user={}\x01auth=Bearer {}\x01\x01", credentials.authentication_identity, credentials.secret diff --git a/src/transport/smtp/client/async_connection.rs b/src/transport/smtp/client/async_connection.rs index b4bf04c..19c9662 100644 --- a/src/transport/smtp/client/async_connection.rs +++ b/src/transport/smtp/client/async_connection.rs @@ -1,18 +1,17 @@ -use std::{fmt::Display, io}; - -use futures_util::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; - use super::{AsyncNetworkStream, ClientCodec, TlsParameters}; use crate::{ transport::smtp::{ authentication::{Credentials, Mechanism}, commands::*, + error, error::Error, extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo}, response::{parse_response, Response}, }, Envelope, }; +use futures_util::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use std::fmt::Display; #[cfg(feature = "tracing")] use super::escape_crlf; @@ -121,7 +120,7 @@ impl AsyncSmtpConnection { if envelope.has_non_ascii_addresses() { if !self.server_info().supports_feature(Extension::SmtpUtfEight) { // don't try to send non-ascii addresses (per RFC) - return Err(Error::Client( + return Err(error::client( "Envelope contains non-ascii chars but server does not support SMTPUTF8", )); } @@ -131,7 +130,7 @@ impl AsyncSmtpConnection { // Check for non-ascii content in message if !email.is_ascii() { if !self.server_info().supports_feature(Extension::EightBitMime) { - return Err(Error::Client( + return Err(error::client( "Message contains non-ascii chars but server does not support 8BITMIME", )); } @@ -186,7 +185,7 @@ impl AsyncSmtpConnection { try_smtp!(self.ehlo(hello_name).await, self); Ok(()) } else { - Err(Error::Client("STARTTLS is not supported on this server")) + Err(error::client("STARTTLS is not supported on this server")) } } @@ -233,12 +232,10 @@ impl AsyncSmtpConnection { let mechanism = self .server_info .get_auth_mechanism(mechanisms) - .ok_or(Error::Client( - "No compatible authentication mechanism was found", - ))?; + .ok_or_else(|| error::client("No compatible authentication mechanism was found"))?; // Limit challenges to avoid blocking - let mut challenges = 10; + let mut challenges: u8 = 10; let mut response = self .command(Auth::new(mechanism, credentials.clone(), None)?) .await?; @@ -257,7 +254,7 @@ impl AsyncSmtpConnection { } if challenges == 0 { - Err(Error::ResponseParsing("Unexpected number of challenges")) + Err(error::response("Unexpected number of challenges")) } else { Ok(response) } @@ -281,8 +278,16 @@ impl AsyncSmtpConnection { /// Writes a string to the server async fn write(&mut self, string: &[u8]) -> Result<(), Error> { - self.stream.get_mut().write_all(string).await?; - self.stream.get_mut().flush().await?; + self.stream + .get_mut() + .write_all(string) + .await + .map_err(error::network)?; + self.stream + .get_mut() + .flush() + .await + .map_err(error::network)?; #[cfg(feature = "tracing")] tracing::debug!("Wrote: {}", escape_crlf(&String::from_utf8_lossy(string))); @@ -293,27 +298,33 @@ impl AsyncSmtpConnection { pub async fn read_response(&mut self) -> Result { let mut buffer = String::with_capacity(100); - while self.stream.read_line(&mut buffer).await? > 0 { + while self + .stream + .read_line(&mut buffer) + .await + .map_err(error::network)? + > 0 + { #[cfg(feature = "tracing")] tracing::debug!("<< {}", escape_crlf(&buffer)); match parse_response(&buffer) { Ok((_remaining, response)) => { - if response.is_positive() { - return Ok(response); + return if response.is_positive() { + Ok(response) + } else { + Err(error::code(response.code)) } - - return Err(response.into()); } Err(nom::Err::Failure(e)) => { - return Err(Error::Parsing(e.code)); + return Err(error::response(e.to_string())); } Err(nom::Err::Incomplete(_)) => { /* read more */ } Err(nom::Err::Error(e)) => { - return Err(Error::Parsing(e.code)); + return Err(error::response(e.to_string())); } } } - Err(io::Error::new(io::ErrorKind::Other, "incomplete").into()) + Err(error::response("incomplete response")) } } diff --git a/src/transport/smtp/client/async_net.rs b/src/transport/smtp/client/async_net.rs index 1c70d30..a12f1ca 100644 --- a/src/transport/smtp/client/async_net.rs +++ b/src/transport/smtp/client/async_net.rs @@ -50,7 +50,7 @@ use tokio1_rustls::client::TlsStream as Tokio1RustlsTlsStream; ))] use super::InnerTlsParameters; use super::TlsParameters; -use crate::transport::smtp::Error; +use crate::transport::smtp::{error, Error}; /// A network stream pub struct AsyncNetworkStream { @@ -144,7 +144,9 @@ impl AsyncNetworkStream { port: u16, tls_parameters: Option, ) -> Result { - let tcp_stream = Tokio02TcpStream::connect((hostname, port)).await?; + let tcp_stream = Tokio02TcpStream::connect((hostname, port)) + .await + .map_err(error::connection)?; let mut stream = AsyncNetworkStream::new(InnerAsyncNetworkStream::Tokio02Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { @@ -159,7 +161,9 @@ impl AsyncNetworkStream { port: u16, tls_parameters: Option, ) -> Result { - let tcp_stream = Tokio1TcpStream::connect((hostname, port)).await?; + let tcp_stream = Tokio1TcpStream::connect((hostname, port)) + .await + .map_err(error::connection)?; let mut stream = AsyncNetworkStream::new(InnerAsyncNetworkStream::Tokio1Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { @@ -174,7 +178,9 @@ impl AsyncNetworkStream { port: u16, tls_parameters: Option, ) -> Result { - let tcp_stream = AsyncStd1TcpStream::connect((hostname, port)).await?; + let tcp_stream = AsyncStd1TcpStream::connect((hostname, port)) + .await + .map_err(error::connection)?; let mut stream = AsyncNetworkStream::new(InnerAsyncNetworkStream::AsyncStd1Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { @@ -203,7 +209,9 @@ impl AsyncNetworkStream { _ => unreachable!(), }; - self.inner = Self::upgrade_tokio02_tls(tcp_stream, tls_parameters).await?; + self.inner = Self::upgrade_tokio02_tls(tcp_stream, tls_parameters) + .await + .map_err(error::connection)?; Ok(()) } #[cfg(all( @@ -224,7 +232,9 @@ impl AsyncNetworkStream { _ => unreachable!(), }; - self.inner = Self::upgrade_tokio1_tls(tcp_stream, tls_parameters).await?; + self.inner = Self::upgrade_tokio1_tls(tcp_stream, tls_parameters) + .await + .map_err(error::connection)?; Ok(()) } #[cfg(all( @@ -245,7 +255,9 @@ impl AsyncNetworkStream { _ => unreachable!(), }; - self.inner = Self::upgrade_asyncstd1_tls(tcp_stream, tls_parameters).await?; + self.inner = Self::upgrade_asyncstd1_tls(tcp_stream, tls_parameters) + .await + .map_err(error::connection)?; Ok(()) } _ => Ok(()), @@ -271,7 +283,10 @@ impl AsyncNetworkStream { use tokio02_native_tls_crate::TlsConnector; let connector = TlsConnector::from(connector); - let stream = connector.connect(&domain, tcp_stream).await?; + let stream = connector + .connect(&domain, tcp_stream) + .await + .map_err(error::connection)?; Ok(InnerAsyncNetworkStream::Tokio02NativeTls(stream)) }; } @@ -284,10 +299,14 @@ impl AsyncNetworkStream { return { use tokio02_rustls::{webpki::DNSNameRef, TlsConnector}; - let domain = DNSNameRef::try_from_ascii_str(&domain)?; + let domain = + DNSNameRef::try_from_ascii_str(&domain).map_err(error::connection)?; let connector = TlsConnector::from(Arc::new(config)); - let stream = connector.connect(domain, tcp_stream).await?; + let stream = connector + .connect(domain, tcp_stream) + .await + .map_err(error::connection)?; Ok(InnerAsyncNetworkStream::Tokio02RustlsTls(stream)) }; } @@ -313,7 +332,10 @@ impl AsyncNetworkStream { use tokio1_native_tls_crate::TlsConnector; let connector = TlsConnector::from(connector); - let stream = connector.connect(&domain, tcp_stream).await?; + let stream = connector + .connect(&domain, tcp_stream) + .await + .map_err(error::connection)?; Ok(InnerAsyncNetworkStream::Tokio1NativeTls(stream)) }; } @@ -326,10 +348,14 @@ impl AsyncNetworkStream { return { use tokio1_rustls::{webpki::DNSNameRef, TlsConnector}; - let domain = DNSNameRef::try_from_ascii_str(&domain)?; + let domain = + DNSNameRef::try_from_ascii_str(&domain).map_err(error::connection)?; let connector = TlsConnector::from(Arc::new(config)); - let stream = connector.connect(domain, tcp_stream).await?; + let stream = connector + .connect(domain, tcp_stream) + .await + .map_err(error::connection)?; Ok(InnerAsyncNetworkStream::Tokio1RustlsTls(stream)) }; } @@ -374,10 +400,14 @@ impl AsyncNetworkStream { return { use async_rustls::{webpki::DNSNameRef, TlsConnector}; - let domain = DNSNameRef::try_from_ascii_str(&domain)?; + let domain = + DNSNameRef::try_from_ascii_str(&domain).map_err(error::connection)?; let connector = TlsConnector::from(Arc::new(config)); - let stream = connector.connect(domain, tcp_stream).await?; + let stream = connector + .connect(domain, tcp_stream) + .await + .map_err(error::connection)?; Ok(InnerAsyncNetworkStream::AsyncStd1RustlsTls(stream)) }; } diff --git a/src/transport/smtp/client/connection.rs b/src/transport/smtp/client/connection.rs index 362fa99..25676de 100644 --- a/src/transport/smtp/client/connection.rs +++ b/src/transport/smtp/client/connection.rs @@ -11,6 +11,7 @@ use crate::{ transport::smtp::{ authentication::{Credentials, Mechanism}, commands::*, + error, error::Error, extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo}, response::{parse_response, Response}, @@ -66,7 +67,7 @@ impl SmtpConnection { panic: false, server_info: ServerInfo::default(), }; - conn.set_timeout(timeout)?; + conn.set_timeout(timeout).map_err(error::network)?; // TODO log let _response = conn.read_response()?; @@ -91,7 +92,7 @@ impl SmtpConnection { if envelope.has_non_ascii_addresses() { if !self.server_info().supports_feature(Extension::SmtpUtfEight) { // don't try to send non-ascii addresses (per RFC) - return Err(Error::Client( + return Err(error::client( "Envelope contains non-ascii chars but server does not support SMTPUTF8", )); } @@ -101,7 +102,7 @@ impl SmtpConnection { // Check for non-ascii content in message if !email.is_ascii() { if !self.server_info().supports_feature(Extension::EightBitMime) { - return Err(Error::Client( + return Err(error::client( "Message contains non-ascii chars but server does not support 8BITMIME", )); } @@ -156,7 +157,7 @@ impl SmtpConnection { // when a TLS library is enabled unreachable!("TLS support required but not supported"); } else { - Err(Error::Client("STARTTLS is not supported on this server")) + Err(error::client("STARTTLS is not supported on this server")) } } @@ -209,9 +210,7 @@ impl SmtpConnection { let mechanism = self .server_info .get_auth_mechanism(mechanisms) - .ok_or(Error::Client( - "No compatible authentication mechanism was found", - ))?; + .ok_or_else(|| error::client("No compatible authentication mechanism was found"))?; // Limit challenges to avoid blocking let mut challenges = 10; @@ -230,7 +229,7 @@ impl SmtpConnection { } if challenges == 0 { - Err(Error::ResponseParsing("Unexpected number of challenges")) + Err(error::response("Unexpected number of challenges")) } else { Ok(response) } @@ -254,8 +253,11 @@ impl SmtpConnection { /// Writes a string to the server fn write(&mut self, string: &[u8]) -> Result<(), Error> { - self.stream.get_mut().write_all(string)?; - self.stream.get_mut().flush()?; + self.stream + .get_mut() + .write_all(string) + .map_err(error::network)?; + self.stream.get_mut().flush().map_err(error::network)?; #[cfg(feature = "tracing")] tracing::debug!("Wrote: {}", escape_crlf(&String::from_utf8_lossy(string))); @@ -266,27 +268,27 @@ impl SmtpConnection { pub fn read_response(&mut self) -> Result { let mut buffer = String::with_capacity(100); - while self.stream.read_line(&mut buffer)? > 0 { + while self.stream.read_line(&mut buffer).map_err(error::network)? > 0 { #[cfg(feature = "tracing")] tracing::debug!("<< {}", escape_crlf(&buffer)); match parse_response(&buffer) { Ok((_remaining, response)) => { - if response.is_positive() { - return Ok(response); - } - - return Err(response.into()); + return if response.is_positive() { + Ok(response) + } else { + Err(error::code(response.code)) + }; } Err(nom::Err::Failure(e)) => { - return Err(Error::Parsing(e.code)); + return Err(error::response(e.to_string())); } Err(nom::Err::Incomplete(_)) => { /* read more */ } Err(nom::Err::Error(e)) => { - return Err(Error::Parsing(e.code)); + return Err(error::response(e.to_string())); } } } - Err(io::Error::new(io::ErrorKind::Other, "incomplete").into()) + Err(error::response("incomplete response")) } } diff --git a/src/transport/smtp/client/net.rs b/src/transport/smtp/client/net.rs index bde6c7a..0e9ac5e 100644 --- a/src/transport/smtp/client/net.rs +++ b/src/transport/smtp/client/net.rs @@ -15,7 +15,7 @@ use rustls::{ClientSession, StreamOwned}; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] use super::InnerTlsParameters; use super::{MockStream, TlsParameters}; -use crate::transport::smtp::Error; +use crate::transport::smtp::{error, Error}; /// A network stream pub struct NetworkStream { @@ -84,18 +84,18 @@ impl NetworkStream { server: T, timeout: Duration, ) -> Result { - let addrs = server.to_socket_addrs()?; + let addrs = server.to_socket_addrs().map_err(error::connection)?; for addr in addrs { if let Ok(result) = TcpStream::connect_timeout(&addr, timeout) { return Ok(result); } } - Err(Error::Client("Could not connect")) + Err(error::connection("Could not connect")) } let tcp_stream = match timeout { Some(t) => try_connect_timeout(server, t)?, - None => TcpStream::connect(server)?, + None => TcpStream::connect(server).map_err(error::connection)?, }; let mut stream = NetworkStream::new(InnerNetworkStream::Tcp(tcp_stream)); @@ -140,14 +140,15 @@ impl NetworkStream { InnerTlsParameters::NativeTls(connector) => { let stream = connector .connect(tls_parameters.domain(), tcp_stream) - .map_err(|err| Error::Io(io::Error::new(io::ErrorKind::Other, err)))?; + .map_err(error::connection)?; InnerNetworkStream::NativeTls(stream) } #[cfg(feature = "rustls-tls")] InnerTlsParameters::RustlsTls(connector) => { use webpki::DNSNameRef; - let domain = DNSNameRef::try_from_ascii_str(tls_parameters.domain())?; + let domain = DNSNameRef::try_from_ascii_str(tls_parameters.domain()) + .map_err(error::connection)?; let stream = StreamOwned::new( ClientSession::new(&Arc::new(connector.clone()), domain), tcp_stream, diff --git a/src/transport/smtp/client/tls.rs b/src/transport/smtp/client/tls.rs index 54f6b0d..861efc4 100644 --- a/src/transport/smtp/client/tls.rs +++ b/src/transport/smtp/client/tls.rs @@ -1,15 +1,14 @@ -#[cfg(feature = "rustls-tls")] -use std::sync::Arc; - +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] +use crate::transport::smtp::{error, Error}; #[cfg(feature = "native-tls")] use native_tls::{Protocol, TlsConnector}; #[cfg(feature = "rustls-tls")] use rustls::{ClientConfig, RootCertStore, ServerCertVerified, ServerCertVerifier, TLSError}; #[cfg(feature = "rustls-tls")] +use std::sync::Arc; +#[cfg(feature = "rustls-tls")] use webpki::DNSNameRef; -use crate::transport::smtp::error::Error; - /// Accepted protocols by default. /// This removes TLS 1.0 and 1.1 compared to tls-native defaults. // This is also rustls' default behavior @@ -142,7 +141,7 @@ impl TlsParametersBuilder { tls_builder.danger_accept_invalid_certs(self.accept_invalid_certs); tls_builder.min_protocol_version(Some(DEFAULT_TLS_MIN_PROTOCOL)); - let connector = tls_builder.build()?; + let connector = tls_builder.build().map_err(error::tls)?; Ok(TlsParameters { connector: InnerTlsParameters::NativeTls(connector), domain: self.domain, @@ -159,9 +158,7 @@ impl TlsParametersBuilder { for cert in self.root_certs { for rustls_cert in cert.rustls { - tls.root_store - .add(&rustls_cert) - .map_err(|_| Error::InvalidCertificate)?; + tls.root_store.add(&rustls_cert).map_err(error::tls)?; } } if self.accept_invalid_certs { @@ -227,12 +224,12 @@ pub struct Certificate { rustls: Vec, } +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] impl Certificate { /// Create a `Certificate` from a DER encoded certificate pub fn from_der(der: Vec) -> Result { #[cfg(feature = "native-tls")] - let native_tls_cert = - native_tls::Certificate::from_der(&der).map_err(|_| Error::InvalidCertificate)?; + let native_tls_cert = native_tls::Certificate::from_der(&der).map_err(error::tls)?; Ok(Self { #[cfg(feature = "native-tls")] @@ -245,8 +242,7 @@ impl Certificate { /// Create a `Certificate` from a PEM encoded certificate pub fn from_pem(pem: &[u8]) -> Result { #[cfg(feature = "native-tls")] - let native_tls_cert = - native_tls::Certificate::from_pem(pem).map_err(|_| Error::InvalidCertificate)?; + let native_tls_cert = native_tls::Certificate::from_pem(pem).map_err(error::tls)?; #[cfg(feature = "rustls-tls")] let rustls_cert = { @@ -254,7 +250,7 @@ impl Certificate { use std::io::Cursor; let mut pem = Cursor::new(pem); - pemfile::certs(&mut pem).map_err(|_| Error::InvalidCertificate)? + pemfile::certs(&mut pem).map_err(|_| error::tls("invalid certificates"))? }; Ok(Self { diff --git a/src/transport/smtp/commands.rs b/src/transport/smtp/commands.rs index 5829aa2..c574013 100644 --- a/src/transport/smtp/commands.rs +++ b/src/transport/smtp/commands.rs @@ -1,13 +1,13 @@ //! SMTP commands use crate::{ + address::Address, transport::smtp::{ authentication::{Credentials, Mechanism}, - error::Error, + error::{self, Error}, extension::{ClientId, MailParameter, RcptParameter}, response::Response, }, - Address, }; use std::fmt::{self, Display, Formatter}; @@ -261,16 +261,17 @@ impl Auth { response: &Response, ) -> Result { if !response.has_code(334) { - return Err(Error::ResponseParsing("Expecting a challenge")); + return Err(error::response("Expecting a challenge")); } let encoded_challenge = response .first_word() - .ok_or(Error::ResponseParsing("Could not read auth challenge"))?; + .ok_or_else(|| error::response("Could not read auth challenge"))?; #[cfg(feature = "tracing")] tracing::debug!("auth encoded challenge: {}", encoded_challenge); - let decoded_challenge = String::from_utf8(base64::decode(&encoded_challenge)?)?; + let decoded_base64 = base64::decode(&encoded_challenge).map_err(error::response)?; + let decoded_challenge = String::from_utf8(decoded_base64).map_err(error::response)?; #[cfg(feature = "tracing")] tracing::debug!("auth decoded challenge: {}", decoded_challenge); diff --git a/src/transport/smtp/error.rs b/src/transport/smtp/error.rs index cb182b8..10c4aa6 100644 --- a/src/transport/smtp/error.rs +++ b/src/transport/smtp/error.rs @@ -1,162 +1,189 @@ //! Error and result type for SMTP clients -use self::Error::*; -use crate::transport::smtp::response::{Response, Severity}; -use base64::DecodeError; -use std::{ - error::Error as StdError, - fmt::{self, Display, Formatter}, - io, - string::FromUtf8Error, -}; +use crate::transport::smtp::response::{Code, Severity}; +use std::{error::Error as StdError, fmt, io}; + +// Inspired by https://github.com/seanmonstar/reqwest/blob/a8566383168c0ef06c21f38cbc9213af6ff6db31/src/error.rs + +/// The Errors that may occur when sending an email over SMTP +pub struct Error { + inner: Box, +} + +pub(crate) type BoxError = Box; + +struct Inner { + kind: Kind, + source: Option, +} + +impl Error { + pub(crate) fn new(kind: Kind, source: Option) -> Error + where + E: Into, + { + Error { + inner: Box::new(Inner { + kind, + source: source.map(Into::into), + }), + } + } + + /// Returns true if the error is from response + pub fn is_response(&self) -> bool { + matches!(self.inner.kind, Kind::Response) + } + + /// Returns true if the error is from client + pub fn is_client(&self) -> bool { + matches!(self.inner.kind, Kind::Client) + } + + /// Returns true if the error is a transient SMTP error + pub fn is_transient(&self) -> bool { + matches!(self.inner.kind, Kind::Transient(_)) + } + + /// Returns true if the error is a permanent SMTP error + pub fn is_permanent(&self) -> bool { + matches!(self.inner.kind, Kind::Permanent(_)) + } + + /// Returns true if the error is caused by a timeout + pub fn is_timeout(&self) -> bool { + let mut source = self.source(); + + while let Some(err) = source { + if let Some(io_err) = err.downcast_ref::() { + return io_err.kind() == std::io::ErrorKind::TimedOut; + } + + source = err.source(); + } + + false + } + + /// Returns true if the error is from TLS + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] + pub fn is_tls(&self) -> bool { + matches!(self.inner.kind, Kind::Tls) + } + + /// Returns the status code, if the error was generated from a response. + pub fn status(&self) -> Option { + match self.inner.kind { + Kind::Transient(code) => Some(code), + Kind::Permanent(code) => Some(code), + _ => None, + } + } + + #[allow(unused)] + pub(crate) fn into_io(self) -> io::Error { + io::Error::new(io::ErrorKind::Other, self) + } +} -/// An enum of all error kinds. #[derive(Debug)] -pub enum Error { +pub(crate) enum Kind { /// Transient SMTP error, 4xx reply code /// /// [RFC 5321, section 4.2.1](https://tools.ietf.org/html/rfc5321#section-4.2.1) - Transient(Response), + Transient(Code), /// Permanent SMTP error, 5xx reply code /// /// [RFC 5321, section 4.2.1](https://tools.ietf.org/html/rfc5321#section-4.2.1) - Permanent(Response), + Permanent(Code), /// Error parsing a response - ResponseParsing(&'static str), - /// Error parsing a base64 string in response - ChallengeParsing(DecodeError), - /// Error parsing UTF8 in response - Utf8Parsing(FromUtf8Error), + Response, /// Internal client error - Client(&'static str), - /// DNS resolution error - Resolution, - /// IO error - Io(io::Error), + Client, + /// Connection error + Connection, + /// Underlying network i/o error + Network, /// TLS error - #[cfg(feature = "native-tls")] - #[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))] - Tls(native_tls::Error), - /// Parsing error - Parsing(nom::error::ErrorKind), - /// Invalid hostname - #[cfg(feature = "rustls-tls")] - #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))] - InvalidDNSName(webpki::InvalidDNSNameError), - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] - InvalidCertificate, - #[cfg(feature = "r2d2")] - #[cfg_attr(docsrs, doc(cfg(feature = "r2d2")))] - Pool(r2d2::Error), + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + Tls, } -impl Display for Error { - fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmt::Error> { - match *self { - // Try to display the first line of the server's response that usually - // contains a short humanly readable error message - Transient(ref err) => fmt.write_str( - err.first_line() - .unwrap_or("transient error during SMTP transaction"), - ), - Permanent(ref err) => fmt.write_str( - err.first_line() - .unwrap_or("permanent error during SMTP transaction"), - ), - ResponseParsing(err) => fmt.write_str(err), - ChallengeParsing(ref err) => err.fmt(fmt), - Utf8Parsing(ref err) => err.fmt(fmt), - Resolution => fmt.write_str("could not resolve hostname"), - Client(err) => fmt.write_str(err), - Io(ref err) => err.fmt(fmt), - #[cfg(feature = "native-tls")] - Tls(ref err) => err.fmt(fmt), - Parsing(ref err) => fmt.write_str(err.description()), - #[cfg(feature = "rustls-tls")] - InvalidDNSName(ref err) => err.fmt(fmt), - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] - InvalidCertificate => fmt.write_str("invalid certificate"), - #[cfg(feature = "r2d2")] - Pool(ref err) => err.fmt(fmt), +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("lettre::Error"); + + builder.field("kind", &self.inner.kind); + + if let Some(ref source) = self.inner.source { + builder.field("source", source); } + + builder.finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.inner.kind { + Kind::Response => f.write_str("response error")?, + Kind::Client => f.write_str("internal client error")?, + Kind::Network => f.write_str("network error")?, + Kind::Connection => f.write_str("Connection error")?, + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + Kind::Tls => f.write_str("tls error")?, + Kind::Transient(ref code) => { + write!(f, "transient error ({})", code)?; + } + Kind::Permanent(ref code) => { + write!(f, "permanent error ({})", code)?; + } + }; + + if let Some(ref e) = self.inner.source { + write!(f, ": {}", e)?; + } + + Ok(()) } } impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { - match *self { - ChallengeParsing(ref err) => Some(&*err), - Utf8Parsing(ref err) => Some(&*err), - Io(ref err) => Some(&*err), - #[cfg(feature = "native-tls")] - Tls(ref err) => Some(&*err), - _ => None, - } - } -} - -impl From for Error { - fn from(err: io::Error) -> Error { - Io(err) - } -} - -#[cfg(feature = "native-tls")] -impl From for Error { - fn from(err: native_tls::Error) -> Error { - Tls(err) - } -} - -impl From>> for Error { - fn from(err: nom::Err>) -> Error { - Parsing(match err { - nom::Err::Incomplete(_) => nom::error::ErrorKind::Complete, - nom::Err::Failure(e) => e.code, - nom::Err::Error(e) => e.code, + self.inner.source.as_ref().map(|e| { + let r: &(dyn std::error::Error + 'static) = &**e; + r }) } } -impl From for Error { - fn from(err: DecodeError) -> Error { - ChallengeParsing(err) +pub(crate) fn code(c: Code) -> Error { + match c.severity { + Severity::TransientNegativeCompletion => Error::new::(Kind::Transient(c), None), + Severity::PermanentNegativeCompletion => Error::new::(Kind::Permanent(c), None), + _ => client("Unknown error code"), } } -impl From for Error { - fn from(err: FromUtf8Error) -> Error { - Utf8Parsing(err) - } +pub(crate) fn response>(e: E) -> Error { + Error::new(Kind::Response, Some(e)) } -#[cfg(feature = "rustls-tls")] -impl From for Error { - fn from(err: webpki::InvalidDNSNameError) -> Error { - InvalidDNSName(err) - } +pub(crate) fn client>(e: E) -> Error { + Error::new(Kind::Client, Some(e)) } -#[cfg(feature = "r2d2")] -impl From for Error { - fn from(err: r2d2::Error) -> Error { - Pool(err) - } +pub(crate) fn network>(e: E) -> Error { + Error::new(Kind::Network, Some(e)) } -impl From for Error { - fn from(response: Response) -> Error { - match response.code.severity { - Severity::TransientNegativeCompletion => Transient(response), - Severity::PermanentNegativeCompletion => Permanent(response), - _ => Client("Unknown error code"), - } - } +pub(crate) fn connection>(e: E) -> Error { + Error::new(Kind::Connection, Some(e)) } -impl From<&'static str> for Error { - fn from(string: &'static str) -> Error { - Client(string) - } +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] +pub(crate) fn tls>(e: E) -> Error { + Error::new(Kind::Tls, Some(e)) } diff --git a/src/transport/smtp/extension.rs b/src/transport/smtp/extension.rs index 66372cc..fb58ad0 100644 --- a/src/transport/smtp/extension.rs +++ b/src/transport/smtp/extension.rs @@ -1,7 +1,10 @@ //! ESMTP features use crate::transport::smtp::{ - authentication::Mechanism, error::Error, response::Response, util::XText, + authentication::Mechanism, + error::{self, Error}, + response::Response, + util::XText, }; use std::{ collections::HashSet, @@ -126,7 +129,7 @@ impl ServerInfo { pub fn from_response(response: &Response) -> Result { let name = match response.first_word() { Some(name) => name, - None => return Err(Error::ResponseParsing("Could not read server name")), + None => return Err(error::response("Could not read server name")), }; let mut features: HashSet = HashSet::new(); diff --git a/src/transport/smtp/pool.rs b/src/transport/smtp/pool.rs index 07911a5..3c495a8 100644 --- a/src/transport/smtp/pool.rs +++ b/src/transport/smtp/pool.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use crate::transport::smtp::{client::SmtpConnection, error::Error, SmtpClient}; +use crate::transport::smtp::{client::SmtpConnection, error, error::Error, SmtpClient}; use r2d2::{CustomizeConnection, ManageConnection, Pool}; @@ -90,7 +90,7 @@ impl ManageConnection for SmtpClient { if conn.test_connected() { return Ok(()); } - Err(Error::Client("is not connected anymore")) + Err(error::network("is not connected anymore")) } fn has_broken(&self, conn: &mut Self::Connection) -> bool { diff --git a/src/transport/smtp/response.rs b/src/transport/smtp/response.rs index 40e18b3..c652e4b 100644 --- a/src/transport/smtp/response.rs +++ b/src/transport/smtp/response.rs @@ -1,7 +1,7 @@ //! SMTP response, containing a mandatory return code and an optional text //! message -use crate::transport::smtp::Error; +use crate::transport::smtp::{error, Error}; use nom::{ branch::alt, bytes::streaming::{tag, take_until}, @@ -120,6 +120,14 @@ impl Code { detail, } } + + /// Tells if the response is positive + pub fn is_positive(&self) -> bool { + matches!( + self.severity, + Severity::PositiveCompletion | Severity::PositiveIntermediate + ) + } } /// Contains an SMTP reply, with separated code and message @@ -139,7 +147,9 @@ impl FromStr for Response { type Err = Error; fn from_str(s: &str) -> result::Result { - parse_response(s).map(|(_, r)| r).map_err(|e| e.into()) + parse_response(s) + .map(|(_, r)| r) + .map_err(|e| error::response(e.to_string())) } } @@ -151,10 +161,7 @@ impl Response { /// Tells if the response is positive pub fn is_positive(&self) -> bool { - matches!( - self.code.severity, - Severity::PositiveCompletion | Severity::PositiveIntermediate - ) + self.code.is_positive() } /// Tests code equality diff --git a/src/transport/smtp/transport.rs b/src/transport/smtp/transport.rs index ed2316f..f3a6722 100644 --- a/src/transport/smtp/transport.rs +++ b/src/transport/smtp/transport.rs @@ -5,9 +5,9 @@ use r2d2::Pool; #[cfg(feature = "r2d2")] use super::PoolConfig; -use super::{ClientId, Credentials, Error, Mechanism, Response, SmtpConnection, SmtpInfo}; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] -use super::{Tls, TlsParameters, SUBMISSIONS_PORT, SUBMISSION_PORT}; +use super::{error, Tls, TlsParameters, SUBMISSIONS_PORT, SUBMISSION_PORT}; +use super::{ClientId, Credentials, Error, Mechanism, Response, SmtpConnection, SmtpInfo}; use crate::{address::Envelope, Transport}; /// Sends emails using the SMTP protocol @@ -28,7 +28,7 @@ impl Transport for SmtpTransport { /// Sends an email fn send_raw(&self, envelope: &Envelope, email: &[u8]) -> Result { #[cfg(feature = "r2d2")] - let mut conn = self.inner.get()?; + let mut conn = self.inner.get().map_err(error::client)?; #[cfg(not(feature = "r2d2"))] let mut conn = self.inner.connection()?;