add auth proxy client connection handling

This commit is contained in:
Conrad Ludgate
2024-09-12 12:44:48 +01:00
parent b71bf47c33
commit fbc37acfdf
2 changed files with 120 additions and 10 deletions

View File

@@ -1,2 +1,105 @@
use std::{sync::Arc, time::Duration};
use quinn::{
crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream,
VarInt,
};
use tokio::{
io::AsyncWriteExt,
select,
signal::unix::{signal, SignalKind},
time::interval,
};
use tokio_util::task::TaskTracker;
#[tokio::main]
async fn main() {}
async fn main() {
let server = "127.0.0.1:5634".parse().unwrap();
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
let crypto = quinn::rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerify))
.with_no_client_auth();
let crypto = QuicClientConfig::try_from(crypto).unwrap();
let config = quinn::ClientConfig::new(Arc::new(crypto));
endpoint.set_default_client_config(config);
let mut int = signal(SignalKind::interrupt()).unwrap();
let mut term = signal(SignalKind::terminate()).unwrap();
let conn = endpoint.connect(server, "pglb").unwrap().await.unwrap();
let mut interval = interval(Duration::from_secs(2));
let tasks = TaskTracker::new();
loop {
select! {
_ = int.recv() => break,
_ = term.recv() => break,
_ = interval.tick() => {
let mut stream = conn.open_uni().await.unwrap();
stream.flush().await.unwrap();
stream.finish().unwrap();
}
stream = conn.accept_bi() => {
let (send, recv) = stream.unwrap();
tasks.spawn(handle_stream(send, recv));
}
}
}
// graceful shutdown
{
let mut stream = conn.open_uni().await.unwrap();
stream.write_all(b"shutdown").await.unwrap();
stream.flush().await.unwrap();
stream.finish().unwrap();
}
tasks.close();
tasks.wait().await;
conn.close(VarInt::from_u32(1), b"graceful shutdown");
}
#[derive(Copy, Clone, Debug)]
struct NoVerify;
impl danger::ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<danger::ServerCertVerified, quinn::rustls::Error> {
Ok(danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &quinn::rustls::DigitallySignedStruct,
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
Ok(danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &quinn::rustls::DigitallySignedStruct,
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
Ok(danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<quinn::rustls::SignatureScheme> {
vec![quinn::rustls::SignatureScheme::ECDSA_NISTP256_SHA256]
}
}
async fn handle_stream(_send: SendStream, _recv: RecvStream) {}

View File

@@ -39,10 +39,7 @@ async fn main() -> anyhow::Result<()> {
conns: Mutex::new(IndexMap::new()),
});
let quinn_handle = tokio::spawn(quinn_server(
auth_endpoint.clone(),
auth_connections.clone(),
));
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
let _frontend_handle = tokio::spawn(start_frontend("127.0.0.1:0"));
@@ -52,7 +49,7 @@ async fn main() -> anyhow::Result<()> {
}
async fn endpoint_config(addr: SocketAddr) -> anyhow::Result<Endpoint> {
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use quinn::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
let mut params = rcgen::CertificateParams::new(vec!["pglb".to_string()]);
params
@@ -87,21 +84,31 @@ async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
.unwrap()
.insert(conn_id, AuthConn { conn: conn.clone() });
// heartbeat loop
loop {
match timeout(Duration::from_secs(1), conn.accept_uni()).await {
Ok(Ok(_)) => {}
match timeout(Duration::from_secs(10), conn.accept_uni()).await {
Ok(Ok(mut heartbeat_stream)) => {
let data = heartbeat_stream.read_to_end(128).await.unwrap();
if data.starts_with(b"shutdown") {
println!("[{conn_id:?}] conn shutdown");
break;
}
// else update latency info
}
Ok(Err(conn_err)) => {
println!("[{conn_id:?}] conn err {conn_err:?}");
state.conns.lock().unwrap().remove(&conn_id);
break;
}
Err(_) => {
println!("[{conn_id:?}] conn timeout err");
state.conns.lock().unwrap().remove(&conn_id);
break;
}
}
}
state.conns.lock().unwrap().remove(&conn_id);
let conn_closed = conn.closed().await;
println!("[{conn_id:?}] conn closed {conn_closed:?}");
});
}
}