proxy: unit tests for auth_quirks (#7199)

## Problem

I noticed code coverage for auth_quirks was pretty bare

## Summary of changes

Adds 3 happy path unit tests for auth_quirks
* scram
* cleartext (websockets)
* cleartext (password hack)
This commit is contained in:
Conrad Ludgate
2024-03-22 13:31:10 +00:00
committed by GitHub
parent 62b318c928
commit 77f3a30440
13 changed files with 285 additions and 38 deletions

1
Cargo.lock generated
View File

@@ -4237,6 +4237,7 @@ dependencies = [
"consumption_metrics",
"dashmap",
"env_logger",
"fallible-iterator",
"futures",
"git-version",
"hashbrown 0.13.2",

View File

@@ -79,6 +79,7 @@ either = "1.8"
enum-map = "2.4.2"
enumset = "1.0.12"
fail = "0.5.0"
fallible-iterator = "0.2"
fs2 = "0.4.3"
futures = "0.3"
futures-core = "0.3"

View File

@@ -97,6 +97,7 @@ workspace_hack.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true
fallible-iterator.workspace = true
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -408,3 +408,228 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use postgres_protocol::{
authentication::sasl::{ChannelBinding, ScramSha256},
message::{backend::Message as PgMessage, frontend},
};
use provider::AuthSecret;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use crate::{
auth::{ComputeUserInfoMaybeEndpoint, IpPattern},
config::AuthenticationConfig,
console::{
self,
provider::{self, CachedAllowedIps, CachedRoleSecret},
CachedNodeInfo,
},
context::RequestMonitoring,
proxy::NeonOptions,
scram::ServerSecret,
stream::{PqStream, Stream},
};
use super::auth_quirks;
struct Auth {
ips: Vec<IpPattern>,
secret: AuthSecret,
}
impl console::Api for Auth {
async fn get_role_secret(
&self,
_ctx: &mut RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
}
async fn get_allowed_ips_and_secret(
&self,
_ctx: &mut RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
{
Ok((
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
))
}
async fn wake_compute(
&self,
_ctx: &mut RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
unimplemented!()
}
}
static CONFIG: &AuthenticationConfig = &AuthenticationConfig {
scram_protocol_timeout: std::time::Duration::from_secs(5),
};
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
loop {
r.read_buf(&mut *b).await.unwrap();
if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
break m;
}
}
}
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};
let user_info = ComputeUserInfoMaybeEndpoint {
user: "conrad".into(),
endpoint_id: Some("endpoint".into()),
options: NeonOptions::default(),
};
let handle = tokio::spawn(async move {
let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
let mut read = BytesMut::new();
// server should offer scram
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSasl(a) => {
let options: Vec<&str> = a.mechanisms().collect().unwrap();
assert_eq!(options, ["SCRAM-SHA-256"]);
}
_ => panic!("wrong message"),
}
// client sends client-first-message
let mut write = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
client.write_all(&write).await.unwrap();
// server response with server-first-message
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSaslContinue(a) => {
scram.update(a.data()).await.unwrap();
}
_ => panic!("wrong message"),
}
// client response with client-final-message
write.clear();
frontend::sasl_response(scram.message(), &mut write).unwrap();
client.write_all(&write).await.unwrap();
// server response with server-final-message
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSaslFinal(a) => {
scram.finish(a.data()).unwrap();
}
_ => panic!("wrong message"),
}
});
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, CONFIG)
.await
.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};
let user_info = ComputeUserInfoMaybeEndpoint {
user: "conrad".into(),
endpoint_id: Some("endpoint".into()),
options: NeonOptions::default(),
};
let handle = tokio::spawn(async move {
let mut read = BytesMut::new();
let mut write = BytesMut::new();
// server should offer cleartext
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationCleartextPassword => {}
_ => panic!("wrong message"),
}
// client responds with password
write.clear();
frontend::password_message(b"my-secret-password", &mut write).unwrap();
client.write_all(&write).await.unwrap();
});
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
.await
.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};
let user_info = ComputeUserInfoMaybeEndpoint {
user: "conrad".into(),
endpoint_id: None,
options: NeonOptions::default(),
};
let handle = tokio::spawn(async move {
let mut read = BytesMut::new();
// server should offer cleartext
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationCleartextPassword => {}
_ => panic!("wrong message"),
}
// client responds with password
let mut write = BytesMut::new();
frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
.unwrap();
client.write_all(&write).await.unwrap();
});
let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
.await
.unwrap();
assert_eq!(creds.info.endpoint, "my-endpoint");
handle.await.unwrap();
}
}

View File

@@ -82,14 +82,13 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `tokio_postgres` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone)]
#[repr(transparent)]
#[derive(Clone, Default)]
pub struct ConnCfg(Box<tokio_postgres::Config>);
/// Creation and initialization routines.
impl ConnCfg {
pub fn new() -> Self {
Self(Default::default())
Self::default()
}
/// Reuse password or auth keys from the other config.
@@ -165,12 +164,6 @@ impl std::ops::DerefMut for ConnCfg {
}
}
impl Default for ConnCfg {
fn default() -> Self {
Self::new()
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {

View File

@@ -6,7 +6,7 @@ pub mod messages;
/// Wrappers for console APIs and their mocks.
pub mod provider;
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
/// Various cache-related types.
pub mod caches {

View File

@@ -14,7 +14,6 @@ use crate::{
context::RequestMonitoring,
scram, EndpointCacheKey, ProjectId,
};
use async_trait::async_trait;
use dashmap::DashMap;
use std::{sync::Arc, time::Duration};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
@@ -326,8 +325,7 @@ pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPatt
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
#[async_trait]
pub trait Api {
pub(crate) trait Api {
/// Get the client's auth secret for authentication.
/// Returns option because user not found situation is special.
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
@@ -363,7 +361,6 @@ pub enum ConsoleBackend {
Test(Box<dyn crate::auth::backend::TestBackend>),
}
#[async_trait]
impl Api for ConsoleBackend {
async fn get_role_secret(
&self,

View File

@@ -8,7 +8,6 @@ use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
use crate::context::RequestMonitoring;
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use async_trait::async_trait;
use futures::TryFutureExt;
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
@@ -144,7 +143,6 @@ async fn get_execute_postgres_query(
Ok(Some(entry))
}
#[async_trait]
impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_role_secret(

View File

@@ -14,7 +14,6 @@ use crate::{
context::RequestMonitoring,
metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
};
use async_trait::async_trait;
use futures::TryFutureExt;
use std::sync::Arc;
use tokio::time::Instant;
@@ -168,7 +167,6 @@ impl Api {
}
}
#[async_trait]
impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_role_secret(

View File

@@ -3,9 +3,7 @@
use std::convert::Infallible;
use hmac::{Hmac, Mac};
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use subtle::{Choice, ConstantTimeEq};
use sha2::Sha256;
use tokio::task::yield_now;
use super::messages::{
@@ -13,6 +11,7 @@ use super::messages::{
};
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::ScramKey;
use crate::config;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
@@ -104,7 +103,7 @@ async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_keys(password: &[u8], salt: &[u8], iterations: u32) -> ([u8; 32], [u8; 32]) {
async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey {
let salted_password = pbkdf2(password, salt, iterations).await;
let make_key = |name| {
@@ -116,7 +115,7 @@ async fn derive_keys(password: &[u8], salt: &[u8], iterations: u32) -> ([u8; 32]
<[u8; 32]>::from(key.into_bytes())
};
(make_key(b"Client Key"), make_key(b"Server Key"))
make_key(b"Client Key").into()
}
pub async fn exchange(
@@ -124,21 +123,12 @@ pub async fn exchange(
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = base64::decode(&secret.salt_base64)?;
let (client_key, server_key) = derive_keys(password, &salt, secret.iterations).await;
let stored_key: [u8; 32] = Sha256::default()
.chain_update(client_key)
.finalize_fixed()
.into();
let client_key = derive_client_key(password, &salt, secret.iterations).await;
// constant time to not leak partial key match
let valid = stored_key.ct_eq(&secret.stored_key.as_bytes())
| server_key.ct_eq(&secret.server_key.as_bytes())
| Choice::from(secret.doomed as u8);
if valid.into() {
Ok(sasl::Outcome::Success(super::ScramKey::from(client_key)))
} else {
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))
} else {
Ok(sasl::Outcome::Success(client_key))
}
}
@@ -220,7 +210,7 @@ impl SaslSentInner {
.derive_client_key(&client_final_message.proof);
// Auth fails either if keys don't match or it's pre-determined to fail.
if client_key.sha256() != secret.stored_key || secret.doomed {
if secret.is_password_invalid(&client_key).into() {
return Ok(sasl::Step::Failure("password doesn't match"));
}

View File

@@ -1,17 +1,31 @@
//! Tools for client/server/stored key management.
use subtle::ConstantTimeEq;
/// Faithfully taken from PostgreSQL.
pub const SCRAM_KEY_LEN: usize = 32;
/// One of the keys derived from the user's password.
/// We use the same structure for all keys, i.e.
/// `ClientKey`, `StoredKey`, and `ServerKey`.
#[derive(Clone, Default, PartialEq, Eq, Debug)]
#[derive(Clone, Default, Eq, Debug)]
#[repr(transparent)]
pub struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl PartialEq for ScramKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl ConstantTimeEq for ScramKey {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.bytes.ct_eq(&other.bytes)
}
}
impl ScramKey {
pub fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()

View File

@@ -206,6 +206,28 @@ mod tests {
}
}
#[test]
fn parse_client_first_message_with_invalid_gs2_authz() {
assert!(ClientFirstMessage::parse("n,authzid,n=user,r=nonce").is_none())
}
#[test]
fn parse_client_first_message_with_extra_params() {
let msg = ClientFirstMessage::parse("n,,n=user,r=nonce,a=foo,b=bar,c=baz").unwrap();
assert_eq!(msg.bare, "n=user,r=nonce,a=foo,b=bar,c=baz");
assert_eq!(msg.username, "user");
assert_eq!(msg.nonce, "nonce");
assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
}
#[test]
fn parse_client_first_message_with_extra_params_invalid() {
// must be of the form `<ascii letter>=<...>`
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,abc=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,1=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,a").is_none());
}
#[test]
fn parse_client_final_message() {
let input = [

View File

@@ -1,5 +1,7 @@
//! Tools for SCRAM server secret management.
use subtle::{Choice, ConstantTimeEq};
use super::base64_decode_array;
use super::key::ScramKey;
@@ -40,6 +42,11 @@ impl ServerSecret {
Some(secret)
}
pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice {
// constant time to not leak partial key match
client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8)
}
/// To avoid revealing information to an attacker, we use a
/// mocked server secret even if the user doesn't exist.
/// See `auth-scram.c : mock_scram_secret` for details.