mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-22 07:30:37 +00:00
some cleanup
This commit is contained in:
@@ -6,7 +6,7 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use indexmap::IndexMap;
|
||||
use itertools::Itertools;
|
||||
@@ -19,7 +19,7 @@ use quinn::{Connection, Endpoint, RecvStream, SendStream};
|
||||
use rand::Rng;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use tokio::{
|
||||
io::{join, AsyncReadExt, AsyncWriteExt, Join},
|
||||
io::{copy_bidirectional, join, AsyncReadExt, AsyncWriteExt, Join},
|
||||
net::{TcpListener, TcpStream},
|
||||
select,
|
||||
time::timeout,
|
||||
@@ -56,7 +56,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
let quinn_handle = tokio::spawn(quinn_server(auth_endpoint, auth_connections.clone()));
|
||||
|
||||
let frontend_config = frontent_tls_config("*.localtest.me", "*.localtest.me")?;
|
||||
let frontend_config = frontent_tls_config()?;
|
||||
|
||||
let _frontend_handle = tokio::spawn(start_frontend(
|
||||
"0.0.0.0:5432".parse()?,
|
||||
@@ -132,27 +132,7 @@ async fn quinn_server(ep: Endpoint, state: Arc<AuthConnState>) {
|
||||
}
|
||||
}
|
||||
|
||||
fn frontent_tls_config(hostname: &str, common_name: &str) -> Result<TlsConfig> {
|
||||
// let ca = rcgen::Certificate::from_params({
|
||||
// let mut params = rcgen::CertificateParams::default();
|
||||
// params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
|
||||
// params
|
||||
// })?;
|
||||
|
||||
// let cert = rcgen::Certificate::from_params({
|
||||
// let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
|
||||
// params.distinguished_name = rcgen::DistinguishedName::new();
|
||||
// params
|
||||
// .distinguished_name
|
||||
// .push(rcgen::DnType::CommonName, common_name);
|
||||
// params
|
||||
// })?;
|
||||
|
||||
// let (cert, key) = (
|
||||
// CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
|
||||
// PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
|
||||
// );
|
||||
|
||||
fn frontent_tls_config() -> Result<TlsConfig> {
|
||||
let (cert, key) = (
|
||||
rustls_pemfile::certs(&mut &*std::fs::read("proxy.crt").unwrap())
|
||||
.collect_vec()
|
||||
@@ -441,17 +421,10 @@ impl PglbConn<AuthPassthrough> {
|
||||
biased;
|
||||
|
||||
msg = auth_stream.next() => {
|
||||
let Some(msg) = msg else {
|
||||
bail!("auth proxy disconnected");
|
||||
};
|
||||
match msg? {
|
||||
PglbMessage::Postgres(mut payload) => {
|
||||
match msg.context("auth proxy disconnected")?? {
|
||||
PglbMessage::Postgres(payload) => {
|
||||
println!("msg {payload:?}");
|
||||
let Some(msg) = PgRawMessage::decode(&mut payload, false)? else {
|
||||
bail!("auth proxy sent invalid message");
|
||||
};
|
||||
println!("parsed msg {msg:?}");
|
||||
client_stream.send(msg).await?;
|
||||
client_stream.send(PgRawMessage::Generic { payload }).await?;
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
|
||||
bail!("auth proxy sent unexpected message");
|
||||
@@ -474,16 +447,11 @@ impl PglbConn<AuthPassthrough> {
|
||||
}
|
||||
|
||||
msg = client_stream.next() => {
|
||||
let Some(msg) = msg else {
|
||||
bail!("client disconnected");
|
||||
};
|
||||
match msg? {
|
||||
match msg.context("client disconnected")?? {
|
||||
PgRawMessage::SslRequest => bail!("protocol violation"),
|
||||
msg => {
|
||||
let mut buf = BytesMut::new();
|
||||
msg.encode(&mut buf)?;
|
||||
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
|
||||
auth_stream.send(proxy::PglbMessage::Postgres(
|
||||
buf
|
||||
payload
|
||||
)).await?;
|
||||
}
|
||||
}
|
||||
@@ -502,20 +470,15 @@ struct ComputeConnect {
|
||||
|
||||
impl PglbConn<ComputeConnect> {
|
||||
async fn handle_compute_connect(self) -> Result<PglbConn<ComputePassthrough>> {
|
||||
println!("connecting to compute...");
|
||||
let ComputeConnect {
|
||||
client_stream,
|
||||
mut auth_stream,
|
||||
compute_socket,
|
||||
} = self.state;
|
||||
let compute_stream = TcpStream::connect(compute_socket).await?;
|
||||
println!("connected to compute");
|
||||
match compute_stream.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
bail!("socket option error: {e}");
|
||||
}
|
||||
};
|
||||
compute_stream
|
||||
.set_nodelay(true)
|
||||
.context("socket option error")?;
|
||||
|
||||
let mut compute_stream = Framed::new(
|
||||
compute_stream,
|
||||
@@ -525,26 +488,13 @@ impl PglbConn<ComputeConnect> {
|
||||
);
|
||||
|
||||
let mut resps = 4;
|
||||
let mut first_auth_proxy_request = true;
|
||||
loop {
|
||||
select! {
|
||||
msg = auth_stream.next() => {
|
||||
let Some(msg) = msg else {
|
||||
bail!("auth proxy disconnected");
|
||||
};
|
||||
match msg? {
|
||||
PglbMessage::Postgres(mut payload) => {
|
||||
let Some(msg) = PgRawMessage::decode(&mut payload, first_auth_proxy_request)? else {
|
||||
bail!("auth proxy sent invalid message");
|
||||
};
|
||||
first_auth_proxy_request = false;
|
||||
compute_stream.send(msg).await?;
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
|
||||
bail!("auth proxy sent unexpected message");
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => {
|
||||
bail!("auth proxy sent unexpected message");
|
||||
match msg.context("auth proxy disconnected")?? {
|
||||
PglbMessage::Postgres(payload) => {
|
||||
println!("msg {payload:?}");
|
||||
compute_stream.send(PgRawMessage::Generic { payload } ).await?;
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
|
||||
println!("establish");
|
||||
@@ -556,21 +506,20 @@ impl PglbConn<ComputeConnect> {
|
||||
},
|
||||
});
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) |
|
||||
PglbMessage::Control(PglbControlMessage::ConnectToCompute { .. }) => {
|
||||
bail!("auth proxy sent unexpected message");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg = compute_stream.next(), if resps > 0 => {
|
||||
let Some(msg) = msg else {
|
||||
bail!("compute disconnected");
|
||||
};
|
||||
match msg? {
|
||||
match msg.context("compute disconnected")?? {
|
||||
PgRawMessage::SslRequest => bail!("protocol violation"),
|
||||
msg => {
|
||||
PgRawMessage::Start(payload) | PgRawMessage::Generic { payload } => {
|
||||
resps -= 1;
|
||||
let mut buf = BytesMut::new();
|
||||
msg.encode(&mut buf)?;
|
||||
auth_stream.send(proxy::PglbMessage::Postgres(
|
||||
buf
|
||||
payload
|
||||
)).await?;
|
||||
}
|
||||
}
|
||||
@@ -589,31 +538,29 @@ struct ComputePassthrough {
|
||||
impl PglbConn<ComputePassthrough> {
|
||||
async fn handle_compute_passthrough(self) -> Result<PglbConn<End>> {
|
||||
let ComputePassthrough {
|
||||
mut client_stream,
|
||||
mut compute_stream,
|
||||
client_stream,
|
||||
compute_stream,
|
||||
} = self.state;
|
||||
|
||||
loop {
|
||||
select! {
|
||||
msg = client_stream.next() => {
|
||||
let Some(msg) = msg else {
|
||||
bail!("compute disconnected");
|
||||
};
|
||||
let msg = msg?;
|
||||
dbg!(&msg);
|
||||
compute_stream.send(msg).await?;
|
||||
}
|
||||
let mut client_parts = client_stream.into_parts();
|
||||
let mut compute_parts = compute_stream.into_parts();
|
||||
|
||||
msg = compute_stream.next() => {
|
||||
let Some(msg) = msg else {
|
||||
bail!("compute disconnected");
|
||||
};
|
||||
let msg = msg?;
|
||||
dbg!(&msg);
|
||||
client_stream.send(msg).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(compute_parts.write_buf.is_empty());
|
||||
assert!(client_parts.write_buf.is_empty());
|
||||
|
||||
client_parts.io.write_all(&compute_parts.read_buf).await?;
|
||||
compute_parts.io.write_all(&client_parts.read_buf).await?;
|
||||
|
||||
drop(client_parts.read_buf);
|
||||
drop(client_parts.write_buf);
|
||||
drop(compute_parts.read_buf);
|
||||
drop(compute_parts.write_buf);
|
||||
|
||||
copy_bidirectional(&mut client_parts.io, &mut compute_parts.io).await?;
|
||||
Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: End,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -662,8 +609,8 @@ impl tokio_util::codec::Decoder for PgRawCodec {
|
||||
#[derive(Debug)]
|
||||
pub enum PgRawMessage {
|
||||
SslRequest,
|
||||
Start(Bytes),
|
||||
Generic { tag: u8, payload: Bytes },
|
||||
Start(BytesMut),
|
||||
Generic { payload: BytesMut },
|
||||
}
|
||||
|
||||
impl PgRawMessage {
|
||||
@@ -673,13 +620,7 @@ impl PgRawMessage {
|
||||
dst.put_u32(8);
|
||||
dst.put_u32(80877103);
|
||||
}
|
||||
Self::Start(payload) => {
|
||||
// dst.put_u32(payload.len() as u32 + 4);
|
||||
dst.put_slice(payload);
|
||||
}
|
||||
Self::Generic { tag, payload } => {
|
||||
// dst.put_u8(*tag);
|
||||
// dst.put_u32(payload.len() as u32 + 4);
|
||||
Self::Start(payload) | Self::Generic { payload } => {
|
||||
dst.put_slice(payload);
|
||||
}
|
||||
}
|
||||
@@ -687,35 +628,25 @@ impl PgRawMessage {
|
||||
}
|
||||
|
||||
fn decode(src: &mut bytes::BytesMut, start: bool) -> Result<Option<Self>> {
|
||||
if start {
|
||||
if src.remaining() < 4 {
|
||||
src.reserve(4);
|
||||
return Ok(None);
|
||||
}
|
||||
let length = u32::from_be_bytes(src[0..4].try_into().unwrap()) as usize;
|
||||
if src.remaining() < length {
|
||||
src.reserve(length - src.remaining());
|
||||
return Ok(None);
|
||||
}
|
||||
if length == 8 && src[4..8] == 80877103u32.to_be_bytes() {
|
||||
Ok(Some(PgRawMessage::SslRequest))
|
||||
} else {
|
||||
Ok(Some(PgRawMessage::Start(src.split_to(length).freeze())))
|
||||
}
|
||||
let extra = if start { 0 } else { 1 };
|
||||
|
||||
if src.remaining() < 4 + extra {
|
||||
src.reserve(4 + extra);
|
||||
return Ok(None);
|
||||
}
|
||||
let length = u32::from_be_bytes(src[extra..4 + extra].try_into().unwrap()) as usize + extra;
|
||||
if src.remaining() < length {
|
||||
src.reserve(length - src.remaining());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if start && length == 8 && src[4..8] == 80877103u32.to_be_bytes() {
|
||||
Ok(Some(PgRawMessage::SslRequest))
|
||||
} else if start {
|
||||
Ok(Some(PgRawMessage::Start(src.split_to(length))))
|
||||
} else {
|
||||
if src.remaining() < 5 {
|
||||
src.reserve(5);
|
||||
return Ok(None);
|
||||
}
|
||||
let tag = src[0];
|
||||
let length = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
|
||||
if src.remaining() < length {
|
||||
src.reserve(length - src.remaining());
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(PgRawMessage::Generic {
|
||||
tag,
|
||||
payload: src.split_to(length).freeze(),
|
||||
payload: src.split_to(length),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user