update rustls

This commit is contained in:
Conrad Ludgate
2023-12-19 15:39:26 +00:00
parent 8b91bbc38e
commit dc109c42bc
9 changed files with 280 additions and 93 deletions

View File

@@ -1,6 +1,9 @@
use crate::{auth, rate_limiter::RateBucketInfo};
use anyhow::{bail, ensure, Context, Ok};
use rustls::{sign, Certificate, PrivateKey};
use rustls::{
crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
@@ -85,14 +88,14 @@ pub fn configure_tls(
let cert_resolver = Arc::new(cert_resolver);
let config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
// allow TLS 1.2 to be compatible with older client libraries
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone())
.into();
// allow TLS 1.2 to be compatible with older client libraries
let config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone())
.into();
Ok(TlsConfig {
config,
@@ -130,14 +133,14 @@ pub enum TlsServerEndPoint {
}
impl TlsServerEndPoint {
pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(&cert.0)
let pem = x509_parser::parse_x509_certificate(&cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
@@ -148,7 +151,7 @@ impl TlsServerEndPoint {
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] =
Sha256::new().chain_update(&cert.0).finalize().into();
Sha256::new().chain_update(&cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
@@ -162,7 +165,7 @@ impl TlsServerEndPoint {
}
}
#[derive(Default)]
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
@@ -182,11 +185,12 @@ impl CertResolver {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
.context(format!("Failed to parse TLS keys at '{key_path}'"))?;
let keys: Result<Vec<_>, _> =
rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect();
let mut keys = keys.context(format!("Failed to parse TLS keys at '{key_path}'"))?;
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
keys.pop().map(rustls::PrivateKey).unwrap()
keys.pop().unwrap()
};
let cert_chain_bytes = std::fs::read(cert_path)
@@ -194,30 +198,29 @@ impl CertResolver {
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.into_iter()
.collect::<Result<Vec<_>, _>>()
.with_context(|| {
format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
)
})?
.into_iter()
.map(rustls::Certificate)
.collect()
};
self.add_cert(priv_key, cert_chain, is_default)
self.add_cert(PrivateKeyDer::Pkcs8(priv_key), cert_chain, is_default)
}
pub fn add_cert(
&mut self,
priv_key: PrivateKey,
cert_chain: Vec<Certificate>,
priv_key: PrivateKeyDer,
cert_chain: Vec<CertificateDer>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(&first_cert.0)
let pem = x509_parser::parse_x509_certificate(&first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;

View File

@@ -328,19 +328,23 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
impl AsyncAccept for ProxyProtocolAccept {
type Connection = WithClientIp<AddrStream>;
type Address = std::net::SocketAddr;
type Error = io::Error;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
) -> Poll<Result<(Self::Connection, Self::Address), Self::Error>> {
use hyper::server::accept::Accept;
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
let Some(conn) = conn else {
return Poll::Ready(None);
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::NotConnected,
"no incoming connection?",
)));
};
Poll::Ready(Some(Ok(WithClientIp::new(conn))))
let addr = conn.remote_addr();
Poll::Ready(Ok((WithClientIp::new(conn), addr)))
}
}

View File

@@ -12,6 +12,7 @@ use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
use crate::{auth, http, sasl, scram};
use async_trait::async_trait;
use rstest::rstest;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
@@ -20,7 +21,11 @@ use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
fn generate_certs(
hostname: &str,
common_name: &str,
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
) -> anyhow::Result<(
CertificateDer<'static>,
CertificateDer<'static>,
PrivateKeyDer<'static>,
)> {
let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
@@ -37,9 +42,9 @@ fn generate_certs(
})?;
Ok((
rustls::Certificate(ca.serialize_der()?),
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
rustls::PrivateKey(cert.serialize_private_key_der()),
CertificateDer::from(ca.serialize_der()?),
CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(cert.serialize_private_key_der())),
))
}
@@ -74,7 +79,6 @@ fn generate_tls_config<'a>(
let tls_config = {
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone())?
.into();
@@ -93,7 +97,6 @@ fn generate_tls_config<'a>(
let client_config = {
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&ca)?;

View File

@@ -77,14 +77,19 @@ pub async fn task_main(
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete`
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn {
error!("failed to accept TLS connection for websockets: {err:?}");
ready(false)
} else {
ready(true)
}
});
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming)
.map(|x| match x {
Ok((conn, _)) => Ok(conn),
Err(e) => Err(e),
})
.filter(|conn| {
if let Err(err) = conn {
error!("failed to accept TLS connection for websockets: {err:?}");
ready(false)
} else {
ready(true)
}
});
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {