refactor to type state pattern

This commit is contained in:
Folke Behrens
2024-09-12 21:32:47 +01:00
parent 3f66c12280
commit 8198a503f2

View File

@@ -172,20 +172,13 @@ async fn start_frontend(
let listener = TcpListener::bind(addr).await?;
socket2::SockRef::from(&listener).set_keepalive(true)?;
let workers = tokio_util::task::task_tracker::TaskTracker::new();
let connections = tokio_util::task::task_tracker::TaskTracker::new();
loop {
match listener.accept().await {
Ok((socket, client_addr)) => {
let w = Worker {
state: Some(WorkerState::Start {
client_stream: socket,
client_addr,
}),
tls: tls.clone(),
auth_conns: Arc::clone(&state),
};
workers.spawn(w.start());
let conn = PglbConn::new(&state, &tls)?;
connections.spawn(conn.handle(socket, client_addr));
}
Err(e) => {
error!("connection accept error: {e}");
@@ -194,98 +187,70 @@ async fn start_frontend(
}
}
#[derive(Clone, Debug)]
struct TlsConfig {
config: Arc<rustls::ServerConfig>,
cert_resolver: Arc<CertResolver>,
}
#[derive(Debug)]
struct Worker {
state: Option<WorkerState>,
tls: TlsConfig,
struct PglbConn<S: PglbConnState> {
inner: PglbConnInner,
state: S,
}
#[derive(Debug)]
struct PglbConnInner {
tls_config: TlsConfig,
auth_conns: Arc<AuthConnState>,
}
trait PglbConnState: std::fmt::Debug {}
impl PglbConnState for Start {}
impl PglbConnState for ClientConnect {}
impl PglbConnState for AuthPassthrough {}
impl PglbConnState for ComputeConnect {}
impl PglbConnState for ComputePassthrough {}
impl PglbConnState for End {}
#[derive(Debug)]
enum WorkerState {
Start {
client_stream: TcpStream,
client_addr: SocketAddr,
},
AuthConnect {
client_stream: TlsStream<TcpStream>,
payload: ConnectionInitiatedPayload,
},
AuthPassthrough {
client_stream: TlsStream<TcpStream>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
},
ComputeConnect {
client_stream: TlsStream<TcpStream>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
},
ComputePassthrough {
client_stream: TlsStream<TcpStream>,
compute_conn: (),
},
End,
}
struct Start;
impl Worker {
async fn start(mut self) -> Result<()> {
loop {
match self.state.take().expect("state should be specified") {
WorkerState::Start {
client_stream,
client_addr,
} => {
self.state = Some(self.handle_start(client_stream, client_addr).await?);
}
WorkerState::AuthConnect {
client_stream,
payload,
} => {
self.state = Some(self.handle_auth_connect(client_stream, payload).await?);
}
WorkerState::AuthPassthrough {
client_stream,
auth_stream,
} => {
self.state = Some(
self.handle_auth_passthrough(client_stream, auth_stream)
.await?,
);
}
WorkerState::ComputeConnect {
client_stream,
auth_stream,
} => {
self.state = Some(
self.handle_compute_connect(client_stream, auth_stream)
.await?,
);
}
WorkerState::ComputePassthrough {
client_stream,
compute_conn,
} => {
self.state = Some(
self.handle_compute_passthrough(client_stream, compute_conn)
.await?,
);
}
WorkerState::End => {
return Ok(());
}
}
}
impl PglbConn<Start> {
fn new(auth_conns: &Arc<AuthConnState>, tls_config: &TlsConfig) -> Result<Self> {
Ok(PglbConn {
inner: PglbConnInner {
auth_conns: Arc::clone(auth_conns),
tls_config: tls_config.clone(),
},
state: Start,
})
}
async fn handle(
self,
client_stream: TcpStream,
client_addr: SocketAddr,
) -> Result<PglbConn<End>> {
self.handle_start(client_stream, client_addr)
.await?
.handle_client_connect()
.await?
.handle_auth_passthrough()
.await?
.handle_connect_connect()
.await?
.handle_compute_passthrough()
.await
}
}
impl PglbConn<Start> {
async fn handle_start(
&self,
self,
mut client_stream: TcpStream,
peer_addr: SocketAddr,
) -> Result<WorkerState> {
client_addr: SocketAddr,
) -> Result<PglbConn<ClientConnect>> {
match client_stream.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
@@ -295,7 +260,7 @@ impl Worker {
// TODO: HAProxy protocol
let tls_requested = match handle_ssl_request_message(&mut client_stream).await {
let tls_requested = match Self::handle_ssl_request_message(&mut client_stream).await {
Ok(tls_requested) => tls_requested,
Err(e) => {
bail!("check_for_ssl_request: {e}");
@@ -304,7 +269,7 @@ impl Worker {
let (client_stream, payload) = if tls_requested {
let (stream, tls_server_end_point, server_name) =
match tls_upgrade(client_stream, self.tls.clone()).await {
match Self::tls_upgrade(client_stream, self.inner.tls_config.clone()).await {
Ok((stream, ep, sn)) => (stream, ep, sn),
Err(e) => {
bail!("tls_upgrade: {e}");
@@ -316,7 +281,7 @@ impl Worker {
ConnectionInitiatedPayload {
tls_server_end_point,
server_name,
ip_addr: peer_addr.ip(),
ip_addr: client_addr.ip(),
},
)
} else {
@@ -324,19 +289,79 @@ impl Worker {
bail!("closing non-TLS connection");
};
Ok(WorkerState::AuthConnect {
client_stream,
payload,
Ok(PglbConn {
inner: self.inner,
state: ClientConnect {
client_stream,
payload,
},
})
}
async fn handle_auth_connect(
&self,
client_stream: TlsStream<TcpStream>,
payload: ConnectionInitiatedPayload,
) -> Result<WorkerState> {
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
let mut buf = BytesMut::with_capacity(8);
let n_peek = stream.peek(&mut buf).await?;
if n_peek == 0 {
bail!("EOF");
}
assert_eq!(buf.len(), 8); // TODO: loop, read more
if buf.len() != 8
|| buf[0..4] != 8u32.to_be_bytes()
|| buf[4..8] != 80877103u32.to_be_bytes()
{
return Ok(false);
}
buf.clear();
let n_read = stream.read(&mut buf).await?;
assert_eq!(n_peek, n_read); // TODO: loop, read more
Ok(true)
}
async fn tls_upgrade(
stream: TcpStream,
tls: TlsConfig,
) -> Result<(TlsStream<TcpStream>, TlsServerEndPoint, Option<String>)> {
let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config)
.accept(stream)
.await?;
let conn_info = tls_stream.get_ref().1;
let server_name = conn_info.server_name().map(|s| s.to_string());
match conn_info.alpn_protocol() {
None | Some(PG_ALPN_PROTOCOL) => {}
Some(other) => {
let alpn = String::from_utf8_lossy(other);
warn!(%alpn, "unexpected ALPN");
bail!("protocol violation");
}
}
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(server_name.as_deref())
.ok_or(anyhow!("missing cert"))?;
Ok((tls_stream, tls_server_end_point, server_name))
}
}
#[derive(Debug)]
struct ClientConnect {
client_stream: TlsStream<TcpStream>,
payload: ConnectionInitiatedPayload,
}
impl PglbConn<ClientConnect> {
async fn handle_client_connect(self) -> Result<PglbConn<AuthPassthrough>> {
let auth_conn = {
let conns = self.auth_conns.conns.lock().unwrap();
let conns = self.inner.auth_conns.conns.lock().unwrap();
if conns.is_empty() {
bail!("no auth proxies avaiable");
}
@@ -356,93 +381,55 @@ impl Worker {
auth_stream
.send(proxy::PglbMessage::Control(
proxy::PglbControlMessage::ConnectionInitiated(payload),
proxy::PglbControlMessage::ConnectionInitiated(self.state.payload),
))
.await?;
Ok(WorkerState::AuthPassthrough {
client_stream,
auth_stream,
Ok(PglbConn {
inner: self.inner,
state: AuthPassthrough {
client_stream: self.state.client_stream,
auth_stream,
},
})
}
}
async fn handle_auth_passthrough(
&self,
_client_stream: TlsStream<TcpStream>,
_auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
) -> Result<WorkerState> {
todo!()
}
#[derive(Debug)]
struct AuthPassthrough {
client_stream: TlsStream<TcpStream>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
}
async fn handle_compute_connect(
&self,
_client_stream: TlsStream<TcpStream>,
_auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
) -> Result<WorkerState> {
todo!()
}
async fn handle_compute_passthrough(
&self,
_client_stream: TlsStream<TcpStream>,
_compute_conn: (),
) -> Result<WorkerState> {
impl PglbConn<AuthPassthrough> {
async fn handle_auth_passthrough(self) -> Result<PglbConn<ComputeConnect>> {
todo!()
}
}
#[derive(Clone, Debug)]
struct TlsConfig {
config: Arc<rustls::ServerConfig>,
cert_resolver: Arc<CertResolver>,
#[derive(Debug)]
struct ComputeConnect {
client_stream: TlsStream<TcpStream>,
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
}
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
let mut buf = BytesMut::with_capacity(8);
let n_peek = stream.peek(&mut buf).await?;
if n_peek == 0 {
bail!("EOF");
impl PglbConn<ComputeConnect> {
async fn handle_connect_connect(self) -> Result<PglbConn<ComputePassthrough>> {
todo!()
}
assert_eq!(buf.len(), 8); // TODO: loop, read more
if buf.len() != 8 || buf[0..4] != 8u32.to_be_bytes() || buf[4..8] != 80877103u32.to_be_bytes() {
return Ok(false);
}
buf.clear();
let n_read = stream.read(&mut buf).await?;
assert_eq!(n_peek, n_read); // TODO: loop, read more
Ok(true)
}
async fn tls_upgrade(
stream: TcpStream,
tls: TlsConfig,
) -> Result<(TlsStream<TcpStream>, TlsServerEndPoint, Option<String>)> {
let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config)
.accept(stream)
.await?;
let conn_info = tls_stream.get_ref().1;
let server_name = conn_info.server_name().map(|s| s.to_string());
match conn_info.alpn_protocol() {
None | Some(PG_ALPN_PROTOCOL) => {}
Some(other) => {
let alpn = String::from_utf8_lossy(other);
warn!(%alpn, "unexpected ALPN");
bail!("protocol violation");
}
}
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(server_name.as_deref())
.ok_or(anyhow!("missing cert"))?;
Ok((tls_stream, tls_server_end_point, server_name))
#[derive(Debug)]
struct ComputePassthrough {
client_stream: TlsStream<TcpStream>,
compute_conn: (),
}
impl PglbConn<ComputePassthrough> {
async fn handle_compute_passthrough(self) -> Result<PglbConn<End>> {
todo!()
}
}
#[derive(Debug)]
struct End;