From fbc37acfdff247f5deb2bb3c2d61b48da5114c4b Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 12 Sep 2024 12:44:48 +0100 Subject: [PATCH] add auth proxy client connection handling --- proxy/src/bin/auth_proxy.rs | 105 +++++++++++++++++++++++++++++++++++- proxy/src/bin/pglb.rs | 25 +++++---- 2 files changed, 120 insertions(+), 10 deletions(-) diff --git a/proxy/src/bin/auth_proxy.rs b/proxy/src/bin/auth_proxy.rs index 7f755fb76d..85e5c8abdc 100644 --- a/proxy/src/bin/auth_proxy.rs +++ b/proxy/src/bin/auth_proxy.rs @@ -1,2 +1,105 @@ +use std::{sync::Arc, time::Duration}; + +use quinn::{ + crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream, + VarInt, +}; +use tokio::{ + io::AsyncWriteExt, + select, + signal::unix::{signal, SignalKind}, + time::interval, +}; +use tokio_util::task::TaskTracker; + #[tokio::main] -async fn main() {} +async fn main() { + let server = "127.0.0.1:5634".parse().unwrap(); + let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap(); + + let crypto = quinn::rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(NoVerify)) + .with_no_client_auth(); + + let crypto = QuicClientConfig::try_from(crypto).unwrap(); + + let config = quinn::ClientConfig::new(Arc::new(crypto)); + endpoint.set_default_client_config(config); + + let mut int = signal(SignalKind::interrupt()).unwrap(); + let mut term = signal(SignalKind::terminate()).unwrap(); + + let conn = endpoint.connect(server, "pglb").unwrap().await.unwrap(); + let mut interval = interval(Duration::from_secs(2)); + + let tasks = TaskTracker::new(); + + loop { + select! { + _ = int.recv() => break, + _ = term.recv() => break, + _ = interval.tick() => { + let mut stream = conn.open_uni().await.unwrap(); + stream.flush().await.unwrap(); + stream.finish().unwrap(); + } + stream = conn.accept_bi() => { + let (send, recv) = stream.unwrap(); + tasks.spawn(handle_stream(send, recv)); + } + } + } + + // graceful shutdown + { + let mut stream = conn.open_uni().await.unwrap(); + stream.write_all(b"shutdown").await.unwrap(); + stream.flush().await.unwrap(); + stream.finish().unwrap(); + } + + tasks.close(); + tasks.wait().await; + conn.close(VarInt::from_u32(1), b"graceful shutdown"); +} + +#[derive(Copy, Clone, Debug)] +struct NoVerify; + +impl danger::ServerCertVerifier for NoVerify { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &quinn::rustls::DigitallySignedStruct, + ) -> Result { + Ok(danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &quinn::rustls::DigitallySignedStruct, + ) -> Result { + Ok(danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![quinn::rustls::SignatureScheme::ECDSA_NISTP256_SHA256] + } +} + +async fn handle_stream(_send: SendStream, _recv: RecvStream) {} diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index a5e56564b1..b38bc2b991 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -39,10 +39,7 @@ async fn main() -> anyhow::Result<()> { conns: Mutex::new(IndexMap::new()), }); - let quinn_handle = tokio::spawn(quinn_server( - auth_endpoint.clone(), - auth_connections.clone(), - )); + let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone())); let _frontend_handle = tokio::spawn(start_frontend("127.0.0.1:0")); @@ -52,7 +49,7 @@ async fn main() -> anyhow::Result<()> { } async fn endpoint_config(addr: SocketAddr) -> anyhow::Result { - use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; + use quinn::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; let mut params = rcgen::CertificateParams::new(vec!["pglb".to_string()]); params @@ -87,21 +84,31 @@ async fn quinn_server(ep: Endpoint, state: Arc) { .unwrap() .insert(conn_id, AuthConn { conn: conn.clone() }); + // heartbeat loop loop { - match timeout(Duration::from_secs(1), conn.accept_uni()).await { - Ok(Ok(_)) => {} + match timeout(Duration::from_secs(10), conn.accept_uni()).await { + Ok(Ok(mut heartbeat_stream)) => { + let data = heartbeat_stream.read_to_end(128).await.unwrap(); + if data.starts_with(b"shutdown") { + println!("[{conn_id:?}] conn shutdown"); + break; + } + // else update latency info + } Ok(Err(conn_err)) => { println!("[{conn_id:?}] conn err {conn_err:?}"); - state.conns.lock().unwrap().remove(&conn_id); break; } Err(_) => { println!("[{conn_id:?}] conn timeout err"); - state.conns.lock().unwrap().remove(&conn_id); break; } } } + + state.conns.lock().unwrap().remove(&conn_id); + let conn_closed = conn.closed().await; + println!("[{conn_id:?}] conn closed {conn_closed:?}"); }); } }