mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-05 20:42:54 +00:00
proxy: allow invalid SNI (#11792)
## Problem Some PrivateLink customers are unable to use Private DNS. As such they use an invalid domain name to address Neon. We currently are rejecting those connections because we cannot resolve the correct certificate. ## Summary of changes 1. Ensure a certificate is always returned. 2. If there is an SNI field, use endpoint fallback if it doesn't match. I suggest reviewing each commit separately.
This commit is contained in:
@@ -32,12 +32,6 @@ pub(crate) enum ComputeUserInfoParseError {
|
||||
option: EndpointId,
|
||||
},
|
||||
|
||||
#[error(
|
||||
"Common name inferred from SNI ('{}') is not known",
|
||||
.cn,
|
||||
)]
|
||||
UnknownCommonName { cn: String },
|
||||
|
||||
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
|
||||
MalformedProjectName(EndpointId),
|
||||
}
|
||||
@@ -66,22 +60,15 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn endpoint_sni(
|
||||
sni: &str,
|
||||
common_names: &HashSet<String>,
|
||||
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
|
||||
let Some((subdomain, common_name)) = sni.split_once('.') else {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
|
||||
};
|
||||
pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
|
||||
let (subdomain, common_name) = sni.split_once('.')?;
|
||||
if !common_names.contains(common_name) {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName {
|
||||
cn: common_name.into(),
|
||||
});
|
||||
return None;
|
||||
}
|
||||
if subdomain == SERVERLESS_DRIVER_SNI {
|
||||
return Ok(None);
|
||||
return None;
|
||||
}
|
||||
Ok(Some(EndpointId::from(subdomain)))
|
||||
Some(EndpointId::from(subdomain))
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
@@ -113,15 +100,8 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
})
|
||||
.map(|name| name.into());
|
||||
|
||||
let endpoint_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
endpoint_sni(sni_str, cn)?
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let endpoint_from_domain =
|
||||
sni.and_then(|sni_str| common_names.and_then(|cn| endpoint_sni(sni_str, cn)));
|
||||
|
||||
let endpoint = match (endpoint_option, endpoint_from_domain) {
|
||||
// Invariant: if we have both project name variants, they should match.
|
||||
@@ -424,21 +404,34 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_inconsistent_sni() {
|
||||
fn parse_unknown_sni() {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
UnknownCommonName { cn } => {
|
||||
assert_eq!(cn, "localhost");
|
||||
}
|
||||
_ => panic!("bad error: {err:?}"),
|
||||
}
|
||||
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.unwrap();
|
||||
|
||||
assert!(info.endpoint_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unknown_sni_with_options() {
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("options", "endpoint=foo-bar-baz-1234"),
|
||||
]);
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(info.endpoint_id.as_deref(), Some("foo-bar-baz-1234"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -24,9 +24,6 @@ pub(crate) enum HandshakeError {
|
||||
#[error("protocol violation")]
|
||||
ProtocolViolation,
|
||||
|
||||
#[error("missing certificate")]
|
||||
MissingCertificate,
|
||||
|
||||
#[error("{0}")]
|
||||
StreamUpgradeError(#[from] StreamUpgradeError),
|
||||
|
||||
@@ -42,10 +39,6 @@ impl ReportableError for HandshakeError {
|
||||
match self {
|
||||
HandshakeError::EarlyData => crate::error::ErrorKind::User,
|
||||
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
|
||||
// This error should not happen, but will if we have no default certificate and
|
||||
// the client sends no SNI extension.
|
||||
// If they provide SNI then we can be sure there is a certificate that matches.
|
||||
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
|
||||
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
|
||||
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
|
||||
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
@@ -146,7 +139,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// try parse endpoint
|
||||
let ep = conn_info
|
||||
.server_name()
|
||||
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
|
||||
.and_then(|sni| endpoint_sni(sni, &tls.common_names));
|
||||
if let Some(ep) = ep {
|
||||
ctx.set_endpoint_id(ep);
|
||||
}
|
||||
@@ -161,10 +154,8 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
}
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(conn_info.server_name())
|
||||
.ok_or(HandshakeError::MissingCertificate)?;
|
||||
let (_, tls_server_end_point) =
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
|
||||
@@ -98,8 +98,7 @@ fn generate_tls_config<'a>(
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert.clone()], key.clone_key())?;
|
||||
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
let cert_resolver = CertResolver::new(key, vec![cert])?;
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
|
||||
@@ -199,8 +199,7 @@ fn get_conn_info(
|
||||
let endpoint = match connection_url.host() {
|
||||
Some(url::Host::Domain(hostname)) => {
|
||||
if let Some(tls) = tls {
|
||||
endpoint_sni(hostname, &tls.common_names)?
|
||||
.ok_or(ConnInfoError::MalformedEndpoint)?
|
||||
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
|
||||
} else {
|
||||
hostname
|
||||
.split_once('.')
|
||||
|
||||
@@ -5,6 +5,7 @@ use anyhow::{Context, bail};
|
||||
use itertools::Itertools;
|
||||
use rustls::crypto::ring::{self, sign};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use rustls::sign::CertifiedKey;
|
||||
use x509_cert::der::{Reader, SliceReader};
|
||||
|
||||
use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
|
||||
@@ -25,10 +26,8 @@ pub fn configure_tls(
|
||||
certs_dir: Option<&String>,
|
||||
allow_tls_keylogfile: bool,
|
||||
) -> anyhow::Result<TlsConfig> {
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
// add default certificate
|
||||
cert_resolver.add_cert_path(key_path, cert_path, true)?;
|
||||
let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?;
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
@@ -40,11 +39,8 @@ pub fn configure_tls(
|
||||
let key_path = path.join("tls.key");
|
||||
let cert_path = path.join("tls.crt");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver.add_cert_path(
|
||||
&key_path.to_string_lossy(),
|
||||
&cert_path.to_string_lossy(),
|
||||
false,
|
||||
)?;
|
||||
cert_resolver
|
||||
.add_cert_path(&key_path.to_string_lossy(), &cert_path.to_string_lossy())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -83,92 +79,42 @@ pub fn configure_tls(
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint),
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
fn parse_new(key_path: &str, cert_path: &str) -> anyhow::Result<Self> {
|
||||
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
|
||||
Self::new(priv_key, cert_chain)
|
||||
}
|
||||
|
||||
fn add_cert_path(
|
||||
&mut self,
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
rustls_pemfile::private_key(&mut &key_bytes[..])
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
};
|
||||
pub fn new(
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
) -> anyhow::Result<Self> {
|
||||
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
|
||||
})?
|
||||
};
|
||||
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
let mut certs = HashMap::new();
|
||||
let default = (cert.clone(), tls_server_end_point);
|
||||
certs.insert(common_name, (cert, tls_server_end_point));
|
||||
Ok(Self { certs, default })
|
||||
}
|
||||
|
||||
pub fn add_cert(
|
||||
fn add_cert_path(&mut self, key_path: &str, cert_path: &str) -> anyhow::Result<()> {
|
||||
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
|
||||
self.add_cert(priv_key, cert_chain)
|
||||
}
|
||||
|
||||
fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let certificate = SliceReader::new(first_cert)
|
||||
.context("Failed to parse cerficiate")?
|
||||
.decode::<x509_cert::Certificate>()
|
||||
.context("Failed to parse cerficiate")?;
|
||||
|
||||
let common_name = certificate.tbs_certificate.subject.to_string();
|
||||
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
if is_default {
|
||||
self.default = Some((cert.clone(), tls_server_end_point));
|
||||
}
|
||||
|
||||
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
|
||||
self.certs.insert(common_name, (cert, tls_server_end_point));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -177,12 +123,82 @@ impl CertResolver {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_key_cert(
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
) -> anyhow::Result<(PrivateKeyDer<'static>, Vec<CertificateDer<'static>>)> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
rustls_pemfile::private_key(&mut &key_bytes[..])
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
)
|
||||
})?
|
||||
};
|
||||
|
||||
Ok((priv_key, cert_chain))
|
||||
}
|
||||
|
||||
fn process_key_cert(
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
) -> anyhow::Result<(String, Arc<CertifiedKey>, TlsServerEndPoint)> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let certificate = SliceReader::new(first_cert)
|
||||
.context("Failed to parse cerficiate")?
|
||||
.decode::<x509_cert::Certificate>()
|
||||
.context("Failed to parse cerficiate")?;
|
||||
|
||||
let common_name = certificate.tbs_certificate.subject.to_string();
|
||||
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
Ok((common_name, cert, tls_server_end_point))
|
||||
}
|
||||
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
client_hello: rustls::server::ClientHello<'_>,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
Some(self.resolve(client_hello.server_name()).0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +206,7 @@ impl CertResolver {
|
||||
pub fn resolve(
|
||||
&self,
|
||||
server_name: Option<&str>,
|
||||
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
|
||||
) -> (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint) {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
@@ -200,12 +216,17 @@ impl CertResolver {
|
||||
if let Some(mut sni_name) = server_name {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return Some(cert.clone());
|
||||
return cert.clone();
|
||||
}
|
||||
if let Some((_, rest)) = sni_name.split_once('.') {
|
||||
sni_name = rest;
|
||||
} else {
|
||||
return None;
|
||||
// The customer has some custom DNS mapping - just return
|
||||
// a default certificate.
|
||||
//
|
||||
// This will error if the customer uses anything stronger
|
||||
// than sslmode=require. That's a choice they can make.
|
||||
return self.default.clone();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user