diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 200b8da714..696ef07814 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -14,7 +14,7 @@ use once_cell::sync::Lazy; use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{error, info, instrument}; +use tracing::{error, info, info_span, instrument, Instrument}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; const ERR_PROTO_VIOLATION: &str = "protocol violation"; @@ -71,7 +71,7 @@ pub async fn task_main( .set_nodelay(true) .context("failed to set socket option")?; - handle_client(config, &cancel_map, session_id, socket).await + handle_postgres_client(config, &cancel_map, session_id, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -81,6 +81,25 @@ pub async fn task_main( } } +/// Handle an incoming PostgreSQL connection +pub async fn handle_postgres_client( + config: &ProxyConfig, + cancel_map: &CancelMap, + session_id: uuid::Uuid, + stream: impl AsyncRead + AsyncWrite + Unpin + Send, +) -> anyhow::Result<()> { + handle_client( + config, + cancel_map, + session_id, + stream, + HostnameMethod::Sni, + false, + ) + .await +} + +/// Handle an incoming Postgres connection that's wrapped in websocket pub async fn handle_ws_client( config: &ProxyConfig, cancel_map: &CancelMap, @@ -88,43 +107,20 @@ pub async fn handle_ws_client( stream: impl AsyncRead + AsyncWrite + Unpin + Send, hostname: Option, ) -> anyhow::Result<()> { - // The `closed` counter will increase when this future is destroyed. - NUM_CONNECTIONS_ACCEPTED_COUNTER.inc(); - scopeguard::defer! { - NUM_CONNECTIONS_CLOSED_COUNTER.inc(); - } - - let tls = config.tls_config.as_ref(); - let hostname = hostname.as_deref(); - - // TLS is None here, because the connection is already encrypted. - let (mut stream, params) = match handshake(stream, None, cancel_map).await? { - Some(x) => x, - None => return Ok(()), // it's a cancellation request - }; - - // Extract credentials which we're going to use for auth. - let creds = { - let common_name = tls.and_then(|tls| tls.common_name.as_deref()); - let result = config - .auth_backend - .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name)) - .transpose(); - - async { result }.or_else(|e| stream.throw_error(e)).await? - }; - - let conn = EstablishedConnection::connect_to_db( - stream, - creds, - ¶ms, + handle_client( + config, + cancel_map, session_id, - cancel_map.new_session()?, + stream, + HostnameMethod::Param(hostname), true, ) - .await?; - conn.handle_connection().await + .await +} + +enum HostnameMethod { + Param(Option), + Sni, } /// Handle an incoming client connection, handshake and authentication. @@ -134,7 +130,9 @@ async fn handle_client( config: &ProxyConfig, cancel_map: &CancelMap, session_id: uuid::Uuid, - stream: impl AsyncRead + AsyncWrite + Unpin + Send, + raw_stream: impl AsyncRead + AsyncWrite + Unpin + Send, + hostname_method: HostnameMethod, + use_cleartext_password_flow: bool, ) -> anyhow::Result<()> { // The `closed` counter will increase when this future is destroyed. NUM_CONNECTIONS_ACCEPTED_COUNTER.inc(); @@ -142,36 +140,66 @@ async fn handle_client( NUM_CONNECTIONS_CLOSED_COUNTER.inc(); } - // Process postgres startup packet and upgrade to TLS (if applicable) - let tls = config.tls_config.as_ref(); - let (mut stream, params) = match handshake(stream, tls, cancel_map).await? { - Some(x) => x, - None => return Ok(()), // it's a cancellation request - }; + // Accept the connection from the client, authenticate it, and establish + // connection to the database. + // + // We cover all these activities in a one tracing span, so that they are + // traced as one request. That makes it convenient to investigate where + // the time is spent when establishing a new connection. After the + // connection has been established, we exit the span, and use a separate + // span for the (rest of the) duration of the connection. + let conn = async { + // Process postgres startup packet and upgrade to TLS (if applicable) + let tls = config.tls_config.as_ref(); + let (mut stream, params) = match handshake(raw_stream, tls, cancel_map).await? { + Some(x) => x, + None => return Ok::<_, anyhow::Error>(None), // it's a cancellation request + }; - // Extract credentials which we're going to use for auth. - let creds = { - let sni = stream.get_ref().sni_hostname(); - let common_name = tls.and_then(|tls| tls.common_name.as_deref()); - let result = config - .auth_backend - .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name)) - .transpose(); + // Extract credentials which we're going to use for auth. + let creds = { + let sni = match &hostname_method { + HostnameMethod::Param(hostname) => hostname.as_deref(), + HostnameMethod::Sni => stream.get_ref().sni_hostname(), + }; + let common_name = tls.and_then(|tls| tls.common_name.as_deref()); + let result = config + .auth_backend + .as_ref() + .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name)) + .transpose(); - async { result }.or_else(|e| stream.throw_error(e)).await? - }; + async { result }.or_else(|e| stream.throw_error(e)).await? + }; - let conn = EstablishedConnection::connect_to_db( - stream, - creds, - ¶ms, - session_id, - cancel_map.new_session()?, - false, - ) + Ok(Some( + EstablishedConnection::connect_to_db( + stream, + creds, + ¶ms, + session_id, + use_cleartext_password_flow, + cancel_map, + ) + .await?, + )) + } + .instrument(info_span!("establish_connection", session_id=%session_id)) .await?; - conn.handle_connection().await + + match conn { + Some(conn) => { + // Connection established in both ways. Forward all traffic until the + // either connection is lost. + conn.handle_connection() + .instrument(info_span!("forward", session_id=%session_id)) + .await + } + None => { + // It was a cancellation request. It was handled in 'handshake' already. + Ok(()) + } + } } /// Establish a (most probably, secure) connection with the client. @@ -264,12 +292,14 @@ impl EstablishedConnection<'_, S> { #[instrument(skip_all)] async fn connect_to_db<'a>( mut stream: PqStream, - creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, - params: &'a StartupMessageParams, + creds: auth::BackendType<'a, auth::ClientCredentials<'_>>, + params: &'_ StartupMessageParams, session_id: uuid::Uuid, - session: cancellation::Session<'a>, use_cleartext_password_flow: bool, - ) -> anyhow::Result> { + cancel_map: &'a CancelMap, + ) -> anyhow::Result> { + let session = cancel_map.new_session()?; + let extra = auth::ConsoleReqExtra { session_id, // aka this connection's id application_name: params.get("application_name"),