diff --git a/proxy/src/cplane_api.rs b/proxy/src/cplane_api.rs index 2b3b8a64bd..3435dca7b2 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cplane_api.rs @@ -12,7 +12,7 @@ pub struct DatabaseInfo { pub port: u16, pub dbname: String, pub user: String, - pub password: String, + pub password: Option, } impl DatabaseInfo { @@ -24,12 +24,23 @@ impl DatabaseInfo { .next() .ok_or_else(|| anyhow::Error::msg("cannot resolve at least one SocketAddr")) } +} - pub fn conn_string(&self) -> String { - format!( - "dbname={} user={} password={}", - self.dbname, self.user, self.password - ) +impl From for tokio_postgres::Config { + fn from(db_info: DatabaseInfo) -> Self { + let mut config = tokio_postgres::Config::new(); + + config + .host(&db_info.host) + .port(db_info.port) + .dbname(&db_info.dbname) + .user(&db_info.user); + + if let Some(password) = db_info.password { + config.password(password); + } + + config } } diff --git a/proxy/src/main.rs b/proxy/src/main.rs index bbabab52f7..c183785635 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -145,18 +145,18 @@ fn main() -> anyhow::Result<()> { println!("Starting mgmt on {}", state.conf.mgmt_address); let mgmt_listener = TcpListener::bind(state.conf.mgmt_address)?; - let threads = vec![ + let threads = [ // Spawn a thread to listen for connections. It will spawn further threads // for each connection. thread::Builder::new() - .name("Proxy thread".into()) + .name("Listener thread".into()) .spawn(move || proxy::thread_main(state, pageserver_listener))?, thread::Builder::new() .name("Mgmt thread".into()) .spawn(move || mgmt::thread_main(state, mgmt_listener))?, ]; - for t in threads.into_iter() { + for t in threads { t.join().unwrap()?; } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 61a742cf38..1debabae9c 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -6,7 +6,6 @@ use anyhow::bail; use tokio_postgres::NoTls; use rand::Rng; -use std::io::Write; use std::{io, sync::mpsc::channel, thread}; use zenith_utils::postgres_backend::Stream; use zenith_utils::postgres_backend::{PostgresBackend, ProtoState}; @@ -28,11 +27,13 @@ pub fn thread_main( println!("accepted connection from {}", peer_addr); socket.set_nodelay(true).unwrap(); - thread::spawn(move || { - if let Err(err) = proxy_conn_main(state, socket) { - println!("error: {}", err); - } - }); + thread::Builder::new() + .name("Proxy thread".into()) + .spawn(move || { + if let Err(err) = proxy_conn_main(state, socket) { + println!("error: {}", err); + } + })?; } } @@ -158,6 +159,7 @@ impl ProxyConnection { fn handle_existing_user(&mut self) -> anyhow::Result { // ask password rand::thread_rng().fill(&mut self.md5_salt); + self.pgb .write_message(&BeMessage::AuthenticationMD5Password(&self.md5_salt))?; self.pgb.state = ProtoState::Authentication; // XXX @@ -250,51 +252,68 @@ databases without opening the browser. /// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result { let mut socket = tokio::net::TcpStream::connect(db_info.socket_addr()?).await?; - let config = db_info.conn_string().parse::()?; + let config = tokio_postgres::Config::from(db_info); let _ = config.connect_raw(&mut socket, NoTls).await?; Ok(socket) } /// Concurrently proxy both directions of the client and server connections fn proxy( - client_read: ReadStream, - client_write: WriteStream, - server_read: ReadStream, - server_write: WriteStream, + (client_read, client_write): (ReadStream, WriteStream), + (server_read, server_write): (ReadStream, WriteStream), ) -> anyhow::Result<()> { - fn do_proxy(mut reader: ReadStream, mut writer: WriteStream) -> io::Result<()> { - std::io::copy(&mut reader, &mut writer)?; - writer.flush()?; - writer.shutdown(std::net::Shutdown::Both) + fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result { + /// FlushWriter will make sure that every message is sent as soon as possible + struct FlushWriter(W); + + impl io::Write for FlushWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + // `std::io::copy` is guaranteed to exit if we return an error, + // so we can afford to lose `res` in case `flush` fails + let res = self.0.write(buf); + if res.is_ok() { + self.0.flush()?; + } + res + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + } + + let res = std::io::copy(&mut reader, &mut FlushWriter(&mut writer)); + writer.shutdown(std::net::Shutdown::Both)?; + res } let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write)); - let res1 = do_proxy(server_read, client_write); - let res2 = client_to_server_jh.join().unwrap(); - res1?; - res2?; + do_proxy(server_read, client_write)?; + client_to_server_jh.join().unwrap()?; Ok(()) } /// Proxy a client connection to a postgres database fn proxy_pass(pgb: PostgresBackend, db_info: DatabaseInfo) -> anyhow::Result<()> { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - let db_stream = runtime.block_on(connect_to_db(db_info))?; - let db_stream = db_stream.into_std()?; - db_stream.set_nonblocking(false)?; + let db_stream = { + // We'll get rid of this once migration to async is complete + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; - let db_stream = zenith_utils::sock_split::BidiStream::from_tcp(db_stream); - let (db_read, db_write) = db_stream.split(); + let stream = runtime.block_on(connect_to_db(db_info))?.into_std()?; + stream.set_nonblocking(false)?; + stream + }; - let stream = match pgb.into_stream() { + let db = zenith_utils::sock_split::BidiStream::from_tcp(db_stream); + + let client = match pgb.into_stream() { Stream::Bidirectional(bidi_stream) => bidi_stream, _ => bail!("invalid stream"), }; - let (client_read, client_write) = stream.split(); - proxy(client_read, client_write, db_read, db_write) + proxy(client.split(), db.split()) } diff --git a/zenith_utils/src/sock_split.rs b/zenith_utils/src/sock_split.rs index f7078dafc2..c62963e113 100644 --- a/zenith_utils/src/sock_split.rs +++ b/zenith_utils/src/sock_split.rs @@ -107,21 +107,12 @@ impl io::Write for WriteStream { } } -pub struct TlsBoxed { - stream: BufStream, - session: rustls::ServerSession, -} - -impl TlsBoxed { - fn rustls_stream(&mut self) -> rustls::Stream { - rustls::Stream::new(&mut self.session, &mut self.stream) - } -} +type TlsStream = rustls::StreamOwned; pub enum BidiStream { Tcp(BufStream), /// This variant is boxed, because [`rustls::ServerSession`] is quite larger than [`BufStream`]. - Tls(Box), + Tls(Box>), } impl BidiStream { @@ -134,11 +125,11 @@ impl BidiStream { Self::Tcp(stream) => stream.get_ref().shutdown(how), Self::Tls(tls_boxed) => { if how == Shutdown::Read { - tls_boxed.stream.get_ref().shutdown(how) + tls_boxed.sock.get_ref().shutdown(how) } else { - tls_boxed.session.send_close_notify(); - let res = tls_boxed.rustls_stream().flush(); - tls_boxed.stream.get_ref().shutdown(how)?; + tls_boxed.sess.send_close_notify(); + let res = tls_boxed.flush(); + tls_boxed.sock.get_ref().shutdown(how)?; res } } @@ -155,7 +146,7 @@ impl BidiStream { (ReadStream::Tcp(reader), WriteStream::Tcp(stream)) } Self::Tls(tls_boxed) => { - let reader = tls_boxed.stream.into_reader(); + let reader = tls_boxed.sock.into_reader(); let buffer_data = reader.buffer().to_owned(); let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192); let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192); @@ -164,7 +155,7 @@ impl BidiStream { let socket = Arc::try_unwrap(reader.into_inner().0).unwrap(); let (read_half, write_half) = - rustls_split::split(socket, tls_boxed.session, read_buf_cfg, write_buf_cfg); + rustls_split::split(socket, tls_boxed.sess, read_buf_cfg, write_buf_cfg); (ReadStream::Tls(read_half), WriteStream::Tls(write_half)) } } @@ -175,7 +166,7 @@ impl BidiStream { Self::Tcp(mut stream) => { session.complete_io(&mut stream)?; assert!(!session.is_handshaking()); - Ok(Self::Tls(Box::new(TlsBoxed { stream, session }))) + Ok(Self::Tls(Box::new(TlsStream::new(session, stream)))) } Self::Tls { .. } => Err(io::Error::new( io::ErrorKind::InvalidInput, @@ -189,7 +180,7 @@ impl io::Read for BidiStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { Self::Tcp(stream) => stream.read(buf), - Self::Tls(tls_boxed) => tls_boxed.rustls_stream().read(buf), + Self::Tls(tls_boxed) => tls_boxed.read(buf), } } } @@ -198,14 +189,14 @@ impl io::Write for BidiStream { fn write(&mut self, buf: &[u8]) -> io::Result { match self { Self::Tcp(stream) => stream.write(buf), - Self::Tls(tls_boxed) => tls_boxed.rustls_stream().write(buf), + Self::Tls(tls_boxed) => tls_boxed.write(buf), } } fn flush(&mut self) -> io::Result<()> { match self { Self::Tcp(stream) => stream.flush(), - Self::Tls(tls_boxed) => tls_boxed.rustls_stream().flush(), + Self::Tls(tls_boxed) => tls_boxed.flush(), } } }