Merge branch 'cloneable/pglb-tls' into pglb

This commit is contained in:
Folke Behrens
2024-09-12 15:47:21 +01:00

View File

@@ -5,14 +5,19 @@ use std::{
time::Duration,
};
use anyhow::Context;
use anyhow::{anyhow, bail, Context, Result};
use bytes::BytesMut;
use indexmap::IndexMap;
use proxy::config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL};
use quinn::{Connection, Endpoint};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::{
io::AsyncReadExt,
net::{TcpListener, TcpStream},
time::timeout,
};
use tracing::error;
use tokio_rustls::server::TlsStream;
use tracing::{error, warn};
type AuthConnId = usize;
struct AuthConnState {
@@ -28,12 +33,10 @@ struct AuthConn {
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
async fn main() -> Result<()> {
let _logging_guard = proxy::logging::init().await?;
let auth_endpoint: Endpoint = endpoint_config("0.0.0.0:5634".parse().unwrap())
.await
.unwrap();
let auth_endpoint: Endpoint = endpoint_config("0.0.0.0:5634".parse()?).await?;
let auth_connections = Arc::new(AuthConnState {
conns: Mutex::new(IndexMap::new()),
@@ -41,16 +44,16 @@ async fn main() -> anyhow::Result<()> {
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
let _frontend_handle = tokio::spawn(start_frontend("127.0.0.1:0"));
let frontend_config = frontent_tls_config("pglb-fe", "pglb-fe")?;
let _frontend_handle = tokio::spawn(start_frontend("0.0.0.0:5432".parse()?, frontend_config));
quinn_handle.await.unwrap();
Ok(())
}
async fn endpoint_config(addr: SocketAddr) -> anyhow::Result<Endpoint> {
use quinn::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
async fn endpoint_config(addr: SocketAddr) -> Result<Endpoint> {
let mut params = rcgen::CertificateParams::new(vec!["pglb".to_string()]);
params
.distinguished_name
@@ -113,8 +116,42 @@ async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
}
}
async fn start_frontend(addr: &str) -> anyhow::Result<Infallible> {
let addr: SocketAddr = addr.parse()?;
fn frontent_tls_config(hostname: &str, common_name: &str) -> Result<TlsConfig> {
let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params
})?;
let cert = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, common_name);
params
})?;
let (cert, key) = (
CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
);
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone_key())?
.into();
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
Ok(TlsConfig {
config,
cert_resolver: Arc::new(cert_resolver),
})
}
async fn start_frontend(addr: SocketAddr, tls: TlsConfig) -> Result<Infallible> {
let listener = TcpListener::bind(addr).await?;
socket2::SockRef::from(&listener).set_keepalive(true)?;
@@ -123,23 +160,102 @@ async fn start_frontend(addr: &str) -> anyhow::Result<Infallible> {
loop {
match listener.accept().await {
Ok((socket, peer_addr)) => {
connections.spawn(handle_frontend_connection(socket, peer_addr));
let tls = tls.clone();
connections.spawn_local(handle_frontend_connection(socket, peer_addr, tls));
}
Err(e) => {
error!("connection accept error: {e}");
}
Err(e) => {}
}
}
}
async fn handle_frontend_connection(socket: TcpStream, _peer_addr: SocketAddr) {
match socket.set_nodelay(true) {
async fn handle_frontend_connection(mut stream: TcpStream, _peer_addr: SocketAddr, tls: TlsConfig) {
match stream.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!("per-client task finished with an error: failed to set socket option: {e:#}");
error!("socket option error: {e}");
return;
}
};
// TODO: HAProxy protocol?
let tls_requested = match handle_ssl_request_message(&mut stream).await {
Ok(tls_requested) => tls_requested,
Err(e) => {
error!("check_for_ssl_request: {e}");
return;
}
};
if tls_requested {
let (stream, ep, sn) = match tls_upgrade(stream, tls).await {
Ok((stream, ep, sn)) => (stream, ep, sn),
Err(e) => {
error!("tls_upgrade: {e}");
return;
}
};
// TODO: send auth msg with tls ep and server name
} else {
// TODO: send auth msg without server name
}
}
// TODO: client state machine
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
let mut buf = BytesMut::with_capacity(8);
let n_peek = stream.peek(&mut buf).await?;
if n_peek == 0 {
bail!("EOF");
}
assert_eq!(buf.len(), 8); // TODO: loop, read more
if buf.len() != 8 || buf[0..4] != 8u32.to_be_bytes() || buf[4..8] != 80877103u32.to_be_bytes() {
return Ok(false);
}
buf.clear();
let n_read = stream.read(&mut buf).await?;
assert_eq!(n_peek, n_read); // TODO: loop, read more
Ok(true)
}
async fn tls_upgrade(
stream: TcpStream,
tls: TlsConfig,
) -> Result<(TlsStream<TcpStream>, TlsServerEndPoint, Option<String>)> {
let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config)
.accept(stream)
.await?;
let conn_info = tls_stream.get_ref().1;
let server_name = conn_info.server_name().map(|s| s.to_string());
match conn_info.alpn_protocol() {
None | Some(PG_ALPN_PROTOCOL) => {}
Some(other) => {
let alpn = String::from_utf8_lossy(other);
warn!(%alpn, "unexpected ALPN");
bail!("protocol violation");
}
}
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(server_name.as_deref())
.ok_or(anyhow!("missing cert"))?;
Ok((tls_stream, tls_server_end_point, server_name))
}
#[derive(Clone, Debug)]
struct TlsConfig {
config: Arc<rustls::ServerConfig>,
cert_resolver: Arc<CertResolver>,
}