properly handle tls-server-end-point

This commit is contained in:
Conrad Ludgate
2024-09-12 17:55:27 +01:00
parent f95ddef4e0
commit 37221f3252
2 changed files with 26 additions and 5 deletions

View File

@@ -11,6 +11,7 @@ use crate::{
stream::AuthProxyStreamExt,
};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use tokio::task_local;
use std::{io, sync::Arc};
use tracing::info;
@@ -76,13 +77,15 @@ pub(crate) struct AuthFlow<'a, State> {
tls_server_end_point: TlsServerEndPoint,
}
task_local! {
pub(crate) static TLS_SERVER_END_POINT: TlsServerEndPoint;
}
/// Initial state of the stream wrapper.
impl<'a> AuthFlow<'a, Begin> {
/// Create a new wrapper for client authentication.
pub(crate) fn new(stream: &'a mut AuthProxyStream) -> Self {
// TODO:
// let tls_server_end_point = stream.get_ref().tls_server_end_point();
let tls_server_end_point = TlsServerEndPoint::Undefined;
let tls_server_end_point = TLS_SERVER_END_POINT.get();
Self {
stream,

View File

@@ -10,6 +10,7 @@ pub(crate) mod wake_compute;
use connect_compute::ComputeConnectBackend;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use futures::SinkExt;
use futures::TryStreamExt;
use pq_proto::FeStartupPacket;
use quinn::RecvStream;
@@ -18,6 +19,7 @@ use tokio::io::join;
use tokio_util::codec::Framed;
use crate::auth_proxy::AuthProxyStream;
use crate::auth_proxy::TLS_SERVER_END_POINT;
use crate::stream::AuthProxyStreamExt;
use crate::PglbControlMessage;
use crate::PglbMessage;
@@ -41,6 +43,7 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::net::SocketAddr;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -473,15 +476,30 @@ pub async fn handle_stream(config: &'static AuthProxyConfig, send: SendStream, r
};
let user_info = config.backend.as_ref().map(|()| user_info);
let user_info = match user_info.authenticate(&mut stream, &config.auth).await {
let res = TLS_SERVER_END_POINT
.scope(
first_msg.tls_server_end_point,
user_info.authenticate(&mut stream, &config.auth),
)
.await;
let user_info = match res {
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await.unwrap();
}
};
user_info
let node_info = user_info
.wake_compute(&RequestMonitoring::test())
.await
.unwrap();
let socket: SocketAddr = node_info.config.get_host().unwrap().parse().unwrap();
stream
.send(PglbMessage::Control(PglbControlMessage::ConnectToCompute {
socket,
}))
.await
.unwrap();
}