diff --git a/lettre/Cargo.toml b/lettre/Cargo.toml index 7ad3db8..c01ab13 100644 --- a/lettre/Cargo.toml +++ b/lettre/Cargo.toml @@ -18,7 +18,7 @@ travis-ci = { repository = "lettre/lettre" } bufstream = "^0.1" log = "^0.3" openssl = "^0.9" -base64 = "^0.5" +base64 = "^0.6" hex = "^0.2" rust-crypto = "^0.2" diff --git a/lettre/src/smtp/client/mock.rs b/lettre/src/smtp/client/mock.rs new file mode 100644 index 0000000..f1c158f --- /dev/null +++ b/lettre/src/smtp/client/mock.rs @@ -0,0 +1,114 @@ +#![allow(missing_docs)] +// Comes from https://github.com/inre/rust-mq/blob/master/netopt + +use std::io::{self, Cursor, Read, Write}; +use std::sync::{Arc, Mutex}; + +pub type MockCursor = Cursor>; + +#[derive(Clone,Debug)] +pub struct MockStream { + reader: Arc>, + writer: Arc>, +} + +impl MockStream { + pub fn new() -> MockStream { + MockStream { + reader: Arc::new(Mutex::new(MockCursor::new(Vec::new()))), + writer: Arc::new(Mutex::new(MockCursor::new(Vec::new()))), + } + } + + pub fn with_vec(vec: Vec) -> MockStream { + MockStream { + reader: Arc::new(Mutex::new(MockCursor::new(vec))), + writer: Arc::new(Mutex::new(MockCursor::new(Vec::new()))), + } + } + + pub fn take_vec(&mut self) -> Vec { + let mut cursor = self.writer.lock().unwrap(); + let vec = cursor.get_ref().to_vec(); + cursor.set_position(0); + cursor.get_mut().clear(); + vec + } + + pub fn next_vec(&mut self, vec: Vec) { + let mut cursor = self.reader.lock().unwrap(); + cursor.set_position(0); + cursor.get_mut().clear(); + cursor.get_mut().extend_from_slice(vec.as_slice()); + } + + pub fn swap(&mut self) { + let mut cur_write = self.writer.lock().unwrap(); + let mut cur_read = self.reader.lock().unwrap(); + let vec_write = cur_write.get_ref().to_vec(); + let vec_read = cur_read.get_ref().to_vec(); + cur_write.set_position(0); + cur_read.set_position(0); + cur_write.get_mut().clear(); + cur_read.get_mut().clear(); + // swap cursors + cur_read.get_mut().extend_from_slice(vec_write.as_slice()); + cur_write.get_mut().extend_from_slice(vec_read.as_slice()); + } +} + +impl Write for MockStream { + fn write(&mut self, msg: &[u8]) -> io::Result { + self.writer.lock().unwrap().write(msg) + } + + fn flush(&mut self) -> io::Result<()> { + self.writer.lock().unwrap().flush() + } +} + +impl Read for MockStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.reader.lock().unwrap().read(buf) + } +} + +#[cfg(test)] +mod test { + use super::MockStream; + use std::io::{Read, Write}; + + #[test] + fn write_take_test() { + let mut mock = MockStream::new(); + // write to mock stream + mock.write(&[1, 2, 3]).unwrap(); + assert_eq!(mock.take_vec(), vec![1, 2, 3]); + } + + #[test] + fn read_with_vec_test() { + let mut mock = MockStream::with_vec(vec![4, 5]); + let mut vec = Vec::new(); + mock.read_to_end(&mut vec).unwrap(); + assert_eq!(vec, vec![4, 5]); + } + + #[test] + fn clone_test() { + let mut mock = MockStream::new(); + let mut clonned = mock.clone(); + mock.write(&[6, 7]).unwrap(); + assert_eq!(clonned.take_vec(), vec![6, 7]); + } + + #[test] + fn swap_test() { + let mut mock = MockStream::new(); + let mut vec = Vec::new(); + mock.write(&[8, 9, 10]).unwrap(); + mock.swap(); + mock.read_to_end(&mut vec).unwrap(); + assert_eq!(vec, vec![8, 9, 10]); + } +} diff --git a/lettre/src/smtp/client/mod.rs b/lettre/src/smtp/client/mod.rs index 2305e0e..ef9c85d 100644 --- a/lettre/src/smtp/client/mod.rs +++ b/lettre/src/smtp/client/mod.rs @@ -1,6 +1,5 @@ //! SMTP client - use base64; use bufstream::BufStream; use openssl::ssl::SslContext; @@ -17,6 +16,7 @@ use std::string::String; use std::time::Duration; pub mod net; +pub mod mock; /// Returns the string after adding a dot at the beginning of each line starting with a dot /// @@ -67,7 +67,7 @@ impl Client { } } -impl Client { +impl Client { /// Closes the SMTP transaction if possible pub fn close(&mut self) { let _ = self.quit(); diff --git a/lettre/src/smtp/client/net.rs b/lettre/src/smtp/client/net.rs index 244fb32..873075b 100644 --- a/lettre/src/smtp/client/net.rs +++ b/lettre/src/smtp/client/net.rs @@ -1,13 +1,74 @@ //! A trait to represent a stream use openssl::ssl::{Ssl, SslContext, SslStream}; -use std::fmt; -use std::fmt::{Debug, Formatter}; + +use smtp::client::mock::MockStream; use std::io; use std::io::{ErrorKind, Read, Write}; -use std::net::{SocketAddr, TcpStream}; +use std::net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream}; use std::time::Duration; +#[derive(Debug)] +/// Represents the different types of underlying network streams +pub enum NetworkStream { + /// Plain TCP stream + Tcp(TcpStream), + /// Encrypted TCP stream + Ssl(SslStream), + /// Mock stream + Mock(MockStream), +} + +impl NetworkStream { + /// Returns peer's address + pub fn peer_addr(&self) -> io::Result { + match *self { + NetworkStream::Tcp(ref s) => s.peer_addr(), + NetworkStream::Ssl(ref s) => s.get_ref().peer_addr(), + NetworkStream::Mock(_) => { + Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 80))) + } + } + } + + /// Shutdowns the connection + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + match *self { + NetworkStream::Tcp(ref s) => s.shutdown(how), + NetworkStream::Ssl(ref s) => s.get_ref().shutdown(how), + NetworkStream::Mock(_) => Ok(()), + } + } +} + +impl Read for NetworkStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match *self { + NetworkStream::Tcp(ref mut s) => s.read(buf), + NetworkStream::Ssl(ref mut s) => s.read(buf), + NetworkStream::Mock(ref mut s) => s.read(buf), + } + } +} + +impl Write for NetworkStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + match *self { + NetworkStream::Tcp(ref mut s) => s.write(buf), + NetworkStream::Ssl(ref mut s) => s.write(buf), + NetworkStream::Mock(ref mut s) => s.write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match *self { + NetworkStream::Tcp(ref mut s) => s.flush(), + NetworkStream::Ssl(ref mut s) => s.flush(), + NetworkStream::Mock(ref mut s) => s.flush(), + } + } +} + /// A trait for the concept of opening a stream pub trait Connector: Sized { /// Opens a connection to the given IP socket @@ -33,14 +94,14 @@ impl Connector for NetworkStream { Err(e) => Err(io::Error::new(ErrorKind::Other, e)), } } - None => Ok(NetworkStream::Plain(tcp_stream)), + None => Ok(NetworkStream::Tcp(tcp_stream)), } } fn upgrade_tls(&mut self, ssl_context: &SslContext) -> io::Result<()> { *self = match *self { - NetworkStream::Plain(ref mut stream) => { + NetworkStream::Tcp(ref mut stream) => { match Ssl::new(ssl_context) { Ok(ssl) => { match ssl.connect(stream.try_clone().unwrap()) { @@ -52,6 +113,7 @@ impl Connector for NetworkStream { } } NetworkStream::Ssl(_) => return Ok(()), + NetworkStream::Mock(_) => return Ok(()), }; Ok(()) @@ -60,50 +122,9 @@ impl Connector for NetworkStream { fn is_encrypted(&self) -> bool { match *self { - NetworkStream::Plain(_) => false, + NetworkStream::Tcp(_) => false, NetworkStream::Ssl(_) => true, - } - } -} - - -/// Represents the different types of underlying network streams -pub enum NetworkStream { - /// Plain TCP - Plain(TcpStream), - /// SSL over TCP - Ssl(SslStream), -} - -impl Debug for NetworkStream { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.write_str("NetworkStream(_)") - } -} - -impl Read for NetworkStream { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - NetworkStream::Plain(ref mut stream) => stream.read(buf), - NetworkStream::Ssl(ref mut stream) => stream.read(buf), - } - } -} - -impl Write for NetworkStream { - #[inline] - fn write(&mut self, msg: &[u8]) -> io::Result { - match *self { - NetworkStream::Plain(ref mut stream) => stream.write(msg), - NetworkStream::Ssl(ref mut stream) => stream.write(msg), - } - } - #[inline] - fn flush(&mut self) -> io::Result<()> { - match *self { - NetworkStream::Plain(ref mut stream) => stream.flush(), - NetworkStream::Ssl(ref mut stream) => stream.flush(), + NetworkStream::Mock(_) => false, } } } @@ -119,16 +140,18 @@ pub trait Timeout: Sized { impl Timeout for NetworkStream { fn set_read_timeout(&mut self, duration: Option) -> io::Result<()> { match *self { - NetworkStream::Plain(ref mut stream) => stream.set_read_timeout(duration), - NetworkStream::Ssl(ref mut stream) => stream.get_mut().set_read_timeout(duration), + NetworkStream::Tcp(ref mut stream) => stream.set_read_timeout(duration), + NetworkStream::Ssl(ref mut stream) => stream.get_ref().set_read_timeout(duration), + NetworkStream::Mock(_) => Ok(()), } } /// Set write tiemout for IO calls fn set_write_timeout(&mut self, duration: Option) -> io::Result<()> { match *self { - NetworkStream::Plain(ref mut stream) => stream.set_write_timeout(duration), - NetworkStream::Ssl(ref mut stream) => stream.get_mut().set_write_timeout(duration), + NetworkStream::Tcp(ref mut stream) => stream.set_write_timeout(duration), + NetworkStream::Ssl(ref mut stream) => stream.get_ref().set_write_timeout(duration), + NetworkStream::Mock(_) => Ok(()), } } }