From 1391a834ce9616f66c8eeced108b19eeb64ca2fa Mon Sep 17 00:00:00 2001 From: Jacob Halsey Date: Sun, 29 May 2022 08:05:39 +0100 Subject: [PATCH] #715: Support setting the local IP address to connect from (#762) This adds a `local_address: Option` parameter to the synchronous, and tokio connect functions. (As far as I can see there is no current way to support this in async-std, because the library doesn't provide any way to do an async connect for an existing socket) --- Cargo.toml | 5 +- src/executor.rs | 1 + src/transport/smtp/client/async_connection.rs | 7 +- src/transport/smtp/client/async_net.rs | 67 ++++++++++----- src/transport/smtp/client/connection.rs | 5 +- src/transport/smtp/client/mod.rs | 2 +- src/transport/smtp/client/net.rs | 85 ++++++++++++++++--- src/transport/smtp/transport.rs | 1 + 8 files changed, 132 insertions(+), 41 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index af36658..0fdd2ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,9 +37,10 @@ uuid = { version = "1", features = ["v4"], optional = true } serde = { version = "1", optional = true, features = ["derive"] } serde_json = { version = "1", optional = true } -# smtp +# smtp-transport nom = { version = "7", optional = true } hostname = { version = "0.3", optional = true } # feature +socket2 = { version = "0.4.4", optional = true } ## tls native-tls = { version = "0.2", optional = true } # feature @@ -93,7 +94,7 @@ mime03 = ["mime"] file-transport = ["uuid"] file-transport-envelope = ["serde", "serde_json", "file-transport"] sendmail-transport = [] -smtp-transport = ["base64", "nom"] +smtp-transport = ["base64", "nom", "socket2"] pool = ["futures-util"] diff --git a/src/executor.rs b/src/executor.rs index 672503e..99adaa7 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -146,6 +146,7 @@ impl Executor for Tokio1Executor { timeout, hello_name, tls_parameters, + None, ) .await?; diff --git a/src/transport/smtp/client/async_connection.rs b/src/transport/smtp/client/async_connection.rs index b3cd03b..9a8d9bf 100644 --- a/src/transport/smtp/client/async_connection.rs +++ b/src/transport/smtp/client/async_connection.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, time::Duration}; +use std::{fmt::Display, net::IpAddr, time::Duration}; use futures_util::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; @@ -54,8 +54,11 @@ impl AsyncSmtpConnection { timeout: Option, hello_name: &ClientId, tls_parameters: Option, + local_address: Option, ) -> Result { - let stream = AsyncNetworkStream::connect_tokio1(server, timeout, tls_parameters).await?; + let stream = + AsyncNetworkStream::connect_tokio1(server, timeout, tls_parameters, local_address) + .await?; Self::connect_impl(stream, hello_name).await } diff --git a/src/transport/smtp/client/async_net.rs b/src/transport/smtp/client/async_net.rs index 3786a7e..d992490 100644 --- a/src/transport/smtp/client/async_net.rs +++ b/src/transport/smtp/client/async_net.rs @@ -1,6 +1,6 @@ use std::{ io, mem, - net::SocketAddr, + net::{IpAddr, SocketAddr}, pin::Pin, task::{Context, Poll}, time::Duration, @@ -19,7 +19,10 @@ use futures_rustls::client::TlsStream as AsyncStd1RustlsTlsStream; #[cfg(feature = "tokio1")] use tokio1_crate::io::{AsyncRead as _, AsyncWrite as _, ReadBuf as Tokio1ReadBuf}; #[cfg(feature = "tokio1")] -use tokio1_crate::net::{TcpStream as Tokio1TcpStream, ToSocketAddrs as Tokio1ToSocketAddrs}; +use tokio1_crate::net::{ + TcpSocket as Tokio1TcpSocket, TcpStream as Tokio1TcpStream, + ToSocketAddrs as Tokio1ToSocketAddrs, +}; #[cfg(feature = "tokio1-native-tls")] use tokio1_native_tls_crate::TlsStream as Tokio1TlsStream; #[cfg(feature = "tokio1-rustls-tls")] @@ -33,6 +36,8 @@ use tokio1_rustls::client::TlsStream as Tokio1RustlsTlsStream; ))] use super::InnerTlsParameters; use super::TlsParameters; +#[cfg(feature = "tokio1")] +use crate::transport::smtp::client::net::resolved_address_filter; use crate::transport::smtp::{error, Error}; /// A network stream @@ -109,44 +114,59 @@ impl AsyncNetworkStream { server: T, timeout: Option, tls_parameters: Option, + local_addr: Option, ) -> Result { - async fn try_connect_timeout( + async fn try_connect( server: T, - timeout: Duration, + timeout: Option, + local_addr: Option, ) -> Result { let addrs = tokio1_crate::net::lookup_host(server) .await - .map_err(error::connection)?; + .map_err(error::connection)? + .filter(|resolved_addr| resolved_address_filter(resolved_addr, local_addr)); let mut last_err = None; for addr in addrs { - let connect_future = Tokio1TcpStream::connect(&addr); - match tokio1_crate::time::timeout(timeout, connect_future).await { - Ok(Ok(stream)) => return Ok(stream), - Ok(Err(err)) => last_err = Some(err), - Err(_) => { - last_err = Some(io::Error::new( - io::ErrorKind::TimedOut, - "connection timed out", - )) + let socket = match addr.ip() { + IpAddr::V4(_) => Tokio1TcpSocket::new_v4(), + IpAddr::V6(_) => Tokio1TcpSocket::new_v6(), + } + .map_err(error::connection)?; + if let Some(local_addr) = local_addr { + socket + .bind(SocketAddr::new(local_addr, 0)) + .map_err(error::connection)?; + } + + let connect_future = socket.connect(addr); + if let Some(timeout) = timeout { + match tokio1_crate::time::timeout(timeout, connect_future).await { + Ok(Ok(stream)) => return Ok(stream), + Ok(Err(err)) => last_err = Some(err), + Err(_) => { + last_err = Some(io::Error::new( + io::ErrorKind::TimedOut, + "connection timed out", + )) + } + } + } else { + match connect_future.await { + Ok(stream) => return Ok(stream), + Err(err) => last_err = Some(err), } } } Err(match last_err { Some(last_err) => error::connection(last_err), - None => error::connection("could not resolve to any address"), + None => error::connection("could not resolve to any supported address"), }) } - let tcp_stream = match timeout { - Some(t) => try_connect_timeout(server, t).await?, - None => Tokio1TcpStream::connect(server) - .await - .map_err(error::connection)?, - }; - + let tcp_stream = try_connect(server, timeout, local_addr).await?; let mut stream = AsyncNetworkStream::new(InnerAsyncNetworkStream::Tokio1Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { stream.upgrade_tls(tls_parameters).await?; @@ -160,6 +180,9 @@ impl AsyncNetworkStream { timeout: Option, tls_parameters: Option, ) -> Result { + // Unfortunately there doesn't currently seem to be a way to set the local address + // Whilst we can create a AsyncStd1TcpStream from an existing socket, it needs to first have + // connected which is a blocking operation. async fn try_connect_timeout( server: T, timeout: Duration, diff --git a/src/transport/smtp/client/connection.rs b/src/transport/smtp/client/connection.rs index 45e1b73..71e3e34 100644 --- a/src/transport/smtp/client/connection.rs +++ b/src/transport/smtp/client/connection.rs @@ -1,7 +1,7 @@ use std::{ fmt::Display, io::{self, BufRead, BufReader, Write}, - net::ToSocketAddrs, + net::{IpAddr, ToSocketAddrs}, time::Duration, }; @@ -58,8 +58,9 @@ impl SmtpConnection { timeout: Option, hello_name: &ClientId, tls_parameters: Option<&TlsParameters>, + local_address: Option, ) -> Result { - let stream = NetworkStream::connect(server, timeout, tls_parameters)?; + let stream = NetworkStream::connect(server, timeout, tls_parameters, local_address)?; let stream = BufReader::new(stream); let mut conn = SmtpConnection { stream, diff --git a/src/transport/smtp/client/mod.rs b/src/transport/smtp/client/mod.rs index 68b750f..3a6123e 100644 --- a/src/transport/smtp/client/mod.rs +++ b/src/transport/smtp/client/mod.rs @@ -12,7 +12,7 @@ //! }; //! //! let hello = ClientId::Domain("my_hostname".to_string()); -//! let mut client = SmtpConnection::connect(&("localhost", SMTP_PORT), None, &hello, None)?; +//! let mut client = SmtpConnection::connect(&("localhost", SMTP_PORT), None, &hello, None, None)?; //! client.command(Mail::new(Some("user@example.com".parse()?), vec![]))?; //! client.command(Rcpt::new("user@example.org".parse()?, vec![]))?; //! client.command(Data)?; diff --git a/src/transport/smtp/client/net.rs b/src/transport/smtp/client/net.rs index dc63726..86fe2e9 100644 --- a/src/transport/smtp/client/net.rs +++ b/src/transport/smtp/client/net.rs @@ -1,7 +1,7 @@ use std::{ io::{self, Read, Write}, mem, - net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream, ToSocketAddrs}, + net::{IpAddr, Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream, ToSocketAddrs}, time::Duration, }; @@ -9,6 +9,7 @@ use std::{ use native_tls::TlsStream; #[cfg(feature = "rustls-tls")] use rustls::{ClientConnection, ServerName, StreamOwned}; +use socket2::{Domain, Protocol, Type}; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] use super::InnerTlsParameters; @@ -83,19 +84,39 @@ impl NetworkStream { server: T, timeout: Option, tls_parameters: Option<&TlsParameters>, + local_addr: Option, ) -> Result { - fn try_connect_timeout( + fn try_connect( server: T, - timeout: Duration, + timeout: Option, + local_addr: Option, ) -> Result { - let addrs = server.to_socket_addrs().map_err(error::connection)?; + let addrs = server + .to_socket_addrs() + .map_err(error::connection)? + .filter(|resolved_addr| resolved_address_filter(resolved_addr, local_addr)); let mut last_err = None; for addr in addrs { - match TcpStream::connect_timeout(&addr, timeout) { - Ok(stream) => return Ok(stream), - Err(err) => last_err = Some(err), + let socket = socket2::Socket::new( + Domain::for_address(addr), + Type::STREAM, + Some(Protocol::TCP), + ) + .map_err(error::connection)?; + bind_local_address(&socket, &addr, local_addr)?; + + if let Some(timeout) = timeout { + match socket.connect_timeout(&addr.into(), timeout) { + Ok(_) => return Ok(socket.into()), + Err(err) => last_err = Some(err), + } + } else { + match socket.connect(&addr.into()) { + Ok(_) => return Ok(socket.into()), + Err(err) => last_err = Some(err), + } } } @@ -105,11 +126,7 @@ impl NetworkStream { }) } - let tcp_stream = match timeout { - Some(t) => try_connect_timeout(server, t)?, - None => TcpStream::connect(server).map_err(error::connection)?, - }; - + let tcp_stream = try_connect(server, timeout, local_addr)?; let mut stream = NetworkStream::new(InnerNetworkStream::Tcp(tcp_stream)); if let Some(tls_parameters) = tls_parameters { stream.upgrade_tls(tls_parameters)?; @@ -289,3 +306,47 @@ impl Write for NetworkStream { } } } + +/// If the local address is set, binds the socket to this address. +/// If local address is not set, then destination address is required to determine to the default +/// local address on some platforms. +/// See: https://github.com/hyperium/hyper/blob/faf24c6ad8eee1c3d5ccc9a4d4835717b8e2903f/src/client/connect/http.rs#L560 +fn bind_local_address( + socket: &socket2::Socket, + dst_addr: &SocketAddr, + local_addr: Option, +) -> Result<(), Error> { + match local_addr { + Some(local_addr) => { + socket + .bind(&SocketAddr::new(local_addr, 0).into()) + .map_err(error::connection)?; + } + _ => { + if cfg!(windows) { + // Windows requires a socket be bound before calling connect + let any: SocketAddr = match dst_addr { + SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), + SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), + }; + socket.bind(&any.into()).map_err(error::connection)?; + } + } + } + Ok(()) +} + +/// When we have an iterator of resolved remote addresses, we must filter them to be the same +/// protocol as the local address binding. If no local address is set, then all will be matched. +pub(crate) fn resolved_address_filter( + resolved_addr: &SocketAddr, + local_addr: Option, +) -> bool { + match local_addr { + Some(local_addr) => match resolved_addr.ip() { + IpAddr::V4(_) => local_addr.is_ipv4(), + IpAddr::V6(_) => local_addr.is_ipv6(), + }, + None => true, + } +} diff --git a/src/transport/smtp/transport.rs b/src/transport/smtp/transport.rs index 5e22b76..ed6b39b 100644 --- a/src/transport/smtp/transport.rs +++ b/src/transport/smtp/transport.rs @@ -221,6 +221,7 @@ impl SmtpClient { self.info.timeout, &self.info.hello_name, tls_parameters, + None, )?; #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]