channel binding (#5683)

## Problem

channel binding protects scram from sophisticated MITM attacks where the
attacker is able to produce 'valid' TLS certificates.

## Summary of changes

get the tls-server-end-point channel binding, and verify it is correct
for the SCRAM-SHA-256-PLUS authentication flow
This commit is contained in:
Conrad Ludgate
2023-11-27 21:45:15 +00:00
committed by GitHub
parent e09bb9974c
commit 316309c85b
15 changed files with 601 additions and 137 deletions

View File

@@ -76,3 +76,4 @@ tokio-util.workspace = true
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true
postgres-protocol.workspace = true

View File

@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
use tokio_postgres::config::AuthKeys;
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
use crate::stream::Stream;
use crate::{
auth::{self, ClientCredentials},
config::AuthenticationConfig,
@@ -131,7 +132,7 @@ async fn auth_quirks_creds(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
@@ -165,7 +166,7 @@ async fn auth_quirks(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
@@ -241,7 +242,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
pub async fn authenticate(
&mut self,
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,

View File

@@ -6,7 +6,7 @@ use crate::{
console::{self, AuthInfo, ConsoleReqExtra},
proxy::LatencyTimer,
sasl, scram,
stream::PqStream,
stream::{PqStream, Stream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -15,7 +15,7 @@ pub(super) async fn authenticate(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
latency_timer: &mut LatencyTimer,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {

View File

@@ -2,7 +2,7 @@ use super::{AuthSuccess, ComputeCredentials};
use crate::{
auth::{self, AuthFlow, ClientCredentials},
proxy::LatencyTimer,
stream,
stream::{self, Stream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -12,7 +12,7 @@ use tracing::{info, warn};
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub async fn cleartext_hack(
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
warn!("cleartext auth flow override is enabled, proceeding");
@@ -37,7 +37,7 @@ pub async fn cleartext_hack(
/// Very similar to [`cleartext_hack`], but there's a specific password format.
pub async fn password_hack(
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
warn!("project not specified, resorting to the password hack auth flow");

View File

@@ -1,16 +1,21 @@
//! Main authentication flow.
use super::{AuthErrorImpl, PasswordHackPayload};
use crate::{sasl, scram, stream::PqStream};
use crate::{
config::TlsServerEndPoint,
sasl, scram,
stream::{PqStream, Stream},
};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
/// Every authentication selector is supposed to implement this trait.
pub trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self) -> BeMessage<'_>;
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
@@ -21,8 +26,14 @@ pub struct Scram<'a>(pub &'a scram::ServerSecret);
impl AuthMethod for Scram<'_> {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
if channel_binding {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
} else {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
scram::METHODS_WITHOUT_PLUS,
))
}
}
}
@@ -32,7 +43,7 @@ pub struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
@@ -43,37 +54,44 @@ pub struct CleartextPassword;
impl AuthMethod for CleartextPassword {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, Stream, State> {
pub struct AuthFlow<'a, S, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream>,
stream: &'a mut PqStream<Stream<S>>,
/// State might contain ancillary data (see [`Self::begin`]).
state: State,
tls_server_end_point: TlsServerEndPoint,
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
/// Create a new wrapper for client authentication.
pub fn new(stream: &'a mut PqStream<S>) -> Self {
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
Self {
stream,
state: Begin,
tls_server_end_point,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
self.stream.write_message(&method.first_message()).await?;
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
tls_server_end_point: self.tls_server_end_point,
})
}
}
@@ -123,9 +141,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
return Err(super::AuthError::bad_auth_method(sasl.method));
}
info!("client chooses {}", sasl.method);
let secret = self.state.0;
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(secret, rand::random, None))
.authenticate(scram::Exchange::new(
secret,
rand::random,
self.tls_server_end_point,
))
.await?;
Ok(outcome)

View File

@@ -6,6 +6,8 @@
use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use tokio::net::TcpListener;
use anyhow::{anyhow, bail, ensure, Context};
@@ -65,7 +67,7 @@ async fn main() -> anyhow::Result<()> {
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
// Configure TLS
let tls_config: Arc<rustls::ServerConfig> = match (
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -89,16 +91,22 @@ async fn main() -> anyhow::Result<()> {
))?
.into_iter()
.map(rustls::Certificate)
.collect()
.collect_vec()
};
rustls::ServerConfig::builder()
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into()
.into();
(tls_config, tls_server_end_point)
}
_ => bail!("tls-key and tls-cert must be specified"),
};
@@ -113,6 +121,7 @@ async fn main() -> anyhow::Result<()> {
let main = tokio::spawn(task_main(
Arc::new(destination),
tls_config,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
));
@@ -134,6 +143,7 @@ async fn main() -> anyhow::Result<()> {
async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -159,7 +169,7 @@ async fn task_main(
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
handle_client(dest_suffix, tls_config, socket).await
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -207,6 +217,7 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
@@ -231,7 +242,11 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(raw.upgrade(tls_config).await?)
Ok(Stream::Tls {
tls: Box::new(raw.upgrade(tls_config).await?),
tls_server_end_point,
})
}
unexpected => {
info!(
@@ -246,9 +261,10 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
async fn handle_client(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
let tls_stream = ssl_handshake(stream, tls_config).await?;
let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of

View File

@@ -1,12 +1,15 @@
use crate::auth;
use anyhow::{bail, ensure, Context, Ok};
use rustls::sign;
use rustls::{sign, Certificate, PrivateKey};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
time::Duration,
};
use tracing::{error, info};
use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
@@ -27,6 +30,7 @@ pub struct MetricCollectionConfig {
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: Option<HashSet<String>>,
pub cert_resolver: Arc<CertResolver>,
}
pub struct HttpConfig {
@@ -52,7 +56,7 @@ pub fn configure_tls(
let mut cert_resolver = CertResolver::new();
// add default certificate
cert_resolver.add_cert(key_path, cert_path, true)?;
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
@@ -64,7 +68,7 @@ pub fn configure_tls(
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert(
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
@@ -76,35 +80,97 @@ pub fn configure_tls(
let common_names = cert_resolver.get_common_names();
let cert_resolver = Arc::new(cert_resolver);
let config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
// allow TLS 1.2 to be compatible with older client libraries
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_cert_resolver(Arc::new(cert_resolver))
.with_cert_resolver(cert_resolver.clone())
.into();
Ok(TlsConfig {
config,
common_names: Some(common_names),
cert_resolver,
})
}
struct CertResolver {
certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
default: Option<Arc<rustls::sign::CertifiedKey>>,
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl CertResolver {
fn new() -> Self {
Self {
certs: HashMap::new(),
default: None,
impl TlsServerEndPoint {
pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(&cert.0)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] =
Sha256::new().chain_update(&cert.0).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
fn add_cert(
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[derive(Default)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
pub fn new() -> Self {
Self::default()
}
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
@@ -120,57 +186,65 @@ impl CertResolver {
keys.pop().map(rustls::PrivateKey).unwrap()
};
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.context(format!(
.with_context(|| {
format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
))?
)
})?
.into_iter()
.map(rustls::Certificate)
.collect()
};
let common_name = {
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
.context(format!(
"Failed to parse PEM object from bytes from file at '{cert_path}'."
))?
.1;
let common_name = pem.parse_x509()?.subject().to_string();
self.add_cert(priv_key, cert_chain, is_default)
}
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
} else {
common_name.strip_prefix("CN=").map(|s| s.to_string())
}
pub fn add_cert(
&mut self,
priv_key: PrivateKey,
cert_chain: Vec<Certificate>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(&first_cert.0)
.context("Failed to parse PEM object from cerficiate")?
.1;
let common_name = pem.subject().to_string();
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
let common_name = if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
} else {
common_name.strip_prefix("CN=").map(|s| s.to_string())
}
.context(format!(
"Failed to parse common name from certificate at '{cert_path}'."
))?;
.context("Failed to parse common name from certificate")?;
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some(cert.clone());
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, cert);
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
fn get_common_names(&self) -> HashSet<String> {
pub fn get_common_names(&self) -> HashSet<String> {
self.certs.keys().map(|s| s.to_string()).collect()
}
}
@@ -178,15 +252,24 @@ impl CertResolver {
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
_client_hello: rustls::server::ClientHello,
client_hello: rustls::server::ClientHello,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
//
// With the current coding foo.com will match *.foo.com and that
// repeats behavior of the old code.
if let Some(mut sni_name) = _client_hello.server_name() {
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return Some(cert.clone());

View File

@@ -470,7 +470,17 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
}
}
_ => bail!(ERR_PROTO_VIOLATION),
@@ -875,7 +885,7 @@ pub async fn proxy_pass(
/// Thin connection context.
struct Client<'a, S> {
/// The underlying libpq protocol stream.
stream: PqStream<S>,
stream: PqStream<Stream<S>>,
/// Client credentials that we care about.
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
/// KV-dictionary with PostgreSQL connection params.
@@ -889,7 +899,7 @@ struct Client<'a, S> {
impl<'a, S> Client<'a, S> {
/// Construct a new connection context.
fn new(
stream: PqStream<S>,
stream: PqStream<Stream<S>>,
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,

View File

@@ -1,19 +1,23 @@
//! A group of high-level tests for connection establishing logic and auth.
//!
mod mitm;
use super::*;
use crate::auth::backend::TestBackend;
use crate::auth::ClientCredentials;
use crate::config::CertResolver;
use crate::console::{CachedNodeInfo, NodeInfo};
use crate::{auth, http, sasl, scram};
use async_trait::async_trait;
use rstest::rstest;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::MakeRustlsConnect;
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
hostname: &str,
common_name: &str,
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default();
@@ -21,7 +25,15 @@ fn generate_certs(
params
})?;
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
let cert = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, common_name);
params
})?;
Ok((
rustls::Certificate(ca.serialize_der()?),
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
@@ -37,7 +49,14 @@ struct ClientConfig<'a> {
impl ClientConfig<'_> {
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
self,
) -> anyhow::Result<impl tokio_postgres::tls::TlsConnect<S>> {
) -> anyhow::Result<
impl tokio_postgres::tls::TlsConnect<
S,
Error = impl std::fmt::Debug,
Future = impl Send,
Stream = RustlsStream<S>,
>,
> {
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
Ok(tls)
@@ -49,20 +68,24 @@ fn generate_tls_config<'a>(
hostname: &'a str,
common_name: &'a str,
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
let (ca, cert, key) = generate_certs(hostname)?;
let (ca, cert, key) = generate_certs(hostname, common_name)?;
let tls_config = {
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert], key)?
.with_single_cert(vec![cert.clone()], key.clone())?
.into();
let common_names = Some([common_name.to_owned()].iter().cloned().collect());
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
let common_names = Some(cert_resolver.get_common_names());
TlsConfig {
config,
common_names,
cert_resolver: Arc::new(cert_resolver),
}
};
@@ -253,6 +276,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
));
let (_client, _conn) = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
.user("user")
.dbname("db")
.password(password)
@@ -263,6 +287,30 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
proxy.await?
}
#[tokio::test]
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new("password")?,
));
let (_client, _conn) = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
.password("password")
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await?;
proxy.await?
}
#[tokio::test]
async fn scram_auth_mock() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);

View File

@@ -0,0 +1,257 @@
//! Man-in-the-middle tests
//!
//! Channel binding should prevent a proxy server
//! - that has access to create valid certificates -
//! from controlling the TLS connection.
use std::fmt::Debug;
use super::*;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_protocol::message::frontend;
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::TlsConnect;
use tokio_util::codec::{Decoder, Encoder};
enum Intercept {
None,
Methods,
SASLResponse,
}
async fn proxy_mitm(
intercept: Intercept,
) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) {
let (end_server1, client1) = tokio::io::duplex(1024);
let (server2, end_client2) = tokio::io::duplex(1024);
let (client_config1, server_config1) =
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
let (client_config2, server_config2) =
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
tokio::spawn(async move {
// begin handshake with end_server
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
// process handshake with end_client
let (end_client, startup) =
handshake(client1, Some(&server_config1), &CancelMap::default())
.await
.unwrap()
.unwrap();
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
let (end_client, buf) = end_client.framed.into_inner();
assert!(buf.is_empty());
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
// give the end_server the startup parameters
let mut buf = BytesMut::new();
frontend::startup_message(startup.iter(), &mut buf).unwrap();
end_server.send(buf.freeze()).await.unwrap();
// proxy messages between end_client and end_server
loop {
tokio::select! {
message = end_server.next() => {
match message {
Some(Ok(message)) => {
// intercept SASL and return only SCRAM-SHA-256 ;)
if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) {
end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
continue;
}
end_client.send(message).await.unwrap()
}
_ => break,
}
}
message = end_client.next() => {
match message {
Some(Ok(message)) => {
// intercept SASL response and return SCRAM-SHA-256 with no channel binding ;)
if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") {
let sasl_message = &message[1+4+19+4..];
let mut new_message = b"n,,".to_vec();
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
let mut buf = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
end_server.send(buf.freeze()).await.unwrap();
continue;
}
end_server.send(message).await.unwrap()
}
_ => break,
}
}
else => { break }
}
}
});
(end_server1, end_client2, client_config1, server_config2)
}
/// taken from tokio-postgres
pub async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
T::Error: Debug,
{
let mut buf = BytesMut::new();
frontend::ssl_request(&mut buf);
stream.write_all(&buf).await.unwrap();
let mut buf = [0];
stream.read_exact(&mut buf).await.unwrap();
if buf[0] != b'S' {
panic!("ssl not supported by server");
}
tls.connect(stream).await.unwrap()
}
struct PgFrame;
impl Decoder for PgFrame {
type Item = Bytes;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 5 {
src.reserve(5 - src.len());
return Ok(None);
}
let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
if src.len() < len {
src.reserve(len - src.len());
return Ok(None);
}
Ok(Some(src.split_to(len).freeze()))
}
}
impl Encoder<Bytes> for PgFrame {
type Error = io::Error;
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.extend_from_slice(&item);
Ok(())
}
}
/// If the client doesn't support channel bindings, it can be exploited.
#[tokio::test]
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new("password")?,
));
let _client_err = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
.password("password")
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await?;
proxy.await?
}
/// If the client chooses SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
connect_failure(
Intercept::None,
tokio_postgres::config::ChannelBinding::Prefer,
)
.await
}
/// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail
#[tokio::test]
async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
connect_failure(
Intercept::Methods,
tokio_postgres::config::ChannelBinding::Prefer,
)
.await
}
/// If the MITM pretends like the client doesn't support channel bindings, it will fail
#[tokio::test]
async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
connect_failure(
Intercept::SASLResponse,
tokio_postgres::config::ChannelBinding::Prefer,
)
.await
}
/// If the client chooses SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
connect_failure(
Intercept::None,
tokio_postgres::config::ChannelBinding::Require,
)
.await
}
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
connect_failure(
Intercept::Methods,
tokio_postgres::config::ChannelBinding::Require,
)
.await
}
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
#[tokio::test]
async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
connect_failure(
Intercept::SASLResponse,
tokio_postgres::config::ChannelBinding::Require,
)
.await
}
async fn connect_failure(
intercept: Intercept,
channel_binding: tokio_postgres::config::ChannelBinding,
) -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new("password")?,
));
let _client_err = tokio_postgres::Config::new()
.channel_binding(channel_binding)
.user("user")
.dbname("db")
.password("password")
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await
.err()
.context("client shouldn't be able to connect")?;
let _server_err = proxy
.await?
.err()
.context("server shouldn't accept client")?;
Ok(())
}

View File

@@ -36,9 +36,9 @@ impl<'a> ChannelBinding<&'a str> {
impl<T: std::fmt::Display> ChannelBinding<T> {
/// Encode channel binding data as base64 for subsequent checks.
pub fn encode<E>(
pub fn encode<'a, E>(
&self,
get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
) -> Result<std::borrow::Cow<'static, str>, E> {
use ChannelBinding::*;
Ok(match self {
@@ -51,12 +51,11 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
"eSws".into()
}
Required(mode) => {
let msg = format!(
"p={mode},,{data}",
mode = mode,
data = get_cbind_data(mode)?
);
base64::encode(msg).into()
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
base64::encode(&cbind_input).into()
}
})
}
@@ -77,7 +76,7 @@ mod tests {
];
for (cb, input) in cases {
assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
}
Ok(())

View File

@@ -22,9 +22,12 @@ pub use secret::ServerSecret;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
// TODO: add SCRAM-SHA-256-PLUS
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
/// A list of supported SCRAM methods.
pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
/// Decode base64 into array without any heap allocations
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
@@ -80,7 +83,11 @@ mod tests {
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
let mut exchange = Exchange::new(&secret, || NONCE, None);
let mut exchange = Exchange::new(
&secret,
|| NONCE,
crate::config::TlsServerEndPoint::Undefined,
);
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";

View File

@@ -5,9 +5,11 @@ use super::messages::{
};
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use crate::config;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
/// The only channel binding mode we currently support.
#[derive(Debug)]
struct TlsServerEndPoint;
impl std::fmt::Display for TlsServerEndPoint {
@@ -43,20 +45,20 @@ pub struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
cert_digest: Option<&'a [u8]>,
tls_server_end_point: config::TlsServerEndPoint,
}
impl<'a> Exchange<'a> {
pub fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
cert_digest: Option<&'a [u8]>,
tls_server_end_point: config::TlsServerEndPoint,
) -> Self {
Self {
state: ExchangeState::Initial,
secret,
nonce,
cert_digest,
tls_server_end_point,
}
}
}
@@ -71,6 +73,14 @@ impl sasl::Mechanism for Exchange<'_> {
let client_first_message = ClientFirstMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
// If the flag is set to "y" and the server supports channel
// binding, the server MUST fail authentication
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
&& self.tls_server_end_point.supported()
{
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
}
let server_first_message = client_first_message.build_server_first_message(
&(self.nonce)(),
&self.secret.salt_base64,
@@ -94,10 +104,11 @@ impl sasl::Mechanism for Exchange<'_> {
let client_final_message = ClientFinalMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| {
self.cert_digest
.map(base64::encode)
.ok_or(SaslError::ChannelBindingFailed("no cert digest provided"))
let channel_binding = cbind_flag.encode(|_| match &self.tls_server_end_point {
config::TlsServerEndPoint::Sha256(x) => Ok(x),
config::TlsServerEndPoint::Undefined => {
Err(SaslError::ChannelBindingFailed("no cert digest provided"))
}
})?;
// This might've been caused by a MITM attack

View File

@@ -1,7 +1,8 @@
use crate::config::TlsServerEndPoint;
use crate::error::UserFacingError;
use anyhow::bail;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
@@ -17,7 +18,7 @@ use tokio_rustls::server::TlsStream;
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
framed: Framed<S>,
pub(crate) framed: Framed<S>,
}
impl<S> PqStream<S> {
@@ -118,19 +119,21 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
}
}
pin_project! {
/// Wrapper for upgrading raw streams into secure streams.
/// NOTE: it should be possible to decompose this object as necessary.
#[project = StreamProj]
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { #[pin] raw: S },
/// Wrapper for upgrading raw streams into secure streams.
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { raw: S },
Tls {
/// We box [`TlsStream`] since it can be quite large.
Tls { #[pin] tls: Box<TlsStream<S>> },
}
tls: Box<TlsStream<S>>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
},
}
impl<S: Unpin> Unpin for Stream<S> {}
impl<S> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
@@ -141,7 +144,17 @@ impl<S> Stream<S> {
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls } => tls.get_ref().1.server_name(),
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
}
}
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
match self {
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
Stream::Tls {
tls_server_end_point,
..
} => *tls_server_end_point,
}
}
}
@@ -158,12 +171,9 @@ pub enum StreamUpgradeError {
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
match self {
Stream::Raw { raw } => {
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
Ok(Stream::Tls { tls })
}
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
}
}
@@ -171,50 +181,46 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
fn poll_read(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_read(context, buf),
Tls { tls } => tls.poll_read(context, buf),
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
fn poll_write(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_write(context, buf),
Tls { tls } => tls.poll_write(context, buf),
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_flush(context),
Tls { tls } => tls.poll_flush(context),
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_shutdown(context),
Tls { tls } => tls.poll_shutdown(context),
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
}
}
}