From 77f3a30440aba4845da3a5203a2764fed4d96648 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 22 Mar 2024 13:31:10 +0000 Subject: [PATCH] proxy: unit tests for auth_quirks (#7199) ## Problem I noticed code coverage for auth_quirks was pretty bare ## Summary of changes Adds 3 happy path unit tests for auth_quirks * scram * cleartext (websockets) * cleartext (password hack) --- Cargo.lock | 1 + Cargo.toml | 1 + proxy/Cargo.toml | 1 + proxy/src/auth/backend.rs | 225 +++++++++++++++++++++++++++++ proxy/src/compute.rs | 11 +- proxy/src/console.rs | 2 +- proxy/src/console/provider.rs | 5 +- proxy/src/console/provider/mock.rs | 2 - proxy/src/console/provider/neon.rs | 2 - proxy/src/scram/exchange.rs | 28 ++-- proxy/src/scram/key.rs | 16 +- proxy/src/scram/messages.rs | 22 +++ proxy/src/scram/secret.rs | 7 + 13 files changed, 285 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dcf1c49924..6409c79ef9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4237,6 +4237,7 @@ dependencies = [ "consumption_metrics", "dashmap", "env_logger", + "fallible-iterator", "futures", "git-version", "hashbrown 0.13.2", diff --git a/Cargo.toml b/Cargo.toml index 2741bd046b..4dda63ff58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ either = "1.8" enum-map = "2.4.2" enumset = "1.0.12" fail = "0.5.0" +fallible-iterator = "0.2" fs2 = "0.4.3" futures = "0.3" futures-core = "0.3" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 3566d8b728..57a2736d5b 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -97,6 +97,7 @@ workspace_hack.workspace = true [dev-dependencies] camino-tempfile.workspace = true +fallible-iterator.workspace = true rcgen.workspace = true rstest.workspace = true tokio-postgres-rustls.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index bc307230dd..04fe83d8eb 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -408,3 +408,228 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use bytes::BytesMut; + use fallible_iterator::FallibleIterator; + use postgres_protocol::{ + authentication::sasl::{ChannelBinding, ScramSha256}, + message::{backend::Message as PgMessage, frontend}, + }; + use provider::AuthSecret; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + + use crate::{ + auth::{ComputeUserInfoMaybeEndpoint, IpPattern}, + config::AuthenticationConfig, + console::{ + self, + provider::{self, CachedAllowedIps, CachedRoleSecret}, + CachedNodeInfo, + }, + context::RequestMonitoring, + proxy::NeonOptions, + scram::ServerSecret, + stream::{PqStream, Stream}, + }; + + use super::auth_quirks; + + struct Auth { + ips: Vec, + secret: AuthSecret, + } + + impl console::Api for Auth { + async fn get_role_secret( + &self, + _ctx: &mut RequestMonitoring, + _user_info: &super::ComputeUserInfo, + ) -> Result { + Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) + } + + async fn get_allowed_ips_and_secret( + &self, + _ctx: &mut RequestMonitoring, + _user_info: &super::ComputeUserInfo, + ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError> + { + Ok(( + CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())), + Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))), + )) + } + + async fn wake_compute( + &self, + _ctx: &mut RequestMonitoring, + _user_info: &super::ComputeUserInfo, + ) -> Result { + unimplemented!() + } + } + + static CONFIG: &AuthenticationConfig = &AuthenticationConfig { + scram_protocol_timeout: std::time::Duration::from_secs(5), + }; + + async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage { + loop { + r.read_buf(&mut *b).await.unwrap(); + if let Some(m) = PgMessage::parse(&mut *b).unwrap() { + break m; + } + } + } + + #[tokio::test] + async fn auth_quirks_scram() { + let (mut client, server) = tokio::io::duplex(1024); + let mut stream = PqStream::new(Stream::from_raw(server)); + + let mut ctx = RequestMonitoring::test(); + let api = Auth { + ips: vec![], + secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), + }; + + let user_info = ComputeUserInfoMaybeEndpoint { + user: "conrad".into(), + endpoint_id: Some("endpoint".into()), + options: NeonOptions::default(), + }; + + let handle = tokio::spawn(async move { + let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported()); + + let mut read = BytesMut::new(); + + // server should offer scram + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationSasl(a) => { + let options: Vec<&str> = a.mechanisms().collect().unwrap(); + assert_eq!(options, ["SCRAM-SHA-256"]); + } + _ => panic!("wrong message"), + } + + // client sends client-first-message + let mut write = BytesMut::new(); + frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap(); + client.write_all(&write).await.unwrap(); + + // server response with server-first-message + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationSaslContinue(a) => { + scram.update(a.data()).await.unwrap(); + } + _ => panic!("wrong message"), + } + + // client response with client-final-message + write.clear(); + frontend::sasl_response(scram.message(), &mut write).unwrap(); + client.write_all(&write).await.unwrap(); + + // server response with server-final-message + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationSaslFinal(a) => { + scram.finish(a.data()).unwrap(); + } + _ => panic!("wrong message"), + } + }); + + let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, CONFIG) + .await + .unwrap(); + + handle.await.unwrap(); + } + + #[tokio::test] + async fn auth_quirks_cleartext() { + let (mut client, server) = tokio::io::duplex(1024); + let mut stream = PqStream::new(Stream::from_raw(server)); + + let mut ctx = RequestMonitoring::test(); + let api = Auth { + ips: vec![], + secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), + }; + + let user_info = ComputeUserInfoMaybeEndpoint { + user: "conrad".into(), + endpoint_id: Some("endpoint".into()), + options: NeonOptions::default(), + }; + + let handle = tokio::spawn(async move { + let mut read = BytesMut::new(); + let mut write = BytesMut::new(); + + // server should offer cleartext + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationCleartextPassword => {} + _ => panic!("wrong message"), + } + + // client responds with password + write.clear(); + frontend::password_message(b"my-secret-password", &mut write).unwrap(); + client.write_all(&write).await.unwrap(); + }); + + let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG) + .await + .unwrap(); + + handle.await.unwrap(); + } + + #[tokio::test] + async fn auth_quirks_password_hack() { + let (mut client, server) = tokio::io::duplex(1024); + let mut stream = PqStream::new(Stream::from_raw(server)); + + let mut ctx = RequestMonitoring::test(); + let api = Auth { + ips: vec![], + secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), + }; + + let user_info = ComputeUserInfoMaybeEndpoint { + user: "conrad".into(), + endpoint_id: None, + options: NeonOptions::default(), + }; + + let handle = tokio::spawn(async move { + let mut read = BytesMut::new(); + + // server should offer cleartext + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationCleartextPassword => {} + _ => panic!("wrong message"), + } + + // client responds with password + let mut write = BytesMut::new(); + frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write) + .unwrap(); + client.write_all(&write).await.unwrap(); + }); + + let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG) + .await + .unwrap(); + + assert_eq!(creds.info.endpoint, "my-endpoint"); + + handle.await.unwrap(); + } +} diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index b61c1fb9ef..65153babcb 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -82,14 +82,13 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; /// A config for establishing a connection to compute node. /// Eventually, `tokio_postgres` will be replaced with something better. /// Newtype allows us to implement methods on top of it. -#[derive(Clone)] -#[repr(transparent)] +#[derive(Clone, Default)] pub struct ConnCfg(Box); /// Creation and initialization routines. impl ConnCfg { pub fn new() -> Self { - Self(Default::default()) + Self::default() } /// Reuse password or auth keys from the other config. @@ -165,12 +164,6 @@ impl std::ops::DerefMut for ConnCfg { } } -impl Default for ConnCfg { - fn default() -> Self { - Self::new() - } -} - impl ConnCfg { /// Establish a raw TCP connection to the compute node. async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> { diff --git a/proxy/src/console.rs b/proxy/src/console.rs index fd3c46b946..ea95e83437 100644 --- a/proxy/src/console.rs +++ b/proxy/src/console.rs @@ -6,7 +6,7 @@ pub mod messages; /// Wrappers for console APIs and their mocks. pub mod provider; -pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo}; +pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo}; /// Various cache-related types. pub mod caches { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 8609606273..69bfd6b045 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -14,7 +14,6 @@ use crate::{ context::RequestMonitoring, scram, EndpointCacheKey, ProjectId, }; -use async_trait::async_trait; use dashmap::DashMap; use std::{sync::Arc, time::Duration}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; @@ -326,8 +325,7 @@ pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc), } -#[async_trait] impl Api for ConsoleBackend { async fn get_role_secret( &self, diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 0579ef6fc4..b759c81373 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -8,7 +8,6 @@ use crate::console::provider::{CachedAllowedIps, CachedRoleSecret}; use crate::context::RequestMonitoring; use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; use crate::{auth::IpPattern, cache::Cached}; -use async_trait::async_trait; use futures::TryFutureExt; use std::{str::FromStr, sync::Arc}; use thiserror::Error; @@ -144,7 +143,6 @@ async fn get_execute_postgres_query( Ok(Some(entry)) } -#[async_trait] impl super::Api for Api { #[tracing::instrument(skip_all)] async fn get_role_secret( diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index b36663518d..89ebfa57f1 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -14,7 +14,6 @@ use crate::{ context::RequestMonitoring, metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}, }; -use async_trait::async_trait; use futures::TryFutureExt; use std::sync::Arc; use tokio::time::Instant; @@ -168,7 +167,6 @@ impl Api { } } -#[async_trait] impl super::Api for Api { #[tracing::instrument(skip_all)] async fn get_role_secret( diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index 682cbe795f..89dd33e59f 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -3,9 +3,7 @@ use std::convert::Infallible; use hmac::{Hmac, Mac}; -use sha2::digest::FixedOutput; -use sha2::{Digest, Sha256}; -use subtle::{Choice, ConstantTimeEq}; +use sha2::Sha256; use tokio::task::yield_now; use super::messages::{ @@ -13,6 +11,7 @@ use super::messages::{ }; use super::secret::ServerSecret; use super::signature::SignatureBuilder; +use super::ScramKey; use crate::config; use crate::sasl::{self, ChannelBinding, Error as SaslError}; @@ -104,7 +103,7 @@ async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] { } // copied from -async fn derive_keys(password: &[u8], salt: &[u8], iterations: u32) -> ([u8; 32], [u8; 32]) { +async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey { let salted_password = pbkdf2(password, salt, iterations).await; let make_key = |name| { @@ -116,7 +115,7 @@ async fn derive_keys(password: &[u8], salt: &[u8], iterations: u32) -> ([u8; 32] <[u8; 32]>::from(key.into_bytes()) }; - (make_key(b"Client Key"), make_key(b"Server Key")) + make_key(b"Client Key").into() } pub async fn exchange( @@ -124,21 +123,12 @@ pub async fn exchange( password: &[u8], ) -> sasl::Result> { let salt = base64::decode(&secret.salt_base64)?; - let (client_key, server_key) = derive_keys(password, &salt, secret.iterations).await; - let stored_key: [u8; 32] = Sha256::default() - .chain_update(client_key) - .finalize_fixed() - .into(); + let client_key = derive_client_key(password, &salt, secret.iterations).await; - // constant time to not leak partial key match - let valid = stored_key.ct_eq(&secret.stored_key.as_bytes()) - | server_key.ct_eq(&secret.server_key.as_bytes()) - | Choice::from(secret.doomed as u8); - - if valid.into() { - Ok(sasl::Outcome::Success(super::ScramKey::from(client_key))) - } else { + if secret.is_password_invalid(&client_key).into() { Ok(sasl::Outcome::Failure("password doesn't match")) + } else { + Ok(sasl::Outcome::Success(client_key)) } } @@ -220,7 +210,7 @@ impl SaslSentInner { .derive_client_key(&client_final_message.proof); // Auth fails either if keys don't match or it's pre-determined to fail. - if client_key.sha256() != secret.stored_key || secret.doomed { + if secret.is_password_invalid(&client_key).into() { return Ok(sasl::Step::Failure("password doesn't match")); } diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index 973126e729..32a3dbd203 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -1,17 +1,31 @@ //! Tools for client/server/stored key management. +use subtle::ConstantTimeEq; + /// Faithfully taken from PostgreSQL. pub const SCRAM_KEY_LEN: usize = 32; /// One of the keys derived from the user's password. /// We use the same structure for all keys, i.e. /// `ClientKey`, `StoredKey`, and `ServerKey`. -#[derive(Clone, Default, PartialEq, Eq, Debug)] +#[derive(Clone, Default, Eq, Debug)] #[repr(transparent)] pub struct ScramKey { bytes: [u8; SCRAM_KEY_LEN], } +impl PartialEq for ScramKey { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl ConstantTimeEq for ScramKey { + fn ct_eq(&self, other: &Self) -> subtle::Choice { + self.bytes.ct_eq(&other.bytes) + } +} + impl ScramKey { pub fn sha256(&self) -> Self { super::sha256([self.as_ref()]).into() diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs index b59baec508..f9372540ca 100644 --- a/proxy/src/scram/messages.rs +++ b/proxy/src/scram/messages.rs @@ -206,6 +206,28 @@ mod tests { } } + #[test] + fn parse_client_first_message_with_invalid_gs2_authz() { + assert!(ClientFirstMessage::parse("n,authzid,n=user,r=nonce").is_none()) + } + + #[test] + fn parse_client_first_message_with_extra_params() { + let msg = ClientFirstMessage::parse("n,,n=user,r=nonce,a=foo,b=bar,c=baz").unwrap(); + assert_eq!(msg.bare, "n=user,r=nonce,a=foo,b=bar,c=baz"); + assert_eq!(msg.username, "user"); + assert_eq!(msg.nonce, "nonce"); + assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient); + } + + #[test] + fn parse_client_first_message_with_extra_params_invalid() { + // must be of the form `=<...>` + assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,abc=foo").is_none()); + assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,1=foo").is_none()); + assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,a").is_none()); + } + #[test] fn parse_client_final_message() { let input = [ diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index b46d8c3ab5..f3414cb8ec 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -1,5 +1,7 @@ //! Tools for SCRAM server secret management. +use subtle::{Choice, ConstantTimeEq}; + use super::base64_decode_array; use super::key::ScramKey; @@ -40,6 +42,11 @@ impl ServerSecret { Some(secret) } + pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice { + // constant time to not leak partial key match + client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8) + } + /// To avoid revealing information to an attacker, we use a /// mocked server secret even if the user doesn't exist. /// See `auth-scram.c : mock_scram_secret` for details.