diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 2d69ee7435..200b8da714 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -115,14 +115,16 @@ pub async fn handle_ws_client( async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new( + let conn = EstablishedConnection::connect_to_db( stream, creds, ¶ms, session_id, cancel_map.new_session()?, - ); - client.handle_connection(true).await + true, + ) + .await?; + conn.handle_connection().await } /// Handle an incoming client connection, handshake and authentication. @@ -160,14 +162,16 @@ async fn handle_client( async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new( + let conn = EstablishedConnection::connect_to_db( stream, creds, ¶ms, session_id, cancel_map.new_session()?, - ); - client.handle_connection(false).await + false, + ) + .await?; + conn.handle_connection().await } /// Establish a (most probably, secure) connection with the client. @@ -237,62 +241,35 @@ async fn handshake( } /// Thin connection context. -struct Client<'a, S> { - /// The underlying libpq protocol stream. - stream: PqStream, - /// Client credentials that we care about. - creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, - /// KV-dictionary with PostgreSQL connection params. - params: &'a StartupMessageParams, - /// Unique connection ID. - session_id: uuid::Uuid, +struct EstablishedConnection<'a, S> { + client_stream: MeasuredStream, + db_stream: MeasuredStream, - session: cancellation::Session<'a>, + /// Hold on to the Session for as long as the connection is alive, so that + /// it can be cancelled. + _session: cancellation::Session<'a>, } -impl<'a, S> Client<'a, S> { - /// Construct a new connection context. - fn new( - stream: PqStream, - creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, - params: &'a StartupMessageParams, - session_id: uuid::Uuid, - session: cancellation::Session<'a>, - ) -> Self { - Self { - stream, - creds, - params, - session_id, - session, - } - } -} - -impl Client<'_, S> { - async fn handle_connection(self, use_cleartext_password_flow: bool) -> anyhow::Result<()> { - let (mut client, mut db) = self.connect_to_db(use_cleartext_password_flow).await?; - +impl EstablishedConnection<'_, S> { + async fn handle_connection(mut self) -> anyhow::Result<()> { // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?; + let _ = tokio::io::copy_bidirectional(&mut self.client_stream, &mut self.db_stream).await?; Ok(()) } /// Let the client authenticate and connect to the designated compute node. + /// On return, the connection is fully established in both ways, and we can start + /// forwarding the bytes. #[instrument(skip_all)] - async fn connect_to_db( - self, + async fn connect_to_db<'a>( + mut stream: PqStream, + creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, + params: &'a StartupMessageParams, + session_id: uuid::Uuid, + session: cancellation::Session<'a>, use_cleartext_password_flow: bool, - ) -> anyhow::Result<(MeasuredStream, MeasuredStream)> { - let Self { - mut stream, - creds, - params, - session_id, - session, - } = self; - + ) -> anyhow::Result> { let extra = auth::ConsoleReqExtra { session_id, // aka this connection's id application_name: params.get("application_name"), @@ -340,11 +317,15 @@ impl Client<'_, S> { .await?; let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx")); - let client = MeasuredStream::new(stream.into_inner(), m_sent); + let client_stream = MeasuredStream::new(stream.into_inner(), m_sent); let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx")); - let db = MeasuredStream::new(db.stream, m_recv); + let db_stream = MeasuredStream::new(db.stream, m_recv); - Ok((client, db)) + Ok(EstablishedConnection { + client_stream, + db_stream, + _session: session, + }) } }