From d336b8b5d93274802ab13d2df0ef013cd09e709c Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Tue, 24 Jan 2023 20:58:41 +0200 Subject: [PATCH] Refactor Client into EstablishedConnection. The name "Client" was a bit ambiguous. Instead of encapsulating all the data needed to establish the connection, change it so that it encapsulates the streams, after the connection has been established. With that, "EstablishedConnection" is a fitting name for it. --- proxy/src/proxy.rs | 89 ++++++++++++++++++---------------------------- 1 file changed, 35 insertions(+), 54 deletions(-) diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 2d69ee7435..200b8da714 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -115,14 +115,16 @@ pub async fn handle_ws_client( async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new( + let conn = EstablishedConnection::connect_to_db( stream, creds, ¶ms, session_id, cancel_map.new_session()?, - ); - client.handle_connection(true).await + true, + ) + .await?; + conn.handle_connection().await } /// Handle an incoming client connection, handshake and authentication. @@ -160,14 +162,16 @@ async fn handle_client( async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new( + let conn = EstablishedConnection::connect_to_db( stream, creds, ¶ms, session_id, cancel_map.new_session()?, - ); - client.handle_connection(false).await + false, + ) + .await?; + conn.handle_connection().await } /// Establish a (most probably, secure) connection with the client. @@ -237,62 +241,35 @@ async fn handshake( } /// Thin connection context. -struct Client<'a, S> { - /// The underlying libpq protocol stream. - stream: PqStream, - /// Client credentials that we care about. - creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, - /// KV-dictionary with PostgreSQL connection params. - params: &'a StartupMessageParams, - /// Unique connection ID. - session_id: uuid::Uuid, +struct EstablishedConnection<'a, S> { + client_stream: MeasuredStream, + db_stream: MeasuredStream, - session: cancellation::Session<'a>, + /// Hold on to the Session for as long as the connection is alive, so that + /// it can be cancelled. + _session: cancellation::Session<'a>, } -impl<'a, S> Client<'a, S> { - /// Construct a new connection context. - fn new( - stream: PqStream, - creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, - params: &'a StartupMessageParams, - session_id: uuid::Uuid, - session: cancellation::Session<'a>, - ) -> Self { - Self { - stream, - creds, - params, - session_id, - session, - } - } -} - -impl Client<'_, S> { - async fn handle_connection(self, use_cleartext_password_flow: bool) -> anyhow::Result<()> { - let (mut client, mut db) = self.connect_to_db(use_cleartext_password_flow).await?; - +impl EstablishedConnection<'_, S> { + async fn handle_connection(mut self) -> anyhow::Result<()> { // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?; + let _ = tokio::io::copy_bidirectional(&mut self.client_stream, &mut self.db_stream).await?; Ok(()) } /// Let the client authenticate and connect to the designated compute node. + /// On return, the connection is fully established in both ways, and we can start + /// forwarding the bytes. #[instrument(skip_all)] - async fn connect_to_db( - self, + async fn connect_to_db<'a>( + mut stream: PqStream, + creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, + params: &'a StartupMessageParams, + session_id: uuid::Uuid, + session: cancellation::Session<'a>, use_cleartext_password_flow: bool, - ) -> anyhow::Result<(MeasuredStream, MeasuredStream)> { - let Self { - mut stream, - creds, - params, - session_id, - session, - } = self; - + ) -> anyhow::Result> { let extra = auth::ConsoleReqExtra { session_id, // aka this connection's id application_name: params.get("application_name"), @@ -340,11 +317,15 @@ impl Client<'_, S> { .await?; let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx")); - let client = MeasuredStream::new(stream.into_inner(), m_sent); + let client_stream = MeasuredStream::new(stream.into_inner(), m_sent); let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx")); - let db = MeasuredStream::new(db.stream, m_recv); + let db_stream = MeasuredStream::new(db.stream, m_recv); - Ok((client, db)) + Ok(EstablishedConnection { + client_stream, + db_stream, + _session: session, + }) } }