mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 05:22:56 +00:00
Support extra domain names for proxy.
Make it possible to specify directory where proxy will look up for
extra certificates. Proxy will iterate through subdirs of that directory
and load `key.pem` and `cert.pem` files from each subdir. Certs directory
structure may look like that:
certs
|--example.com
| |--key.pem
| |--cert.pem
|--foo.bar
|--key.pem
|--cert.pem
Actual domain names are taken from certs and key, subdir names are
ignored.
This commit is contained in:
@@ -53,7 +53,7 @@ pub async fn password_hack(
|
||||
.await?;
|
||||
|
||||
info!(project = &payload.project, "received missing parameter");
|
||||
creds.project = Some(payload.project.into());
|
||||
creds.project = Some(payload.project);
|
||||
|
||||
let mut node = api.wake_compute(extra, creds).await?;
|
||||
node.config.password(payload.password);
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use crate::error::UserFacingError;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashSet;
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
|
||||
@@ -19,11 +19,10 @@ pub enum ClientCredsParseError {
|
||||
InconsistentProjectNames { domain: String, option: String },
|
||||
|
||||
#[error(
|
||||
"SNI ('{}') inconsistently formatted with respect to common name ('{}'). \
|
||||
SNI should be formatted as '<project-name>.{}'.",
|
||||
.sni, .cn, .cn,
|
||||
"Common name inferred from SNI ('{}') is not known",
|
||||
.cn,
|
||||
)]
|
||||
InconsistentSni { sni: String, cn: String },
|
||||
UnknownCommonName { cn: String },
|
||||
|
||||
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
|
||||
MalformedProjectName(String),
|
||||
@@ -37,7 +36,7 @@ impl UserFacingError for ClientCredsParseError {}
|
||||
pub struct ClientCredentials<'a> {
|
||||
pub user: &'a str,
|
||||
// TODO: this is a severe misnomer! We should think of a new name ASAP.
|
||||
pub project: Option<Cow<'a, str>>,
|
||||
pub project: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientCredentials<'_> {
|
||||
@@ -51,7 +50,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
pub fn parse(
|
||||
params: &'a StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
common_names: Option<HashSet<String>>,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
@@ -60,37 +59,43 @@ impl<'a> ClientCredentials<'a> {
|
||||
let user = get_param("user")?;
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let project_option = params.options_raw().and_then(|mut options| {
|
||||
options
|
||||
.find_map(|opt| opt.strip_prefix("project="))
|
||||
.map(Cow::Borrowed)
|
||||
});
|
||||
let project_option = params
|
||||
.options_raw()
|
||||
.and_then(|mut options| options.find_map(|opt| opt.strip_prefix("project=")))
|
||||
.map(|name| name.to_string());
|
||||
|
||||
// Alternative project name is in fact a subdomain from SNI.
|
||||
// NOTE: we do not consider SNI if `common_name` is missing.
|
||||
let project_domain = sni
|
||||
.zip(common_name)
|
||||
.map(|(sni, cn)| {
|
||||
subdomain_from_sni(sni, cn)
|
||||
.ok_or_else(|| InconsistentSni {
|
||||
sni: sni.into(),
|
||||
cn: cn.into(),
|
||||
let project_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain);
|
||||
|
||||
let project = common_name_from_sni
|
||||
.and_then(|domain| {
|
||||
if cn.contains(domain) {
|
||||
subdomain_from_sni(sni_str, domain)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.map(Cow::<'static, str>::Owned)
|
||||
})
|
||||
.transpose()?;
|
||||
.ok_or_else(|| UnknownCommonName {
|
||||
cn: common_name_from_sni.unwrap_or("").into(),
|
||||
})?;
|
||||
|
||||
let project = match (project_option, project_domain) {
|
||||
Some(project)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let project = match (project_option, project_from_domain) {
|
||||
// Invariant: if we have both project name variants, they should match.
|
||||
(Some(option), Some(domain)) if option != domain => {
|
||||
Some(Err(InconsistentProjectNames {
|
||||
domain: domain.into(),
|
||||
option: option.into(),
|
||||
}))
|
||||
Some(Err(InconsistentProjectNames { domain, option }))
|
||||
}
|
||||
// Invariant: project name may not contain certain characters.
|
||||
(a, b) => a.or(b).map(|name| match project_name_valid(&name) {
|
||||
false => Err(MalformedProjectName(name.into())),
|
||||
false => Err(MalformedProjectName(name)),
|
||||
true => Ok(name),
|
||||
}),
|
||||
}
|
||||
@@ -149,9 +154,9 @@ mod tests {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
let sni = Some("foo.localhost");
|
||||
let common_name = Some("localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
|
||||
@@ -177,24 +182,41 @@ mod tests {
|
||||
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
|
||||
|
||||
let sni = Some("baz.localhost");
|
||||
let common_name = Some("localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_multi_common_names() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.a.com");
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.b.com");
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_projects_different() {
|
||||
let options =
|
||||
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
|
||||
|
||||
let sni = Some("second.localhost");
|
||||
let common_name = Some("localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||
let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail");
|
||||
match err {
|
||||
InconsistentProjectNames { domain, option } => {
|
||||
assert_eq!(option, "first");
|
||||
@@ -209,13 +231,12 @@ mod tests {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_name = Some("example.com");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||
let err = ClientCredentials::parse(&options, sni, common_names).expect_err("should fail");
|
||||
match err {
|
||||
InconsistentSni { sni, cn } => {
|
||||
assert_eq!(sni, "project.localhost");
|
||||
assert_eq!(cn, "example.com");
|
||||
UnknownCommonName { cn } => {
|
||||
assert_eq!(cn, "localhost");
|
||||
}
|
||||
_ => panic!("bad error: {err:?}"),
|
||||
}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
use crate::auth;
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use std::{str::FromStr, sync::Arc, time::Duration};
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use rustls::sign;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
@@ -16,7 +22,7 @@ pub struct MetricCollectionConfig {
|
||||
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_name: Option<String>,
|
||||
pub common_names: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
@@ -26,28 +32,33 @@ impl TlsConfig {
|
||||
}
|
||||
|
||||
/// Configure TLS for the main endpoint.
|
||||
pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
|
||||
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
pub fn configure_tls(
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
certs_dir: Option<&String>,
|
||||
) -> anyhow::Result<TlsConfig> {
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
keys.pop().map(rustls::PrivateKey).unwrap()
|
||||
};
|
||||
// add default certificate
|
||||
cert_resolver.add_cert(key_path, cert_path)?;
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
for entry in std::fs::read_dir(certs_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let key_path = path.join("key.pem");
|
||||
let cert_path = path.join("cert.pem");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver
|
||||
.add_cert(&key_path.to_string_lossy(), &cert_path.to_string_lossy())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.context(format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
};
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
@@ -55,27 +66,105 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.with_cert_resolver(Arc::new(cert_resolver))
|
||||
.into();
|
||||
|
||||
// determine common name from tls-cert (-c server.crt param).
|
||||
// used in asserting project name formatting invariant.
|
||||
let common_name = {
|
||||
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
|
||||
.context(format!(
|
||||
"Failed to parse PEM object from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.1;
|
||||
let common_name = pem.parse_x509()?.subject().to_string();
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
};
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
common_name,
|
||||
common_names: Some(common_names),
|
||||
})
|
||||
}
|
||||
|
||||
struct CertResolver {
|
||||
certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
certs: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_cert(&mut self, key_path: &str, cert_path: &str) -> anyhow::Result<()> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
|
||||
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
keys.pop().map(rustls::PrivateKey).unwrap()
|
||||
};
|
||||
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.context(format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
};
|
||||
|
||||
let common_name = {
|
||||
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
|
||||
.context(format!(
|
||||
"Failed to parse PEM object from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.1;
|
||||
let common_name = pem.parse_x509()?.subject().to_string();
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
}
|
||||
.context(format!(
|
||||
"Failed to parse common name from certificate at '{cert_path}'."
|
||||
))?;
|
||||
|
||||
self.certs.insert(
|
||||
common_name,
|
||||
Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key)),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_common_names(&self) -> HashSet<String> {
|
||||
self.certs.keys().map(|s| s.to_string()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
_client_hello: rustls::server::ClientHello,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
//
|
||||
// With the current coding foo.com will match *.foo.com and that
|
||||
// repeats behavior of the old code.
|
||||
if let Some(mut sni_name) = _client_hello.server_name() {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return Some(cert.clone());
|
||||
}
|
||||
if let Some((_, rest)) = sni_name.split_once('.') {
|
||||
sni_name = rest;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
pub struct CacheOptions {
|
||||
/// Max number of entries.
|
||||
|
||||
@@ -132,7 +132,11 @@ fn build_config(args: &clap::ArgMatches) -> anyhow::Result<&'static ProxyConfig>
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?),
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||
key_path,
|
||||
cert_path,
|
||||
args.get_one::<String>("certs-dir"),
|
||||
)?),
|
||||
(None, None) => None,
|
||||
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
|
||||
};
|
||||
@@ -254,6 +258,12 @@ fn cli() -> clap::Command {
|
||||
.alias("ssl-cert") // backwards compatibility
|
||||
.help("path to TLS cert for client postgres connections"),
|
||||
)
|
||||
// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||
.arg(
|
||||
Arg::new("certs-dir")
|
||||
.long("certs-dir")
|
||||
.help("path to directory with TLS certificates for client postgres connections"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("metric-collection-endpoint")
|
||||
.long("metric-collection-endpoint")
|
||||
|
||||
@@ -124,11 +124,11 @@ pub async fn handle_ws_client(
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name))
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
@@ -163,11 +163,11 @@ async fn handle_client(
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let sni = stream.get_ref().sni_hostname();
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name))
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_names))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
|
||||
@@ -54,9 +54,11 @@ fn generate_tls_config<'a>(
|
||||
.with_single_cert(vec![cert], key)?
|
||||
.into();
|
||||
|
||||
let common_names = Some([common_name.to_owned()].iter().cloned().collect());
|
||||
|
||||
TlsConfig {
|
||||
config,
|
||||
common_name: Some(common_name.to_string()),
|
||||
common_names,
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user