mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-27 01:50:38 +00:00
update rustls
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
use crate::{auth, rate_limiter::RateBucketInfo};
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use rustls::{sign, Certificate, PrivateKey};
|
||||
use rustls::{
|
||||
crypto::ring::sign,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
@@ -85,14 +88,14 @@ pub fn configure_tls(
|
||||
|
||||
let cert_resolver = Arc::new(cert_resolver);
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver.clone())
|
||||
.into();
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
let config = rustls::ServerConfig::builder_with_protocol_versions(&[
|
||||
&rustls::version::TLS13,
|
||||
&rustls::version::TLS12,
|
||||
])
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver.clone())
|
||||
.into();
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
@@ -130,14 +133,14 @@ pub enum TlsServerEndPoint {
|
||||
}
|
||||
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &Certificate) -> anyhow::Result<Self> {
|
||||
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(&cert.0)
|
||||
let pem = x509_parser::parse_x509_certificate(&cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
@@ -148,7 +151,7 @@ impl TlsServerEndPoint {
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] =
|
||||
Sha256::new().chain_update(&cert.0).finalize().into();
|
||||
Sha256::new().chain_update(&cert).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
@@ -162,7 +165,7 @@ impl TlsServerEndPoint {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Debug)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
@@ -182,11 +185,12 @@ impl CertResolver {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
|
||||
.context(format!("Failed to parse TLS keys at '{key_path}'"))?;
|
||||
let keys: Result<Vec<_>, _> =
|
||||
rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect();
|
||||
let mut keys = keys.context(format!("Failed to parse TLS keys at '{key_path}'"))?;
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
keys.pop().map(rustls::PrivateKey).unwrap()
|
||||
keys.pop().unwrap()
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
@@ -194,30 +198,29 @@ impl CertResolver {
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
)
|
||||
})?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
};
|
||||
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
self.add_cert(PrivateKeyDer::Pkcs8(priv_key), cert_chain, is_default)
|
||||
}
|
||||
|
||||
pub fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKey,
|
||||
cert_chain: Vec<Certificate>,
|
||||
priv_key: PrivateKeyDer,
|
||||
cert_chain: Vec<CertificateDer>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
let pem = x509_parser::parse_x509_certificate(&first_cert.0)
|
||||
let pem = x509_parser::parse_x509_certificate(&first_cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
|
||||
@@ -328,19 +328,23 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
|
||||
|
||||
impl AsyncAccept for ProxyProtocolAccept {
|
||||
type Connection = WithClientIp<AddrStream>;
|
||||
|
||||
type Address = std::net::SocketAddr;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
|
||||
) -> Poll<Result<(Self::Connection, Self::Address), Self::Error>> {
|
||||
use hyper::server::accept::Accept;
|
||||
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
|
||||
let Some(conn) = conn else {
|
||||
return Poll::Ready(None);
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"no incoming connection?",
|
||||
)));
|
||||
};
|
||||
|
||||
Poll::Ready(Some(Ok(WithClientIp::new(conn))))
|
||||
let addr = conn.remote_addr();
|
||||
Poll::Ready(Ok((WithClientIp::new(conn), addr)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
@@ -20,7 +21,11 @@ use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
fn generate_certs(
|
||||
hostname: &str,
|
||||
common_name: &str,
|
||||
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
|
||||
) -> anyhow::Result<(
|
||||
CertificateDer<'static>,
|
||||
CertificateDer<'static>,
|
||||
PrivateKeyDer<'static>,
|
||||
)> {
|
||||
let ca = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::default();
|
||||
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
|
||||
@@ -37,9 +42,9 @@ fn generate_certs(
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
rustls::Certificate(ca.serialize_der()?),
|
||||
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
|
||||
rustls::PrivateKey(cert.serialize_private_key_der()),
|
||||
CertificateDer::from(ca.serialize_der()?),
|
||||
CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
|
||||
PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(cert.serialize_private_key_der())),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -74,7 +79,6 @@ fn generate_tls_config<'a>(
|
||||
|
||||
let tls_config = {
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert.clone()], key.clone())?
|
||||
.into();
|
||||
@@ -93,7 +97,6 @@ fn generate_tls_config<'a>(
|
||||
|
||||
let client_config = {
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates({
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add(&ca)?;
|
||||
|
||||
@@ -77,14 +77,19 @@ pub async fn task_main(
|
||||
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
ws_connections.close(); // allows `ws_connections.wait to complete`
|
||||
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
error!("failed to accept TLS connection for websockets: {err:?}");
|
||||
ready(false)
|
||||
} else {
|
||||
ready(true)
|
||||
}
|
||||
});
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming)
|
||||
.map(|x| match x {
|
||||
Ok((conn, _)) => Ok(conn),
|
||||
Err(e) => Err(e),
|
||||
})
|
||||
.filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
error!("failed to accept TLS connection for websockets: {err:?}");
|
||||
ready(false)
|
||||
} else {
|
||||
ready(true)
|
||||
}
|
||||
});
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
|
||||
|
||||
Reference in New Issue
Block a user