mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-21 23:20:40 +00:00
Merge branch 'cloneable/pglb-tls' into pglb
This commit is contained in:
@@ -5,14 +5,19 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use bytes::BytesMut;
|
||||
use indexmap::IndexMap;
|
||||
use proxy::config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL};
|
||||
use quinn::{Connection, Endpoint};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use tokio::{
|
||||
io::AsyncReadExt,
|
||||
net::{TcpListener, TcpStream},
|
||||
time::timeout,
|
||||
};
|
||||
use tracing::error;
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tracing::{error, warn};
|
||||
|
||||
type AuthConnId = usize;
|
||||
struct AuthConnState {
|
||||
@@ -28,12 +33,10 @@ struct AuthConn {
|
||||
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
async fn main() -> Result<()> {
|
||||
let _logging_guard = proxy::logging::init().await?;
|
||||
|
||||
let auth_endpoint: Endpoint = endpoint_config("0.0.0.0:5634".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
let auth_endpoint: Endpoint = endpoint_config("0.0.0.0:5634".parse()?).await?;
|
||||
|
||||
let auth_connections = Arc::new(AuthConnState {
|
||||
conns: Mutex::new(IndexMap::new()),
|
||||
@@ -41,16 +44,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
|
||||
|
||||
let _frontend_handle = tokio::spawn(start_frontend("127.0.0.1:0"));
|
||||
let frontend_config = frontent_tls_config("pglb-fe", "pglb-fe")?;
|
||||
|
||||
let _frontend_handle = tokio::spawn(start_frontend("0.0.0.0:5432".parse()?, frontend_config));
|
||||
|
||||
quinn_handle.await.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn endpoint_config(addr: SocketAddr) -> anyhow::Result<Endpoint> {
|
||||
use quinn::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
|
||||
async fn endpoint_config(addr: SocketAddr) -> Result<Endpoint> {
|
||||
let mut params = rcgen::CertificateParams::new(vec!["pglb".to_string()]);
|
||||
params
|
||||
.distinguished_name
|
||||
@@ -113,8 +116,42 @@ async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_frontend(addr: &str) -> anyhow::Result<Infallible> {
|
||||
let addr: SocketAddr = addr.parse()?;
|
||||
fn frontent_tls_config(hostname: &str, common_name: &str) -> Result<TlsConfig> {
|
||||
let ca = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::default();
|
||||
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
|
||||
params
|
||||
})?;
|
||||
|
||||
let cert = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
|
||||
params.distinguished_name = rcgen::DistinguishedName::new();
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, common_name);
|
||||
params
|
||||
})?;
|
||||
|
||||
let (cert, key) = (
|
||||
CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
|
||||
PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
|
||||
);
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert.clone()], key.clone_key())?
|
||||
.into();
|
||||
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
cert_resolver: Arc::new(cert_resolver),
|
||||
})
|
||||
}
|
||||
|
||||
async fn start_frontend(addr: SocketAddr, tls: TlsConfig) -> Result<Infallible> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
@@ -123,23 +160,102 @@ async fn start_frontend(addr: &str) -> anyhow::Result<Infallible> {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((socket, peer_addr)) => {
|
||||
connections.spawn(handle_frontend_connection(socket, peer_addr));
|
||||
let tls = tls.clone();
|
||||
connections.spawn_local(handle_frontend_connection(socket, peer_addr, tls));
|
||||
}
|
||||
Err(e) => {
|
||||
error!("connection accept error: {e}");
|
||||
}
|
||||
Err(e) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_frontend_connection(socket: TcpStream, _peer_addr: SocketAddr) {
|
||||
match socket.set_nodelay(true) {
|
||||
async fn handle_frontend_connection(mut stream: TcpStream, _peer_addr: SocketAddr, tls: TlsConfig) {
|
||||
match stream.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: failed to set socket option: {e:#}");
|
||||
error!("socket option error: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: HAProxy protocol?
|
||||
|
||||
let tls_requested = match handle_ssl_request_message(&mut stream).await {
|
||||
Ok(tls_requested) => tls_requested,
|
||||
Err(e) => {
|
||||
error!("check_for_ssl_request: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if tls_requested {
|
||||
let (stream, ep, sn) = match tls_upgrade(stream, tls).await {
|
||||
Ok((stream, ep, sn)) => (stream, ep, sn),
|
||||
Err(e) => {
|
||||
error!("tls_upgrade: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: send auth msg with tls ep and server name
|
||||
} else {
|
||||
// TODO: send auth msg without server name
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: client state machine
|
||||
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
|
||||
let mut buf = BytesMut::with_capacity(8);
|
||||
|
||||
let n_peek = stream.peek(&mut buf).await?;
|
||||
if n_peek == 0 {
|
||||
bail!("EOF");
|
||||
}
|
||||
|
||||
assert_eq!(buf.len(), 8); // TODO: loop, read more
|
||||
|
||||
if buf.len() != 8 || buf[0..4] != 8u32.to_be_bytes() || buf[4..8] != 80877103u32.to_be_bytes() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
buf.clear();
|
||||
let n_read = stream.read(&mut buf).await?;
|
||||
|
||||
assert_eq!(n_peek, n_read); // TODO: loop, read more
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn tls_upgrade(
|
||||
stream: TcpStream,
|
||||
tls: TlsConfig,
|
||||
) -> Result<(TlsStream<TcpStream>, TlsServerEndPoint, Option<String>)> {
|
||||
let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config)
|
||||
.accept(stream)
|
||||
.await?;
|
||||
|
||||
let conn_info = tls_stream.get_ref().1;
|
||||
let server_name = conn_info.server_name().map(|s| s.to_string());
|
||||
|
||||
match conn_info.alpn_protocol() {
|
||||
None | Some(PG_ALPN_PROTOCOL) => {}
|
||||
Some(other) => {
|
||||
let alpn = String::from_utf8_lossy(other);
|
||||
warn!(%alpn, "unexpected ALPN");
|
||||
bail!("protocol violation");
|
||||
}
|
||||
}
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(server_name.as_deref())
|
||||
.ok_or(anyhow!("missing cert"))?;
|
||||
|
||||
Ok((tls_stream, tls_server_end_point, server_name))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TlsConfig {
|
||||
config: Arc<rustls::ServerConfig>,
|
||||
cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user