Merge branch 'cloneable/pglb-workers' into pglb

This commit is contained in:
Folke Behrens
2024-09-12 18:00:37 +01:00
2 changed files with 229 additions and 41 deletions

View File

@@ -7,23 +7,32 @@ use std::{
use anyhow::{anyhow, bail, Context, Result};
use bytes::BytesMut;
use futures::sink::SinkExt;
use indexmap::IndexMap;
use proxy::config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL};
use quinn::{Connection, Endpoint};
use proxy::{
config::{CertResolver, TlsServerEndPoint, PG_ALPN_PROTOCOL},
ConnectionInitiatedPayload, PglbCodec,
};
use quinn::{Connection, Endpoint, RecvStream, SendStream};
use rand::Rng;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::{
io::AsyncReadExt,
io::{join, AsyncReadExt, Join},
net::{TcpListener, TcpStream},
time::timeout,
};
use tokio_rustls::server::TlsStream;
use tokio_util::codec::Framed;
use tracing::{error, warn};
type AuthConnId = usize;
#[derive(Debug)]
struct AuthConnState {
conns: Mutex<IndexMap<AuthConnId, AuthConn>>,
}
#[derive(Clone, Debug)]
struct AuthConn {
conn: Connection,
// latency info...
@@ -46,7 +55,11 @@ async fn main() -> Result<()> {
let frontend_config = frontent_tls_config("pglb-fe", "pglb-fe")?;
let _frontend_handle = tokio::spawn(start_frontend("0.0.0.0:5432".parse()?, frontend_config));
let _frontend_handle = tokio::spawn(start_frontend(
"0.0.0.0:5432".parse()?,
frontend_config,
auth_connections.clone(),
));
quinn_handle.await.unwrap();
@@ -151,17 +164,28 @@ fn frontent_tls_config(hostname: &str, common_name: &str) -> Result<TlsConfig> {
})
}
async fn start_frontend(addr: SocketAddr, tls: TlsConfig) -> Result<Infallible> {
async fn start_frontend(
addr: SocketAddr,
tls: TlsConfig,
state: Arc<AuthConnState>,
) -> Result<Infallible> {
let listener = TcpListener::bind(addr).await?;
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let workers = tokio_util::task::task_tracker::TaskTracker::new();
loop {
match listener.accept().await {
Ok((socket, peer_addr)) => {
let tls = tls.clone();
connections.spawn_local(handle_frontend_connection(socket, peer_addr, tls));
Ok((socket, client_addr)) => {
let w = Worker {
state: Some(WorkerState::Start {
client_stream: socket,
client_addr,
}),
tls: tls.clone(),
auth_conns: Arc::clone(&state),
};
workers.spawn(w.start());
}
Err(e) => {
error!("connection accept error: {e}");
@@ -170,38 +194,207 @@ async fn start_frontend(addr: SocketAddr, tls: TlsConfig) -> Result<Infallible>
}
}
async fn handle_frontend_connection(mut stream: TcpStream, _peer_addr: SocketAddr, tls: TlsConfig) {
match stream.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!("socket option error: {e}");
return;
#[derive(Debug)]
struct Worker {
state: Option<WorkerState>,
tls: TlsConfig,
auth_conns: Arc<AuthConnState>,
}
#[derive(Debug)]
enum WorkerState {
Start {
client_stream: TcpStream,
client_addr: SocketAddr,
},
AuthConnect {
client_stream: TlsStream<TcpStream>,
payload: ConnectionInitiatedPayload,
},
AuthPassthrough {
client_stream: TlsStream<TcpStream>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
},
ComputeConnect {
client_stream: TlsStream<TcpStream>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
},
ComputePassthrough {
client_stream: TlsStream<TcpStream>,
compute_conn: (),
},
End,
}
impl Worker {
async fn start(mut self) -> Result<()> {
loop {
match self.state.take().expect("state should be specified") {
WorkerState::Start {
client_stream,
client_addr,
} => {
self.state = Some(self.handle_start(client_stream, client_addr).await?);
}
WorkerState::AuthConnect {
client_stream,
payload,
} => {
self.state = Some(self.handle_auth_connect(client_stream, payload).await?);
}
WorkerState::AuthPassthrough {
client_stream,
auth_stream,
} => {
self.state = Some(
self.handle_auth_passthrough(client_stream, auth_stream)
.await?,
);
}
WorkerState::ComputeConnect {
client_stream,
auth_stream,
} => {
self.state = Some(
self.handle_compute_connect(client_stream, auth_stream)
.await?,
);
}
WorkerState::ComputePassthrough {
client_stream,
compute_conn,
} => {
self.state = Some(
self.handle_compute_passthrough(client_stream, compute_conn)
.await?,
);
}
WorkerState::End => {
return Ok(());
}
}
}
};
}
// TODO: HAProxy protocol?
let tls_requested = match handle_ssl_request_message(&mut stream).await {
Ok(tls_requested) => tls_requested,
Err(e) => {
error!("check_for_ssl_request: {e}");
return;
}
};
if tls_requested {
let (stream, ep, sn) = match tls_upgrade(stream, tls).await {
Ok((stream, ep, sn)) => (stream, ep, sn),
async fn handle_start(
&self,
mut client_stream: TcpStream,
peer_addr: SocketAddr,
) -> Result<WorkerState> {
match client_stream.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!("tls_upgrade: {e}");
return;
bail!("socket option error: {e}");
}
};
// TODO: send auth msg with tls ep and server name
} else {
// TODO: send auth msg without server name
// TODO: HAProxy protocol
let tls_requested = match handle_ssl_request_message(&mut client_stream).await {
Ok(tls_requested) => tls_requested,
Err(e) => {
bail!("check_for_ssl_request: {e}");
}
};
let (client_stream, payload) = if tls_requested {
let (stream, tls_server_end_point, server_name) =
match tls_upgrade(client_stream, self.tls.clone()).await {
Ok((stream, ep, sn)) => (stream, ep, sn),
Err(e) => {
bail!("tls_upgrade: {e}");
}
};
(
stream,
ConnectionInitiatedPayload {
tls_server_end_point,
server_name,
ip_addr: peer_addr.ip(),
},
)
} else {
// TODO: support unsecured connections?
bail!("closing non-TLS connection");
};
Ok(WorkerState::AuthConnect {
client_stream,
payload,
})
}
async fn handle_auth_connect(
&self,
client_stream: TlsStream<TcpStream>,
payload: ConnectionInitiatedPayload,
) -> Result<WorkerState> {
let auth_conn = {
let conns = self.auth_conns.conns.lock().unwrap();
if conns.is_empty() {
bail!("no auth proxies avaiable");
}
let mut rng = rand::thread_rng();
conns
.get_index(rng.gen_range(0..conns.len()))
.unwrap()
.1
.clone()
// TODO: check closed?
};
let (send, recv) = auth_conn.conn.open_bi().await?;
let mut auth_stream = Framed::new(join(recv, send), PglbCodec);
auth_stream
.send(proxy::PglbMessage::Control(
proxy::PglbControlMessage::ConnectionInitiated(payload),
))
.await?;
Ok(WorkerState::AuthPassthrough {
client_stream,
auth_stream,
})
}
async fn handle_auth_passthrough(
&self,
_client_stream: TlsStream<TcpStream>,
_auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
) -> Result<WorkerState> {
todo!()
}
async fn handle_compute_connect(
&self,
_client_stream: TlsStream<TcpStream>,
_auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
) -> Result<WorkerState> {
todo!()
}
async fn handle_compute_passthrough(
&self,
_client_stream: TlsStream<TcpStream>,
_compute_conn: (),
) -> Result<WorkerState> {
todo!()
}
}
#[derive(Clone, Debug)]
struct TlsConfig {
config: Arc<rustls::ServerConfig>,
cert_resolver: Arc<CertResolver>,
}
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
@@ -253,9 +446,3 @@ async fn tls_upgrade(
Ok((tls_stream, tls_server_end_point, server_name))
}
#[derive(Clone, Debug)]
struct TlsConfig {
config: Arc<rustls::ServerConfig>,
cert_resolver: Arc<CertResolver>,
}

View File

@@ -283,6 +283,7 @@ impl EndpointId {
}
}
#[derive(Debug)]
pub struct PglbCodec;
impl tokio_util::codec::Encoder<PglbMessage> for PglbCodec {
@@ -404,7 +405,7 @@ pub enum PglbControlMessage {
ComputeEstablish,
}
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug)]
pub struct ConnectionInitiatedPayload {
pub tls_server_end_point: TlsServerEndPoint,
pub server_name: Option<String>,