mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 05:52:55 +00:00
channel binding (#5683)
## Problem channel binding protects scram from sophisticated MITM attacks where the attacker is able to produce 'valid' TLS certificates. ## Summary of changes get the tls-server-end-point channel binding, and verify it is correct for the SCRAM-SHA-256-PLUS authentication flow
This commit is contained in:
@@ -76,3 +76,4 @@ tokio-util.workspace = true
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
|
||||
@@ -6,6 +6,7 @@ pub use link::LinkAuthError;
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
|
||||
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
auth::{self, ClientCredentials},
|
||||
config::AuthenticationConfig,
|
||||
@@ -131,7 +132,7 @@ async fn auth_quirks_creds(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
@@ -165,7 +166,7 @@ async fn auth_quirks(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
@@ -241,7 +242,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
pub async fn authenticate(
|
||||
&mut self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::{
|
||||
console::{self, AuthInfo, ConsoleReqExtra},
|
||||
proxy::LatencyTimer,
|
||||
sasl, scram,
|
||||
stream::PqStream,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
@@ -15,7 +15,7 @@ pub(super) async fn authenticate(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::{AuthSuccess, ComputeCredentials};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
proxy::LatencyTimer,
|
||||
stream,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
@@ -12,7 +12,7 @@ use tracing::{info, warn};
|
||||
/// These properties are benefical for serverless JS workers, so we
|
||||
/// use this mechanism for websocket connections.
|
||||
pub async fn cleartext_hack(
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
@@ -37,7 +37,7 @@ pub async fn cleartext_hack(
|
||||
/// Very similar to [`cleartext_hack`], but there's a specific password format.
|
||||
pub async fn password_hack(
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::{AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{sasl, scram, stream::PqStream};
|
||||
use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
sasl, scram,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self) -> BeMessage<'_>;
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
@@ -21,8 +26,14 @@ pub struct Scram<'a>(pub &'a scram::ServerSecret);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
|
||||
if channel_binding {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
} else {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
|
||||
scram::METHODS_WITHOUT_PLUS,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +43,7 @@ pub struct PasswordHack;
|
||||
|
||||
impl AuthMethod for PasswordHack {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
@@ -43,37 +54,44 @@ pub struct CleartextPassword;
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, Stream, State> {
|
||||
pub struct AuthFlow<'a, S, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream>,
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
state: State,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub fn new(stream: &'a mut PqStream<S>) -> Self {
|
||||
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
|
||||
Self {
|
||||
stream,
|
||||
state: Begin,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
self.stream.write_message(&method.first_message()).await?;
|
||||
self.stream
|
||||
.write_message(&method.first_message(self.tls_server_end_point.supported()))
|
||||
.await?;
|
||||
|
||||
Ok(AuthFlow {
|
||||
stream: self.stream,
|
||||
state: method,
|
||||
tls_server_end_point: self.tls_server_end_point,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -123,9 +141,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||
}
|
||||
|
||||
info!("client chooses {}", sasl.method);
|
||||
|
||||
let secret = self.state.0;
|
||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(secret, rand::random, None))
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
rand::random,
|
||||
self.tls_server_end_point,
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(outcome)
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use futures::future::Either;
|
||||
use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use anyhow::{anyhow, bail, ensure, Context};
|
||||
@@ -65,7 +67,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
|
||||
|
||||
// Configure TLS
|
||||
let tls_config: Arc<rustls::ServerConfig> = match (
|
||||
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
@@ -89,16 +91,22 @@ async fn main() -> anyhow::Result<()> {
|
||||
))?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
.collect_vec()
|
||||
};
|
||||
|
||||
rustls::ServerConfig::builder()
|
||||
// needed for channel bindings
|
||||
let first_cert = cert_chain.first().context("missing certificate")?;
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let tls_config = rustls::ServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into()
|
||||
.into();
|
||||
|
||||
(tls_config, tls_server_end_point)
|
||||
}
|
||||
_ => bail!("tls-key and tls-cert must be specified"),
|
||||
};
|
||||
@@ -113,6 +121,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let main = tokio::spawn(task_main(
|
||||
Arc::new(destination),
|
||||
tls_config,
|
||||
tls_server_end_point,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
@@ -134,6 +143,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -159,7 +169,7 @@ async fn task_main(
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
info!(%peer_addr, "serving");
|
||||
handle_client(dest_suffix, tls_config, socket).await
|
||||
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -207,6 +217,7 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
@@ -231,7 +242,11 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
Ok(raw.upgrade(tls_config).await?)
|
||||
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(raw.upgrade(tls_config).await?),
|
||||
tls_server_end_point,
|
||||
})
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
@@ -246,9 +261,10 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
async fn handle_client(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let tls_stream = ssl_handshake(stream, tls_config).await?;
|
||||
let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use crate::auth;
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use rustls::sign;
|
||||
use rustls::{sign, Certificate, PrivateKey};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
@@ -27,6 +30,7 @@ pub struct MetricCollectionConfig {
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_names: Option<HashSet<String>>,
|
||||
pub cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
pub struct HttpConfig {
|
||||
@@ -52,7 +56,7 @@ pub fn configure_tls(
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
// add default certificate
|
||||
cert_resolver.add_cert(key_path, cert_path, true)?;
|
||||
cert_resolver.add_cert_path(key_path, cert_path, true)?;
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
@@ -64,7 +68,7 @@ pub fn configure_tls(
|
||||
let key_path = path.join("tls.key");
|
||||
let cert_path = path.join("tls.crt");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver.add_cert(
|
||||
cert_resolver.add_cert_path(
|
||||
&key_path.to_string_lossy(),
|
||||
&cert_path.to_string_lossy(),
|
||||
false,
|
||||
@@ -76,35 +80,97 @@ pub fn configure_tls(
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
let cert_resolver = Arc::new(cert_resolver);
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(Arc::new(cert_resolver))
|
||||
.with_cert_resolver(cert_resolver.clone())
|
||||
.into();
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
common_names: Some(common_names),
|
||||
cert_resolver,
|
||||
})
|
||||
}
|
||||
|
||||
struct CertResolver {
|
||||
certs: HashMap<String, Arc<rustls::sign::CertifiedKey>>,
|
||||
default: Option<Arc<rustls::sign::CertifiedKey>>,
|
||||
/// Channel binding parameter
|
||||
///
|
||||
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
|
||||
/// Description: The hash of the TLS server's certificate as it
|
||||
/// appears, octet for octet, in the server's Certificate message. Note
|
||||
/// that the Certificate message contains a certificate_list, in which
|
||||
/// the first element is the server's certificate.
|
||||
///
|
||||
/// The hash function is to be selected as follows:
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function and that hash function neither MD5 nor SHA-1, then use
|
||||
/// the hash function associated with the certificate's
|
||||
/// signatureAlgorithm;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses no hash functions or
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
certs: HashMap::new(),
|
||||
default: None,
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(&cert.0)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
info!(subject = %pem.subject, "parsing TLS certificate");
|
||||
|
||||
let reg = oid_registry::OidRegistry::default().with_all_crypto();
|
||||
let oid = pem.signature_algorithm.oid();
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] =
|
||||
Sha256::new().chain_update(&cert.0).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
|
||||
Ok(Self::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
fn add_cert(
|
||||
pub fn supported(&self) -> bool {
|
||||
!matches!(self, TlsServerEndPoint::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn add_cert_path(
|
||||
&mut self,
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
@@ -120,57 +186,65 @@ impl CertResolver {
|
||||
keys.pop().map(rustls::PrivateKey).unwrap()
|
||||
};
|
||||
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.context(format!(
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
)
|
||||
})?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
};
|
||||
|
||||
let common_name = {
|
||||
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
|
||||
.context(format!(
|
||||
"Failed to parse PEM object from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.1;
|
||||
let common_name = pem.parse_x509()?.subject().to_string();
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
}
|
||||
|
||||
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
pub fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKey,
|
||||
cert_chain: Vec<Certificate>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
let pem = x509_parser::parse_x509_certificate(&first_cert.0)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
let common_name = if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
.context(format!(
|
||||
"Failed to parse common name from certificate at '{cert_path}'."
|
||||
))?;
|
||||
.context("Failed to parse common name from certificate")?;
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
if is_default {
|
||||
self.default = Some(cert.clone());
|
||||
self.default = Some((cert.clone(), tls_server_end_point));
|
||||
}
|
||||
|
||||
self.certs.insert(common_name, cert);
|
||||
self.certs.insert(common_name, (cert, tls_server_end_point));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_common_names(&self) -> HashSet<String> {
|
||||
pub fn get_common_names(&self) -> HashSet<String> {
|
||||
self.certs.keys().map(|s| s.to_string()).collect()
|
||||
}
|
||||
}
|
||||
@@ -178,15 +252,24 @@ impl CertResolver {
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
_client_hello: rustls::server::ClientHello,
|
||||
client_hello: rustls::server::ClientHello,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn resolve(
|
||||
&self,
|
||||
server_name: Option<&str>,
|
||||
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
//
|
||||
// With the current coding foo.com will match *.foo.com and that
|
||||
// repeats behavior of the old code.
|
||||
if let Some(mut sni_name) = _client_hello.server_name() {
|
||||
if let Some(mut sni_name) = server_name {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return Some(cert.clone());
|
||||
|
||||
@@ -470,7 +470,17 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
|
||||
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(tls_stream.get_ref().1.server_name())
|
||||
.context("missing certificate")?;
|
||||
|
||||
stream = PqStream::new(Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
@@ -875,7 +885,7 @@ pub async fn proxy_pass(
|
||||
/// Thin connection context.
|
||||
struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<S>,
|
||||
stream: PqStream<Stream<S>>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
@@ -889,7 +899,7 @@ struct Client<'a, S> {
|
||||
impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<S>,
|
||||
stream: PqStream<Stream<S>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
//! A group of high-level tests for connection establishing logic and auth.
|
||||
//!
|
||||
|
||||
mod mitm;
|
||||
|
||||
use super::*;
|
||||
use crate::auth::backend::TestBackend;
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::{CachedNodeInfo, NodeInfo};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
hostname: &str,
|
||||
common_name: &str,
|
||||
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
|
||||
let ca = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::default();
|
||||
@@ -21,7 +25,15 @@ fn generate_certs(
|
||||
params
|
||||
})?;
|
||||
|
||||
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
|
||||
let cert = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
|
||||
params.distinguished_name = rcgen::DistinguishedName::new();
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, common_name);
|
||||
params
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
rustls::Certificate(ca.serialize_der()?),
|
||||
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
|
||||
@@ -37,7 +49,14 @@ struct ClientConfig<'a> {
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
self,
|
||||
) -> anyhow::Result<impl tokio_postgres::tls::TlsConnect<S>> {
|
||||
) -> anyhow::Result<
|
||||
impl tokio_postgres::tls::TlsConnect<
|
||||
S,
|
||||
Error = impl std::fmt::Debug,
|
||||
Future = impl Send,
|
||||
Stream = RustlsStream<S>,
|
||||
>,
|
||||
> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
Ok(tls)
|
||||
@@ -49,20 +68,24 @@ fn generate_tls_config<'a>(
|
||||
hostname: &'a str,
|
||||
common_name: &'a str,
|
||||
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
|
||||
let (ca, cert, key) = generate_certs(hostname)?;
|
||||
let (ca, cert, key) = generate_certs(hostname, common_name)?;
|
||||
|
||||
let tls_config = {
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert], key)?
|
||||
.with_single_cert(vec![cert.clone()], key.clone())?
|
||||
.into();
|
||||
|
||||
let common_names = Some([common_name.to_owned()].iter().cloned().collect());
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
|
||||
let common_names = Some(cert_resolver.get_common_names());
|
||||
|
||||
TlsConfig {
|
||||
config,
|
||||
common_names,
|
||||
cert_resolver: Arc::new(cert_resolver),
|
||||
}
|
||||
};
|
||||
|
||||
@@ -253,6 +276,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
@@ -263,6 +287,30 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password")?,
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
257
proxy/src/proxy/tests/mitm.rs
Normal file
257
proxy/src/proxy/tests/mitm.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
//! Man-in-the-middle tests
|
||||
//!
|
||||
//! Channel binding should prevent a proxy server
|
||||
//! - that has access to create valid certificates -
|
||||
//! from controlling the TLS connection.
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use super::*;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::TlsConnect;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
enum Intercept {
|
||||
None,
|
||||
Methods,
|
||||
SASLResponse,
|
||||
}
|
||||
|
||||
async fn proxy_mitm(
|
||||
intercept: Intercept,
|
||||
) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) {
|
||||
let (end_server1, client1) = tokio::io::duplex(1024);
|
||||
let (server2, end_client2) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config1, server_config1) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
|
||||
let (client_config2, server_config2) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// begin handshake with end_server
|
||||
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
|
||||
// process handshake with end_client
|
||||
let (end_client, startup) =
|
||||
handshake(client1, Some(&server_config1), &CancelMap::default())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
let (end_client, buf) = end_client.framed.into_inner();
|
||||
assert!(buf.is_empty());
|
||||
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
|
||||
|
||||
// give the end_server the startup parameters
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(startup.iter(), &mut buf).unwrap();
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
|
||||
// proxy messages between end_client and end_server
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = end_server.next() => {
|
||||
match message {
|
||||
Some(Ok(message)) => {
|
||||
// intercept SASL and return only SCRAM-SHA-256 ;)
|
||||
if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) {
|
||||
end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_client.send(message).await.unwrap()
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
message = end_client.next() => {
|
||||
match message {
|
||||
Some(Ok(message)) => {
|
||||
// intercept SASL response and return SCRAM-SHA-256 with no channel binding ;)
|
||||
if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") {
|
||||
let sasl_message = &message[1+4+19+4..];
|
||||
let mut new_message = b"n,,".to_vec();
|
||||
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
|
||||
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_server.send(message).await.unwrap()
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
else => { break }
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(end_server1, end_client2, client_config1, server_config2)
|
||||
}
|
||||
|
||||
/// taken from tokio-postgres
|
||||
pub async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
T::Error: Debug,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::ssl_request(&mut buf);
|
||||
stream.write_all(&buf).await.unwrap();
|
||||
|
||||
let mut buf = [0];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
if buf[0] != b'S' {
|
||||
panic!("ssl not supported by server");
|
||||
}
|
||||
|
||||
tls.connect(stream).await.unwrap()
|
||||
}
|
||||
|
||||
struct PgFrame;
|
||||
impl Decoder for PgFrame {
|
||||
type Item = Bytes;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if src.len() < 5 {
|
||||
src.reserve(5 - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
|
||||
if src.len() < len {
|
||||
src.reserve(len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(src.split_to(len).freeze()))
|
||||
}
|
||||
}
|
||||
impl Encoder<Bytes> for PgFrame {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
dst.extend_from_slice(&item);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// If the client doesn't support channel bindings, it can be exploited.
|
||||
#[tokio::test]
|
||||
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password")?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
/// If the client chooses SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the MITM pretends like the client doesn't support channel bindings, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client chooses SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_failure(
|
||||
intercept: Intercept,
|
||||
channel_binding: tokio_postgres::config::ChannelBinding,
|
||||
) -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password")?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(channel_binding)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await
|
||||
.err()
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
let _server_err = proxy
|
||||
.await?
|
||||
.err()
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -36,9 +36,9 @@ impl<'a> ChannelBinding<&'a str> {
|
||||
|
||||
impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
/// Encode channel binding data as base64 for subsequent checks.
|
||||
pub fn encode<E>(
|
||||
pub fn encode<'a, E>(
|
||||
&self,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
|
||||
) -> Result<std::borrow::Cow<'static, str>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
@@ -51,12 +51,11 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
"eSws".into()
|
||||
}
|
||||
Required(mode) => {
|
||||
let msg = format!(
|
||||
"p={mode},,{data}",
|
||||
mode = mode,
|
||||
data = get_cbind_data(mode)?
|
||||
);
|
||||
base64::encode(msg).into()
|
||||
use std::io::Write;
|
||||
let mut cbind_input = vec![];
|
||||
write!(&mut cbind_input, "p={mode},,",).unwrap();
|
||||
cbind_input.extend_from_slice(get_cbind_data(mode)?);
|
||||
base64::encode(&cbind_input).into()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -77,7 +76,7 @@ mod tests {
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
|
||||
assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -22,9 +22,12 @@ pub use secret::ServerSecret;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
// TODO: add SCRAM-SHA-256-PLUS
|
||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
|
||||
/// A list of supported SCRAM methods.
|
||||
pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
|
||||
pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
|
||||
pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
|
||||
|
||||
/// Decode base64 into array without any heap allocations
|
||||
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
|
||||
@@ -80,7 +83,11 @@ mod tests {
|
||||
const NONCE: [u8; 18] = [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
];
|
||||
let mut exchange = Exchange::new(&secret, || NONCE, None);
|
||||
let mut exchange = Exchange::new(
|
||||
&secret,
|
||||
|| NONCE,
|
||||
crate::config::TlsServerEndPoint::Undefined,
|
||||
);
|
||||
|
||||
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
|
||||
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
|
||||
|
||||
@@ -5,9 +5,11 @@ use super::messages::{
|
||||
};
|
||||
use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use crate::config;
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
#[derive(Debug)]
|
||||
struct TlsServerEndPoint;
|
||||
|
||||
impl std::fmt::Display for TlsServerEndPoint {
|
||||
@@ -43,20 +45,20 @@ pub struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
}
|
||||
|
||||
impl<'a> Exchange<'a> {
|
||||
pub fn new(
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: ExchangeState::Initial,
|
||||
secret,
|
||||
nonce,
|
||||
cert_digest,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,6 +73,14 @@ impl sasl::Mechanism for Exchange<'_> {
|
||||
let client_first_message = ClientFirstMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
|
||||
|
||||
// If the flag is set to "y" and the server supports channel
|
||||
// binding, the server MUST fail authentication
|
||||
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
|
||||
&& self.tls_server_end_point.supported()
|
||||
{
|
||||
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
|
||||
}
|
||||
|
||||
let server_first_message = client_first_message.build_server_first_message(
|
||||
&(self.nonce)(),
|
||||
&self.secret.salt_base64,
|
||||
@@ -94,10 +104,11 @@ impl sasl::Mechanism for Exchange<'_> {
|
||||
let client_final_message = ClientFinalMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| {
|
||||
self.cert_digest
|
||||
.map(base64::encode)
|
||||
.ok_or(SaslError::ChannelBindingFailed("no cert digest provided"))
|
||||
let channel_binding = cbind_flag.encode(|_| match &self.tls_server_end_point {
|
||||
config::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
config::TlsServerEndPoint::Undefined => {
|
||||
Err(SaslError::ChannelBindingFailed("no cert digest provided"))
|
||||
}
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::UserFacingError;
|
||||
use anyhow::bail;
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
@@ -17,7 +18,7 @@ use tokio_rustls::server::TlsStream;
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
framed: Framed<S>,
|
||||
pub(crate) framed: Framed<S>,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
@@ -118,19 +119,21 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
/// NOTE: it should be possible to decompose this object as necessary.
|
||||
#[project = StreamProj]
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { #[pin] raw: S },
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { raw: S },
|
||||
Tls {
|
||||
/// We box [`TlsStream`] since it can be quite large.
|
||||
Tls { #[pin] tls: Box<TlsStream<S>> },
|
||||
}
|
||||
tls: Box<TlsStream<S>>,
|
||||
/// Channel binding parameter
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
},
|
||||
}
|
||||
|
||||
impl<S: Unpin> Unpin for Stream<S> {}
|
||||
|
||||
impl<S> Stream<S> {
|
||||
/// Construct a new instance from a raw stream.
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
@@ -141,7 +144,17 @@ impl<S> Stream<S> {
|
||||
pub fn sni_hostname(&self) -> Option<&str> {
|
||||
match self {
|
||||
Stream::Raw { .. } => None,
|
||||
Stream::Tls { tls } => tls.get_ref().1.server_name(),
|
||||
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
|
||||
match self {
|
||||
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
|
||||
Stream::Tls {
|
||||
tls_server_end_point,
|
||||
..
|
||||
} => *tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,12 +171,9 @@ pub enum StreamUpgradeError {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
/// If possible, upgrade raw stream into a secure TLS-based stream.
|
||||
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<Self, StreamUpgradeError> {
|
||||
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> Result<TlsStream<S>, StreamUpgradeError> {
|
||||
match self {
|
||||
Stream::Raw { raw } => {
|
||||
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
|
||||
Ok(Stream::Tls { tls })
|
||||
}
|
||||
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?),
|
||||
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
|
||||
}
|
||||
}
|
||||
@@ -171,50 +181,46 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_read(context, buf),
|
||||
Tls { tls } => tls.poll_read(context, buf),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_write(context, buf),
|
||||
Tls { tls } => tls.poll_write(context, buf),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_flush(context),
|
||||
Tls { tls } => tls.poll_flush(context),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_shutdown(context),
|
||||
Tls { tls } => tls.poll_shutdown(context),
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user