diff --git a/Cargo.lock b/Cargo.lock index dcef66c15d..824cac13b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4259,6 +4259,7 @@ dependencies = [ "smallvec", "smol_str", "socket2 0.5.5", + "subtle", "sync_wrapper", "task-local-extensions", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 0f3dbd4987..44e6ec9744 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -149,6 +149,7 @@ smol_str = { version = "0.2.0", features = ["serde"] } socket2 = "0.5" strum = "0.24" strum_macros = "0.24" +"subtle" = "2.5.0" svg_fmt = "0.4.1" sync_wrapper = "0.1.2" tar = "0.4" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index b3a5bf873e..93a1fe85db 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -63,6 +63,7 @@ sha2 = { workspace = true, features = ["asm"] } smol_str.workspace = true smallvec.workspace = true socket2.workspace = true +subtle.workspace = true sync_wrapper.workspace = true task-local-extensions.workspace = true thiserror.workspace = true diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index f26dcb7c9a..45bbad8cb2 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -194,14 +194,7 @@ pub(crate) async fn validate_password_and_exchange( } // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { - use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256}; - let sasl_client = ScramSha256::new(password, ChannelBinding::unsupported()); - let outcome = crate::scram::exchange( - &scram_secret, - sasl_client, - crate::config::TlsServerEndPoint::Undefined, - ) - .await?; + let outcome = crate::scram::exchange(&scram_secret, password).await?; let client_key = match outcome { sasl::Outcome::Success(client_key) => client_key, diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index 1cf8b53e11..0811416ca2 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -33,6 +33,9 @@ pub enum Error { #[error("Internal error: missing digest")] MissingBinding, + #[error("could not decode salt: {0}")] + Base64(#[from] base64::DecodeError), + #[error(transparent)] Io(#[from] io::Error), } @@ -55,6 +58,7 @@ impl ReportableError for Error { Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, Error::BadClientMessage(_) => crate::error::ErrorKind::User, Error::MissingBinding => crate::error::ErrorKind::Service, + Error::Base64(_) => crate::error::ErrorKind::ControlPlane, Error::Io(_) => crate::error::ErrorKind::ClientDisconnect, } } diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index 76541ae2f3..ed80675f8a 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -56,8 +56,6 @@ fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { #[cfg(test)] mod tests { - use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256}; - use crate::sasl::{Mechanism, Step}; use super::{Exchange, ServerSecret}; @@ -115,16 +113,9 @@ mod tests { async fn run_round_trip_test(server_password: &str, client_password: &str) { let scram_secret = ServerSecret::build(server_password).await.unwrap(); - let sasl_client = - ScramSha256::new(client_password.as_bytes(), ChannelBinding::unsupported()); - - let outcome = super::exchange( - &scram_secret, - sasl_client, - crate::config::TlsServerEndPoint::Undefined, - ) - .await - .unwrap(); + let outcome = super::exchange(&scram_secret, client_password.as_bytes()) + .await + .unwrap(); match outcome { crate::sasl::Outcome::Success(_) => {} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index 51c0ba4e09..682cbe795f 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -2,7 +2,11 @@ use std::convert::Infallible; -use postgres_protocol::authentication::sasl::ScramSha256; +use hmac::{Hmac, Mac}; +use sha2::digest::FixedOutput; +use sha2::{Digest, Sha256}; +use subtle::{Choice, ConstantTimeEq}; +use tokio::task::yield_now; use super::messages::{ ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, @@ -71,40 +75,71 @@ impl<'a> Exchange<'a> { } } +// copied from +async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] { + let hmac = Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + let mut prev = hmac + .clone() + .chain_update(salt) + .chain_update(1u32.to_be_bytes()) + .finalize() + .into_bytes(); + + let mut hi = prev; + + for i in 1..iterations { + prev = hmac.clone().chain_update(prev).finalize().into_bytes(); + + for (hi, prev) in hi.iter_mut().zip(prev) { + *hi ^= prev; + } + // yield every ~250us + // hopefully reduces tail latencies + if i % 1024 == 0 { + yield_now().await + } + } + + hi.into() +} + +// copied from +async fn derive_keys(password: &[u8], salt: &[u8], iterations: u32) -> ([u8; 32], [u8; 32]) { + let salted_password = pbkdf2(password, salt, iterations).await; + + let make_key = |name| { + let key = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes") + .chain_update(name) + .finalize(); + + <[u8; 32]>::from(key.into_bytes()) + }; + + (make_key(b"Client Key"), make_key(b"Server Key")) +} + pub async fn exchange( secret: &ServerSecret, - mut client: ScramSha256, - tls_server_end_point: config::TlsServerEndPoint, + password: &[u8], ) -> sasl::Result> { - use sasl::Step::*; + let salt = base64::decode(&secret.salt_base64)?; + let (client_key, server_key) = derive_keys(password, &salt, secret.iterations).await; + let stored_key: [u8; 32] = Sha256::default() + .chain_update(client_key) + .finalize_fixed() + .into(); - let init = SaslInitial { - nonce: rand::random, - }; + // constant time to not leak partial key match + let valid = stored_key.ct_eq(&secret.stored_key.as_bytes()) + | server_key.ct_eq(&secret.server_key.as_bytes()) + | Choice::from(secret.doomed as u8); - let client_first = std::str::from_utf8(client.message()) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - let sent = match init.transition(secret, &tls_server_end_point, client_first)? { - Continue(sent, server_first) => { - client.update(server_first.as_bytes()).await?; - sent - } - Success(x, _) => match x {}, - Failure(msg) => return Ok(sasl::Outcome::Failure(msg)), - }; - - let client_final = std::str::from_utf8(client.message()) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - let keys = match sent.transition(secret, &tls_server_end_point, client_final)? { - Success(keys, server_final) => { - client.finish(server_final.as_bytes())?; - keys - } - Continue(x, _) => match x {}, - Failure(msg) => return Ok(sasl::Outcome::Failure(msg)), - }; - - Ok(sasl::Outcome::Success(keys)) + if valid.into() { + Ok(sasl::Outcome::Success(super::ScramKey::from(client_key))) + } else { + Ok(sasl::Outcome::Failure("password doesn't match")) + } } impl SaslInitial {