some cleanup

This commit is contained in:
Conrad Ludgate
2024-09-13 18:27:07 +01:00
parent fe6946e15e
commit 8918b1c872

View File

@@ -6,7 +6,7 @@ use std::{
};
use anyhow::{anyhow, bail, Context, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use bytes::{Buf, BufMut, BytesMut};
use futures::{SinkExt, StreamExt};
use indexmap::IndexMap;
use itertools::Itertools;
@@ -19,7 +19,7 @@ use quinn::{Connection, Endpoint, RecvStream, SendStream};
use rand::Rng;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::{
io::{join, AsyncReadExt, AsyncWriteExt, Join},
io::{copy_bidirectional, join, AsyncReadExt, AsyncWriteExt, Join},
net::{TcpListener, TcpStream},
select,
time::timeout,
@@ -56,7 +56,7 @@ async fn main() -> Result<()> {
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
let frontend_config = frontent_tls_config("*.localtest.me", "*.localtest.me")?;
let frontend_config = frontent_tls_config()?;
let _frontend_handle = tokio::spawn(start_frontend(
"0.0.0.0:5432".parse()?,
@@ -132,27 +132,7 @@ async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
}
}
fn frontent_tls_config(hostname: &str, common_name: &str) -> Result<TlsConfig> {
// let ca = rcgen::Certificate::from_params({
// let mut params = rcgen::CertificateParams::default();
// params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
// params
// })?;
// let cert = rcgen::Certificate::from_params({
// let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
// params.distinguished_name = rcgen::DistinguishedName::new();
// params
// .distinguished_name
// .push(rcgen::DnType::CommonName, common_name);
// params
// })?;
// let (cert, key) = (
// CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
// PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
// );
fn frontent_tls_config() -> Result<TlsConfig> {
let (cert, key) = (
rustls_pemfile::certs(&mut &*std::fs::read("proxy.crt").unwrap())
.collect_vec()
@@ -441,17 +421,10 @@ impl PglbConn<AuthPassthrough> {
biased;
msg = auth_stream.next() => {
let Some(msg) = msg else {
bail!("auth proxy disconnected");
};
match msg? {
PglbMessage::Postgres(mut payload) => {
match msg.context("auth proxy disconnected")?? {
PglbMessage::Postgres(payload) => {
println!("msg {payload:?}");
let Some(msg) = PgRawMessage::decode(&mut payload, false)? else {
bail!("auth proxy sent invalid message");
};
println!("parsed msg {msg:?}");
client_stream.send(msg).await?;
client_stream.send(PgRawMessage::Generic { payload }).await?;
}
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
bail!("auth proxy sent unexpected message");
@@ -474,16 +447,11 @@ impl PglbConn<AuthPassthrough> {
}
msg = client_stream.next() => {
let Some(msg) = msg else {
bail!("client disconnected");
};
match msg? {
match msg.context("client disconnected")?? {
PgRawMessage::SslRequest => bail!("protocol violation"),
msg => {
let mut buf = BytesMut::new();
msg.encode(&mut buf)?;
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
auth_stream.send(proxy::PglbMessage::Postgres(
buf
payload
)).await?;
}
}
@@ -502,20 +470,15 @@ struct ComputeConnect {
impl PglbConn<ComputeConnect> {
async fn handle_compute_connect(self) -> Result<PglbConn<ComputePassthrough>> {
println!("connecting to compute...");
let ComputeConnect {
client_stream,
mut auth_stream,
compute_socket,
} = self.state;
let compute_stream = TcpStream::connect(compute_socket).await?;
println!("connected to compute");
match compute_stream.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
bail!("socket option error: {e}");
}
};
compute_stream
.set_nodelay(true)
.context("socket option error")?;
let mut compute_stream = Framed::new(
compute_stream,
@@ -525,26 +488,13 @@ impl PglbConn<ComputeConnect> {
);
let mut resps = 4;
let mut first_auth_proxy_request = true;
loop {
select! {
msg = auth_stream.next() => {
let Some(msg) = msg else {
bail!("auth proxy disconnected");
};
match msg? {
PglbMessage::Postgres(mut payload) => {
let Some(msg) = PgRawMessage::decode(&mut payload, first_auth_proxy_request)? else {
bail!("auth proxy sent invalid message");
};
first_auth_proxy_request = false;
compute_stream.send(msg).await?;
}
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
bail!("auth proxy sent unexpected message");
}
PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => {
bail!("auth proxy sent unexpected message");
match msg.context("auth proxy disconnected")?? {
PglbMessage::Postgres(payload) => {
println!("msg {payload:?}");
compute_stream.send(PgRawMessage::Generic { payload } ).await?;
}
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
println!("establish");
@@ -556,21 +506,20 @@ impl PglbConn<ComputeConnect> {
},
});
}
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) |
PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => {
bail!("auth proxy sent unexpected message");
}
}
}
msg = compute_stream.next(), if resps > 0 => {
let Some(msg) = msg else {
bail!("compute disconnected");
};
match msg? {
match msg.context("compute disconnected")?? {
PgRawMessage::SslRequest => bail!("protocol violation"),
msg => {
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
resps -= 1;
let mut buf = BytesMut::new();
msg.encode(&mut buf)?;
auth_stream.send(proxy::PglbMessage::Postgres(
buf
payload
)).await?;
}
}
@@ -589,31 +538,29 @@ struct ComputePassthrough {
impl PglbConn<ComputePassthrough> {
async fn handle_compute_passthrough(self) -> Result<PglbConn<End>> {
let ComputePassthrough {
mut client_stream,
mut compute_stream,
client_stream,
compute_stream,
} = self.state;
loop {
select! {
msg = client_stream.next() => {
let Some(msg) = msg else {
bail!("compute disconnected");
};
let msg = msg?;
dbg!(&msg);
compute_stream.send(msg).await?;
}
let mut client_parts = client_stream.into_parts();
let mut compute_parts = compute_stream.into_parts();
msg = compute_stream.next() => {
let Some(msg) = msg else {
bail!("compute disconnected");
};
let msg = msg?;
dbg!(&msg);
client_stream.send(msg).await?;
}
}
}
assert!(compute_parts.write_buf.is_empty());
assert!(client_parts.write_buf.is_empty());
client_parts.io.write_all(&compute_parts.read_buf).await?;
compute_parts.io.write_all(&client_parts.read_buf).await?;
drop(client_parts.read_buf);
drop(client_parts.write_buf);
drop(compute_parts.read_buf);
drop(compute_parts.write_buf);
copy_bidirectional(&mut client_parts.io, &mut compute_parts.io).await?;
Ok(PglbConn {
inner: self.inner,
state: End,
})
}
}
@@ -662,8 +609,8 @@ impl tokio_util::codec::Decoder for PgRawCodec {
#[derive(Debug)]
pub enum PgRawMessage {
SslRequest,
Start(Bytes),
Generic { tag: u8, payload: Bytes },
Start(BytesMut),
Generic { payload: BytesMut },
}
impl PgRawMessage {
@@ -673,13 +620,7 @@ impl PgRawMessage {
dst.put_u32(8);
dst.put_u32(80877103);
}
Self::Start(payload) => {
// dst.put_u32(payload.len() as u32 + 4);
dst.put_slice(payload);
}
Self::Generic { tag, payload } => {
// dst.put_u8(*tag);
// dst.put_u32(payload.len() as u32 + 4);
Self::Start(payload) | Self::Generic { payload } => {
dst.put_slice(payload);
}
}
@@ -687,35 +628,25 @@ impl PgRawMessage {
}
fn decode(src: &mut bytes::BytesMut, start: bool) -> Result<Option<Self>> {
if start {
if src.remaining() < 4 {
src.reserve(4);
return Ok(None);
}
let length = u32::from_be_bytes(src[0..4].try_into().unwrap()) as usize;
if src.remaining() < length {
src.reserve(length - src.remaining());
return Ok(None);
}
if length == 8 && src[4..8] == 80877103u32.to_be_bytes() {
Ok(Some(PgRawMessage::SslRequest))
} else {
Ok(Some(PgRawMessage::Start(src.split_to(length).freeze())))
}
let extra = if start { 0 } else { 1 };
if src.remaining() < 4 + extra {
src.reserve(4 + extra);
return Ok(None);
}
let length = u32::from_be_bytes(src[extra..4 + extra].try_into().unwrap()) as usize + extra;
if src.remaining() < length {
src.reserve(length - src.remaining());
return Ok(None);
}
if start && length == 8 && src[4..8] == 80877103u32.to_be_bytes() {
Ok(Some(PgRawMessage::SslRequest))
} else if start {
Ok(Some(PgRawMessage::Start(src.split_to(length))))
} else {
if src.remaining() < 5 {
src.reserve(5);
return Ok(None);
}
let tag = src[0];
let length = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
if src.remaining() < length {
src.reserve(length - src.remaining());
return Ok(None);
}
Ok(Some(PgRawMessage::Generic {
tag,
payload: src.split_to(length).freeze(),
payload: src.split_to(length),
}))
}
}