mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-15 20:20:38 +00:00
Merge branch 'main' into bojan-get-page-tests
This commit is contained in:
@@ -1,14 +1,24 @@
|
||||
mod credentials;
|
||||
|
||||
#[cfg(test)]
|
||||
mod flow;
|
||||
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::cplane_api::{self, CPlaneApi};
|
||||
use crate::error::UserFacingError;
|
||||
use crate::stream::PqStream;
|
||||
use crate::waiters;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
pub use credentials::ClientCredentials;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use flow::*;
|
||||
|
||||
/// Common authentication error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
@@ -16,13 +26,17 @@ pub enum AuthErrorImpl {
|
||||
#[error(transparent)]
|
||||
Console(#[from] cplane_api::AuthError),
|
||||
|
||||
#[cfg(test)]
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
|
||||
/// For passwords that couldn't be processed by [`parse_password`].
|
||||
#[error("Malformed password message")]
|
||||
MalformedPassword,
|
||||
|
||||
/// Errors produced by [`PqStream`].
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl AuthErrorImpl {
|
||||
@@ -67,70 +81,6 @@ impl UserFacingError for AuthError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ClientCredsParseError {
|
||||
#[error("Parameter `{0}` is missing in startup packet")]
|
||||
MissingKey(&'static str),
|
||||
}
|
||||
|
||||
impl UserFacingError for ClientCredsParseError {}
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
}
|
||||
|
||||
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
type Error = ClientCredsParseError;
|
||||
|
||||
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
|
||||
let mut get_param = |key| {
|
||||
value
|
||||
.remove(key)
|
||||
.ok_or(ClientCredsParseError::MissingKey(key))
|
||||
};
|
||||
|
||||
let user = get_param("user")?;
|
||||
let db = get_param("database")?;
|
||||
|
||||
Ok(Self { user, dbname: db })
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
/// Use credentials to authenticate the user.
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<DatabaseInfo, AuthError> {
|
||||
fail::fail_point!("proxy-authenticate", |_| {
|
||||
Err(AuthError::auth_failed("failpoint triggered"))
|
||||
});
|
||||
|
||||
use crate::config::ClientAuthMethod::*;
|
||||
use crate::config::RouterConfig::*;
|
||||
match &config.router_config {
|
||||
Static { host, port } => handle_static(host.clone(), *port, client, self).await,
|
||||
Dynamic(Mixed) => {
|
||||
if self.user.ends_with("@zenith") {
|
||||
handle_existing_user(config, client, self).await
|
||||
} else {
|
||||
handle_new_user(config, client).await
|
||||
}
|
||||
}
|
||||
Dynamic(Password) => handle_existing_user(config, client, self).await,
|
||||
Dynamic(Link) => handle_new_user(config, client).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
async fn handle_static(
|
||||
host: String,
|
||||
port: u16,
|
||||
@@ -169,7 +119,7 @@ async fn handle_existing_user(
|
||||
let md5_salt = rand::random();
|
||||
|
||||
client
|
||||
.write_message(&Be::AuthenticationMD5Password(&md5_salt))
|
||||
.write_message(&Be::AuthenticationMD5Password(md5_salt))
|
||||
.await?;
|
||||
|
||||
// Read client's password hash
|
||||
@@ -213,6 +163,10 @@ async fn handle_new_user(
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
fn parse_password(bytes: &[u8]) -> Option<&str> {
|
||||
std::str::from_utf8(bytes).ok()?.strip_suffix('\0')
|
||||
}
|
||||
|
||||
70
proxy/src/auth/credentials.rs
Normal file
70
proxy/src/auth/credentials.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use super::AuthError;
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::error::UserFacingError;
|
||||
use crate::stream::PqStream;
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ClientCredsParseError {
|
||||
#[error("Parameter `{0}` is missing in startup packet")]
|
||||
MissingKey(&'static str),
|
||||
}
|
||||
|
||||
impl UserFacingError for ClientCredsParseError {}
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
}
|
||||
|
||||
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
type Error = ClientCredsParseError;
|
||||
|
||||
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
|
||||
let mut get_param = |key| {
|
||||
value
|
||||
.remove(key)
|
||||
.ok_or(ClientCredsParseError::MissingKey(key))
|
||||
};
|
||||
|
||||
let user = get_param("user")?;
|
||||
let db = get_param("database")?;
|
||||
|
||||
Ok(Self { user, dbname: db })
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
/// Use credentials to authenticate the user.
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<DatabaseInfo, AuthError> {
|
||||
fail::fail_point!("proxy-authenticate", |_| {
|
||||
Err(AuthError::auth_failed("failpoint triggered"))
|
||||
});
|
||||
|
||||
use crate::config::ClientAuthMethod::*;
|
||||
use crate::config::RouterConfig::*;
|
||||
match &config.router_config {
|
||||
Static { host, port } => super::handle_static(host.clone(), *port, client, self).await,
|
||||
Dynamic(Mixed) => {
|
||||
if self.user.ends_with("@zenith") {
|
||||
super::handle_existing_user(config, client, self).await
|
||||
} else {
|
||||
super::handle_new_user(config, client).await
|
||||
}
|
||||
}
|
||||
Dynamic(Password) => super::handle_existing_user(config, client, self).await,
|
||||
Dynamic(Link) => super::handle_new_user(config, client).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
102
proxy/src/auth/flow.rs
Normal file
102
proxy/src/auth/flow.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::{AuthError, AuthErrorImpl};
|
||||
use crate::stream::PqStream;
|
||||
use crate::{sasl, scram};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
pub struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
}
|
||||
}
|
||||
|
||||
/// Use password-based auth in [`AuthFlow`].
|
||||
pub struct Md5(
|
||||
/// Salt for client.
|
||||
pub [u8; 4],
|
||||
);
|
||||
|
||||
impl AuthMethod for Md5 {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationMD5Password(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, Stream, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream>,
|
||||
/// State might contain ancillary data (see [`AuthFlow::begin`]).
|
||||
state: State,
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub fn new(stream: &'a mut PqStream<S>) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
state: Begin,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
self.stream.write_message(&method.first_message()).await?;
|
||||
|
||||
Ok(AuthFlow {
|
||||
stream: self.stream,
|
||||
state: method,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling simple MD5 password auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Md5> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
#[allow(unused)]
|
||||
pub async fn authenticate(self) -> Result<(), AuthError> {
|
||||
unimplemented!("MD5 auth flow is yet to be implemented");
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> Result<(), AuthError> {
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
|
||||
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
if !scram::METHODS.contains(&sasl.method) {
|
||||
return Err(AuthErrorImpl::auth_failed("method not supported").into());
|
||||
}
|
||||
|
||||
let secret = self.state.0;
|
||||
sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(secret, rand::random, None))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,8 @@
|
||||
///
|
||||
/// Postgres protocol proxy/router.
|
||||
///
|
||||
/// This service listens psql port and can check auth via external service
|
||||
/// (control plane API in our case) and can create new databases and accounts
|
||||
/// in somewhat transparent manner (again via communication with control plane API).
|
||||
///
|
||||
use anyhow::{bail, Context};
|
||||
use clap::{Arg, Command};
|
||||
use config::ProxyConfig;
|
||||
use futures::FutureExt;
|
||||
use std::future::Future;
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use zenith_utils::GIT_VERSION;
|
||||
|
||||
use crate::config::{ClientAuthMethod, RouterConfig};
|
||||
//! Postgres protocol proxy/router.
|
||||
//!
|
||||
//! This service listens psql port and can check auth via external service
|
||||
//! (control plane API in our case) and can create new databases and accounts
|
||||
//! in somewhat transparent manner (again via communication with control plane API).
|
||||
|
||||
mod auth;
|
||||
mod cancellation;
|
||||
@@ -27,6 +16,24 @@ mod proxy;
|
||||
mod stream;
|
||||
mod waiters;
|
||||
|
||||
// Currently SCRAM is only used in tests
|
||||
#[cfg(test)]
|
||||
mod parse;
|
||||
#[cfg(test)]
|
||||
mod sasl;
|
||||
#[cfg(test)]
|
||||
mod scram;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use clap::{Arg, Command};
|
||||
use config::ProxyConfig;
|
||||
use futures::FutureExt;
|
||||
use std::future::Future;
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use zenith_utils::GIT_VERSION;
|
||||
|
||||
use crate::config::{ClientAuthMethod, RouterConfig};
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
async fn flatten_err(
|
||||
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
|
||||
|
||||
18
proxy/src/parse.rs
Normal file
18
proxy/src/parse.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! Small parsing helpers.
|
||||
|
||||
use std::convert::TryInto;
|
||||
use std::ffi::CStr;
|
||||
|
||||
pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
|
||||
let pos = bytes.iter().position(|&x| x == 0)?;
|
||||
let (cstr, other) = bytes.split_at(pos + 1);
|
||||
// SAFETY: we've already checked that there's a terminator
|
||||
Some((unsafe { CStr::from_bytes_with_nul_unchecked(cstr) }, other))
|
||||
}
|
||||
|
||||
pub fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
|
||||
(bytes.len() >= N).then(|| {
|
||||
let (head, tail) = bytes.split_at(N);
|
||||
(head.try_into().unwrap(), tail)
|
||||
})
|
||||
}
|
||||
@@ -119,7 +119,6 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// We can't perform TLS handshake without a config
|
||||
let enc = tls.is_some();
|
||||
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
||||
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
@@ -219,32 +218,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<S> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use tokio::io::DuplexStream;
|
||||
use crate::{auth, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
async fn dummy_proxy(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
tls: Option<TlsConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
|
||||
// TODO: add some infra + tests for credentials
|
||||
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
|
||||
.await?
|
||||
.context("no stream")?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
hostname: &str,
|
||||
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
|
||||
@@ -262,19 +243,115 @@ mod tests {
|
||||
))
|
||||
}
|
||||
|
||||
struct ClientConfig<'a> {
|
||||
config: rustls::ClientConfig,
|
||||
hostname: &'a str,
|
||||
}
|
||||
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
self,
|
||||
) -> anyhow::Result<impl tokio_postgres::tls::TlsConnect<S>> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
Ok(tls)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate TLS certificates and build rustls configs for client and server.
|
||||
fn generate_tls_config(
|
||||
hostname: &str,
|
||||
) -> anyhow::Result<(ClientConfig<'_>, Arc<rustls::ServerConfig>)> {
|
||||
let (ca, cert, key) = generate_certs(hostname)?;
|
||||
|
||||
let server_config = {
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config.into()
|
||||
};
|
||||
|
||||
let client_config = {
|
||||
let mut config = rustls::ClientConfig::new();
|
||||
config.root_store.add(&ca)?;
|
||||
ClientConfig { config, hostname }
|
||||
};
|
||||
|
||||
Ok((client_config, server_config))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
trait TestAuth: Sized {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
_stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct NoAuth;
|
||||
impl TestAuth for NoAuth {}
|
||||
|
||||
struct Scram(scram::ServerSecret);
|
||||
|
||||
impl Scram {
|
||||
fn new(password: &str) -> anyhow::Result<Self> {
|
||||
let salt = rand::random::<[u8; 16]>();
|
||||
let secret = scram::ServerSecret::build(password, &salt, 256)
|
||||
.context("failed to generate scram secret")?;
|
||||
Ok(Scram(secret))
|
||||
}
|
||||
|
||||
fn mock(user: &str) -> Self {
|
||||
let salt = rand::random::<[u8; 32]>();
|
||||
Scram(scram::ServerSecret::mock(user, &salt))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TestAuth for Scram {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
auth::AuthFlow::new(stream)
|
||||
.begin(auth::Scram(&self.0))
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A dummy proxy impl which performs a handshake and reports auth success.
|
||||
async fn dummy_proxy(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
tls: Option<TlsConfig>,
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
|
||||
.await?
|
||||
.context("handshake failed")?;
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let server_config = {
|
||||
let (_ca, cert, key) = generate_certs("localhost")?;
|
||||
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config
|
||||
};
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
|
||||
let (_, server_config) = generate_tls_config("localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let client_err = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
@@ -301,30 +378,14 @@ mod tests {
|
||||
async fn handshake_tls() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (ca, cert, key) = generate_certs("localhost")?;
|
||||
|
||||
let server_config = {
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config
|
||||
};
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
|
||||
|
||||
let client_config = {
|
||||
let mut config = rustls::ClientConfig::new();
|
||||
config.root_store.add(&ca)?;
|
||||
config
|
||||
};
|
||||
|
||||
let mut mk = MakeRustlsConnect::new(client_config);
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, "localhost")?;
|
||||
let (client_config, server_config) = generate_tls_config("localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, tls)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
@@ -334,7 +395,7 @@ mod tests {
|
||||
async fn handshake_raw() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None));
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
@@ -350,7 +411,7 @@ mod tests {
|
||||
async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None));
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
|
||||
|
||||
let client_err = tokio_postgres::Config::new()
|
||||
.ssl_mode(SslMode::Disable)
|
||||
@@ -391,4 +452,66 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case("password_foo")]
|
||||
#[case("pwd-bar")]
|
||||
#[case("")]
|
||||
#[tokio::test]
|
||||
async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) = generate_tls_config("localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new(password)?,
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) = generate_tls_config("localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::mock("user"),
|
||||
));
|
||||
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
let password: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(rand::random::<u8>() as usize)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(&password) // no password will match the mocked secret
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await
|
||||
.err() // -> Option<E>
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
let _server_err = proxy
|
||||
.await?
|
||||
.err() // -> Option<E>
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
47
proxy/src/sasl.rs
Normal file
47
proxy/src/sasl.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
//! Simple Authentication and Security Layer.
|
||||
//!
|
||||
//! RFC: <https://datatracker.ietf.org/doc/html/rfc4422>.
|
||||
//!
|
||||
//! Reference implementation:
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-sasl.c>
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth.c>
|
||||
|
||||
mod channel_binding;
|
||||
mod messages;
|
||||
mod stream;
|
||||
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
pub use channel_binding::ChannelBinding;
|
||||
pub use messages::FirstMessage;
|
||||
pub use stream::SaslStream;
|
||||
|
||||
/// Fine-grained auth errors help in writing tests.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Failed to authenticate client: {0}")]
|
||||
AuthenticationFailed(&'static str),
|
||||
|
||||
#[error("Channel binding failed: {0}")]
|
||||
ChannelBindingFailed(&'static str),
|
||||
|
||||
#[error("Unsupported channel binding method: {0}")]
|
||||
ChannelBindingBadMethod(Box<str>),
|
||||
|
||||
#[error("Bad client message")]
|
||||
BadClientMessage,
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
/// A convenient result type for SASL exchange.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait.
|
||||
pub trait Mechanism: Sized {
|
||||
/// Produce a server challenge to be sent to the client.
|
||||
/// This is how this method is called in PostgreSQL (`libpq/sasl.h`).
|
||||
fn exchange(self, input: &str) -> Result<(Option<Self>, String)>;
|
||||
}
|
||||
85
proxy/src/sasl/channel_binding.rs
Normal file
85
proxy/src/sasl/channel_binding.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
//! Definition and parser for channel binding flag (a part of the `GS2` header).
|
||||
|
||||
/// Channel binding flag (possibly with params).
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum ChannelBinding<T> {
|
||||
/// Client doesn't support channel binding.
|
||||
NotSupportedClient,
|
||||
/// Client thinks server doesn't support channel binding.
|
||||
NotSupportedServer,
|
||||
/// Client wants to use this type of channel binding.
|
||||
Required(T),
|
||||
}
|
||||
|
||||
impl<T> ChannelBinding<T> {
|
||||
pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
NotSupportedClient => NotSupportedClient,
|
||||
NotSupportedServer => NotSupportedServer,
|
||||
Required(x) => Required(f(x)?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ChannelBinding<&'a str> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
use ChannelBinding::*;
|
||||
Some(match input {
|
||||
"n" => NotSupportedClient,
|
||||
"y" => NotSupportedServer,
|
||||
other => Required(other.strip_prefix("p=")?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
/// Encode channel binding data as base64 for subsequent checks.
|
||||
pub fn encode<E>(
|
||||
&self,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
|
||||
) -> Result<std::borrow::Cow<'static, str>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
NotSupportedClient => {
|
||||
// base64::encode("n,,")
|
||||
"biws".into()
|
||||
}
|
||||
NotSupportedServer => {
|
||||
// base64::encode("y,,")
|
||||
"eSws".into()
|
||||
}
|
||||
Required(mode) => {
|
||||
let msg = format!(
|
||||
"p={mode},,{data}",
|
||||
mode = mode,
|
||||
data = get_cbind_data(mode)?
|
||||
);
|
||||
base64::encode(msg).into()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn channel_binding_encode() -> anyhow::Result<()> {
|
||||
use ChannelBinding::*;
|
||||
|
||||
let cases = [
|
||||
(NotSupportedClient, base64::encode("n,,")),
|
||||
(NotSupportedServer, base64::encode("y,,")),
|
||||
(Required("foo"), base64::encode("p=foo,,bar")),
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
67
proxy/src/sasl/messages.rs
Normal file
67
proxy/src/sasl/messages.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
//! Definitions for SASL messages.
|
||||
|
||||
use crate::parse::{split_at_const, split_cstr};
|
||||
use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
|
||||
/// SASL-specific payload of [`PasswordMessage`](zenith_utils::pq_proto::FeMessage::PasswordMessage).
|
||||
#[derive(Debug)]
|
||||
pub struct FirstMessage<'a> {
|
||||
/// Authentication method, e.g. `"SCRAM-SHA-256"`.
|
||||
pub method: &'a str,
|
||||
/// Initial client message.
|
||||
pub message: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> FirstMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(bytes: &'a [u8]) -> Option<Self> {
|
||||
let (method_cstr, tail) = split_cstr(bytes)?;
|
||||
let method = method_cstr.to_str().ok()?;
|
||||
|
||||
let (len_bytes, bytes) = split_at_const(tail)?;
|
||||
let len = u32::from_be_bytes(*len_bytes) as usize;
|
||||
if len != bytes.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let message = std::str::from_utf8(bytes).ok()?;
|
||||
Some(Self { method, message })
|
||||
}
|
||||
}
|
||||
|
||||
/// A single SASL message.
|
||||
/// This struct is deliberately decoupled from lower-level
|
||||
/// [`BeAuthenticationSaslMessage`](zenith_utils::pq_proto::BeAuthenticationSaslMessage).
|
||||
#[derive(Debug)]
|
||||
pub(super) enum ServerMessage<T> {
|
||||
/// We expect to see more steps.
|
||||
Continue(T),
|
||||
/// This is the final step.
|
||||
Final(T),
|
||||
}
|
||||
|
||||
impl<'a> ServerMessage<&'a str> {
|
||||
pub(super) fn to_reply(&self) -> BeMessage<'a> {
|
||||
use BeAuthenticationSaslMessage::*;
|
||||
BeMessage::AuthenticationSasl(match self {
|
||||
ServerMessage::Continue(s) => Continue(s.as_bytes()),
|
||||
ServerMessage::Final(s) => Final(s.as_bytes()),
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_sasl_first_message() {
|
||||
let proto = "SCRAM-SHA-256";
|
||||
let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4";
|
||||
let sasl_len = (sasl.len() as u32).to_be_bytes();
|
||||
let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat();
|
||||
|
||||
let password = FirstMessage::parse(&bytes).unwrap();
|
||||
assert_eq!(password.method, proto);
|
||||
assert_eq!(password.message, sasl);
|
||||
}
|
||||
}
|
||||
70
proxy/src/sasl/stream.rs
Normal file
70
proxy/src/sasl/stream.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
//! Abstraction for the string-oriented SASL protocols.
|
||||
|
||||
use super::{messages::ServerMessage, Mechanism};
|
||||
use crate::stream::PqStream;
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
/// Abstracts away all peculiarities of the libpq's protocol.
|
||||
pub struct SaslStream<'a, S> {
|
||||
/// The underlying stream.
|
||||
stream: &'a mut PqStream<S>,
|
||||
/// Current password message we received from client.
|
||||
current: bytes::Bytes,
|
||||
/// First SASL message produced by client.
|
||||
first: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a, S> SaslStream<'a, S> {
|
||||
pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
current: bytes::Bytes::new(),
|
||||
first: Some(first),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
|
||||
// Receive a new SASL message from the client.
|
||||
async fn recv(&mut self) -> io::Result<&str> {
|
||||
if let Some(first) = self.first.take() {
|
||||
return Ok(first);
|
||||
}
|
||||
|
||||
self.current = self.stream.read_password_message().await?;
|
||||
let s = std::str::from_utf8(&self.current)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
// Send a SASL message to the client.
|
||||
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
|
||||
self.stream.write_message(&msg.to_reply()).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
/// Perform SASL message exchange according to the underlying algorithm
|
||||
/// until user is either authenticated or denied access.
|
||||
pub async fn authenticate(mut self, mut mechanism: impl Mechanism) -> super::Result<()> {
|
||||
loop {
|
||||
let input = self.recv().await?;
|
||||
let (moved, reply) = mechanism.exchange(input)?;
|
||||
match moved {
|
||||
Some(moved) => {
|
||||
self.send(&ServerMessage::Continue(&reply)).await?;
|
||||
mechanism = moved;
|
||||
}
|
||||
None => {
|
||||
self.send(&ServerMessage::Final(&reply)).await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
59
proxy/src/scram.rs
Normal file
59
proxy/src/scram.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
//! Salted Challenge Response Authentication Mechanism.
|
||||
//!
|
||||
//! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
|
||||
//!
|
||||
//! Reference implementation:
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
|
||||
|
||||
mod exchange;
|
||||
mod key;
|
||||
mod messages;
|
||||
mod password;
|
||||
mod secret;
|
||||
mod signature;
|
||||
|
||||
pub use secret::*;
|
||||
|
||||
pub use exchange::Exchange;
|
||||
pub use secret::ServerSecret;
|
||||
|
||||
use hmac::{Hmac, Mac, NewMac};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
// TODO: add SCRAM-SHA-256-PLUS
|
||||
/// A list of supported SCRAM methods.
|
||||
pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
|
||||
|
||||
/// Decode base64 into array without any heap allocations
|
||||
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
|
||||
let mut bytes = [0u8; N];
|
||||
|
||||
let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
|
||||
if size != N {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(bytes)
|
||||
}
|
||||
|
||||
/// This function essentially is `Hmac(sha256, key, input)`.
|
||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
|
||||
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(key).expect("bad key size");
|
||||
parts.into_iter().for_each(|s| mac.update(s));
|
||||
|
||||
// TODO: maybe newer `hmac` et al already migrated to regular arrays?
|
||||
let mut result = [0u8; 32];
|
||||
result.copy_from_slice(mac.finalize().into_bytes().as_slice());
|
||||
result
|
||||
}
|
||||
|
||||
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
parts.into_iter().for_each(|s| hasher.update(s));
|
||||
|
||||
let mut result = [0u8; 32];
|
||||
result.copy_from_slice(hasher.finalize().as_slice());
|
||||
result
|
||||
}
|
||||
134
proxy/src/scram/exchange.rs
Normal file
134
proxy/src/scram/exchange.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
//! Implementation of the SCRAM authentication algorithm.
|
||||
|
||||
use super::messages::{
|
||||
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
|
||||
};
|
||||
use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
#[derive(Debug)]
|
||||
struct TlsServerEndPoint;
|
||||
|
||||
impl std::fmt::Display for TlsServerEndPoint {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "tls-server-end-point")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for TlsServerEndPoint {
|
||||
type Err = sasl::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"tls-server-end-point" => Ok(TlsServerEndPoint),
|
||||
_ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ExchangeState {
|
||||
/// Waiting for [`ClientFirstMessage`].
|
||||
Initial,
|
||||
/// Waiting for [`ClientFinalMessage`].
|
||||
SaltSent {
|
||||
cbind_flag: ChannelBinding<TlsServerEndPoint>,
|
||||
client_first_message_bare: String,
|
||||
server_first_message: OwnedServerFirstMessage,
|
||||
},
|
||||
}
|
||||
|
||||
/// Server's side of SCRAM auth algorithm.
|
||||
#[derive(Debug)]
|
||||
pub struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
}
|
||||
|
||||
impl<'a> Exchange<'a> {
|
||||
pub fn new(
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: ExchangeState::Initial,
|
||||
secret,
|
||||
nonce,
|
||||
cert_digest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl sasl::Mechanism for Exchange<'_> {
|
||||
fn exchange(mut self, input: &str) -> sasl::Result<(Option<Self>, String)> {
|
||||
use ExchangeState::*;
|
||||
match &self.state {
|
||||
Initial => {
|
||||
let client_first_message =
|
||||
ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
|
||||
|
||||
let server_first_message = client_first_message.build_server_first_message(
|
||||
&(self.nonce)(),
|
||||
&self.secret.salt_base64,
|
||||
self.secret.iterations,
|
||||
);
|
||||
let msg = server_first_message.as_str().to_owned();
|
||||
|
||||
self.state = SaltSent {
|
||||
cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
|
||||
client_first_message_bare: client_first_message.bare.to_owned(),
|
||||
server_first_message,
|
||||
};
|
||||
|
||||
Ok((Some(self), msg))
|
||||
}
|
||||
SaltSent {
|
||||
cbind_flag,
|
||||
client_first_message_bare,
|
||||
server_first_message,
|
||||
} => {
|
||||
let client_final_message =
|
||||
ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| {
|
||||
self.cert_digest
|
||||
.map(base64::encode)
|
||||
.ok_or(SaslError::ChannelBindingFailed("no cert digest provided"))
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
if client_final_message.channel_binding != channel_binding {
|
||||
return Err(SaslError::ChannelBindingFailed("data mismatch"));
|
||||
}
|
||||
|
||||
if client_final_message.nonce != server_first_message.nonce() {
|
||||
return Err(SaslError::AuthenticationFailed("bad nonce"));
|
||||
}
|
||||
|
||||
let signature_builder = SignatureBuilder {
|
||||
client_first_message_bare,
|
||||
server_first_message: server_first_message.as_str(),
|
||||
client_final_message_without_proof: client_final_message.without_proof,
|
||||
};
|
||||
|
||||
let client_key = signature_builder
|
||||
.build(&self.secret.stored_key)
|
||||
.derive_client_key(&client_final_message.proof);
|
||||
|
||||
if client_key.sha256() != self.secret.stored_key {
|
||||
return Err(SaslError::AuthenticationFailed("keys don't match"));
|
||||
}
|
||||
|
||||
let msg = client_final_message
|
||||
.build_server_final_message(signature_builder, &self.secret.server_key);
|
||||
|
||||
Ok((None, msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
33
proxy/src/scram/key.rs
Normal file
33
proxy/src/scram/key.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! Tools for client/server/stored key management.
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_KEY_LEN: usize = 32;
|
||||
|
||||
/// One of the keys derived from the [password](super::password::SaltedPassword).
|
||||
/// We use the same structure for all keys, i.e.
|
||||
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
||||
#[derive(Default, Debug, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
pub struct ScramKey {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl ScramKey {
|
||||
pub fn sha256(&self) -> Self {
|
||||
super::sha256([self.as_ref()]).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
|
||||
#[inline(always)]
|
||||
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for ScramKey {
|
||||
#[inline(always)]
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.bytes
|
||||
}
|
||||
}
|
||||
232
proxy/src/scram/messages.rs
Normal file
232
proxy/src/scram/messages.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
//! Definitions for SCRAM messages.
|
||||
|
||||
use super::base64_decode_array;
|
||||
use super::key::{ScramKey, SCRAM_KEY_LEN};
|
||||
use super::signature::SignatureBuilder;
|
||||
use crate::sasl::ChannelBinding;
|
||||
use std::fmt;
|
||||
use std::ops::Range;
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_RAW_NONCE_LEN: usize = 18;
|
||||
|
||||
/// Although we ignore all extensions, we still have to validate the message.
|
||||
fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
|
||||
for mut chars in parts.map(|s| s.chars()) {
|
||||
let attr = chars.next()?;
|
||||
if !('a'..'z').contains(&attr) && !('A'..'Z').contains(&attr) {
|
||||
return None;
|
||||
}
|
||||
let eq = chars.next()?;
|
||||
if eq != '=' {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientFirstMessage<'a> {
|
||||
/// `client-first-message-bare`.
|
||||
pub bare: &'a str,
|
||||
/// Channel binding mode.
|
||||
pub cbind_flag: ChannelBinding<&'a str>,
|
||||
/// (Client username)[<https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf/src/backend/libpq/auth-scram.c#L13>].
|
||||
pub username: &'a str,
|
||||
/// Client nonce.
|
||||
pub nonce: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> ClientFirstMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
let mut parts = input.split(',');
|
||||
|
||||
let cbind_flag = ChannelBinding::parse(parts.next()?)?;
|
||||
|
||||
// PG doesn't support authorization identity,
|
||||
// so we don't bother defining GS2 header type
|
||||
let authzid = parts.next()?;
|
||||
if !authzid.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Unfortunately, `parts.as_str()` is unstable
|
||||
let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
|
||||
let (_, bare) = input.split_at(pos);
|
||||
|
||||
// In theory, these might be preceded by "reserved-mext" (i.e. "m=")
|
||||
let username = parts.next()?.strip_prefix("n=")?;
|
||||
let nonce = parts.next()?.strip_prefix("r=")?;
|
||||
|
||||
// Validate but ignore auth extensions
|
||||
validate_sasl_extensions(parts)?;
|
||||
|
||||
Some(Self {
|
||||
bare,
|
||||
cbind_flag,
|
||||
username,
|
||||
nonce,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a response to [`ClientFirstMessage`].
|
||||
pub fn build_server_first_message(
|
||||
&self,
|
||||
nonce: &[u8; SCRAM_RAW_NONCE_LEN],
|
||||
salt_base64: &str,
|
||||
iterations: u32,
|
||||
) -> OwnedServerFirstMessage {
|
||||
use std::fmt::Write;
|
||||
|
||||
let mut message = String::new();
|
||||
write!(&mut message, "r={}", self.nonce).unwrap();
|
||||
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
|
||||
let combined_nonce = 2..message.len();
|
||||
write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
|
||||
|
||||
// This design guarantees that it's impossible to create a
|
||||
// server-first-message without receiving a client-first-message
|
||||
OwnedServerFirstMessage {
|
||||
message,
|
||||
nonce: combined_nonce,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientFinalMessage<'a> {
|
||||
/// `client-final-message-without-proof`.
|
||||
pub without_proof: &'a str,
|
||||
/// Channel binding data (base64).
|
||||
pub channel_binding: &'a str,
|
||||
/// Combined client & server nonce.
|
||||
pub nonce: &'a str,
|
||||
/// Client auth proof.
|
||||
pub proof: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl<'a> ClientFinalMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
let (without_proof, proof) = input.rsplit_once(',')?;
|
||||
|
||||
let mut parts = without_proof.split(',');
|
||||
let channel_binding = parts.next()?.strip_prefix("c=")?;
|
||||
let nonce = parts.next()?.strip_prefix("r=")?;
|
||||
|
||||
// Validate but ignore auth extensions
|
||||
validate_sasl_extensions(parts)?;
|
||||
|
||||
let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
|
||||
|
||||
Some(Self {
|
||||
without_proof,
|
||||
channel_binding,
|
||||
nonce,
|
||||
proof,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a response to [`ClientFinalMessage`].
|
||||
pub fn build_server_final_message(
|
||||
&self,
|
||||
signature_builder: SignatureBuilder,
|
||||
server_key: &ScramKey,
|
||||
) -> String {
|
||||
let mut buf = String::from("v=");
|
||||
base64::encode_config_buf(
|
||||
signature_builder.build(server_key),
|
||||
base64::STANDARD,
|
||||
&mut buf,
|
||||
);
|
||||
|
||||
buf
|
||||
}
|
||||
}
|
||||
|
||||
/// We need to keep a convenient representation of this
|
||||
/// message for the next authentication step.
|
||||
pub struct OwnedServerFirstMessage {
|
||||
/// Owned `server-first-message`.
|
||||
message: String,
|
||||
/// Slice into `message`.
|
||||
nonce: Range<usize>,
|
||||
}
|
||||
|
||||
impl OwnedServerFirstMessage {
|
||||
/// Extract combined nonce from the message.
|
||||
#[inline(always)]
|
||||
pub fn nonce(&self) -> &str {
|
||||
&self.message[self.nonce.clone()]
|
||||
}
|
||||
|
||||
/// Get reference to a text representation of the message.
|
||||
#[inline(always)]
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.message
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for OwnedServerFirstMessage {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("ServerFirstMessage")
|
||||
.field("message", &self.as_str())
|
||||
.field("nonce", &self.nonce())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_client_first_message() {
|
||||
use ChannelBinding::*;
|
||||
|
||||
// (Almost) real strings captured during debug sessions
|
||||
let cases = [
|
||||
(NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
|
||||
(NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
|
||||
(
|
||||
Required("tls-server-end-point"),
|
||||
"p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju",
|
||||
),
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
let msg = ClientFirstMessage::parse(input).unwrap();
|
||||
|
||||
assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju");
|
||||
assert_eq!(msg.username, "pepe");
|
||||
assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
|
||||
assert_eq!(msg.cbind_flag, cb);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_client_final_message() {
|
||||
let input = [
|
||||
"c=eSws",
|
||||
"r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
|
||||
"p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
|
||||
]
|
||||
.join(",");
|
||||
|
||||
let msg = ClientFinalMessage::parse(&input).unwrap();
|
||||
assert_eq!(
|
||||
msg.without_proof,
|
||||
"c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
|
||||
);
|
||||
assert_eq!(
|
||||
msg.nonce,
|
||||
"iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
|
||||
);
|
||||
assert_eq!(
|
||||
base64::encode(msg.proof),
|
||||
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
|
||||
);
|
||||
}
|
||||
}
|
||||
48
proxy/src/scram/password.rs
Normal file
48
proxy/src/scram/password.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
//! Password hashing routines.
|
||||
|
||||
use super::key::ScramKey;
|
||||
|
||||
pub const SALTED_PASSWORD_LEN: usize = 32;
|
||||
|
||||
/// Salted hashed password is essential for [key](super::key) derivation.
|
||||
#[repr(transparent)]
|
||||
pub struct SaltedPassword {
|
||||
bytes: [u8; SALTED_PASSWORD_LEN],
|
||||
}
|
||||
|
||||
impl SaltedPassword {
|
||||
/// See `scram-common.c : scram_SaltedPassword` for details.
|
||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2898> (see `PBKDF2`).
|
||||
pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
|
||||
let one = 1_u32.to_be_bytes(); // magic
|
||||
|
||||
let mut current = super::hmac_sha256(password, [salt, &one]);
|
||||
let mut result = current;
|
||||
for _ in 1..iterations {
|
||||
current = super::hmac_sha256(password, [current.as_ref()]);
|
||||
// TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094
|
||||
for (i, x) in current.iter().enumerate() {
|
||||
result[i] ^= x;
|
||||
}
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
|
||||
/// Derive `ClientKey` from a salted hashed password.
|
||||
pub fn client_key(&self) -> ScramKey {
|
||||
super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into()
|
||||
}
|
||||
|
||||
/// Derive `ServerKey` from a salted hashed password.
|
||||
pub fn server_key(&self) -> ScramKey {
|
||||
super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword {
|
||||
#[inline(always)]
|
||||
fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
116
proxy/src/scram/secret.rs
Normal file
116
proxy/src/scram/secret.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
//! Tools for SCRAM server secret management.
|
||||
|
||||
use super::base64_decode_array;
|
||||
use super::key::ScramKey;
|
||||
|
||||
/// Server secret is produced from [password](super::password::SaltedPassword)
|
||||
/// and is used throughout the authentication process.
|
||||
#[derive(Debug)]
|
||||
pub struct ServerSecret {
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
pub iterations: u32,
|
||||
/// Salt used to hash user's password.
|
||||
pub salt_base64: String,
|
||||
/// Hashed `ClientKey`.
|
||||
pub stored_key: ScramKey,
|
||||
/// Used by client to verify server's signature.
|
||||
pub server_key: ScramKey,
|
||||
}
|
||||
|
||||
impl ServerSecret {
|
||||
pub fn parse(input: &str) -> Option<Self> {
|
||||
// SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
|
||||
let s = input.strip_prefix("SCRAM-SHA-256$")?;
|
||||
let (params, keys) = s.split_once('$')?;
|
||||
|
||||
let ((iterations, salt), (stored_key, server_key)) =
|
||||
params.split_once(':').zip(keys.split_once(':'))?;
|
||||
|
||||
let secret = ServerSecret {
|
||||
iterations: iterations.parse().ok()?,
|
||||
salt_base64: salt.to_owned(),
|
||||
stored_key: base64_decode_array(stored_key)?.into(),
|
||||
server_key: base64_decode_array(server_key)?.into(),
|
||||
};
|
||||
|
||||
Some(secret)
|
||||
}
|
||||
|
||||
/// To avoid revealing information to an attacker, we use a
|
||||
/// mocked server secret even if the user doesn't exist.
|
||||
/// See `auth-scram.c : mock_scram_secret` for details.
|
||||
pub fn mock(user: &str, nonce: &[u8; 32]) -> Self {
|
||||
// Refer to `auth-scram.c : scram_mock_salt`.
|
||||
let mocked_salt = super::sha256([user.as_bytes(), nonce]);
|
||||
|
||||
Self {
|
||||
iterations: 4096,
|
||||
salt_base64: base64::encode(&mocked_salt),
|
||||
stored_key: ScramKey::default(),
|
||||
server_key: ScramKey::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a new server secret from the prerequisites.
|
||||
/// XXX: We only use this function in tests.
|
||||
#[cfg(test)]
|
||||
pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option<Self> {
|
||||
// TODO: implement proper password normalization required by the RFC
|
||||
if !password.is_ascii() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations);
|
||||
|
||||
Some(Self {
|
||||
iterations,
|
||||
salt_base64: base64::encode(&salt),
|
||||
stored_key: password.client_key().sha256(),
|
||||
server_key: password.server_key(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_scram_secret() {
|
||||
let iterations = 4096;
|
||||
let salt = "+/tQQax7twvwTj64mjBsxQ==";
|
||||
let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns=";
|
||||
let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI=";
|
||||
|
||||
let secret = format!(
|
||||
"SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",
|
||||
iterations = iterations,
|
||||
salt = salt,
|
||||
stored_key = stored_key,
|
||||
server_key = server_key,
|
||||
);
|
||||
|
||||
let parsed = ServerSecret::parse(&secret).unwrap();
|
||||
assert_eq!(parsed.iterations, iterations);
|
||||
assert_eq!(parsed.salt_base64, salt);
|
||||
|
||||
assert_eq!(base64::encode(parsed.stored_key), stored_key);
|
||||
assert_eq!(base64::encode(parsed.server_key), server_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_scram_secret() {
|
||||
let salt = b"salt";
|
||||
let secret = ServerSecret::build("password", salt, 4096).unwrap();
|
||||
assert_eq!(secret.iterations, 4096);
|
||||
assert_eq!(secret.salt_base64, base64::encode(salt));
|
||||
assert_eq!(
|
||||
base64::encode(secret.stored_key.as_ref()),
|
||||
"lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ="
|
||||
);
|
||||
assert_eq!(
|
||||
base64::encode(secret.server_key.as_ref()),
|
||||
"ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw="
|
||||
);
|
||||
}
|
||||
}
|
||||
66
proxy/src/scram/signature.rs
Normal file
66
proxy/src/scram/signature.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
//! Tools for client/server signature management.
|
||||
|
||||
use super::key::{ScramKey, SCRAM_KEY_LEN};
|
||||
|
||||
/// A collection of message parts needed to derive the client's signature.
|
||||
#[derive(Debug)]
|
||||
pub struct SignatureBuilder<'a> {
|
||||
pub client_first_message_bare: &'a str,
|
||||
pub server_first_message: &'a str,
|
||||
pub client_final_message_without_proof: &'a str,
|
||||
}
|
||||
|
||||
impl SignatureBuilder<'_> {
|
||||
pub fn build(&self, key: &ScramKey) -> Signature {
|
||||
let parts = [
|
||||
self.client_first_message_bare.as_bytes(),
|
||||
b",",
|
||||
self.server_first_message.as_bytes(),
|
||||
b",",
|
||||
self.client_final_message_without_proof.as_bytes(),
|
||||
];
|
||||
|
||||
super::hmac_sha256(key.as_ref(), parts).into()
|
||||
}
|
||||
}
|
||||
|
||||
/// A computed value which, when xored with `ClientProof`,
|
||||
/// produces `ClientKey` that we need for authentication.
|
||||
#[derive(Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct Signature {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
/// Derive `ClientKey` from client's signature and proof.
|
||||
pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey {
|
||||
// This is how the proof is calculated:
|
||||
//
|
||||
// 1. sha256(ClientKey) -> StoredKey
|
||||
// 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature
|
||||
// 3. ClientKey ^ ClientSignature -> ClientProof
|
||||
//
|
||||
// Step 3 implies that we can restore ClientKey from the proof
|
||||
// by xoring the latter with the ClientSignature. Afterwards we
|
||||
// can check that the presumed ClientKey meets our expectations.
|
||||
let mut signature = self.bytes;
|
||||
for (i, x) in proof.iter().enumerate() {
|
||||
signature[i] ^= x;
|
||||
}
|
||||
|
||||
signature.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SCRAM_KEY_LEN]> for Signature {
|
||||
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for Signature {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.bytes
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user