diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index b6c9980496..21ae1ec15d 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -6,7 +6,7 @@ use std::{ }; use anyhow::{anyhow, bail, Context, Result}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use futures::{SinkExt, StreamExt}; use indexmap::IndexMap; use itertools::Itertools; @@ -19,7 +19,7 @@ use quinn::{Connection, Endpoint, RecvStream, SendStream}; use rand::Rng; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use tokio::{ - io::{join, AsyncReadExt, AsyncWriteExt, Join}, + io::{copy_bidirectional, join, AsyncReadExt, AsyncWriteExt, Join}, net::{TcpListener, TcpStream}, select, time::timeout, @@ -56,7 +56,7 @@ async fn main() -> Result<()> { let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone())); - let frontend_config = frontent_tls_config("*.localtest.me", "*.localtest.me")?; + let frontend_config = frontent_tls_config()?; let _frontend_handle = tokio::spawn(start_frontend( "0.0.0.0:5432".parse()?, @@ -132,27 +132,7 @@ async fn quinn_server(ep: Endpoint, state: Arc) { } } -fn frontent_tls_config(hostname: &str, common_name: &str) -> Result { - // let ca = rcgen::Certificate::from_params({ - // let mut params = rcgen::CertificateParams::default(); - // params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); - // params - // })?; - - // let cert = rcgen::Certificate::from_params({ - // let mut params = rcgen::CertificateParams::new(vec![hostname.into()]); - // params.distinguished_name = rcgen::DistinguishedName::new(); - // params - // .distinguished_name - // .push(rcgen::DnType::CommonName, common_name); - // params - // })?; - - // let (cert, key) = ( - // CertificateDer::from(cert.serialize_der_with_signer(&ca)?), - // PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()), - // ); - +fn frontent_tls_config() -> Result { let (cert, key) = ( rustls_pemfile::certs(&mut &*std::fs::read("proxy.crt").unwrap()) .collect_vec() @@ -441,17 +421,10 @@ impl PglbConn { biased; msg = auth_stream.next() => { - let Some(msg) = msg else { - bail!("auth proxy disconnected"); - }; - match msg? { - PglbMessage::Postgres(mut payload) => { + match msg.context("auth proxy disconnected")?? { + PglbMessage::Postgres(payload) => { println!("msg {payload:?}"); - let Some(msg) = PgRawMessage::decode(&mut payload, false)? else { - bail!("auth proxy sent invalid message"); - }; - println!("parsed msg {msg:?}"); - client_stream.send(msg).await?; + client_stream.send(PgRawMessage::Generic { payload }).await?; } PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => { bail!("auth proxy sent unexpected message"); @@ -474,16 +447,11 @@ impl PglbConn { } msg = client_stream.next() => { - let Some(msg) = msg else { - bail!("client disconnected"); - }; - match msg? { + match msg.context("client disconnected")?? { PgRawMessage::SslRequest => bail!("protocol violation"), - msg => { - let mut buf = BytesMut::new(); - msg.encode(&mut buf)?; + PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => { auth_stream.send(proxy::PglbMessage::Postgres( - buf + payload )).await?; } } @@ -502,20 +470,15 @@ struct ComputeConnect { impl PglbConn { async fn handle_compute_connect(self) -> Result> { - println!("connecting to compute..."); let ComputeConnect { client_stream, mut auth_stream, compute_socket, } = self.state; let compute_stream = TcpStream::connect(compute_socket).await?; - println!("connected to compute"); - match compute_stream.set_nodelay(true) { - Ok(()) => {} - Err(e) => { - bail!("socket option error: {e}"); - } - }; + compute_stream + .set_nodelay(true) + .context("socket option error")?; let mut compute_stream = Framed::new( compute_stream, @@ -525,26 +488,13 @@ impl PglbConn { ); let mut resps = 4; - let mut first_auth_proxy_request = true; loop { select! { msg = auth_stream.next() => { - let Some(msg) = msg else { - bail!("auth proxy disconnected"); - }; - match msg? { - PglbMessage::Postgres(mut payload) => { - let Some(msg) = PgRawMessage::decode(&mut payload, first_auth_proxy_request)? else { - bail!("auth proxy sent invalid message"); - }; - first_auth_proxy_request = false; - compute_stream.send(msg).await?; - } - PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => { - bail!("auth proxy sent unexpected message"); - } - PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => { - bail!("auth proxy sent unexpected message"); + match msg.context("auth proxy disconnected")?? { + PglbMessage::Postgres(payload) => { + println!("msg {payload:?}"); + compute_stream.send(PgRawMessage::Generic { payload } ).await?; } PglbMessage::Control(PglbControlMessage::ComputeEstablish) => { println!("establish"); @@ -556,21 +506,20 @@ impl PglbConn { }, }); } + PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) | + PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => { + bail!("auth proxy sent unexpected message"); + } } } msg = compute_stream.next(), if resps > 0 => { - let Some(msg) = msg else { - bail!("compute disconnected"); - }; - match msg? { + match msg.context("compute disconnected")?? { PgRawMessage::SslRequest => bail!("protocol violation"), - msg => { + PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => { resps -= 1; - let mut buf = BytesMut::new(); - msg.encode(&mut buf)?; auth_stream.send(proxy::PglbMessage::Postgres( - buf + payload )).await?; } } @@ -589,31 +538,29 @@ struct ComputePassthrough { impl PglbConn { async fn handle_compute_passthrough(self) -> Result> { let ComputePassthrough { - mut client_stream, - mut compute_stream, + client_stream, + compute_stream, } = self.state; - loop { - select! { - msg = client_stream.next() => { - let Some(msg) = msg else { - bail!("compute disconnected"); - }; - let msg = msg?; - dbg!(&msg); - compute_stream.send(msg).await?; - } + let mut client_parts = client_stream.into_parts(); + let mut compute_parts = compute_stream.into_parts(); - msg = compute_stream.next() => { - let Some(msg) = msg else { - bail!("compute disconnected"); - }; - let msg = msg?; - dbg!(&msg); - client_stream.send(msg).await?; - } - } - } + assert!(compute_parts.write_buf.is_empty()); + assert!(client_parts.write_buf.is_empty()); + + client_parts.io.write_all(&compute_parts.read_buf).await?; + compute_parts.io.write_all(&client_parts.read_buf).await?; + + drop(client_parts.read_buf); + drop(client_parts.write_buf); + drop(compute_parts.read_buf); + drop(compute_parts.write_buf); + + copy_bidirectional(&mut client_parts.io, &mut compute_parts.io).await?; + Ok(PglbConn { + inner: self.inner, + state: End, + }) } } @@ -662,8 +609,8 @@ impl tokio_util::codec::Decoder for PgRawCodec { #[derive(Debug)] pub enum PgRawMessage { SslRequest, - Start(Bytes), - Generic { tag: u8, payload: Bytes }, + Start(BytesMut), + Generic { payload: BytesMut }, } impl PgRawMessage { @@ -673,13 +620,7 @@ impl PgRawMessage { dst.put_u32(8); dst.put_u32(80877103); } - Self::Start(payload) => { - // dst.put_u32(payload.len() as u32 + 4); - dst.put_slice(payload); - } - Self::Generic { tag, payload } => { - // dst.put_u8(*tag); - // dst.put_u32(payload.len() as u32 + 4); + Self::Start(payload) | Self::Generic { payload } => { dst.put_slice(payload); } } @@ -687,35 +628,25 @@ impl PgRawMessage { } fn decode(src: &mut bytes::BytesMut, start: bool) -> Result> { - if start { - if src.remaining() < 4 { - src.reserve(4); - return Ok(None); - } - let length = u32::from_be_bytes(src[0..4].try_into().unwrap()) as usize; - if src.remaining() < length { - src.reserve(length - src.remaining()); - return Ok(None); - } - if length == 8 && src[4..8] == 80877103u32.to_be_bytes() { - Ok(Some(PgRawMessage::SslRequest)) - } else { - Ok(Some(PgRawMessage::Start(src.split_to(length).freeze()))) - } + let extra = if start { 0 } else { 1 }; + + if src.remaining() < 4 + extra { + src.reserve(4 + extra); + return Ok(None); + } + let length = u32::from_be_bytes(src[extra..4 + extra].try_into().unwrap()) as usize + extra; + if src.remaining() < length { + src.reserve(length - src.remaining()); + return Ok(None); + } + + if start && length == 8 && src[4..8] == 80877103u32.to_be_bytes() { + Ok(Some(PgRawMessage::SslRequest)) + } else if start { + Ok(Some(PgRawMessage::Start(src.split_to(length)))) } else { - if src.remaining() < 5 { - src.reserve(5); - return Ok(None); - } - let tag = src[0]; - let length = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1; - if src.remaining() < length { - src.reserve(length - src.remaining()); - return Ok(None); - } Ok(Some(PgRawMessage::Generic { - tag, - payload: src.split_to(length).freeze(), + payload: src.split_to(length), })) } }