mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
add auth proxy client connection handling
This commit is contained in:
@@ -1,2 +1,105 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use quinn::{
|
||||
crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream,
|
||||
VarInt,
|
||||
};
|
||||
use tokio::{
|
||||
io::AsyncWriteExt,
|
||||
select,
|
||||
signal::unix::{signal, SignalKind},
|
||||
time::interval,
|
||||
};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {}
|
||||
async fn main() {
|
||||
let server = "127.0.0.1:5634".parse().unwrap();
|
||||
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
|
||||
|
||||
let crypto = quinn::rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerify))
|
||||
.with_no_client_auth();
|
||||
|
||||
let crypto = QuicClientConfig::try_from(crypto).unwrap();
|
||||
|
||||
let config = quinn::ClientConfig::new(Arc::new(crypto));
|
||||
endpoint.set_default_client_config(config);
|
||||
|
||||
let mut int = signal(SignalKind::interrupt()).unwrap();
|
||||
let mut term = signal(SignalKind::terminate()).unwrap();
|
||||
|
||||
let conn = endpoint.connect(server, "pglb").unwrap().await.unwrap();
|
||||
let mut interval = interval(Duration::from_secs(2));
|
||||
|
||||
let tasks = TaskTracker::new();
|
||||
|
||||
loop {
|
||||
select! {
|
||||
_ = int.recv() => break,
|
||||
_ = term.recv() => break,
|
||||
_ = interval.tick() => {
|
||||
let mut stream = conn.open_uni().await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
stream.finish().unwrap();
|
||||
}
|
||||
stream = conn.accept_bi() => {
|
||||
let (send, recv) = stream.unwrap();
|
||||
tasks.spawn(handle_stream(send, recv));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// graceful shutdown
|
||||
{
|
||||
let mut stream = conn.open_uni().await.unwrap();
|
||||
stream.write_all(b"shutdown").await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
stream.finish().unwrap();
|
||||
}
|
||||
|
||||
tasks.close();
|
||||
tasks.wait().await;
|
||||
conn.close(VarInt::from_u32(1), b"graceful shutdown");
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
struct NoVerify;
|
||||
|
||||
impl danger::ServerCertVerifier for NoVerify {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<danger::ServerCertVerified, quinn::rustls::Error> {
|
||||
Ok(danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &quinn::rustls::DigitallySignedStruct,
|
||||
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
|
||||
Ok(danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &quinn::rustls::DigitallySignedStruct,
|
||||
) -> Result<danger::HandshakeSignatureValid, quinn::rustls::Error> {
|
||||
Ok(danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<quinn::rustls::SignatureScheme> {
|
||||
vec![quinn::rustls::SignatureScheme::ECDSA_NISTP256_SHA256]
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stream(_send: SendStream, _recv: RecvStream) {}
|
||||
|
||||
@@ -39,10 +39,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
conns: Mutex::new(IndexMap::new()),
|
||||
});
|
||||
|
||||
let quinn_handle = tokio::spawn(quinn_server(
|
||||
auth_endpoint.clone(),
|
||||
auth_connections.clone(),
|
||||
));
|
||||
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
|
||||
|
||||
let _frontend_handle = tokio::spawn(start_frontend("127.0.0.1:0"));
|
||||
|
||||
@@ -52,7 +49,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
async fn endpoint_config(addr: SocketAddr) -> anyhow::Result<Endpoint> {
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use quinn::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
|
||||
let mut params = rcgen::CertificateParams::new(vec!["pglb".to_string()]);
|
||||
params
|
||||
@@ -87,21 +84,31 @@ async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
|
||||
.unwrap()
|
||||
.insert(conn_id, AuthConn { conn: conn.clone() });
|
||||
|
||||
// heartbeat loop
|
||||
loop {
|
||||
match timeout(Duration::from_secs(1), conn.accept_uni()).await {
|
||||
Ok(Ok(_)) => {}
|
||||
match timeout(Duration::from_secs(10), conn.accept_uni()).await {
|
||||
Ok(Ok(mut heartbeat_stream)) => {
|
||||
let data = heartbeat_stream.read_to_end(128).await.unwrap();
|
||||
if data.starts_with(b"shutdown") {
|
||||
println!("[{conn_id:?}] conn shutdown");
|
||||
break;
|
||||
}
|
||||
// else update latency info
|
||||
}
|
||||
Ok(Err(conn_err)) => {
|
||||
println!("[{conn_id:?}] conn err {conn_err:?}");
|
||||
state.conns.lock().unwrap().remove(&conn_id);
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
println!("[{conn_id:?}] conn timeout err");
|
||||
state.conns.lock().unwrap().remove(&conn_id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
state.conns.lock().unwrap().remove(&conn_id);
|
||||
let conn_closed = conn.closed().await;
|
||||
println!("[{conn_id:?}] conn closed {conn_closed:?}");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user