From e06aa4b91dc9851470393109951388d19de3180a Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Fri, 13 Sep 2024 09:49:11 +0100 Subject: [PATCH] impl compute connect --- proxy/src/bin/pglb.rs | 82 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 76 insertions(+), 6 deletions(-) diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index c4f9d5d5b8..d671553ec0 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -1,6 +1,6 @@ use std::{ convert::Infallible, - net::{IpAddr, SocketAddr}, + net::SocketAddr, sync::{Arc, Mutex}, time::Duration, }; @@ -257,7 +257,7 @@ impl PglbConn { .await? .handle_auth_passthrough() .await? - .handle_connect_connect() + .handle_compute_connect() .await? .handle_compute_passthrough() .await @@ -499,15 +499,85 @@ struct ComputeConnect { } impl PglbConn { - async fn handle_connect_connect(self) -> Result> { - todo!() + async fn handle_compute_connect(self) -> Result> { + println!("connecting to compute..."); + let ComputeConnect { + client_stream, + mut auth_stream, + compute_socket, + } = self.state; + let compute_stream = TcpStream::connect(compute_socket).await?; + println!("connected to compute"); + match compute_stream.set_nodelay(true) { + Ok(()) => {} + Err(e) => { + bail!("socket option error: {e}"); + } + }; + + let mut compute_stream = Framed::new( + compute_stream, + PgRawCodec { + start_or_ssl_request: true, + }, + ); + + loop { + select! { + msg = auth_stream.next() => { + let Some(msg) = msg else { + bail!("auth proxy disconnected"); + }; + match msg? { + PglbMessage::Postgres(mut payload) => { + let Some(msg) = PgRawMessage::decode(&mut payload, false)? else { + bail!("auth proxy sent invalid message"); + }; + compute_stream.send(msg).await?; + } + PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => { + bail!("auth proxy sent unexpected message"); + } + PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => { + bail!("auth proxy sent unexpected message"); + } + PglbMessage::Control(PglbControlMessage::ComputeEstablish) => { + println!("establish"); + return Ok(PglbConn { + inner: self.inner, + state: ComputePassthrough { + client_stream, + compute_stream, + }, + }); + } + } + } + + msg = compute_stream.next() => { + let Some(msg) = msg else { + bail!("compute disconnected"); + }; + match msg? { + PgRawMessage::SslRequest => bail!("protocol violation"), + msg => { + let mut buf = BytesMut::new(); + msg.encode(&mut buf)?; + auth_stream.send(proxy::PglbMessage::Postgres( + buf + )).await?; + } + } + } + } + } } } #[derive(Debug)] struct ComputePassthrough { - client_stream: TlsStream, - compute_conn: (), + client_stream: Framed, PgRawCodec>, + compute_stream: Framed, } impl PglbConn {