proxy: merge handle_client (#4740)

## Problem

Second half of #4699. we were maintaining 2 implementations of
handle_client.

## Summary of changes

Merge the handle_client code, but abstract some of the details.

## Checklist before requesting a review

- [X] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

## Checklist before merging

- [ ] Do not forget to reformat commit message to not include the above
checklist
This commit is contained in:
Conrad Ludgate
2023-07-17 22:20:23 +01:00
committed by GitHub
parent 4580f5085a
commit 2e8a3afab1
2 changed files with 52 additions and 57 deletions

View File

@@ -1,5 +1,8 @@
use crate::{
cancellation::CancelMap, config::ProxyConfig, error::io_error, proxy::handle_ws_client,
cancellation::CancelMap,
config::ProxyConfig,
error::io_error,
proxy::{handle_client, ClientMode},
};
use bytes::{Buf, Bytes};
use futures::{Sink, Stream, StreamExt};
@@ -150,12 +153,12 @@ async fn serve_websocket(
hostname: Option<String>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
handle_ws_client(
handle_client(
config,
cancel_map,
session_id,
WebSocketRw::new(websocket),
hostname,
ClientMode::Websockets { hostname },
)
.await?;
Ok(())

View File

@@ -104,7 +104,8 @@ pub async fn task_main(
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(config, &cancel_map, session_id, socket).await
handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp)
.await
}
.unwrap_or_else(move |e| {
// Acknowledge that the task has finished with an error.
@@ -129,14 +130,50 @@ pub async fn task_main(
Ok(())
}
// TODO(tech debt): unite this with its twin below.
pub enum ClientMode {
Tcp,
Websockets { hostname: Option<String> },
}
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
match self {
ClientMode::Tcp => config.allow_self_signed_compute,
ClientMode::Websockets { .. } => false,
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
// TLS is None here if using websockets, because the connection is already encrypted.
ClientMode::Websockets { .. } => None,
}
}
}
#[tracing::instrument(fields(session_id = ?session_id), skip_all)]
pub async fn handle_ws_client(
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
stream: impl AsyncRead + AsyncWrite + Unpin,
hostname: Option<String>,
stream: S,
mode: ClientMode,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
@@ -145,10 +182,8 @@ pub async fn handle_ws_client(
}
let tls = config.tls_config.as_ref();
let hostname = hostname.as_deref();
// TLS is None here, because the connection is already encrypted.
let do_handshake = handshake(stream, None, cancel_map);
let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
@@ -156,6 +191,7 @@ pub async fn handle_ws_client(
// Extract credentials which we're going to use for auth.
let creds = {
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.and_then(|tls| tls.common_names.clone());
let result = config
.auth_backend
@@ -169,59 +205,15 @@ pub async fn handle_ws_client(
}
};
let client = Client::new(stream, creds, &params, session_id, false);
cancel_map
.with_session(|session| client.connect_to_db(session, true))
.await
}
#[tracing::instrument(fields(session_id = ?session_id), skip_all)]
async fn handle_client(
config: &'static ProxyConfig,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
scopeguard::defer! {
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
}
let tls = config.tls_config.as_ref();
let do_handshake = handshake(stream, tls, cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
// Extract credentials which we're going to use for auth.
let creds = {
let sni = stream.get_ref().sni_hostname();
let common_names = tls.and_then(|tls| tls.common_names.clone());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, sni, common_names))
.transpose();
match result {
Ok(creds) => creds,
Err(e) => stream.throw_error(e).await?,
}
};
let allow_self_signed_compute = config.allow_self_signed_compute;
let client = Client::new(
stream,
creds,
&params,
session_id,
allow_self_signed_compute,
mode.allow_self_signed_compute(config),
);
cancel_map
.with_session(|session| client.connect_to_db(session, false))
.with_session(|session| client.connect_to_db(session, mode.allow_cleartext()))
.await
}