diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index ebbf3e728e..5b7a87bc11 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -1,5 +1,8 @@ use crate::{ - cancellation::CancelMap, config::ProxyConfig, error::io_error, proxy::handle_ws_client, + cancellation::CancelMap, + config::ProxyConfig, + error::io_error, + proxy::{handle_client, ClientMode}, }; use bytes::{Buf, Bytes}; use futures::{Sink, Stream, StreamExt}; @@ -150,12 +153,12 @@ async fn serve_websocket( hostname: Option, ) -> anyhow::Result<()> { let websocket = websocket.await?; - handle_ws_client( + handle_client( config, cancel_map, session_id, WebSocketRw::new(websocket), - hostname, + ClientMode::Websockets { hostname }, ) .await?; Ok(()) diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index a43192c11e..d317d382a7 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -104,7 +104,8 @@ pub async fn task_main( .set_nodelay(true) .context("failed to set socket option")?; - handle_client(config, &cancel_map, session_id, socket).await + handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp) + .await } .unwrap_or_else(move |e| { // Acknowledge that the task has finished with an error. @@ -129,14 +130,50 @@ pub async fn task_main( Ok(()) } -// TODO(tech debt): unite this with its twin below. +pub enum ClientMode { + Tcp, + Websockets { hostname: Option }, +} + +/// Abstracts the logic of handling TCP vs WS clients +impl ClientMode { + fn allow_cleartext(&self) -> bool { + match self { + ClientMode::Tcp => false, + ClientMode::Websockets { .. } => true, + } + } + + fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { + match self { + ClientMode::Tcp => config.allow_self_signed_compute, + ClientMode::Websockets { .. } => false, + } + } + + fn hostname<'a, S>(&'a self, s: &'a Stream) -> Option<&'a str> { + match self { + ClientMode::Tcp => s.sni_hostname(), + ClientMode::Websockets { hostname } => hostname.as_deref(), + } + } + + fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> { + match self { + ClientMode::Tcp => tls, + // TLS is None here if using websockets, because the connection is already encrypted. + ClientMode::Websockets { .. } => None, + } + } +} + #[tracing::instrument(fields(session_id = ?session_id), skip_all)] -pub async fn handle_ws_client( +pub async fn handle_client( config: &'static ProxyConfig, cancel_map: &CancelMap, session_id: uuid::Uuid, - stream: impl AsyncRead + AsyncWrite + Unpin, - hostname: Option, + stream: S, + mode: ClientMode, ) -> anyhow::Result<()> { // The `closed` counter will increase when this future is destroyed. NUM_CONNECTIONS_ACCEPTED_COUNTER.inc(); @@ -145,10 +182,8 @@ pub async fn handle_ws_client( } let tls = config.tls_config.as_ref(); - let hostname = hostname.as_deref(); - // TLS is None here, because the connection is already encrypted. - let do_handshake = handshake(stream, None, cancel_map); + let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map); let (mut stream, params) = match do_handshake.await? { Some(x) => x, None => return Ok(()), // it's a cancellation request @@ -156,6 +191,7 @@ pub async fn handle_ws_client( // Extract credentials which we're going to use for auth. let creds = { + let hostname = mode.hostname(stream.get_ref()); let common_names = tls.and_then(|tls| tls.common_names.clone()); let result = config .auth_backend @@ -169,59 +205,15 @@ pub async fn handle_ws_client( } }; - let client = Client::new(stream, creds, ¶ms, session_id, false); - cancel_map - .with_session(|session| client.connect_to_db(session, true)) - .await -} - -#[tracing::instrument(fields(session_id = ?session_id), skip_all)] -async fn handle_client( - config: &'static ProxyConfig, - cancel_map: &CancelMap, - session_id: uuid::Uuid, - stream: impl AsyncRead + AsyncWrite + Unpin, -) -> anyhow::Result<()> { - // The `closed` counter will increase when this future is destroyed. - NUM_CONNECTIONS_ACCEPTED_COUNTER.inc(); - scopeguard::defer! { - NUM_CONNECTIONS_CLOSED_COUNTER.inc(); - } - - let tls = config.tls_config.as_ref(); - let do_handshake = handshake(stream, tls, cancel_map); - let (mut stream, params) = match do_handshake.await? { - Some(x) => x, - None => return Ok(()), // it's a cancellation request - }; - - // Extract credentials which we're going to use for auth. - let creds = { - let sni = stream.get_ref().sni_hostname(); - let common_names = tls.and_then(|tls| tls.common_names.clone()); - let result = config - .auth_backend - .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_names)) - .transpose(); - - match result { - Ok(creds) => creds, - Err(e) => stream.throw_error(e).await?, - } - }; - - let allow_self_signed_compute = config.allow_self_signed_compute; - let client = Client::new( stream, creds, ¶ms, session_id, - allow_self_signed_compute, + mode.allow_self_signed_compute(config), ); cancel_map - .with_session(|session| client.connect_to_db(session, false)) + .with_session(|session| client.connect_to_db(session, mode.allow_cleartext())) .await }