diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index d2011f6f98..634b576417 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -1,17 +1,17 @@ use std::{ convert::Infallible, - net::SocketAddr, + net::{IpAddr, SocketAddr}, sync::{Arc, Mutex}, time::Duration, }; use anyhow::{anyhow, bail, Context, Result}; -use bytes::BytesMut; -use futures::sink::SinkExt; +use bytes::{Buf, BufMut, BytesMut}; +use futures::{SinkExt, StreamExt}; use indexmap::IndexMap; use proxy::{ config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL}, - ConnectionInitiatedPayload, PglbCodec, + ConnectionInitiatedPayload, PglbCodec, PglbControlMessage, PglbMessage, }; use quinn::{Connection, Endpoint, RecvStream, SendStream}; use rand::Rng; @@ -19,6 +19,7 @@ use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use tokio::{ io::{join, AsyncReadExt, Join}, net::{TcpListener, TcpStream}, + select, time::timeout, }; use tokio_rustls::server::TlsStream; @@ -388,7 +389,12 @@ impl PglbConn { Ok(PglbConn { inner: self.inner, state: AuthPassthrough { - client_stream: self.state.client_stream, + client_stream: Framed::new( + self.state.client_stream, + PgRawCodec { + start_or_ssl_request: true, + }, + ), auth_stream, }, }) @@ -397,20 +403,74 @@ impl PglbConn { #[derive(Debug)] struct AuthPassthrough { - client_stream: TlsStream, + client_stream: Framed, PgRawCodec>, auth_stream: Framed, PglbCodec>, } impl PglbConn { async fn handle_auth_passthrough(self) -> Result> { - todo!() + let mut client_stream = self.state.client_stream; + let mut auth_stream = self.state.auth_stream; + + loop { + select! { + biased; + + msg = auth_stream.next() => { + let Some(msg) = msg else { + bail!("auth proxy disconnected"); + }; + match msg? { + PglbMessage::Postgres(mut payload) => { + let Some(msg) = PgRawMessage::decode(&mut payload, false)? else { + bail!("auth proxy sent invalid message"); + }; + client_stream.send(msg).await?; + } + PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => { + bail!("auth proxy sent unexpected message"); + } + PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket }) => { + return Ok(PglbConn { + inner: self.inner, + state: ComputeConnect { + client_stream, + auth_stream, + compute_socket:socket, + }, + }); + } + PglbMessage::Control(PglbControlMessage::ComputeEstablish) => { + bail!("auth proxy sent unexpected message"); + } + } + } + + msg = client_stream.next() => { + let Some(msg) = msg else { + bail!("client disconnected"); + }; + match msg? { + PgRawMessage::SslRequest => bail!("protocol violation"), + msg => { + let mut buf = BytesMut::new(); + msg.encode(&mut buf)?; + auth_stream.send(proxy::PglbMessage::Postgres( + buf + )).await?; + } + } + } + } + } } } #[derive(Debug)] struct ComputeConnect { - client_stream: TlsStream, + client_stream: Framed, PgRawCodec>, auth_stream: Framed, PglbCodec>, + compute_socket: SocketAddr, } impl PglbConn { @@ -433,3 +493,103 @@ impl PglbConn { #[derive(Debug)] struct End; + +#[derive(Debug)] +struct PgRawCodec { + start_or_ssl_request: bool, +} + +impl tokio_util::codec::Encoder for PgRawCodec { + type Error = anyhow::Error; + + fn encode(&mut self, item: PgRawMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { + item.encode(dst) + } +} + +impl tokio_util::codec::Decoder for PgRawCodec { + type Item = PgRawMessage; + type Error = anyhow::Error; + + fn decode(&mut self, src: &mut bytes::BytesMut) -> Result, Self::Error> { + if self.start_or_ssl_request { + match PgRawMessage::decode(src, true)? { + msg @ Some(PgRawMessage::Start(..)) => { + self.start_or_ssl_request = false; + Ok(msg) + } + msg @ Some(PgRawMessage::SslRequest) => Ok(msg), + Some(PgRawMessage::Generic { .. }) => unreachable!(), + None => Ok(None), + } + } else { + match PgRawMessage::decode(src, false)? { + Some(PgRawMessage::Start(..)) => unreachable!(), + Some(PgRawMessage::SslRequest) => unreachable!(), + msg @ Some(PgRawMessage::Generic { .. }) => Ok(msg), + None => Ok(None), + } + } + } +} + +pub enum PgRawMessage { + SslRequest, + Start(Vec), + Generic { tag: u8, payload: Vec }, +} + +impl PgRawMessage { + fn encode(&self, dst: &mut bytes::BytesMut) -> Result<()> { + match self { + Self::SslRequest => { + dst.put_u32(8); + dst.put_u32(80877103); + } + Self::Start(payload) => { + dst.put_u32(payload.len() as u32 + 4); + dst.put_slice(&payload); + } + Self::Generic { tag, payload } => { + dst.put_u8(*tag); + dst.put_u32(payload.len() as u32 + 4); + dst.put_slice(&payload); + } + } + Ok(()) + } + + fn decode(src: &mut bytes::BytesMut, start: bool) -> Result> { + if start { + if src.remaining() < 4 { + src.reserve(4); + return Ok(None); + } + let length = src.get_u32() as usize - 4; + if src.remaining() < length { + src.reserve(length - src.remaining()); + return Ok(None); + } + if length == 4 && src.starts_with(&80877103u32.to_be_bytes()) { + Ok(Some(PgRawMessage::SslRequest)) + } else { + Ok(Some(PgRawMessage::Start(src.split_off(length).to_vec()))) + } + } else { + if src.remaining() < 5 { + src.reserve(5); + return Ok(None); + } + let tag = src.get_u8(); + let length = src.get_u32() as usize - 4; + if src.remaining() < length { + src.reserve(length - src.remaining()); + return Ok(None); + } + Ok(Some(PgRawMessage::Generic { + tag, + payload: src.split_off(length).to_vec(), + })) + } + } +}