diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index 70ada56501..37b3d2b15c 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -7,23 +7,32 @@ use std::{ use anyhow::{anyhow, bail, Context, Result}; use bytes::BytesMut; +use futures::sink::SinkExt; use indexmap::IndexMap; -use proxy::config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL}; -use quinn::{Connection, Endpoint}; +use proxy::{ + config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL}, + ConnectionInitiatedPayload, PglbCodec, +}; +use quinn::{Connection, Endpoint, RecvStream, SendStream}; +use rand::Rng; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use tokio::{ - io::AsyncReadExt, + io::{join, AsyncReadExt, Join}, net::{TcpListener, TcpStream}, time::timeout, }; use tokio_rustls::server::TlsStream; +use tokio_util::codec::Framed; use tracing::{error, warn}; type AuthConnId = usize; + +#[derive(Debug)] struct AuthConnState { conns: Mutex>, } +#[derive(Clone, Debug)] struct AuthConn { conn: Connection, // latency info... @@ -46,7 +55,11 @@ async fn main() -> Result<()> { let frontend_config = frontent_tls_config("pglb-fe", "pglb-fe")?; - let _frontend_handle = tokio::spawn(start_frontend("0.0.0.0:5432".parse()?, frontend_config)); + let _frontend_handle = tokio::spawn(start_frontend( + "0.0.0.0:5432".parse()?, + frontend_config, + auth_connections.clone(), + )); quinn_handle.await.unwrap(); @@ -151,17 +164,28 @@ fn frontent_tls_config(hostname: &str, common_name: &str) -> Result { }) } -async fn start_frontend(addr: SocketAddr, tls: TlsConfig) -> Result { +async fn start_frontend( + addr: SocketAddr, + tls: TlsConfig, + state: Arc, +) -> Result { let listener = TcpListener::bind(addr).await?; socket2::SockRef::from(&listener).set_keepalive(true)?; - let connections = tokio_util::task::task_tracker::TaskTracker::new(); + let workers = tokio_util::task::task_tracker::TaskTracker::new(); loop { match listener.accept().await { - Ok((socket, peer_addr)) => { - let tls = tls.clone(); - connections.spawn_local(handle_frontend_connection(socket, peer_addr, tls)); + Ok((socket, client_addr)) => { + let w = Worker { + state: Some(WorkerState::Start { + client_stream: socket, + client_addr, + }), + tls: tls.clone(), + auth_conns: Arc::clone(&state), + }; + workers.spawn(w.start()); } Err(e) => { error!("connection accept error: {e}"); @@ -170,38 +194,207 @@ async fn start_frontend(addr: SocketAddr, tls: TlsConfig) -> Result } } -async fn handle_frontend_connection(mut stream: TcpStream, _peer_addr: SocketAddr, tls: TlsConfig) { - match stream.set_nodelay(true) { - Ok(()) => {} - Err(e) => { - error!("socket option error: {e}"); - return; +#[derive(Debug)] +struct Worker { + state: Option, + tls: TlsConfig, + auth_conns: Arc, +} + +#[derive(Debug)] +enum WorkerState { + Start { + client_stream: TcpStream, + client_addr: SocketAddr, + }, + AuthConnect { + client_stream: TlsStream, + payload: ConnectionInitiatedPayload, + }, + AuthPassthrough { + client_stream: TlsStream, + auth_stream: Framed, PglbCodec>, + }, + ComputeConnect { + client_stream: TlsStream, + auth_stream: Framed, PglbCodec>, + }, + ComputePassthrough { + client_stream: TlsStream, + compute_conn: (), + }, + End, +} + +impl Worker { + async fn start(mut self) -> Result<()> { + loop { + match self.state.take().expect("state should be specified") { + WorkerState::Start { + client_stream, + client_addr, + } => { + self.state = Some(self.handle_start(client_stream, client_addr).await?); + } + + WorkerState::AuthConnect { + client_stream, + payload, + } => { + self.state = Some(self.handle_auth_connect(client_stream, payload).await?); + } + + WorkerState::AuthPassthrough { + client_stream, + auth_stream, + } => { + self.state = Some( + self.handle_auth_passthrough(client_stream, auth_stream) + .await?, + ); + } + + WorkerState::ComputeConnect { + client_stream, + auth_stream, + } => { + self.state = Some( + self.handle_compute_connect(client_stream, auth_stream) + .await?, + ); + } + + WorkerState::ComputePassthrough { + client_stream, + compute_conn, + } => { + self.state = Some( + self.handle_compute_passthrough(client_stream, compute_conn) + .await?, + ); + } + + WorkerState::End => { + return Ok(()); + } + } } - }; + } - // TODO: HAProxy protocol? - - let tls_requested = match handle_ssl_request_message(&mut stream).await { - Ok(tls_requested) => tls_requested, - Err(e) => { - error!("check_for_ssl_request: {e}"); - return; - } - }; - - if tls_requested { - let (stream, ep, sn) = match tls_upgrade(stream, tls).await { - Ok((stream, ep, sn)) => (stream, ep, sn), + async fn handle_start( + &self, + mut client_stream: TcpStream, + peer_addr: SocketAddr, + ) -> Result { + match client_stream.set_nodelay(true) { + Ok(()) => {} Err(e) => { - error!("tls_upgrade: {e}"); - return; + bail!("socket option error: {e}"); } }; - // TODO: send auth msg with tls ep and server name - } else { - // TODO: send auth msg without server name + // TODO: HAProxy protocol + + let tls_requested = match handle_ssl_request_message(&mut client_stream).await { + Ok(tls_requested) => tls_requested, + Err(e) => { + bail!("check_for_ssl_request: {e}"); + } + }; + + let (client_stream, payload) = if tls_requested { + let (stream, tls_server_end_point, server_name) = + match tls_upgrade(client_stream, self.tls.clone()).await { + Ok((stream, ep, sn)) => (stream, ep, sn), + Err(e) => { + bail!("tls_upgrade: {e}"); + } + }; + + ( + stream, + ConnectionInitiatedPayload { + tls_server_end_point, + server_name, + ip_addr: peer_addr.ip(), + }, + ) + } else { + // TODO: support unsecured connections? + bail!("closing non-TLS connection"); + }; + + Ok(WorkerState::AuthConnect { + client_stream, + payload, + }) } + + async fn handle_auth_connect( + &self, + client_stream: TlsStream, + payload: ConnectionInitiatedPayload, + ) -> Result { + let auth_conn = { + let conns = self.auth_conns.conns.lock().unwrap(); + if conns.is_empty() { + bail!("no auth proxies avaiable"); + } + + let mut rng = rand::thread_rng(); + conns + .get_index(rng.gen_range(0..conns.len())) + .unwrap() + .1 + .clone() + + // TODO: check closed? + }; + + let (send, recv) = auth_conn.conn.open_bi().await?; + let mut auth_stream = Framed::new(join(recv, send), PglbCodec); + + auth_stream + .send(proxy::PglbMessage::Control( + proxy::PglbControlMessage::ConnectionInitiated(payload), + )) + .await?; + + Ok(WorkerState::AuthPassthrough { + client_stream, + auth_stream, + }) + } + + async fn handle_auth_passthrough( + &self, + _client_stream: TlsStream, + _auth_stream: Framed, PglbCodec>, + ) -> Result { + todo!() + } + + async fn handle_compute_connect( + &self, + _client_stream: TlsStream, + _auth_stream: Framed, PglbCodec>, + ) -> Result { + todo!() + } + + async fn handle_compute_passthrough( + &self, + _client_stream: TlsStream, + _compute_conn: (), + ) -> Result { + todo!() + } +} + +#[derive(Clone, Debug)] +struct TlsConfig { + config: Arc, + cert_resolver: Arc, } async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result { @@ -253,9 +446,3 @@ async fn tls_upgrade( Ok((tls_stream, tls_server_end_point, server_name)) } - -#[derive(Clone, Debug)] -struct TlsConfig { - config: Arc, - cert_resolver: Arc, -} diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 82c8a01301..bc953b2217 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -283,6 +283,7 @@ impl EndpointId { } } +#[derive(Debug)] pub struct PglbCodec; impl tokio_util::codec::Encoder for PglbCodec { @@ -404,7 +405,7 @@ pub enum PglbControlMessage { ComputeEstablish, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct ConnectionInitiatedPayload { pub tls_server_end_point: TlsServerEndPoint, pub server_name: Option,