Refactor common parts of handle_client and handle_ws_client to function.

There was a lot of duplicated code.

The resulting shared function now uses two tracing spans, one for
establishing the connections, and a separate span for forwarding the
traffic after that. This makes for nicer traces in the future, because
you can dig into how long the startup phase takes, and where the time
is spent.
This commit is contained in:
Heikki Linnakangas
2023-01-24 21:37:57 +02:00
parent 3ebca60517
commit 95fd68d76e

View File

@@ -14,7 +14,7 @@ use once_cell::sync::Lazy;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, instrument};
use tracing::{error, info, info_span, instrument, Instrument};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
@@ -71,7 +71,7 @@ pub async fn task_main(
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(config, &cancel_map, session_id, socket).await
handle_postgres_client(config, &cancel_map, session_id, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -81,6 +81,25 @@ pub async fn task_main(
}
}
/// Handle an incoming PostgreSQL connection
pub async fn handle_postgres_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
) -> anyhow::Result<()> {
handle_client(
config,
cancel_map,
session_id,
stream,
HostnameMethod::Sni,
false,
)
.await
}
/// Handle an incoming Postgres connection that's wrapped in websocket
pub async fn handle_ws_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
@@ -88,43 +107,20 @@ pub async fn handle_ws_client(
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
hostname: Option<String>,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
scopeguard::defer! {
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
}
let tls = config.tls_config.as_ref();
let hostname = hostname.as_deref();
// TLS is None here, because the connection is already encrypted.
let (mut stream, params) = match handshake(stream, None, cancel_map).await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
// Extract credentials which we're going to use for auth.
let creds = {
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let conn = EstablishedConnection::connect_to_db(
stream,
creds,
&params,
handle_client(
config,
cancel_map,
session_id,
cancel_map.new_session()?,
stream,
HostnameMethod::Param(hostname),
true,
)
.await?;
conn.handle_connection().await
.await
}
enum HostnameMethod {
Param(Option<String>),
Sni,
}
/// Handle an incoming client connection, handshake and authentication.
@@ -134,7 +130,9 @@ async fn handle_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
session_id: uuid::Uuid,
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
raw_stream: impl AsyncRead + AsyncWrite + Unpin + Send,
hostname_method: HostnameMethod,
use_cleartext_password_flow: bool,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
@@ -142,36 +140,66 @@ async fn handle_client(
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
}
// Process postgres startup packet and upgrade to TLS (if applicable)
let tls = config.tls_config.as_ref();
let (mut stream, params) = match handshake(stream, tls, cancel_map).await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
// Accept the connection from the client, authenticate it, and establish
// connection to the database.
//
// We cover all these activities in a one tracing span, so that they are
// traced as one request. That makes it convenient to investigate where
// the time is spent when establishing a new connection. After the
// connection has been established, we exit the span, and use a separate
// span for the (rest of the) duration of the connection.
let conn = async {
// Process postgres startup packet and upgrade to TLS (if applicable)
let tls = config.tls_config.as_ref();
let (mut stream, params) = match handshake(raw_stream, tls, cancel_map).await? {
Some(x) => x,
None => return Ok::<_, anyhow::Error>(None), // it's a cancellation request
};
// Extract credentials which we're going to use for auth.
let creds = {
let sni = stream.get_ref().sni_hostname();
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.transpose();
// Extract credentials which we're going to use for auth.
let creds = {
let sni = match &hostname_method {
HostnameMethod::Param(hostname) => hostname.as_deref(),
HostnameMethod::Sni => stream.get_ref().sni_hostname(),
};
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
};
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let conn = EstablishedConnection::connect_to_db(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
false,
)
Ok(Some(
EstablishedConnection::connect_to_db(
stream,
creds,
&params,
session_id,
use_cleartext_password_flow,
cancel_map,
)
.await?,
))
}
.instrument(info_span!("establish_connection", session_id=%session_id))
.await?;
conn.handle_connection().await
match conn {
Some(conn) => {
// Connection established in both ways. Forward all traffic until the
// either connection is lost.
conn.handle_connection()
.instrument(info_span!("forward", session_id=%session_id))
.await
}
None => {
// It was a cancellation request. It was handled in 'handshake' already.
Ok(())
}
}
}
/// Establish a (most probably, secure) connection with the client.
@@ -264,12 +292,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> EstablishedConnection<'_, S> {
#[instrument(skip_all)]
async fn connect_to_db<'a>(
mut stream: PqStream<S>,
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
creds: auth::BackendType<'a, auth::ClientCredentials<'_>>,
params: &'_ StartupMessageParams,
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
use_cleartext_password_flow: bool,
) -> anyhow::Result<EstablishedConnection<S>> {
cancel_map: &'a CancelMap,
) -> anyhow::Result<EstablishedConnection<'a, S>> {
let session = cancel_map.new_session()?;
let extra = auth::ConsoleReqExtra {
session_id, // aka this connection's id
application_name: params.get("application_name"),