diff --git a/Cargo.lock b/Cargo.lock index d605169986..16fcd0c4c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3505,6 +3505,7 @@ dependencies = [ "pbkdf2", "pin-project-lite", "postgres-native-tls", + "postgres-protocol", "postgres_backend", "pq_proto", "prometheus", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 0ec7efd316..39a9c3ddb0 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -76,3 +76,4 @@ tokio-util.workspace = true rcgen.workspace = true rstest.workspace = true tokio-postgres-rustls.workspace = true +postgres-protocol.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 9cf45c0eec..f0197cc31b 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -6,6 +6,7 @@ pub use link::LinkAuthError; use tokio_postgres::config::AuthKeys; use crate::proxy::{handle_try_wake, retry_after, LatencyTimer}; +use crate::stream::Stream; use crate::{ auth::{self, ClientCredentials}, config::AuthenticationConfig, @@ -131,7 +132,7 @@ async fn auth_quirks_creds( api: &impl console::Api, extra: &ConsoleReqExtra<'_>, creds: &mut ClientCredentials<'_>, - client: &mut stream::PqStream, + client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, latency_timer: &mut LatencyTimer, @@ -165,7 +166,7 @@ async fn auth_quirks( api: &impl console::Api, extra: &ConsoleReqExtra<'_>, creds: &mut ClientCredentials<'_>, - client: &mut stream::PqStream, + client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, latency_timer: &mut LatencyTimer, @@ -241,7 +242,7 @@ impl BackendType<'_, ClientCredentials<'_>> { pub async fn authenticate( &mut self, extra: &ConsoleReqExtra<'_>, - client: &mut stream::PqStream, + client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, latency_timer: &mut LatencyTimer, diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index aee0057606..ac0d490db1 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -6,7 +6,7 @@ use crate::{ console::{self, AuthInfo, ConsoleReqExtra}, proxy::LatencyTimer, sasl, scram, - stream::PqStream, + stream::{PqStream, Stream}, }; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; @@ -15,7 +15,7 @@ pub(super) async fn authenticate( api: &impl console::Api, extra: &ConsoleReqExtra<'_>, creds: &ClientCredentials<'_>, - client: &mut PqStream, + client: &mut PqStream>, config: &'static AuthenticationConfig, latency_timer: &mut LatencyTimer, ) -> auth::Result> { diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 895683af1b..4448dbc56a 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -2,7 +2,7 @@ use super::{AuthSuccess, ComputeCredentials}; use crate::{ auth::{self, AuthFlow, ClientCredentials}, proxy::LatencyTimer, - stream, + stream::{self, Stream}, }; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; @@ -12,7 +12,7 @@ use tracing::{info, warn}; /// These properties are benefical for serverless JS workers, so we /// use this mechanism for websocket connections. pub async fn cleartext_hack( - client: &mut stream::PqStream, + client: &mut stream::PqStream>, latency_timer: &mut LatencyTimer, ) -> auth::Result> { warn!("cleartext auth flow override is enabled, proceeding"); @@ -37,7 +37,7 @@ pub async fn cleartext_hack( /// Very similar to [`cleartext_hack`], but there's a specific password format. pub async fn password_hack( creds: &mut ClientCredentials<'_>, - client: &mut stream::PqStream, + client: &mut stream::PqStream>, latency_timer: &mut LatencyTimer, ) -> auth::Result> { warn!("project not specified, resorting to the password hack auth flow"); diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 190abc9b2e..efb90733d6 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,16 +1,21 @@ //! Main authentication flow. use super::{AuthErrorImpl, PasswordHackPayload}; -use crate::{sasl, scram, stream::PqStream}; +use crate::{ + config::TlsServerEndPoint, + sasl, scram, + stream::{PqStream, Stream}, +}; use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use std::io; use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::info; /// Every authentication selector is supposed to implement this trait. pub trait AuthMethod { /// Any authentication selector should provide initial backend message /// containing auth method name and parameters, e.g. md5 salt. - fn first_message(&self) -> BeMessage<'_>; + fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; } /// Initial state of [`AuthFlow`]. @@ -21,8 +26,14 @@ pub struct Scram<'a>(pub &'a scram::ServerSecret); impl AuthMethod for Scram<'_> { #[inline(always)] - fn first_message(&self) -> BeMessage<'_> { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + fn first_message(&self, channel_binding: bool) -> BeMessage<'_> { + if channel_binding { + Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + } else { + Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( + scram::METHODS_WITHOUT_PLUS, + )) + } } } @@ -32,7 +43,7 @@ pub struct PasswordHack; impl AuthMethod for PasswordHack { #[inline(always)] - fn first_message(&self) -> BeMessage<'_> { + fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { Be::AuthenticationCleartextPassword } } @@ -43,37 +54,44 @@ pub struct CleartextPassword; impl AuthMethod for CleartextPassword { #[inline(always)] - fn first_message(&self) -> BeMessage<'_> { + fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { Be::AuthenticationCleartextPassword } } /// This wrapper for [`PqStream`] performs client authentication. #[must_use] -pub struct AuthFlow<'a, Stream, State> { +pub struct AuthFlow<'a, S, State> { /// The underlying stream which implements libpq's protocol. - stream: &'a mut PqStream, + stream: &'a mut PqStream>, /// State might contain ancillary data (see [`Self::begin`]). state: State, + tls_server_end_point: TlsServerEndPoint, } /// Initial state of the stream wrapper. -impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { +impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { /// Create a new wrapper for client authentication. - pub fn new(stream: &'a mut PqStream) -> Self { + pub fn new(stream: &'a mut PqStream>) -> Self { + let tls_server_end_point = stream.get_ref().tls_server_end_point(); + Self { stream, state: Begin, + tls_server_end_point, } } /// Move to the next step by sending auth method's name & params to client. pub async fn begin(self, method: M) -> io::Result> { - self.stream.write_message(&method.first_message()).await?; + self.stream + .write_message(&method.first_message(self.tls_server_end_point.supported())) + .await?; Ok(AuthFlow { stream: self.stream, state: method, + tls_server_end_point: self.tls_server_end_point, }) } } @@ -123,9 +141,15 @@ impl AuthFlow<'_, S, Scram<'_>> { return Err(super::AuthError::bad_auth_method(sasl.method)); } + info!("client chooses {}", sasl.method); + let secret = self.state.0; let outcome = sasl::SaslStream::new(self.stream, sasl.message) - .authenticate(scram::Exchange::new(secret, rand::random, None)) + .authenticate(scram::Exchange::new( + secret, + rand::random, + self.tls_server_end_point, + )) .await?; Ok(outcome) diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 42aecdb6fe..2b859fc2db 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -6,6 +6,8 @@ use std::{net::SocketAddr, sync::Arc}; use futures::future::Either; +use itertools::Itertools; +use proxy::config::TlsServerEndPoint; use tokio::net::TcpListener; use anyhow::{anyhow, bail, ensure, Context}; @@ -65,7 +67,7 @@ async fn main() -> anyhow::Result<()> { let destination: String = args.get_one::("dest").unwrap().parse()?; // Configure TLS - let tls_config: Arc = match ( + let (tls_config, tls_server_end_point): (Arc, TlsServerEndPoint) = match ( args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { @@ -89,16 +91,22 @@ async fn main() -> anyhow::Result<()> { ))? .into_iter() .map(rustls::Certificate) - .collect() + .collect_vec() }; - rustls::ServerConfig::builder() + // needed for channel bindings + let first_cert = cert_chain.first().context("missing certificate")?; + let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; + + let tls_config = rustls::ServerConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])? .with_no_client_auth() .with_single_cert(cert_chain, key)? - .into() + .into(); + + (tls_config, tls_server_end_point) } _ => bail!("tls-key and tls-cert must be specified"), }; @@ -113,6 +121,7 @@ async fn main() -> anyhow::Result<()> { let main = tokio::spawn(task_main( Arc::new(destination), tls_config, + tls_server_end_point, proxy_listener, cancellation_token.clone(), )); @@ -134,6 +143,7 @@ async fn main() -> anyhow::Result<()> { async fn task_main( dest_suffix: Arc, tls_config: Arc, + tls_server_end_point: TlsServerEndPoint, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -159,7 +169,7 @@ async fn task_main( .context("failed to set socket option")?; info!(%peer_addr, "serving"); - handle_client(dest_suffix, tls_config, socket).await + handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -207,6 +217,7 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod async fn ssl_handshake( raw_stream: S, tls_config: Arc, + tls_server_end_point: TlsServerEndPoint, ) -> anyhow::Result> { let mut stream = PqStream::new(Stream::from_raw(raw_stream)); @@ -231,7 +242,11 @@ async fn ssl_handshake( if !read_buf.is_empty() { bail!("data is sent before server replied with EncryptionResponse"); } - Ok(raw.upgrade(tls_config).await?) + + Ok(Stream::Tls { + tls: Box::new(raw.upgrade(tls_config).await?), + tls_server_end_point, + }) } unexpected => { info!( @@ -246,9 +261,10 @@ async fn ssl_handshake( async fn handle_client( dest_suffix: Arc, tls_config: Arc, + tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let tls_stream = ssl_handshake(stream, tls_config).await?; + let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of diff --git a/proxy/src/config.rs b/proxy/src/config.rs index bd00123905..0c094ff4aa 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,12 +1,15 @@ use crate::auth; use anyhow::{bail, ensure, Context, Ok}; -use rustls::sign; +use rustls::{sign, Certificate, PrivateKey}; +use sha2::{Digest, Sha256}; use std::{ collections::{HashMap, HashSet}, str::FromStr, sync::Arc, time::Duration, }; +use tracing::{error, info}; +use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, @@ -27,6 +30,7 @@ pub struct MetricCollectionConfig { pub struct TlsConfig { pub config: Arc, pub common_names: Option>, + pub cert_resolver: Arc, } pub struct HttpConfig { @@ -52,7 +56,7 @@ pub fn configure_tls( let mut cert_resolver = CertResolver::new(); // add default certificate - cert_resolver.add_cert(key_path, cert_path, true)?; + cert_resolver.add_cert_path(key_path, cert_path, true)?; // add extra certificates if let Some(certs_dir) = certs_dir { @@ -64,7 +68,7 @@ pub fn configure_tls( let key_path = path.join("tls.key"); let cert_path = path.join("tls.crt"); if key_path.exists() && cert_path.exists() { - cert_resolver.add_cert( + cert_resolver.add_cert_path( &key_path.to_string_lossy(), &cert_path.to_string_lossy(), false, @@ -76,35 +80,97 @@ pub fn configure_tls( let common_names = cert_resolver.get_common_names(); + let cert_resolver = Arc::new(cert_resolver); + let config = rustls::ServerConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() // allow TLS 1.2 to be compatible with older client libraries .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])? .with_no_client_auth() - .with_cert_resolver(Arc::new(cert_resolver)) + .with_cert_resolver(cert_resolver.clone()) .into(); Ok(TlsConfig { config, common_names: Some(common_names), + cert_resolver, }) } -struct CertResolver { - certs: HashMap>, - default: Option>, +/// Channel binding parameter +/// +/// +/// Description: The hash of the TLS server's certificate as it +/// appears, octet for octet, in the server's Certificate message. Note +/// that the Certificate message contains a certificate_list, in which +/// the first element is the server's certificate. +/// +/// The hash function is to be selected as follows: +/// +/// * if the certificate's signatureAlgorithm uses a single hash +/// function, and that hash function is either MD5 or SHA-1, then use SHA-256; +/// +/// * if the certificate's signatureAlgorithm uses a single hash +/// function and that hash function neither MD5 nor SHA-1, then use +/// the hash function associated with the certificate's +/// signatureAlgorithm; +/// +/// * if the certificate's signatureAlgorithm uses no hash functions or +/// uses multiple hash functions, then this channel binding type's +/// channel bindings are undefined at this time (updates to is channel +/// binding type may occur to address this issue if it ever arises). +#[derive(Debug, Clone, Copy)] +pub enum TlsServerEndPoint { + Sha256([u8; 32]), + Undefined, } -impl CertResolver { - fn new() -> Self { - Self { - certs: HashMap::new(), - default: None, +impl TlsServerEndPoint { + pub fn new(cert: &Certificate) -> anyhow::Result { + let sha256_oids = [ + // I'm explicitly not adding MD5 or SHA1 here... They're bad. + oid_registry::OID_SIG_ECDSA_WITH_SHA256, + oid_registry::OID_PKCS1_SHA256WITHRSA, + ]; + + let pem = x509_parser::parse_x509_certificate(&cert.0) + .context("Failed to parse PEM object from cerficiate")? + .1; + + info!(subject = %pem.subject, "parsing TLS certificate"); + + let reg = oid_registry::OidRegistry::default().with_all_crypto(); + let oid = pem.signature_algorithm.oid(); + let alg = reg.get(oid); + if sha256_oids.contains(oid) { + let tls_server_end_point: [u8; 32] = + Sha256::new().chain_update(&cert.0).finalize().into(); + info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding"); + Ok(Self::Sha256(tls_server_end_point)) + } else { + error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding"); + Ok(Self::Undefined) } } - fn add_cert( + pub fn supported(&self) -> bool { + !matches!(self, TlsServerEndPoint::Undefined) + } +} + +#[derive(Default)] +pub struct CertResolver { + certs: HashMap, TlsServerEndPoint)>, + default: Option<(Arc, TlsServerEndPoint)>, +} + +impl CertResolver { + pub fn new() -> Self { + Self::default() + } + + fn add_cert_path( &mut self, key_path: &str, cert_path: &str, @@ -120,57 +186,65 @@ impl CertResolver { keys.pop().map(rustls::PrivateKey).unwrap() }; - let key = sign::any_supported_type(&priv_key).context("invalid private key")?; - let cert_chain_bytes = std::fs::read(cert_path) .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; let cert_chain = { rustls_pemfile::certs(&mut &cert_chain_bytes[..]) - .context(format!( + .with_context(|| { + format!( "Failed to read TLS certificate chain from bytes from file at '{cert_path}'." - ))? + ) + })? .into_iter() .map(rustls::Certificate) .collect() }; - let common_name = { - let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes) - .context(format!( - "Failed to parse PEM object from bytes from file at '{cert_path}'." - ))? - .1; - let common_name = pem.parse_x509()?.subject().to_string(); + self.add_cert(priv_key, cert_chain, is_default) + } - // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as - // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so - // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names - // and passed None instead, which blows up number of cases downstream code should handle. Proper coding - // here should better avoid Option for common_names, and do wildcard-based certificate selection instead - // of cutting off '*.' parts. - if common_name.starts_with("CN=*.") { - common_name.strip_prefix("CN=*.").map(|s| s.to_string()) - } else { - common_name.strip_prefix("CN=").map(|s| s.to_string()) - } + pub fn add_cert( + &mut self, + priv_key: PrivateKey, + cert_chain: Vec, + is_default: bool, + ) -> anyhow::Result<()> { + let key = sign::any_supported_type(&priv_key).context("invalid private key")?; + + let first_cert = &cert_chain[0]; + let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; + let pem = x509_parser::parse_x509_certificate(&first_cert.0) + .context("Failed to parse PEM object from cerficiate")? + .1; + + let common_name = pem.subject().to_string(); + + // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as + // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so + // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names + // and passed None instead, which blows up number of cases downstream code should handle. Proper coding + // here should better avoid Option for common_names, and do wildcard-based certificate selection instead + // of cutting off '*.' parts. + let common_name = if common_name.starts_with("CN=*.") { + common_name.strip_prefix("CN=*.").map(|s| s.to_string()) + } else { + common_name.strip_prefix("CN=").map(|s| s.to_string()) } - .context(format!( - "Failed to parse common name from certificate at '{cert_path}'." - ))?; + .context("Failed to parse common name from certificate")?; let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key)); if is_default { - self.default = Some(cert.clone()); + self.default = Some((cert.clone(), tls_server_end_point)); } - self.certs.insert(common_name, cert); + self.certs.insert(common_name, (cert, tls_server_end_point)); Ok(()) } - fn get_common_names(&self) -> HashSet { + pub fn get_common_names(&self) -> HashSet { self.certs.keys().map(|s| s.to_string()).collect() } } @@ -178,15 +252,24 @@ impl CertResolver { impl rustls::server::ResolvesServerCert for CertResolver { fn resolve( &self, - _client_hello: rustls::server::ClientHello, + client_hello: rustls::server::ClientHello, ) -> Option> { + self.resolve(client_hello.server_name()).map(|x| x.0) + } +} + +impl CertResolver { + pub fn resolve( + &self, + server_name: Option<&str>, + ) -> Option<(Arc, TlsServerEndPoint)> { // loop here and cut off more and more subdomains until we find // a match to get a proper wildcard support. OTOH, we now do not // use nested domains, so keep this simple for now. // // With the current coding foo.com will match *.foo.com and that // repeats behavior of the old code. - if let Some(mut sni_name) = _client_hello.server_name() { + if let Some(mut sni_name) = server_name { loop { if let Some(cert) = self.certs.get(sni_name) { return Some(cert.clone()); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index adcb1bffaf..9560c8546a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -470,7 +470,17 @@ async fn handshake( if !read_buf.is_empty() { bail!("data is sent before server replied with EncryptionResponse"); } - stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?); + let tls_stream = raw.upgrade(tls.to_server_config()).await?; + + let (_, tls_server_end_point) = tls + .cert_resolver + .resolve(tls_stream.get_ref().1.server_name()) + .context("missing certificate")?; + + stream = PqStream::new(Stream::Tls { + tls: Box::new(tls_stream), + tls_server_end_point, + }); } } _ => bail!(ERR_PROTO_VIOLATION), @@ -875,7 +885,7 @@ pub async fn proxy_pass( /// Thin connection context. struct Client<'a, S> { /// The underlying libpq protocol stream. - stream: PqStream, + stream: PqStream>, /// Client credentials that we care about. creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, /// KV-dictionary with PostgreSQL connection params. @@ -889,7 +899,7 @@ struct Client<'a, S> { impl<'a, S> Client<'a, S> { /// Construct a new connection context. fn new( - stream: PqStream, + stream: PqStream>, creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, params: &'a StartupMessageParams, session_id: uuid::Uuid, diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 3ae4df46ef..de9cc0800b 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -1,19 +1,23 @@ //! A group of high-level tests for connection establishing logic and auth. -//! + +mod mitm; + use super::*; use crate::auth::backend::TestBackend; use crate::auth::ClientCredentials; +use crate::config::CertResolver; use crate::console::{CachedNodeInfo, NodeInfo}; use crate::{auth, http, sasl, scram}; use async_trait::async_trait; use rstest::rstest; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; -use tokio_postgres_rustls::MakeRustlsConnect; +use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream}; /// Generate a set of TLS certificates: CA + server. fn generate_certs( hostname: &str, + common_name: &str, ) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> { let ca = rcgen::Certificate::from_params({ let mut params = rcgen::CertificateParams::default(); @@ -21,7 +25,15 @@ fn generate_certs( params })?; - let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?; + let cert = rcgen::Certificate::from_params({ + let mut params = rcgen::CertificateParams::new(vec![hostname.into()]); + params.distinguished_name = rcgen::DistinguishedName::new(); + params + .distinguished_name + .push(rcgen::DnType::CommonName, common_name); + params + })?; + Ok(( rustls::Certificate(ca.serialize_der()?), rustls::Certificate(cert.serialize_der_with_signer(&ca)?), @@ -37,7 +49,14 @@ struct ClientConfig<'a> { impl ClientConfig<'_> { fn make_tls_connect( self, - ) -> anyhow::Result> { + ) -> anyhow::Result< + impl tokio_postgres::tls::TlsConnect< + S, + Error = impl std::fmt::Debug, + Future = impl Send, + Stream = RustlsStream, + >, + > { let mut mk = MakeRustlsConnect::new(self.config); let tls = MakeTlsConnect::::make_tls_connect(&mut mk, self.hostname)?; Ok(tls) @@ -49,20 +68,24 @@ fn generate_tls_config<'a>( hostname: &'a str, common_name: &'a str, ) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> { - let (ca, cert, key) = generate_certs(hostname)?; + let (ca, cert, key) = generate_certs(hostname, common_name)?; let tls_config = { let config = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() - .with_single_cert(vec![cert], key)? + .with_single_cert(vec![cert.clone()], key.clone())? .into(); - let common_names = Some([common_name.to_owned()].iter().cloned().collect()); + let mut cert_resolver = CertResolver::new(); + cert_resolver.add_cert(key, vec![cert], true)?; + + let common_names = Some(cert_resolver.get_common_names()); TlsConfig { config, common_names, + cert_resolver: Arc::new(cert_resolver), } }; @@ -253,6 +276,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { )); let (_client, _conn) = tokio_postgres::Config::new() + .channel_binding(tokio_postgres::config::ChannelBinding::Require) .user("user") .dbname("db") .password(password) @@ -263,6 +287,30 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { proxy.await? } +#[tokio::test] +async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let (client_config, server_config) = + generate_tls_config("generic-project-name.localhost", "localhost")?; + let proxy = tokio::spawn(dummy_proxy( + client, + Some(server_config), + Scram::new("password")?, + )); + + let (_client, _conn) = tokio_postgres::Config::new() + .channel_binding(tokio_postgres::config::ChannelBinding::Disable) + .user("user") + .dbname("db") + .password("password") + .ssl_mode(SslMode::Require) + .connect_raw(server, client_config.make_tls_connect()?) + .await?; + + proxy.await? +} + #[tokio::test] async fn scram_auth_mock() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs new file mode 100644 index 0000000000..50b3034936 --- /dev/null +++ b/proxy/src/proxy/tests/mitm.rs @@ -0,0 +1,257 @@ +//! Man-in-the-middle tests +//! +//! Channel binding should prevent a proxy server +//! - that has access to create valid certificates - +//! from controlling the TLS connection. + +use std::fmt::Debug; + +use super::*; +use bytes::{Bytes, BytesMut}; +use futures::{SinkExt, StreamExt}; +use postgres_protocol::message::frontend; +use tokio::io::{AsyncReadExt, DuplexStream}; +use tokio_postgres::config::SslMode; +use tokio_postgres::tls::TlsConnect; +use tokio_util::codec::{Decoder, Encoder}; + +enum Intercept { + None, + Methods, + SASLResponse, +} + +async fn proxy_mitm( + intercept: Intercept, +) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) { + let (end_server1, client1) = tokio::io::duplex(1024); + let (server2, end_client2) = tokio::io::duplex(1024); + + let (client_config1, server_config1) = + generate_tls_config("generic-project-name.localhost", "localhost").unwrap(); + let (client_config2, server_config2) = + generate_tls_config("generic-project-name.localhost", "localhost").unwrap(); + + tokio::spawn(async move { + // begin handshake with end_server + let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await; + // process handshake with end_client + let (end_client, startup) = + handshake(client1, Some(&server_config1), &CancelMap::default()) + .await + .unwrap() + .unwrap(); + + let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); + let (end_client, buf) = end_client.framed.into_inner(); + assert!(buf.is_empty()); + let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame); + + // give the end_server the startup parameters + let mut buf = BytesMut::new(); + frontend::startup_message(startup.iter(), &mut buf).unwrap(); + end_server.send(buf.freeze()).await.unwrap(); + + // proxy messages between end_client and end_server + loop { + tokio::select! { + message = end_server.next() => { + match message { + Some(Ok(message)) => { + // intercept SASL and return only SCRAM-SHA-256 ;) + if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) { + end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap(); + continue; + } + end_client.send(message).await.unwrap() + } + _ => break, + } + } + message = end_client.next() => { + match message { + Some(Ok(message)) => { + // intercept SASL response and return SCRAM-SHA-256 with no channel binding ;) + if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") { + let sasl_message = &message[1+4+19+4..]; + let mut new_message = b"n,,".to_vec(); + new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap()); + + let mut buf = BytesMut::new(); + frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap(); + + end_server.send(buf.freeze()).await.unwrap(); + continue; + } + end_server.send(message).await.unwrap() + } + _ => break, + } + } + else => { break } + } + } + }); + + (end_server1, end_client2, client_config1, server_config2) +} + +/// taken from tokio-postgres +pub async fn connect_tls(mut stream: S, tls: T) -> T::Stream +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + T::Error: Debug, +{ + let mut buf = BytesMut::new(); + frontend::ssl_request(&mut buf); + stream.write_all(&buf).await.unwrap(); + + let mut buf = [0]; + stream.read_exact(&mut buf).await.unwrap(); + + if buf[0] != b'S' { + panic!("ssl not supported by server"); + } + + tls.connect(stream).await.unwrap() +} + +struct PgFrame; +impl Decoder for PgFrame { + type Item = Bytes; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.len() < 5 { + src.reserve(5 - src.len()); + return Ok(None); + } + let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1; + if src.len() < len { + src.reserve(len - src.len()); + return Ok(None); + } + Ok(Some(src.split_to(len).freeze())) + } +} +impl Encoder for PgFrame { + type Error = io::Error; + + fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> { + dst.extend_from_slice(&item); + Ok(()) + } +} + +/// If the client doesn't support channel bindings, it can be exploited. +#[tokio::test] +async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { + let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await; + let proxy = tokio::spawn(dummy_proxy( + client, + Some(server_config), + Scram::new("password")?, + )); + + let _client_err = tokio_postgres::Config::new() + .channel_binding(tokio_postgres::config::ChannelBinding::Disable) + .user("user") + .dbname("db") + .password("password") + .ssl_mode(SslMode::Require) + .connect_raw(server, client_config.make_tls_connect()?) + .await?; + + proxy.await? +} + +/// If the client chooses SCRAM-PLUS, it will fail +#[tokio::test] +async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> { + connect_failure( + Intercept::None, + tokio_postgres::config::ChannelBinding::Prefer, + ) + .await +} + +/// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail +#[tokio::test] +async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> { + connect_failure( + Intercept::Methods, + tokio_postgres::config::ChannelBinding::Prefer, + ) + .await +} + +/// If the MITM pretends like the client doesn't support channel bindings, it will fail +#[tokio::test] +async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> { + connect_failure( + Intercept::SASLResponse, + tokio_postgres::config::ChannelBinding::Prefer, + ) + .await +} + +/// If the client chooses SCRAM-PLUS, it will fail +#[tokio::test] +async fn scram_auth_require_channel_binding() -> anyhow::Result<()> { + connect_failure( + Intercept::None, + tokio_postgres::config::ChannelBinding::Require, + ) + .await +} + +/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail +#[tokio::test] +async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> { + connect_failure( + Intercept::Methods, + tokio_postgres::config::ChannelBinding::Require, + ) + .await +} + +/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail +#[tokio::test] +async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> { + connect_failure( + Intercept::SASLResponse, + tokio_postgres::config::ChannelBinding::Require, + ) + .await +} + +async fn connect_failure( + intercept: Intercept, + channel_binding: tokio_postgres::config::ChannelBinding, +) -> anyhow::Result<()> { + let (server, client, client_config, server_config) = proxy_mitm(intercept).await; + let proxy = tokio::spawn(dummy_proxy( + client, + Some(server_config), + Scram::new("password")?, + )); + + let _client_err = tokio_postgres::Config::new() + .channel_binding(channel_binding) + .user("user") + .dbname("db") + .password("password") + .ssl_mode(SslMode::Require) + .connect_raw(server, client_config.make_tls_connect()?) + .await + .err() + .context("client shouldn't be able to connect")?; + + let _server_err = proxy + .await? + .err() + .context("server shouldn't accept client")?; + + Ok(()) +} diff --git a/proxy/src/sasl/channel_binding.rs b/proxy/src/sasl/channel_binding.rs index 776adabe55..13d681de6d 100644 --- a/proxy/src/sasl/channel_binding.rs +++ b/proxy/src/sasl/channel_binding.rs @@ -36,9 +36,9 @@ impl<'a> ChannelBinding<&'a str> { impl ChannelBinding { /// Encode channel binding data as base64 for subsequent checks. - pub fn encode( + pub fn encode<'a, E>( &self, - get_cbind_data: impl FnOnce(&T) -> Result, + get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>, ) -> Result, E> { use ChannelBinding::*; Ok(match self { @@ -51,12 +51,11 @@ impl ChannelBinding { "eSws".into() } Required(mode) => { - let msg = format!( - "p={mode},,{data}", - mode = mode, - data = get_cbind_data(mode)? - ); - base64::encode(msg).into() + use std::io::Write; + let mut cbind_input = vec![]; + write!(&mut cbind_input, "p={mode},,",).unwrap(); + cbind_input.extend_from_slice(get_cbind_data(mode)?); + base64::encode(&cbind_input).into() } }) } @@ -77,7 +76,7 @@ mod tests { ]; for (cb, input) in cases { - assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input); + assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input); } Ok(()) diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index 2de26af96b..63271309e1 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -22,9 +22,12 @@ pub use secret::ServerSecret; use hmac::{Hmac, Mac}; use sha2::{Digest, Sha256}; -// TODO: add SCRAM-SHA-256-PLUS +const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; +const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; + /// A list of supported SCRAM methods. -pub const METHODS: &[&str] = &["SCRAM-SHA-256"]; +pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256]; +pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256]; /// Decode base64 into array without any heap allocations fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N]> { @@ -80,7 +83,11 @@ mod tests { const NONCE: [u8; 18] = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, ]; - let mut exchange = Exchange::new(&secret, || NONCE, None); + let mut exchange = Exchange::new( + &secret, + || NONCE, + crate::config::TlsServerEndPoint::Undefined, + ); let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO"; let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0="; diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index 882769a70d..319d9b1014 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -5,9 +5,11 @@ use super::messages::{ }; use super::secret::ServerSecret; use super::signature::SignatureBuilder; +use crate::config; use crate::sasl::{self, ChannelBinding, Error as SaslError}; /// The only channel binding mode we currently support. +#[derive(Debug)] struct TlsServerEndPoint; impl std::fmt::Display for TlsServerEndPoint { @@ -43,20 +45,20 @@ pub struct Exchange<'a> { state: ExchangeState, secret: &'a ServerSecret, nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], - cert_digest: Option<&'a [u8]>, + tls_server_end_point: config::TlsServerEndPoint, } impl<'a> Exchange<'a> { pub fn new( secret: &'a ServerSecret, nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], - cert_digest: Option<&'a [u8]>, + tls_server_end_point: config::TlsServerEndPoint, ) -> Self { Self { state: ExchangeState::Initial, secret, nonce, - cert_digest, + tls_server_end_point, } } } @@ -71,6 +73,14 @@ impl sasl::Mechanism for Exchange<'_> { let client_first_message = ClientFirstMessage::parse(input) .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?; + // If the flag is set to "y" and the server supports channel + // binding, the server MUST fail authentication + if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer + && self.tls_server_end_point.supported() + { + return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used")); + } + let server_first_message = client_first_message.build_server_first_message( &(self.nonce)(), &self.secret.salt_base64, @@ -94,10 +104,11 @@ impl sasl::Mechanism for Exchange<'_> { let client_final_message = ClientFinalMessage::parse(input) .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?; - let channel_binding = cbind_flag.encode(|_| { - self.cert_digest - .map(base64::encode) - .ok_or(SaslError::ChannelBindingFailed("no cert digest provided")) + let channel_binding = cbind_flag.encode(|_| match &self.tls_server_end_point { + config::TlsServerEndPoint::Sha256(x) => Ok(x), + config::TlsServerEndPoint::Undefined => { + Err(SaslError::ChannelBindingFailed("no cert digest provided")) + } })?; // This might've been caused by a MITM attack diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 6210601a80..f48b3fe39f 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,7 +1,8 @@ +use crate::config::TlsServerEndPoint; use crate::error::UserFacingError; use anyhow::bail; use bytes::BytesMut; -use pin_project_lite::pin_project; + use pq_proto::framed::{ConnectionError, Framed}; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; @@ -17,7 +18,7 @@ use tokio_rustls::server::TlsStream; /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying /// to pass random malformed bytes through the connection). pub struct PqStream { - framed: Framed, + pub(crate) framed: Framed, } impl PqStream { @@ -118,19 +119,21 @@ impl PqStream { } } -pin_project! { - /// Wrapper for upgrading raw streams into secure streams. - /// NOTE: it should be possible to decompose this object as necessary. - #[project = StreamProj] - pub enum Stream { - /// We always begin with a raw stream, - /// which may then be upgraded into a secure stream. - Raw { #[pin] raw: S }, +/// Wrapper for upgrading raw streams into secure streams. +pub enum Stream { + /// We always begin with a raw stream, + /// which may then be upgraded into a secure stream. + Raw { raw: S }, + Tls { /// We box [`TlsStream`] since it can be quite large. - Tls { #[pin] tls: Box> }, - } + tls: Box>, + /// Channel binding parameter + tls_server_end_point: TlsServerEndPoint, + }, } +impl Unpin for Stream {} + impl Stream { /// Construct a new instance from a raw stream. pub fn from_raw(raw: S) -> Self { @@ -141,7 +144,17 @@ impl Stream { pub fn sni_hostname(&self) -> Option<&str> { match self { Stream::Raw { .. } => None, - Stream::Tls { tls } => tls.get_ref().1.server_name(), + Stream::Tls { tls, .. } => tls.get_ref().1.server_name(), + } + } + + pub fn tls_server_end_point(&self) -> TlsServerEndPoint { + match self { + Stream::Raw { .. } => TlsServerEndPoint::Undefined, + Stream::Tls { + tls_server_end_point, + .. + } => *tls_server_end_point, } } } @@ -158,12 +171,9 @@ pub enum StreamUpgradeError { impl Stream { /// If possible, upgrade raw stream into a secure TLS-based stream. - pub async fn upgrade(self, cfg: Arc) -> Result { + pub async fn upgrade(self, cfg: Arc) -> Result, StreamUpgradeError> { match self { - Stream::Raw { raw } => { - let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?); - Ok(Stream::Tls { tls }) - } + Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?), Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls), } } @@ -171,50 +181,46 @@ impl Stream { impl AsyncRead for Stream { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, context: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> task::Poll> { - use StreamProj::*; - match self.project() { - Raw { raw } => raw.poll_read(context, buf), - Tls { tls } => tls.poll_read(context, buf), + match &mut *self { + Self::Raw { raw } => Pin::new(raw).poll_read(context, buf), + Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf), } } } impl AsyncWrite for Stream { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, context: &mut task::Context<'_>, buf: &[u8], ) -> task::Poll> { - use StreamProj::*; - match self.project() { - Raw { raw } => raw.poll_write(context, buf), - Tls { tls } => tls.poll_write(context, buf), + match &mut *self { + Self::Raw { raw } => Pin::new(raw).poll_write(context, buf), + Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf), } } fn poll_flush( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, context: &mut task::Context<'_>, ) -> task::Poll> { - use StreamProj::*; - match self.project() { - Raw { raw } => raw.poll_flush(context), - Tls { tls } => tls.poll_flush(context), + match &mut *self { + Self::Raw { raw } => Pin::new(raw).poll_flush(context), + Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context), } } fn poll_shutdown( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, context: &mut task::Context<'_>, ) -> task::Poll> { - use StreamProj::*; - match self.project() { - Raw { raw } => raw.poll_shutdown(context), - Tls { tls } => tls.poll_shutdown(context), + match &mut *self { + Self::Raw { raw } => Pin::new(raw).poll_shutdown(context), + Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context), } } }