diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index f710581cb2..d45806461e 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -53,7 +53,7 @@ pub async fn password_hack( .await?; info!(project = &payload.project, "received missing parameter"); - creds.project = Some(payload.project.into()); + creds.project = Some(payload.project); let mut node = api.wake_compute(extra, creds).await?; node.config.password(payload.password); diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index c556c33197..b21cd79ddf 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -2,7 +2,7 @@ use crate::error::UserFacingError; use pq_proto::StartupMessageParams; -use std::borrow::Cow; +use std::collections::HashSet; use thiserror::Error; use tracing::info; @@ -19,11 +19,10 @@ pub enum ClientCredsParseError { InconsistentProjectNames { domain: String, option: String }, #[error( - "SNI ('{}') inconsistently formatted with respect to common name ('{}'). \ - SNI should be formatted as '.{}'.", - .sni, .cn, .cn, + "Common name inferred from SNI ('{}') is not known", + .cn, )] - InconsistentSni { sni: String, cn: String }, + UnknownCommonName { cn: String }, #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")] MalformedProjectName(String), @@ -37,7 +36,7 @@ impl UserFacingError for ClientCredsParseError {} pub struct ClientCredentials<'a> { pub user: &'a str, // TODO: this is a severe misnomer! We should think of a new name ASAP. - pub project: Option>, + pub project: Option, } impl ClientCredentials<'_> { @@ -51,7 +50,7 @@ impl<'a> ClientCredentials<'a> { pub fn parse( params: &'a StartupMessageParams, sni: Option<&str>, - common_name: Option<&str>, + common_names: Option>, ) -> Result { use ClientCredsParseError::*; @@ -60,37 +59,43 @@ impl<'a> ClientCredentials<'a> { let user = get_param("user")?; // Project name might be passed via PG's command-line options. - let project_option = params.options_raw().and_then(|mut options| { - options - .find_map(|opt| opt.strip_prefix("project=")) - .map(Cow::Borrowed) - }); + let project_option = params + .options_raw() + .and_then(|mut options| options.find_map(|opt| opt.strip_prefix("project="))) + .map(|name| name.to_string()); - // Alternative project name is in fact a subdomain from SNI. - // NOTE: we do not consider SNI if `common_name` is missing. - let project_domain = sni - .zip(common_name) - .map(|(sni, cn)| { - subdomain_from_sni(sni, cn) - .ok_or_else(|| InconsistentSni { - sni: sni.into(), - cn: cn.into(), + let project_from_domain = if let Some(sni_str) = sni { + if let Some(cn) = common_names { + let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain); + + let project = common_name_from_sni + .and_then(|domain| { + if cn.contains(domain) { + subdomain_from_sni(sni_str, domain) + } else { + None + } }) - .map(Cow::<'static, str>::Owned) - }) - .transpose()?; + .ok_or_else(|| UnknownCommonName { + cn: common_name_from_sni.unwrap_or("").into(), + })?; - let project = match (project_option, project_domain) { + Some(project) + } else { + None + } + } else { + None + }; + + let project = match (project_option, project_from_domain) { // Invariant: if we have both project name variants, they should match. (Some(option), Some(domain)) if option != domain => { - Some(Err(InconsistentProjectNames { - domain: domain.into(), - option: option.into(), - })) + Some(Err(InconsistentProjectNames { domain, option })) } // Invariant: project name may not contain certain characters. (a, b) => a.or(b).map(|name| match project_name_valid(&name) { - false => Err(MalformedProjectName(name.into())), + false => Err(MalformedProjectName(name)), true => Ok(name), }), } @@ -149,9 +154,9 @@ mod tests { let options = StartupMessageParams::new([("user", "john_doe")]); let sni = Some("foo.localhost"); - let common_name = Some("localhost"); + let common_names = Some(["localhost".into()].into()); - let creds = ClientCredentials::parse(&options, sni, common_name)?; + let creds = ClientCredentials::parse(&options, sni, common_names)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("foo")); @@ -177,24 +182,41 @@ mod tests { let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]); let sni = Some("baz.localhost"); - let common_name = Some("localhost"); + let common_names = Some(["localhost".into()].into()); - let creds = ClientCredentials::parse(&options, sni, common_name)?; + let creds = ClientCredentials::parse(&options, sni, common_names)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("baz")); Ok(()) } + #[test] + fn parse_multi_common_names() -> anyhow::Result<()> { + let options = StartupMessageParams::new([("user", "john_doe")]); + + let common_names = Some(["a.com".into(), "b.com".into()].into()); + let sni = Some("p1.a.com"); + let creds = ClientCredentials::parse(&options, sni, common_names)?; + assert_eq!(creds.project.as_deref(), Some("p1")); + + let common_names = Some(["a.com".into(), "b.com".into()].into()); + let sni = Some("p1.b.com"); + let creds = ClientCredentials::parse(&options, sni, common_names)?; + assert_eq!(creds.project.as_deref(), Some("p1")); + + Ok(()) + } + #[test] fn parse_projects_different() { let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]); let sni = Some("second.localhost"); - let common_name = Some("localhost"); + let common_names = Some(["localhost".into()].into()); - let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); + let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail"); match err { InconsistentProjectNames { domain, option } => { assert_eq!(option, "first"); @@ -209,13 +231,12 @@ mod tests { let options = StartupMessageParams::new([("user", "john_doe")]); let sni = Some("project.localhost"); - let common_name = Some("example.com"); + let common_names = Some(["example.com".into()].into()); - let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); + let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail"); match err { - InconsistentSni { sni, cn } => { - assert_eq!(sni, "project.localhost"); - assert_eq!(cn, "example.com"); + UnknownCommonName { cn } => { + assert_eq!(cn, "localhost"); } _ => panic!("bad error: {err:?}"), } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 600db7f8ec..9f6241d733 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,6 +1,12 @@ use crate::auth; -use anyhow::{bail, ensure, Context}; -use std::{str::FromStr, sync::Arc, time::Duration}; +use anyhow::{bail, ensure, Context, Ok}; +use rustls::sign; +use std::{ + collections::{HashMap, HashSet}, + str::FromStr, + sync::Arc, + time::Duration, +}; pub struct ProxyConfig { pub tls_config: Option, @@ -16,7 +22,7 @@ pub struct MetricCollectionConfig { pub struct TlsConfig { pub config: Arc, - pub common_name: Option, + pub common_names: Option>, } impl TlsConfig { @@ -26,28 +32,33 @@ impl TlsConfig { } /// Configure TLS for the main endpoint. -pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result { - let key = { - let key_bytes = std::fs::read(key_path).context("TLS key file")?; - let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) - .context(format!("Failed to read TLS keys at '{key_path}'"))?; +pub fn configure_tls( + key_path: &str, + cert_path: &str, + certs_dir: Option<&String>, +) -> anyhow::Result { + let mut cert_resolver = CertResolver::new(); - ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); - keys.pop().map(rustls::PrivateKey).unwrap() - }; + // add default certificate + cert_resolver.add_cert(key_path, cert_path)?; - let cert_chain_bytes = std::fs::read(cert_path) - .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; + // add extra certificates + if let Some(certs_dir) = certs_dir { + for entry in std::fs::read_dir(certs_dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_dir() { + let key_path = path.join("key.pem"); + let cert_path = path.join("cert.pem"); + if key_path.exists() && cert_path.exists() { + cert_resolver + .add_cert(&key_path.to_string_lossy(), &cert_path.to_string_lossy())?; + } + } + } + } - let cert_chain = { - rustls_pemfile::certs(&mut &cert_chain_bytes[..]) - .context(format!( - "Failed to read TLS certificate chain from bytes from file at '{cert_path}'." - ))? - .into_iter() - .map(rustls::Certificate) - .collect() - }; + let common_names = cert_resolver.get_common_names(); let config = rustls::ServerConfig::builder() .with_safe_default_cipher_suites() @@ -55,27 +66,105 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result>, +} + +impl CertResolver { + fn new() -> Self { + Self { + certs: HashMap::new(), + } + } + + fn add_cert(&mut self, key_path: &str, cert_path: &str) -> anyhow::Result<()> { + let priv_key = { + let key_bytes = std::fs::read(key_path).context("TLS key file")?; + let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) + .context(format!("Failed to read TLS keys at '{key_path}'"))?; + + ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); + keys.pop().map(rustls::PrivateKey).unwrap() + }; + + let key = sign::any_supported_type(&priv_key).context("invalid private key")?; + + let cert_chain_bytes = std::fs::read(cert_path) + .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; + + let cert_chain = { + rustls_pemfile::certs(&mut &cert_chain_bytes[..]) + .context(format!( + "Failed to read TLS certificate chain from bytes from file at '{cert_path}'." + ))? + .into_iter() + .map(rustls::Certificate) + .collect() + }; + + let common_name = { + let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes) + .context(format!( + "Failed to parse PEM object from bytes from file at '{cert_path}'." + ))? + .1; + let common_name = pem.parse_x509()?.subject().to_string(); + common_name.strip_prefix("CN=*.").map(|s| s.to_string()) + } + .context(format!( + "Failed to parse common name from certificate at '{cert_path}'." + ))?; + + self.certs.insert( + common_name, + Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key)), + ); + + Ok(()) + } + + fn get_common_names(&self) -> HashSet { + self.certs.keys().map(|s| s.to_string()).collect() + } +} + +impl rustls::server::ResolvesServerCert for CertResolver { + fn resolve( + &self, + _client_hello: rustls::server::ClientHello, + ) -> Option> { + // loop here and cut off more and more subdomains until we find + // a match to get a proper wildcard support. OTOH, we now do not + // use nested domains, so keep this simple for now. + // + // With the current coding foo.com will match *.foo.com and that + // repeats behavior of the old code. + if let Some(mut sni_name) = _client_hello.server_name() { + loop { + if let Some(cert) = self.certs.get(sni_name) { + return Some(cert.clone()); + } + if let Some((_, rest)) = sni_name.split_once('.') { + sni_name = rest; + } else { + return None; + } + } + } else { + None + } + } +} + /// Helper for cmdline cache options parsing. pub struct CacheOptions { /// Max number of entries. diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 85478da3bc..c6526e9aff 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -132,7 +132,11 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig> args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { - (Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?), + (Some(key_path), Some(cert_path)) => Some(config::configure_tls( + key_path, + cert_path, + args.get_one::("certs-dir"), + )?), (None, None) => None, _ => bail!("either both or neither tls-key and tls-cert must be specified"), }; @@ -254,6 +258,12 @@ fn cli() -> clap::Command { .alias("ssl-cert") // backwards compatibility .help("path to TLS cert for client postgres connections"), ) + // tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir + .arg( + Arg::new("certs-dir") + .long("certs-dir") + .help("path to directory with TLS certificates for client postgres connections"), + ) .arg( Arg::new("metric-collection-endpoint") .long("metric-collection-endpoint") diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 03c9c72f30..70fb25474e 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -124,11 +124,11 @@ pub async fn handle_ws_client( // Extract credentials which we're going to use for auth. let creds = { - let common_name = tls.and_then(|tls| tls.common_name.as_deref()); + let common_names = tls.and_then(|tls| tls.common_names.clone()); let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name)) + .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await? @@ -163,11 +163,11 @@ async fn handle_client( // Extract credentials which we're going to use for auth. let creds = { let sni = stream.get_ref().sni_hostname(); - let common_name = tls.and_then(|tls| tls.common_name.as_deref()); + let common_names = tls.and_then(|tls| tls.common_names.clone()); let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name)) + .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_names)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await? diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index ed429df421..60acb588dc 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -54,9 +54,11 @@ fn generate_tls_config<'a>( .with_single_cert(vec![cert], key)? .into(); + let common_names = Some([common_name.to_owned()].iter().cloned().collect()); + TlsConfig { config, - common_name: Some(common_name.to_string()), + common_names, } };