diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index c55af325e3..183976374a 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -32,12 +32,6 @@ pub(crate) enum ComputeUserInfoParseError { option: EndpointId, }, - #[error( - "Common name inferred from SNI ('{}') is not known", - .cn, - )] - UnknownCommonName { cn: String }, - #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")] MalformedProjectName(EndpointId), } @@ -66,22 +60,15 @@ impl ComputeUserInfoMaybeEndpoint { } } -pub(crate) fn endpoint_sni( - sni: &str, - common_names: &HashSet, -) -> Result, ComputeUserInfoParseError> { - let Some((subdomain, common_name)) = sni.split_once('.') else { - return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() }); - }; +pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet) -> Option { + let (subdomain, common_name) = sni.split_once('.')?; if !common_names.contains(common_name) { - return Err(ComputeUserInfoParseError::UnknownCommonName { - cn: common_name.into(), - }); + return None; } if subdomain == SERVERLESS_DRIVER_SNI { - return Ok(None); + return None; } - Ok(Some(EndpointId::from(subdomain))) + Some(EndpointId::from(subdomain)) } impl ComputeUserInfoMaybeEndpoint { @@ -113,15 +100,8 @@ impl ComputeUserInfoMaybeEndpoint { }) .map(|name| name.into()); - let endpoint_from_domain = if let Some(sni_str) = sni { - if let Some(cn) = common_names { - endpoint_sni(sni_str, cn)? - } else { - None - } - } else { - None - }; + let endpoint_from_domain = + sni.and_then(|sni_str| common_names.and_then(|cn| endpoint_sni(sni_str, cn))); let endpoint = match (endpoint_option, endpoint_from_domain) { // Invariant: if we have both project name variants, they should match. @@ -424,21 +404,34 @@ mod tests { } #[test] - fn parse_inconsistent_sni() { + fn parse_unknown_sni() { let options = StartupMessageParams::new([("user", "john_doe")]); let sni = Some("project.localhost"); let common_names = Some(["example.com".into()].into()); let ctx = RequestContext::test(); - let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref()) - .expect_err("should fail"); - match err { - UnknownCommonName { cn } => { - assert_eq!(cn, "localhost"); - } - _ => panic!("bad error: {err:?}"), - } + let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref()) + .unwrap(); + + assert!(info.endpoint_id.is_none()); + } + + #[test] + fn parse_unknown_sni_with_options() { + let options = StartupMessageParams::new([ + ("user", "john_doe"), + ("options", "endpoint=foo-bar-baz-1234"), + ]); + + let sni = Some("project.localhost"); + let common_names = Some(["example.com".into()].into()); + + let ctx = RequestContext::test(); + let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref()) + .unwrap(); + + assert_eq!(info.endpoint_id.as_deref(), Some("foo-bar-baz-1234")); } #[test] diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index c05031ad97..54c02f2c15 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -24,9 +24,6 @@ pub(crate) enum HandshakeError { #[error("protocol violation")] ProtocolViolation, - #[error("missing certificate")] - MissingCertificate, - #[error("{0}")] StreamUpgradeError(#[from] StreamUpgradeError), @@ -42,10 +39,6 @@ impl ReportableError for HandshakeError { match self { HandshakeError::EarlyData => crate::error::ErrorKind::User, HandshakeError::ProtocolViolation => crate::error::ErrorKind::User, - // This error should not happen, but will if we have no default certificate and - // the client sends no SNI extension. - // If they provide SNI then we can be sure there is a certificate that matches. - HandshakeError::MissingCertificate => crate::error::ErrorKind::Service, HandshakeError::StreamUpgradeError(upgrade) => match upgrade { StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service, StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect, @@ -146,7 +139,7 @@ pub(crate) async fn handshake( // try parse endpoint let ep = conn_info .server_name() - .and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten()); + .and_then(|sni| endpoint_sni(sni, &tls.common_names)); if let Some(ep) = ep { ctx.set_endpoint_id(ep); } @@ -161,10 +154,8 @@ pub(crate) async fn handshake( } } - let (_, tls_server_end_point) = tls - .cert_resolver - .resolve(conn_info.server_name()) - .ok_or(HandshakeError::MissingCertificate)?; + let (_, tls_server_end_point) = + tls.cert_resolver.resolve(conn_info.server_name()); stream = PqStream { framed: Framed { diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 9a6864c33e..f47636cd71 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -98,8 +98,7 @@ fn generate_tls_config<'a>( .with_no_client_auth() .with_single_cert(vec![cert.clone()], key.clone_key())?; - let mut cert_resolver = CertResolver::new(); - cert_resolver.add_cert(key, vec![cert], true)?; + let cert_resolver = CertResolver::new(key, vec![cert])?; let common_names = cert_resolver.get_common_names(); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 7fb39553f9..fee5942b7e 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -199,8 +199,7 @@ fn get_conn_info( let endpoint = match connection_url.host() { Some(url::Host::Domain(hostname)) => { if let Some(tls) = tls { - endpoint_sni(hostname, &tls.common_names)? - .ok_or(ConnInfoError::MalformedEndpoint)? + endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)? } else { hostname .split_once('.') diff --git a/proxy/src/tls/server_config.rs b/proxy/src/tls/server_config.rs index 5a95e69fde..8f8917ef62 100644 --- a/proxy/src/tls/server_config.rs +++ b/proxy/src/tls/server_config.rs @@ -5,6 +5,7 @@ use anyhow::{Context, bail}; use itertools::Itertools; use rustls::crypto::ring::{self, sign}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::sign::CertifiedKey; use x509_cert::der::{Reader, SliceReader}; use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint}; @@ -25,10 +26,8 @@ pub fn configure_tls( certs_dir: Option<&String>, allow_tls_keylogfile: bool, ) -> anyhow::Result { - let mut cert_resolver = CertResolver::new(); - // add default certificate - cert_resolver.add_cert_path(key_path, cert_path, true)?; + let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?; // add extra certificates if let Some(certs_dir) = certs_dir { @@ -40,11 +39,8 @@ pub fn configure_tls( let key_path = path.join("tls.key"); let cert_path = path.join("tls.crt"); if key_path.exists() && cert_path.exists() { - cert_resolver.add_cert_path( - &key_path.to_string_lossy(), - &cert_path.to_string_lossy(), - false, - )?; + cert_resolver + .add_cert_path(&key_path.to_string_lossy(), &cert_path.to_string_lossy())?; } } } @@ -83,92 +79,42 @@ pub fn configure_tls( }) } -#[derive(Default, Debug)] +#[derive(Debug)] pub struct CertResolver { certs: HashMap, TlsServerEndPoint)>, - default: Option<(Arc, TlsServerEndPoint)>, + default: (Arc, TlsServerEndPoint), } impl CertResolver { - pub fn new() -> Self { - Self::default() + fn parse_new(key_path: &str, cert_path: &str) -> anyhow::Result { + let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?; + Self::new(priv_key, cert_chain) } - fn add_cert_path( - &mut self, - key_path: &str, - cert_path: &str, - is_default: bool, - ) -> anyhow::Result<()> { - let priv_key = { - let key_bytes = std::fs::read(key_path) - .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?; - rustls_pemfile::private_key(&mut &key_bytes[..]) - .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? - .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? - }; + pub fn new( + priv_key: PrivateKeyDer<'static>, + cert_chain: Vec>, + ) -> anyhow::Result { + let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?; - let cert_chain_bytes = std::fs::read(cert_path) - .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; - - let cert_chain = { - rustls_pemfile::certs(&mut &cert_chain_bytes[..]) - .try_collect() - .with_context(|| { - format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.") - })? - }; - - self.add_cert(priv_key, cert_chain, is_default) + let mut certs = HashMap::new(); + let default = (cert.clone(), tls_server_end_point); + certs.insert(common_name, (cert, tls_server_end_point)); + Ok(Self { certs, default }) } - pub fn add_cert( + fn add_cert_path(&mut self, key_path: &str, cert_path: &str) -> anyhow::Result<()> { + let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?; + self.add_cert(priv_key, cert_chain) + } + + fn add_cert( &mut self, priv_key: PrivateKeyDer<'static>, cert_chain: Vec>, - is_default: bool, ) -> anyhow::Result<()> { - let key = sign::any_supported_type(&priv_key).context("invalid private key")?; - - let first_cert = &cert_chain[0]; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - - let certificate = SliceReader::new(first_cert) - .context("Failed to parse cerficiate")? - .decode::() - .context("Failed to parse cerficiate")?; - - let common_name = certificate.tbs_certificate.subject.to_string(); - - // We need to get the canonical name for this certificate so we can match them against any domain names - // seen within the proxy codebase. - // - // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI. - // We need to remove the wildcard prefix for the purposes of certificate selection. - // - // auth-broker does not use SNI and instead uses the Neon-Connection-String header. - // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String. - // - // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string - // validation, so let's we can continue with any common-name - let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") { - s.to_string() - } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") { - s.to_string() - } else if let Some(s) = common_name.strip_prefix("CN=") { - s.to_string() - } else { - bail!("Failed to parse common name from certificate") - }; - - let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key)); - - if is_default { - self.default = Some((cert.clone(), tls_server_end_point)); - } - + let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?; self.certs.insert(common_name, (cert, tls_server_end_point)); - Ok(()) } @@ -177,12 +123,82 @@ impl CertResolver { } } +fn parse_key_cert( + key_path: &str, + cert_path: &str, +) -> anyhow::Result<(PrivateKeyDer<'static>, Vec>)> { + let priv_key = { + let key_bytes = std::fs::read(key_path) + .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?; + rustls_pemfile::private_key(&mut &key_bytes[..]) + .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? + .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? + }; + + let cert_chain_bytes = std::fs::read(cert_path) + .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; + + let cert_chain = { + rustls_pemfile::certs(&mut &cert_chain_bytes[..]) + .try_collect() + .with_context(|| { + format!( + "Failed to read TLS certificate chain from bytes from file at '{cert_path}'." + ) + })? + }; + + Ok((priv_key, cert_chain)) +} + +fn process_key_cert( + priv_key: PrivateKeyDer<'static>, + cert_chain: Vec>, +) -> anyhow::Result<(String, Arc, TlsServerEndPoint)> { + let key = sign::any_supported_type(&priv_key).context("invalid private key")?; + + let first_cert = &cert_chain[0]; + let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; + + let certificate = SliceReader::new(first_cert) + .context("Failed to parse cerficiate")? + .decode::() + .context("Failed to parse cerficiate")?; + + let common_name = certificate.tbs_certificate.subject.to_string(); + + // We need to get the canonical name for this certificate so we can match them against any domain names + // seen within the proxy codebase. + // + // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI. + // We need to remove the wildcard prefix for the purposes of certificate selection. + // + // auth-broker does not use SNI and instead uses the Neon-Connection-String header. + // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String. + // + // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string + // validation, so let's we can continue with any common-name + let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") { + s.to_string() + } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") { + s.to_string() + } else if let Some(s) = common_name.strip_prefix("CN=") { + s.to_string() + } else { + bail!("Failed to parse common name from certificate") + }; + + let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key)); + + Ok((common_name, cert, tls_server_end_point)) +} + impl rustls::server::ResolvesServerCert for CertResolver { fn resolve( &self, client_hello: rustls::server::ClientHello<'_>, ) -> Option> { - self.resolve(client_hello.server_name()).map(|x| x.0) + Some(self.resolve(client_hello.server_name()).0) } } @@ -190,7 +206,7 @@ impl CertResolver { pub fn resolve( &self, server_name: Option<&str>, - ) -> Option<(Arc, TlsServerEndPoint)> { + ) -> (Arc, TlsServerEndPoint) { // loop here and cut off more and more subdomains until we find // a match to get a proper wildcard support. OTOH, we now do not // use nested domains, so keep this simple for now. @@ -200,12 +216,17 @@ impl CertResolver { if let Some(mut sni_name) = server_name { loop { if let Some(cert) = self.certs.get(sni_name) { - return Some(cert.clone()); + return cert.clone(); } if let Some((_, rest)) = sni_name.split_once('.') { sni_name = rest; } else { - return None; + // The customer has some custom DNS mapping - just return + // a default certificate. + // + // This will error if the customer uses anything stronger + // than sslmode=require. That's a choice they can make. + return self.default.clone(); } } } else {