diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 60d1962d7f..0992c6d875 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -10,7 +10,6 @@ use tracing::info; use super::backend::ComputeCredentialKeys; use super::{AuthError, PasswordHackPayload}; -use crate::config::TlsServerEndPoint; use crate::context::RequestContext; use crate::control_plane::AuthSecret; use crate::intern::EndpointIdInt; @@ -18,6 +17,7 @@ use crate::sasl; use crate::scram::threadpool::ThreadPool; use crate::scram::{self}; use crate::stream::{PqStream, Stream}; +use crate::tls::TlsServerEndPoint; /// Every authentication selector is supposed to implement this trait. pub(crate) trait AuthMethod { diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 56bbd94850..644f670f88 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -13,7 +13,9 @@ use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP}; use proxy::auth::{self}; use proxy::cancellation::CancellationHandlerMain; -use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig}; +use proxy::config::{ + self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, +}; use proxy::control_plane::locks::ApiLocks; use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings}; use proxy::http::health_server::AppMetrics; @@ -25,6 +27,7 @@ use proxy::rate_limiter::{ use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::{self, GlobalConnPoolOptions}; +use proxy::tls::client_config::compute_client_config_with_root_certs; use proxy::types::RoleName; use proxy::url::ApiUrl; @@ -209,6 +212,7 @@ async fn main() -> anyhow::Result<()> { http_listener, shutdown.clone(), Arc::new(CancellationHandlerMain::new( + &config.connect_to_compute, Arc::new(DashMap::new()), None, proxy::metrics::CancellationSource::Local, @@ -268,6 +272,12 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, }; + let compute_config = ComputeConfig { + retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?, + tls: Arc::new(compute_client_config_with_root_certs()?), + timeout: Duration::from_secs(2), + }; + Ok(Box::leak(Box::new(ProxyConfig { tls_config: None, metric_collection: None, @@ -289,9 +299,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig region: "local".into(), wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, connect_compute_locks, - connect_to_compute_retry_config: RetryConfig::parse( - RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES, - )?, + connect_to_compute: compute_config, }))) } diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 9538384b9e..97d870a83a 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -10,12 +10,12 @@ use clap::Arg; use futures::future::Either; use futures::TryFutureExt; use itertools::Itertools; -use proxy::config::TlsServerEndPoint; use proxy::context::RequestContext; use proxy::metrics::{Metrics, ThreadPoolMetrics}; use proxy::protocol2::ConnectionInfo; use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource}; use proxy::stream::{PqStream, Stream}; +use proxy::tls::TlsServerEndPoint; use rustls::crypto::ring; use rustls::pki_types::PrivateKeyDer; use tokio::io::{AsyncRead, AsyncWrite}; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 3dcf9ca060..3b122d771c 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,6 +1,7 @@ use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; +use std::time::Duration; use anyhow::bail; use futures::future::Either; @@ -8,7 +9,7 @@ use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; use proxy::cancellation::{CancelMap, CancellationHandler}; use proxy::config::{ - self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig, + self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, }; use proxy::context::parquet::ParquetUploadArgs; @@ -23,6 +24,7 @@ use proxy::redis::{elasticache, notifications}; use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::GlobalConnPoolOptions; +use proxy::tls::client_config::compute_client_config_with_root_certs; use proxy::{auth, control_plane, http, serverless, usage_metrics}; use remote_storage::RemoteStorageConfig; use tokio::net::TcpListener; @@ -397,6 +399,7 @@ async fn main() -> anyhow::Result<()> { let cancellation_handler = Arc::new(CancellationHandler::< Option>>, >::new( + &config.connect_to_compute, cancel_map.clone(), redis_publisher, proxy::metrics::CancellationSource::FromClient, @@ -492,6 +495,7 @@ async fn main() -> anyhow::Result<()> { let cache = api.caches.project_info.clone(); if let Some(client) = client1 { maintenance_tasks.spawn(notifications::task_main( + config, client, cache.clone(), cancel_map.clone(), @@ -500,6 +504,7 @@ async fn main() -> anyhow::Result<()> { } if let Some(client) = client2 { maintenance_tasks.spawn(notifications::task_main( + config, client, cache.clone(), cancel_map.clone(), @@ -632,6 +637,12 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { console_redirect_confirmation_timeout: args.webauth_confirmation_timeout, }; + let compute_config = ComputeConfig { + retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?, + tls: Arc::new(compute_client_config_with_root_certs()?), + timeout: Duration::from_secs(2), + }; + let config = ProxyConfig { tls_config, metric_collection, @@ -642,9 +653,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { region: args.region.clone(), wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, connect_compute_locks, - connect_to_compute_retry_config: config::RetryConfig::parse( - &args.connect_to_compute_retry, - )?, + connect_to_compute: compute_config, }; let config = Box::leak(Box::new(config)); diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index ebaea173ae..df618cf242 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -3,11 +3,9 @@ use std::sync::Arc; use dashmap::DashMap; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; -use once_cell::sync::OnceCell; use postgres_client::tls::MakeTlsConnect; use postgres_client::CancelToken; use pq_proto::CancelKeyData; -use rustls::crypto::ring; use thiserror::Error; use tokio::net::TcpStream; use tokio::sync::Mutex; @@ -15,15 +13,15 @@ use tracing::{debug, info}; use uuid::Uuid; use crate::auth::{check_peer_addr_is_in_list, IpPattern}; -use crate::compute::load_certs; +use crate::config::ComputeConfig; use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancellationRequest, CancellationSource, Metrics}; -use crate::postgres_rustls::MakeRustlsConnect; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::cancellation_publisher::{ CancellationPublisher, CancellationPublisherMut, RedisPublisherClient, }; +use crate::tls::postgres_rustls::MakeRustlsConnect; pub type CancelMap = Arc>>; pub type CancellationHandlerMain = CancellationHandler>>>; @@ -35,6 +33,7 @@ type IpSubnetKey = IpNet; /// /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances. pub struct CancellationHandler

{ + compute_config: &'static ComputeConfig, map: CancelMap, client: P, /// This field used for the monitoring purposes. @@ -183,7 +182,7 @@ impl CancellationHandler

{ "cancelling query per user's request using key {key}, hostname {}, address: {}", cancel_closure.hostname, cancel_closure.socket_addr ); - cancel_closure.try_cancel_query().await + cancel_closure.try_cancel_query(self.compute_config).await } #[cfg(test)] @@ -198,8 +197,13 @@ impl CancellationHandler

{ } impl CancellationHandler<()> { - pub fn new(map: CancelMap, from: CancellationSource) -> Self { + pub fn new( + compute_config: &'static ComputeConfig, + map: CancelMap, + from: CancellationSource, + ) -> Self { Self { + compute_config, map, client: (), from, @@ -214,8 +218,14 @@ impl CancellationHandler<()> { } impl CancellationHandler>>> { - pub fn new(map: CancelMap, client: Option>>, from: CancellationSource) -> Self { + pub fn new( + compute_config: &'static ComputeConfig, + map: CancelMap, + client: Option>>, + from: CancellationSource, + ) -> Self { Self { + compute_config, map, client, from, @@ -229,8 +239,6 @@ impl CancellationHandler>>> { } } -static TLS_ROOTS: OnceCell> = OnceCell::new(); - /// This should've been a [`std::future::Future`], but /// it's impossible to name a type of an unboxed future /// (we'd need something like `#![feature(type_alias_impl_trait)]`). @@ -257,27 +265,14 @@ impl CancelClosure { } } /// Cancels the query running on user's compute node. - pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> { + pub(crate) async fn try_cancel_query( + self, + compute_config: &ComputeConfig, + ) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; - let root_store = TLS_ROOTS - .get_or_try_init(load_certs) - .map_err(|_e| { - CancelError::IO(std::io::Error::new( - std::io::ErrorKind::Other, - "TLS root store initialization failed".to_string(), - )) - })? - .clone(); - - let client_config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .expect("ring should support the default protocol versions") - .with_root_certificates(root_store) - .with_no_client_auth(); - - let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config); + let mut mk_tls = + crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); let tls = >::make_tls_connect( &mut mk_tls, &self.hostname, @@ -329,11 +324,30 @@ impl

Drop for Session

{ #[cfg(test)] #[expect(clippy::unwrap_used)] mod tests { + use std::time::Duration; + use super::*; + use crate::config::RetryConfig; + use crate::tls::client_config::compute_client_config_with_certs; + + fn config() -> ComputeConfig { + let retry = RetryConfig { + base_delay: Duration::from_secs(1), + max_retries: 5, + backoff_factor: 2.0, + }; + + ComputeConfig { + retry, + tls: Arc::new(compute_client_config_with_certs(std::iter::empty())), + timeout: Duration::from_secs(2), + } + } #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { let cancellation_handler = Arc::new(CancellationHandler::<()>::new( + Box::leak(Box::new(config())), CancelMap::default(), CancellationSource::FromRedis, )); @@ -349,8 +363,11 @@ mod tests { #[tokio::test] async fn cancel_session_noop_regression() { - let handler = - CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local); + let handler = CancellationHandler::<()>::new( + Box::leak(Box::new(config())), + CancelMap::default(), + CancellationSource::Local, + ); handler .cancel_session( CancelKeyData { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 8dc9b59e81..d60dfd0f80 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,16 +1,13 @@ use std::io; use std::net::SocketAddr; -use std::sync::Arc; use std::time::Duration; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; -use once_cell::sync::OnceCell; use postgres_client::tls::MakeTlsConnect; use postgres_client::{CancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; use pq_proto::StartupMessageParams; -use rustls::crypto::ring; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::TcpStream; @@ -18,14 +15,15 @@ use tracing::{debug, error, info, warn}; use crate::auth::parse_endpoint_param; use crate::cancellation::CancelClosure; +use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; -use crate::postgres_rustls::MakeRustlsConnect; use crate::proxy::neon_option; +use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; @@ -40,9 +38,6 @@ pub(crate) enum ConnectionError { #[error("{COULD_NOT_CONNECT}: {0}")] CouldNotConnect(#[from] io::Error), - #[error("Couldn't load native TLS certificates: {0:?}")] - TlsCertificateError(Vec), - #[error("{COULD_NOT_CONNECT}: {0}")] TlsError(#[from] InvalidDnsNameError), @@ -89,7 +84,6 @@ impl ReportableError for ConnectionError { } ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute, - ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service, ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), @@ -251,25 +245,13 @@ impl ConnCfg { &self, ctx: &RequestContext, aux: MetricsAuxInfo, - timeout: Duration, + config: &ComputeConfig, ) -> Result { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream, host) = self.connect_raw(timeout).await?; + let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?; drop(pause); - let root_store = TLS_ROOTS - .get_or_try_init(load_certs) - .map_err(ConnectionError::TlsCertificateError)? - .clone(); - - let client_config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .expect("ring should support the default protocol versions") - .with_root_certificates(root_store) - .with_no_client_auth(); - - let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config); + let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone()); let tls = >::make_tls_connect( &mut mk_tls, host, @@ -341,19 +323,6 @@ fn filtered_options(options: &str) -> Option { Some(options) } -pub(crate) fn load_certs() -> Result, Vec> { - let der_certs = rustls_native_certs::load_native_certs(); - - if !der_certs.errors.is_empty() { - return Err(der_certs.errors); - } - - let mut store = rustls::RootCertStore::empty(); - store.add_parsable_certificates(der_certs.certs); - Ok(Arc::new(store)) -} -static TLS_ROOTS: OnceCell> = OnceCell::new(); - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 33d1d2e9e4..8502edcfab 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,17 +1,10 @@ -use std::collections::{HashMap, HashSet}; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use anyhow::{bail, ensure, Context, Ok}; use clap::ValueEnum; -use itertools::Itertools; use remote_storage::RemoteStorageConfig; -use rustls::crypto::ring::{self, sign}; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; -use sha2::{Digest, Sha256}; -use tracing::{error, info}; -use x509_parser::oid_registry; use crate::auth::backend::jwt::JwkCache; use crate::auth::backend::AuthRateLimiter; @@ -20,6 +13,7 @@ use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig} use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; use crate::serverless::GlobalConnPoolOptions; +pub use crate::tls::server_config::{configure_tls, TlsConfig}; use crate::types::Host; pub struct ProxyConfig { @@ -32,7 +26,13 @@ pub struct ProxyConfig { pub handshake_timeout: Duration, pub wake_compute_retry_config: RetryConfig, pub connect_compute_locks: ApiLocks, - pub connect_to_compute_retry_config: RetryConfig, + pub connect_to_compute: ComputeConfig, +} + +pub struct ComputeConfig { + pub retry: RetryConfig, + pub tls: Arc, + pub timeout: Duration, } #[derive(Copy, Clone, Debug, ValueEnum, PartialEq)] @@ -52,12 +52,6 @@ pub struct MetricCollectionConfig { pub backup_metric_collection_config: MetricBackupCollectionConfig, } -pub struct TlsConfig { - pub config: Arc, - pub common_names: HashSet, - pub cert_resolver: Arc, -} - pub struct HttpConfig { pub accept_websockets: bool, pub pool_options: GlobalConnPoolOptions, @@ -80,272 +74,6 @@ pub struct AuthenticationConfig { pub console_redirect_confirmation_timeout: tokio::time::Duration, } -impl TlsConfig { - pub fn to_server_config(&self) -> Arc { - self.config.clone() - } -} - -/// -pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql"; - -/// Configure TLS for the main endpoint. -pub fn configure_tls( - key_path: &str, - cert_path: &str, - certs_dir: Option<&String>, - allow_tls_keylogfile: bool, -) -> anyhow::Result { - let mut cert_resolver = CertResolver::new(); - - // add default certificate - cert_resolver.add_cert_path(key_path, cert_path, true)?; - - // add extra certificates - if let Some(certs_dir) = certs_dir { - for entry in std::fs::read_dir(certs_dir)? { - let entry = entry?; - let path = entry.path(); - if path.is_dir() { - // file names aligned with default cert-manager names - let key_path = path.join("tls.key"); - let cert_path = path.join("tls.crt"); - if key_path.exists() && cert_path.exists() { - cert_resolver.add_cert_path( - &key_path.to_string_lossy(), - &cert_path.to_string_lossy(), - false, - )?; - } - } - } - } - - let common_names = cert_resolver.get_common_names(); - - let cert_resolver = Arc::new(cert_resolver); - - // allow TLS 1.2 to be compatible with older client libraries - let mut config = - rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) - .context("ring should support TLS1.2 and TLS1.3")? - .with_no_client_auth() - .with_cert_resolver(cert_resolver.clone()); - - config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()]; - - if allow_tls_keylogfile { - // KeyLogFile will check for the SSLKEYLOGFILE environment variable. - config.key_log = Arc::new(rustls::KeyLogFile::new()); - } - - Ok(TlsConfig { - config: Arc::new(config), - common_names, - cert_resolver, - }) -} - -/// Channel binding parameter -/// -/// -/// Description: The hash of the TLS server's certificate as it -/// appears, octet for octet, in the server's Certificate message. Note -/// that the Certificate message contains a certificate_list, in which -/// the first element is the server's certificate. -/// -/// The hash function is to be selected as follows: -/// -/// * if the certificate's signatureAlgorithm uses a single hash -/// function, and that hash function is either MD5 or SHA-1, then use SHA-256; -/// -/// * if the certificate's signatureAlgorithm uses a single hash -/// function and that hash function neither MD5 nor SHA-1, then use -/// the hash function associated with the certificate's -/// signatureAlgorithm; -/// -/// * if the certificate's signatureAlgorithm uses no hash functions or -/// uses multiple hash functions, then this channel binding type's -/// channel bindings are undefined at this time (updates to is channel -/// binding type may occur to address this issue if it ever arises). -#[derive(Debug, Clone, Copy)] -pub enum TlsServerEndPoint { - Sha256([u8; 32]), - Undefined, -} - -impl TlsServerEndPoint { - pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result { - let sha256_oids = [ - // I'm explicitly not adding MD5 or SHA1 here... They're bad. - oid_registry::OID_SIG_ECDSA_WITH_SHA256, - oid_registry::OID_PKCS1_SHA256WITHRSA, - ]; - - let pem = x509_parser::parse_x509_certificate(cert) - .context("Failed to parse PEM object from cerficiate")? - .1; - - info!(subject = %pem.subject, "parsing TLS certificate"); - - let reg = oid_registry::OidRegistry::default().with_all_crypto(); - let oid = pem.signature_algorithm.oid(); - let alg = reg.get(oid); - if sha256_oids.contains(oid) { - let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into(); - info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding"); - Ok(Self::Sha256(tls_server_end_point)) - } else { - error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding"); - Ok(Self::Undefined) - } - } - - pub fn supported(&self) -> bool { - !matches!(self, TlsServerEndPoint::Undefined) - } -} - -#[derive(Default, Debug)] -pub struct CertResolver { - certs: HashMap, TlsServerEndPoint)>, - default: Option<(Arc, TlsServerEndPoint)>, -} - -impl CertResolver { - pub fn new() -> Self { - Self::default() - } - - fn add_cert_path( - &mut self, - key_path: &str, - cert_path: &str, - is_default: bool, - ) -> anyhow::Result<()> { - let priv_key = { - let key_bytes = std::fs::read(key_path) - .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?; - rustls_pemfile::private_key(&mut &key_bytes[..]) - .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? - .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? - }; - - let cert_chain_bytes = std::fs::read(cert_path) - .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; - - let cert_chain = { - rustls_pemfile::certs(&mut &cert_chain_bytes[..]) - .try_collect() - .with_context(|| { - format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.") - })? - }; - - self.add_cert(priv_key, cert_chain, is_default) - } - - pub fn add_cert( - &mut self, - priv_key: PrivateKeyDer<'static>, - cert_chain: Vec>, - is_default: bool, - ) -> anyhow::Result<()> { - let key = sign::any_supported_type(&priv_key).context("invalid private key")?; - - let first_cert = &cert_chain[0]; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - let pem = x509_parser::parse_x509_certificate(first_cert) - .context("Failed to parse PEM object from cerficiate")? - .1; - - let common_name = pem.subject().to_string(); - - // We need to get the canonical name for this certificate so we can match them against any domain names - // seen within the proxy codebase. - // - // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI. - // We need to remove the wildcard prefix for the purposes of certificate selection. - // - // auth-broker does not use SNI and instead uses the Neon-Connection-String header. - // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String. - // - // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string - // validation, so let's we can continue with any common-name - let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") { - s.to_string() - } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") { - s.to_string() - } else if let Some(s) = common_name.strip_prefix("CN=") { - s.to_string() - } else { - bail!("Failed to parse common name from certificate") - }; - - let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key)); - - if is_default { - self.default = Some((cert.clone(), tls_server_end_point)); - } - - self.certs.insert(common_name, (cert, tls_server_end_point)); - - Ok(()) - } - - pub fn get_common_names(&self) -> HashSet { - self.certs.keys().map(|s| s.to_string()).collect() - } -} - -impl rustls::server::ResolvesServerCert for CertResolver { - fn resolve( - &self, - client_hello: rustls::server::ClientHello<'_>, - ) -> Option> { - self.resolve(client_hello.server_name()).map(|x| x.0) - } -} - -impl CertResolver { - pub fn resolve( - &self, - server_name: Option<&str>, - ) -> Option<(Arc, TlsServerEndPoint)> { - // loop here and cut off more and more subdomains until we find - // a match to get a proper wildcard support. OTOH, we now do not - // use nested domains, so keep this simple for now. - // - // With the current coding foo.com will match *.foo.com and that - // repeats behavior of the old code. - if let Some(mut sni_name) = server_name { - loop { - if let Some(cert) = self.certs.get(sni_name) { - return Some(cert.clone()); - } - if let Some((_, rest)) = sni_name.split_once('.') { - sni_name = rest; - } else { - return None; - } - } - } else { - // No SNI, use the default certificate, otherwise we can't get to - // options parameter which can be used to set endpoint name too. - // That means that non-SNI flow will not work for CNAME domains in - // verify-full mode. - // - // If that will be a problem we can: - // - // a) Instead of multi-cert approach use single cert with extra - // domains listed in Subject Alternative Name (SAN). - // b) Deploy separate proxy instances for extra domains. - self.default.clone() - } - } -} - #[derive(Debug)] pub struct EndpointCacheConfig { /// Batch size to receive all endpoints on the startup. diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index c477822e85..25a549039c 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -115,7 +115,7 @@ pub async fn task_main( Ok(Some(p)) => { ctx.set_success(); let _disconnect = ctx.log_connect(); - match p.proxy_pass().await { + match p.proxy_pass(&config.connect_to_compute).await { Ok(()) => {} Err(ErrorSource::Client(e)) => { error!(?session_id, "per-client task finished with an IO error from the client: {e:#}"); @@ -216,7 +216,7 @@ pub(crate) async fn handle_client( }, &user_info, config.wake_compute_retry_config, - config.connect_to_compute_retry_config, + &config.connect_to_compute, ) .or_else(|e| stream.throw_error(e)) .await?; diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 0ca1a6aae0..c65041df0e 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -10,13 +10,13 @@ pub mod client; pub(crate) mod errors; use std::sync::Arc; -use std::time::Duration; use crate::auth::backend::jwt::AuthRule; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::IpPattern; use crate::cache::project_info::ProjectInfoCacheImpl; use crate::cache::{Cached, TimedLru}; +use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; use crate::intern::ProjectIdInt; @@ -73,9 +73,9 @@ impl NodeInfo { pub(crate) async fn connect( &self, ctx: &RequestContext, - timeout: Duration, + config: &ComputeConfig, ) -> Result { - self.config.connect(ctx, self.aux.clone(), timeout).await + self.config.connect(ctx, self.aux.clone(), config).await } pub(crate) fn reuse_settings(&mut self, other: Self) { diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index a5a72f26d9..c56474edd7 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -89,7 +89,6 @@ pub mod jemalloc; pub mod logging; pub mod metrics; pub mod parse; -pub mod postgres_rustls; pub mod protocol2; pub mod proxy; pub mod rate_limiter; @@ -99,6 +98,7 @@ pub mod scram; pub mod serverless; pub mod signals; pub mod stream; +pub mod tls; pub mod types; pub mod url; pub mod usage_metrics; diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 4a30d23985..8a80494860 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -6,7 +6,7 @@ use tracing::{debug, info, warn}; use super::retry::ShouldRetryWakeCompute; use crate::auth::backend::ComputeCredentialKeys; use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT}; -use crate::config::RetryConfig; +use crate::config::{ComputeConfig, RetryConfig}; use crate::context::RequestContext; use crate::control_plane::errors::WakeComputeError; use crate::control_plane::locks::ApiLocks; @@ -19,8 +19,6 @@ use crate::proxy::retry::{retry_after, should_retry, CouldRetry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; -const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); - /// If we couldn't connect, a cached connection info might be to blame /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. @@ -49,7 +47,7 @@ pub(crate) trait ConnectMechanism { &self, ctx: &RequestContext, node_info: &control_plane::CachedNodeInfo, - timeout: time::Duration, + config: &ComputeConfig, ) -> Result; fn update_connect_config(&self, conf: &mut compute::ConnCfg); @@ -86,11 +84,11 @@ impl ConnectMechanism for TcpMechanism<'_> { &self, ctx: &RequestContext, node_info: &control_plane::CachedNodeInfo, - timeout: time::Duration, + config: &ComputeConfig, ) -> Result { let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; - permit.release_result(node_info.connect(ctx, timeout).await) + permit.release_result(node_info.connect(ctx, config).await) } fn update_connect_config(&self, config: &mut compute::ConnCfg) { @@ -105,7 +103,7 @@ pub(crate) async fn connect_to_compute Result where M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug, @@ -119,10 +117,7 @@ where mechanism.update_connect_config(&mut node_info.config); // try once - let err = match mechanism - .connect_once(ctx, &node_info, CONNECT_TIMEOUT) - .await - { + let err = match mechanism.connect_once(ctx, &node_info, compute).await { Ok(res) => { ctx.success(); Metrics::get().proxy.retries_metric.observe( @@ -142,7 +137,7 @@ where let node_info = if !node_info.cached() || !err.should_retry_wake_compute() { // If we just recieved this from cplane and didn't get it from cache, we shouldn't retry. // Do not need to retrieve a new node_info, just return the old one. - if should_retry(&err, num_retries, connect_to_compute_retry_config) { + if should_retry(&err, num_retries, compute.retry) { Metrics::get().proxy.retries_metric.observe( RetriesMetricGroup { outcome: ConnectOutcome::Failed, @@ -172,10 +167,7 @@ where debug!("wake_compute success. attempting to connect"); num_retries = 1; loop { - match mechanism - .connect_once(ctx, &node_info, CONNECT_TIMEOUT) - .await - { + match mechanism.connect_once(ctx, &node_info, compute).await { Ok(res) => { ctx.success(); Metrics::get().proxy.retries_metric.observe( @@ -190,7 +182,7 @@ where return Ok(res); } Err(e) => { - if !should_retry(&e, num_retries, connect_to_compute_retry_config) { + if !should_retry(&e, num_retries, compute.retry) { // Don't log an error here, caller will print the error Metrics::get().proxy.retries_metric.observe( RetriesMetricGroup { @@ -206,7 +198,7 @@ where } }; - let wait_duration = retry_after(num_retries, connect_to_compute_retry_config); + let wait_duration = retry_after(num_retries, compute.retry); num_retries += 1; let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout); diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index e27c211932..955f754497 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -8,12 +8,13 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; use crate::auth::endpoint_sni; -use crate::config::{TlsConfig, PG_ALPN_PROTOCOL}; +use crate::config::TlsConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; use crate::proxy::ERR_INSECURE_CONNECTION; use crate::stream::{PqStream, Stream, StreamUpgradeError}; +use crate::tls::PG_ALPN_PROTOCOL; #[derive(Error, Debug)] pub(crate) enum HandshakeError { diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index dbe174cab7..3926c56fec 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -152,7 +152,7 @@ pub async fn task_main( Ok(Some(p)) => { ctx.set_success(); let _disconnect = ctx.log_connect(); - match p.proxy_pass().await { + match p.proxy_pass(&config.connect_to_compute).await { Ok(()) => {} Err(ErrorSource::Client(e)) => { warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}"); @@ -351,7 +351,7 @@ pub(crate) async fn handle_client( }, &user_info, config.wake_compute_retry_config, - config.connect_to_compute_retry_config, + &config.connect_to_compute, ) .or_else(|e| stream.throw_error(e)) .await?; diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index dcaa81e5cd..a42f9aad39 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -5,6 +5,7 @@ use utils::measured_stream::MeasuredStream; use super::copy_bidirectional::ErrorSource; use crate::cancellation; use crate::compute::PostgresConnection; +use crate::config::ComputeConfig; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; use crate::stream::Stream; @@ -67,9 +68,17 @@ pub(crate) struct ProxyPassthrough { } impl ProxyPassthrough { - pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { + pub(crate) async fn proxy_pass( + self, + compute_config: &ComputeConfig, + ) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; - if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { + if let Err(err) = self + .compute + .cancel_closure + .try_cancel_query(compute_config) + .await + { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); } res diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 95c518fed9..10db2bcb30 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -22,14 +22,16 @@ use super::*; use crate::auth::backend::{ ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, }; -use crate::config::{CertResolver, RetryConfig}; +use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{ self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache, }; use crate::error::ErrorKind; -use crate::postgres_rustls::MakeRustlsConnect; +use crate::tls::client_config::compute_client_config_with_certs; +use crate::tls::postgres_rustls::MakeRustlsConnect; +use crate::tls::server_config::CertResolver; use crate::types::{BranchId, EndpointId, ProjectId}; use crate::{sasl, scram}; @@ -67,7 +69,7 @@ fn generate_certs( } struct ClientConfig<'a> { - config: rustls::ClientConfig, + config: Arc, hostname: &'a str, } @@ -110,16 +112,7 @@ fn generate_tls_config<'a>( }; let client_config = { - let config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .context("ring should support the default protocol versions")? - .with_root_certificates({ - let mut store = rustls::RootCertStore::empty(); - store.add(ca)?; - store - }) - .with_no_client_auth(); + let config = Arc::new(compute_client_config_with_certs([ca])); ClientConfig { config, hostname } }; @@ -468,7 +461,7 @@ impl ConnectMechanism for TestConnectMechanism { &self, _ctx: &RequestContext, _node_info: &control_plane::CachedNodeInfo, - _timeout: std::time::Duration, + _config: &ComputeConfig, ) -> Result { let mut counter = self.counter.lock().unwrap(); let action = self.sequence[*counter]; @@ -576,6 +569,20 @@ fn helper_create_connect_info( user_info } +fn config() -> ComputeConfig { + let retry = RetryConfig { + base_delay: Duration::from_secs(1), + max_retries: 5, + backoff_factor: 2.0, + }; + + ComputeConfig { + retry, + tls: Arc::new(compute_client_config_with_certs(std::iter::empty())), + timeout: Duration::from_secs(2), + } +} + #[tokio::test] async fn connect_to_compute_success() { let _ = env_logger::try_init(); @@ -583,12 +590,8 @@ async fn connect_to_compute_success() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -601,12 +604,8 @@ async fn connect_to_compute_retry() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -620,12 +619,8 @@ async fn connect_to_compute_non_retry_1() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -639,12 +634,8 @@ async fn connect_to_compute_non_retry_2() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -665,17 +656,13 @@ async fn connect_to_compute_non_retry_3() { max_retries: 1, backoff_factor: 2.0, }; - let connect_to_compute_retry_config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; + let config = config(); connect_to_compute( &ctx, &mechanism, &user_info, wake_compute_retry_config, - connect_to_compute_retry_config, + &config, ) .await .unwrap_err(); @@ -690,12 +677,8 @@ async fn wake_retry() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -709,12 +692,8 @@ async fn wake_non_retry() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index d18dfd2465..80b93b6c4f 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -12,6 +12,7 @@ use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; use crate::cancellation::{CancelMap, CancellationHandler}; +use crate::config::ProxyConfig; use crate::intern::{ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; @@ -249,6 +250,7 @@ async fn handle_messages( /// Handle console's invalidation messages. #[tracing::instrument(name = "redis_notifications", skip_all)] pub async fn task_main( + config: &'static ProxyConfig, redis: ConnectionWithCredentialsProvider, cache: Arc, cancel_map: CancelMap, @@ -258,6 +260,7 @@ where C: ProjectInfoCache + Send + Sync + 'static, { let cancellation_handler = Arc::new(CancellationHandler::<()>::new( + &config.connect_to_compute, cancel_map, crate::metrics::CancellationSource::FromRedis, )); diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index 6a13f645a5..77853db3db 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -13,7 +13,6 @@ use super::secret::ServerSecret; use super::signature::SignatureBuilder; use super::threadpool::ThreadPool; use super::ScramKey; -use crate::config; use crate::intern::EndpointIdInt; use crate::sasl::{self, ChannelBinding, Error as SaslError}; @@ -59,14 +58,14 @@ enum ExchangeState { pub(crate) struct Exchange<'a> { state: ExchangeState, secret: &'a ServerSecret, - tls_server_end_point: config::TlsServerEndPoint, + tls_server_end_point: crate::tls::TlsServerEndPoint, } impl<'a> Exchange<'a> { pub(crate) fn new( secret: &'a ServerSecret, nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], - tls_server_end_point: config::TlsServerEndPoint, + tls_server_end_point: crate::tls::TlsServerEndPoint, ) -> Self { Self { state: ExchangeState::Initial(SaslInitial { nonce }), @@ -120,7 +119,7 @@ impl SaslInitial { fn transition( &self, secret: &ServerSecret, - tls_server_end_point: &config::TlsServerEndPoint, + tls_server_end_point: &crate::tls::TlsServerEndPoint, input: &str, ) -> sasl::Result> { let client_first_message = ClientFirstMessage::parse(input) @@ -155,7 +154,7 @@ impl SaslSentInner { fn transition( &self, secret: &ServerSecret, - tls_server_end_point: &config::TlsServerEndPoint, + tls_server_end_point: &crate::tls::TlsServerEndPoint, input: &str, ) -> sasl::Result> { let Self { @@ -168,8 +167,8 @@ impl SaslSentInner { .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?; let channel_binding = cbind_flag.encode(|_| match tls_server_end_point { - config::TlsServerEndPoint::Sha256(x) => Ok(x), - config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding), + crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x), + crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding), })?; // This might've been caused by a MITM attack diff --git a/proxy/src/scram/mod.rs b/proxy/src/scram/mod.rs index b49a9f32ee..cfa571cbe1 100644 --- a/proxy/src/scram/mod.rs +++ b/proxy/src/scram/mod.rs @@ -77,11 +77,8 @@ mod tests { const NONCE: [u8; 18] = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, ]; - let mut exchange = Exchange::new( - &secret, - || NONCE, - crate::config::TlsServerEndPoint::Undefined, - ); + let mut exchange = + Exchange::new(&secret, || NONCE, crate::tls::TlsServerEndPoint::Undefined); let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO"; let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0="; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 449d50b6e7..b398c3ddd0 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -22,7 +22,7 @@ use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; -use crate::config::ProxyConfig; +use crate::config::{ComputeConfig, ProxyConfig}; use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; @@ -196,7 +196,7 @@ impl PoolingBackend { }, &backend, self.config.wake_compute_retry_config, - self.config.connect_to_compute_retry_config, + &self.config.connect_to_compute, ) .await } @@ -237,7 +237,7 @@ impl PoolingBackend { }, &backend, self.config.wake_compute_retry_config, - self.config.connect_to_compute_retry_config, + &self.config.connect_to_compute, ) .await } @@ -502,7 +502,7 @@ impl ConnectMechanism for TokioMechanism { &self, ctx: &RequestContext, node_info: &CachedNodeInfo, - timeout: Duration, + compute_config: &ComputeConfig, ) -> Result { let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; @@ -511,7 +511,7 @@ impl ConnectMechanism for TokioMechanism { let config = config .user(&self.conn_info.user_info.user) .dbname(&self.conn_info.dbname) - .connect_timeout(timeout); + .connect_timeout(compute_config.timeout); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let res = config.connect(postgres_client::NoTls).await; @@ -552,7 +552,7 @@ impl ConnectMechanism for HyperMechanism { &self, ctx: &RequestContext, node_info: &CachedNodeInfo, - timeout: Duration, + config: &ComputeConfig, ) -> Result { let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; @@ -560,7 +560,7 @@ impl ConnectMechanism for HyperMechanism { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let port = node_info.config.get_port(); - let res = connect_http2(&host, port, timeout).await; + let res = connect_http2(&host, port, config.timeout).await; drop(pause); let (client, connection) = permit.release_result(res)?; diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 812fedaf04..47326c1181 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -168,7 +168,7 @@ pub(crate) async fn serve_websocket( Ok(Some(p)) => { ctx.set_success(); ctx.log_connect(); - match p.proxy_pass().await { + match p.proxy_pass(&config.connect_to_compute).await { Ok(()) => Ok(()), Err(ErrorSource::Client(err)) => Err(err).context("client"), Err(ErrorSource::Compute(err)) => Err(err).context("compute"), diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 11f426819d..ace27a7284 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -11,9 +11,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; use tracing::debug; -use crate::config::TlsServerEndPoint; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; +use crate::tls::TlsServerEndPoint; /// Stream wrapper which implements libpq's protocol. /// diff --git a/proxy/src/tls/client_config.rs b/proxy/src/tls/client_config.rs new file mode 100644 index 0000000000..a2d695aae1 --- /dev/null +++ b/proxy/src/tls/client_config.rs @@ -0,0 +1,42 @@ +use std::sync::Arc; + +use anyhow::bail; +use rustls::crypto::ring; + +pub(crate) fn load_certs() -> anyhow::Result> { + let der_certs = rustls_native_certs::load_native_certs(); + + if !der_certs.errors.is_empty() { + bail!("could not parse certificates: {:?}", der_certs.errors); + } + + let mut store = rustls::RootCertStore::empty(); + store.add_parsable_certificates(der_certs.certs); + Ok(Arc::new(store)) +} + +/// Loads the root certificates and constructs a client config suitable for connecting to the neon compute. +/// This function is blocking. +pub fn compute_client_config_with_root_certs() -> anyhow::Result { + Ok( + rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_safe_default_protocol_versions() + .expect("ring should support the default protocol versions") + .with_root_certificates(load_certs()?) + .with_no_client_auth(), + ) +} + +#[cfg(test)] +pub fn compute_client_config_with_certs( + certs: impl IntoIterator>, +) -> rustls::ClientConfig { + let mut store = rustls::RootCertStore::empty(); + store.add_parsable_certificates(certs); + + rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_safe_default_protocol_versions() + .expect("ring should support the default protocol versions") + .with_root_certificates(store) + .with_no_client_auth() +} diff --git a/proxy/src/tls/mod.rs b/proxy/src/tls/mod.rs new file mode 100644 index 0000000000..d6ce6bd9fc --- /dev/null +++ b/proxy/src/tls/mod.rs @@ -0,0 +1,72 @@ +pub mod client_config; +pub mod postgres_rustls; +pub mod server_config; + +use anyhow::Context; +use rustls::pki_types::CertificateDer; +use sha2::{Digest, Sha256}; +use tracing::{error, info}; +use x509_parser::oid_registry; + +/// +pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql"; + +/// Channel binding parameter +/// +/// +/// Description: The hash of the TLS server's certificate as it +/// appears, octet for octet, in the server's Certificate message. Note +/// that the Certificate message contains a certificate_list, in which +/// the first element is the server's certificate. +/// +/// The hash function is to be selected as follows: +/// +/// * if the certificate's signatureAlgorithm uses a single hash +/// function, and that hash function is either MD5 or SHA-1, then use SHA-256; +/// +/// * if the certificate's signatureAlgorithm uses a single hash +/// function and that hash function neither MD5 nor SHA-1, then use +/// the hash function associated with the certificate's +/// signatureAlgorithm; +/// +/// * if the certificate's signatureAlgorithm uses no hash functions or +/// uses multiple hash functions, then this channel binding type's +/// channel bindings are undefined at this time (updates to is channel +/// binding type may occur to address this issue if it ever arises). +#[derive(Debug, Clone, Copy)] +pub enum TlsServerEndPoint { + Sha256([u8; 32]), + Undefined, +} + +impl TlsServerEndPoint { + pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result { + let sha256_oids = [ + // I'm explicitly not adding MD5 or SHA1 here... They're bad. + oid_registry::OID_SIG_ECDSA_WITH_SHA256, + oid_registry::OID_PKCS1_SHA256WITHRSA, + ]; + + let pem = x509_parser::parse_x509_certificate(cert) + .context("Failed to parse PEM object from cerficiate")? + .1; + + info!(subject = %pem.subject, "parsing TLS certificate"); + + let reg = oid_registry::OidRegistry::default().with_all_crypto(); + let oid = pem.signature_algorithm.oid(); + let alg = reg.get(oid); + if sha256_oids.contains(oid) { + let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into(); + info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding"); + Ok(Self::Sha256(tls_server_end_point)) + } else { + error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding"); + Ok(Self::Undefined) + } + } + + pub fn supported(&self) -> bool { + !matches!(self, TlsServerEndPoint::Undefined) + } +} diff --git a/proxy/src/postgres_rustls/mod.rs b/proxy/src/tls/postgres_rustls.rs similarity index 96% rename from proxy/src/postgres_rustls/mod.rs rename to proxy/src/tls/postgres_rustls.rs index 5ef20991c3..0ad279b635 100644 --- a/proxy/src/postgres_rustls/mod.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -18,7 +18,7 @@ mod private { use tokio_rustls::client::TlsStream; use tokio_rustls::TlsConnector; - use crate::config::TlsServerEndPoint; + use crate::tls::TlsServerEndPoint; pub struct TlsConnectFuture { inner: tokio_rustls::Connect, @@ -126,16 +126,14 @@ mod private { /// That way you can connect to PostgreSQL using `rustls` as the TLS stack. #[derive(Clone)] pub struct MakeRustlsConnect { - config: Arc, + pub config: Arc, } impl MakeRustlsConnect { /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. #[must_use] - pub fn new(config: ClientConfig) -> Self { - Self { - config: Arc::new(config), - } + pub fn new(config: Arc) -> Self { + Self { config } } } diff --git a/proxy/src/tls/server_config.rs b/proxy/src/tls/server_config.rs new file mode 100644 index 0000000000..2cc1657eea --- /dev/null +++ b/proxy/src/tls/server_config.rs @@ -0,0 +1,218 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use anyhow::{bail, Context}; +use itertools::Itertools; +use rustls::crypto::ring::{self, sign}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; + +use super::{TlsServerEndPoint, PG_ALPN_PROTOCOL}; + +pub struct TlsConfig { + pub config: Arc, + pub common_names: HashSet, + pub cert_resolver: Arc, +} + +impl TlsConfig { + pub fn to_server_config(&self) -> Arc { + self.config.clone() + } +} + +/// Configure TLS for the main endpoint. +pub fn configure_tls( + key_path: &str, + cert_path: &str, + certs_dir: Option<&String>, + allow_tls_keylogfile: bool, +) -> anyhow::Result { + let mut cert_resolver = CertResolver::new(); + + // add default certificate + cert_resolver.add_cert_path(key_path, cert_path, true)?; + + // add extra certificates + if let Some(certs_dir) = certs_dir { + for entry in std::fs::read_dir(certs_dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_dir() { + // file names aligned with default cert-manager names + let key_path = path.join("tls.key"); + let cert_path = path.join("tls.crt"); + if key_path.exists() && cert_path.exists() { + cert_resolver.add_cert_path( + &key_path.to_string_lossy(), + &cert_path.to_string_lossy(), + false, + )?; + } + } + } + } + + let common_names = cert_resolver.get_common_names(); + + let cert_resolver = Arc::new(cert_resolver); + + // allow TLS 1.2 to be compatible with older client libraries + let mut config = + rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) + .context("ring should support TLS1.2 and TLS1.3")? + .with_no_client_auth() + .with_cert_resolver(cert_resolver.clone()); + + config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()]; + + if allow_tls_keylogfile { + // KeyLogFile will check for the SSLKEYLOGFILE environment variable. + config.key_log = Arc::new(rustls::KeyLogFile::new()); + } + + Ok(TlsConfig { + config: Arc::new(config), + common_names, + cert_resolver, + }) +} + +#[derive(Default, Debug)] +pub struct CertResolver { + certs: HashMap, TlsServerEndPoint)>, + default: Option<(Arc, TlsServerEndPoint)>, +} + +impl CertResolver { + pub fn new() -> Self { + Self::default() + } + + fn add_cert_path( + &mut self, + key_path: &str, + cert_path: &str, + is_default: bool, + ) -> anyhow::Result<()> { + let priv_key = { + let key_bytes = std::fs::read(key_path) + .with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?; + rustls_pemfile::private_key(&mut &key_bytes[..]) + .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? + .with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))? + }; + + let cert_chain_bytes = std::fs::read(cert_path) + .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; + + let cert_chain = { + rustls_pemfile::certs(&mut &cert_chain_bytes[..]) + .try_collect() + .with_context(|| { + format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.") + })? + }; + + self.add_cert(priv_key, cert_chain, is_default) + } + + pub fn add_cert( + &mut self, + priv_key: PrivateKeyDer<'static>, + cert_chain: Vec>, + is_default: bool, + ) -> anyhow::Result<()> { + let key = sign::any_supported_type(&priv_key).context("invalid private key")?; + + let first_cert = &cert_chain[0]; + let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; + let pem = x509_parser::parse_x509_certificate(first_cert) + .context("Failed to parse PEM object from cerficiate")? + .1; + + let common_name = pem.subject().to_string(); + + // We need to get the canonical name for this certificate so we can match them against any domain names + // seen within the proxy codebase. + // + // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI. + // We need to remove the wildcard prefix for the purposes of certificate selection. + // + // auth-broker does not use SNI and instead uses the Neon-Connection-String header. + // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String. + // + // Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string + // validation, so let's we can continue with any common-name + let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") { + s.to_string() + } else if let Some(s) = common_name.strip_prefix("CN=apiauth.") { + s.to_string() + } else if let Some(s) = common_name.strip_prefix("CN=") { + s.to_string() + } else { + bail!("Failed to parse common name from certificate") + }; + + let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key)); + + if is_default { + self.default = Some((cert.clone(), tls_server_end_point)); + } + + self.certs.insert(common_name, (cert, tls_server_end_point)); + + Ok(()) + } + + pub fn get_common_names(&self) -> HashSet { + self.certs.keys().map(|s| s.to_string()).collect() + } +} + +impl rustls::server::ResolvesServerCert for CertResolver { + fn resolve( + &self, + client_hello: rustls::server::ClientHello<'_>, + ) -> Option> { + self.resolve(client_hello.server_name()).map(|x| x.0) + } +} + +impl CertResolver { + pub fn resolve( + &self, + server_name: Option<&str>, + ) -> Option<(Arc, TlsServerEndPoint)> { + // loop here and cut off more and more subdomains until we find + // a match to get a proper wildcard support. OTOH, we now do not + // use nested domains, so keep this simple for now. + // + // With the current coding foo.com will match *.foo.com and that + // repeats behavior of the old code. + if let Some(mut sni_name) = server_name { + loop { + if let Some(cert) = self.certs.get(sni_name) { + return Some(cert.clone()); + } + if let Some((_, rest)) = sni_name.split_once('.') { + sni_name = rest; + } else { + return None; + } + } + } else { + // No SNI, use the default certificate, otherwise we can't get to + // options parameter which can be used to set endpoint name too. + // That means that non-SNI flow will not work for CNAME domains in + // verify-full mode. + // + // If that will be a problem we can: + // + // a) Instead of multi-cert approach use single cert with extra + // domains listed in Subject Alternative Name (SAN). + // b) Deploy separate proxy instances for extra domains. + self.default.clone() + } + } +}