Merge pull request #141 from amousset/improve-network

feat(transport): Add a mock network stream
This commit is contained in:
Alexis Mousset
2017-06-14 00:32:14 +02:00
committed by GitHub
4 changed files with 192 additions and 55 deletions

View File

@@ -18,7 +18,7 @@ travis-ci = { repository = "lettre/lettre" }
bufstream = "^0.1"
log = "^0.3"
openssl = "^0.9"
base64 = "^0.5"
base64 = "^0.6"
hex = "^0.2"
rust-crypto = "^0.2"

View File

@@ -0,0 +1,114 @@
#![allow(missing_docs)]
// Comes from https://github.com/inre/rust-mq/blob/master/netopt
use std::io::{self, Cursor, Read, Write};
use std::sync::{Arc, Mutex};
pub type MockCursor = Cursor<Vec<u8>>;
#[derive(Clone,Debug)]
pub struct MockStream {
reader: Arc<Mutex<MockCursor>>,
writer: Arc<Mutex<MockCursor>>,
}
impl MockStream {
pub fn new() -> MockStream {
MockStream {
reader: Arc::new(Mutex::new(MockCursor::new(Vec::new()))),
writer: Arc::new(Mutex::new(MockCursor::new(Vec::new()))),
}
}
pub fn with_vec(vec: Vec<u8>) -> MockStream {
MockStream {
reader: Arc::new(Mutex::new(MockCursor::new(vec))),
writer: Arc::new(Mutex::new(MockCursor::new(Vec::new()))),
}
}
pub fn take_vec(&mut self) -> Vec<u8> {
let mut cursor = self.writer.lock().unwrap();
let vec = cursor.get_ref().to_vec();
cursor.set_position(0);
cursor.get_mut().clear();
vec
}
pub fn next_vec(&mut self, vec: Vec<u8>) {
let mut cursor = self.reader.lock().unwrap();
cursor.set_position(0);
cursor.get_mut().clear();
cursor.get_mut().extend_from_slice(vec.as_slice());
}
pub fn swap(&mut self) {
let mut cur_write = self.writer.lock().unwrap();
let mut cur_read = self.reader.lock().unwrap();
let vec_write = cur_write.get_ref().to_vec();
let vec_read = cur_read.get_ref().to_vec();
cur_write.set_position(0);
cur_read.set_position(0);
cur_write.get_mut().clear();
cur_read.get_mut().clear();
// swap cursors
cur_read.get_mut().extend_from_slice(vec_write.as_slice());
cur_write.get_mut().extend_from_slice(vec_read.as_slice());
}
}
impl Write for MockStream {
fn write(&mut self, msg: &[u8]) -> io::Result<usize> {
self.writer.lock().unwrap().write(msg)
}
fn flush(&mut self) -> io::Result<()> {
self.writer.lock().unwrap().flush()
}
}
impl Read for MockStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.reader.lock().unwrap().read(buf)
}
}
#[cfg(test)]
mod test {
use super::MockStream;
use std::io::{Read, Write};
#[test]
fn write_take_test() {
let mut mock = MockStream::new();
// write to mock stream
mock.write(&[1, 2, 3]).unwrap();
assert_eq!(mock.take_vec(), vec![1, 2, 3]);
}
#[test]
fn read_with_vec_test() {
let mut mock = MockStream::with_vec(vec![4, 5]);
let mut vec = Vec::new();
mock.read_to_end(&mut vec).unwrap();
assert_eq!(vec, vec![4, 5]);
}
#[test]
fn clone_test() {
let mut mock = MockStream::new();
let mut clonned = mock.clone();
mock.write(&[6, 7]).unwrap();
assert_eq!(clonned.take_vec(), vec![6, 7]);
}
#[test]
fn swap_test() {
let mut mock = MockStream::new();
let mut vec = Vec::new();
mock.write(&[8, 9, 10]).unwrap();
mock.swap();
mock.read_to_end(&mut vec).unwrap();
assert_eq!(vec, vec![8, 9, 10]);
}
}

View File

@@ -1,6 +1,5 @@
//! SMTP client
use base64;
use bufstream::BufStream;
use openssl::ssl::SslContext;
@@ -17,6 +16,7 @@ use std::string::String;
use std::time::Duration;
pub mod net;
pub mod mock;
/// Returns the string after adding a dot at the beginning of each line starting with a dot
///
@@ -67,7 +67,7 @@ impl<S: Write + Read> Client<S> {
}
}
impl<S: Connector + Timeout + Write + Read + Debug> Client<S> {
impl<S: Connector + Write + Read + Timeout + Debug> Client<S> {
/// Closes the SMTP transaction if possible
pub fn close(&mut self) {
let _ = self.quit();

View File

@@ -1,13 +1,74 @@
//! A trait to represent a stream
use openssl::ssl::{Ssl, SslContext, SslStream};
use std::fmt;
use std::fmt::{Debug, Formatter};
use smtp::client::mock::MockStream;
use std::io;
use std::io::{ErrorKind, Read, Write};
use std::net::{SocketAddr, TcpStream};
use std::net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream};
use std::time::Duration;
#[derive(Debug)]
/// Represents the different types of underlying network streams
pub enum NetworkStream {
/// Plain TCP stream
Tcp(TcpStream),
/// Encrypted TCP stream
Ssl(SslStream<TcpStream>),
/// Mock stream
Mock(MockStream),
}
impl NetworkStream {
/// Returns peer's address
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
match *self {
NetworkStream::Tcp(ref s) => s.peer_addr(),
NetworkStream::Ssl(ref s) => s.get_ref().peer_addr(),
NetworkStream::Mock(_) => {
Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 80)))
}
}
}
/// Shutdowns the connection
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
match *self {
NetworkStream::Tcp(ref s) => s.shutdown(how),
NetworkStream::Ssl(ref s) => s.get_ref().shutdown(how),
NetworkStream::Mock(_) => Ok(()),
}
}
}
impl Read for NetworkStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match *self {
NetworkStream::Tcp(ref mut s) => s.read(buf),
NetworkStream::Ssl(ref mut s) => s.read(buf),
NetworkStream::Mock(ref mut s) => s.read(buf),
}
}
}
impl Write for NetworkStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match *self {
NetworkStream::Tcp(ref mut s) => s.write(buf),
NetworkStream::Ssl(ref mut s) => s.write(buf),
NetworkStream::Mock(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match *self {
NetworkStream::Tcp(ref mut s) => s.flush(),
NetworkStream::Ssl(ref mut s) => s.flush(),
NetworkStream::Mock(ref mut s) => s.flush(),
}
}
}
/// A trait for the concept of opening a stream
pub trait Connector: Sized {
/// Opens a connection to the given IP socket
@@ -33,14 +94,14 @@ impl Connector for NetworkStream {
Err(e) => Err(io::Error::new(ErrorKind::Other, e)),
}
}
None => Ok(NetworkStream::Plain(tcp_stream)),
None => Ok(NetworkStream::Tcp(tcp_stream)),
}
}
fn upgrade_tls(&mut self, ssl_context: &SslContext) -> io::Result<()> {
*self = match *self {
NetworkStream::Plain(ref mut stream) => {
NetworkStream::Tcp(ref mut stream) => {
match Ssl::new(ssl_context) {
Ok(ssl) => {
match ssl.connect(stream.try_clone().unwrap()) {
@@ -52,6 +113,7 @@ impl Connector for NetworkStream {
}
}
NetworkStream::Ssl(_) => return Ok(()),
NetworkStream::Mock(_) => return Ok(()),
};
Ok(())
@@ -60,50 +122,9 @@ impl Connector for NetworkStream {
fn is_encrypted(&self) -> bool {
match *self {
NetworkStream::Plain(_) => false,
NetworkStream::Tcp(_) => false,
NetworkStream::Ssl(_) => true,
}
}
}
/// Represents the different types of underlying network streams
pub enum NetworkStream {
/// Plain TCP
Plain(TcpStream),
/// SSL over TCP
Ssl(SslStream<TcpStream>),
}
impl Debug for NetworkStream {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.write_str("NetworkStream(_)")
}
}
impl Read for NetworkStream {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match *self {
NetworkStream::Plain(ref mut stream) => stream.read(buf),
NetworkStream::Ssl(ref mut stream) => stream.read(buf),
}
}
}
impl Write for NetworkStream {
#[inline]
fn write(&mut self, msg: &[u8]) -> io::Result<usize> {
match *self {
NetworkStream::Plain(ref mut stream) => stream.write(msg),
NetworkStream::Ssl(ref mut stream) => stream.write(msg),
}
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
match *self {
NetworkStream::Plain(ref mut stream) => stream.flush(),
NetworkStream::Ssl(ref mut stream) => stream.flush(),
NetworkStream::Mock(_) => false,
}
}
}
@@ -119,16 +140,18 @@ pub trait Timeout: Sized {
impl Timeout for NetworkStream {
fn set_read_timeout(&mut self, duration: Option<Duration>) -> io::Result<()> {
match *self {
NetworkStream::Plain(ref mut stream) => stream.set_read_timeout(duration),
NetworkStream::Ssl(ref mut stream) => stream.get_mut().set_read_timeout(duration),
NetworkStream::Tcp(ref mut stream) => stream.set_read_timeout(duration),
NetworkStream::Ssl(ref mut stream) => stream.get_ref().set_read_timeout(duration),
NetworkStream::Mock(_) => Ok(()),
}
}
/// Set write tiemout for IO calls
fn set_write_timeout(&mut self, duration: Option<Duration>) -> io::Result<()> {
match *self {
NetworkStream::Plain(ref mut stream) => stream.set_write_timeout(duration),
NetworkStream::Ssl(ref mut stream) => stream.get_mut().set_write_timeout(duration),
NetworkStream::Tcp(ref mut stream) => stream.set_write_timeout(duration),
NetworkStream::Ssl(ref mut stream) => stream.get_ref().set_write_timeout(duration),
NetworkStream::Mock(_) => Ok(()),
}
}
}