add auth proxy codec

This commit is contained in:
Conrad Ludgate
2024-09-12 15:17:01 +01:00
parent fbc37acfdf
commit 2af5352708
3 changed files with 147 additions and 5 deletions

View File

@@ -1,16 +1,20 @@
use std::{sync::Arc, time::Duration};
use proxy::PglbCodec;
use quinn::{
crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream,
VarInt,
};
use tokio::{
io::AsyncWriteExt,
io::{join, AsyncWriteExt},
select,
signal::unix::{signal, SignalKind},
time::interval,
};
use tokio_util::task::TaskTracker;
use tokio_util::{
codec::{Framed, FramedRead, FramedWrite},
task::TaskTracker,
};
#[tokio::main]
async fn main() {
@@ -102,4 +106,6 @@ impl danger::ServerCertVerifier for NoVerify {
}
}
async fn handle_stream(_send: SendStream, _recv: RecvStream) {}
async fn handle_stream(send: SendStream, recv: RecvStream) {
let _stream = Framed::new(join(recv, send), PglbCodec);
}

View File

@@ -13,6 +13,7 @@ use rustls::{
crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
@@ -149,7 +150,7 @@ pub fn configure_tls(
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,

View File

@@ -82,10 +82,17 @@
impl_trait_overcaptures,
)]
use std::{convert::Infallible, future::Future};
use std::{
convert::Infallible,
future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
};
use anyhow::{bail, Context};
use bytes::{Buf, BufMut};
use config::TlsServerEndPoint;
use intern::{EndpointIdInt, EndpointIdTag, InternId};
use serde::{Deserialize, Serialize};
use tokio::task::JoinError;
use tokio_util::sync::CancellationToken;
use tracing::warn;
@@ -274,3 +281,131 @@ impl EndpointId {
ProjectId(self.0.clone())
}
}
pub struct PglbCodec;
impl tokio_util::codec::Encoder<PglbMessage> for PglbCodec {
type Error = anyhow::Error;
fn encode(&mut self, item: PglbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
match item {
PglbMessage::Control(ctrl) => {
dst.put_u8(1);
match ctrl {
PglbControlMessage::ConnectionInitiated(msg) => {
let encode = serde_json::to_string(&msg).context("ser")?;
dst.put_u32(1 + encode.len() as u32);
dst.put_u8(0);
dst.put(encode.as_bytes());
}
PglbControlMessage::ConnectToCompute { socket } => match socket {
SocketAddr::V4(v4) => {
dst.put_u32(1 + 4 + 2);
dst.put_u8(1);
dst.put_u32(v4.ip().to_bits());
dst.put_u16(v4.port());
}
SocketAddr::V6(v6) => {
dst.put_u32(1 + 16 + 2);
dst.put_u8(1);
dst.put_u128(v6.ip().to_bits());
dst.put_u16(v6.port());
}
},
PglbControlMessage::ComputeEstablish => {
dst.put_u32(1);
dst.put_u8(2);
}
}
}
PglbMessage::Postgres(pg) => {
dst.put_u8(0);
dst.put_u32(pg.len() as u32);
dst.put(pg);
}
}
Ok(())
}
}
impl tokio_util::codec::Decoder for PglbCodec {
type Item = PglbMessage;
type Error = anyhow::Error;
fn decode(&mut self, dst: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if dst.remaining() < 5 {
dst.reserve(5);
return Ok(None);
}
let msg = dst[0];
let len = u32::from_be_bytes(dst[1..5].try_into().unwrap()) as usize;
if len + 5 > dst.remaining() {
dst.reserve(len + 5);
return Ok(None);
}
dst.advance(5);
let mut payload = dst.split_to(len);
match msg {
// postgres
0 => Ok(Some(PglbMessage::Postgres(payload.freeze()))),
// control
1 => {
if payload.is_empty() {
bail!("invalid ctrl message")
}
let ctrl_msg = payload.split_to(1)[0];
let ctrl_msg = match ctrl_msg {
0 => PglbControlMessage::ConnectionInitiated(
serde_json::from_slice(&payload).context("deser")?,
),
// ipv4 socket
1 if len == 7 => PglbControlMessage::ConnectToCompute {
socket: SocketAddr::new(
IpAddr::V4(Ipv4Addr::from_bits(payload.get_u32())),
payload.get_u16(),
),
},
// ipv6 socket
1 if len == 19 => PglbControlMessage::ConnectToCompute {
socket: SocketAddr::new(
IpAddr::V6(Ipv6Addr::from_bits(payload.get_u128())),
payload.get_u16(),
),
},
2 if len == 1 => PglbControlMessage::ComputeEstablish,
_ => bail!("invalid ctrl message"),
};
Ok(Some(PglbMessage::Control(ctrl_msg)))
}
_ => bail!("invalid message"),
}
}
}
pub enum PglbMessage {
Control(PglbControlMessage),
Postgres(bytes::Bytes),
}
pub enum PglbControlMessage {
// from pglb to auth proxy
ConnectionInitiated(ConnectionInitiatedPayload),
// from auth proxy to pglb
ConnectToCompute { socket: SocketAddr },
// from auth proxy to pglb
ComputeEstablish,
}
#[derive(Serialize, Deserialize)]
pub struct ConnectionInitiatedPayload {
tls_server_end_point: TlsServerEndPoint,
server_name: Option<String>,
ip_addr: IpAddr,
}