mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-29 19:10:38 +00:00
chore(proxy): pre-load native tls certificates and propagate compute client config (#10182)
Now that we construct the TLS client config for cancellation as well as connect, it feels appropriate to construct the same config once and re-use it elsewhere. It might also help should #7500 require any extra setup, so we can easily add it to all the appropriate call sites.
This commit is contained in:
42
proxy/src/tls/client_config.rs
Normal file
42
proxy/src/tls/client_config.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use rustls::crypto::ring;
|
||||
|
||||
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
|
||||
let der_certs = rustls_native_certs::load_native_certs();
|
||||
|
||||
if !der_certs.errors.is_empty() {
|
||||
bail!("could not parse certificates: {:?}", der_certs.errors);
|
||||
}
|
||||
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(der_certs.certs);
|
||||
Ok(Arc::new(store))
|
||||
}
|
||||
|
||||
/// Loads the root certificates and constructs a client config suitable for connecting to the neon compute.
|
||||
/// This function is blocking.
|
||||
pub fn compute_client_config_with_root_certs() -> anyhow::Result<rustls::ClientConfig> {
|
||||
Ok(
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(load_certs()?)
|
||||
.with_no_client_auth(),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn compute_client_config_with_certs(
|
||||
certs: impl IntoIterator<Item = rustls::pki_types::CertificateDer<'static>>,
|
||||
) -> rustls::ClientConfig {
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(certs);
|
||||
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(store)
|
||||
.with_no_client_auth()
|
||||
}
|
||||
72
proxy/src/tls/mod.rs
Normal file
72
proxy/src/tls/mod.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
pub mod client_config;
|
||||
pub mod postgres_rustls;
|
||||
pub mod server_config;
|
||||
|
||||
use anyhow::Context;
|
||||
use rustls::pki_types::CertificateDer;
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
|
||||
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
|
||||
|
||||
/// Channel binding parameter
|
||||
///
|
||||
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
|
||||
/// Description: The hash of the TLS server's certificate as it
|
||||
/// appears, octet for octet, in the server's Certificate message. Note
|
||||
/// that the Certificate message contains a certificate_list, in which
|
||||
/// the first element is the server's certificate.
|
||||
///
|
||||
/// The hash function is to be selected as follows:
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function and that hash function neither MD5 nor SHA-1, then use
|
||||
/// the hash function associated with the certificate's
|
||||
/// signatureAlgorithm;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses no hash functions or
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
}
|
||||
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
info!(subject = %pem.subject, "parsing TLS certificate");
|
||||
|
||||
let reg = oid_registry::OidRegistry::default().with_all_crypto();
|
||||
let oid = pem.signature_algorithm.oid();
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
|
||||
Ok(Self::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported(&self) -> bool {
|
||||
!matches!(self, TlsServerEndPoint::Undefined)
|
||||
}
|
||||
}
|
||||
156
proxy/src/tls/postgres_rustls.rs
Normal file
156
proxy/src/tls/postgres_rustls.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::Arc;
|
||||
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use rustls::pki_types::ServerName;
|
||||
use rustls::ClientConfig;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
mod private {
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use postgres_client::tls::{ChannelBinding, TlsConnect};
|
||||
use rustls::pki_types::ServerName;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::client::TlsStream;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
pub struct TlsConnectFuture<S> {
|
||||
inner: tokio_rustls::Connect<S>,
|
||||
}
|
||||
|
||||
impl<S> Future for TlsConnectFuture<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = io::Result<RustlsStream<S>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RustlsConnect(pub RustlsConnectData);
|
||||
|
||||
pub struct RustlsConnectData {
|
||||
pub hostname: ServerName<'static>,
|
||||
pub connector: TlsConnector,
|
||||
}
|
||||
|
||||
impl<S> TlsConnect<S> for RustlsConnect
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Stream = RustlsStream<S>;
|
||||
type Error = io::Error;
|
||||
type Future = TlsConnectFuture<S>;
|
||||
|
||||
fn connect(self, stream: S) -> Self::Future {
|
||||
TlsConnectFuture {
|
||||
inner: self.0.connector.connect(self.0.hostname, stream),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RustlsStream<S>(TlsStream<S>);
|
||||
|
||||
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn channel_binding(&self) -> ChannelBinding {
|
||||
let (_, session) = self.0.get_ref();
|
||||
match session.peer_certificates() {
|
||||
Some([cert, ..]) => TlsServerEndPoint::new(cert)
|
||||
.ok()
|
||||
.and_then(|cb| match cb {
|
||||
TlsServerEndPoint::Sha256(hash) => Some(hash),
|
||||
TlsServerEndPoint::Undefined => None,
|
||||
})
|
||||
.map_or_else(ChannelBinding::none, |hash| {
|
||||
ChannelBinding::tls_server_end_point(hash.to_vec())
|
||||
}),
|
||||
_ => ChannelBinding::none(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for RustlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<tokio::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for RustlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<tokio::io::Result<usize>> {
|
||||
Pin::new(&mut self.0).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<tokio::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<tokio::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A `MakeTlsConnect` implementation using `rustls`.
|
||||
///
|
||||
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
|
||||
#[derive(Clone)]
|
||||
pub struct MakeRustlsConnect {
|
||||
pub config: Arc<ClientConfig>,
|
||||
}
|
||||
|
||||
impl MakeRustlsConnect {
|
||||
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
|
||||
#[must_use]
|
||||
pub fn new(config: Arc<ClientConfig>) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Stream = private::RustlsStream<S>;
|
||||
type TlsConnect = private::RustlsConnect;
|
||||
type Error = rustls::pki_types::InvalidDnsNameError;
|
||||
|
||||
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
ServerName::try_from(hostname).map(|dns_name| {
|
||||
private::RustlsConnect(private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: Arc::clone(&self.config).into(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
218
proxy/src/tls/server_config.rs
Normal file
218
proxy/src/tls/server_config.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use itertools::Itertools;
|
||||
use rustls::crypto::ring::{self, sign};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
|
||||
use super::{TlsServerEndPoint, PG_ALPN_PROTOCOL};
|
||||
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_names: HashSet<String>,
|
||||
pub cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
|
||||
self.config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure TLS for the main endpoint.
|
||||
pub fn configure_tls(
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
certs_dir: Option<&String>,
|
||||
allow_tls_keylogfile: bool,
|
||||
) -> anyhow::Result<TlsConfig> {
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
// add default certificate
|
||||
cert_resolver.add_cert_path(key_path, cert_path, true)?;
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
for entry in std::fs::read_dir(certs_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
// file names aligned with default cert-manager names
|
||||
let key_path = path.join("tls.key");
|
||||
let cert_path = path.join("tls.crt");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver.add_cert_path(
|
||||
&key_path.to_string_lossy(),
|
||||
&cert_path.to_string_lossy(),
|
||||
false,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
let cert_resolver = Arc::new(cert_resolver);
|
||||
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
let mut config =
|
||||
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
.context("ring should support TLS1.2 and TLS1.3")?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver.clone());
|
||||
|
||||
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
|
||||
|
||||
if allow_tls_keylogfile {
|
||||
// KeyLogFile will check for the SSLKEYLOGFILE environment variable.
|
||||
config.key_log = Arc::new(rustls::KeyLogFile::new());
|
||||
}
|
||||
|
||||
Ok(TlsConfig {
|
||||
config: Arc::new(config),
|
||||
common_names,
|
||||
cert_resolver,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn add_cert_path(
|
||||
&mut self,
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
rustls_pemfile::private_key(&mut &key_bytes[..])
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
|
||||
})?
|
||||
};
|
||||
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
}
|
||||
|
||||
pub fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
let pem = x509_parser::parse_x509_certificate(first_cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
if is_default {
|
||||
self.default = Some((cert.clone(), tls_server_end_point));
|
||||
}
|
||||
|
||||
self.certs.insert(common_name, (cert, tls_server_end_point));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_common_names(&self) -> HashSet<String> {
|
||||
self.certs.keys().map(|s| s.to_string()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
client_hello: rustls::server::ClientHello<'_>,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn resolve(
|
||||
&self,
|
||||
server_name: Option<&str>,
|
||||
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
//
|
||||
// With the current coding foo.com will match *.foo.com and that
|
||||
// repeats behavior of the old code.
|
||||
if let Some(mut sni_name) = server_name {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return Some(cert.clone());
|
||||
}
|
||||
if let Some((_, rest)) = sni_name.split_once('.') {
|
||||
sni_name = rest;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No SNI, use the default certificate, otherwise we can't get to
|
||||
// options parameter which can be used to set endpoint name too.
|
||||
// That means that non-SNI flow will not work for CNAME domains in
|
||||
// verify-full mode.
|
||||
//
|
||||
// If that will be a problem we can:
|
||||
//
|
||||
// a) Instead of multi-cert approach use single cert with extra
|
||||
// domains listed in Subject Alternative Name (SAN).
|
||||
// b) Deploy separate proxy instances for extra domains.
|
||||
self.default.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user