From 37221f32521504412d3e4eea5ee24a623b072dd0 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 12 Sep 2024 17:55:27 +0100 Subject: [PATCH] properly handle tls-server-end-point --- proxy/src/auth_proxy/flow.rs | 9 ++++++--- proxy/src/proxy.rs | 22 ++++++++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/proxy/src/auth_proxy/flow.rs b/proxy/src/auth_proxy/flow.rs index fdf1e6bdae..0f9af13f03 100644 --- a/proxy/src/auth_proxy/flow.rs +++ b/proxy/src/auth_proxy/flow.rs @@ -11,6 +11,7 @@ use crate::{ stream::AuthProxyStreamExt, }; use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; +use tokio::task_local; use std::{io, sync::Arc}; use tracing::info; @@ -76,13 +77,15 @@ pub(crate) struct AuthFlow<'a, State> { tls_server_end_point: TlsServerEndPoint, } +task_local! { + pub(crate) static TLS_SERVER_END_POINT: TlsServerEndPoint; +} + /// Initial state of the stream wrapper. impl<'a> AuthFlow<'a, Begin> { /// Create a new wrapper for client authentication. pub(crate) fn new(stream: &'a mut AuthProxyStream) -> Self { - // TODO: - // let tls_server_end_point = stream.get_ref().tls_server_end_point(); - let tls_server_end_point = TlsServerEndPoint::Undefined; + let tls_server_end_point = TLS_SERVER_END_POINT.get(); Self { stream, diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 14381fdbf9..0996c18efc 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -10,6 +10,7 @@ pub(crate) mod wake_compute; use connect_compute::ComputeConnectBackend; pub use copy_bidirectional::copy_bidirectional_client_compute; pub use copy_bidirectional::ErrorSource; +use futures::SinkExt; use futures::TryStreamExt; use pq_proto::FeStartupPacket; use quinn::RecvStream; @@ -18,6 +19,7 @@ use tokio::io::join; use tokio_util::codec::Framed; use crate::auth_proxy::AuthProxyStream; +use crate::auth_proxy::TLS_SERVER_END_POINT; use crate::stream::AuthProxyStreamExt; use crate::PglbControlMessage; use crate::PglbMessage; @@ -41,6 +43,7 @@ use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, StartupMessageParams}; use regex::Regex; use smol_str::{format_smolstr, SmolStr}; +use std::net::SocketAddr; use std::sync::Arc; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; @@ -473,15 +476,30 @@ pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, r }; let user_info = config.backend.as_ref().map(|()| user_info); - let user_info = match user_info.authenticate(&mut stream, &config.auth).await { + let res = TLS_SERVER_END_POINT + .scope( + first_msg.tls_server_end_point, + user_info.authenticate(&mut stream, &config.auth), + ) + .await; + let user_info = match res { Ok(auth_result) => auth_result, Err(e) => { return stream.throw_error(e).await.unwrap(); } }; - user_info + let node_info = user_info .wake_compute(&RequestMonitoring::test()) .await .unwrap(); + + let socket: SocketAddr = node_info.config.get_host().unwrap().parse().unwrap(); + + stream + .send(PglbMessage::Control(PglbControlMessage::ConnectToCompute { + socket, + })) + .await + .unwrap(); }