Remove MockStream and all internal uses of it (#580)

Co-authored-by: Alexis Mousset <contact@amousset.me>
This commit is contained in:
Paolo Barbolini
2021-03-19 09:20:05 +01:00
committed by GitHub
parent f041c00df7
commit 3bc729ca64
4 changed files with 52 additions and 150 deletions

View File

@@ -5,6 +5,7 @@
))]
use std::sync::Arc;
use std::{
mem,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
@@ -203,7 +204,7 @@ impl AsyncNetworkStream {
#[cfg(any(feature = "tokio02-native-tls", feature = "tokio02-rustls-tls"))]
InnerAsyncNetworkStream::Tokio02Tcp(_) => {
// get owned TcpStream
let tcp_stream = std::mem::replace(&mut self.inner, InnerAsyncNetworkStream::None);
let tcp_stream = mem::replace(&mut self.inner, InnerAsyncNetworkStream::None);
let tcp_stream = match tcp_stream {
InnerAsyncNetworkStream::Tokio02Tcp(tcp_stream) => tcp_stream,
_ => unreachable!(),
@@ -226,7 +227,7 @@ impl AsyncNetworkStream {
#[cfg(any(feature = "tokio1-native-tls", feature = "tokio1-rustls-tls"))]
InnerAsyncNetworkStream::Tokio1Tcp(_) => {
// get owned TcpStream
let tcp_stream = std::mem::replace(&mut self.inner, InnerAsyncNetworkStream::None);
let tcp_stream = mem::replace(&mut self.inner, InnerAsyncNetworkStream::None);
let tcp_stream = match tcp_stream {
InnerAsyncNetworkStream::Tokio1Tcp(tcp_stream) => tcp_stream,
_ => unreachable!(),
@@ -249,7 +250,7 @@ impl AsyncNetworkStream {
#[cfg(any(feature = "async-std1-native-tls", feature = "async-std1-rustls-tls"))]
InnerAsyncNetworkStream::AsyncStd1Tcp(_) => {
// get owned TcpStream
let tcp_stream = std::mem::replace(&mut self.inner, InnerAsyncNetworkStream::None);
let tcp_stream = mem::replace(&mut self.inner, InnerAsyncNetworkStream::None);
let tcp_stream = match tcp_stream {
InnerAsyncNetworkStream::AsyncStd1Tcp(tcp_stream) => tcp_stream,
_ => unreachable!(),
@@ -270,7 +271,7 @@ impl AsyncNetworkStream {
tcp_stream: Tokio02TcpStream,
mut tls_parameters: TlsParameters,
) -> Result<InnerAsyncNetworkStream, Error> {
let domain = std::mem::take(&mut tls_parameters.domain);
let domain = mem::take(&mut tls_parameters.domain);
match tls_parameters.connector {
#[cfg(feature = "native-tls")]
@@ -319,7 +320,7 @@ impl AsyncNetworkStream {
tcp_stream: Tokio1TcpStream,
mut tls_parameters: TlsParameters,
) -> Result<InnerAsyncNetworkStream, Error> {
let domain = std::mem::take(&mut tls_parameters.domain);
let domain = mem::take(&mut tls_parameters.domain);
match tls_parameters.connector {
#[cfg(feature = "native-tls")]
@@ -368,7 +369,7 @@ impl AsyncNetworkStream {
tcp_stream: AsyncStd1TcpStream,
mut tls_parameters: TlsParameters,
) -> Result<InnerAsyncNetworkStream, Error> {
let domain = std::mem::take(&mut tls_parameters.domain);
let domain = mem::take(&mut tls_parameters.domain);
match tls_parameters.connector {
#[cfg(feature = "native-tls")]

View File

@@ -1,122 +0,0 @@
#![allow(missing_docs)]
// Comes from https://github.com/inre/rust-mq/blob/master/netopt
use std::{
io::{self, Cursor, Read, Write},
sync::{Arc, Mutex},
};
pub type MockCursor = Cursor<Vec<u8>>;
#[derive(Clone, Debug)]
pub struct MockStream {
reader: Arc<Mutex<MockCursor>>,
writer: Arc<Mutex<MockCursor>>,
}
impl Default for MockStream {
fn default() -> Self {
Self::new()
}
}
impl MockStream {
pub fn new() -> MockStream {
MockStream {
reader: Arc::new(Mutex::new(MockCursor::new(Vec::new()))),
writer: Arc::new(Mutex::new(MockCursor::new(Vec::new()))),
}
}
pub fn with_vec(vec: Vec<u8>) -> MockStream {
MockStream {
reader: Arc::new(Mutex::new(MockCursor::new(vec))),
writer: Arc::new(Mutex::new(MockCursor::new(Vec::new()))),
}
}
pub fn take_vec(&mut self) -> Vec<u8> {
let mut cursor = self.writer.lock().unwrap();
let vec = cursor.get_ref().to_vec();
cursor.set_position(0);
cursor.get_mut().clear();
vec
}
pub fn next_vec(&mut self, vec: &[u8]) {
let mut cursor = self.reader.lock().unwrap();
cursor.set_position(0);
cursor.get_mut().clear();
cursor.get_mut().extend_from_slice(vec);
}
pub fn swap(&mut self) {
let mut cur_write = self.writer.lock().unwrap();
let mut cur_read = self.reader.lock().unwrap();
let vec_write = cur_write.get_ref().to_vec();
let vec_read = cur_read.get_ref().to_vec();
cur_write.set_position(0);
cur_read.set_position(0);
cur_write.get_mut().clear();
cur_read.get_mut().clear();
// swap cursors
cur_read.get_mut().extend_from_slice(vec_write.as_slice());
cur_write.get_mut().extend_from_slice(vec_read.as_slice());
}
}
impl Write for MockStream {
fn write(&mut self, msg: &[u8]) -> io::Result<usize> {
self.writer.lock().unwrap().write(msg)
}
fn flush(&mut self) -> io::Result<()> {
self.writer.lock().unwrap().flush()
}
}
impl Read for MockStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.reader.lock().unwrap().read(buf)
}
}
#[cfg(test)]
mod test {
use super::MockStream;
use std::io::{Read, Write};
#[test]
fn write_take_test() {
let mut mock = MockStream::new();
// write to mock stream
mock.write_all(&[1, 2, 3]).unwrap();
assert_eq!(mock.take_vec(), vec![1, 2, 3]);
}
#[test]
fn read_with_vec_test() {
let mut mock = MockStream::with_vec(vec![4, 5]);
let mut vec = Vec::new();
mock.read_to_end(&mut vec).unwrap();
assert_eq!(vec, vec![4, 5]);
}
#[test]
fn clone_test() {
let mut mock = MockStream::new();
let mut cloned = mock.clone();
mock.write_all(&[6, 7]).unwrap();
assert_eq!(cloned.take_vec(), vec![6, 7]);
}
#[test]
fn swap_test() {
let mut mock = MockStream::new();
let mut vec = Vec::new();
mock.write_all(&[8, 9, 10]).unwrap();
mock.swap();
mock.read_to_end(&mut vec).unwrap();
assert_eq!(vec, vec![8, 9, 10]);
}
}

View File

@@ -36,7 +36,6 @@ use self::net::NetworkStream;
pub(super) use self::tls::InnerTlsParameters;
pub use self::{
connection::SmtpConnection,
mock::MockStream,
tls::{Certificate, Tls, TlsParameters, TlsParametersBuilder},
};
@@ -45,7 +44,6 @@ mod async_connection;
#[cfg(any(feature = "tokio02", feature = "tokio1", feature = "async-std1"))]
mod async_net;
mod connection;
mod mock;
mod net;
mod tls;

View File

@@ -2,6 +2,7 @@
use std::sync::Arc;
use std::{
io::{self, Read, Write},
mem,
net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpStream, ToSocketAddrs},
time::Duration,
};
@@ -14,7 +15,7 @@ use rustls::{ClientSession, StreamOwned};
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
use super::InnerTlsParameters;
use super::{MockStream, TlsParameters};
use super::TlsParameters;
use crate::transport::smtp::{error, Error};
/// A network stream
@@ -35,17 +36,17 @@ enum InnerNetworkStream {
/// Encrypted TCP stream
#[cfg(feature = "rustls-tls")]
RustlsTls(StreamOwned<ClientSession, TcpStream>),
/// Mock stream
Mock(MockStream),
/// Can't be built
None,
}
impl NetworkStream {
fn new(inner: InnerNetworkStream) -> Self {
NetworkStream { inner }
}
if let InnerNetworkStream::None = inner {
debug_assert!(false, "InnerNetworkStream::None must never be built");
}
pub fn new_mock(mock: MockStream) -> Self {
Self::new(InnerNetworkStream::Mock(mock))
NetworkStream { inner }
}
/// Returns peer's address
@@ -56,10 +57,13 @@ impl NetworkStream {
InnerNetworkStream::NativeTls(ref s) => s.get_ref().peer_addr(),
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(ref s) => s.get_ref().peer_addr(),
InnerNetworkStream::Mock(_) => Ok(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 0, 1),
80,
))),
InnerNetworkStream::None => {
debug_assert!(false, "InnerNetworkStream::None must never be built");
Ok(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 0, 1),
80,
)))
}
}
}
@@ -71,7 +75,10 @@ impl NetworkStream {
InnerNetworkStream::NativeTls(ref s) => s.get_ref().shutdown(how),
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(ref s) => s.get_ref().shutdown(how),
InnerNetworkStream::Mock(_) => Ok(()),
InnerNetworkStream::None => {
debug_assert!(false, "InnerNetworkStream::None must never be built");
Ok(())
}
}
}
@@ -116,8 +123,7 @@ impl NetworkStream {
#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
InnerNetworkStream::Tcp(_) => {
// get owned TcpStream
let tcp_stream =
std::mem::replace(&mut self.inner, InnerNetworkStream::Mock(MockStream::new()));
let tcp_stream = mem::replace(&mut self.inner, InnerNetworkStream::None);
let tcp_stream = match tcp_stream {
InnerNetworkStream::Tcp(tcp_stream) => tcp_stream,
_ => unreachable!(),
@@ -161,11 +167,15 @@ impl NetworkStream {
pub fn is_encrypted(&self) -> bool {
match self.inner {
InnerNetworkStream::Tcp(_) | InnerNetworkStream::Mock(_) => false,
InnerNetworkStream::Tcp(_) => false,
#[cfg(feature = "native-tls")]
InnerNetworkStream::NativeTls(_) => true,
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(_) => true,
InnerNetworkStream::None => {
debug_assert!(false, "InnerNetworkStream::None must never be built");
false
}
}
}
@@ -180,7 +190,10 @@ impl NetworkStream {
InnerNetworkStream::RustlsTls(ref mut stream) => {
stream.get_ref().set_read_timeout(duration)
}
InnerNetworkStream::Mock(_) => Ok(()),
InnerNetworkStream::None => {
debug_assert!(false, "InnerNetworkStream::None must never be built");
Ok(())
}
}
}
@@ -198,7 +211,10 @@ impl NetworkStream {
stream.get_ref().set_write_timeout(duration)
}
InnerNetworkStream::Mock(_) => Ok(()),
InnerNetworkStream::None => {
debug_assert!(false, "InnerNetworkStream::None must never be built");
Ok(())
}
}
}
}
@@ -211,7 +227,10 @@ impl Read for NetworkStream {
InnerNetworkStream::NativeTls(ref mut s) => s.read(buf),
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(ref mut s) => s.read(buf),
InnerNetworkStream::Mock(ref mut s) => s.read(buf),
InnerNetworkStream::None => {
debug_assert!(false, "InnerNetworkStream::None must never be built");
Ok(0)
}
}
}
}
@@ -224,7 +243,10 @@ impl Write for NetworkStream {
InnerNetworkStream::NativeTls(ref mut s) => s.write(buf),
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(ref mut s) => s.write(buf),
InnerNetworkStream::Mock(ref mut s) => s.write(buf),
InnerNetworkStream::None => {
debug_assert!(false, "InnerNetworkStream::None must never be built");
Ok(0)
}
}
}
@@ -235,7 +257,10 @@ impl Write for NetworkStream {
InnerNetworkStream::NativeTls(ref mut s) => s.flush(),
#[cfg(feature = "rustls-tls")]
InnerNetworkStream::RustlsTls(ref mut s) => s.flush(),
InnerNetworkStream::Mock(ref mut s) => s.flush(),
InnerNetworkStream::None => {
debug_assert!(false, "InnerNetworkStream::None must never be built");
Ok(())
}
}
}
}