[proxy] Prevent TLS stream from hanging

This change causes writer halves of a TLS stream to always flush after a
portion of bytes has been written by `std::io::copy`. Furthermore, some
cosmetic and minor functional changes are made to facilitate debug.
This commit is contained in:
Dmitry Ivanov
2021-10-19 18:31:26 +03:00
parent e42c884c2b
commit 85116a8375
4 changed files with 81 additions and 60 deletions

View File

@@ -107,21 +107,12 @@ impl io::Write for WriteStream {
}
}
pub struct TlsBoxed {
stream: BufStream,
session: rustls::ServerSession,
}
impl TlsBoxed {
fn rustls_stream(&mut self) -> rustls::Stream<rustls::ServerSession, BufStream> {
rustls::Stream::new(&mut self.session, &mut self.stream)
}
}
type TlsStream<T> = rustls::StreamOwned<rustls::ServerSession, T>;
pub enum BidiStream {
Tcp(BufStream),
/// This variant is boxed, because [`rustls::ServerSession`] is quite larger than [`BufStream`].
Tls(Box<TlsBoxed>),
Tls(Box<TlsStream<BufStream>>),
}
impl BidiStream {
@@ -134,11 +125,11 @@ impl BidiStream {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(tls_boxed) => {
if how == Shutdown::Read {
tls_boxed.stream.get_ref().shutdown(how)
tls_boxed.sock.get_ref().shutdown(how)
} else {
tls_boxed.session.send_close_notify();
let res = tls_boxed.rustls_stream().flush();
tls_boxed.stream.get_ref().shutdown(how)?;
tls_boxed.sess.send_close_notify();
let res = tls_boxed.flush();
tls_boxed.sock.get_ref().shutdown(how)?;
res
}
}
@@ -155,7 +146,7 @@ impl BidiStream {
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
}
Self::Tls(tls_boxed) => {
let reader = tls_boxed.stream.into_reader();
let reader = tls_boxed.sock.into_reader();
let buffer_data = reader.buffer().to_owned();
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
@@ -164,7 +155,7 @@ impl BidiStream {
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
let (read_half, write_half) =
rustls_split::split(socket, tls_boxed.session, read_buf_cfg, write_buf_cfg);
rustls_split::split(socket, tls_boxed.sess, read_buf_cfg, write_buf_cfg);
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
}
}
@@ -175,7 +166,7 @@ impl BidiStream {
Self::Tcp(mut stream) => {
session.complete_io(&mut stream)?;
assert!(!session.is_handshaking());
Ok(Self::Tls(Box::new(TlsBoxed { stream, session })))
Ok(Self::Tls(Box::new(TlsStream::new(session, stream))))
}
Self::Tls { .. } => Err(io::Error::new(
io::ErrorKind::InvalidInput,
@@ -189,7 +180,7 @@ impl io::Read for BidiStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::Tls(tls_boxed) => tls_boxed.rustls_stream().read(buf),
Self::Tls(tls_boxed) => tls_boxed.read(buf),
}
}
}
@@ -198,14 +189,14 @@ impl io::Write for BidiStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
Self::Tls(tls_boxed) => tls_boxed.rustls_stream().write(buf),
Self::Tls(tls_boxed) => tls_boxed.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
Self::Tls(tls_boxed) => tls_boxed.rustls_stream().flush(),
Self::Tls(tls_boxed) => tls_boxed.flush(),
}
}
}