mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 01:12:56 +00:00
refactor to type state pattern
This commit is contained in:
@@ -172,20 +172,13 @@ async fn start_frontend(
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let workers = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((socket, client_addr)) => {
|
||||
let w = Worker {
|
||||
state: Some(WorkerState::Start {
|
||||
client_stream: socket,
|
||||
client_addr,
|
||||
}),
|
||||
tls: tls.clone(),
|
||||
auth_conns: Arc::clone(&state),
|
||||
};
|
||||
workers.spawn(w.start());
|
||||
let conn = PglbConn::new(&state, &tls)?;
|
||||
connections.spawn(conn.handle(socket, client_addr));
|
||||
}
|
||||
Err(e) => {
|
||||
error!("connection accept error: {e}");
|
||||
@@ -194,98 +187,70 @@ async fn start_frontend(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TlsConfig {
|
||||
config: Arc<rustls::ServerConfig>,
|
||||
cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Worker {
|
||||
state: Option<WorkerState>,
|
||||
tls: TlsConfig,
|
||||
struct PglbConn<S: PglbConnState> {
|
||||
inner: PglbConnInner,
|
||||
state: S,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PglbConnInner {
|
||||
tls_config: TlsConfig,
|
||||
auth_conns: Arc<AuthConnState>,
|
||||
}
|
||||
|
||||
trait PglbConnState: std::fmt::Debug {}
|
||||
impl PglbConnState for Start {}
|
||||
impl PglbConnState for ClientConnect {}
|
||||
impl PglbConnState for AuthPassthrough {}
|
||||
impl PglbConnState for ComputeConnect {}
|
||||
impl PglbConnState for ComputePassthrough {}
|
||||
impl PglbConnState for End {}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum WorkerState {
|
||||
Start {
|
||||
client_stream: TcpStream,
|
||||
client_addr: SocketAddr,
|
||||
},
|
||||
AuthConnect {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
payload: ConnectionInitiatedPayload,
|
||||
},
|
||||
AuthPassthrough {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
},
|
||||
ComputeConnect {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
},
|
||||
ComputePassthrough {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
compute_conn: (),
|
||||
},
|
||||
End,
|
||||
}
|
||||
struct Start;
|
||||
|
||||
impl Worker {
|
||||
async fn start(mut self) -> Result<()> {
|
||||
loop {
|
||||
match self.state.take().expect("state should be specified") {
|
||||
WorkerState::Start {
|
||||
client_stream,
|
||||
client_addr,
|
||||
} => {
|
||||
self.state = Some(self.handle_start(client_stream, client_addr).await?);
|
||||
}
|
||||
|
||||
WorkerState::AuthConnect {
|
||||
client_stream,
|
||||
payload,
|
||||
} => {
|
||||
self.state = Some(self.handle_auth_connect(client_stream, payload).await?);
|
||||
}
|
||||
|
||||
WorkerState::AuthPassthrough {
|
||||
client_stream,
|
||||
auth_stream,
|
||||
} => {
|
||||
self.state = Some(
|
||||
self.handle_auth_passthrough(client_stream, auth_stream)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
WorkerState::ComputeConnect {
|
||||
client_stream,
|
||||
auth_stream,
|
||||
} => {
|
||||
self.state = Some(
|
||||
self.handle_compute_connect(client_stream, auth_stream)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
WorkerState::ComputePassthrough {
|
||||
client_stream,
|
||||
compute_conn,
|
||||
} => {
|
||||
self.state = Some(
|
||||
self.handle_compute_passthrough(client_stream, compute_conn)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
WorkerState::End => {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
impl PglbConn<Start> {
|
||||
fn new(auth_conns: &Arc<AuthConnState>, tls_config: &TlsConfig) -> Result<Self> {
|
||||
Ok(PglbConn {
|
||||
inner: PglbConnInner {
|
||||
auth_conns: Arc::clone(auth_conns),
|
||||
tls_config: tls_config.clone(),
|
||||
},
|
||||
state: Start,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
self,
|
||||
client_stream: TcpStream,
|
||||
client_addr: SocketAddr,
|
||||
) -> Result<PglbConn<End>> {
|
||||
self.handle_start(client_stream, client_addr)
|
||||
.await?
|
||||
.handle_client_connect()
|
||||
.await?
|
||||
.handle_auth_passthrough()
|
||||
.await?
|
||||
.handle_connect_connect()
|
||||
.await?
|
||||
.handle_compute_passthrough()
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl PglbConn<Start> {
|
||||
async fn handle_start(
|
||||
&self,
|
||||
self,
|
||||
mut client_stream: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> Result<WorkerState> {
|
||||
client_addr: SocketAddr,
|
||||
) -> Result<PglbConn<ClientConnect>> {
|
||||
match client_stream.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
@@ -295,7 +260,7 @@ impl Worker {
|
||||
|
||||
// TODO: HAProxy protocol
|
||||
|
||||
let tls_requested = match handle_ssl_request_message(&mut client_stream).await {
|
||||
let tls_requested = match Self::handle_ssl_request_message(&mut client_stream).await {
|
||||
Ok(tls_requested) => tls_requested,
|
||||
Err(e) => {
|
||||
bail!("check_for_ssl_request: {e}");
|
||||
@@ -304,7 +269,7 @@ impl Worker {
|
||||
|
||||
let (client_stream, payload) = if tls_requested {
|
||||
let (stream, tls_server_end_point, server_name) =
|
||||
match tls_upgrade(client_stream, self.tls.clone()).await {
|
||||
match Self::tls_upgrade(client_stream, self.inner.tls_config.clone()).await {
|
||||
Ok((stream, ep, sn)) => (stream, ep, sn),
|
||||
Err(e) => {
|
||||
bail!("tls_upgrade: {e}");
|
||||
@@ -316,7 +281,7 @@ impl Worker {
|
||||
ConnectionInitiatedPayload {
|
||||
tls_server_end_point,
|
||||
server_name,
|
||||
ip_addr: peer_addr.ip(),
|
||||
ip_addr: client_addr.ip(),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
@@ -324,19 +289,79 @@ impl Worker {
|
||||
bail!("closing non-TLS connection");
|
||||
};
|
||||
|
||||
Ok(WorkerState::AuthConnect {
|
||||
client_stream,
|
||||
payload,
|
||||
Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: ClientConnect {
|
||||
client_stream,
|
||||
payload,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_auth_connect(
|
||||
&self,
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
payload: ConnectionInitiatedPayload,
|
||||
) -> Result<WorkerState> {
|
||||
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
|
||||
let mut buf = BytesMut::with_capacity(8);
|
||||
|
||||
let n_peek = stream.peek(&mut buf).await?;
|
||||
if n_peek == 0 {
|
||||
bail!("EOF");
|
||||
}
|
||||
|
||||
assert_eq!(buf.len(), 8); // TODO: loop, read more
|
||||
|
||||
if buf.len() != 8
|
||||
|| buf[0..4] != 8u32.to_be_bytes()
|
||||
|| buf[4..8] != 80877103u32.to_be_bytes()
|
||||
{
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
buf.clear();
|
||||
let n_read = stream.read(&mut buf).await?;
|
||||
|
||||
assert_eq!(n_peek, n_read); // TODO: loop, read more
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn tls_upgrade(
|
||||
stream: TcpStream,
|
||||
tls: TlsConfig,
|
||||
) -> Result<(TlsStream<TcpStream>, TlsServerEndPoint, Option<String>)> {
|
||||
let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config)
|
||||
.accept(stream)
|
||||
.await?;
|
||||
|
||||
let conn_info = tls_stream.get_ref().1;
|
||||
let server_name = conn_info.server_name().map(|s| s.to_string());
|
||||
|
||||
match conn_info.alpn_protocol() {
|
||||
None | Some(PG_ALPN_PROTOCOL) => {}
|
||||
Some(other) => {
|
||||
let alpn = String::from_utf8_lossy(other);
|
||||
warn!(%alpn, "unexpected ALPN");
|
||||
bail!("protocol violation");
|
||||
}
|
||||
}
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(server_name.as_deref())
|
||||
.ok_or(anyhow!("missing cert"))?;
|
||||
|
||||
Ok((tls_stream, tls_server_end_point, server_name))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClientConnect {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
payload: ConnectionInitiatedPayload,
|
||||
}
|
||||
|
||||
impl PglbConn<ClientConnect> {
|
||||
async fn handle_client_connect(self) -> Result<PglbConn<AuthPassthrough>> {
|
||||
let auth_conn = {
|
||||
let conns = self.auth_conns.conns.lock().unwrap();
|
||||
let conns = self.inner.auth_conns.conns.lock().unwrap();
|
||||
if conns.is_empty() {
|
||||
bail!("no auth proxies avaiable");
|
||||
}
|
||||
@@ -356,93 +381,55 @@ impl Worker {
|
||||
|
||||
auth_stream
|
||||
.send(proxy::PglbMessage::Control(
|
||||
proxy::PglbControlMessage::ConnectionInitiated(payload),
|
||||
proxy::PglbControlMessage::ConnectionInitiated(self.state.payload),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(WorkerState::AuthPassthrough {
|
||||
client_stream,
|
||||
auth_stream,
|
||||
Ok(PglbConn {
|
||||
inner: self.inner,
|
||||
state: AuthPassthrough {
|
||||
client_stream: self.state.client_stream,
|
||||
auth_stream,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_auth_passthrough(
|
||||
&self,
|
||||
_client_stream: TlsStream<TcpStream>,
|
||||
_auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
) -> Result<WorkerState> {
|
||||
todo!()
|
||||
}
|
||||
#[derive(Debug)]
|
||||
struct AuthPassthrough {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
}
|
||||
|
||||
async fn handle_compute_connect(
|
||||
&self,
|
||||
_client_stream: TlsStream<TcpStream>,
|
||||
_auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
) -> Result<WorkerState> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn handle_compute_passthrough(
|
||||
&self,
|
||||
_client_stream: TlsStream<TcpStream>,
|
||||
_compute_conn: (),
|
||||
) -> Result<WorkerState> {
|
||||
impl PglbConn<AuthPassthrough> {
|
||||
async fn handle_auth_passthrough(self) -> Result<PglbConn<ComputeConnect>> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TlsConfig {
|
||||
config: Arc<rustls::ServerConfig>,
|
||||
cert_resolver: Arc<CertResolver>,
|
||||
#[derive(Debug)]
|
||||
struct ComputeConnect {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
auth_stream: Framed<Join<RecvStream, SendStream>, PglbCodec>,
|
||||
}
|
||||
|
||||
async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result<bool> {
|
||||
let mut buf = BytesMut::with_capacity(8);
|
||||
|
||||
let n_peek = stream.peek(&mut buf).await?;
|
||||
if n_peek == 0 {
|
||||
bail!("EOF");
|
||||
impl PglbConn<ComputeConnect> {
|
||||
async fn handle_connect_connect(self) -> Result<PglbConn<ComputePassthrough>> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
assert_eq!(buf.len(), 8); // TODO: loop, read more
|
||||
|
||||
if buf.len() != 8 || buf[0..4] != 8u32.to_be_bytes() || buf[4..8] != 80877103u32.to_be_bytes() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
buf.clear();
|
||||
let n_read = stream.read(&mut buf).await?;
|
||||
|
||||
assert_eq!(n_peek, n_read); // TODO: loop, read more
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn tls_upgrade(
|
||||
stream: TcpStream,
|
||||
tls: TlsConfig,
|
||||
) -> Result<(TlsStream<TcpStream>, TlsServerEndPoint, Option<String>)> {
|
||||
let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config)
|
||||
.accept(stream)
|
||||
.await?;
|
||||
|
||||
let conn_info = tls_stream.get_ref().1;
|
||||
let server_name = conn_info.server_name().map(|s| s.to_string());
|
||||
|
||||
match conn_info.alpn_protocol() {
|
||||
None | Some(PG_ALPN_PROTOCOL) => {}
|
||||
Some(other) => {
|
||||
let alpn = String::from_utf8_lossy(other);
|
||||
warn!(%alpn, "unexpected ALPN");
|
||||
bail!("protocol violation");
|
||||
}
|
||||
}
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(server_name.as_deref())
|
||||
.ok_or(anyhow!("missing cert"))?;
|
||||
|
||||
Ok((tls_stream, tls_server_end_point, server_name))
|
||||
#[derive(Debug)]
|
||||
struct ComputePassthrough {
|
||||
client_stream: TlsStream<TcpStream>,
|
||||
compute_conn: (),
|
||||
}
|
||||
|
||||
impl PglbConn<ComputePassthrough> {
|
||||
async fn handle_compute_passthrough(self) -> Result<PglbConn<End>> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct End;
|
||||
|
||||
Reference in New Issue
Block a user