diff --git a/lettre/Cargo.toml b/lettre/Cargo.toml index c51a59f..6948cbe 100644 --- a/lettre/Cargo.toml +++ b/lettre/Cargo.toml @@ -31,6 +31,7 @@ serde_derive = { version = "^1.0", optional = true } failure = "^0.1" failure_derive = "^0.1" fast_chemail = "^0.9" +r2d2 = { version = "^0.8", optional = true} [dev-dependencies] env_logger = "^0.5" @@ -43,6 +44,7 @@ serde-impls = ["serde", "serde_derive"] file-transport = ["serde-impls", "serde_json"] smtp-transport = ["bufstream", "native-tls", "base64", "nom", "hostname"] sendmail-transport = [] +connection-pool = [ "r2d2" ] [[example]] name = "smtp" diff --git a/lettre/src/lib.rs b/lettre/src/lib.rs index f34ac22..2630ecb 100644 --- a/lettre/src/lib.rs +++ b/lettre/src/lib.rs @@ -32,6 +32,8 @@ extern crate serde_json; #[macro_use] extern crate failure_derive; extern crate fast_chemail; +#[cfg(feature = "connection-pool")] +extern crate r2d2; pub mod error; #[cfg(feature = "file-transport")] @@ -53,6 +55,8 @@ pub use sendmail::SendmailTransport; pub use smtp::client::net::ClientTlsParameters; #[cfg(feature = "smtp-transport")] pub use smtp::{ClientSecurity, SmtpClient, SmtpTransport}; +#[cfg(all(feature = "smtp-transport", feature = "connection-pool"))] +pub use smtp::r2d2::SmtpConnectionManager; use std::ffi::OsStr; use std::fmt::{self, Display, Formatter}; use std::io; diff --git a/lettre/src/smtp/client/mod.rs b/lettre/src/smtp/client/mod.rs index c2c1843..a3b19a3 100644 --- a/lettre/src/smtp/client/mod.rs +++ b/lettre/src/smtp/client/mod.rs @@ -167,7 +167,7 @@ impl InnerClient { /// Checks if the server is connected using the NOOP SMTP command #[cfg_attr(feature = "cargo-clippy", allow(wrong_self_convention))] pub fn is_connected(&mut self) -> bool { - self.command(NoopCommand).is_ok() + self.stream.is_some() && self.command(NoopCommand).is_ok() } /// Sends an AUTH command with the given mechanism, and handles challenge if needed diff --git a/lettre/src/smtp/mod.rs b/lettre/src/smtp/mod.rs index 78bb695..85f10cf 100644 --- a/lettre/src/smtp/mod.rs +++ b/lettre/src/smtp/mod.rs @@ -31,6 +31,8 @@ pub mod client; pub mod commands; pub mod error; pub mod extension; +#[cfg(feature = "connection-pool")] +pub mod r2d2; pub mod response; pub mod util; @@ -73,6 +75,7 @@ pub enum ConnectionReuseParameters { /// Contains client configuration #[allow(missing_debug_implementations)] +#[derive(Clone)] pub struct SmtpClient { /// Enable connection reuse connection_reuse: ConnectionReuseParameters, @@ -242,6 +245,95 @@ impl<'a> SmtpTransport { } } + fn connect(&mut self) -> Result<(), Error> { + // Check if the connection is still available + if (self.state.connection_reuse_count > 0) && (!self.client.is_connected()) { + self.close(); + } + + if self.state.connection_reuse_count > 0 { + info!("connection already established to {}", self.client_info.server_addr); + return Ok(()); + } + + self.client.connect( + &self.client_info.server_addr, + match self.client_info.security { + ClientSecurity::Wrapper(ref tls_parameters) => Some(tls_parameters), + _ => None, + }, + )?; + + self.client.set_timeout(self.client_info.timeout)?; + + // Log the connection + info!("connection established to {}", self.client_info.server_addr); + + self.ehlo()?; + + match ( + &self.client_info.security.clone(), + self.server_info + .as_ref() + .unwrap() + .supports_feature(Extension::StartTls), + ) { + (&ClientSecurity::Required(_), false) => { + return Err(From::from("Could not encrypt connection, aborting")) + } + (&ClientSecurity::Opportunistic(_), false) => (), + (&ClientSecurity::None, _) => (), + (&ClientSecurity::Wrapper(_), _) => (), + (&ClientSecurity::Opportunistic(ref tls_parameters), true) + | (&ClientSecurity::Required(ref tls_parameters), true) => { + try_smtp!(self.client.command(StarttlsCommand), self); + try_smtp!(self.client.upgrade_tls_stream(tls_parameters), self); + + debug!("connection encrypted"); + + // Send EHLO again + self.ehlo()?; + } + } + + if self.client_info.credentials.is_some() { + let mut found = false; + + // Compute accepted mechanism + let accepted_mechanisms = match self.client_info.authentication_mechanism { + Some(mechanism) => vec![mechanism], + None => { + if self.client.is_encrypted() { + DEFAULT_ENCRYPTED_MECHANISMS.to_vec() + } else { + DEFAULT_UNENCRYPTED_MECHANISMS.to_vec() + } + } + }; + + for mechanism in accepted_mechanisms { + if self.server_info + .as_ref() + .unwrap() + .supports_auth_mechanism(mechanism) + { + found = true; + try_smtp!( + self.client + .auth(mechanism, self.client_info.credentials.as_ref().unwrap(),), + self + ); + break; + } + } + + if !found { + info!("No supported authentication mechanisms available"); + } + } + Ok(()) + } + /// Gets the EHLO response and updates server information fn ehlo(&mut self) -> SmtpResult { // Extended Hello @@ -280,87 +372,8 @@ impl<'a> Transport<'a> for SmtpTransport { fn send(&mut self, email: SendableEmail) -> SmtpResult { let message_id = email.message_id().to_string(); - // Check if the connection is still available - if (self.state.connection_reuse_count > 0) && (!self.client.is_connected()) { - self.close(); - } - - if self.state.connection_reuse_count == 0 { - self.client.connect( - &self.client_info.server_addr, - match self.client_info.security { - ClientSecurity::Wrapper(ref tls_parameters) => Some(tls_parameters), - _ => None, - }, - )?; - - self.client.set_timeout(self.client_info.timeout)?; - - // Log the connection - info!("connection established to {}", self.client_info.server_addr); - - self.ehlo()?; - - match ( - &self.client_info.security.clone(), - self.server_info - .as_ref() - .unwrap() - .supports_feature(Extension::StartTls), - ) { - (&ClientSecurity::Required(_), false) => { - return Err(From::from("Could not encrypt connection, aborting")) - } - (&ClientSecurity::Opportunistic(_), false) => (), - (&ClientSecurity::None, _) => (), - (&ClientSecurity::Wrapper(_), _) => (), - (&ClientSecurity::Opportunistic(ref tls_parameters), true) - | (&ClientSecurity::Required(ref tls_parameters), true) => { - try_smtp!(self.client.command(StarttlsCommand), self); - try_smtp!(self.client.upgrade_tls_stream(tls_parameters), self); - - debug!("connection encrypted"); - - // Send EHLO again - self.ehlo()?; - } - } - - if self.client_info.credentials.is_some() { - let mut found = false; - - // Compute accepted mechanism - let accepted_mechanisms = match self.client_info.authentication_mechanism { - Some(mechanism) => vec![mechanism], - None => { - if self.client.is_encrypted() { - DEFAULT_ENCRYPTED_MECHANISMS.to_vec() - } else { - DEFAULT_UNENCRYPTED_MECHANISMS.to_vec() - } - } - }; - - for mechanism in accepted_mechanisms { - if self.server_info - .as_ref() - .unwrap() - .supports_auth_mechanism(mechanism) - { - found = true; - try_smtp!( - self.client - .auth(mechanism, self.client_info.credentials.as_ref().unwrap(),), - self - ); - break; - } - } - - if !found { - info!("No supported authentication mechanisms available"); - } - } + if !self.client.is_connected() { + self.connect()?; } // Mail diff --git a/lettre/src/smtp/r2d2.rs b/lettre/src/smtp/r2d2.rs new file mode 100644 index 0000000..41e1afe --- /dev/null +++ b/lettre/src/smtp/r2d2.rs @@ -0,0 +1,38 @@ +use r2d2::ManageConnection; +use smtp::{ConnectionReuseParameters, SmtpClient, SmtpTransport}; +use smtp::error::Error; + +pub struct SmtpConnectionManager { + transport_builder: SmtpClient, +} + +impl SmtpConnectionManager { + pub fn new(transport_builder: SmtpClient) -> Result { + Ok(SmtpConnectionManager { + transport_builder: transport_builder + .connection_reuse(ConnectionReuseParameters::ReuseUnlimited), + }) + } +} + +impl ManageConnection for SmtpConnectionManager { + type Connection = SmtpTransport; + type Error = Error; + + fn connect(&self) -> Result { + let mut transport = SmtpTransport::new(self.transport_builder.clone()); + transport.connect()?; + Ok(transport) + } + + fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Error> { + if conn.client.is_connected() { + return Ok(()); + } + Err(Error::Client("is not connected anymore")) + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + conn.state.panic + } +} diff --git a/lettre/tests/r2d2_smtp.rs b/lettre/tests/r2d2_smtp.rs new file mode 100644 index 0000000..055fdbe --- /dev/null +++ b/lettre/tests/r2d2_smtp.rs @@ -0,0 +1,73 @@ +#[cfg(all(test, feature = "smtp-transport", feature = "connection-pool"))] +mod test { + extern crate lettre; + extern crate r2d2; + + use self::lettre::{SmtpConnectionManager, Transport}; + use self::lettre::{ClientSecurity, EmailAddress, Envelope, SendableEmail, SmtpClient}; + use self::r2d2::Pool; + use std::sync::mpsc; + use std::thread; + + fn email(message: &str) -> SendableEmail { + SendableEmail::new( + Envelope::new( + Some(EmailAddress::new("user@localhost".to_string()).unwrap()), + vec![EmailAddress::new("root@localhost".to_string()).unwrap()], + ).unwrap(), + "id".to_string(), + message.to_string().into_bytes(), + ) + } + + #[test] + fn send_one() { + let client = SmtpClient::new("localhost:2525", ClientSecurity::None).unwrap(); + let manager = SmtpConnectionManager::new(client).unwrap(); + let pool = Pool::builder().max_size(1).build(manager).unwrap(); + + let mut mailer = pool.get().unwrap(); + let result = (*mailer).send(email("send one")); + assert!(result.is_ok()); + } + + #[test] + fn send_from_thread() { + let client = SmtpClient::new("127.0.0.1:2525", ClientSecurity::None).unwrap(); + let manager = SmtpConnectionManager::new(client).unwrap(); + let pool = Pool::builder().max_size(2).build(manager).unwrap(); + + let (s1, r1) = mpsc::channel(); + let (s2, r2) = mpsc::channel(); + + let pool1 = pool.clone(); + let t1 = thread::spawn(move || { + let mut conn = pool1.get().unwrap(); + s1.send(()).unwrap(); + r2.recv().unwrap(); + (*conn) + .send(email("send from thread 1")) + .expect("Send failed from thread 1"); + drop(conn); + }); + + let pool2 = pool.clone(); + let t2 = thread::spawn(move || { + let mut conn = pool2.get().unwrap(); + s2.send(()).unwrap(); + r1.recv().unwrap(); + (*conn) + .send(email("send from thread 2")) + .expect("Send failed from thread 2"); + drop(conn); + }); + + t1.join().unwrap(); + t2.join().unwrap(); + + let mut mailer = pool.get().unwrap(); + (*mailer) + .send(email("send from main thread")) + .expect("Send failed from main thread"); + } +}