use std::{ io::{self, BufReader, Write}, net::{Shutdown, TcpStream}, sync::Arc, }; use rustls::Connection; /// Wrapper supporting reads of a shared TcpStream. pub struct ArcTcpRead(Arc); impl io::Read for ArcTcpRead { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { (&*self.0).read(buf) } } impl std::ops::Deref for ArcTcpRead { type Target = TcpStream; fn deref(&self) -> &Self::Target { self.0.deref() } } /// Wrapper around a TCP Stream supporting buffered reads. pub struct BufStream(BufReader); impl io::Read for BufStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.0.read(buf) } } impl io::Write for BufStream { fn write(&mut self, buf: &[u8]) -> io::Result { self.get_ref().write(buf) } fn flush(&mut self) -> io::Result<()> { self.get_ref().flush() } } impl BufStream { /// Unwrap into the internal BufReader. fn into_reader(self) -> BufReader { self.0 } /// Returns a reference to the underlying TcpStream. fn get_ref(&self) -> &TcpStream { &*self.0.get_ref().0 } } pub enum ReadStream { Tcp(BufReader), Tls(rustls_split::ReadHalf), } impl io::Read for ReadStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { Self::Tcp(reader) => reader.read(buf), Self::Tls(read_half) => read_half.read(buf), } } } impl ReadStream { pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { match self { Self::Tcp(stream) => stream.get_ref().shutdown(how), Self::Tls(write_half) => write_half.shutdown(how), } } } pub enum WriteStream { Tcp(Arc), Tls(rustls_split::WriteHalf), } impl WriteStream { pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { match self { Self::Tcp(stream) => stream.shutdown(how), Self::Tls(write_half) => write_half.shutdown(how), } } } impl io::Write for WriteStream { fn write(&mut self, buf: &[u8]) -> io::Result { match self { Self::Tcp(stream) => stream.as_ref().write(buf), Self::Tls(write_half) => write_half.write(buf), } } fn flush(&mut self) -> io::Result<()> { match self { Self::Tcp(stream) => stream.as_ref().flush(), Self::Tls(write_half) => write_half.flush(), } } } type TlsStream = rustls::StreamOwned; pub enum BidiStream { Tcp(BufStream), /// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`]. Tls(Box>), } impl BidiStream { pub fn from_tcp(stream: TcpStream) -> Self { Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream))))) } pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { match self { Self::Tcp(stream) => stream.get_ref().shutdown(how), Self::Tls(tls_boxed) => { if how == Shutdown::Read { tls_boxed.sock.get_ref().shutdown(how) } else { tls_boxed.conn.send_close_notify(); let res = tls_boxed.flush(); tls_boxed.sock.get_ref().shutdown(how)?; res } } } } /// Split the bi-directional stream into two owned read and write halves. pub fn split(self) -> (ReadStream, WriteStream) { match self { Self::Tcp(stream) => { let reader = stream.into_reader(); let stream: Arc = reader.get_ref().0.clone(); (ReadStream::Tcp(reader), WriteStream::Tcp(stream)) } Self::Tls(tls_boxed) => { let reader = tls_boxed.sock.into_reader(); let buffer_data = reader.buffer().to_owned(); let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192); let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192); // TODO would be nice to avoid the Arc here let socket = Arc::try_unwrap(reader.into_inner().0).unwrap(); let (read_half, write_half) = rustls_split::split( socket, Connection::Server(tls_boxed.conn), read_buf_cfg, write_buf_cfg, ); (ReadStream::Tls(read_half), WriteStream::Tls(write_half)) } } } pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result { match self { Self::Tcp(mut stream) => { conn.complete_io(&mut stream)?; assert!(!conn.is_handshaking()); Ok(Self::Tls(Box::new(TlsStream::new(conn, stream)))) } Self::Tls { .. } => Err(io::Error::new( io::ErrorKind::InvalidInput, "TLS is already started on this stream", )), } } } impl io::Read for BidiStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { Self::Tcp(stream) => stream.read(buf), Self::Tls(tls_boxed) => tls_boxed.read(buf), } } } impl io::Write for BidiStream { fn write(&mut self, buf: &[u8]) -> io::Result { match self { Self::Tcp(stream) => stream.write(buf), Self::Tls(tls_boxed) => tls_boxed.write(buf), } } fn flush(&mut self) -> io::Result<()> { match self { Self::Tcp(stream) => stream.flush(), Self::Tls(tls_boxed) => tls_boxed.flush(), } } }