Simplify SNI parsing

This commit is contained in:
Dmitry Ivanov
2023-04-05 13:03:16 +03:00
parent aba8cec279
commit 3f8751191b
4 changed files with 19 additions and 77 deletions

View File

@@ -18,13 +18,6 @@ pub enum ClientCredsParseError {
)]
InconsistentProjectNames { domain: String, option: String },
#[error(
"SNI ('{}') inconsistently formatted with respect to common name ('{}'). \
SNI should be formatted as '<project-name>.{}'.",
.sni, .cn, .cn,
)]
InconsistentSni { sni: String, cn: String },
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
MalformedProjectName(String),
}
@@ -51,7 +44,6 @@ impl<'a> ClientCredentials<'a> {
pub fn parse(
params: &'a StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
@@ -67,18 +59,10 @@ impl<'a> ClientCredentials<'a> {
});
// Alternative project name is in fact a subdomain from SNI.
// NOTE: we do not consider SNI if `common_name` is missing.
let project_domain = sni
.zip(common_name)
.map(|(sni, cn)| {
subdomain_from_sni(sni, cn)
.ok_or_else(|| InconsistentSni {
sni: sni.into(),
cn: cn.into(),
})
.map(Cow::<'static, str>::Owned)
})
.transpose()?;
let project_domain = sni.and_then(|sni| {
let (domain, _) = sni.split_once('.')?;
Some(Cow::from(domain.to_owned()))
});
let project = match (project_option, project_domain) {
// Invariant: if we have both project name variants, they should match.
@@ -106,12 +90,6 @@ fn project_name_valid(name: &str) -> bool {
name.chars().all(|c| c.is_alphanumeric() || c == '-')
}
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
sni.strip_suffix(common_name)?
.strip_suffix('.')
.map(str::to_owned)
}
#[cfg(test)]
mod tests {
use super::*;
@@ -122,7 +100,7 @@ mod tests {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let creds = ClientCredentials::parse(&options, None, None)?;
let creds = ClientCredentials::parse(&options, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -137,7 +115,7 @@ mod tests {
("foo", "bar"), // should be ignored
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let creds = ClientCredentials::parse(&options, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -149,9 +127,8 @@ mod tests {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(&options, sni, common_name)?;
let creds = ClientCredentials::parse(&options, sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("foo"));
@@ -165,7 +142,7 @@ mod tests {
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let creds = ClientCredentials::parse(&options, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -177,9 +154,8 @@ mod tests {
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(&options, sni, common_name)?;
let creds = ClientCredentials::parse(&options, sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("baz"));
@@ -192,9 +168,8 @@ mod tests {
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
let sni = Some("second.localhost");
let common_name = Some("localhost");
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
let err = ClientCredentials::parse(&options, sni).expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -203,21 +178,4 @@ mod tests {
_ => panic!("bad error: {err:?}"),
}
}
#[test]
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_name = Some("example.com");
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
match err {
InconsistentSni { sni, cn } => {
assert_eq!(sni, "project.localhost");
assert_eq!(cn, "example.com");
}
_ => panic!("bad error: {err:?}"),
}
}
}

View File

@@ -16,7 +16,6 @@ pub struct MetricCollectionConfig {
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_name: Option<String>,
}
impl TlsConfig {
@@ -27,19 +26,16 @@ impl TlsConfig {
impl TlsConfig {
pub fn new(resolver: certs::CertResolver) -> anyhow::Result<Self> {
let resolver = Arc::new(resolver);
let rustls_config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
// allow TLS 1.2 to be compatible with older client libraries
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_cert_resolver(resolver.clone());
.with_cert_resolver(Arc::new(resolver));
let config = TlsConfig {
config: Arc::new(rustls_config),
common_name: None,
};
Ok(config)

View File

@@ -112,7 +112,6 @@ pub async fn handle_ws_client(
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
}
let tls = config.tls_config.as_ref();
let hostname = hostname.as_deref();
// TLS is None here, because the connection is already encrypted.
@@ -124,11 +123,10 @@ pub async fn handle_ws_client(
// Extract credentials which we're going to use for auth.
let creds = {
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_name))
.map(|_| auth::ClientCredentials::parse(&params, hostname))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
@@ -163,11 +161,10 @@ async fn handle_client(
// Extract credentials which we're going to use for auth.
let creds = {
let sni = stream.get_ref().sni_hostname();
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.map(|_| auth::ClientCredentials::parse(&params, sni))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?

View File

@@ -41,10 +41,7 @@ impl ClientConfig<'_> {
}
/// Generate TLS certificates and build rustls configs for client and server.
fn generate_tls_config<'a>(
hostname: &'a str,
common_name: &'a str,
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
fn generate_tls_config(hostname: &str) -> anyhow::Result<(ClientConfig<'_>, TlsConfig)> {
let (ca, cert, key) = generate_certs(hostname)?;
let tls_config = {
@@ -54,10 +51,7 @@ fn generate_tls_config<'a>(
.with_single_cert(vec![cert], key)?
.into();
TlsConfig {
config,
common_name: Some(common_name.to_string()),
}
TlsConfig { config }
};
let client_config = {
@@ -150,7 +144,7 @@ async fn dummy_proxy(
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
let (_, server_config) = generate_tls_config("generic-project-name.localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let client_err = tokio_postgres::Config::new()
@@ -178,8 +172,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
async fn handshake_tls() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let (client_config, server_config) = generate_tls_config("generic-project-name.localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
@@ -237,8 +230,7 @@ async fn keepalive_is_inherited() -> anyhow::Result<()> {
async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let (client_config, server_config) = generate_tls_config("generic-project-name.localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
@@ -260,8 +252,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
async fn scram_auth_mock() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let (client_config, server_config) = generate_tls_config("generic-project-name.localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),