From dc109c42bcd9e3f3805db1edb3caa17ca72052cf Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 19 Dec 2023 15:39:26 +0000 Subject: [PATCH] update rustls --- Cargo.lock | 116 +++++++++++++------ Cargo.toml | 13 ++- libs/postgres_backend/Cargo.toml | 5 +- libs/postgres_backend/src/lib.rs | 114 +++++++++++++++++- libs/postgres_backend/tests/simple_select.rs | 26 +++-- proxy/src/config.rs | 49 ++++---- proxy/src/protocol2.rs | 14 ++- proxy/src/proxy/tests.rs | 15 ++- proxy/src/serverless.rs | 21 ++-- 9 files changed, 280 insertions(+), 93 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0e51e88e3b..26cf17b807 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -572,7 +572,7 @@ dependencies = [ "once_cell", "pin-project-lite", "pin-utils", - "rustls", + "rustls 0.21.9", "tokio", "tracing", ] @@ -2278,10 +2278,10 @@ dependencies = [ "http", "hyper", "log", - "rustls", + "rustls 0.21.9", "rustls-native-certs", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.0", ] [[package]] @@ -3482,14 +3482,14 @@ dependencies = [ "futures", "once_cell", "pq_proto", - "rustls", - "rustls-pemfile", + "ring 0.17.6", + "rustls 0.22.1", + "rustls-pemfile 2.0.0", "serde", "thiserror", "tokio", "tokio-postgres", - "tokio-postgres-rustls", - "tokio-rustls", + "tokio-rustls 0.25.0", "tracing", "workspace_hack", ] @@ -3717,8 +3717,8 @@ dependencies = [ "routerify", "rstest", "rustc-hash", - "rustls", - "rustls-pemfile", + "rustls 0.22.1", + "rustls-pemfile 2.0.0", "scopeguard", "serde", "serde_json", @@ -3732,7 +3732,7 @@ dependencies = [ "tokio", "tokio-postgres", "tokio-postgres-rustls", - "tokio-rustls", + "tokio-rustls 0.25.0", "tokio-util", "tracing", "tracing-opentelemetry", @@ -3860,12 +3860,12 @@ dependencies = [ [[package]] name = "rcgen" -version = "0.11.1" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4954fbc00dcd4d8282c987710e50ba513d351400dbdd00e803a05172a90d8976" +checksum = "5d918c80c5a4c7560db726763020bd16db179e4d5b828078842274a443addb5d" dependencies = [ - "pem 2.0.1", - "ring 0.16.20", + "pem 3.0.3", + "ring 0.17.6", "time", "yasna", ] @@ -4003,14 +4003,14 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", - "rustls-pemfile", + "rustls 0.21.9", + "rustls-pemfile 1.0.2", "serde", "serde_json", "serde_urlencoded", "tokio", "tokio-native-tls", - "tokio-rustls", + "tokio-rustls 0.24.0", "tokio-util", "tower-service", "url", @@ -4250,6 +4250,20 @@ dependencies = [ "sct", ] +[[package]] +name = "rustls" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe6b63262c9fcac8659abfaa96cac103d28166d3ff3eaf8f412e19f3ae9e5a48" +dependencies = [ + "log", + "ring 0.17.6", + "rustls-pki-types", + "rustls-webpki 0.102.0", + "subtle", + "zeroize", +] + [[package]] name = "rustls-native-certs" version = "0.6.2" @@ -4257,7 +4271,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" dependencies = [ "openssl-probe", - "rustls-pemfile", + "rustls-pemfile 1.0.2", "schannel", "security-framework", ] @@ -4271,6 +4285,22 @@ dependencies = [ "base64 0.21.1", ] +[[package]] +name = "rustls-pemfile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4" +dependencies = [ + "base64 0.21.1", + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7673e0aa20ee4937c6aacfc12bb8341cfbf054cdd21df6bec5fd0629fe9339b" + [[package]] name = "rustls-webpki" version = "0.100.2" @@ -4291,6 +4321,17 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "rustls-webpki" +version = "0.102.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de2635c8bc2b88d367767c5de8ea1d8db9af3f6219eba28442242d9ab81d1b89" +dependencies = [ + "ring 0.17.6", + "rustls-pki-types", + "untrusted 0.9.0", +] + [[package]] name = "rustversion" version = "1.0.12" @@ -4331,7 +4372,7 @@ dependencies = [ "serde_with", "thiserror", "tokio", - "tokio-rustls", + "tokio-rustls 0.25.0", "tokio-stream", "tracing", "tracing-appender", @@ -4495,7 +4536,7 @@ checksum = "2e95efd0cefa32028cdb9766c96de71d96671072f9fb494dc9fb84c0ef93e52b" dependencies = [ "httpdate", "reqwest", - "rustls", + "rustls 0.21.9", "sentry-backtrace", "sentry-contexts", "sentry-core", @@ -5161,16 +5202,14 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tls-listener" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81294c017957a1a69794f506723519255879e15a870507faf45dfed288b763dd" +version = "0.9.0" +source = "git+https://github.com/conradludgate/tls-listener?branch=main#4801141b5660613e77816044da6540aa64f388ec" dependencies = [ "futures-util", - "hyper", "pin-project-lite", "thiserror", "tokio", - "tokio-rustls", + "tokio-rustls 0.25.0", ] [[package]] @@ -5253,10 +5292,10 @@ checksum = "dd5831152cb0d3f79ef5523b357319ba154795d64c7078b2daa95a803b54057f" dependencies = [ "futures", "ring 0.16.20", - "rustls", + "rustls 0.21.9", "tokio", "tokio-postgres", - "tokio-rustls", + "tokio-rustls 0.24.0", ] [[package]] @@ -5265,7 +5304,18 @@ version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5" dependencies = [ - "rustls", + "rustls 0.21.9", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" +dependencies = [ + "rustls 0.22.1", + "rustls-pki-types", "tokio", ] @@ -5412,9 +5462,9 @@ dependencies = [ "pin-project", "prost", "rustls-native-certs", - "rustls-pemfile", + "rustls-pemfile 1.0.2", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.0", "tokio-stream", "tower", "tower-layer", @@ -5726,7 +5776,7 @@ dependencies = [ "base64 0.21.1", "log", "once_cell", - "rustls", + "rustls 0.21.9", "rustls-webpki 0.100.2", "url", "webpki-roots 0.23.1", @@ -6319,7 +6369,7 @@ dependencies = [ "regex-syntax 0.8.2", "reqwest", "ring 0.16.20", - "rustls", + "rustls 0.21.9", "scopeguard", "serde", "serde_json", @@ -6330,7 +6380,7 @@ dependencies = [ "time", "time-macros", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.0", "tokio-util", "toml_datetime", "toml_edit", diff --git a/Cargo.toml b/Cargo.toml index 6884de7bf5..d05768e9ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -115,11 +115,12 @@ reqwest = { version = "0.11", default-features = false, features = ["rustls-tls" reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_19"] } reqwest-middleware = "0.2.0" reqwest-retry = "0.2.2" +ring = "0.17" routerify = "3" rpds = "0.13" rustc-hash = "1.1.0" -rustls = "0.21" -rustls-pemfile = "1" +rustls = "0.22.1" +rustls-pemfile = "2.0.0" rustls-split = "0.3" scopeguard = "1.1" sysinfo = "0.29.2" @@ -143,11 +144,11 @@ tar = "0.4" task-local-extensions = "0.1.4" test-context = "0.1" thiserror = "1.0" -tls-listener = { version = "0.7", features = ["rustls", "hyper-h1"] } +tls-listener = { version = "0.9.0", features = ["rustls"] } tokio = { version = "1.17", features = ["macros"] } tokio-io-timeout = "1.2.0" tokio-postgres-rustls = "0.10.0" -tokio-rustls = "0.24" +tokio-rustls = "0.25.0" tokio-stream = "0.1" tokio-tar = "0.3" tokio-util = { version = "0.7.10", features = ["io", "rt"] } @@ -202,7 +203,7 @@ workspace_hack = { version = "0.1", path = "./workspace_hack/" } ## Build dependencies criterion = "0.5.1" -rcgen = "0.11" +rcgen = "0.12" rstest = "0.18" camino-tempfile = "1.0.2" tonic-build = "0.9" @@ -213,6 +214,8 @@ tonic-build = "0.9" # TODO: we should probably fork `tokio-postgres-rustls` instead. tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" } +tls-listener = { git = "https://github.com/conradludgate/tls-listener", branch="main" } + ################# Binary contents sections [profile.release] diff --git a/libs/postgres_backend/Cargo.toml b/libs/postgres_backend/Cargo.toml index 8e249c09f7..ee0e1acb97 100644 --- a/libs/postgres_backend/Cargo.toml +++ b/libs/postgres_backend/Cargo.toml @@ -9,10 +9,12 @@ async-trait.workspace = true anyhow.workspace = true bytes.workspace = true futures.workspace = true +ring.workspace = true rustls.workspace = true serde.workspace = true thiserror.workspace = true tokio.workspace = true +tokio-postgres.workspace = true tokio-rustls.workspace = true tracing.workspace = true @@ -22,5 +24,4 @@ workspace_hack.workspace = true [dev-dependencies] once_cell.workspace = true rustls-pemfile.workspace = true -tokio-postgres.workspace = true -tokio-postgres-rustls.workspace = true \ No newline at end of file +# tokio-postgres-rustls.workspace = true diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs index 1dae008a4f..9dca2f98c3 100644 --- a/libs/postgres_backend/src/lib.rs +++ b/libs/postgres_backend/src/lib.rs @@ -6,7 +6,7 @@ #![deny(clippy::undocumented_unsafe_blocks)] use anyhow::Context; use bytes::Bytes; -use futures::pin_mut; +use futures::{pin_mut, TryFutureExt, FutureExt}; use serde::{Deserialize, Serialize}; use std::io::ErrorKind; use std::net::SocketAddr; @@ -1030,3 +1030,115 @@ pub enum CopyStreamHandlerEnd { #[error(transparent)] Other(#[from] anyhow::Error), } + +#[derive(Clone)] +pub struct MakeRustlsConnect { + config: Arc, +} + +impl MakeRustlsConnect { + pub fn new(config: rustls::ClientConfig) -> Self { + Self { + config: Arc::new(config), + } + } +} + +impl tokio_postgres::tls::MakeTlsConnect for MakeRustlsConnect +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Stream = RustlsStream; + type TlsConnect = RustlsConnect; + type Error = io::Error; + + fn make_tls_connect(&mut self, hostname: &str) -> io::Result { + rustls::pki_types::ServerName::try_from(hostname) + .map(|dns_name| { + RustlsConnect(Some(RustlsConnectData { + hostname: dns_name.to_owned(), + connector: Arc::clone(&self.config).into(), + })) + }) + .or(Ok(RustlsConnect(None))) + } +} + +pub struct RustlsConnect(Option); + +struct RustlsConnectData { + hostname: rustls::pki_types::ServerName<'static>, + connector: tokio_rustls::TlsConnector, +} + +impl tokio_postgres::tls::TlsConnect for RustlsConnect +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Stream = RustlsStream; + type Error = io::Error; + type Future = Pin>> + Send>>; + + fn connect(self, stream: S) -> Self::Future { + match self.0 { + None => Box::pin(core::future::ready(Err(io::ErrorKind::InvalidInput.into()))), + Some(c) => c + .connector + .connect(c.hostname, stream) + .map_ok(|s| RustlsStream(Box::pin(s))) + .boxed(), + } + } +} + +pub struct RustlsStream(Pin>>); + +impl tokio_postgres::tls::TlsStream for RustlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> tokio_postgres::tls::ChannelBinding { + let (_, session) = self.0.get_ref(); + match session.peer_certificates() { + Some(certs) if !certs.is_empty() => { + let sha256 = ring::digest::digest(&ring::digest::SHA256, certs[0].as_ref()); + tokio_postgres::tls::ChannelBinding::tls_server_end_point(sha256.as_ref().into()) + } + _ => tokio_postgres::tls::ChannelBinding::none(), + } + } +} + +impl AsyncRead for RustlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task:: Context, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + self.0.as_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for RustlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task:: Context, + buf: &[u8], + ) -> Poll> { + self.0.as_mut().poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut std::task:: Context) -> Poll> { + self.0.as_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut std::task:: Context) -> Poll> { + self.0.as_mut().poll_shutdown(cx) + } +} diff --git a/libs/postgres_backend/tests/simple_select.rs b/libs/postgres_backend/tests/simple_select.rs index e046fa5260..3738e54f35 100644 --- a/libs/postgres_backend/tests/simple_select.rs +++ b/libs/postgres_backend/tests/simple_select.rs @@ -1,5 +1,6 @@ /// Test postgres_backend_async with tokio_postgres use once_cell::sync::Lazy; +use postgres_backend::MakeRustlsConnect; use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError}; use pq_proto::{BeMessage, RowDescriptor}; use std::io::Cursor; @@ -9,7 +10,6 @@ use tokio::net::{TcpListener, TcpStream}; use tokio_postgres::config::SslMode; use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::{Config, NoTls, SimpleQueryMessage}; -use tokio_postgres_rustls::MakeRustlsConnect; // generate client, server test streams async fn make_tcp_pair() -> (TcpStream, TcpStream) { @@ -72,14 +72,21 @@ async fn simple_select() { } } -static KEY: Lazy = Lazy::new(|| { +static KEY: Lazy> = Lazy::new(|| { let mut cursor = Cursor::new(include_bytes!("key.pem")); - rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone()) + + let key = rustls_pemfile::rsa_private_keys(&mut cursor) + .next() + .unwrap() + .unwrap(); + key.secret_pkcs1_der().to_owned().into() }); -static CERT: Lazy = Lazy::new(|| { +static CERT: Lazy> = Lazy::new(|| { let mut cursor = Cursor::new(include_bytes!("cert.pem")); - rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone()) + let cert = rustls_pemfile::certs(&mut cursor).next().unwrap().unwrap(); + + cert.into_owned() }); // test that basic select with ssl works @@ -87,10 +94,10 @@ static CERT: Lazy = Lazy::new(|| { async fn simple_select_ssl() { let (client_sock, server_sock) = make_tcp_pair().await; + let key = rustls::pki_types::PrivateKeyDer::Pkcs1(KEY.secret_pkcs1_der().to_owned().into()); let server_cfg = rustls::ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(vec![CERT.clone()], KEY.clone()) + .with_single_cert(vec![CERT.clone()], key) .unwrap(); let tls_config = Some(Arc::new(server_cfg)); let pgbackend = @@ -102,14 +109,13 @@ async fn simple_select_ssl() { }); let client_cfg = rustls::ClientConfig::builder() - .with_safe_defaults() .with_root_certificates({ let mut store = rustls::RootCertStore::empty(); - store.add(&CERT).unwrap(); + store.add(CERT.clone()).unwrap(); store }) .with_no_client_auth(); - let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg); + let mut make_tls_connect = MakeRustlsConnect::new(client_cfg); let tls_connect = >::make_tls_connect( &mut make_tls_connect, "localhost", diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 2ed248af8d..06582c4e63 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,6 +1,9 @@ use crate::{auth, rate_limiter::RateBucketInfo}; use anyhow::{bail, ensure, Context, Ok}; -use rustls::{sign, Certificate, PrivateKey}; +use rustls::{ + crypto::ring::sign, + pki_types::{CertificateDer, PrivateKeyDer}, +}; use sha2::{Digest, Sha256}; use std::{ collections::{HashMap, HashSet}, @@ -85,14 +88,14 @@ pub fn configure_tls( let cert_resolver = Arc::new(cert_resolver); - let config = rustls::ServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - // allow TLS 1.2 to be compatible with older client libraries - .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])? - .with_no_client_auth() - .with_cert_resolver(cert_resolver.clone()) - .into(); + // allow TLS 1.2 to be compatible with older client libraries + let config = rustls::ServerConfig::builder_with_protocol_versions(&[ + &rustls::version::TLS13, + &rustls::version::TLS12, + ]) + .with_no_client_auth() + .with_cert_resolver(cert_resolver.clone()) + .into(); Ok(TlsConfig { config, @@ -130,14 +133,14 @@ pub enum TlsServerEndPoint { } impl TlsServerEndPoint { - pub fn new(cert: &Certificate) -> anyhow::Result { + pub fn new(cert: &CertificateDer) -> anyhow::Result { let sha256_oids = [ // I'm explicitly not adding MD5 or SHA1 here... They're bad. oid_registry::OID_SIG_ECDSA_WITH_SHA256, oid_registry::OID_PKCS1_SHA256WITHRSA, ]; - let pem = x509_parser::parse_x509_certificate(&cert.0) + let pem = x509_parser::parse_x509_certificate(&cert) .context("Failed to parse PEM object from cerficiate")? .1; @@ -148,7 +151,7 @@ impl TlsServerEndPoint { let alg = reg.get(oid); if sha256_oids.contains(oid) { let tls_server_end_point: [u8; 32] = - Sha256::new().chain_update(&cert.0).finalize().into(); + Sha256::new().chain_update(&cert).finalize().into(); info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding"); Ok(Self::Sha256(tls_server_end_point)) } else { @@ -162,7 +165,7 @@ impl TlsServerEndPoint { } } -#[derive(Default)] +#[derive(Default, Debug)] pub struct CertResolver { certs: HashMap, TlsServerEndPoint)>, default: Option<(Arc, TlsServerEndPoint)>, @@ -182,11 +185,12 @@ impl CertResolver { let priv_key = { let key_bytes = std::fs::read(key_path) .context(format!("Failed to read TLS keys at '{key_path}'"))?; - let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) - .context(format!("Failed to parse TLS keys at '{key_path}'"))?; + let keys: Result, _> = + rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect(); + let mut keys = keys.context(format!("Failed to parse TLS keys at '{key_path}'"))?; ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); - keys.pop().map(rustls::PrivateKey).unwrap() + keys.pop().unwrap() }; let cert_chain_bytes = std::fs::read(cert_path) @@ -194,30 +198,29 @@ impl CertResolver { let cert_chain = { rustls_pemfile::certs(&mut &cert_chain_bytes[..]) + .into_iter() + .collect::, _>>() .with_context(|| { format!( "Failed to read TLS certificate chain from bytes from file at '{cert_path}'." ) })? - .into_iter() - .map(rustls::Certificate) - .collect() }; - self.add_cert(priv_key, cert_chain, is_default) + self.add_cert(PrivateKeyDer::Pkcs8(priv_key), cert_chain, is_default) } pub fn add_cert( &mut self, - priv_key: PrivateKey, - cert_chain: Vec, + priv_key: PrivateKeyDer, + cert_chain: Vec, is_default: bool, ) -> anyhow::Result<()> { let key = sign::any_supported_type(&priv_key).context("invalid private key")?; let first_cert = &cert_chain[0]; let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - let pem = x509_parser::parse_x509_certificate(&first_cert.0) + let pem = x509_parser::parse_x509_certificate(&first_cert) .context("Failed to parse PEM object from cerficiate")? .1; diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 1d8931be85..8280d7b986 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -328,19 +328,23 @@ impl AsyncRead for WithClientIp { impl AsyncAccept for ProxyProtocolAccept { type Connection = WithClientIp; - + type Address = std::net::SocketAddr; type Error = io::Error; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll> { + use hyper::server::accept::Accept; let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?); let Some(conn) = conn else { - return Poll::Ready(None); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::NotConnected, + "no incoming connection?", + ))); }; - - Poll::Ready(Some(Ok(WithClientIp::new(conn)))) + let addr = conn.remote_addr(); + Poll::Ready(Ok((WithClientIp::new(conn), addr))) } } diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 3c483c59ee..9cb73cb28e 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -12,6 +12,7 @@ use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; use async_trait::async_trait; use rstest::rstest; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer}; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream}; @@ -20,7 +21,11 @@ use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream}; fn generate_certs( hostname: &str, common_name: &str, -) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> { +) -> anyhow::Result<( + CertificateDer<'static>, + CertificateDer<'static>, + PrivateKeyDer<'static>, +)> { let ca = rcgen::Certificate::from_params({ let mut params = rcgen::CertificateParams::default(); params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); @@ -37,9 +42,9 @@ fn generate_certs( })?; Ok(( - rustls::Certificate(ca.serialize_der()?), - rustls::Certificate(cert.serialize_der_with_signer(&ca)?), - rustls::PrivateKey(cert.serialize_private_key_der()), + CertificateDer::from(ca.serialize_der()?), + CertificateDer::from(cert.serialize_der_with_signer(&ca)?), + PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(cert.serialize_private_key_der())), )) } @@ -74,7 +79,6 @@ fn generate_tls_config<'a>( let tls_config = { let config = rustls::ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![cert.clone()], key.clone())? .into(); @@ -93,7 +97,6 @@ fn generate_tls_config<'a>( let client_config = { let config = rustls::ClientConfig::builder() - .with_safe_defaults() .with_root_certificates({ let mut store = rustls::RootCertStore::empty(); store.add(&ca)?; diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index e358a0712f..08fcf07be2 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -77,14 +77,19 @@ pub async fn task_main( let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); ws_connections.close(); // allows `ws_connections.wait to complete` - let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { - if let Err(err) = conn { - error!("failed to accept TLS connection for websockets: {err:?}"); - ready(false) - } else { - ready(true) - } - }); + let tls_listener = TlsListener::new(tls_acceptor, addr_incoming) + .map(|x| match x { + Ok((conn, _)) => Ok(conn), + Err(e) => Err(e), + }) + .filter(|conn| { + if let Err(err) = conn { + error!("failed to accept TLS connection for websockets: {err:?}"); + ready(false) + } else { + ready(true) + } + }); let make_svc = hyper::service::make_service_fn( |stream: &tokio_rustls::server::TlsStream>| {