From c0ef9a38a17cae9de087ecfd6bd294e6bb0f1302 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Wed, 19 Aug 2020 12:56:02 +0200 Subject: [PATCH] Implement async smtp via tokio 0.2 --- Cargo.toml | 19 +- examples/tokio02_smtp_starttls.rs | 32 ++ examples/tokio02_smtp_tls.rs | 32 ++ src/lib.rs | 2 + src/transport/smtp/async_transport.rs | 230 +++++++++++++++ src/transport/smtp/client/async_connection.rs | 275 ++++++++++++++++++ src/transport/smtp/client/async_net.rs | 243 ++++++++++++++++ src/transport/smtp/client/mod.rs | 11 +- src/transport/smtp/client/tls.rs | 11 +- src/transport/smtp/mod.rs | 7 + 10 files changed, 857 insertions(+), 5 deletions(-) create mode 100644 examples/tokio02_smtp_starttls.rs create mode 100644 examples/tokio02_smtp_tls.rs create mode 100644 src/transport/smtp/async_transport.rs create mode 100644 src/transport/smtp/client/async_connection.rs create mode 100644 src/transport/smtp/client/async_net.rs diff --git a/Cargo.toml b/Cargo.toml index 221efba..d1a1218 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,11 @@ maintenance = { status = "actively-developed" } async-attributes = { version = "1.1", optional = true } async-std = { version = "1.5", optional = true, features = ["unstable"] } async-trait = { version = "0.1", optional = true } -tokio02_crate = { package = "tokio", version = "0.2.7", features = ["fs", "process", "io-util"], optional = true } +tokio02_crate = { package = "tokio", version = "0.2.7", features = ["fs", "process", "tcp", "dns", "io-util"], optional = true } +tokio02_native_tls_crate = { package = "tokio-native-tls", version = "0.1", optional = true } +tokio02_rustls = { package = "tokio-rustls", version = "0.14", optional = true } +futures-io = { version = "0.3", optional = true } +futures-util = { version = "0.3", features = ["io"], optional = true } base64 = { version = "0.12", optional = true } hostname = { version = "0.3", optional = true } hyperx = { version = "1", optional = true, features = ["headers"] } @@ -54,10 +58,13 @@ name = "transport_smtp" [features] async-std1 = ["async-std", "async-trait", "async-attributes"] -tokio02 = ["tokio02_crate", "async-trait"] +tokio02 = ["tokio02_crate", "async-trait", "futures-io", "futures-util"] +tokio02-native-tls = ["tokio02", "native-tls", "tokio02_native_tls_crate"] +tokio02-rustls-tls = ["tokio02", "rustls-tls", "tokio02_rustls"] builder = ["mime", "base64", "hyperx", "rand", "quoted_printable"] default = ["file-transport", "smtp-transport", "native-tls", "hostname", "r2d2", "sendmail-transport", "builder"] file-transport = ["serde", "serde_json"] +# native-tls rustls-tls = ["webpki", "webpki-roots", "rustls"] sendmail-transport = [] smtp-transport = ["base64", "nom"] @@ -77,3 +84,11 @@ required-features = ["smtp-transport", "native-tls"] [[example]] name = "smtp_starttls" required-features = ["smtp-transport", "native-tls"] + +[[example]] +name = "tokio02_smtp_tls" +required-features = ["smtp-transport", "tokio02", "tokio02-native-tls"] + +[[example]] +name = "tokio02_smtp_starttls" +required-features = ["smtp-transport", "tokio02", "tokio02-native-tls"] diff --git a/examples/tokio02_smtp_starttls.rs b/examples/tokio02_smtp_starttls.rs new file mode 100644 index 0000000..2f9663e --- /dev/null +++ b/examples/tokio02_smtp_starttls.rs @@ -0,0 +1,32 @@ +// This line is only to make it compile from lettre's examples folder, +// since it uses Rust 2018 crate renaming to import tokio. +// Won't be needed in user's code. +use tokio02_crate as tokio; + +use lettre::transport::smtp::authentication::Credentials; +use lettre::{AsyncSmtpTransport, Message, Tokio02Connector, Tokio02Transport}; + +#[tokio::main] +async fn main() { + let email = Message::builder() + .from("NoBody ".parse().unwrap()) + .reply_to("Yuin ".parse().unwrap()) + .to("Hei ".parse().unwrap()) + .subject("Happy new async year") + .body("Be happy with async!") + .unwrap(); + + let creds = Credentials::new("smtp_username".to_string(), "smtp_password".to_string()); + + // Open a remote connection to gmail using STARTTLS + let mailer = AsyncSmtpTransport::::starttls_relay("smtp.gmail.com") + .unwrap() + .credentials(creds) + .build(); + + // Send the email + match mailer.send(email).await { + Ok(_) => println!("Email sent successfully!"), + Err(e) => panic!("Could not send email: {:?}", e), + } +} diff --git a/examples/tokio02_smtp_tls.rs b/examples/tokio02_smtp_tls.rs new file mode 100644 index 0000000..979b125 --- /dev/null +++ b/examples/tokio02_smtp_tls.rs @@ -0,0 +1,32 @@ +// This line is only to make it compile from lettre's examples folder, +// since it uses Rust 2018 crate renaming to import tokio. +// Won't be needed in user's code. +use tokio02_crate as tokio; + +use lettre::transport::smtp::authentication::Credentials; +use lettre::{AsyncSmtpTransport, Message, Tokio02Connector, Tokio02Transport}; + +#[tokio::main] +async fn main() { + let email = Message::builder() + .from("NoBody ".parse().unwrap()) + .reply_to("Yuin ".parse().unwrap()) + .to("Hei ".parse().unwrap()) + .subject("Happy new async year") + .body("Be happy with async!") + .unwrap(); + + let creds = Credentials::new("smtp_username".to_string(), "smtp_password".to_string()); + + // Open a remote connection to gmail + let mailer = AsyncSmtpTransport::::relay("smtp.gmail.com") + .unwrap() + .credentials(creds) + .build(); + + // Send the email + match mailer.send(email).await { + Ok(_) => println!("Email sent successfully!"), + Err(e) => panic!("Could not send email: {:?}", e), + } +} diff --git a/src/lib.rs b/src/lib.rs index b0d886d..fbb3651 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,8 @@ pub use crate::transport::sendmail::SendmailTransport; pub use crate::transport::smtp::r2d2::SmtpConnectionManager; #[cfg(feature = "smtp-transport")] pub use crate::transport::smtp::SmtpTransport; +#[cfg(all(feature = "smtp-transport", feature = "tokio02"))] +pub use crate::transport::smtp::{AsyncSmtpTransport, Tokio02Connector}; pub use crate::{address::Address, transport::stub::StubTransport}; #[cfg(any(feature = "async-std1", feature = "tokio02"))] use async_trait::async_trait; diff --git a/src/transport/smtp/async_transport.rs b/src/transport/smtp/async_transport.rs new file mode 100644 index 0000000..b169b21 --- /dev/null +++ b/src/transport/smtp/async_transport.rs @@ -0,0 +1,230 @@ +use async_trait::async_trait; + +use super::client::AsyncSmtpConnection; +#[cfg(feature = "tokio02")] +use super::Tls; +use super::{ClientId, Credentials, Error, Mechanism, Response, SmtpInfo}; +use crate::{Envelope, Tokio02Transport}; + +#[allow(missing_debug_implementations)] +#[derive(Clone)] +pub struct AsyncSmtpTransport { + // TODO: pool + inner: AsyncSmtpClient, +} + +#[async_trait] +impl Tokio02Transport for AsyncSmtpTransport { + type Ok = Response; + type Error = Error; + + /// Sends an email + async fn send_raw(&self, envelope: &Envelope, email: &[u8]) -> Result { + let mut conn = self.inner.connection().await?; + + let result = conn.send(envelope, email).await?; + + conn.quit().await?; + + Ok(result) + } +} + +impl AsyncSmtpTransport +where + C: AsyncSmtpConnector, +{ + /// Simple and secure transport, should be used when possible. + /// Creates an encrypted transport over submissions port, using the provided domain + /// to validate TLS certificates. + #[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))] + pub fn relay(relay: &str) -> Result { + use super::{TlsParameters, SUBMISSIONS_PORT}; + + let tls_parameters = TlsParameters::new_tokio02(relay.into())?; + + Ok(Self::builder_dangerous(relay) + .port(SUBMISSIONS_PORT) + .tls(Tls::Wrapper(tls_parameters))) + } + + /// Simple and secure transport, should be used when the server doesn't support wrapped TLS connections. + /// Creates an encrypted transport over submissions port, by first connecting using an unencrypted + /// connection and then upgrading it with STARTTLS, using the provided domain to validate TLS certificates. + /// If the connection can't be upgraded it will fail connecting altogether. + #[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))] + pub fn starttls_relay(relay: &str) -> Result { + use super::{TlsParameters, SUBMISSION_PORT}; + + let tls_parameters = TlsParameters::new(relay.into())?; + + Ok(Self::builder_dangerous(relay) + .port(SUBMISSION_PORT) + .tls(Tls::Required(tls_parameters))) + } + + /// Creates a new local SMTP client to port 25 + /// + /// Shortcut for local unencrypted relay (typical local email daemon that will handle relaying) + pub fn unencrypted_localhost() -> AsyncSmtpTransport { + Self::builder_dangerous("localhost").build() + } + + /// Creates a new SMTP client + /// + /// Defaults are: + /// + /// * No authentication + /// * No TLS + /// * Port 25 + /// + /// Consider using [`AsyncSmtpTransport::relay`] instead, if possible. + pub fn builder_dangerous>(server: T) -> AsyncSmtpTransportBuilder { + let mut new = SmtpInfo::default(); + new.server = server.into(); + AsyncSmtpTransportBuilder { info: new } + } +} + +/// Contains client configuration +#[allow(missing_debug_implementations)] +#[derive(Clone)] +pub struct AsyncSmtpTransportBuilder { + info: SmtpInfo, +} + +/// Builder for the SMTP `AsyncSmtpTransport` +impl AsyncSmtpTransportBuilder { + /// Set the name used during EHLO + pub fn hello_name(mut self, name: ClientId) -> Self { + self.info.hello_name = name; + self + } + + /// Set the authentication mechanism to use + pub fn credentials(mut self, credentials: Credentials) -> Self { + self.info.credentials = Some(credentials); + self + } + + /// Set the authentication mechanism to use + pub fn authentication(mut self, mechanisms: Vec) -> Self { + self.info.authentication = mechanisms; + self + } + + /// Set the port to use + pub fn port(mut self, port: u16) -> Self { + self.info.port = port; + self + } + + /// Set the TLS settings to use + #[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))] + pub fn tls(mut self, tls: Tls) -> Self { + self.info.tls = tls; + self + } + + /// Build the transport (with default pool if enabled) + pub fn build(self) -> AsyncSmtpTransport + where + C: AsyncSmtpConnector, + { + let connector = Default::default(); + let client = AsyncSmtpClient { + connector, + info: self.info, + }; + AsyncSmtpTransport { inner: client } + } +} + +/// Build client +#[derive(Clone)] +pub struct AsyncSmtpClient { + connector: C, + info: SmtpInfo, +} + +impl AsyncSmtpClient +where + C: AsyncSmtpConnector, +{ + /// Creates a new connection directly usable to send emails + /// + /// Handles encryption and authentication + pub async fn connection(&self) -> Result { + let mut conn = C::connect( + &self.info.server, + self.info.port, + &self.info.hello_name, + &self.info.tls, + ) + .await?; + + if let Some(credentials) = &self.info.credentials { + conn.auth(&self.info.authentication, &credentials).await?; + } + Ok(conn) + } +} + +#[async_trait] +pub trait AsyncSmtpConnector: Default + private::Sealed { + async fn connect( + hostname: &str, + port: u16, + hello_name: &ClientId, + tls: &Tls, + ) -> Result; +} + +#[derive(Debug, Copy, Clone, Default)] +#[cfg(feature = "tokio02")] +pub struct Tokio02Connector; + +#[async_trait] +#[cfg(feature = "tokio02")] +impl AsyncSmtpConnector for Tokio02Connector { + async fn connect( + hostname: &str, + port: u16, + hello_name: &ClientId, + tls: &Tls, + ) -> Result { + #[allow(clippy::match_single_binding)] + let tls_parameters = match tls { + #[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))] + Tls::Wrapper(ref tls_parameters) => Some(tls_parameters.clone()), + _ => None, + }; + let mut conn = + AsyncSmtpConnection::connect_tokio02(hostname, port, hello_name, tls_parameters) + .await?; + + #[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))] + match tls { + Tls::Opportunistic(ref tls_parameters) => { + if conn.can_starttls() { + conn.starttls(tls_parameters.clone(), hello_name).await?; + } + } + Tls::Required(ref tls_parameters) => { + conn.starttls(tls_parameters.clone(), hello_name).await?; + } + _ => (), + } + + Ok(conn) + } +} + +mod private { + use super::*; + + pub trait Sealed {} + + #[cfg(feature = "tokio02")] + impl Sealed for Tokio02Connector {} +} diff --git a/src/transport/smtp/client/async_connection.rs b/src/transport/smtp/client/async_connection.rs new file mode 100644 index 0000000..a014d05 --- /dev/null +++ b/src/transport/smtp/client/async_connection.rs @@ -0,0 +1,275 @@ +use std::fmt::Display; +use std::io; + +use futures_util::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; + +#[cfg(feature = "log")] +use log::debug; + +use super::{AsyncNetworkStream, ClientCodec, TlsParameters}; +use crate::transport::smtp::authentication::{Credentials, Mechanism}; +use crate::transport::smtp::commands::*; +use crate::transport::smtp::error::Error; +use crate::transport::smtp::extension::{ + ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo, +}; +use crate::transport::smtp::response::{parse_response, Response}; +use crate::Envelope; + +#[cfg(feature = "log")] +use super::escape_crlf; + +macro_rules! try_smtp ( + ($err: expr, $client: ident) => ({ + match $err { + Ok(val) => val, + Err(err) => { + $client.abort().await; + return Err(From::from(err)) + }, + } + }) +); + +/// Structure that implements the SMTP client +pub struct AsyncSmtpConnection { + /// TCP stream between client and server + /// Value is None before connection + stream: BufReader, + /// Panic state + panic: bool, + /// Information about the server + server_info: ServerInfo, +} + +impl AsyncSmtpConnection { + pub fn server_info(&self) -> &ServerInfo { + &self.server_info + } + + // FIXME add simple connect and rename this one + + /// Connects to the configured server + /// + /// Sends EHLO and parses server information + pub async fn connect_tokio02( + hostname: &str, + port: u16, + hello_name: &ClientId, + tls_parameters: Option, + ) -> Result { + let stream = AsyncNetworkStream::connect_tokio02(hostname, port, tls_parameters).await?; + Self::connect_impl(stream, hello_name).await + } + + async fn connect_impl( + stream: AsyncNetworkStream, + hello_name: &ClientId, + ) -> Result { + let stream = BufReader::new(stream); + let mut conn = AsyncSmtpConnection { + stream, + panic: false, + server_info: ServerInfo::default(), + }; + // TODO log + let _response = conn.read_response().await?; + + conn.ehlo(hello_name).await?; + + // Print server information + #[cfg(feature = "log")] + debug!("server {}", conn.server_info); + Ok(conn) + } + + pub async fn send(&mut self, envelope: &Envelope, email: &[u8]) -> Result { + // Mail + let mut mail_options = vec![]; + + if self.server_info().supports_feature(Extension::EightBitMime) { + mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime)); + } + try_smtp!( + self.command(Mail::new(envelope.from().cloned(), mail_options)) + .await, + self + ); + + // Recipient + for to_address in envelope.to() { + try_smtp!( + self.command(Rcpt::new(to_address.clone(), vec![])).await, + self + ); + } + + // Data + try_smtp!(self.command(Data).await, self); + + // Message content + let result = try_smtp!(self.message(email).await, self); + Ok(result) + } + + pub fn has_broken(&self) -> bool { + self.panic + } + + pub fn can_starttls(&self) -> bool { + !self.is_encrypted() && self.server_info.supports_feature(Extension::StartTls) + } + + #[allow(unused_variables)] + pub async fn starttls( + &mut self, + tls_parameters: TlsParameters, + hello_name: &ClientId, + ) -> Result<(), Error> { + if self.server_info.supports_feature(Extension::StartTls) { + try_smtp!(self.command(Starttls).await, self); + try_smtp!( + self.stream.get_mut().upgrade_tls(tls_parameters).await, + self + ); + #[cfg(feature = "log")] + debug!("connection encrypted"); + // Send EHLO again + try_smtp!(self.ehlo(hello_name).await, self); + Ok(()) + } else { + Err(Error::Client("STARTTLS is not supported on this server")) + } + } + + /// Send EHLO and update server info + async fn ehlo(&mut self, hello_name: &ClientId) -> Result<(), Error> { + let ehlo_response = try_smtp!( + self.command(Ehlo::new(ClientId::new(hello_name.to_string()))) + .await, + self + ); + self.server_info = try_smtp!(ServerInfo::from_response(&ehlo_response), self); + Ok(()) + } + + pub async fn quit(&mut self) -> Result { + Ok(try_smtp!(self.command(Quit).await, self)) + } + + pub async fn abort(&mut self) { + // Only try to quit if we are not already broken + if !self.panic { + self.panic = true; + let _ = self.command(Quit).await; + } + } + + /// Sets the underlying stream + pub fn set_stream(&mut self, stream: AsyncNetworkStream) { + self.stream = BufReader::new(stream); + } + + /// Tells if the underlying stream is currently encrypted + pub fn is_encrypted(&self) -> bool { + self.stream.get_ref().is_encrypted() + } + + /// Checks if the server is connected using the NOOP SMTP command + pub async fn test_connected(&mut self) -> bool { + self.command(Noop).await.is_ok() + } + + /// Sends an AUTH command with the given mechanism, and handles challenge if needed + pub async fn auth( + &mut self, + mechanisms: &[Mechanism], + credentials: &Credentials, + ) -> Result { + let mechanism = self + .server_info + .get_auth_mechanism(mechanisms) + .ok_or(Error::Client( + "No compatible authentication mechanism was found", + ))?; + + // Limit challenges to avoid blocking + let mut challenges = 10; + let mut response = self + .command(Auth::new(mechanism, credentials.clone(), None)?) + .await?; + + while challenges > 0 && response.has_code(334) { + challenges -= 1; + response = try_smtp!( + self.command(Auth::new_from_response( + mechanism, + credentials.clone(), + &response, + )?) + .await, + self + ); + } + + if challenges == 0 { + Err(Error::ResponseParsing("Unexpected number of challenges")) + } else { + Ok(response) + } + } + + /// Sends the message content + pub async fn message(&mut self, message: &[u8]) -> Result { + let mut out_buf: Vec = vec![]; + let mut codec = ClientCodec::new(); + codec.encode(message, &mut out_buf); + self.write(out_buf.as_slice()).await?; + self.write(b"\r\n.\r\n").await?; + self.read_response().await + } + + /// Sends an SMTP command + pub async fn command(&mut self, command: C) -> Result { + self.write(command.to_string().as_bytes()).await?; + self.read_response().await + } + + /// Writes a string to the server + async fn write(&mut self, string: &[u8]) -> Result<(), Error> { + self.stream.get_mut().write_all(string).await?; + self.stream.get_mut().flush().await?; + + #[cfg(feature = "log")] + debug!("Wrote: {}", escape_crlf(&String::from_utf8_lossy(string))); + Ok(()) + } + + /// Gets the SMTP response + pub async fn read_response(&mut self) -> Result { + let mut buffer = String::with_capacity(100); + + while self.stream.read_line(&mut buffer).await? > 0 { + #[cfg(feature = "log")] + debug!("<< {}", escape_crlf(&buffer)); + match parse_response(&buffer) { + Ok((_remaining, response)) => { + if response.is_positive() { + return Ok(response); + } + + return Err(response.into()); + } + Err(nom::Err::Failure(e)) => { + return Err(Error::Parsing(e.1)); + } + Err(nom::Err::Incomplete(_)) => { /* read more */ } + Err(nom::Err::Error(e)) => { + return Err(Error::Parsing(e.1)); + } + } + } + + Err(io::Error::new(io::ErrorKind::Other, "incomplete").into()) + } +} diff --git a/src/transport/smtp/client/async_net.rs b/src/transport/smtp/client/async_net.rs new file mode 100644 index 0000000..5dcfead --- /dev/null +++ b/src/transport/smtp/client/async_net.rs @@ -0,0 +1,243 @@ +use std::net::{Shutdown, SocketAddr}; +use std::pin::Pin; +#[cfg(feature = "tokio02-rustls-tls")] +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures_io::{Error as IoError, ErrorKind, Result as IoResult}; +#[cfg(feature = "tokio02")] +use tokio02_crate::io::{AsyncRead, AsyncWrite}; +#[cfg(feature = "tokio02")] +use tokio02_crate::net::TcpStream; + +#[cfg(feature = "tokio02-native-tls")] +use tokio02_native_tls_crate::TlsStream; + +#[cfg(feature = "tokio02-rustls-tls")] +use tokio02_rustls::client::TlsStream as RustlsTlsStream; + +#[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))] +use super::InnerTlsParameters; +use super::TlsParameters; +use crate::transport::smtp::Error; + +/// A network stream +pub struct AsyncNetworkStream { + inner: InnerAsyncNetworkStream, +} + +/// Represents the different types of underlying network streams +#[allow(dead_code)] +enum InnerAsyncNetworkStream { + /// Plain TCP stream + #[cfg(feature = "tokio02")] + Tokio02Tcp(TcpStream), + /// Encrypted TCP stream + #[cfg(feature = "tokio02-native-tls")] + Tokio02NativeTls(TlsStream), + /// Encrypted TCP stream + #[cfg(feature = "tokio02-rustls-tls")] + Tokio02RustlsTls(Box>), + /// Can't be built + None, +} + +impl AsyncNetworkStream { + fn new(inner: InnerAsyncNetworkStream) -> Self { + if let InnerAsyncNetworkStream::None = inner { + debug_assert!(false, "InnerAsyncNetworkStream::None should never be built"); + } + + AsyncNetworkStream { inner } + } + + /// Returns peer's address + pub fn peer_addr(&self) -> IoResult { + match self.inner { + #[cfg(feature = "tokio02")] + InnerAsyncNetworkStream::Tokio02Tcp(ref s) => s.peer_addr(), + #[cfg(feature = "tokio02-native-tls")] + InnerAsyncNetworkStream::Tokio02NativeTls(ref s) => { + s.get_ref().get_ref().get_ref().peer_addr() + } + #[cfg(feature = "tokio02-rustls-tls")] + InnerAsyncNetworkStream::Tokio02RustlsTls(ref s) => s.get_ref().0.peer_addr(), + InnerAsyncNetworkStream::None => { + debug_assert!(false, "InnerAsyncNetworkStream::None should never be built"); + Err(IoError::new( + ErrorKind::Other, + "InnerAsyncNetworkStream::None should never be built", + )) + } + } + } + + /// Shutdowns the connection + pub fn shutdown(&self, how: Shutdown) -> IoResult<()> { + match self.inner { + #[cfg(feature = "tokio02")] + InnerAsyncNetworkStream::Tokio02Tcp(ref s) => s.shutdown(how), + #[cfg(feature = "tokio02-native-tls")] + InnerAsyncNetworkStream::Tokio02NativeTls(ref s) => { + s.get_ref().get_ref().get_ref().shutdown(how) + } + #[cfg(feature = "tokio02-rustls-tls")] + InnerAsyncNetworkStream::Tokio02RustlsTls(ref s) => s.get_ref().0.shutdown(how), + InnerAsyncNetworkStream::None => { + debug_assert!(false, "InnerAsyncNetworkStream::None should never be built"); + Ok(()) + } + } + } + + #[cfg(feature = "tokio02")] + pub async fn connect_tokio02( + hostname: &str, + port: u16, + tls_parameters: Option, + ) -> Result { + let tcp_stream = TcpStream::connect((hostname, port)).await?; + + let mut stream = AsyncNetworkStream::new(InnerAsyncNetworkStream::Tokio02Tcp(tcp_stream)); + if let Some(tls_parameters) = tls_parameters { + stream.upgrade_tls(tls_parameters).await?; + } + Ok(stream) + } + + pub async fn upgrade_tls(&mut self, tls_parameters: TlsParameters) -> Result<(), Error> { + match &self.inner { + #[cfg(not(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls")))] + InnerAsyncNetworkStream::Tokio02Tcp(_) => { + let _ = tls_parameters; + panic!("Trying to upgrade an AsyncNetworkStream without having enabled either the tokio02-native-tls or the tokio02-rustls-tls feature"); + } + + #[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))] + InnerAsyncNetworkStream::Tokio02Tcp(_) => { + // get owned TcpStream + let tcp_stream = std::mem::replace(&mut self.inner, InnerAsyncNetworkStream::None); + let tcp_stream = match tcp_stream { + InnerAsyncNetworkStream::Tokio02Tcp(tcp_stream) => tcp_stream, + _ => unreachable!(), + }; + + self.inner = Self::upgrade_tokio02_tls(tcp_stream, tls_parameters).await?; + Ok(()) + } + _ => Ok(()), + } + } + + #[allow(unused_variables)] + #[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))] + async fn upgrade_tokio02_tls( + tcp_stream: TcpStream, + mut tls_parameters: TlsParameters, + ) -> Result { + let domain = std::mem::take(&mut tls_parameters.domain); + + match tls_parameters.connector { + #[cfg(feature = "native-tls")] + InnerTlsParameters::NativeTls(connector) => { + #[cfg(not(feature = "tokio02-native-tls"))] + panic!("built without the tokio02-native-tls feature"); + + #[cfg(feature = "tokio02-native-tls")] + return { + use tokio02_native_tls_crate::TlsConnector; + + let connector = TlsConnector::from(connector); + let stream = connector.connect(&domain, tcp_stream).await?; + Ok(InnerAsyncNetworkStream::Tokio02NativeTls(stream)) + }; + } + #[cfg(feature = "rustls-tls")] + InnerTlsParameters::RustlsTls(config) => { + #[cfg(not(feature = "tokio02-rustls-tls"))] + panic!("built without the tokio02-rustls-tls feature"); + + #[cfg(feature = "tokio02-rustls-tls")] + return { + use tokio02_rustls::webpki::DNSNameRef; + use tokio02_rustls::TlsConnector; + + let domain = DNSNameRef::try_from_ascii_str(&domain)?; + + let connector = TlsConnector::from(Arc::new(config)); + let stream = connector.connect(domain, tcp_stream).await?; + Ok(InnerAsyncNetworkStream::Tokio02RustlsTls(Box::new(stream))) + }; + } + } + } + + pub fn is_encrypted(&self) -> bool { + match self.inner { + #[cfg(feature = "tokio02")] + InnerAsyncNetworkStream::Tokio02Tcp(_) => false, + #[cfg(feature = "tokio02-native-tls")] + InnerAsyncNetworkStream::Tokio02NativeTls(_) => true, + #[cfg(feature = "tokio02-rustls-tls")] + InnerAsyncNetworkStream::Tokio02RustlsTls(_) => true, + InnerAsyncNetworkStream::None => false, + } + } +} + +impl futures_io::AsyncRead for AsyncNetworkStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + match self.inner { + #[cfg(feature = "tokio02")] + InnerAsyncNetworkStream::Tokio02Tcp(ref mut s) => Pin::new(s).poll_read(cx, buf), + #[cfg(feature = "tokio02-native-tls")] + InnerAsyncNetworkStream::Tokio02NativeTls(ref mut s) => Pin::new(s).poll_read(cx, buf), + #[cfg(feature = "tokio02-rustls-tls")] + InnerAsyncNetworkStream::Tokio02RustlsTls(ref mut s) => Pin::new(s).poll_read(cx, buf), + InnerAsyncNetworkStream::None => { + debug_assert!(false, "InnerAsyncNetworkStream::None should never be built"); + Poll::Ready(Ok(0)) + } + } + } +} + +impl futures_io::AsyncWrite for AsyncNetworkStream { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + match self.inner { + #[cfg(feature = "tokio02")] + InnerAsyncNetworkStream::Tokio02Tcp(ref mut s) => Pin::new(s).poll_write(cx, buf), + #[cfg(feature = "tokio02-native-tls")] + InnerAsyncNetworkStream::Tokio02NativeTls(ref mut s) => Pin::new(s).poll_write(cx, buf), + #[cfg(feature = "tokio02-rustls-tls")] + InnerAsyncNetworkStream::Tokio02RustlsTls(ref mut s) => Pin::new(s).poll_write(cx, buf), + InnerAsyncNetworkStream::None => { + debug_assert!(false, "InnerAsyncNetworkStream::None should never be built"); + Poll::Ready(Ok(0)) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.inner { + #[cfg(feature = "tokio02")] + InnerAsyncNetworkStream::Tokio02Tcp(ref mut s) => Pin::new(s).poll_flush(cx), + #[cfg(feature = "tokio02-native-tls")] + InnerAsyncNetworkStream::Tokio02NativeTls(ref mut s) => Pin::new(s).poll_flush(cx), + #[cfg(feature = "tokio02-rustls-tls")] + InnerAsyncNetworkStream::Tokio02RustlsTls(ref mut s) => Pin::new(s).poll_flush(cx), + InnerAsyncNetworkStream::None => { + debug_assert!(false, "InnerAsyncNetworkStream::None should never be built"); + Poll::Ready(Ok(())) + } + } + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + Poll::Ready(self.shutdown(Shutdown::Write)) + } +} diff --git a/src/transport/smtp/client/mod.rs b/src/transport/smtp/client/mod.rs index b2e0851..4582a69 100644 --- a/src/transport/smtp/client/mod.rs +++ b/src/transport/smtp/client/mod.rs @@ -3,14 +3,21 @@ #[cfg(feature = "serde")] use std::fmt::Debug; -use self::net::NetworkStream; - +#[cfg(feature = "tokio02")] +pub(crate) use self::async_connection::AsyncSmtpConnection; +#[cfg(feature = "tokio02")] +pub(crate) use self::async_net::AsyncNetworkStream; pub use self::connection::SmtpConnection; pub use self::mock::MockStream; +use self::net::NetworkStream; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] pub(super) use self::tls::InnerTlsParameters; pub use self::tls::{Tls, TlsParameters}; +#[cfg(feature = "tokio02")] +mod async_connection; +#[cfg(feature = "tokio02")] +mod async_net; mod connection; mod mock; mod net; diff --git a/src/transport/smtp/client/tls.rs b/src/transport/smtp/client/tls.rs index cd8a3c0..e9c6958 100644 --- a/src/transport/smtp/client/tls.rs +++ b/src/transport/smtp/client/tls.rs @@ -35,7 +35,7 @@ pub enum Tls { pub struct TlsParameters { pub(crate) connector: InnerTlsParameters, /// The domain name which is expected in the TLS certificate from the server - domain: String, + pub(super) domain: String, } #[derive(Clone)] @@ -58,6 +58,15 @@ impl TlsParameters { return Self::new_rustls(domain); } + #[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))] + pub(crate) fn new_tokio02(domain: String) -> Result { + #[cfg(feature = "tokio02-native-tls")] + return Self::new_native(domain); + + #[cfg(not(feature = "tokio02-native-tls"))] + return Self::new_rustls(domain); + } + /// Creates a new `TlsParameters` using native-tls #[cfg(feature = "native-tls")] pub fn new_native(domain: String) -> Result { diff --git a/src/transport/smtp/mod.rs b/src/transport/smtp/mod.rs index 619b8d9..42ebb42 100644 --- a/src/transport/smtp/mod.rs +++ b/src/transport/smtp/mod.rs @@ -178,6 +178,11 @@ use std::time::Duration; +#[cfg(feature = "tokio02")] +pub use self::async_transport::{ + AsyncSmtpClient, AsyncSmtpConnector, AsyncSmtpTransport, AsyncSmtpTransportBuilder, + Tokio02Connector, +}; pub use self::transport::{SmtpClient, SmtpTransport, SmtpTransportBuilder}; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] use crate::transport::smtp::client::TlsParameters; @@ -190,6 +195,8 @@ use crate::transport::smtp::{ }; use client::Tls; +#[cfg(feature = "tokio02")] +mod async_transport; pub mod authentication; pub mod client; pub mod commands;