From 469597fdb6ad766d7206e794663386cb97be932c Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Thu, 12 Sep 2024 15:44:26 +0100 Subject: [PATCH] TLS conn accept --- proxy/src/bin/pglb.rs | 152 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 134 insertions(+), 18 deletions(-) diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index b38bc2b991..70ada56501 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -5,14 +5,19 @@ use std::{ time::Duration, }; -use anyhow::Context; +use anyhow::{anyhow, bail, Context, Result}; +use bytes::BytesMut; use indexmap::IndexMap; +use proxy::config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL}; use quinn::{Connection, Endpoint}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use tokio::{ + io::AsyncReadExt, net::{TcpListener, TcpStream}, time::timeout, }; -use tracing::error; +use tokio_rustls::server::TlsStream; +use tracing::{error, warn}; type AuthConnId = usize; struct AuthConnState { @@ -28,12 +33,10 @@ struct AuthConn { static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<()> { let _logging_guard = proxy::logging::init().await?; - let auth_endpoint: Endpoint = endpoint_config("0.0.0.0:5634".parse().unwrap()) - .await - .unwrap(); + let auth_endpoint: Endpoint = endpoint_config("0.0.0.0:5634".parse()?).await?; let auth_connections = Arc::new(AuthConnState { conns: Mutex::new(IndexMap::new()), @@ -41,16 +44,16 @@ async fn main() -> anyhow::Result<()> { let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone())); - let _frontend_handle = tokio::spawn(start_frontend("127.0.0.1:0")); + let frontend_config = frontent_tls_config("pglb-fe", "pglb-fe")?; + + let _frontend_handle = tokio::spawn(start_frontend("0.0.0.0:5432".parse()?, frontend_config)); quinn_handle.await.unwrap(); Ok(()) } -async fn endpoint_config(addr: SocketAddr) -> anyhow::Result { - use quinn::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; - +async fn endpoint_config(addr: SocketAddr) -> Result { let mut params = rcgen::CertificateParams::new(vec!["pglb".to_string()]); params .distinguished_name @@ -113,8 +116,42 @@ async fn quinn_server(ep: Endpoint, state: Arc) { } } -async fn start_frontend(addr: &str) -> anyhow::Result { - let addr: SocketAddr = addr.parse()?; +fn frontent_tls_config(hostname: &str, common_name: &str) -> Result { + let ca = rcgen::Certificate::from_params({ + let mut params = rcgen::CertificateParams::default(); + params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + params + })?; + + let cert = rcgen::Certificate::from_params({ + let mut params = rcgen::CertificateParams::new(vec![hostname.into()]); + params.distinguished_name = rcgen::DistinguishedName::new(); + params + .distinguished_name + .push(rcgen::DnType::CommonName, common_name); + params + })?; + + let (cert, key) = ( + CertificateDer::from(cert.serialize_der_with_signer(&ca)?), + PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()), + ); + + let config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert.clone()], key.clone_key())? + .into(); + + let mut cert_resolver = CertResolver::new(); + cert_resolver.add_cert(key, vec![cert], true)?; + + Ok(TlsConfig { + config, + cert_resolver: Arc::new(cert_resolver), + }) +} + +async fn start_frontend(addr: SocketAddr, tls: TlsConfig) -> Result { let listener = TcpListener::bind(addr).await?; socket2::SockRef::from(&listener).set_keepalive(true)?; @@ -123,23 +160,102 @@ async fn start_frontend(addr: &str) -> anyhow::Result { loop { match listener.accept().await { Ok((socket, peer_addr)) => { - connections.spawn(handle_frontend_connection(socket, peer_addr)); + let tls = tls.clone(); + connections.spawn_local(handle_frontend_connection(socket, peer_addr, tls)); + } + Err(e) => { + error!("connection accept error: {e}"); } - Err(e) => {} } } } -async fn handle_frontend_connection(socket: TcpStream, _peer_addr: SocketAddr) { - match socket.set_nodelay(true) { +async fn handle_frontend_connection(mut stream: TcpStream, _peer_addr: SocketAddr, tls: TlsConfig) { + match stream.set_nodelay(true) { Ok(()) => {} Err(e) => { - error!("per-client task finished with an error: failed to set socket option: {e:#}"); + error!("socket option error: {e}"); return; } }; // TODO: HAProxy protocol? + + let tls_requested = match handle_ssl_request_message(&mut stream).await { + Ok(tls_requested) => tls_requested, + Err(e) => { + error!("check_for_ssl_request: {e}"); + return; + } + }; + + if tls_requested { + let (stream, ep, sn) = match tls_upgrade(stream, tls).await { + Ok((stream, ep, sn)) => (stream, ep, sn), + Err(e) => { + error!("tls_upgrade: {e}"); + return; + } + }; + + // TODO: send auth msg with tls ep and server name + } else { + // TODO: send auth msg without server name + } } -// TODO: client state machine +async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result { + let mut buf = BytesMut::with_capacity(8); + + let n_peek = stream.peek(&mut buf).await?; + if n_peek == 0 { + bail!("EOF"); + } + + assert_eq!(buf.len(), 8); // TODO: loop, read more + + if buf.len() != 8 || buf[0..4] != 8u32.to_be_bytes() || buf[4..8] != 80877103u32.to_be_bytes() { + return Ok(false); + } + + buf.clear(); + let n_read = stream.read(&mut buf).await?; + + assert_eq!(n_peek, n_read); // TODO: loop, read more + + Ok(true) +} + +async fn tls_upgrade( + stream: TcpStream, + tls: TlsConfig, +) -> Result<(TlsStream, TlsServerEndPoint, Option)> { + let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config) + .accept(stream) + .await?; + + let conn_info = tls_stream.get_ref().1; + let server_name = conn_info.server_name().map(|s| s.to_string()); + + match conn_info.alpn_protocol() { + None | Some(PG_ALPN_PROTOCOL) => {} + Some(other) => { + let alpn = String::from_utf8_lossy(other); + warn!(%alpn, "unexpected ALPN"); + bail!("protocol violation"); + } + } + + let (_, tls_server_end_point) = tls + .cert_resolver + .resolve(server_name.as_deref()) + .ok_or(anyhow!("missing cert"))?; + + Ok((tls_stream, tls_server_end_point, server_name)) +} + +#[derive(Clone, Debug)] +struct TlsConfig { + config: Arc, + cert_resolver: Arc, +}