Compare commits

...

4 Commits

Author SHA1 Message Date
Conrad Ludgate
a5c1716edc hakari 2023-12-19 15:56:19 +00:00
Conrad Ludgate
0f36927a17 deps 2023-12-19 15:55:56 +00:00
Conrad Ludgate
b78a8c4d53 fix 2023-12-19 15:49:35 +00:00
Conrad Ludgate
dc109c42bc update rustls 2023-12-19 15:49:35 +00:00
11 changed files with 297 additions and 149 deletions

155
Cargo.lock generated
View File

@@ -572,7 +572,7 @@ dependencies = [
"once_cell", "once_cell",
"pin-project-lite", "pin-project-lite",
"pin-utils", "pin-utils",
"rustls", "rustls 0.21.9",
"tokio", "tokio",
"tracing", "tracing",
] ]
@@ -2278,10 +2278,10 @@ dependencies = [
"http", "http",
"hyper", "hyper",
"log", "log",
"rustls", "rustls 0.21.9",
"rustls-native-certs", "rustls-native-certs",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls 0.24.0",
] ]
[[package]] [[package]]
@@ -2493,7 +2493,7 @@ checksum = "5c7ea04a7c5c055c175f189b6dc6ba036fd62306b58c66c9f6389036c503a3f4"
dependencies = [ dependencies = [
"base64 0.21.1", "base64 0.21.1",
"js-sys", "js-sys",
"pem 3.0.3", "pem",
"ring 0.17.6", "ring 0.17.6",
"serde", "serde",
"serde_json", "serde_json",
@@ -3290,16 +3290,6 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]]
name = "pem"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b13fe415cdf3c8e44518e18a7c95a13431d9bdf6d15367d82b23c377fdd441a"
dependencies = [
"base64 0.21.1",
"serde",
]
[[package]] [[package]]
name = "pem" name = "pem"
version = "3.0.3" version = "3.0.3"
@@ -3482,14 +3472,14 @@ dependencies = [
"futures", "futures",
"once_cell", "once_cell",
"pq_proto", "pq_proto",
"rustls", "ring 0.17.6",
"rustls-pemfile", "rustls 0.22.1",
"rustls-pemfile 2.0.0",
"serde", "serde",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-postgres", "tokio-postgres",
"tokio-postgres-rustls", "tokio-rustls 0.25.0",
"tokio-rustls",
"tracing", "tracing",
"workspace_hack", "workspace_hack",
] ]
@@ -3717,8 +3707,8 @@ dependencies = [
"routerify", "routerify",
"rstest", "rstest",
"rustc-hash", "rustc-hash",
"rustls", "rustls 0.22.1",
"rustls-pemfile", "rustls-pemfile 2.0.0",
"scopeguard", "scopeguard",
"serde", "serde",
"serde_json", "serde_json",
@@ -3732,7 +3722,7 @@ dependencies = [
"tokio", "tokio",
"tokio-postgres", "tokio-postgres",
"tokio-postgres-rustls", "tokio-postgres-rustls",
"tokio-rustls", "tokio-rustls 0.25.0",
"tokio-util", "tokio-util",
"tracing", "tracing",
"tracing-opentelemetry", "tracing-opentelemetry",
@@ -3741,7 +3731,7 @@ dependencies = [
"url", "url",
"utils", "utils",
"uuid", "uuid",
"webpki-roots 0.25.2", "webpki-roots",
"workspace_hack", "workspace_hack",
"x509-parser", "x509-parser",
] ]
@@ -3860,12 +3850,12 @@ dependencies = [
[[package]] [[package]]
name = "rcgen" name = "rcgen"
version = "0.11.1" version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4954fbc00dcd4d8282c987710e50ba513d351400dbdd00e803a05172a90d8976" checksum = "5d918c80c5a4c7560db726763020bd16db179e4d5b828078842274a443addb5d"
dependencies = [ dependencies = [
"pem 2.0.1", "pem",
"ring 0.16.20", "ring 0.17.6",
"time", "time",
"yasna", "yasna",
] ]
@@ -4003,14 +3993,14 @@ dependencies = [
"once_cell", "once_cell",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"rustls", "rustls 0.21.9",
"rustls-pemfile", "rustls-pemfile 1.0.2",
"serde", "serde",
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"tokio", "tokio",
"tokio-native-tls", "tokio-native-tls",
"tokio-rustls", "tokio-rustls 0.24.0",
"tokio-util", "tokio-util",
"tower-service", "tower-service",
"url", "url",
@@ -4018,7 +4008,7 @@ dependencies = [
"wasm-bindgen-futures", "wasm-bindgen-futures",
"wasm-streams", "wasm-streams",
"web-sys", "web-sys",
"webpki-roots 0.25.2", "webpki-roots",
"winreg", "winreg",
] ]
@@ -4250,6 +4240,20 @@ dependencies = [
"sct", "sct",
] ]
[[package]]
name = "rustls"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe6b63262c9fcac8659abfaa96cac103d28166d3ff3eaf8f412e19f3ae9e5a48"
dependencies = [
"log",
"ring 0.17.6",
"rustls-pki-types",
"rustls-webpki 0.102.0",
"subtle",
"zeroize",
]
[[package]] [[package]]
name = "rustls-native-certs" name = "rustls-native-certs"
version = "0.6.2" version = "0.6.2"
@@ -4257,7 +4261,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50"
dependencies = [ dependencies = [
"openssl-probe", "openssl-probe",
"rustls-pemfile", "rustls-pemfile 1.0.2",
"schannel", "schannel",
"security-framework", "security-framework",
] ]
@@ -4272,15 +4276,21 @@ dependencies = [
] ]
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-pemfile"
version = "0.100.2" version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab" checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4"
dependencies = [ dependencies = [
"ring 0.16.20", "base64 0.21.1",
"untrusted 0.7.1", "rustls-pki-types",
] ]
[[package]]
name = "rustls-pki-types"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7673e0aa20ee4937c6aacfc12bb8341cfbf054cdd21df6bec5fd0629fe9339b"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.101.7" version = "0.101.7"
@@ -4291,6 +4301,17 @@ dependencies = [
"untrusted 0.9.0", "untrusted 0.9.0",
] ]
[[package]]
name = "rustls-webpki"
version = "0.102.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de2635c8bc2b88d367767c5de8ea1d8db9af3f6219eba28442242d9ab81d1b89"
dependencies = [
"ring 0.17.6",
"rustls-pki-types",
"untrusted 0.9.0",
]
[[package]] [[package]]
name = "rustversion" name = "rustversion"
version = "1.0.12" version = "1.0.12"
@@ -4331,7 +4352,7 @@ dependencies = [
"serde_with", "serde_with",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls 0.25.0",
"tokio-stream", "tokio-stream",
"tracing", "tracing",
"tracing-appender", "tracing-appender",
@@ -4495,7 +4516,7 @@ checksum = "2e95efd0cefa32028cdb9766c96de71d96671072f9fb494dc9fb84c0ef93e52b"
dependencies = [ dependencies = [
"httpdate", "httpdate",
"reqwest", "reqwest",
"rustls", "rustls 0.21.9",
"sentry-backtrace", "sentry-backtrace",
"sentry-contexts", "sentry-contexts",
"sentry-core", "sentry-core",
@@ -4503,7 +4524,7 @@ dependencies = [
"sentry-tracing", "sentry-tracing",
"tokio", "tokio",
"ureq", "ureq",
"webpki-roots 0.25.2", "webpki-roots",
] ]
[[package]] [[package]]
@@ -5161,16 +5182,14 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tls-listener" name = "tls-listener"
version = "0.7.0" version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/conradludgate/tls-listener?branch=main#4801141b5660613e77816044da6540aa64f388ec"
checksum = "81294c017957a1a69794f506723519255879e15a870507faf45dfed288b763dd"
dependencies = [ dependencies = [
"futures-util", "futures-util",
"hyper",
"pin-project-lite", "pin-project-lite",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls 0.25.0",
] ]
[[package]] [[package]]
@@ -5253,10 +5272,10 @@ checksum = "dd5831152cb0d3f79ef5523b357319ba154795d64c7078b2daa95a803b54057f"
dependencies = [ dependencies = [
"futures", "futures",
"ring 0.16.20", "ring 0.16.20",
"rustls", "rustls 0.21.9",
"tokio", "tokio",
"tokio-postgres", "tokio-postgres",
"tokio-rustls", "tokio-rustls 0.24.0",
] ]
[[package]] [[package]]
@@ -5265,7 +5284,18 @@ version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5" checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5"
dependencies = [ dependencies = [
"rustls", "rustls 0.21.9",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f"
dependencies = [
"rustls 0.22.1",
"rustls-pki-types",
"tokio", "tokio",
] ]
@@ -5412,9 +5442,9 @@ dependencies = [
"pin-project", "pin-project",
"prost", "prost",
"rustls-native-certs", "rustls-native-certs",
"rustls-pemfile", "rustls-pemfile 1.0.2",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls 0.24.0",
"tokio-stream", "tokio-stream",
"tower", "tower",
"tower-layer", "tower-layer",
@@ -5719,17 +5749,17 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]] [[package]]
name = "ureq" name = "ureq"
version = "2.7.1" version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97"
dependencies = [ dependencies = [
"base64 0.21.1", "base64 0.21.1",
"log", "log",
"once_cell", "once_cell",
"rustls", "rustls 0.21.9",
"rustls-webpki 0.100.2", "rustls-webpki 0.101.7",
"url", "url",
"webpki-roots 0.23.1", "webpki-roots",
] ]
[[package]] [[package]]
@@ -6038,15 +6068,6 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "webpki-roots"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338"
dependencies = [
"rustls-webpki 0.100.2",
]
[[package]] [[package]]
name = "webpki-roots" name = "webpki-roots"
version = "0.25.2" version = "0.25.2"
@@ -6295,11 +6316,8 @@ dependencies = [
"either", "either",
"fail", "fail",
"futures", "futures",
"futures-channel",
"futures-core",
"futures-executor", "futures-executor",
"futures-io", "futures-io",
"futures-sink",
"futures-util", "futures-util",
"hex", "hex",
"hmac", "hmac",
@@ -6318,8 +6336,7 @@ dependencies = [
"regex-automata 0.4.3", "regex-automata 0.4.3",
"regex-syntax 0.8.2", "regex-syntax 0.8.2",
"reqwest", "reqwest",
"ring 0.16.20", "rustls 0.21.9",
"rustls",
"scopeguard", "scopeguard",
"serde", "serde",
"serde_json", "serde_json",
@@ -6330,7 +6347,7 @@ dependencies = [
"time", "time",
"time-macros", "time-macros",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls 0.24.0",
"tokio-util", "tokio-util",
"toml_datetime", "toml_datetime",
"toml_edit", "toml_edit",

View File

@@ -115,11 +115,12 @@ reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_19"] } reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_19"] }
reqwest-middleware = "0.2.0" reqwest-middleware = "0.2.0"
reqwest-retry = "0.2.2" reqwest-retry = "0.2.2"
ring = "0.17"
routerify = "3" routerify = "3"
rpds = "0.13" rpds = "0.13"
rustc-hash = "1.1.0" rustc-hash = "1.1.0"
rustls = "0.21" rustls = "0.22.1"
rustls-pemfile = "1" rustls-pemfile = "2.0.0"
rustls-split = "0.3" rustls-split = "0.3"
scopeguard = "1.1" scopeguard = "1.1"
sysinfo = "0.29.2" sysinfo = "0.29.2"
@@ -143,11 +144,11 @@ tar = "0.4"
task-local-extensions = "0.1.4" task-local-extensions = "0.1.4"
test-context = "0.1" test-context = "0.1"
thiserror = "1.0" thiserror = "1.0"
tls-listener = { version = "0.7", features = ["rustls", "hyper-h1"] } tls-listener = { version = "0.9.0", features = ["rustls"] }
tokio = { version = "1.17", features = ["macros"] } tokio = { version = "1.17", features = ["macros"] }
tokio-io-timeout = "1.2.0" tokio-io-timeout = "1.2.0"
tokio-postgres-rustls = "0.10.0" tokio-postgres-rustls = "0.10.0"
tokio-rustls = "0.24" tokio-rustls = "0.25.0"
tokio-stream = "0.1" tokio-stream = "0.1"
tokio-tar = "0.3" tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "rt"] } tokio-util = { version = "0.7.10", features = ["io", "rt"] }
@@ -202,7 +203,7 @@ workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies ## Build dependencies
criterion = "0.5.1" criterion = "0.5.1"
rcgen = "0.11" rcgen = "0.12"
rstest = "0.18" rstest = "0.18"
camino-tempfile = "1.0.2" camino-tempfile = "1.0.2"
tonic-build = "0.9" tonic-build = "0.9"
@@ -213,6 +214,8 @@ tonic-build = "0.9"
# TODO: we should probably fork `tokio-postgres-rustls` instead. # TODO: we should probably fork `tokio-postgres-rustls` instead.
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" } tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
tls-listener = { git = "https://github.com/conradludgate/tls-listener", branch="main" }
################# Binary contents sections ################# Binary contents sections
[profile.release] [profile.release]

View File

@@ -9,10 +9,12 @@ async-trait.workspace = true
anyhow.workspace = true anyhow.workspace = true
bytes.workspace = true bytes.workspace = true
futures.workspace = true futures.workspace = true
ring.workspace = true
rustls.workspace = true rustls.workspace = true
serde.workspace = true serde.workspace = true
thiserror.workspace = true thiserror.workspace = true
tokio.workspace = true tokio.workspace = true
tokio-postgres.workspace = true
tokio-rustls.workspace = true tokio-rustls.workspace = true
tracing.workspace = true tracing.workspace = true
@@ -22,5 +24,4 @@ workspace_hack.workspace = true
[dev-dependencies] [dev-dependencies]
once_cell.workspace = true once_cell.workspace = true
rustls-pemfile.workspace = true rustls-pemfile.workspace = true
tokio-postgres.workspace = true # tokio-postgres-rustls.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -6,7 +6,7 @@
#![deny(clippy::undocumented_unsafe_blocks)] #![deny(clippy::undocumented_unsafe_blocks)]
use anyhow::Context; use anyhow::Context;
use bytes::Bytes; use bytes::Bytes;
use futures::pin_mut; use futures::{pin_mut, TryFutureExt, FutureExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::io::ErrorKind; use std::io::ErrorKind;
use std::net::SocketAddr; use std::net::SocketAddr;
@@ -1030,3 +1030,115 @@ pub enum CopyStreamHandlerEnd {
#[error(transparent)] #[error(transparent)]
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
#[derive(Clone)]
pub struct MakeRustlsConnect {
config: Arc<rustls::ClientConfig>,
}
impl MakeRustlsConnect {
pub fn new(config: rustls::ClientConfig) -> Self {
Self {
config: Arc::new(config),
}
}
}
impl<S> tokio_postgres::tls::MakeTlsConnect<S> for MakeRustlsConnect
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = RustlsStream<S>;
type TlsConnect = RustlsConnect;
type Error = io::Error;
fn make_tls_connect(&mut self, hostname: &str) -> io::Result<RustlsConnect> {
rustls::pki_types::ServerName::try_from(hostname)
.map(|dns_name| {
RustlsConnect(Some(RustlsConnectData {
hostname: dns_name.to_owned(),
connector: Arc::clone(&self.config).into(),
}))
})
.or(Ok(RustlsConnect(None)))
}
}
pub struct RustlsConnect(Option<RustlsConnectData>);
struct RustlsConnectData {
hostname: rustls::pki_types::ServerName<'static>,
connector: tokio_rustls::TlsConnector,
}
impl<S> tokio_postgres::tls::TlsConnect<S> for RustlsConnect
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Stream = RustlsStream<S>;
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = io::Result<RustlsStream<S>>> + Send>>;
fn connect(self, stream: S) -> Self::Future {
match self.0 {
None => Box::pin(core::future::ready(Err(io::ErrorKind::InvalidInput.into()))),
Some(c) => c
.connector
.connect(c.hostname, stream)
.map_ok(|s| RustlsStream(Box::pin(s)))
.boxed(),
}
}
}
pub struct RustlsStream<S>(Pin<Box<tokio_rustls::client:: TlsStream<S>>>);
impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> tokio_postgres::tls::ChannelBinding {
let (_, session) = self.0.get_ref();
match session.peer_certificates() {
Some(certs) if !certs.is_empty() => {
let sha256 = ring::digest::digest(&ring::digest::SHA256, certs[0].as_ref());
tokio_postgres::tls::ChannelBinding::tls_server_end_point(sha256.as_ref().into())
}
_ => tokio_postgres::tls::ChannelBinding::none(),
}
}
}
impl<S> AsyncRead for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task:: Context,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
self.0.as_mut().poll_read(cx, buf)
}
}
impl<S> AsyncWrite for RustlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task:: Context,
buf: &[u8],
) -> Poll<tokio::io::Result<usize>> {
self.0.as_mut().poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut std::task:: Context) -> Poll<tokio::io::Result<()>> {
self.0.as_mut().poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut std::task:: Context) -> Poll<tokio::io::Result<()>> {
self.0.as_mut().poll_shutdown(cx)
}
}

View File

@@ -1,5 +1,6 @@
/// Test postgres_backend_async with tokio_postgres /// Test postgres_backend_async with tokio_postgres
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use postgres_backend::MakeRustlsConnect;
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError}; use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor}; use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor; use std::io::Cursor;
@@ -9,7 +10,6 @@ use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode; use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{Config, NoTls, SimpleQueryMessage}; use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
use tokio_postgres_rustls::MakeRustlsConnect;
// generate client, server test streams // generate client, server test streams
async fn make_tcp_pair() -> (TcpStream, TcpStream) { async fn make_tcp_pair() -> (TcpStream, TcpStream) {
@@ -72,14 +72,21 @@ async fn simple_select() {
} }
} }
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| { static KEY: Lazy<rustls::pki_types::PrivatePkcs1KeyDer<'static>> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem")); let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
let key = rustls_pemfile::rsa_private_keys(&mut cursor)
.next()
.unwrap()
.unwrap();
key.secret_pkcs1_der().to_owned().into()
}); });
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| { static CERT: Lazy<rustls::pki_types::CertificateDer<'static>> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem")); let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone()) let cert = rustls_pemfile::certs(&mut cursor).next().unwrap().unwrap();
cert.into_owned()
}); });
// test that basic select with ssl works // test that basic select with ssl works
@@ -87,10 +94,10 @@ static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
async fn simple_select_ssl() { async fn simple_select_ssl() {
let (client_sock, server_sock) = make_tcp_pair().await; let (client_sock, server_sock) = make_tcp_pair().await;
let key = rustls::pki_types::PrivateKeyDer::Pkcs1(KEY.secret_pkcs1_der().to_owned().into());
let server_cfg = rustls::ServerConfig::builder() let server_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone()) .with_single_cert(vec![CERT.clone()], key)
.unwrap(); .unwrap();
let tls_config = Some(Arc::new(server_cfg)); let tls_config = Some(Arc::new(server_cfg));
let pgbackend = let pgbackend =
@@ -102,14 +109,13 @@ async fn simple_select_ssl() {
}); });
let client_cfg = rustls::ClientConfig::builder() let client_cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({ .with_root_certificates({
let mut store = rustls::RootCertStore::empty(); let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap(); store.add(CERT.clone()).unwrap();
store store
}) })
.with_no_client_auth(); .with_no_client_auth();
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg); let mut make_tls_connect = MakeRustlsConnect::new(client_cfg);
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect( let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
&mut make_tls_connect, &mut make_tls_connect,
"localhost", "localhost",

View File

@@ -6,7 +6,6 @@
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use futures::future::Either; use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint; use proxy::config::TlsServerEndPoint;
use proxy::proxy::run_until_cancelled; use proxy::proxy::run_until_cancelled;
use tokio::net::TcpListener; use tokio::net::TcpListener;
@@ -76,10 +75,12 @@ async fn main() -> anyhow::Result<()> {
let key = { let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?; let key_bytes = std::fs::read(key_path).context("TLS key file")?;
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
.collect::<Result<Vec<_>, _>>()
.context(format!("Failed to read TLS keys at '{key_path}'"))?; .context(format!("Failed to read TLS keys at '{key_path}'"))?;
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
keys.pop().map(rustls::PrivateKey).unwrap() let bytes = keys.pop().unwrap().secret_pkcs8_der().to_owned();
rustls::pki_types::PrivateKeyDer::Pkcs1(bytes.into())
}; };
let cert_chain_bytes = std::fs::read(cert_path) let cert_chain_bytes = std::fs::read(cert_path)
@@ -87,25 +88,23 @@ async fn main() -> anyhow::Result<()> {
let cert_chain = { let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..]) rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.collect::<Result<Vec<_>,_>>()
.context(format!( .context(format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'." "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
))? ))?
.into_iter()
.map(rustls::Certificate)
.collect_vec()
}; };
// needed for channel bindings // needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?; let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config = rustls::ServerConfig::builder() let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[
.with_safe_default_cipher_suites() &rustls::version::TLS13,
.with_safe_default_kx_groups() &rustls::version::TLS12,
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])? ])
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(cert_chain, key)? .with_single_cert(cert_chain, key)?
.into(); .into();
(tls_config, tls_server_end_point) (tls_config, tls_server_end_point)
} }

View File

@@ -1,6 +1,9 @@
use crate::{auth, rate_limiter::RateBucketInfo}; use crate::{auth, rate_limiter::RateBucketInfo};
use anyhow::{bail, ensure, Context, Ok}; use anyhow::{bail, ensure, Context, Ok};
use rustls::{sign, Certificate, PrivateKey}; use rustls::{
crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
@@ -85,14 +88,14 @@ pub fn configure_tls(
let cert_resolver = Arc::new(cert_resolver); let cert_resolver = Arc::new(cert_resolver);
let config = rustls::ServerConfig::builder() // allow TLS 1.2 to be compatible with older client libraries
.with_safe_default_cipher_suites() let config = rustls::ServerConfig::builder_with_protocol_versions(&[
.with_safe_default_kx_groups() &rustls::version::TLS13,
// allow TLS 1.2 to be compatible with older client libraries &rustls::version::TLS12,
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])? ])
.with_no_client_auth() .with_no_client_auth()
.with_cert_resolver(cert_resolver.clone()) .with_cert_resolver(cert_resolver.clone())
.into(); .into();
Ok(TlsConfig { Ok(TlsConfig {
config, config,
@@ -130,14 +133,14 @@ pub enum TlsServerEndPoint {
} }
impl TlsServerEndPoint { impl TlsServerEndPoint {
pub fn new(cert: &Certificate) -> anyhow::Result<Self> { pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
let sha256_oids = [ let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad. // I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256, oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA, oid_registry::OID_PKCS1_SHA256WITHRSA,
]; ];
let pem = x509_parser::parse_x509_certificate(&cert.0) let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")? .context("Failed to parse PEM object from cerficiate")?
.1; .1;
@@ -147,8 +150,7 @@ impl TlsServerEndPoint {
let oid = pem.signature_algorithm.oid(); let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid); let alg = reg.get(oid);
if sha256_oids.contains(oid) { if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
Sha256::new().chain_update(&cert.0).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding"); info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point)) Ok(Self::Sha256(tls_server_end_point))
} else { } else {
@@ -162,7 +164,7 @@ impl TlsServerEndPoint {
} }
} }
#[derive(Default)] #[derive(Default, Debug)]
pub struct CertResolver { pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>, certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>, default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
@@ -182,11 +184,12 @@ impl CertResolver {
let priv_key = { let priv_key = {
let key_bytes = std::fs::read(key_path) let key_bytes = std::fs::read(key_path)
.context(format!("Failed to read TLS keys at '{key_path}'"))?; .context(format!("Failed to read TLS keys at '{key_path}'"))?;
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) let keys: Result<Vec<_>, _> =
.context(format!("Failed to parse TLS keys at '{key_path}'"))?; rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect();
let mut keys = keys.context(format!("Failed to parse TLS keys at '{key_path}'"))?;
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
keys.pop().map(rustls::PrivateKey).unwrap() keys.pop().unwrap()
}; };
let cert_chain_bytes = std::fs::read(cert_path) let cert_chain_bytes = std::fs::read(cert_path)
@@ -194,30 +197,28 @@ impl CertResolver {
let cert_chain = { let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..]) rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.collect::<Result<Vec<_>, _>>()
.with_context(|| { .with_context(|| {
format!( format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'." "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
) )
})? })?
.into_iter()
.map(rustls::Certificate)
.collect()
}; };
self.add_cert(priv_key, cert_chain, is_default) self.add_cert(PrivateKeyDer::Pkcs8(priv_key), cert_chain, is_default)
} }
pub fn add_cert( pub fn add_cert(
&mut self, &mut self,
priv_key: PrivateKey, priv_key: PrivateKeyDer,
cert_chain: Vec<Certificate>, cert_chain: Vec<CertificateDer<'static>>,
is_default: bool, is_default: bool,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?; let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0]; let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let pem = x509_parser::parse_x509_certificate(&first_cert.0) let pem = x509_parser::parse_x509_certificate(first_cert)
.context("Failed to parse PEM object from cerficiate")? .context("Failed to parse PEM object from cerficiate")?
.1; .1;

View File

@@ -328,19 +328,23 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
impl AsyncAccept for ProxyProtocolAccept { impl AsyncAccept for ProxyProtocolAccept {
type Connection = WithClientIp<AddrStream>; type Connection = WithClientIp<AddrStream>;
type Address = std::net::SocketAddr;
type Error = io::Error; type Error = io::Error;
fn poll_accept( fn poll_accept(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> { ) -> Poll<Result<(Self::Connection, Self::Address), Self::Error>> {
use hyper::server::accept::Accept;
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?); let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
let Some(conn) = conn else { let Some(conn) = conn else {
return Poll::Ready(None); return Poll::Ready(Err(io::Error::new(
io::ErrorKind::NotConnected,
"no incoming connection?",
)));
}; };
let addr = conn.remote_addr();
Poll::Ready(Some(Ok(WithClientIp::new(conn)))) Poll::Ready(Ok((WithClientIp::new(conn), addr)))
} }
} }

View File

@@ -11,16 +11,21 @@ use crate::console::{CachedNodeInfo, NodeInfo};
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
use crate::{auth, http, sasl, scram}; use crate::{auth, http, sasl, scram};
use async_trait::async_trait; use async_trait::async_trait;
use postgres_backend::{MakeRustlsConnect, RustlsStream};
use rstest::rstest; use rstest::rstest;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer};
use tokio_postgres::config::SslMode; use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls}; use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
/// Generate a set of TLS certificates: CA + server. /// Generate a set of TLS certificates: CA + server.
fn generate_certs( fn generate_certs(
hostname: &str, hostname: &str,
common_name: &str, common_name: &str,
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> { ) -> anyhow::Result<(
CertificateDer<'static>,
CertificateDer<'static>,
PrivateKeyDer<'static>,
)> {
let ca = rcgen::Certificate::from_params({ let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default(); let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
@@ -37,9 +42,9 @@ fn generate_certs(
})?; })?;
Ok(( Ok((
rustls::Certificate(ca.serialize_der()?), CertificateDer::from(ca.serialize_der()?),
rustls::Certificate(cert.serialize_der_with_signer(&ca)?), CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
rustls::PrivateKey(cert.serialize_private_key_der()), PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(cert.serialize_private_key_der())),
)) ))
} }
@@ -73,10 +78,10 @@ fn generate_tls_config<'a>(
let (ca, cert, key) = generate_certs(hostname, common_name)?; let (ca, cert, key) = generate_certs(hostname, common_name)?;
let tls_config = { let tls_config = {
let key_clone = rustls::pki_types::PrivateKeyDer::Pkcs1(key.secret_der().to_owned().into());
let config = rustls::ServerConfig::builder() let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone())? .with_single_cert(vec![cert.clone()], key_clone)?
.into(); .into();
let mut cert_resolver = CertResolver::new(); let mut cert_resolver = CertResolver::new();
@@ -93,10 +98,9 @@ fn generate_tls_config<'a>(
let client_config = { let client_config = {
let config = rustls::ClientConfig::builder() let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({ .with_root_certificates({
let mut store = rustls::RootCertStore::empty(); let mut store = rustls::RootCertStore::empty();
store.add(&ca)?; store.add(ca)?;
store store
}) })
.with_no_client_auth(); .with_no_client_auth();

View File

@@ -77,14 +77,19 @@ pub async fn task_main(
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete` ws_connections.close(); // allows `ws_connections.wait to complete`
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { let tls_listener = TlsListener::new(tls_acceptor, addr_incoming)
if let Err(err) = conn { .map(|x| match x {
error!("failed to accept TLS connection for websockets: {err:?}"); Ok((conn, _)) => Ok(conn),
ready(false) Err(e) => Err(e),
} else { })
ready(true) .filter(|conn| {
} if let Err(err) = conn {
}); error!("failed to accept TLS connection for websockets: {err:?}");
ready(false)
} else {
ready(true)
}
});
let make_svc = hyper::service::make_service_fn( let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| { |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {

View File

@@ -33,11 +33,8 @@ dashmap = { version = "5", default-features = false, features = ["raw-api"] }
either = { version = "1" } either = { version = "1" }
fail = { version = "0.5", default-features = false, features = ["failpoints"] } fail = { version = "0.5", default-features = false, features = ["failpoints"] }
futures = { version = "0.3" } futures = { version = "0.3" }
futures-channel = { version = "0.3", features = ["sink"] }
futures-core = { version = "0.3" }
futures-executor = { version = "0.3" } futures-executor = { version = "0.3" }
futures-io = { version = "0.3" } futures-io = { version = "0.3" }
futures-sink = { version = "0.3" }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] } futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
hex = { version = "0.4", features = ["serde"] } hex = { version = "0.4", features = ["serde"] }
hmac = { version = "0.12", default-features = false, features = ["reset"] } hmac = { version = "0.12", default-features = false, features = ["reset"] }
@@ -56,7 +53,6 @@ regex = { version = "1" }
regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] } regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] }
regex-syntax = { version = "0.8" } regex-syntax = { version = "0.8" }
reqwest = { version = "0.11", default-features = false, features = ["blocking", "default-tls", "json", "multipart", "rustls-tls", "stream"] } reqwest = { version = "0.11", default-features = false, features = ["blocking", "default-tls", "json", "multipart", "rustls-tls", "stream"] }
ring = { version = "0.16" }
rustls = { version = "0.21", features = ["dangerous_configuration"] } rustls = { version = "0.21", features = ["dangerous_configuration"] }
scopeguard = { version = "1" } scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] } serde = { version = "1", features = ["alloc", "derive"] }