diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index b62a11ccb2..fce503a190 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -147,14 +147,14 @@ impl JwkCacheEntryLock { Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"), Ok(r) => { let resp: http::Response = r.into(); - match parse_json_body_with_limit::( + match parse_json_body_with_limit::( resp.into_body(), MAX_JWK_BODY_SIZE, ) .await { Err(e) => { - tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); + tracing::warn!(url=?rule.jwks_url, error=%e, "could not decode JWKs"); } Ok(jwks) => { key_sets.insert( diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 14720b5c6b..0a38c70ad6 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -6,7 +6,6 @@ pub mod health_server; use std::time::Duration; -use anyhow::bail; use bytes::Bytes; use http_body_util::BodyExt; use hyper1::body::Body; @@ -113,10 +112,20 @@ impl Endpoint { } } -pub(crate) async fn parse_json_body_with_limit( - mut b: impl Body + Unpin, +#[derive(Debug, thiserror::Error)] +pub(crate) enum ReadPayloadError { + #[error("could not read the HTTP body: {0}")] + Read(E), + #[error("could not parse the HTTP body: {0}")] + Parse(#[from] serde_json::Error), + #[error("could not parse the HTTP body: content length exceeds limit of {0} bytes")] + LengthExceeded(usize), +} + +pub(crate) async fn parse_json_body_with_limit( + mut b: impl Body + Unpin, limit: usize, -) -> anyhow::Result { +) -> Result> { // We could use `b.limited().collect().await.to_bytes()` here // but this ends up being slightly more efficient as far as I can tell. @@ -124,14 +133,19 @@ pub(crate) async fn parse_json_body_with_limit( // in reqwest, this value is influenced by the Content-Length header. let lower_bound = match usize::try_from(b.size_hint().lower()) { Ok(bound) if bound <= limit => bound, - _ => bail!("Content length exceeds limit of {limit} bytes"), + _ => return Err(ReadPayloadError::LengthExceeded(limit)), }; let mut bytes = Vec::with_capacity(lower_bound); - while let Some(frame) = b.frame().await.transpose()? { + while let Some(frame) = b + .frame() + .await + .transpose() + .map_err(ReadPayloadError::Read)? + { if let Ok(data) = frame.into_data() { if bytes.len() + data.len() > limit { - bail!("Content length exceeds limit of {limit} bytes") + return Err(ReadPayloadError::LengthExceeded(limit)); } bytes.extend_from_slice(&data); } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 15f4ee5639..dd017a0da8 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -50,6 +50,7 @@ use crate::context::RequestMonitoring; use crate::error::ErrorKind; use crate::error::ReportableError; use crate::error::UserFacingError; +use crate::http::parse_json_body_with_limit; use crate::metrics::HttpDirection; use crate::metrics::Metrics; use crate::proxy::run_until_cancelled; @@ -363,7 +364,7 @@ pub(crate) async fn handle( #[derive(Debug, thiserror::Error)] pub(crate) enum SqlOverHttpError { #[error("{0}")] - ReadPayload(#[from] ReadPayloadError), + ReadPayload(ReadPayloadError), #[error("{0}")] ConnectCompute(#[from] HttpConnError), #[error("{0}")] @@ -417,9 +418,9 @@ impl UserFacingError for SqlOverHttpError { #[derive(Debug, thiserror::Error)] pub(crate) enum ReadPayloadError { #[error("could not read the HTTP request body: {0}")] - Read(#[from] hyper1::Error), + Read(hyper1::Error), #[error("could not parse the HTTP request body: {0}")] - Parse(#[from] serde_json::Error), + Parse(serde_json::Error), } impl ReportableError for ReadPayloadError { @@ -431,6 +432,18 @@ impl ReportableError for ReadPayloadError { } } +impl From> for SqlOverHttpError { + fn from(value: crate::http::ReadPayloadError) -> Self { + match value { + crate::http::ReadPayloadError::Read(e) => Self::ReadPayload(ReadPayloadError::Read(e)), + crate::http::ReadPayloadError::Parse(e) => { + Self::ReadPayload(ReadPayloadError::Parse(e)) + } + crate::http::ReadPayloadError::LengthExceeded(x) => Self::RequestTooLarge(x as u64), + } + } +} + #[derive(Debug, thiserror::Error)] pub(crate) enum SqlOverHttpCancel { #[error("query was cancelled")] @@ -590,15 +603,14 @@ async fn handle_db_inner( )); } - let fetch_and_process_request = Box::pin( - async { - let body = request.into_body().collect().await?.to_bytes(); - info!(length = body.len(), "request payload read"); - let payload: Payload = serde_json::from_slice(&body)?; - Ok::(payload) // Adjust error type accordingly - } - .map_err(SqlOverHttpError::from), - ); + let fetch_and_process_request = Box::pin(async { + let payload = parse_json_body_with_limit( + request.into_body(), + config.http_config.max_request_size_bytes as usize, + ) + .await?; + Ok::(payload) // Adjust error type accordingly + }); let authenticate_and_connect = Box::pin( async {