diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index c556c33197..719175e79e 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -18,13 +18,6 @@ pub enum ClientCredsParseError { )] InconsistentProjectNames { domain: String, option: String }, - #[error( - "SNI ('{}') inconsistently formatted with respect to common name ('{}'). \ - SNI should be formatted as '.{}'.", - .sni, .cn, .cn, - )] - InconsistentSni { sni: String, cn: String }, - #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")] MalformedProjectName(String), } @@ -51,7 +44,6 @@ impl<'a> ClientCredentials<'a> { pub fn parse( params: &'a StartupMessageParams, sni: Option<&str>, - common_name: Option<&str>, ) -> Result { use ClientCredsParseError::*; @@ -67,18 +59,10 @@ impl<'a> ClientCredentials<'a> { }); // Alternative project name is in fact a subdomain from SNI. - // NOTE: we do not consider SNI if `common_name` is missing. - let project_domain = sni - .zip(common_name) - .map(|(sni, cn)| { - subdomain_from_sni(sni, cn) - .ok_or_else(|| InconsistentSni { - sni: sni.into(), - cn: cn.into(), - }) - .map(Cow::<'static, str>::Owned) - }) - .transpose()?; + let project_domain = sni.and_then(|sni| { + let (domain, _) = sni.split_once('.')?; + Some(Cow::from(domain.to_owned())) + }); let project = match (project_option, project_domain) { // Invariant: if we have both project name variants, they should match. @@ -106,12 +90,6 @@ fn project_name_valid(name: &str) -> bool { name.chars().all(|c| c.is_alphanumeric() || c == '-') } -fn subdomain_from_sni(sni: &str, common_name: &str) -> Option { - sni.strip_suffix(common_name)? - .strip_suffix('.') - .map(str::to_owned) -} - #[cfg(test)] mod tests { use super::*; @@ -122,7 +100,7 @@ mod tests { // According to postgresql, only `user` should be required. let options = StartupMessageParams::new([("user", "john_doe")]); - let creds = ClientCredentials::parse(&options, None, None)?; + let creds = ClientCredentials::parse(&options, None)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project, None); @@ -137,7 +115,7 @@ mod tests { ("foo", "bar"), // should be ignored ]); - let creds = ClientCredentials::parse(&options, None, None)?; + let creds = ClientCredentials::parse(&options, None)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project, None); @@ -149,9 +127,8 @@ mod tests { let options = StartupMessageParams::new([("user", "john_doe")]); let sni = Some("foo.localhost"); - let common_name = Some("localhost"); - let creds = ClientCredentials::parse(&options, sni, common_name)?; + let creds = ClientCredentials::parse(&options, sni)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("foo")); @@ -165,7 +142,7 @@ mod tests { ("options", "-ckey=1 project=bar -c geqo=off"), ]); - let creds = ClientCredentials::parse(&options, None, None)?; + let creds = ClientCredentials::parse(&options, None)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("bar")); @@ -177,9 +154,8 @@ mod tests { let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]); let sni = Some("baz.localhost"); - let common_name = Some("localhost"); - let creds = ClientCredentials::parse(&options, sni, common_name)?; + let creds = ClientCredentials::parse(&options, sni)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("baz")); @@ -192,9 +168,8 @@ mod tests { StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]); let sni = Some("second.localhost"); - let common_name = Some("localhost"); - let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); + let err = ClientCredentials::parse(&options, sni).expect_err("should fail"); match err { InconsistentProjectNames { domain, option } => { assert_eq!(option, "first"); @@ -203,21 +178,4 @@ mod tests { _ => panic!("bad error: {err:?}"), } } - - #[test] - fn parse_inconsistent_sni() { - let options = StartupMessageParams::new([("user", "john_doe")]); - - let sni = Some("project.localhost"); - let common_name = Some("example.com"); - - let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); - match err { - InconsistentSni { sni, cn } => { - assert_eq!(sni, "project.localhost"); - assert_eq!(cn, "example.com"); - } - _ => panic!("bad error: {err:?}"), - } - } } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 31b9480703..3dce4998de 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -16,7 +16,6 @@ pub struct MetricCollectionConfig { pub struct TlsConfig { pub config: Arc, - pub common_name: Option, } impl TlsConfig { @@ -27,19 +26,16 @@ impl TlsConfig { impl TlsConfig { pub fn new(resolver: certs::CertResolver) -> anyhow::Result { - let resolver = Arc::new(resolver); - let rustls_config = rustls::ServerConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() // allow TLS 1.2 to be compatible with older client libraries .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])? .with_no_client_auth() - .with_cert_resolver(resolver.clone()); + .with_cert_resolver(Arc::new(resolver)); let config = TlsConfig { config: Arc::new(rustls_config), - common_name: None, }; Ok(config) diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 03c9c72f30..f1358627af 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -112,7 +112,6 @@ pub async fn handle_ws_client( NUM_CONNECTIONS_CLOSED_COUNTER.inc(); } - let tls = config.tls_config.as_ref(); let hostname = hostname.as_deref(); // TLS is None here, because the connection is already encrypted. @@ -124,11 +123,10 @@ pub async fn handle_ws_client( // Extract credentials which we're going to use for auth. let creds = { - let common_name = tls.and_then(|tls| tls.common_name.as_deref()); let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name)) + .map(|_| auth::ClientCredentials::parse(¶ms, hostname)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await? @@ -163,11 +161,10 @@ async fn handle_client( // Extract credentials which we're going to use for auth. let creds = { let sni = stream.get_ref().sni_hostname(); - let common_name = tls.and_then(|tls| tls.common_name.as_deref()); let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name)) + .map(|_| auth::ClientCredentials::parse(¶ms, sni)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await? diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index ed429df421..1441f5c8b0 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -41,10 +41,7 @@ impl ClientConfig<'_> { } /// Generate TLS certificates and build rustls configs for client and server. -fn generate_tls_config<'a>( - hostname: &'a str, - common_name: &'a str, -) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> { +fn generate_tls_config(hostname: &str) -> anyhow::Result<(ClientConfig<'_>, TlsConfig)> { let (ca, cert, key) = generate_certs(hostname)?; let tls_config = { @@ -54,10 +51,7 @@ fn generate_tls_config<'a>( .with_single_cert(vec![cert], key)? .into(); - TlsConfig { - config, - common_name: Some(common_name.to_string()), - } + TlsConfig { config } }; let client_config = { @@ -150,7 +144,7 @@ async fn dummy_proxy( async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?; + let (_, server_config) = generate_tls_config("generic-project-name.localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let client_err = tokio_postgres::Config::new() @@ -178,8 +172,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { async fn handshake_tls() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (client_config, server_config) = - generate_tls_config("generic-project-name.localhost", "localhost")?; + let (client_config, server_config) = generate_tls_config("generic-project-name.localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let (_client, _conn) = tokio_postgres::Config::new() @@ -237,8 +230,7 @@ async fn keepalive_is_inherited() -> anyhow::Result<()> { async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (client_config, server_config) = - generate_tls_config("generic-project-name.localhost", "localhost")?; + let (client_config, server_config) = generate_tls_config("generic-project-name.localhost")?; let proxy = tokio::spawn(dummy_proxy( client, Some(server_config), @@ -260,8 +252,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { async fn scram_auth_mock() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (client_config, server_config) = - generate_tls_config("generic-project-name.localhost", "localhost")?; + let (client_config, server_config) = generate_tls_config("generic-project-name.localhost")?; let proxy = tokio::spawn(dummy_proxy( client, Some(server_config),