diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index dcc500f2c8..8445368740 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -25,19 +25,15 @@ pub(super) async fn authenticate( } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret, ctx); - let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async { - AuthFlow::new(client, scram) - .authenticate() - .await - .inspect_err(|error| { - warn!(?error, "error processing scram messages"); - }) - }) + let auth_outcome = tokio::time::timeout( + config.scram_protocol_timeout, + AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(), + ) .await .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) - .map_err(auth::AuthError::user_timeout)??; + .map_err(auth::AuthError::user_timeout)? + .inspect_err(|error| warn!(?error, "error processing scram messages"))?; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 9499aba61b..7fb84b5ee5 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -159,7 +159,7 @@ pub async fn task_main( } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, backend: &'static ConsoleRedirectBackend, ctx: &RequestContext, diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 93f4ea6cf7..da548d6b2c 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -7,7 +7,9 @@ use std::time::Duration; use ::http::HeaderName; use ::http::header::AUTHORIZATION; +use bytes::Bytes; use futures::TryFutureExt; +use hyper::StatusCode; use postgres_client::config::SslMode; use tokio::time::Instant; use tracing::{Instrument, debug, info, info_span, warn}; @@ -72,28 +74,34 @@ impl NeonControlPlaneClient { role: &RoleName, ) -> Result { async { - let request = self - .endpoint - .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, ctx.session_id().to_string()) - .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", ctx.session_id())]) - .query(&[ - ("application_name", ctx.console_application_name().as_str()), - ("endpointish", endpoint.as_str()), - ("role", role.as_str()), - ]) - .build()?; - - debug!(url = request.url().as_str(), "sending http request"); - let start = Instant::now(); let response = { - let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); - self.endpoint.execute(request).await? - }; - info!(duration = ?start.elapsed(), "received http response"); + let request = self + .endpoint + .get_path("get_endpoint_access_control") + .header(X_REQUEST_ID, ctx.session_id().to_string()) + .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) + .query(&[("session_id", ctx.session_id())]) + .query(&[ + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), + ]) + .build()?; - let body = match parse_body::(response).await { + debug!(url = request.url().as_str(), "sending http request"); + let start = Instant::now(); + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + let response = self.endpoint.execute(request).await?; + + info!(duration = ?start.elapsed(), "received http response"); + + response + }; + + let body = match parse_body::( + response.status(), + response.bytes().await?, + ) { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. // TODO(anna): retry @@ -184,7 +192,10 @@ impl NeonControlPlaneClient { drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::( + response.status(), + response.bytes().await.map_err(ControlPlaneError::from)?, + )?; let rules = body .jwks @@ -236,7 +247,7 @@ impl NeonControlPlaneClient { let response = self.endpoint.execute(request).await?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::(response.status(), response.bytes().await?)?; // Unfortunately, ownership won't let us use `Option::ok_or` here. let (host, port) = match parse_host_port(&body.address) { @@ -487,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } /// Parse http response body, taking status code into account. -async fn parse_body serde::Deserialize<'a>>( - response: http::Response, +fn parse_body serde::Deserialize<'a>>( + status: StatusCode, + body: Bytes, ) -> Result { - let status = response.status(); if status.is_success() { // We shouldn't log raw body because it may contain secrets. info!("request succeeded, processing the body"); - return Ok(response.json().await?); + return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?); } - let s = response.bytes().await?; + // Log plaintext to be able to detect, whether there are some cases not covered by the error struct. - info!("response_error plaintext: {:?}", s); + info!("response_error plaintext: {:?}", body); // Don't throw an error here because it's not as important // as the fact that the request itself has failed. - let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| { + let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| { warn!("failed to parse error body: {e}"); - ControlPlaneErrorMessage { + Box::new(ControlPlaneErrorMessage { error: "reason unclear (malformed error message)".into(), http_status_code: status, status: None, - } + }) }); body.http_status_code = status; warn!("console responded with an error ({status}): {body:?}"); - Err(ControlPlaneError::Message(Box::new(body))) + Err(ControlPlaneError::Message(body)) } fn parse_host_port(input: &str) -> Option<(&str, u16)> { diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index 96f600d836..36607e7861 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -4,9 +4,10 @@ pub mod health_server; -use std::time::Duration; +use std::time::{Duration, Instant}; use bytes::Bytes; +use futures::FutureExt; use http::Method; use http_body_util::BodyExt; use hyper::body::Body; @@ -109,15 +110,31 @@ impl Endpoint { } /// Execute a [request](reqwest::Request). - pub(crate) async fn execute(&self, request: Request) -> Result { - let _timer = Metrics::get() + pub(crate) fn execute( + &self, + request: Request, + ) -> impl Future> { + let metric = Metrics::get() .proxy .console_request_latency - .start_timer(ConsoleRequest { + .with_labels(ConsoleRequest { request: request.url().path(), }); - self.client.execute(request).await + let req = self.client.execute(request).boxed(); + + async move { + let start = Instant::now(); + scopeguard::defer!({ + Metrics::get() + .proxy + .console_request_latency + .get_metric(metric) + .observe_duration_since(start); + }); + + req.await + } } } diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index d68d9f9474..43074bf208 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -186,7 +186,7 @@ where pub async fn read_message<'a, S>( stream: &mut S, buf: &'a mut Vec, - max: usize, + max: u32, ) -> io::Result<(u8, &'a mut [u8])> where S: AsyncRead + Unpin, @@ -206,7 +206,7 @@ where let header = read!(stream => Header); // as described above, the length must be at least 4. - let Some(len) = (header.len.get() as usize).checked_sub(4) else { + let Some(len) = header.len.get().checked_sub(4) else { return Err(io::Error::other(format!( "invalid startup message length {}, must be at least 4.", header.len, @@ -222,7 +222,7 @@ where } // read in our entire message. - buf.resize(len, 0); + buf.resize(len as usize, 0); stream.read_exact(buf).await?; Ok((header.tag, buf)) diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 13ee8c7dd2..6970ab8714 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,3 +1,4 @@ +use futures::{FutureExt, TryFutureExt}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -57,7 +58,7 @@ pub(crate) enum HandshakeData { /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] -pub(crate) async fn handshake( +pub(crate) async fn handshake( ctx: &RequestContext, stream: S, mut tls: Option<&TlsConfig>, @@ -108,7 +109,9 @@ pub(crate) async fn handshake( } } } - }); + }) + .map_ok(Box::new) + .boxed(); res?; @@ -146,7 +149,7 @@ pub(crate) async fn handshake( tls.cert_resolver.resolve(conn_info.server_name()); let tls = Stream::Tls { - tls: Box::new(tls_stream), + tls: tls_stream, tls_server_end_point, }; (stream, msg) = PqStream::parse_startup(tls).await?; diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index ac0aca1176..0ffc54aa88 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -270,7 +270,7 @@ impl ReportableError for ClientRequestError { } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 8f9bd2de2d..55ab5f4dba 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,3 +1,4 @@ +use futures::FutureExt; use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; @@ -89,6 +90,7 @@ impl ProxyPassthrough { .compute .cancel_closure .try_cancel_query(compute_config) + .boxed() .await { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index cb15132673..52ccca58d5 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -30,52 +30,53 @@ where F: FnOnce(&str) -> super::Result, M: Mechanism, { - let sasl = { + let (mut mechanism, mut input) = { // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); // Initial client message contains the chosen auth method's name. let msg = stream.read_password_message().await?; - super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))? + + let sasl = super::FirstMessage::parse(msg) + .ok_or(super::Error::BadClientMessage("bad sasl message"))?; + + (mechanism(sasl.method)?, sasl.message) }; - let mut mechanism = mechanism(sasl.method)?; - let mut input = sasl.message; loop { - let step = mechanism - .exchange(input) - .inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?; - - match step { - Step::Continue(moved_mechanism, reply) => { + match mechanism.exchange(input) { + Ok(Step::Continue(moved_mechanism, reply)) => { mechanism = moved_mechanism; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // write reply let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); - - // get next input - stream.flush().await?; - let msg = stream.read_password_message().await?; - input = std::str::from_utf8(msg) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + drop(reply); } - Step::Success(result, reply) => { - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - + Ok(Step::Success(result, reply)) => { // write reply let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); stream.write_message(BeMessage::AuthenticationOk); + // exit with success break Ok(Outcome::Success(result)); } // exit with failure - Step::Failure(reason) => break Ok(Outcome::Failure(reason)), + Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)), + Err(error) => { + tracing::info!(?error, "error during SASL exchange"); + return Err(error); + } } + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 7126430a85..c49a431c95 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -72,7 +72,7 @@ impl PqStream { impl PqStream { /// Read a raw postgres packet, which will respect the max length requested. /// This is not cancel safe. - async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> { + async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> { let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; if actual_tag != tag { return Err(io::Error::other(format!( @@ -89,7 +89,7 @@ impl PqStream { // passwords are usually pretty short // and SASL SCRAM messages are no longer than 256 bytes in my testing // (a few hashes and random bytes, encoded into base64). - const MAX_PASSWORD_LENGTH: usize = 512; + const MAX_PASSWORD_LENGTH: u32 = 512; self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) .await } diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index f09e916a1d..013b307f0b 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -31,7 +31,9 @@ mod private { type Output = io::Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream) + Pin::new(&mut self.inner) + .poll(cx) + .map_ok(|s| RustlsStream(Box::new(s))) } } @@ -57,7 +59,7 @@ mod private { } } - pub struct RustlsStream(TlsStream); + pub struct RustlsStream(Box>); impl postgres_client::tls::TlsStream for RustlsStream where