From 2af53527082fb0e108f6d32d9d7cebd7f4961f29 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 12 Sep 2024 15:17:01 +0100 Subject: [PATCH] add auth proxy codec --- proxy/src/bin/auth_proxy.rs | 12 +++- proxy/src/config.rs | 3 +- proxy/src/lib.rs | 137 +++++++++++++++++++++++++++++++++++- 3 files changed, 147 insertions(+), 5 deletions(-) diff --git a/proxy/src/bin/auth_proxy.rs b/proxy/src/bin/auth_proxy.rs index 85e5c8abdc..2f0686f11e 100644 --- a/proxy/src/bin/auth_proxy.rs +++ b/proxy/src/bin/auth_proxy.rs @@ -1,16 +1,20 @@ use std::{sync::Arc, time::Duration}; +use proxy::PglbCodec; use quinn::{ crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream, VarInt, }; use tokio::{ - io::AsyncWriteExt, + io::{join, AsyncWriteExt}, select, signal::unix::{signal, SignalKind}, time::interval, }; -use tokio_util::task::TaskTracker; +use tokio_util::{ + codec::{Framed, FramedRead, FramedWrite}, + task::TaskTracker, +}; #[tokio::main] async fn main() { @@ -102,4 +106,6 @@ impl danger::ServerCertVerifier for NoVerify { } } -async fn handle_stream(_send: SendStream, _recv: RecvStream) {} +async fn handle_stream(send: SendStream, recv: RecvStream) { + let _stream = Framed::new(join(recv, send), PglbCodec); +} diff --git a/proxy/src/config.rs b/proxy/src/config.rs index d7fc6eee22..87fa05a071 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -13,6 +13,7 @@ use rustls::{ crypto::ring::sign, pki_types::{CertificateDer, PrivateKeyDer}, }; +use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::{ collections::{HashMap, HashSet}, @@ -149,7 +150,7 @@ pub fn configure_tls( /// uses multiple hash functions, then this channel binding type's /// channel bindings are undefined at this time (updates to is channel /// binding type may occur to address this issue if it ever arises). -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum TlsServerEndPoint { Sha256([u8; 32]), Undefined, diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 0070839aa8..0c820c8512 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -82,10 +82,17 @@ impl_trait_overcaptures, )] -use std::{convert::Infallible, future::Future}; +use std::{ + convert::Infallible, + future::Future, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, +}; use anyhow::{bail, Context}; +use bytes::{Buf, BufMut}; +use config::TlsServerEndPoint; use intern::{EndpointIdInt, EndpointIdTag, InternId}; +use serde::{Deserialize, Serialize}; use tokio::task::JoinError; use tokio_util::sync::CancellationToken; use tracing::warn; @@ -274,3 +281,131 @@ impl EndpointId { ProjectId(self.0.clone()) } } + +pub struct PglbCodec; + +impl tokio_util::codec::Encoder for PglbCodec { + type Error = anyhow::Error; + + fn encode(&mut self, item: PglbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { + match item { + PglbMessage::Control(ctrl) => { + dst.put_u8(1); + match ctrl { + PglbControlMessage::ConnectionInitiated(msg) => { + let encode = serde_json::to_string(&msg).context("ser")?; + dst.put_u32(1 + encode.len() as u32); + dst.put_u8(0); + dst.put(encode.as_bytes()); + } + PglbControlMessage::ConnectToCompute { socket } => match socket { + SocketAddr::V4(v4) => { + dst.put_u32(1 + 4 + 2); + dst.put_u8(1); + dst.put_u32(v4.ip().to_bits()); + dst.put_u16(v4.port()); + } + SocketAddr::V6(v6) => { + dst.put_u32(1 + 16 + 2); + dst.put_u8(1); + dst.put_u128(v6.ip().to_bits()); + dst.put_u16(v6.port()); + } + }, + PglbControlMessage::ComputeEstablish => { + dst.put_u32(1); + dst.put_u8(2); + } + } + } + PglbMessage::Postgres(pg) => { + dst.put_u8(0); + dst.put_u32(pg.len() as u32); + dst.put(pg); + } + } + Ok(()) + } +} + +impl tokio_util::codec::Decoder for PglbCodec { + type Item = PglbMessage; + type Error = anyhow::Error; + + fn decode(&mut self, dst: &mut bytes::BytesMut) -> Result, Self::Error> { + if dst.remaining() < 5 { + dst.reserve(5); + return Ok(None); + } + + let msg = dst[0]; + let len = u32::from_be_bytes(dst[1..5].try_into().unwrap()) as usize; + + if len + 5 > dst.remaining() { + dst.reserve(len + 5); + return Ok(None); + } + + dst.advance(5); + let mut payload = dst.split_to(len); + + match msg { + // postgres + 0 => Ok(Some(PglbMessage::Postgres(payload.freeze()))), + // control + 1 => { + if payload.is_empty() { + bail!("invalid ctrl message") + } + let ctrl_msg = payload.split_to(1)[0]; + let ctrl_msg = match ctrl_msg { + 0 => PglbControlMessage::ConnectionInitiated( + serde_json::from_slice(&payload).context("deser")?, + ), + // ipv4 socket + 1 if len == 7 => PglbControlMessage::ConnectToCompute { + socket: SocketAddr::new( + IpAddr::V4(Ipv4Addr::from_bits(payload.get_u32())), + payload.get_u16(), + ), + }, + + // ipv6 socket + 1 if len == 19 => PglbControlMessage::ConnectToCompute { + socket: SocketAddr::new( + IpAddr::V6(Ipv6Addr::from_bits(payload.get_u128())), + payload.get_u16(), + ), + }, + + 2 if len == 1 => PglbControlMessage::ComputeEstablish, + + _ => bail!("invalid ctrl message"), + }; + Ok(Some(PglbMessage::Control(ctrl_msg))) + } + _ => bail!("invalid message"), + } + } +} + +pub enum PglbMessage { + Control(PglbControlMessage), + Postgres(bytes::Bytes), +} + +pub enum PglbControlMessage { + // from pglb to auth proxy + ConnectionInitiated(ConnectionInitiatedPayload), + // from auth proxy to pglb + ConnectToCompute { socket: SocketAddr }, + // from auth proxy to pglb + ComputeEstablish, +} + +#[derive(Serialize, Deserialize)] +pub struct ConnectionInitiatedPayload { + tls_server_end_point: TlsServerEndPoint, + server_name: Option, + ip_addr: IpAddr, +}