This commit is contained in:
Conrad Ludgate
2023-12-19 15:47:38 +00:00
parent dc109c42bc
commit b78a8c4d53
3 changed files with 19 additions and 21 deletions

View File

@@ -6,7 +6,6 @@
use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::proxy::run_until_cancelled;
use tokio::net::TcpListener;
@@ -76,10 +75,12 @@ async fn main() -> anyhow::Result<()> {
let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
.collect::<Result<Vec<_>, _>>()
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
keys.pop().map(rustls::PrivateKey).unwrap()
let bytes = keys.pop().unwrap().secret_pkcs8_der().to_owned();
rustls::pki_types::PrivateKeyDer::Pkcs1(bytes.into())
};
let cert_chain_bytes = std::fs::read(cert_path)
@@ -87,25 +88,23 @@ async fn main() -> anyhow::Result<()> {
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.collect::<Result<Vec<_>,_>>()
.context(format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
))?
.into_iter()
.map(rustls::Certificate)
.collect_vec()
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
(tls_config, tls_server_end_point)
}

View File

@@ -140,7 +140,7 @@ impl TlsServerEndPoint {
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(&cert)
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
@@ -150,8 +150,7 @@ impl TlsServerEndPoint {
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] =
Sha256::new().chain_update(&cert).finalize().into();
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
@@ -198,7 +197,6 @@ impl CertResolver {
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.into_iter()
.collect::<Result<Vec<_>, _>>()
.with_context(|| {
format!(
@@ -213,14 +211,14 @@ impl CertResolver {
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer,
cert_chain: Vec<CertificateDer>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(&first_cert)
let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")?
.1;

View File

@@ -11,11 +11,11 @@ use crate::console::{CachedNodeInfo, NodeInfo};
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
use crate::{auth, http, sasl, scram};
use async_trait::async_trait;
use postgres_backend::{MakeRustlsConnect, RustlsStream};
use rstest::rstest;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
@@ -78,9 +78,10 @@ fn generate_tls_config<'a>(
let (ca, cert, key) = generate_certs(hostname, common_name)?;
let tls_config = {
let key_clone = rustls::pki_types::PrivateKeyDer::Pkcs1(key.secret_der().to_owned().into());
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone())?
.with_single_cert(vec![cert.clone()], key_clone)?
.into();
let mut cert_resolver = CertResolver::new();
@@ -99,7 +100,7 @@ fn generate_tls_config<'a>(
let config = rustls::ClientConfig::builder()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&ca)?;
store.add(ca)?;
store
})
.with_no_client_auth();