Merge branch 'cloneable/pglb-msg-codec' into pglb

This commit is contained in:
Folke Behrens
2024-09-12 23:49:22 +01:00

View File

@@ -1,17 +1,17 @@
use std::{
convert::Infallible,
net::SocketAddr,
net::{IpAddr, SocketAddr},
sync::{Arc, Mutex},
time::Duration,
};
use anyhow::{anyhow, bail, Context, Result};
use bytes::BytesMut;
use futures::sink::SinkExt;
use bytes::{Buf, BufMut, BytesMut};
use futures::{SinkExt, StreamExt};
use indexmap::IndexMap;
use proxy::{
config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL},
ConnectionInitiatedPayload, PglbCodec,
ConnectionInitiatedPayload, PglbCodec, PglbControlMessage, PglbMessage,
};
use quinn::{Connection, Endpoint, RecvStream, SendStream};
use rand::Rng;
@@ -19,6 +19,7 @@ use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::{
io::{join, AsyncReadExt, Join},
net::{TcpListener, TcpStream},
select,
time::timeout,
};
use tokio_rustls::server::TlsStream;
@@ -388,7 +389,12 @@ impl PglbConn<ClientConnect> {
Ok(PglbConn {
inner: self.inner,
state: AuthPassthrough {
client_stream: self.state.client_stream,
client_stream: Framed::new(
self.state.client_stream,
PgRawCodec {
start_or_ssl_request: true,
},
),
auth_stream,
},
})
@@ -397,20 +403,74 @@ impl PglbConn<ClientConnect> {
#[derive(Debug)]
struct AuthPassthrough {
client_stream: TlsStream<TcpStream>,
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
}
impl PglbConn<AuthPassthrough> {
async fn handle_auth_passthrough(self) -> Result<PglbConn<ComputeConnect>> {
todo!()
let mut client_stream = self.state.client_stream;
let mut auth_stream = self.state.auth_stream;
loop {
select! {
biased;
msg = auth_stream.next() => {
let Some(msg) = msg else {
bail!("auth proxy disconnected");
};
match msg? {
PglbMessage::Postgres(mut payload) => {
let Some(msg) = PgRawMessage::decode(&mut payload, false)? else {
bail!("auth proxy sent invalid message");
};
client_stream.send(msg).await?;
}
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
bail!("auth proxy sent unexpected message");
}
PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket }) => {
return Ok(PglbConn {
inner: self.inner,
state: ComputeConnect {
client_stream,
auth_stream,
compute_socket:socket,
},
});
}
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
bail!("auth proxy sent unexpected message");
}
}
}
msg = client_stream.next() => {
let Some(msg) = msg else {
bail!("client disconnected");
};
match msg? {
PgRawMessage::SslRequest => bail!("protocol violation"),
msg => {
let mut buf = BytesMut::new();
msg.encode(&mut buf)?;
auth_stream.send(proxy::PglbMessage::Postgres(
buf
)).await?;
}
}
}
}
}
}
}
#[derive(Debug)]
struct ComputeConnect {
client_stream: TlsStream<TcpStream>,
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
compute_socket: SocketAddr,
}
impl PglbConn<ComputeConnect> {
@@ -433,3 +493,103 @@ impl PglbConn<ComputePassthrough> {
#[derive(Debug)]
struct End;
#[derive(Debug)]
struct PgRawCodec {
start_or_ssl_request: bool,
}
impl tokio_util::codec::Encoder<PgRawMessage> for PgRawCodec {
type Error = anyhow::Error;
fn encode(&mut self, item: PgRawMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
item.encode(dst)
}
}
impl tokio_util::codec::Decoder for PgRawCodec {
type Item = PgRawMessage;
type Error = anyhow::Error;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.start_or_ssl_request {
match PgRawMessage::decode(src, true)? {
msg @ Some(PgRawMessage::Start(..)) => {
self.start_or_ssl_request = false;
Ok(msg)
}
msg @ Some(PgRawMessage::SslRequest) => Ok(msg),
Some(PgRawMessage::Generic { .. }) => unreachable!(),
None => Ok(None),
}
} else {
match PgRawMessage::decode(src, false)? {
Some(PgRawMessage::Start(..)) => unreachable!(),
Some(PgRawMessage::SslRequest) => unreachable!(),
msg @ Some(PgRawMessage::Generic { .. }) => Ok(msg),
None => Ok(None),
}
}
}
}
pub enum PgRawMessage {
SslRequest,
Start(Vec<u8>),
Generic { tag: u8, payload: Vec<u8> },
}
impl PgRawMessage {
fn encode(&self, dst: &mut bytes::BytesMut) -> Result<()> {
match self {
Self::SslRequest => {
dst.put_u32(8);
dst.put_u32(80877103);
}
Self::Start(payload) => {
dst.put_u32(payload.len() as u32 + 4);
dst.put_slice(&payload);
}
Self::Generic { tag, payload } => {
dst.put_u8(*tag);
dst.put_u32(payload.len() as u32 + 4);
dst.put_slice(&payload);
}
}
Ok(())
}
fn decode(src: &mut bytes::BytesMut, start: bool) -> Result<Option<Self>> {
if start {
if src.remaining() < 4 {
src.reserve(4);
return Ok(None);
}
let length = src.get_u32() as usize - 4;
if src.remaining() < length {
src.reserve(length - src.remaining());
return Ok(None);
}
if length == 4 && src.starts_with(&80877103u32.to_be_bytes()) {
Ok(Some(PgRawMessage::SslRequest))
} else {
Ok(Some(PgRawMessage::Start(src.split_off(length).to_vec())))
}
} else {
if src.remaining() < 5 {
src.reserve(5);
return Ok(None);
}
let tag = src.get_u8();
let length = src.get_u32() as usize - 4;
if src.remaining() < length {
src.reserve(length - src.remaining());
return Ok(None);
}
Ok(Some(PgRawMessage::Generic {
tag,
payload: src.split_off(length).to_vec(),
}))
}
}
}