diff --git a/proxy/src/bin/auth_proxy.rs b/proxy/src/bin/auth_proxy.rs index 848d7fe9fe..55a7c37bde 100644 --- a/proxy/src/bin/auth_proxy.rs +++ b/proxy/src/bin/auth_proxy.rs @@ -1,11 +1,9 @@ use std::{sync::Arc, time::Duration}; use clap::Parser; -use futures::TryStreamExt; -use pq_proto::FeStartupPacket; use proxy::{ - auth::{self, backend::AuthRateLimiter}, - auth_proxy::{self, backend::MaybeOwned, AuthProxyStream, Backend}, + auth::backend::AuthRateLimiter, + auth_proxy::{backend::MaybeOwned, Backend}, config::{self, AuthenticationConfig, CacheOptions, ProjectInfoCacheOptions}, console::{ caches::ApiCaches, @@ -14,23 +12,18 @@ use proxy::{ }, http, metrics::Metrics, - proxy::NeonOptions, + proxy::{handle_stream, AuthProxyConfig}, rate_limiter::{RateBucketInfo, WakeComputeRateLimiter}, scram::threadpool::ThreadPool, - stream::AuthProxyStreamExt, - PglbCodec, PglbControlMessage, PglbMessage, -}; -use quinn::{ - crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream, - VarInt, }; +use quinn::{crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, VarInt}; use tokio::{ - io::{join, AsyncWriteExt}, + io::AsyncWriteExt, select, signal::unix::{signal, SignalKind}, time::interval, }; -use tokio_util::{codec::Framed, task::TaskTracker}; +use tokio_util::task::TaskTracker; /// Neon proxy/router #[derive(Parser)] @@ -177,7 +170,7 @@ async fn main() { rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, }; - let config = Box::leak(Box::new(Config { backend, auth })); + let config = Box::leak(Box::new(AuthProxyConfig { backend, auth })); loop { select! { @@ -208,11 +201,6 @@ async fn main() { conn.close(VarInt::from_u32(1), b"graceful shutdown"); } -struct Config { - backend: Backend<'static, ()>, - auth: AuthenticationConfig, -} - #[derive(Copy, Clone, Debug)] struct NoVerify; @@ -250,36 +238,3 @@ impl danger::ServerCertVerifier for NoVerify { vec![quinn::rustls::SignatureScheme::ECDSA_NISTP256_SHA256] } } - -async fn handle_stream(config: &'static Config, send: SendStream, recv: RecvStream) { - let mut stream: AuthProxyStream = Framed::new(join(recv, send), PglbCodec); - - let first_msg = stream.try_next().await.unwrap(); - let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg - else { - panic!("invalid first msg") - }; - - let startup = stream.read_startup_packet().await.unwrap(); - let FeStartupPacket::StartupMessage { version, params } = startup else { - panic!("invalid startup message") - }; - - // Extract credentials which we're going to use for auth. - let user_info = auth::ComputeUserInfoMaybeEndpoint { - user: params.get("user").unwrap().into(), - endpoint_id: first_msg - .server_name - .as_deref() - .map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()), - options: NeonOptions::parse_params(¶ms), - }; - - let user_info = config.backend.as_ref().map(|()| user_info); - let user_info = match user_info.authenticate(&mut stream, &config.auth).await { - Ok(auth_result) => auth_result, - Err(e) => { - return stream.throw_error(e).await.unwrap(); - } - }; -} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 8806f31f14..14381fdbf9 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -7,9 +7,20 @@ pub(crate) mod handshake; pub(crate) mod passthrough; pub(crate) mod retry; pub(crate) mod wake_compute; +use connect_compute::ComputeConnectBackend; pub use copy_bidirectional::copy_bidirectional_client_compute; pub use copy_bidirectional::ErrorSource; +use futures::TryStreamExt; +use pq_proto::FeStartupPacket; +use quinn::RecvStream; +use quinn::SendStream; +use tokio::io::join; +use tokio_util::codec::Framed; +use crate::auth_proxy::AuthProxyStream; +use crate::stream::AuthProxyStreamExt; +use crate::PglbControlMessage; +use crate::PglbMessage; use crate::{ auth, cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}, @@ -431,3 +442,46 @@ pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> { let (_, [k, v]) = cap.extract(); Some((k, v)) } + +pub struct AuthProxyConfig { + pub backend: crate::auth_proxy::Backend<'static, ()>, + pub auth: crate::config::AuthenticationConfig, +} + +pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, recv: RecvStream) { + let mut stream: AuthProxyStream = Framed::new(join(recv, send), crate::PglbCodec); + + let first_msg = stream.try_next().await.unwrap(); + let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg + else { + panic!("invalid first msg") + }; + + let startup = stream.read_startup_packet().await.unwrap(); + let FeStartupPacket::StartupMessage { version: _, params } = startup else { + panic!("invalid startup message") + }; + + // Extract credentials which we're going to use for auth. + let user_info = auth::ComputeUserInfoMaybeEndpoint { + user: params.get("user").unwrap().into(), + endpoint_id: first_msg + .server_name + .as_deref() + .map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()), + options: NeonOptions::parse_params(¶ms), + }; + + let user_info = config.backend.as_ref().map(|()| user_info); + let user_info = match user_info.authenticate(&mut stream, &config.auth).await { + Ok(auth_result) => auth_result, + Err(e) => { + return stream.throw_error(e).await.unwrap(); + } + }; + + user_info + .wake_compute(&RequestMonitoring::test()) + .await + .unwrap(); +} diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index a1331e5b2c..0b9ad33a87 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -299,7 +299,7 @@ impl AsyncWrite for Stream { } #[allow(async_fn_in_trait)] -pub trait AuthProxyStreamExt { +pub(crate) trait AuthProxyStreamExt { /// Write the message into an internal buffer, but don't flush the underlying stream. fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;