update ssl handling and add some logs

This commit is contained in:
Conrad Ludgate
2024-09-13 09:25:11 +01:00
parent d698a50984
commit 24c48856a2
2 changed files with 56 additions and 23 deletions

View File

@@ -9,6 +9,8 @@ use anyhow::{anyhow, bail, Context, Result};
use bytes::{Buf, BufMut, BytesMut};
use futures::{SinkExt, StreamExt};
use indexmap::IndexMap;
use itertools::Itertools;
use pq_proto::BeMessage;
use proxy::{
config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL},
ConnectionInitiatedPayload, PglbCodec, PglbControlMessage, PglbMessage,
@@ -17,7 +19,7 @@ use quinn::{Connection, Endpoint, RecvStream, SendStream};
use rand::Rng;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::{
io::{join, AsyncReadExt, Join},
io::{join, AsyncReadExt, AsyncWriteExt, Join},
net::{TcpListener, TcpStream},
select,
time::timeout,
@@ -54,7 +56,7 @@ async fn main() -> Result<()> {
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
let frontend_config = frontent_tls_config("pglb-fe", "pglb-fe")?;
let frontend_config = frontent_tls_config("*.localtest.me", "*.localtest.me")?;
let _frontend_handle = tokio::spawn(start_frontend(
"0.0.0.0:5432".parse()?,
@@ -131,24 +133,37 @@ async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
}
fn frontent_tls_config(hostname: &str, common_name: &str) -> Result<TlsConfig> {
let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params
})?;
// let ca = rcgen::Certificate::from_params({
// let mut params = rcgen::CertificateParams::default();
// params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
// params
// })?;
let cert = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, common_name);
params
})?;
// let cert = rcgen::Certificate::from_params({
// let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
// params.distinguished_name = rcgen::DistinguishedName::new();
// params
// .distinguished_name
// .push(rcgen::DnType::CommonName, common_name);
// params
// })?;
// let (cert, key) = (
// CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
// PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
// );
let (cert, key) = (
CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
rustls_pemfile::certs(&mut &*std::fs::read("proxy.crt").unwrap())
.collect_vec()
.remove(0)
.unwrap(),
PrivateKeyDer::Pkcs8(
rustls_pemfile::pkcs8_private_keys(&mut &*std::fs::read("proxy.key").unwrap())
.collect_vec()
.remove(0)
.unwrap(),
),
);
let config = rustls::ServerConfig::builder()
@@ -173,11 +188,14 @@ async fn start_frontend(
let listener = TcpListener::bind(addr).await?;
socket2::SockRef::from(&listener).set_keepalive(true)?;
println!("starting");
let connections = tokio_util::task::task_tracker::TaskTracker::new();
loop {
match listener.accept().await {
Ok((socket, client_addr)) => {
println!("accepted");
let conn = PglbConn::new(&state, &tls)?;
connections.spawn(conn.handle(socket, client_addr));
}
@@ -269,6 +287,12 @@ impl PglbConn<Start> {
};
let (client_stream, payload) = if tls_requested {
println!("starting tls upgrade");
let mut buf = BytesMut::new();
BeMessage::write(&mut buf, &BeMessage::EncryptionResponse(true)).unwrap();
client_stream.write_all(&buf).await?;
let (stream, tls_server_end_point, server_name) =
match Self::tls_upgrade(client_stream, self.inner.tls_config.clone()).await {
Ok((stream, ep, sn)) => (stream, ep, sn),
@@ -290,6 +314,8 @@ impl PglbConn<Start> {
bail!("closing non-TLS connection");
};
println!("tls done");
Ok(PglbConn {
inner: self.inner,
state: ClientConnect {
@@ -300,7 +326,8 @@ impl PglbConn<Start> {
}
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
let mut buf = BytesMut::with_capacity(8);
println!("checking for ssl request");
let mut buf = vec![0u8; 8];
let n_peek = stream.peek(&mut buf).await?;
if n_peek == 0 {
@@ -315,11 +342,7 @@ impl PglbConn<Start> {
{
return Ok(false);
}
buf.clear();
let n_read = stream.read(&mut buf).await?;
assert_eq!(n_peek, n_read); // TODO: loop, read more
stream.read_exact(&mut buf).await?;
Ok(true)
}
@@ -376,6 +399,7 @@ impl PglbConn<ClientConnect> {
// TODO: check closed?
};
println!("connecting to {}", auth_conn.conn.stable_id());
let (send, recv) = auth_conn.conn.open_bi().await?;
let mut auth_stream = Framed::new(join(recv, send), PglbCodec);
@@ -431,6 +455,7 @@ impl PglbConn<AuthPassthrough> {
bail!("auth proxy sent unexpected message");
}
PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket }) => {
println!("socket");
return Ok(PglbConn {
inner: self.inner,
state: ComputeConnect {

View File

@@ -477,18 +477,26 @@ pub async fn handle_stream(
panic!("invalid first msg")
};
println!("new conn: {conn_info:?}");
// read startup packet
let startup = stream.read_startup_packet().await?;
let FeStartupPacket::StartupMessage { version: _, params } = startup else {
panic!("invalid startup message")
};
println!("params: {params:?}");
let user_info = auth_with_user(&mut stream, config, &conn_info, &params).await?;
println!("authenticated");
// wake the compute
let node_info = user_info.wake_compute(&RequestMonitoring::test()).await?;
let socket: SocketAddr = node_info.config.get_host()?.parse()?;
println!("woke compute");
// tell pglb that the compute is up
stream
.write_message(&pq_proto::BeMessage::AuthenticationOk)