diff --git a/proxy/src/bin/pglb.rs b/proxy/src/bin/pglb.rs index 37b3d2b15c..d2011f6f98 100644 --- a/proxy/src/bin/pglb.rs +++ b/proxy/src/bin/pglb.rs @@ -172,20 +172,13 @@ async fn start_frontend( let listener = TcpListener::bind(addr).await?; socket2::SockRef::from(&listener).set_keepalive(true)?; - let workers = tokio_util::task::task_tracker::TaskTracker::new(); + let connections = tokio_util::task::task_tracker::TaskTracker::new(); loop { match listener.accept().await { Ok((socket, client_addr)) => { - let w = Worker { - state: Some(WorkerState::Start { - client_stream: socket, - client_addr, - }), - tls: tls.clone(), - auth_conns: Arc::clone(&state), - }; - workers.spawn(w.start()); + let conn = PglbConn::new(&state, &tls)?; + connections.spawn(conn.handle(socket, client_addr)); } Err(e) => { error!("connection accept error: {e}"); @@ -194,98 +187,70 @@ async fn start_frontend( } } +#[derive(Clone, Debug)] +struct TlsConfig { + config: Arc, + cert_resolver: Arc, +} + #[derive(Debug)] -struct Worker { - state: Option, - tls: TlsConfig, +struct PglbConn { + inner: PglbConnInner, + state: S, +} + +#[derive(Debug)] +struct PglbConnInner { + tls_config: TlsConfig, auth_conns: Arc, } +trait PglbConnState: std::fmt::Debug {} +impl PglbConnState for Start {} +impl PglbConnState for ClientConnect {} +impl PglbConnState for AuthPassthrough {} +impl PglbConnState for ComputeConnect {} +impl PglbConnState for ComputePassthrough {} +impl PglbConnState for End {} + #[derive(Debug)] -enum WorkerState { - Start { - client_stream: TcpStream, - client_addr: SocketAddr, - }, - AuthConnect { - client_stream: TlsStream, - payload: ConnectionInitiatedPayload, - }, - AuthPassthrough { - client_stream: TlsStream, - auth_stream: Framed, PglbCodec>, - }, - ComputeConnect { - client_stream: TlsStream, - auth_stream: Framed, PglbCodec>, - }, - ComputePassthrough { - client_stream: TlsStream, - compute_conn: (), - }, - End, -} +struct Start; -impl Worker { - async fn start(mut self) -> Result<()> { - loop { - match self.state.take().expect("state should be specified") { - WorkerState::Start { - client_stream, - client_addr, - } => { - self.state = Some(self.handle_start(client_stream, client_addr).await?); - } - - WorkerState::AuthConnect { - client_stream, - payload, - } => { - self.state = Some(self.handle_auth_connect(client_stream, payload).await?); - } - - WorkerState::AuthPassthrough { - client_stream, - auth_stream, - } => { - self.state = Some( - self.handle_auth_passthrough(client_stream, auth_stream) - .await?, - ); - } - - WorkerState::ComputeConnect { - client_stream, - auth_stream, - } => { - self.state = Some( - self.handle_compute_connect(client_stream, auth_stream) - .await?, - ); - } - - WorkerState::ComputePassthrough { - client_stream, - compute_conn, - } => { - self.state = Some( - self.handle_compute_passthrough(client_stream, compute_conn) - .await?, - ); - } - - WorkerState::End => { - return Ok(()); - } - } - } +impl PglbConn { + fn new(auth_conns: &Arc, tls_config: &TlsConfig) -> Result { + Ok(PglbConn { + inner: PglbConnInner { + auth_conns: Arc::clone(auth_conns), + tls_config: tls_config.clone(), + }, + state: Start, + }) } + async fn handle( + self, + client_stream: TcpStream, + client_addr: SocketAddr, + ) -> Result> { + self.handle_start(client_stream, client_addr) + .await? + .handle_client_connect() + .await? + .handle_auth_passthrough() + .await? + .handle_connect_connect() + .await? + .handle_compute_passthrough() + .await + } +} + +impl PglbConn { async fn handle_start( - &self, + self, mut client_stream: TcpStream, - peer_addr: SocketAddr, - ) -> Result { + client_addr: SocketAddr, + ) -> Result> { match client_stream.set_nodelay(true) { Ok(()) => {} Err(e) => { @@ -295,7 +260,7 @@ impl Worker { // TODO: HAProxy protocol - let tls_requested = match handle_ssl_request_message(&mut client_stream).await { + let tls_requested = match Self::handle_ssl_request_message(&mut client_stream).await { Ok(tls_requested) => tls_requested, Err(e) => { bail!("check_for_ssl_request: {e}"); @@ -304,7 +269,7 @@ impl Worker { let (client_stream, payload) = if tls_requested { let (stream, tls_server_end_point, server_name) = - match tls_upgrade(client_stream, self.tls.clone()).await { + match Self::tls_upgrade(client_stream, self.inner.tls_config.clone()).await { Ok((stream, ep, sn)) => (stream, ep, sn), Err(e) => { bail!("tls_upgrade: {e}"); @@ -316,7 +281,7 @@ impl Worker { ConnectionInitiatedPayload { tls_server_end_point, server_name, - ip_addr: peer_addr.ip(), + ip_addr: client_addr.ip(), }, ) } else { @@ -324,19 +289,79 @@ impl Worker { bail!("closing non-TLS connection"); }; - Ok(WorkerState::AuthConnect { - client_stream, - payload, + Ok(PglbConn { + inner: self.inner, + state: ClientConnect { + client_stream, + payload, + }, }) } - async fn handle_auth_connect( - &self, - client_stream: TlsStream, - payload: ConnectionInitiatedPayload, - ) -> Result { + async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result { + let mut buf = BytesMut::with_capacity(8); + + let n_peek = stream.peek(&mut buf).await?; + if n_peek == 0 { + bail!("EOF"); + } + + assert_eq!(buf.len(), 8); // TODO: loop, read more + + if buf.len() != 8 + || buf[0..4] != 8u32.to_be_bytes() + || buf[4..8] != 80877103u32.to_be_bytes() + { + return Ok(false); + } + + buf.clear(); + let n_read = stream.read(&mut buf).await?; + + assert_eq!(n_peek, n_read); // TODO: loop, read more + + Ok(true) + } + + async fn tls_upgrade( + stream: TcpStream, + tls: TlsConfig, + ) -> Result<(TlsStream, TlsServerEndPoint, Option)> { + let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config) + .accept(stream) + .await?; + + let conn_info = tls_stream.get_ref().1; + let server_name = conn_info.server_name().map(|s| s.to_string()); + + match conn_info.alpn_protocol() { + None | Some(PG_ALPN_PROTOCOL) => {} + Some(other) => { + let alpn = String::from_utf8_lossy(other); + warn!(%alpn, "unexpected ALPN"); + bail!("protocol violation"); + } + } + + let (_, tls_server_end_point) = tls + .cert_resolver + .resolve(server_name.as_deref()) + .ok_or(anyhow!("missing cert"))?; + + Ok((tls_stream, tls_server_end_point, server_name)) + } +} + +#[derive(Debug)] +struct ClientConnect { + client_stream: TlsStream, + payload: ConnectionInitiatedPayload, +} + +impl PglbConn { + async fn handle_client_connect(self) -> Result> { let auth_conn = { - let conns = self.auth_conns.conns.lock().unwrap(); + let conns = self.inner.auth_conns.conns.lock().unwrap(); if conns.is_empty() { bail!("no auth proxies avaiable"); } @@ -356,93 +381,55 @@ impl Worker { auth_stream .send(proxy::PglbMessage::Control( - proxy::PglbControlMessage::ConnectionInitiated(payload), + proxy::PglbControlMessage::ConnectionInitiated(self.state.payload), )) .await?; - Ok(WorkerState::AuthPassthrough { - client_stream, - auth_stream, + Ok(PglbConn { + inner: self.inner, + state: AuthPassthrough { + client_stream: self.state.client_stream, + auth_stream, + }, }) } +} - async fn handle_auth_passthrough( - &self, - _client_stream: TlsStream, - _auth_stream: Framed, PglbCodec>, - ) -> Result { - todo!() - } +#[derive(Debug)] +struct AuthPassthrough { + client_stream: TlsStream, + auth_stream: Framed, PglbCodec>, +} - async fn handle_compute_connect( - &self, - _client_stream: TlsStream, - _auth_stream: Framed, PglbCodec>, - ) -> Result { - todo!() - } - - async fn handle_compute_passthrough( - &self, - _client_stream: TlsStream, - _compute_conn: (), - ) -> Result { +impl PglbConn { + async fn handle_auth_passthrough(self) -> Result> { todo!() } } -#[derive(Clone, Debug)] -struct TlsConfig { - config: Arc, - cert_resolver: Arc, +#[derive(Debug)] +struct ComputeConnect { + client_stream: TlsStream, + auth_stream: Framed, PglbCodec>, } -async fn handle_ssl_request_message(stream: &mut TcpStream) -> Result { - let mut buf = BytesMut::with_capacity(8); - - let n_peek = stream.peek(&mut buf).await?; - if n_peek == 0 { - bail!("EOF"); +impl PglbConn { + async fn handle_connect_connect(self) -> Result> { + todo!() } - - assert_eq!(buf.len(), 8); // TODO: loop, read more - - if buf.len() != 8 || buf[0..4] != 8u32.to_be_bytes() || buf[4..8] != 80877103u32.to_be_bytes() { - return Ok(false); - } - - buf.clear(); - let n_read = stream.read(&mut buf).await?; - - assert_eq!(n_peek, n_read); // TODO: loop, read more - - Ok(true) } -async fn tls_upgrade( - stream: TcpStream, - tls: TlsConfig, -) -> Result<(TlsStream, TlsServerEndPoint, Option)> { - let tls_stream = tokio_rustls::TlsAcceptor::from(tls.config) - .accept(stream) - .await?; - - let conn_info = tls_stream.get_ref().1; - let server_name = conn_info.server_name().map(|s| s.to_string()); - - match conn_info.alpn_protocol() { - None | Some(PG_ALPN_PROTOCOL) => {} - Some(other) => { - let alpn = String::from_utf8_lossy(other); - warn!(%alpn, "unexpected ALPN"); - bail!("protocol violation"); - } - } - - let (_, tls_server_end_point) = tls - .cert_resolver - .resolve(server_name.as_deref()) - .ok_or(anyhow!("missing cert"))?; - - Ok((tls_stream, tls_server_end_point, server_name)) +#[derive(Debug)] +struct ComputePassthrough { + client_stream: TlsStream, + compute_conn: (), } + +impl PglbConn { + async fn handle_compute_passthrough(self) -> Result> { + todo!() + } +} + +#[derive(Debug)] +struct End;