diff --git a/Cargo.toml b/Cargo.toml index 95a2310..015332e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ async-std = { version = "1.8", optional = true, features = ["unstable"] } async-rustls = { version = "0.2", optional = true } ## tokio -tokio1_crate = { package = "tokio", version = "1", features = ["fs", "process", "net", "io-util"], optional = true } +tokio1_crate = { package = "tokio", version = "1", features = ["fs", "process", "time", "net", "io-util"], optional = true } tokio1_native_tls_crate = { package = "tokio-native-tls", version = "0.3", optional = true } tokio1_rustls = { package = "tokio-rustls", version = "0.22", optional = true } diff --git a/src/executor.rs b/src/executor.rs index d084f5f..2dabfe1 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -5,6 +5,8 @@ use std::fmt::Debug; use std::io::Result as IoResult; #[cfg(feature = "file-transport")] use std::path::Path; +#[cfg(feature = "smtp-transport")] +use std::time::Duration; #[cfg(all( feature = "smtp-transport", @@ -43,6 +45,7 @@ pub trait Executor: Debug + Send + Sync + private::Sealed { async fn connect( hostname: &str, port: u16, + timeout: Option, hello_name: &ClientId, tls: &Tls, ) -> Result; @@ -79,6 +82,7 @@ impl Executor for Tokio1Executor { async fn connect( hostname: &str, port: u16, + timeout: Option, hello_name: &ClientId, tls: &Tls, ) -> Result { @@ -89,8 +93,13 @@ impl Executor for Tokio1Executor { _ => None, }; #[allow(unused_mut)] - let mut conn = - AsyncSmtpConnection::connect_tokio1(hostname, port, hello_name, tls_parameters).await?; + let mut conn = AsyncSmtpConnection::connect_tokio1( + (hostname, port), + timeout, + hello_name, + tls_parameters, + ) + .await?; #[cfg(any(feature = "tokio1-native-tls", feature = "tokio1-rustls-tls"))] match tls { @@ -144,6 +153,7 @@ impl Executor for AsyncStd1Executor { async fn connect( hostname: &str, port: u16, + timeout: Option, hello_name: &ClientId, tls: &Tls, ) -> Result { @@ -154,9 +164,13 @@ impl Executor for AsyncStd1Executor { _ => None, }; #[allow(unused_mut)] - let mut conn = - AsyncSmtpConnection::connect_asyncstd1(hostname, port, hello_name, tls_parameters) - .await?; + let mut conn = AsyncSmtpConnection::connect_asyncstd1( + (hostname, port), + timeout, + hello_name, + tls_parameters, + ) + .await?; #[cfg(any(feature = "async-std1-native-tls", feature = "async-std1-rustls-tls"))] match tls { diff --git a/src/transport/smtp/async_transport.rs b/src/transport/smtp/async_transport.rs index a64d598..3fd6700 100644 --- a/src/transport/smtp/async_transport.rs +++ b/src/transport/smtp/async_transport.rs @@ -1,6 +1,7 @@ use std::{ fmt::{self, Debug}, marker::PhantomData, + time::Duration, }; use async_trait::async_trait; @@ -141,6 +142,7 @@ where /// /// * No authentication /// * No TLS + /// * A 60 seconds timeout for smtp commands /// * Port 25 /// /// Consider using [`AsyncSmtpTransport::relay`](#method.relay) or @@ -208,6 +210,12 @@ impl AsyncSmtpTransportBuilder { self } + /// Set the timeout duration + pub fn timeout(mut self, timeout: Option) -> Self { + self.info.timeout = timeout; + self + } + /// Set the TLS settings to use #[cfg(any( feature = "tokio1-native-tls", @@ -259,6 +267,7 @@ where let mut conn = E::connect( &self.info.server, self.info.port, + self.info.timeout, &self.info.hello_name, &self.info.tls, ) diff --git a/src/transport/smtp/client/async_connection.rs b/src/transport/smtp/client/async_connection.rs index 2d459bb..e341726 100644 --- a/src/transport/smtp/client/async_connection.rs +++ b/src/transport/smtp/client/async_connection.rs @@ -11,7 +11,7 @@ use crate::{ Envelope, }; use futures_util::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use std::fmt::Display; +use std::{fmt::Display, time::Duration}; #[cfg(feature = "tracing")] use super::escape_crlf; @@ -48,13 +48,13 @@ impl AsyncSmtpConnection { /// /// Sends EHLO and parses server information #[cfg(feature = "tokio1")] - pub async fn connect_tokio1( - hostname: &str, - port: u16, + pub async fn connect_tokio1( + server: T, + timeout: Option, hello_name: &ClientId, tls_parameters: Option, ) -> Result { - let stream = AsyncNetworkStream::connect_tokio1(hostname, port, tls_parameters).await?; + let stream = AsyncNetworkStream::connect_tokio1(server, timeout, tls_parameters).await?; Self::connect_impl(stream, hello_name).await } @@ -62,13 +62,13 @@ impl AsyncSmtpConnection { /// /// Sends EHLO and parses server information #[cfg(feature = "async-std1")] - pub async fn connect_asyncstd1( - hostname: &str, - port: u16, + pub async fn connect_asyncstd1( + server: T, + timeout: Option, hello_name: &ClientId, tls_parameters: Option, ) -> Result { - let stream = AsyncNetworkStream::connect_asyncstd1(hostname, port, tls_parameters).await?; + let stream = AsyncNetworkStream::connect_asyncstd1(server, timeout, tls_parameters).await?; Self::connect_impl(stream, hello_name).await } diff --git a/src/transport/smtp/client/async_net.rs b/src/transport/smtp/client/async_net.rs index 0a98aa4..a757eb4 100644 --- a/src/transport/smtp/client/async_net.rs +++ b/src/transport/smtp/client/async_net.rs @@ -1,8 +1,9 @@ use std::{ - mem, + io, mem, net::SocketAddr, pin::Pin, task::{Context, Poll}, + time::Duration, }; use futures_io::{ @@ -13,9 +14,9 @@ use futures_io::{ use tokio1_crate::io::{AsyncRead as _, AsyncWrite as _, ReadBuf as Tokio1ReadBuf}; #[cfg(feature = "async-std1")] -use async_std::net::TcpStream as AsyncStd1TcpStream; +use async_std::net::{TcpStream as AsyncStd1TcpStream, ToSocketAddrs as AsyncStd1ToSocketAddrs}; #[cfg(feature = "tokio1")] -use tokio1_crate::net::TcpStream as Tokio1TcpStream; +use tokio1_crate::net::{TcpStream as Tokio1TcpStream, ToSocketAddrs as Tokio1ToSocketAddrs}; #[cfg(feature = "async-std1-native-tls")] use async_native_tls::TlsStream as AsyncStd1TlsStream; @@ -107,14 +108,47 @@ impl AsyncNetworkStream { } #[cfg(feature = "tokio1")] - pub async fn connect_tokio1( - hostname: &str, - port: u16, + pub async fn connect_tokio1( + server: T, + timeout: Option, tls_parameters: Option, ) -> Result { - let tcp_stream = Tokio1TcpStream::connect((hostname, port)) - .await - .map_err(error::connection)?; + async fn try_connect_timeout( + server: T, + timeout: Duration, + ) -> Result { + let addrs = tokio1_crate::net::lookup_host(server) + .await + .map_err(error::connection)?; + + let mut last_err = None; + + for addr in addrs { + let connect_future = Tokio1TcpStream::connect(&addr); + match tokio1_crate::time::timeout(timeout, connect_future).await { + Ok(Ok(stream)) => return Ok(stream), + Ok(Err(err)) => last_err = Some(err), + Err(_) => { + last_err = Some(io::Error::new( + io::ErrorKind::TimedOut, + "connection timed out", + )) + } + } + } + + Err(match last_err { + Some(last_err) => error::connection(last_err), + None => error::connection("could not resolve to any address"), + }) + } + + let tcp_stream = match timeout { + Some(t) => try_connect_timeout(server, t).await?, + None => Tokio1TcpStream::connect(server) + .await + .map_err(error::connection)?, + }; let mut stream = AsyncNetworkStream::new(InnerAsyncNetworkStream::Tokio1Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { @@ -124,14 +158,45 @@ impl AsyncNetworkStream { } #[cfg(feature = "async-std1")] - pub async fn connect_asyncstd1( - hostname: &str, - port: u16, + pub async fn connect_asyncstd1( + server: T, + timeout: Option, tls_parameters: Option, ) -> Result { - let tcp_stream = AsyncStd1TcpStream::connect((hostname, port)) - .await - .map_err(error::connection)?; + async fn try_connect_timeout( + server: T, + timeout: Duration, + ) -> Result { + let addrs = server.to_socket_addrs().await.map_err(error::connection)?; + + let mut last_err = None; + + for addr in addrs { + let connect_future = AsyncStd1TcpStream::connect(&addr); + match async_std::future::timeout(timeout, connect_future).await { + Ok(Ok(stream)) => return Ok(stream), + Ok(Err(err)) => last_err = Some(err), + Err(_) => { + last_err = Some(io::Error::new( + io::ErrorKind::TimedOut, + "connection timed out", + )) + } + } + } + + Err(match last_err { + Some(last_err) => error::connection(last_err), + None => error::connection("could not resolve to any address"), + }) + } + + let tcp_stream = match timeout { + Some(t) => try_connect_timeout(server, t).await?, + None => AsyncStd1TcpStream::connect(server) + .await + .map_err(error::connection)?, + }; let mut stream = AsyncNetworkStream::new(InnerAsyncNetworkStream::AsyncStd1Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { diff --git a/src/transport/smtp/client/net.rs b/src/transport/smtp/client/net.rs index 40ad6b0..6f5a979 100644 --- a/src/transport/smtp/client/net.rs +++ b/src/transport/smtp/client/net.rs @@ -90,12 +90,20 @@ impl NetworkStream { timeout: Duration, ) -> Result { let addrs = server.to_socket_addrs().map_err(error::connection)?; + + let mut last_err = None; + for addr in addrs { - if let Ok(result) = TcpStream::connect_timeout(&addr, timeout) { - return Ok(result); + match TcpStream::connect_timeout(&addr, timeout) { + Ok(stream) => return Ok(stream), + Err(err) => last_err = Some(err), } } - Err(error::connection("Could not connect")) + + Err(match last_err { + Some(last_err) => error::connection(last_err), + None => error::connection("could not resolve to any address"), + }) } let tcp_stream = match timeout {