[proxy] Prevent TLS stream from hanging

This change causes writer halves of a TLS stream to always flush after a
portion of bytes has been written by `std::io::copy`. Furthermore, some
cosmetic and minor functional changes are made to facilitate debug.
This commit is contained in:
Dmitry Ivanov
2021-10-19 18:31:26 +03:00
parent e42c884c2b
commit 85116a8375
4 changed files with 81 additions and 60 deletions

View File

@@ -12,7 +12,7 @@ pub struct DatabaseInfo {
pub port: u16,
pub dbname: String,
pub user: String,
pub password: String,
pub password: Option<String>,
}
impl DatabaseInfo {
@@ -24,12 +24,23 @@ impl DatabaseInfo {
.next()
.ok_or_else(|| anyhow::Error::msg("cannot resolve at least one SocketAddr"))
}
}
pub fn conn_string(&self) -> String {
format!(
"dbname={} user={} password={}",
self.dbname, self.user, self.password
)
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}

View File

@@ -145,18 +145,18 @@ fn main() -> anyhow::Result<()> {
println!("Starting mgmt on {}", state.conf.mgmt_address);
let mgmt_listener = TcpListener::bind(state.conf.mgmt_address)?;
let threads = vec![
let threads = [
// Spawn a thread to listen for connections. It will spawn further threads
// for each connection.
thread::Builder::new()
.name("Proxy thread".into())
.name("Listener thread".into())
.spawn(move || proxy::thread_main(state, pageserver_listener))?,
thread::Builder::new()
.name("Mgmt thread".into())
.spawn(move || mgmt::thread_main(state, mgmt_listener))?,
];
for t in threads.into_iter() {
for t in threads {
t.join().unwrap()?;
}

View File

@@ -6,7 +6,6 @@ use anyhow::bail;
use tokio_postgres::NoTls;
use rand::Rng;
use std::io::Write;
use std::{io, sync::mpsc::channel, thread};
use zenith_utils::postgres_backend::Stream;
use zenith_utils::postgres_backend::{PostgresBackend, ProtoState};
@@ -28,11 +27,13 @@ pub fn thread_main(
println!("accepted connection from {}", peer_addr);
socket.set_nodelay(true).unwrap();
thread::spawn(move || {
if let Err(err) = proxy_conn_main(state, socket) {
println!("error: {}", err);
}
});
thread::Builder::new()
.name("Proxy thread".into())
.spawn(move || {
if let Err(err) = proxy_conn_main(state, socket) {
println!("error: {}", err);
}
})?;
}
}
@@ -158,6 +159,7 @@ impl ProxyConnection {
fn handle_existing_user(&mut self) -> anyhow::Result<DatabaseInfo> {
// ask password
rand::thread_rng().fill(&mut self.md5_salt);
self.pgb
.write_message(&BeMessage::AuthenticationMD5Password(&self.md5_salt))?;
self.pgb.state = ProtoState::Authentication; // XXX
@@ -250,51 +252,68 @@ databases without opening the browser.
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result<tokio::net::TcpStream> {
let mut socket = tokio::net::TcpStream::connect(db_info.socket_addr()?).await?;
let config = db_info.conn_string().parse::<tokio_postgres::Config>()?;
let config = tokio_postgres::Config::from(db_info);
let _ = config.connect_raw(&mut socket, NoTls).await?;
Ok(socket)
}
/// Concurrently proxy both directions of the client and server connections
fn proxy(
client_read: ReadStream,
client_write: WriteStream,
server_read: ReadStream,
server_write: WriteStream,
(client_read, client_write): (ReadStream, WriteStream),
(server_read, server_write): (ReadStream, WriteStream),
) -> anyhow::Result<()> {
fn do_proxy(mut reader: ReadStream, mut writer: WriteStream) -> io::Result<()> {
std::io::copy(&mut reader, &mut writer)?;
writer.flush()?;
writer.shutdown(std::net::Shutdown::Both)
fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result<u64> {
/// FlushWriter will make sure that every message is sent as soon as possible
struct FlushWriter<W>(W);
impl<W: io::Write> io::Write for FlushWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
// `std::io::copy` is guaranteed to exit if we return an error,
// so we can afford to lose `res` in case `flush` fails
let res = self.0.write(buf);
if res.is_ok() {
self.0.flush()?;
}
res
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
let res = std::io::copy(&mut reader, &mut FlushWriter(&mut writer));
writer.shutdown(std::net::Shutdown::Both)?;
res
}
let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write));
let res1 = do_proxy(server_read, client_write);
let res2 = client_to_server_jh.join().unwrap();
res1?;
res2?;
do_proxy(server_read, client_write)?;
client_to_server_jh.join().unwrap()?;
Ok(())
}
/// Proxy a client connection to a postgres database
fn proxy_pass(pgb: PostgresBackend, db_info: DatabaseInfo) -> anyhow::Result<()> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let db_stream = runtime.block_on(connect_to_db(db_info))?;
let db_stream = db_stream.into_std()?;
db_stream.set_nonblocking(false)?;
let db_stream = {
// We'll get rid of this once migration to async is complete
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let db_stream = zenith_utils::sock_split::BidiStream::from_tcp(db_stream);
let (db_read, db_write) = db_stream.split();
let stream = runtime.block_on(connect_to_db(db_info))?.into_std()?;
stream.set_nonblocking(false)?;
stream
};
let stream = match pgb.into_stream() {
let db = zenith_utils::sock_split::BidiStream::from_tcp(db_stream);
let client = match pgb.into_stream() {
Stream::Bidirectional(bidi_stream) => bidi_stream,
_ => bail!("invalid stream"),
};
let (client_read, client_write) = stream.split();
proxy(client_read, client_write, db_read, db_write)
proxy(client.split(), db.split())
}

View File

@@ -107,21 +107,12 @@ impl io::Write for WriteStream {
}
}
pub struct TlsBoxed {
stream: BufStream,
session: rustls::ServerSession,
}
impl TlsBoxed {
fn rustls_stream(&mut self) -> rustls::Stream<rustls::ServerSession, BufStream> {
rustls::Stream::new(&mut self.session, &mut self.stream)
}
}
type TlsStream<T> = rustls::StreamOwned<rustls::ServerSession, T>;
pub enum BidiStream {
Tcp(BufStream),
/// This variant is boxed, because [`rustls::ServerSession`] is quite larger than [`BufStream`].
Tls(Box<TlsBoxed>),
Tls(Box<TlsStream<BufStream>>),
}
impl BidiStream {
@@ -134,11 +125,11 @@ impl BidiStream {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(tls_boxed) => {
if how == Shutdown::Read {
tls_boxed.stream.get_ref().shutdown(how)
tls_boxed.sock.get_ref().shutdown(how)
} else {
tls_boxed.session.send_close_notify();
let res = tls_boxed.rustls_stream().flush();
tls_boxed.stream.get_ref().shutdown(how)?;
tls_boxed.sess.send_close_notify();
let res = tls_boxed.flush();
tls_boxed.sock.get_ref().shutdown(how)?;
res
}
}
@@ -155,7 +146,7 @@ impl BidiStream {
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
}
Self::Tls(tls_boxed) => {
let reader = tls_boxed.stream.into_reader();
let reader = tls_boxed.sock.into_reader();
let buffer_data = reader.buffer().to_owned();
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
@@ -164,7 +155,7 @@ impl BidiStream {
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
let (read_half, write_half) =
rustls_split::split(socket, tls_boxed.session, read_buf_cfg, write_buf_cfg);
rustls_split::split(socket, tls_boxed.sess, read_buf_cfg, write_buf_cfg);
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
}
}
@@ -175,7 +166,7 @@ impl BidiStream {
Self::Tcp(mut stream) => {
session.complete_io(&mut stream)?;
assert!(!session.is_handshaking());
Ok(Self::Tls(Box::new(TlsBoxed { stream, session })))
Ok(Self::Tls(Box::new(TlsStream::new(session, stream))))
}
Self::Tls { .. } => Err(io::Error::new(
io::ErrorKind::InvalidInput,
@@ -189,7 +180,7 @@ impl io::Read for BidiStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::Tls(tls_boxed) => tls_boxed.rustls_stream().read(buf),
Self::Tls(tls_boxed) => tls_boxed.read(buf),
}
}
}
@@ -198,14 +189,14 @@ impl io::Write for BidiStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
Self::Tls(tls_boxed) => tls_boxed.rustls_stream().write(buf),
Self::Tls(tls_boxed) => tls_boxed.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
Self::Tls(tls_boxed) => tls_boxed.rustls_stream().flush(),
Self::Tls(tls_boxed) => tls_boxed.flush(),
}
}
}