diff --git a/Cargo.lock b/Cargo.lock index c7c96ab52f..c1b737ca94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1443,6 +1443,7 @@ dependencies = [ "hex", "md5", "rand", + "rustls", "serde", "serde_json", "tokio", @@ -1679,6 +1680,29 @@ dependencies = [ "semver", ] +[[package]] +name = "rustls" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" +dependencies = [ + "base64 0.13.0", + "log", + "ring", + "sct", + "webpki", +] + +[[package]] +name = "rustls-split" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d63f11490b4d8d45a362e171d7fe4a9ef154770a339e696a05eb354bc36837" +dependencies = [ + "rustls", + "webpki", +] + [[package]] name = "rustversion" version = "1.0.5" @@ -1716,6 +1740,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sct" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.3.1" @@ -2463,6 +2497,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "which" version = "3.1.1" @@ -2595,10 +2639,13 @@ dependencies = [ "postgres", "rand", "routerify", + "rustls", + "rustls-split", "serde", "serde_json", "thiserror", "tokio", + "webpki", "workspace_hack", "zenith_metrics", ] diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 83ea21f95e..b64fe261ef 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -176,7 +176,7 @@ fn page_service_conn_main( } let mut conn_handler = PageServerHandler::new(conf, auth); - let pgbackend = PostgresBackend::new(socket, auth_type)?; + let pgbackend = PostgresBackend::new(socket, auth_type, None)?; pgbackend.run(&mut conn_handler) } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 051a070f2f..fd13b7ba9e 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -17,5 +17,6 @@ serde_json = "1" tokio = { version = "1.7.1", features = ["full"] } tokio-postgres = "0.7.2" clap = "2.33.0" +rustls = "0.19.1" zenith_utils = { path = "../zenith_utils" } diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 15c446c9c1..df0291ec59 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -8,13 +8,15 @@ use std::{ collections::HashMap, net::{SocketAddr, TcpListener}, - sync::{mpsc, Mutex}, + sync::{mpsc, Arc, Mutex}, thread, }; -use clap::{App, Arg}; +use anyhow::{anyhow, bail, ensure, Context}; +use clap::{App, Arg, ArgMatches}; use cplane_api::DatabaseInfo; +use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig}; mod cplane_api; mod mgmt; @@ -33,6 +35,8 @@ pub struct ProxyConf { /// control plane address where we would check auth. pub cplane_address: SocketAddr, + + pub ssl_config: Option>, } pub struct ProxyState { @@ -40,6 +44,38 @@ pub struct ProxyState { pub waiters: Mutex>>>, } +fn configure_ssl(arg_matches: &ArgMatches) -> anyhow::Result>> { + let (key_path, cert_path) = match ( + arg_matches.value_of("ssl-key"), + arg_matches.value_of("ssl-cert"), + ) { + (Some(key_path), Some(cert_path)) => (key_path, cert_path), + (None, None) => return Ok(None), + _ => bail!("either both or neither ssl-key and ssl-cert must be specified"), + }; + + let key = { + let key_bytes = std::fs::read(key_path).context("SSL key file")?; + let mut keys = pemfile::rsa_private_keys(&mut &key_bytes[..]) + .or_else(|_| pemfile::pkcs8_private_keys(&mut &key_bytes[..])) + .map_err(|_| anyhow!("couldn't read TLS keys"))?; + ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); + keys.pop().unwrap() + }; + + let cert_chain = { + let cert_chain_bytes = std::fs::read(cert_path).context("SSL cert file")?; + pemfile::certs(&mut &cert_chain_bytes[..]) + .map_err(|_| anyhow!("couldn't read TLS certificates"))? + }; + + let mut config = ServerConfig::new(NoClientAuth::new()); + config.set_single_cert(cert_chain, key)?; + config.versions = vec![ProtocolVersion::TLSv1_3]; + + Ok(Some(Arc::new(config))) +} + fn main() -> anyhow::Result<()> { let arg_matches = App::new("Zenith proxy/router") .arg( @@ -66,6 +102,20 @@ fn main() -> anyhow::Result<()> { .help("redirect unauthenticated users to given uri") .default_value("http://localhost:3000/psql_session/"), ) + .arg( + Arg::with_name("ssl-key") + .short("k") + .long("ssl-key") + .takes_value(true) + .help("path to SSL key for client postgres connections"), + ) + .arg( + Arg::with_name("ssl-cert") + .short("c") + .long("ssl-cert") + .takes_value(true) + .help("path to SSL cert for client postgres connections"), + ) .get_matches(); let conf = ProxyConf { @@ -73,6 +123,7 @@ fn main() -> anyhow::Result<()> { mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?, redirect_uri: arg_matches.value_of("uri").unwrap().parse()?, cplane_address: "127.0.0.1:3000".parse()?, + ssl_config: configure_ssl(&arg_matches)?, }; let state = ProxyState { conf, diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index b8c59bbe03..0d3c114421 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -34,7 +34,7 @@ pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow: pub fn mgmt_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> { let mut conn_handler = MgmtHandler { state }; - let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; + let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?; pgbackend.run(&mut conn_handler) } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 4e9e242467..802a0bd305 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -6,11 +6,12 @@ use anyhow::bail; use tokio_postgres::NoTls; use rand::Rng; -use std::sync::mpsc::channel; -use std::thread; -use tokio::io::AsyncWriteExt; +use std::io::Write; +use std::{io, sync::mpsc::channel, thread}; +use zenith_utils::postgres_backend::Stream; use zenith_utils::postgres_backend::{PostgresBackend, ProtoState}; use zenith_utils::pq_proto::*; +use zenith_utils::sock_split::{ReadStream, WriteStream}; use zenith_utils::{postgres_backend, pq_proto::BeMessage}; /// @@ -59,7 +60,11 @@ pub fn proxy_conn_main( cplane: CPlaneApi::new(&state.conf.cplane_address), user: "".into(), database: "".into(), - pgb: PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?, + pgb: PostgresBackend::new( + socket, + postgres_backend::AuthType::MD5, + state.conf.ssl_config.clone(), + )?, md5_salt: [0u8; 4], psql_session_id: "".into(), }; @@ -75,17 +80,7 @@ pub fn proxy_conn_main( conn.handle_new_user()? }; - // ok, proxy pass user connection to database_uri - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - - let _ = runtime.block_on(proxy_pass(conn.pgb, db_info))?; - - println!("proxy_conn_main done;"); - - Ok(()) + proxy_pass(conn.pgb, db_info) } impl ProxyConnection { @@ -94,6 +89,7 @@ impl ProxyConnection { } fn handle_startup(&mut self) -> anyhow::Result<()> { + let mut encrypted = false; loop { let msg = self.pgb.read_message()?; println!("got message {:?}", msg); @@ -102,11 +98,29 @@ impl ProxyConnection { println!("got startup message {:?}", m); match m.kind { - StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { + StartupRequestCode::NegotiateGss => { + self.pgb + .write_message(&BeMessage::EncryptionResponse(false))?; + } + StartupRequestCode::NegotiateSsl => { println!("SSL requested"); - self.pgb.write_message(&BeMessage::Negotiate)?; + if self.pgb.tls_config.is_some() { + self.pgb + .write_message(&BeMessage::EncryptionResponse(true))?; + self.pgb.start_tls()?; + encrypted = true; + } else { + self.pgb + .write_message(&BeMessage::EncryptionResponse(false))?; + } } StartupRequestCode::Normal => { + if self.state.conf.ssl_config.is_some() && !encrypted { + self.pgb.write_message(&BeMessage::ErrorResponse( + "must connect with TLS".to_string(), + ))?; + bail!("client did not connect with TLS"); + } self.user = m .params .get("user") @@ -226,31 +240,52 @@ databases without opening the browser. } } -async fn proxy_pass(pgb: PostgresBackend, db_info: DatabaseInfo) -> anyhow::Result<()> { +/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message +async fn connect_to_db(db_info: DatabaseInfo) -> anyhow::Result { let mut socket = tokio::net::TcpStream::connect(db_info.socket_addr()).await?; let config = db_info.conn_string().parse::()?; let _ = config.connect_raw(&mut socket, NoTls).await?; + Ok(socket) +} - println!("Connected to pg, proxying"); +/// Concurrently proxy both directions of the client and server connections +fn proxy( + client_read: ReadStream, + client_write: WriteStream, + server_read: ReadStream, + server_write: WriteStream, +) -> anyhow::Result<()> { + fn do_proxy(mut reader: ReadStream, mut writer: WriteStream) -> io::Result<()> { + std::io::copy(&mut reader, &mut writer)?; + writer.flush()?; + writer.shutdown(std::net::Shutdown::Both) + } - let incoming_std = pgb.into_stream(); - incoming_std.set_nonblocking(true)?; - let mut incoming_conn = tokio::net::TcpStream::from_std(incoming_std)?; + let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write)); - let (mut ri, mut wi) = incoming_conn.split(); - let (mut ro, mut wo) = socket.split(); - - let client_to_server = async { - tokio::io::copy(&mut ri, &mut wo).await?; - wo.shutdown().await - }; - - let server_to_client = async { - tokio::io::copy(&mut ro, &mut wi).await?; - wi.shutdown().await - }; - - tokio::try_join!(client_to_server, server_to_client)?; + let res1 = do_proxy(server_read, client_write); + let res2 = client_to_server_jh.join().unwrap(); + res1?; + res2?; Ok(()) } + +/// Proxy a client connection to a postgres database +fn proxy_pass(pgb: PostgresBackend, db_info: DatabaseInfo) -> anyhow::Result<()> { + let runtime = tokio::runtime::Builder::new_current_thread().build()?; + let db_stream = runtime.block_on(connect_to_db(db_info))?; + let db_stream = db_stream.into_std()?; + db_stream.set_nonblocking(false)?; + + let db_stream = zenith_utils::sock_split::BidiStream::from_tcp(db_stream); + let (db_read, db_write) = db_stream.split(); + + let stream = match pgb.into_stream() { + Stream::Bidirectional(bidi_stream) => bidi_stream, + _ => bail!("invalid stream"), + }; + + let (client_read, client_write) = stream.split(); + proxy(client_read, client_write, db_read, db_write) +} diff --git a/walkeeper/src/receive_wal.rs b/walkeeper/src/receive_wal.rs index 646f6d9f88..611f36abed 100644 --- a/walkeeper/src/receive_wal.rs +++ b/walkeeper/src/receive_wal.rs @@ -74,7 +74,7 @@ fn request_callback(conf: WalAcceptorConf, timelineid: ZTimelineId, tenantid: ZT impl<'pg> ReceiveWalConn<'pg> { pub fn new(pg: &'pg mut PostgresBackend) -> Result> { - let peer_addr = pg.get_peer_addr()?; + let peer_addr = pg.get_peer_addr().clone(); Ok(ReceiveWalConn { pg_backend: pg, peer_addr, diff --git a/walkeeper/src/replication.rs b/walkeeper/src/replication.rs index e011b428f5..c1a4266741 100644 --- a/walkeeper/src/replication.rs +++ b/walkeeper/src/replication.rs @@ -11,8 +11,7 @@ use regex::Regex; use serde::{Deserialize, Serialize}; use std::cmp::min; use std::fs::File; -use std::io::{BufReader, Read, Seek, SeekFrom}; -use std::net::TcpStream; +use std::io::{Read, Seek, SeekFrom}; use std::path::Path; use std::sync::Arc; use std::thread::sleep; @@ -22,6 +21,7 @@ use zenith_utils::bin_ser::BeSer; use zenith_utils::lsn::Lsn; use zenith_utils::postgres_backend::PostgresBackend; use zenith_utils::pq_proto::{BeMessage, FeMessage, XLogDataBody}; +use zenith_utils::sock_split::ReadStream; pub const END_REPLICATION_MARKER: Lsn = Lsn::MAX; @@ -49,7 +49,7 @@ impl HotStandbyFeedback { pub struct ReplicationConn { /// This is an `Option` because we will spawn a background thread that will /// `take` it from us. - stream_in: Option>, + stream_in: Option, } // TODO: move this to crate::timeline when there's more users diff --git a/walkeeper/src/wal_service.rs b/walkeeper/src/wal_service.rs index 33d619bcb8..c77078560c 100644 --- a/walkeeper/src/wal_service.rs +++ b/walkeeper/src/wal_service.rs @@ -41,7 +41,7 @@ fn handle_socket(socket: TcpStream, conf: WalAcceptorConf) -> Result<()> { socket.set_nodelay(true)?; let mut conn_handler = SendWalHandler::new(conf); - let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; + let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?; // libpq replication protocol between wal_acceptor and replicas/pagers pgbackend.run(&mut conn_handler)?; diff --git a/zenith_utils/Cargo.toml b/zenith_utils/Cargo.toml index 68b8f69c12..731e2d45a2 100644 --- a/zenith_utils/Cargo.toml +++ b/zenith_utils/Cargo.toml @@ -25,6 +25,10 @@ rand = "0.8.3" jsonwebtoken = "7" hex = { version = "0.4.3", features = ["serde"] } +rustls = "0.19.1" +rustls-split = "0.2.1" + [dev-dependencies] hex-literal = "0.3" bytes = "1.0" +webpki = "0.21" diff --git a/zenith_utils/src/lib.rs b/zenith_utils/src/lib.rs index 4f37d1ce49..ea90642c6f 100644 --- a/zenith_utils/src/lib.rs +++ b/zenith_utils/src/lib.rs @@ -23,3 +23,6 @@ pub mod auth; pub mod zid; // http endpoint utils pub mod http; + +// socket splitting utils +pub mod sock_split; diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 27aa0e4277..f5fe2319f3 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -4,14 +4,16 @@ //! is rather narrow, but we can extend it once required. use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode}; -use anyhow::{bail, ensure, Result}; +use crate::sock_split::{BidiStream, ReadStream, WriteStream}; +use anyhow::{anyhow, bail, ensure, Result}; use bytes::{Bytes, BytesMut}; use log::*; use rand::Rng; use serde::{Deserialize, Serialize}; -use std::io::{self, BufReader, Write}; +use std::io::{self, Write}; use std::net::{Shutdown, SocketAddr, TcpStream}; use std::str::FromStr; +use std::sync::Arc; pub trait Handler { /// Handle single query. @@ -45,6 +47,7 @@ pub trait Handler { #[derive(Clone, Copy, PartialEq, PartialOrd)] pub enum ProtoState { Initialization, + Encrypted, Authentication, Established, } @@ -76,12 +79,40 @@ pub enum ProcessMsgResult { Break, } +/// Always-writeable sock_split stream. +/// May not be readable. See [`PostgresBackend::take_stream_in`] +pub enum Stream { + Bidirectional(BidiStream), + WriteOnly(WriteStream), +} + +impl Stream { + fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { + match self { + Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how), + Self::WriteOnly(write_stream) => write_stream.shutdown(how), + } + } +} + +impl io::Write for Stream { + fn write(&mut self, buf: &[u8]) -> io::Result { + match self { + Self::Bidirectional(bidi_stream) => bidi_stream.write(buf), + Self::WriteOnly(write_stream) => write_stream.write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self { + Self::Bidirectional(bidi_stream) => bidi_stream.flush(), + Self::WriteOnly(write_stream) => write_stream.flush(), + } + } +} + pub struct PostgresBackend { - // replication.rs wants to handle reading on its own in separate thread, so - // wrap in Option to be able to take and transfer the BufReader. Ugly, but I - // have no better ideas. - stream_in: Option>, - stream_out: TcpStream, + stream: Option, // Output buffer. c.f. BeMessage::write why we are using BytesMut here. buf_out: BytesMut, @@ -89,6 +120,9 @@ pub struct PostgresBackend { md5_salt: [u8; 4], auth_type: AuthType, + + peer_addr: SocketAddr, + pub tls_config: Option>, } pub fn query_from_cstring(query_string: Bytes) -> Vec { @@ -102,47 +136,52 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec { } impl PostgresBackend { - pub fn new(socket: TcpStream, auth_type: AuthType) -> io::Result { - let mut pb = PostgresBackend { - stream_in: None, - stream_out: socket, + pub fn new( + socket: TcpStream, + auth_type: AuthType, + tls_config: Option>, + ) -> io::Result { + let peer_addr = socket.peer_addr()?; + Ok(Self { + stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))), buf_out: BytesMut::with_capacity(10 * 1024), state: ProtoState::Initialization, md5_salt: [0u8; 4], auth_type, - }; - - // if socket cloning fails, report the error and bail out - pb.stream_in = match pb.stream_out.try_clone() { - Ok(read_sock) => Some(BufReader::new(read_sock)), - Err(error) => { - let errmsg = format!("{}", error); - let _res = pb.write_message_noflush(&BeMessage::ErrorResponse(errmsg)); - return Err(error); - } - }; - - Ok(pb) + tls_config, + peer_addr, + }) } - pub fn into_stream(self) -> TcpStream { - self.stream_out + pub fn into_stream(self) -> Stream { + self.stream.unwrap() } /// Get direct reference (into the Option) to the read stream. - fn get_stream_in(&mut self) -> Result<&mut BufReader> { - match self.stream_in { - Some(ref mut stream_in) => Ok(stream_in), - None => bail!("stream_in was taken"), + fn get_stream_in(&mut self) -> Result<&mut BidiStream> { + match &mut self.stream { + Some(Stream::Bidirectional(stream)) => Ok(stream), + _ => Err(anyhow!("reader taken")), } } - pub fn get_peer_addr(&self) -> Result { - Ok(self.stream_out.peer_addr()?) + pub fn get_peer_addr(&self) -> &SocketAddr { + &self.peer_addr } - pub fn take_stream_in(&mut self) -> Option> { - self.stream_in.take() + pub fn take_stream_in(&mut self) -> Option { + let stream = self.stream.take(); + match stream { + Some(Stream::Bidirectional(bidi_stream)) => { + let (read, write) = bidi_stream.split(); + self.stream = Some(Stream::WriteOnly(write)); + Some(read) + } + stream => { + self.stream = stream; + None + } + } } /// Read full message or return None if connection is closed. @@ -151,7 +190,7 @@ impl PostgresBackend { use ProtoState::*; match state { - Initialization => FeStartupMessage::read(stream), + Initialization | Encrypted => FeStartupMessage::read(stream), Authentication | Established => FeMessage::read(stream), } } @@ -164,7 +203,8 @@ impl PostgresBackend { /// Flush output buffer into the socket. pub fn flush(&mut self) -> io::Result<&mut Self> { - self.stream_out.write_all(&self.buf_out)?; + let stream = self.stream.as_mut().unwrap(); + stream.write_all(&self.buf_out)?; self.buf_out.clear(); Ok(self) } @@ -178,13 +218,14 @@ impl PostgresBackend { // Wrapper for run_message_loop() that shuts down socket when we are done pub fn run(mut self, handler: &mut impl Handler) -> Result<()> { let ret = self.run_message_loop(handler); - let _res = self.stream_out.shutdown(Shutdown::Both); + if let Some(stream) = self.stream.as_mut() { + let _ = stream.shutdown(Shutdown::Both); + } ret } fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<()> { - let peer_addr = self.stream_out.peer_addr()?; - trace!("postgres backend to {:?} started", peer_addr); + trace!("postgres backend to {:?} started", self.peer_addr); let mut unnamed_query_string = Bytes::new(); @@ -197,10 +238,24 @@ impl PostgresBackend { } } - trace!("postgres backend to {:?} exited", peer_addr); + trace!("postgres backend to {:?} exited", self.peer_addr); Ok(()) } + pub fn start_tls(&mut self) -> anyhow::Result<()> { + match self.stream.take() { + Some(Stream::Bidirectional(bidi_stream)) => { + let session = rustls::ServerSession::new(&self.tls_config.clone().unwrap()); + self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(session)?)); + Ok(()) + } + stream => { + self.stream = stream; + bail!("can't start TLs without bidi stream"); + } + } + } + fn process_message( &mut self, handler: &mut impl Handler, @@ -224,11 +279,30 @@ impl PostgresBackend { trace!("got startup message {:?}", m); match m.kind { - StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { + StartupRequestCode::NegotiateSsl => { info!("SSL requested"); - self.write_message(&BeMessage::Negotiate)?; + + if self.tls_config.is_some() { + self.write_message(&BeMessage::EncryptionResponse(true))?; + self.start_tls()?; + self.state = ProtoState::Encrypted; + } else { + self.write_message(&BeMessage::EncryptionResponse(false))?; + } + } + StartupRequestCode::NegotiateGss => { + info!("GSS requested"); + self.write_message(&BeMessage::EncryptionResponse(false))?; } StartupRequestCode::Normal => { + if self.tls_config.is_some() && !matches!(self.state, ProtoState::Encrypted) + { + self.write_message(&BeMessage::ErrorResponse( + "must connect with TLS".to_string(), + ))?; + bail!("client did not connect with TLS"); + } + // NB: startup() may change self.auth_type -- we are using that in proxy code // to bypass auth for new users. handler.startup(self, &m)?; diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index b2722f3b13..76a1c8cc4a 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -348,8 +348,8 @@ pub enum BeMessage<'a> { // None means column is NULL DataRow(&'a [Option<&'a [u8]>]), ErrorResponse(String), - // see https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.11 - Negotiate, + // single byte - used in response to SSLRequest/GSSENCRequest + EncryptionResponse(bool), NoData, ParameterDescription, ParameterStatus, @@ -657,8 +657,9 @@ impl<'a> BeMessage<'a> { write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); } - BeMessage::Negotiate => { - buf.put_u8(b'N'); + BeMessage::EncryptionResponse(should_negotiate) => { + let response = if *should_negotiate { b'Y' } else { b'N' }; + buf.put_u8(response); } BeMessage::ParameterStatus => { diff --git a/zenith_utils/src/sock_split.rs b/zenith_utils/src/sock_split.rs new file mode 100644 index 0000000000..5d47d933ff --- /dev/null +++ b/zenith_utils/src/sock_split.rs @@ -0,0 +1,206 @@ +use std::{ + io::{self, BufReader, Write}, + net::{Shutdown, TcpStream}, + sync::Arc, +}; + +use rustls::Session; + +/// Wrapper supporting reads of a shared TcpStream. +pub struct ArcTcpRead(Arc); + +impl io::Read for ArcTcpRead { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + (&*self.0).read(buf) + } +} + +impl std::ops::Deref for ArcTcpRead { + type Target = TcpStream; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +/// Wrapper around a TCP Stream supporting buffered reads. +pub struct BufStream(BufReader); + +impl io::Read for BufStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl io::Write for BufStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.get_ref().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.get_ref().flush() + } +} + +impl BufStream { + /// Unwrap into the internal BufReader. + fn into_reader(self) -> BufReader { + self.0 + } + + /// Returns a reference to the underlying TcpStream. + fn get_ref(&self) -> &TcpStream { + &*self.0.get_ref().0 + } +} + +pub enum ReadStream { + Tcp(BufReader), + Tls(rustls_split::ReadHalf), +} + +impl io::Read for ReadStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + Self::Tcp(reader) => reader.read(buf), + Self::Tls(read_half) => read_half.read(buf), + } + } +} + +impl ReadStream { + pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { + match self { + Self::Tcp(stream) => stream.get_ref().shutdown(how), + Self::Tls(write_half) => write_half.shutdown(how), + } + } +} + +pub enum WriteStream { + Tcp(Arc), + Tls(rustls_split::WriteHalf), +} + +impl WriteStream { + pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { + match self { + Self::Tcp(stream) => stream.shutdown(how), + Self::Tls(write_half) => write_half.shutdown(how), + } + } +} + +impl io::Write for WriteStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + match self { + Self::Tcp(stream) => stream.as_ref().write(buf), + Self::Tls(write_half) => write_half.write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self { + Self::Tcp(stream) => stream.as_ref().flush(), + Self::Tls(write_half) => write_half.flush(), + } + } +} + +pub enum BidiStream { + Tcp(BufStream), + Tls { + stream: BufStream, + session: rustls::ServerSession, + }, +} + +impl BidiStream { + pub fn from_tcp(stream: TcpStream) -> Self { + Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream))))) + } + + pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { + match self { + Self::Tcp(stream) => stream.get_ref().shutdown(how), + Self::Tls { + stream: reader, + session, + } => { + if how == Shutdown::Read { + reader.get_ref().shutdown(how) + } else { + session.send_close_notify(); + let mut stream = rustls::Stream::new(session, reader); + let res = stream.flush(); + reader.get_ref().shutdown(how)?; + res + } + } + } + } + + /// Split the bi-directional stream into two owned read and write halves. + pub fn split(self) -> (ReadStream, WriteStream) { + match self { + Self::Tcp(stream) => { + let reader = stream.into_reader(); + let stream: Arc = reader.get_ref().0.clone(); + + (ReadStream::Tcp(reader), WriteStream::Tcp(stream)) + } + Self::Tls { stream, session } => { + let reader = stream.into_reader(); + let buffer_data = reader.buffer().to_owned(); + let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192); + let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192); + + // TODO would be nice to avoid the Arc here + let socket = Arc::try_unwrap(reader.into_inner().0).unwrap(); + + let (read_half, write_half) = + rustls_split::split(socket, session, read_buf_cfg, write_buf_cfg); + (ReadStream::Tls(read_half), WriteStream::Tls(write_half)) + } + } + } + + pub fn start_tls(self, mut session: rustls::ServerSession) -> io::Result { + match self { + Self::Tcp(mut stream) => { + session.complete_io(&mut stream)?; + assert!(!session.is_handshaking()); + Ok(Self::Tls { stream, session }) + } + Self::Tls { .. } => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "TLS is already started on this stream", + )), + } + } +} + +impl io::Read for BidiStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + Self::Tcp(stream) => stream.read(buf), + Self::Tls { stream, session } => rustls::Stream::new(session, stream).read(buf), + } + } +} + +impl io::Write for BidiStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + match self { + Self::Tcp(stream) => stream.write(buf), + Self::Tls { stream, session } => rustls::Stream::new(session, stream).write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self { + Self::Tcp(stream) => stream.flush(), + Self::Tls { stream, session } => rustls::Stream::new(session, stream).flush(), + } + } +} diff --git a/zenith_utils/tests/cert.pem b/zenith_utils/tests/cert.pem new file mode 100644 index 0000000000..ecd3c17cec --- /dev/null +++ b/zenith_utils/tests/cert.pem @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDbjCCAlagAwIBAgIUGHJukXa1bQathgBHC40+A18BsnYwDQYJKoZIhvcNAQEL +BQAwYzELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcM +DVNhbiBGcmFuY2lzY28xEzARBgNVBAoMCk15IENvbXBhbnkxEjAQBgNVBAMMCWxv +Y2FsaG9zdDAgFw0yMTA4MTMxODQyMjBaGA8yMTIxMDcyMDE4NDIyMFowYzELMAkG +A1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDVNhbiBGcmFu +Y2lzY28xEzARBgNVBAoMCk15IENvbXBhbnkxEjAQBgNVBAMMCWxvY2FsaG9zdDCC +ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOI9S+nh8ABMp5jpb7WWfAYr +tGJ4C7gi9IPTVIRxSSrt5KglEysrOiKlhan1Ut2e8CCudztdXtCvT8/goJWlmxpF +IQkErlCsOdGHeEJ0EZxoU1fMkBAQVf6Rb1JE9ladG2+D1e7yvxmMqfPVuU8lj+kN +nESP+I3ESNCtuqgtfcErxu3TuhSzV2slSi5lrYQCwERgCevl6LUNd2mEaYdS4mmJ +4RZqc2C4y7JO5wSDjga8GIBHJVo70HRVsvX7eE8r6tMP2HyGyonBitBKAc2QEQIv +cLCuMOTtTBlYcMvTmJEOHFKwIJXm0XmQfAWeKFfyK7493fB4Gu+8Dc1xC+IHaTEC +AwEAAaMYMBYwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IB +AQBjY+g3eF8m8lEWz+QgKp88MhTdtJTsEsSz0GAi58SnEkuyxVOHjKEyjGKJWTtT +ICgmEzC85uaS7VBdftoYNmsbvNewGiisDGQRWCjOGM7lTaA4FQPADguexMvXh/nO +9PQoTxtp7qwvGWO2mED6LWU6bjT3cL+XgrOwT9sticRTl6/BXV8wAmyxT0DkQ3nJ +zbRuTP/G2kE0bRK++67kK0ovopRkX6Dl6di1EFlkAnPBC2d8tdcNTXYhkxZk4O0q +GUolwiuWz/dtD3tZ2bx3vqzT7uIFHS4XP6Q3SRNWFTGhuvAc7DPvCZBqxy6odeyQ +VxBgJtq+pNjYYkeaSQVQ+UMU +-----END CERTIFICATE----- diff --git a/zenith_utils/tests/key.pem b/zenith_utils/tests/key.pem new file mode 100644 index 0000000000..8283be5f5f --- /dev/null +++ b/zenith_utils/tests/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA4j1L6eHwAEynmOlvtZZ8Biu0YngLuCL0g9NUhHFJKu3kqCUT +Kys6IqWFqfVS3Z7wIK53O11e0K9Pz+CglaWbGkUhCQSuUKw50Yd4QnQRnGhTV8yQ +EBBV/pFvUkT2Vp0bb4PV7vK/GYyp89W5TyWP6Q2cRI/4jcRI0K26qC19wSvG7dO6 +FLNXayVKLmWthALARGAJ6+XotQ13aYRph1LiaYnhFmpzYLjLsk7nBIOOBrwYgEcl +WjvQdFWy9ft4Tyvq0w/YfIbKicGK0EoBzZARAi9wsK4w5O1MGVhwy9OYkQ4cUrAg +lebReZB8BZ4oV/Irvj3d8Hga77wNzXEL4gdpMQIDAQABAoIBAQClKycO+zpinZQG +GPbLVa/6OVIaSZYUusBUtaaQgrxuMPusnlSeQZLR1JH/APGchvq8gWLe3k3ogPT9 +yPq0BhF0Xl+928L/dp1HkWWE7oQk8i1Wfiv27lY54iepoltN5KkxAsjfCC3oEz/I +mpINbFjiRmN90rYdmd2nLA6H1Z5ntZQm5AcTo3OJZlTVN9eH9TV8f0AQRQgUJsL9 +75agSmj7euqZOqvvwfpsYzaZEhzMSG2QIcS3WglInbHy8c6ikZSm36J36wgsatMz +CBZ6pMNtonRSKvAECQhBGEA73evtnGbLH0EY9KouN4KSHEHob89dGVeeXozksf9x +QUE1/yOhAoGBAP818f7vIH6Z3QwWgTMwQsPBW+wNOIbTZrbZaihnz2K9XMu39TV6 +DWQHMsOlvg2QURZGwqB3jFn4wqZHmt7XYwk553E60kIw4hDvgpkkqmXVwK3kZASQ +RRUax3hZ1gCWxpXlRZ1SvHNXjN9KEFwqQbR33XcxzC3TpSp0KYghT9jFAoGBAOLw +agejqSF+f/5W1QhEKlM+tSlluo2sn5kKVkM4nNezFukb3pu5oScFjoQQGsoaz5aU +kLlxW5h/aSxquhgcuo6I4Ux5dcgNm4QeonCCp+Qycn7tzyoJFL4odT9vYPQa5O9E +hD9aSqhBBD1IIOS2T3vcW6VxibKZx1CRMDdRz119AoGALflr1L8DHYteNLVBJRWG +kXkdtBJVooQmtr3Hz+uTgngWZWSIOc/45ZIeZPxQlmTvFpI8sWeX0wVrG0U+8vHe +F2Vk+hLcmavwrZhX8HqYb6vn/+tq0R+kMj8Wu+mDEawXrh0VQ1gKNsUIzZisBc5e +88G8FaLU41SDJniymqFVnvkCgYEA1ou/UfWRwg6b5tIkmKoI8aZJExgPpDzcrYyu +POLatLmlIUCt1b9K8V85evTWvtdWBd/yar8WfzeFMO69fGo8nOAfT3NMvJLQwblM +jN2Y6A4hXIpq3iyzpYsOPaiImn6KjQHTnSk5h5Pf9CeqoU8SGeEb629JZMYpPqvk +T4hSaOkCgYBPaf51oSAstqdj0vxrsFS3EN3D8Fk0xQWt9Ss3ZGFAlTaEq5xoIk4k +YfKVDv1S6/vlzbheIIzQ2lzVvG4AW+drQLsmEx5iMKvbNtFAur9kwUFU202Q2dki +ZQJ/JvjnPYFKxy+SVlLJ1h9RD9E3dgL/Ai7OUfbmX771vN0IQF7Z6Q== +-----END RSA PRIVATE KEY----- diff --git a/zenith_utils/tests/ssl_test.rs b/zenith_utils/tests/ssl_test.rs new file mode 100644 index 0000000000..3bc5ffb790 --- /dev/null +++ b/zenith_utils/tests/ssl_test.rs @@ -0,0 +1,224 @@ +use std::{ + collections::HashMap, + io::{Cursor, Read, Write}, + net::{TcpListener, TcpStream}, + sync::Arc, +}; + +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use lazy_static::lazy_static; +use rustls::Session; + +use zenith_utils::postgres_backend::{AuthType, Handler, PostgresBackend}; + +fn make_tcp_pair() -> (TcpStream, TcpStream) { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + let client_stream = TcpStream::connect(addr).unwrap(); + let (server_stream, _) = listener.accept().unwrap(); + (server_stream, client_stream) +} + +lazy_static! { + static ref KEY: rustls::PrivateKey = { + let mut cursor = Cursor::new(include_bytes!("key.pem")); + rustls::internal::pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone() + }; + static ref CERT: rustls::Certificate = { + let mut cursor = Cursor::new(include_bytes!("cert.pem")); + rustls::internal::pemfile::certs(&mut cursor).unwrap()[0].clone() + }; +} + +#[test] +fn ssl() { + let (mut client_sock, server_sock) = make_tcp_pair(); + + const QUERY: &[u8] = b"hello world"; + + let client_jh = std::thread::spawn(move || { + // SSLRequest + client_sock.write_u32::(8).unwrap(); + client_sock.write_u32::(80877103).unwrap(); + + let ssl_response = client_sock.read_u8().unwrap(); + assert_eq!(b'Y', ssl_response); + + let mut cfg = rustls::ClientConfig::new(); + cfg.root_store.add(&CERT).unwrap(); + let client_config = Arc::new(cfg); + + let dns_name = webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap(); + let mut session = rustls::ClientSession::new(&client_config, dns_name); + + session.complete_io(&mut client_sock).unwrap(); + assert!(!session.is_handshaking()); + + let mut stream = rustls::Stream::new(&mut session, &mut client_sock); + + // StartupMessage + stream.write_u32::(9).unwrap(); + stream.write_u32::(196608).unwrap(); + stream.write_u8(0).unwrap(); + stream.flush().unwrap(); + + // wait for ReadyForQuery + let mut msg_buf = Vec::new(); + loop { + let msg = stream.read_u8().unwrap(); + let size = stream.read_u32::().unwrap() - 4; + msg_buf.resize(size as usize, 0); + stream.read_exact(&mut msg_buf).unwrap(); + + if msg == b'Z' { + // ReadyForQuery + break; + } + } + + // Query + stream.write_u8(b'Q').unwrap(); + stream + .write_u32::(4u32 + QUERY.len() as u32) + .unwrap(); + stream.write_all(QUERY).unwrap(); + stream.flush().unwrap(); + + // ReadyForQuery + let msg = stream.read_u8().unwrap(); + assert_eq!(msg, b'Z'); + }); + + struct TestHandler { + got_query: bool, + } + impl Handler for TestHandler { + fn process_query( + &mut self, + _pgb: &mut PostgresBackend, + query_string: bytes::Bytes, + ) -> anyhow::Result<()> { + self.got_query = query_string.as_ref() == QUERY; + Ok(()) + } + } + let mut handler = TestHandler { got_query: false }; + + let mut cfg = rustls::ServerConfig::new(rustls::NoClientAuth::new()); + cfg.set_single_cert(vec![CERT.clone()], KEY.clone()) + .unwrap(); + let tls_config = Some(Arc::new(cfg)); + + let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config).unwrap(); + pgb.run(&mut handler).unwrap(); + assert!(handler.got_query); + + client_jh.join().unwrap(); + + // TODO consider shutdown behavior +} + +#[test] +fn no_ssl() { + let (mut client_sock, server_sock) = make_tcp_pair(); + + let client_jh = std::thread::spawn(move || { + let mut buf = BytesMut::new(); + + // SSLRequest + buf.put_u32(8); + buf.put_u32(80877103); + client_sock.write_all(&buf).unwrap(); + buf.clear(); + + let ssl_response = client_sock.read_u8().unwrap(); + assert_eq!(b'N', ssl_response); + }); + + struct TestHandler; + + impl Handler for TestHandler { + fn process_query( + &mut self, + _pgb: &mut PostgresBackend, + _query_string: bytes::Bytes, + ) -> anyhow::Result<()> { + panic!() + } + } + + let mut handler = TestHandler; + + let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None).unwrap(); + pgb.run(&mut handler).unwrap(); + + client_jh.join().unwrap(); +} + +#[test] +fn server_forces_ssl() { + let (mut client_sock, server_sock) = make_tcp_pair(); + + let client_jh = std::thread::spawn(move || { + // StartupMessage + client_sock.write_u32::(9).unwrap(); + client_sock.write_u32::(196608).unwrap(); + client_sock.write_u8(0).unwrap(); + client_sock.flush().unwrap(); + + // ErrorResponse + assert_eq!(client_sock.read_u8().unwrap(), b'E'); + let len = client_sock.read_u32::().unwrap() - 4; + + let mut body = vec![0; len as usize]; + client_sock.read_exact(&mut body).unwrap(); + let mut body = Bytes::from(body); + + let mut errors = HashMap::new(); + loop { + let field_type = body.get_u8(); + if field_type == 0u8 { + break; + } + + let end_idx = body.iter().position(|&b| b == 0u8).unwrap(); + let mut value = body.split_to(end_idx + 1); + assert_eq!(value[end_idx], 0u8); + value.truncate(end_idx); + let old = errors.insert(field_type, value); + assert!(old.is_none()); + } + + assert!(!body.has_remaining()); + + assert_eq!("must connect with TLS", errors.get(&b'M').unwrap()); + + // TODO read failure + }); + + struct TestHandler; + impl Handler for TestHandler { + fn process_query( + &mut self, + _pgb: &mut PostgresBackend, + _query_string: bytes::Bytes, + ) -> anyhow::Result<()> { + panic!() + } + } + let mut handler = TestHandler; + + let mut cfg = rustls::ServerConfig::new(rustls::NoClientAuth::new()); + cfg.set_single_cert(vec![CERT.clone()], KEY.clone()) + .unwrap(); + let tls_config = Some(Arc::new(cfg)); + + let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config).unwrap(); + let res = pgb.run(&mut handler).unwrap_err(); + assert_eq!("client did not connect with TLS", format!("{}", res)); + + client_jh.join().unwrap(); + + // TODO consider shutdown behavior +}