From 24c48856a27399e1bc0b61efa5bd3bfc14ea573b Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 13 Sep 2024 09:25:11 +0100 Subject: [PATCH] update ssl handling and add some logs --- proxy/src/bin/pglb.rs | 71 +++++++++++++++++++++++++++++-------------- proxy/src/proxy.rs | 8 +++++ 2 files changed, 56 insertions(+), 23 deletions(-) diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index 634b576417..b7e4a02dd5 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -9,6 +9,8 @@ use anyhow::{anyhow, bail, Context, Result}; use bytes::{Buf, BufMut, BytesMut}; use futures::{SinkExt, StreamExt}; use indexmap::IndexMap; +use itertools::Itertools; +use pq_proto::BeMessage; use proxy::{ config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL}, ConnectionInitiatedPayload, PglbCodec, PglbControlMessage, PglbMessage, @@ -17,7 +19,7 @@ use quinn::{Connection, Endpoint, RecvStream, SendStream}; use rand::Rng; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use tokio::{ - io::{join, AsyncReadExt, Join}, + io::{join, AsyncReadExt, AsyncWriteExt, Join}, net::{TcpListener, TcpStream}, select, time::timeout, @@ -54,7 +56,7 @@ async fn main() -> Result<()> { let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone())); - let frontend_config = frontent_tls_config("pglb-fe", "pglb-fe")?; + let frontend_config = frontent_tls_config("*.localtest.me", "*.localtest.me")?; let _frontend_handle = tokio::spawn(start_frontend( "0.0.0.0:5432".parse()?, @@ -131,24 +133,37 @@ async fn quinn_server(ep: Endpoint, state: Arc) { } fn frontent_tls_config(hostname: &str, common_name: &str) -> Result { - let ca = rcgen::Certificate::from_params({ - let mut params = rcgen::CertificateParams::default(); - params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); - params - })?; + // let ca = rcgen::Certificate::from_params({ + // let mut params = rcgen::CertificateParams::default(); + // params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + // params + // })?; - let cert = rcgen::Certificate::from_params({ - let mut params = rcgen::CertificateParams::new(vec![hostname.into()]); - params.distinguished_name = rcgen::DistinguishedName::new(); - params - .distinguished_name - .push(rcgen::DnType::CommonName, common_name); - params - })?; + // let cert = rcgen::Certificate::from_params({ + // let mut params = rcgen::CertificateParams::new(vec![hostname.into()]); + // params.distinguished_name = rcgen::DistinguishedName::new(); + // params + // .distinguished_name + // .push(rcgen::DnType::CommonName, common_name); + // params + // })?; + + // let (cert, key) = ( + // CertificateDer::from(cert.serialize_der_with_signer(&ca)?), + // PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()), + // ); let (cert, key) = ( - CertificateDer::from(cert.serialize_der_with_signer(&ca)?), - PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()), + rustls_pemfile::certs(&mut &*std::fs::read("proxy.crt").unwrap()) + .collect_vec() + .remove(0) + .unwrap(), + PrivateKeyDer::Pkcs8( + rustls_pemfile::pkcs8_private_keys(&mut &*std::fs::read("proxy.key").unwrap()) + .collect_vec() + .remove(0) + .unwrap(), + ), ); let config = rustls::ServerConfig::builder() @@ -173,11 +188,14 @@ async fn start_frontend( let listener = TcpListener::bind(addr).await?; socket2::SockRef::from(&listener).set_keepalive(true)?; + println!("starting"); + let connections = tokio_util::task::task_tracker::TaskTracker::new(); loop { match listener.accept().await { Ok((socket, client_addr)) => { + println!("accepted"); let conn = PglbConn::new(&state, &tls)?; connections.spawn(conn.handle(socket, client_addr)); } @@ -269,6 +287,12 @@ impl PglbConn { }; let (client_stream, payload) = if tls_requested { + println!("starting tls upgrade"); + + let mut buf = BytesMut::new(); + BeMessage::write(&mut buf, &BeMessage::EncryptionResponse(true)).unwrap(); + client_stream.write_all(&buf).await?; + let (stream, tls_server_end_point, server_name) = match Self::tls_upgrade(client_stream, self.inner.tls_config.clone()).await { Ok((stream, ep, sn)) => (stream, ep, sn), @@ -290,6 +314,8 @@ impl PglbConn { bail!("closing non-TLS connection"); }; + println!("tls done"); + Ok(PglbConn { inner: self.inner, state: ClientConnect { @@ -300,7 +326,8 @@ impl PglbConn { } async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result { - let mut buf = BytesMut::with_capacity(8); + println!("checking for ssl request"); + let mut buf = vec![0u8; 8]; let n_peek = stream.peek(&mut buf).await?; if n_peek == 0 { @@ -315,11 +342,7 @@ impl PglbConn { { return Ok(false); } - - buf.clear(); - let n_read = stream.read(&mut buf).await?; - - assert_eq!(n_peek, n_read); // TODO: loop, read more + stream.read_exact(&mut buf).await?; Ok(true) } @@ -376,6 +399,7 @@ impl PglbConn { // TODO: check closed? }; + println!("connecting to {}", auth_conn.conn.stable_id()); let (send, recv) = auth_conn.conn.open_bi().await?; let mut auth_stream = Framed::new(join(recv, send), PglbCodec); @@ -431,6 +455,7 @@ impl PglbConn { bail!("auth proxy sent unexpected message"); } PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket }) => { + println!("socket"); return Ok(PglbConn { inner: self.inner, state: ComputeConnect { diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 9b50686b51..146b4bf369 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -477,18 +477,26 @@ pub async fn handle_stream( panic!("invalid first msg") }; + println!("new conn: {conn_info:?}"); + // read startup packet let startup = stream.read_startup_packet().await?; let FeStartupPacket::StartupMessage { version: _, params } = startup else { panic!("invalid startup message") }; + println!("params: {params:?}"); + let user_info = auth_with_user(&mut stream, config, &conn_info, ¶ms).await?; + println!("authenticated"); + // wake the compute let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?; let socket: SocketAddr = node_info.config.get_host()?.parse()?; + println!("woke compute"); + // tell pglb that the compute is up stream .write_message(&pq_proto::BeMessage::AuthenticationOk)