From dc41d108e8bc1d07acbdf2ba2394f4cdc05e6ad5 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 12 Sep 2024 11:45:20 +0100 Subject: [PATCH] add conn state with heartbeat system --- proxy/src/bin/pglb.rs | 60 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index b8e1c78814..0b79cb6a42 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -1,7 +1,22 @@ -use std::net::SocketAddr; +use std::{ + collections::HashMap, + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, +}; use anyhow::Context; -use quinn::Endpoint; +use quinn::{Connection, Endpoint}; +use tokio::time::timeout; + +struct ConnState { + conns: Mutex>, +} + +struct Conn { + conn: Connection, + // latency info... +} #[tokio::main] async fn main() { @@ -9,7 +24,11 @@ async fn main() { .await .unwrap(); - let quinn_handle = tokio::spawn(quinn_server(endpoint.clone())); + let connections = Arc::new(ConnState { + conns: Mutex::new(HashMap::new()), + }); + + let quinn_handle = tokio::spawn(quinn_server(endpoint.clone(), connections.clone())); // tcp listener goes here @@ -36,6 +55,37 @@ async fn endpoint_config(addr: SocketAddr) -> anyhow::Result { Endpoint::server(config, addr).context("endpoint") } -async fn quinn_server(_ep: Endpoint) { - std::future::pending().await +async fn quinn_server(ep: Endpoint, state: Arc) { + loop { + let incoming = ep.accept().await.expect("quinn server should not crash"); + let state = state.clone(); + tokio::spawn(async move { + let conn = incoming.await.unwrap(); + + let conn_id = conn.stable_id(); + println!("[{conn_id:?}] new conn"); + + state + .conns + .lock() + .unwrap() + .insert(conn_id, Conn { conn: conn.clone() }); + + loop { + match timeout(Duration::from_secs(1), conn.accept_uni()).await { + Ok(Ok(_)) => {} + Ok(Err(conn_err)) => { + println!("[{conn_id:?}] conn err {conn_err:?}"); + state.conns.lock().unwrap().remove(&conn_id); + break; + } + Err(_) => { + println!("[{conn_id:?}] conn timeout err"); + state.conns.lock().unwrap().remove(&conn_id); + break; + } + } + } + }); + } }