mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-22 07:30:37 +00:00
add auth proxy codec
This commit is contained in:
@@ -1,16 +1,20 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use proxy::PglbCodec;
|
||||
use quinn::{
|
||||
crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream,
|
||||
VarInt,
|
||||
};
|
||||
use tokio::{
|
||||
io::AsyncWriteExt,
|
||||
io::{join, AsyncWriteExt},
|
||||
select,
|
||||
signal::unix::{signal, SignalKind},
|
||||
time::interval,
|
||||
};
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tokio_util::{
|
||||
codec::{Framed, FramedRead, FramedWrite},
|
||||
task::TaskTracker,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
@@ -102,4 +106,6 @@ impl danger::ServerCertVerifier for NoVerify {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stream(_send: SendStream, _recv: RecvStream) {}
|
||||
async fn handle_stream(send: SendStream, recv: RecvStream) {
|
||||
let _stream = Framed::new(join(recv, send), PglbCodec);
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ use rustls::{
|
||||
crypto::ring::sign,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
@@ -149,7 +150,7 @@ pub fn configure_tls(
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
|
||||
137
proxy/src/lib.rs
137
proxy/src/lib.rs
@@ -82,10 +82,17 @@
|
||||
impl_trait_overcaptures,
|
||||
)]
|
||||
|
||||
use std::{convert::Infallible, future::Future};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::Future,
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use bytes::{Buf, BufMut};
|
||||
use config::TlsServerEndPoint;
|
||||
use intern::{EndpointIdInt, EndpointIdTag, InternId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::task::JoinError;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
@@ -274,3 +281,131 @@ impl EndpointId {
|
||||
ProjectId(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PglbCodec;
|
||||
|
||||
impl tokio_util::codec::Encoder<PglbMessage> for PglbCodec {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn encode(&mut self, item: PglbMessage, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
|
||||
match item {
|
||||
PglbMessage::Control(ctrl) => {
|
||||
dst.put_u8(1);
|
||||
match ctrl {
|
||||
PglbControlMessage::ConnectionInitiated(msg) => {
|
||||
let encode = serde_json::to_string(&msg).context("ser")?;
|
||||
dst.put_u32(1 + encode.len() as u32);
|
||||
dst.put_u8(0);
|
||||
dst.put(encode.as_bytes());
|
||||
}
|
||||
PglbControlMessage::ConnectToCompute { socket } => match socket {
|
||||
SocketAddr::V4(v4) => {
|
||||
dst.put_u32(1 + 4 + 2);
|
||||
dst.put_u8(1);
|
||||
dst.put_u32(v4.ip().to_bits());
|
||||
dst.put_u16(v4.port());
|
||||
}
|
||||
SocketAddr::V6(v6) => {
|
||||
dst.put_u32(1 + 16 + 2);
|
||||
dst.put_u8(1);
|
||||
dst.put_u128(v6.ip().to_bits());
|
||||
dst.put_u16(v6.port());
|
||||
}
|
||||
},
|
||||
PglbControlMessage::ComputeEstablish => {
|
||||
dst.put_u32(1);
|
||||
dst.put_u8(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
PglbMessage::Postgres(pg) => {
|
||||
dst.put_u8(0);
|
||||
dst.put_u32(pg.len() as u32);
|
||||
dst.put(pg);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio_util::codec::Decoder for PglbCodec {
|
||||
type Item = PglbMessage;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn decode(&mut self, dst: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if dst.remaining() < 5 {
|
||||
dst.reserve(5);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let msg = dst[0];
|
||||
let len = u32::from_be_bytes(dst[1..5].try_into().unwrap()) as usize;
|
||||
|
||||
if len + 5 > dst.remaining() {
|
||||
dst.reserve(len + 5);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
dst.advance(5);
|
||||
let mut payload = dst.split_to(len);
|
||||
|
||||
match msg {
|
||||
// postgres
|
||||
0 => Ok(Some(PglbMessage::Postgres(payload.freeze()))),
|
||||
// control
|
||||
1 => {
|
||||
if payload.is_empty() {
|
||||
bail!("invalid ctrl message")
|
||||
}
|
||||
let ctrl_msg = payload.split_to(1)[0];
|
||||
let ctrl_msg = match ctrl_msg {
|
||||
0 => PglbControlMessage::ConnectionInitiated(
|
||||
serde_json::from_slice(&payload).context("deser")?,
|
||||
),
|
||||
// ipv4 socket
|
||||
1 if len == 7 => PglbControlMessage::ConnectToCompute {
|
||||
socket: SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::from_bits(payload.get_u32())),
|
||||
payload.get_u16(),
|
||||
),
|
||||
},
|
||||
|
||||
// ipv6 socket
|
||||
1 if len == 19 => PglbControlMessage::ConnectToCompute {
|
||||
socket: SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::from_bits(payload.get_u128())),
|
||||
payload.get_u16(),
|
||||
),
|
||||
},
|
||||
|
||||
2 if len == 1 => PglbControlMessage::ComputeEstablish,
|
||||
|
||||
_ => bail!("invalid ctrl message"),
|
||||
};
|
||||
Ok(Some(PglbMessage::Control(ctrl_msg)))
|
||||
}
|
||||
_ => bail!("invalid message"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum PglbMessage {
|
||||
Control(PglbControlMessage),
|
||||
Postgres(bytes::Bytes),
|
||||
}
|
||||
|
||||
pub enum PglbControlMessage {
|
||||
// from pglb to auth proxy
|
||||
ConnectionInitiated(ConnectionInitiatedPayload),
|
||||
// from auth proxy to pglb
|
||||
ConnectToCompute { socket: SocketAddr },
|
||||
// from auth proxy to pglb
|
||||
ComputeEstablish,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ConnectionInitiatedPayload {
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
server_name: Option<String>,
|
||||
ip_addr: IpAddr,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user