mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
Merge branch 'cloneable/pglb-msg-codec' into pglb
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
net::SocketAddr,
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use bytes::BytesMut;
|
||||
use futures::sink::SinkExt;
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use indexmap::IndexMap;
|
||||
use proxy::{
|
||||
config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL},
|
||||
ConnectionInitiatedPayload, PglbCodec,
|
||||
ConnectionInitiatedPayload, PglbCodec, PglbControlMessage, PglbMessage,
|
||||
};
|
||||
use quinn::{Connection, Endpoint, RecvStream, SendStream};
|
||||
use rand::Rng;
|
||||
@@ -19,6 +19,7 @@ use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use tokio::{
|
||||
io::{join, AsyncReadExt, Join},
|
||||
net::{TcpListener, TcpStream},
|
||||
select,
|
||||
time::timeout,
|
||||
};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
@@ -388,7 +389,12 @@ impl PglbConn<ClientConnect> {
|
||||
Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: AuthPassthrough {
|
||||
client_stream: self.state.client_stream,
|
||||
client_stream: Framed::new(
|
||||
self.state.client_stream,
|
||||
PgRawCodec {
|
||||
start_or_ssl_request: true,
|
||||
},
|
||||
),
|
||||
auth_stream,
|
||||
},
|
||||
})
|
||||
@@ -397,20 +403,74 @@ impl PglbConn<ClientConnect> {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AuthPassthrough {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
}
|
||||
|
||||
impl PglbConn<AuthPassthrough> {
|
||||
async fn handle_auth_passthrough(self) -> Result<PglbConn<ComputeConnect>> {
|
||||
todo!()
|
||||
let mut client_stream = self.state.client_stream;
|
||||
let mut auth_stream = self.state.auth_stream;
|
||||
|
||||
loop {
|
||||
select! {
|
||||
biased;
|
||||
|
||||
msg = auth_stream.next() => {
|
||||
let Some(msg) = msg else {
|
||||
bail!("auth proxy disconnected");
|
||||
};
|
||||
match msg? {
|
||||
PglbMessage::Postgres(mut payload) => {
|
||||
let Some(msg) = PgRawMessage::decode(&mut payload, false)? else {
|
||||
bail!("auth proxy sent invalid message");
|
||||
};
|
||||
client_stream.send(msg).await?;
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_)) => {
|
||||
bail!("auth proxy sent unexpected message");
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ConnectToCompute { socket }) => {
|
||||
return Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: ComputeConnect {
|
||||
client_stream,
|
||||
auth_stream,
|
||||
compute_socket:socket,
|
||||
},
|
||||
});
|
||||
}
|
||||
PglbMessage::Control(PglbControlMessage::ComputeEstablish) => {
|
||||
bail!("auth proxy sent unexpected message");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg = client_stream.next() => {
|
||||
let Some(msg) = msg else {
|
||||
bail!("client disconnected");
|
||||
};
|
||||
match msg? {
|
||||
PgRawMessage::SslRequest => bail!("protocol violation"),
|
||||
msg => {
|
||||
let mut buf = BytesMut::new();
|
||||
msg.encode(&mut buf)?;
|
||||
auth_stream.send(proxy::PglbMessage::Postgres(
|
||||
buf
|
||||
)).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ComputeConnect {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
client_stream: Framed<TlsStream<TcpStream>, PgRawCodec>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
compute_socket: SocketAddr,
|
||||
}
|
||||
|
||||
impl PglbConn<ComputeConnect> {
|
||||
@@ -433,3 +493,103 @@ impl PglbConn<ComputePassthrough> {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct End;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PgRawCodec {
|
||||
start_or_ssl_request: bool,
|
||||
}
|
||||
|
||||
impl tokio_util::codec::Encoder<PgRawMessage> for PgRawCodec {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn encode(&mut self, item: PgRawMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
|
||||
item.encode(dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio_util::codec::Decoder for PgRawCodec {
|
||||
type Item = PgRawMessage;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if self.start_or_ssl_request {
|
||||
match PgRawMessage::decode(src, true)? {
|
||||
msg @ Some(PgRawMessage::Start(..)) => {
|
||||
self.start_or_ssl_request = false;
|
||||
Ok(msg)
|
||||
}
|
||||
msg @ Some(PgRawMessage::SslRequest) => Ok(msg),
|
||||
Some(PgRawMessage::Generic { .. }) => unreachable!(),
|
||||
None => Ok(None),
|
||||
}
|
||||
} else {
|
||||
match PgRawMessage::decode(src, false)? {
|
||||
Some(PgRawMessage::Start(..)) => unreachable!(),
|
||||
Some(PgRawMessage::SslRequest) => unreachable!(),
|
||||
msg @ Some(PgRawMessage::Generic { .. }) => Ok(msg),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum PgRawMessage {
|
||||
SslRequest,
|
||||
Start(Vec<u8>),
|
||||
Generic { tag: u8, payload: Vec<u8> },
|
||||
}
|
||||
|
||||
impl PgRawMessage {
|
||||
fn encode(&self, dst: &mut bytes::BytesMut) -> Result<()> {
|
||||
match self {
|
||||
Self::SslRequest => {
|
||||
dst.put_u32(8);
|
||||
dst.put_u32(80877103);
|
||||
}
|
||||
Self::Start(payload) => {
|
||||
dst.put_u32(payload.len() as u32 + 4);
|
||||
dst.put_slice(&payload);
|
||||
}
|
||||
Self::Generic { tag, payload } => {
|
||||
dst.put_u8(*tag);
|
||||
dst.put_u32(payload.len() as u32 + 4);
|
||||
dst.put_slice(&payload);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn decode(src: &mut bytes::BytesMut, start: bool) -> Result<Option<Self>> {
|
||||
if start {
|
||||
if src.remaining() < 4 {
|
||||
src.reserve(4);
|
||||
return Ok(None);
|
||||
}
|
||||
let length = src.get_u32() as usize - 4;
|
||||
if src.remaining() < length {
|
||||
src.reserve(length - src.remaining());
|
||||
return Ok(None);
|
||||
}
|
||||
if length == 4 && src.starts_with(&80877103u32.to_be_bytes()) {
|
||||
Ok(Some(PgRawMessage::SslRequest))
|
||||
} else {
|
||||
Ok(Some(PgRawMessage::Start(src.split_off(length).to_vec())))
|
||||
}
|
||||
} else {
|
||||
if src.remaining() < 5 {
|
||||
src.reserve(5);
|
||||
return Ok(None);
|
||||
}
|
||||
let tag = src.get_u8();
|
||||
let length = src.get_u32() as usize - 4;
|
||||
if src.remaining() < length {
|
||||
src.reserve(length - src.remaining());
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(PgRawMessage::Generic {
|
||||
tag,
|
||||
payload: src.split_off(length).to_vec(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user