From 2cca1b3e4e30c0b2d5073753bd0b9d96d752fd86 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 21 Aug 2024 18:44:57 +0100 Subject: [PATCH] fix --- proxy/src/auth/backend.rs | 22 +++++++++++++++------- proxy/src/auth/backend/classic.rs | 7 ++++++- proxy/src/auth/backend/hacks.rs | 11 +++++++++-- proxy/src/auth/flow.rs | 17 +++++++++++------ proxy/src/bin/pg_sni_router.rs | 6 ++++-- proxy/src/bin/proxy.rs | 12 +++++------- proxy/src/config.rs | 7 ++++--- proxy/src/protocol2.rs | 2 +- proxy/src/proxy.rs | 2 +- proxy/src/proxy/handshake.rs | 6 ++++-- proxy/src/proxy/passthrough.rs | 8 +++++--- proxy/src/proxy/tests.rs | 10 +++++++--- proxy/src/serverless.rs | 2 ++ proxy/src/serverless/websocket.rs | 4 ++-- proxy/src/stream.rs | 24 +++++++++++++----------- 15 files changed, 89 insertions(+), 51 deletions(-) diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 7592d076ec..56795dc74e 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -4,6 +4,7 @@ pub mod jwt; mod link; use std::net::IpAddr; +use std::os::fd::AsRawFd; use std::sync::Arc; use std::time::Duration; @@ -23,6 +24,7 @@ use crate::context::RequestMonitoring; use crate::intern::EndpointIdInt; use crate::metrics::Metrics; use crate::proxy::connect_compute::ComputeConnectBackend; +use crate::proxy::handshake::KtlsAsyncReadReady; use crate::proxy::NeonOptions; use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo}; use crate::stream::Stream; @@ -274,7 +276,9 @@ async fn auth_quirks( ctx: &RequestMonitoring, api: &impl console::Api, user_info: ComputeUserInfoMaybeEndpoint, - client: &mut stream::PqStream>, + client: &mut stream::PqStream< + Stream, + >, allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, @@ -358,7 +362,9 @@ async fn authenticate_with_secret( ctx: &RequestMonitoring, secret: AuthSecret, info: ComputeUserInfo, - client: &mut stream::PqStream>, + client: &mut stream::PqStream< + Stream, + >, unauthenticated_password: Option>, allow_cleartext: bool, config: &'static AuthenticationConfig, @@ -417,7 +423,9 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { pub async fn authenticate( self, ctx: &RequestMonitoring, - client: &mut stream::PqStream>, + client: &mut stream::PqStream< + Stream, + >, allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, @@ -542,7 +550,7 @@ mod tests { CachedNodeInfo, }, context::RequestMonitoring, - proxy::NeonOptions, + proxy::{tests::DummyClient, NeonOptions}, rate_limiter::{EndpointRateLimiter, RateBucketInfo}, scram::{threadpool::ThreadPool, ServerSecret}, stream::{PqStream, Stream}, @@ -650,7 +658,7 @@ mod tests { #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new(Stream::from_raw(DummyClient(server))); let ctx = RequestMonitoring::test(); let api = Auth { @@ -727,7 +735,7 @@ mod tests { #[tokio::test] async fn auth_quirks_cleartext() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new(Stream::from_raw(DummyClient(server))); let ctx = RequestMonitoring::test(); let api = Auth { @@ -779,7 +787,7 @@ mod tests { #[tokio::test] async fn auth_quirks_password_hack() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new(Stream::from_raw(DummyClient(server))); let ctx = RequestMonitoring::test(); let api = Auth { diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 285fa29428..f9a3a0ffe3 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -1,3 +1,5 @@ +use std::os::fd::AsRawFd; + use super::{ComputeCredentials, ComputeUserInfo}; use crate::{ auth::{self, backend::ComputeCredentialKeys, AuthFlow}, @@ -5,6 +7,7 @@ use crate::{ config::AuthenticationConfig, console::AuthSecret, context::RequestMonitoring, + proxy::handshake::KtlsAsyncReadReady, sasl, stream::{PqStream, Stream}, }; @@ -14,7 +17,9 @@ use tracing::{info, warn}; pub(super) async fn authenticate( ctx: &RequestMonitoring, creds: ComputeUserInfo, - client: &mut PqStream>, + client: &mut PqStream< + Stream, + >, config: &'static AuthenticationConfig, secret: AuthSecret, ) -> auth::Result { diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 56921dd949..157781638c 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -1,3 +1,5 @@ +use std::os::fd::AsRawFd; + use super::{ ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint, }; @@ -7,6 +9,7 @@ use crate::{ console::AuthSecret, context::RequestMonitoring, intern::EndpointIdInt, + proxy::handshake::KtlsAsyncReadReady, sasl, stream::{self, Stream}, }; @@ -20,7 +23,9 @@ use tracing::{info, warn}; pub async fn authenticate_cleartext( ctx: &RequestMonitoring, info: ComputeUserInfo, - client: &mut stream::PqStream>, + client: &mut stream::PqStream< + Stream, + >, secret: AuthSecret, config: &'static AuthenticationConfig, ) -> auth::Result { @@ -62,7 +67,9 @@ pub async fn authenticate_cleartext( pub async fn password_hack_no_authentication( ctx: &RequestMonitoring, info: ComputeUserInfoNoEndpoint, - client: &mut stream::PqStream>, + client: &mut stream::PqStream< + Stream, + >, ) -> auth::Result { warn!("project not specified, resorting to the password hack auth flow"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index acf7b4f6b6..0d1b7af67b 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -6,13 +6,14 @@ use crate::{ console::AuthSecret, context::RequestMonitoring, intern::EndpointIdInt, + proxy::handshake::KtlsAsyncReadReady, sasl, scram::{self, threadpool::ThreadPool}, stream::{PqStream, Stream}, }; use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; -use std::{io, sync::Arc}; +use std::{io, os::fd::AsRawFd, sync::Arc}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -70,7 +71,7 @@ impl AuthMethod for CleartextPassword { /// This wrapper for [`PqStream`] performs client authentication. #[must_use] -pub struct AuthFlow<'a, S, State> { +pub struct AuthFlow<'a, S: AsRawFd, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream>, /// State might contain ancillary data (see [`Self::begin`]). @@ -79,7 +80,7 @@ pub struct AuthFlow<'a, S, State> { } /// Initial state of the stream wrapper. -impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { +impl<'a, S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AuthFlow<'a, S, Begin> { /// Create a new wrapper for client authentication. pub fn new(stream: &'a mut PqStream>) -> Self { let tls_server_end_point = stream.get_ref().tls_server_end_point(); @@ -105,7 +106,9 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { } } -impl AuthFlow<'_, S, PasswordHack> { +impl + AuthFlow<'_, S, PasswordHack> +{ /// Perform user authentication. Raise an error in case authentication failed. pub async fn get_password(self) -> super::Result { let msg = self.stream.read_password_message().await?; @@ -124,7 +127,9 @@ impl AuthFlow<'_, S, PasswordHack> { } } -impl AuthFlow<'_, S, CleartextPassword> { +impl + AuthFlow<'_, S, CleartextPassword> +{ /// Perform user authentication. Raise an error in case authentication failed. pub async fn authenticate(self) -> super::Result> { let msg = self.stream.read_password_message().await?; @@ -149,7 +154,7 @@ impl AuthFlow<'_, S, CleartextPassword> { } /// Stream wrapper for handling [SCRAM](crate::scram) auth. -impl AuthFlow<'_, S, Scram<'_>> { +impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub async fn authenticate(self) -> super::Result> { let Scram(secret, ctx) = self.state; diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 4c67b206b1..f83aeb76a3 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -1,3 +1,4 @@ +use std::os::fd::AsRawFd; /// A stand-alone program that routes connections, e.g. from /// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`. /// @@ -9,6 +10,7 @@ use futures::future::Either; use itertools::Itertools; use proxy::context::RequestMonitoring; use proxy::metrics::{Metrics, ThreadPoolMetrics}; +use proxy::proxy::handshake::KtlsAsyncReadReady; use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource}; use rustls::pki_types::PrivateKeyDer; use tokio::net::TcpListener; @@ -197,7 +199,7 @@ async fn task_main( const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; -async fn ssl_handshake( +async fn ssl_handshake( ctx: &RequestMonitoring, raw_stream: S, tls_config: Arc, @@ -248,7 +250,7 @@ async fn handle_client( ctx: RequestMonitoring, dest_suffix: Arc, tls_config: Arc, - stream: impl AsyncRead + AsyncWrite + Unpin, + stream: impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady, ) -> anyhow::Result<()> { let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index d83a1f3bcf..b9cc783ee1 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -285,7 +285,7 @@ async fn main() -> anyhow::Result<()> { }; let args = ProxyCliArgs::parse(); - let config = build_config(&args)?; + let config = build_config(&args).await?; info!("Authentication backend: {}", config.auth_backend); info!("Using region: {}", args.aws_region); @@ -529,16 +529,14 @@ async fn main() -> anyhow::Result<()> { } /// ProxyConfig is created at proxy startup, and lives forever. -fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { +async fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let thread_pool = ThreadPool::new(args.scram_thread_pool_size); Metrics::install(thread_pool.metrics.clone()); let tls_config = match (&args.tls_key, &args.tls_cert) { - (Some(key_path), Some(cert_path)) => Some(config::configure_tls( - key_path, - cert_path, - args.certs_dir.as_ref(), - )?), + (Some(key_path), Some(cert_path)) => { + Some(config::configure_tls(key_path, cert_path, args.certs_dir.as_ref()).await?) + } (None, None) => None, _ => bail!("either both or neither tls-key and tls-cert must be specified"), }; diff --git a/proxy/src/config.rs b/proxy/src/config.rs index b26ee733d9..eed5632aad 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -76,7 +76,7 @@ impl TlsConfig { pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql"; /// Configure TLS for the main endpoint. -pub fn configure_tls( +pub async fn configure_tls( key_path: &str, cert_path: &str, certs_dir: Option<&String>, @@ -114,8 +114,9 @@ pub fn configure_tls( #[cfg(target_os = "linux")] let provider = { let mut provider = provider; - let compat = ktls::CompatibleCiphers::new()?; - provider.cipher_suites.retain(|s| compat.is_supported(s)); + let compat = ktls::CompatibleCiphers::new().await?; + provider.cipher_suites.retain(|s| compat.is_compatible(*s)); + provider }; // allow TLS 1.2 to be compatible with older client libraries diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index a5616af282..3456b7b9f8 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -28,7 +28,7 @@ impl AsRawFd for ChainRW { } #[cfg(all(target_os = "linux", not(test)))] -impl AsRawFd for ChainRW { +impl ktls::AsyncReadReady for ChainRW { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { if self.buf.is_empty() { self.inner.poll_read_ready(cx) diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 88c1fc1ce3..78446a614f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,5 +1,5 @@ #[cfg(test)] -mod tests; +pub mod tests; pub mod connect_compute; mod copy_bidirectional; diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index f6fcfe395e..62c173eb10 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -50,6 +50,8 @@ impl ReportableError for HandshakeError { match self { HandshakeError::EarlyData => crate::error::ErrorKind::User, HandshakeError::ProtocolViolation => crate::error::ErrorKind::User, + #[cfg(all(target_os = "linux", not(test)))] + HandshakeError::KtlsUpgradeError(_) => crate::error::ErrorKind::Service, // This error should not happen, but will if we have no default certificate and // the client sends no SNI extension. // If they provide SNI then we can be sure there is a certificate that matches. @@ -64,7 +66,7 @@ impl ReportableError for HandshakeError { } } -pub enum HandshakeData { +pub enum HandshakeData { Startup( PqStream>, Option, @@ -201,7 +203,7 @@ where #[cfg(any(not(target_os = "linux"), test))] tls: Box::pin(tls_stream), #[cfg(all(target_os = "linux", not(test)))] - tls: Box::pin(ktls::config_ktls_server(tls_stream)?), + tls: ktls::config_ktls_server(tls_stream).await?, tls_server_end_point, }, read_buf, diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 9942fac383..d3c11ca9ad 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,3 +1,5 @@ +use std::os::fd::AsRawFd; + use crate::{ cancellation, compute::PostgresConnection, @@ -10,7 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; use utils::measured_stream::MeasuredStream; -use super::copy_bidirectional::ErrorSource; +use super::{copy_bidirectional::ErrorSource, handshake::KtlsAsyncReadReady}; /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(skip_all)] @@ -57,7 +59,7 @@ pub async fn proxy_pass( Ok(()) } -pub struct ProxyPassthrough { +pub struct ProxyPassthrough { pub client: Stream, pub compute: PostgresConnection, pub aux: MetricsAuxInfo, @@ -67,7 +69,7 @@ pub struct ProxyPassthrough { pub cancel: cancellation::Session

, } -impl ProxyPassthrough { +impl ProxyPassthrough { pub async fn proxy_pass(self) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 9e4aae17ea..5e78ea70ff 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -64,7 +64,7 @@ fn generate_certs( )) } -struct DummyClient(DuplexStream); +pub struct DummyClient(pub DuplexStream); impl AsRawFd for DummyClient { fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd { @@ -170,7 +170,9 @@ fn generate_tls_config<'a>( #[async_trait] trait TestAuth: Sized { - async fn authenticate( + async fn authenticate< + S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady, + >( self, stream: &mut PqStream>, ) -> anyhow::Result<()> { @@ -199,7 +201,9 @@ impl Scram { #[async_trait] impl TestAuth for Scram { - async fn authenticate( + async fn authenticate< + S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady, + >( self, stream: &mut PqStream>, ) -> anyhow::Result<()> { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index c2d4139fbe..f62145b855 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -197,6 +197,8 @@ impl MaybeTlsAcceptor for rustls::ServerConfig { #[cfg(all(target_os = "linux", not(test)))] return ktls::config_ktls_server(tls) + .await + .map(|s| Box::pin(s) as _) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); #[cfg(any(not(target_os = "linux"), test))] diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 597fa71bf2..e0322b2050 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -52,8 +52,8 @@ impl AsRawFd for WebSocketRw { } } #[cfg(all(target_os = "linux", not(test)))] -impl AsRawFd for ChainRW { - fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { +impl ktls::AsyncReadReady for WebSocketRw { + fn poll_read_ready(&self, _cx: &mut Context<'_>) -> Poll> { unreachable!("ktls should not need to be used for websocket rw") } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index d9bf3b86cb..b25047e95e 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,11 +1,13 @@ use crate::config::TlsServerEndPoint; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; +use crate::proxy::handshake::KtlsAsyncReadReady; use bytes::BytesMut; use pq_proto::framed::{ConnectionError, Framed}; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; +use std::os::fd::AsRawFd; use std::pin::Pin; use std::sync::Arc; use std::{io, task}; @@ -172,7 +174,7 @@ impl PqStream { } /// Wrapper for upgrading raw streams into secure streams. -pub enum Stream { +pub enum Stream { /// We always begin with a raw stream, /// which may then be upgraded into a secure stream. Raw { raw: S }, @@ -182,16 +184,16 @@ pub enum Stream { tls: Pin>>, #[cfg(all(target_os = "linux", not(test)))] - tls: Pin>>, + tls: ktls::KtlsStream, /// Channel binding parameter tls_server_end_point: TlsServerEndPoint, }, } -impl Unpin for Stream {} +impl Unpin for Stream {} -impl Stream { +impl Stream { /// Construct a new instance from a raw stream. pub fn from_raw(raw: S) -> Self { Self::Raw { raw } @@ -218,7 +220,7 @@ pub enum StreamUpgradeError { Io(#[from] io::Error), } -impl Stream { +impl Stream { /// If possible, upgrade raw stream into a secure TLS-based stream. pub async fn upgrade( self, @@ -239,7 +241,7 @@ impl Stream { } } -impl AsyncRead for Stream { +impl AsyncRead for Stream { fn poll_read( mut self: Pin<&mut Self>, context: &mut task::Context<'_>, @@ -247,12 +249,12 @@ impl AsyncRead for Stream { ) -> task::Poll> { match &mut *self { Self::Raw { raw } => Pin::new(raw).poll_read(context, buf), - Self::Tls { tls, .. } => tls.as_mut().poll_read(context, buf), + Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf), } } } -impl AsyncWrite for Stream { +impl AsyncWrite for Stream { fn poll_write( mut self: Pin<&mut Self>, context: &mut task::Context<'_>, @@ -260,7 +262,7 @@ impl AsyncWrite for Stream { ) -> task::Poll> { match &mut *self { Self::Raw { raw } => Pin::new(raw).poll_write(context, buf), - Self::Tls { tls, .. } => tls.as_mut().poll_write(context, buf), + Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf), } } @@ -270,7 +272,7 @@ impl AsyncWrite for Stream { ) -> task::Poll> { match &mut *self { Self::Raw { raw } => Pin::new(raw).poll_flush(context), - Self::Tls { tls, .. } => tls.as_mut().poll_flush(context), + Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context), } } @@ -280,7 +282,7 @@ impl AsyncWrite for Stream { ) -> task::Poll> { match &mut *self { Self::Raw { raw } => Pin::new(raw).poll_shutdown(context), - Self::Tls { tls, .. } => tls.as_mut().poll_shutdown(context), + Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context), } } }