share json parse fn

This commit is contained in:
Conrad Ludgate
2024-09-17 08:43:59 +01:00
parent f11254f2c5
commit 1466767571
3 changed files with 47 additions and 21 deletions

View File

@@ -147,14 +147,14 @@ impl JwkCacheEntryLock {
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
Ok(r) => {
let resp: http::Response<reqwest::Body> = r.into();
match parse_json_body_with_limit::<jose_jwk::JwkSet>(
match parse_json_body_with_limit::<jose_jwk::JwkSet, _>(
resp.into_body(),
MAX_JWK_BODY_SIZE,
)
.await
{
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
tracing::warn!(url=?rule.jwks_url, error=%e, "could not decode JWKs");
}
Ok(jwks) => {
key_sets.insert(

View File

@@ -6,7 +6,6 @@ pub mod health_server;
use std::time::Duration;
use anyhow::bail;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper1::body::Body;
@@ -113,10 +112,20 @@ impl Endpoint {
}
}
pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError<E> {
#[error("could not read the HTTP body: {0}")]
Read(E),
#[error("could not parse the HTTP body: {0}")]
Parse(#[from] serde_json::Error),
#[error("could not parse the HTTP body: content length exceeds limit of {0} bytes")]
LengthExceeded(usize),
}
pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned, E>(
mut b: impl Body<Data = Bytes, Error = E> + Unpin,
limit: usize,
) -> anyhow::Result<D> {
) -> Result<D, ReadPayloadError<E>> {
// We could use `b.limited().collect().await.to_bytes()` here
// but this ends up being slightly more efficient as far as I can tell.
@@ -124,14 +133,19 @@ pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
// in reqwest, this value is influenced by the Content-Length header.
let lower_bound = match usize::try_from(b.size_hint().lower()) {
Ok(bound) if bound <= limit => bound,
_ => bail!("Content length exceeds limit of {limit} bytes"),
_ => return Err(ReadPayloadError::LengthExceeded(limit)),
};
let mut bytes = Vec::with_capacity(lower_bound);
while let Some(frame) = b.frame().await.transpose()? {
while let Some(frame) = b
.frame()
.await
.transpose()
.map_err(ReadPayloadError::Read)?
{
if let Ok(data) = frame.into_data() {
if bytes.len() + data.len() > limit {
bail!("Content length exceeds limit of {limit} bytes")
return Err(ReadPayloadError::LengthExceeded(limit));
}
bytes.extend_from_slice(&data);
}

View File

@@ -50,6 +50,7 @@ use crate::context::RequestMonitoring;
use crate::error::ErrorKind;
use crate::error::ReportableError;
use crate::error::UserFacingError;
use crate::http::parse_json_body_with_limit;
use crate::metrics::HttpDirection;
use crate::metrics::Metrics;
use crate::proxy::run_until_cancelled;
@@ -363,7 +364,7 @@ pub(crate) async fn handle(
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpError {
#[error("{0}")]
ReadPayload(#[from] ReadPayloadError),
ReadPayload(ReadPayloadError),
#[error("{0}")]
ConnectCompute(#[from] HttpConnError),
#[error("{0}")]
@@ -417,9 +418,9 @@ impl UserFacingError for SqlOverHttpError {
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper1::Error),
Read(hyper1::Error),
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
Parse(serde_json::Error),
}
impl ReportableError for ReadPayloadError {
@@ -431,6 +432,18 @@ impl ReportableError for ReadPayloadError {
}
}
impl From<crate::http::ReadPayloadError<hyper1::Error>> for SqlOverHttpError {
fn from(value: crate::http::ReadPayloadError<hyper1::Error>) -> Self {
match value {
crate::http::ReadPayloadError::Read(e) => Self::ReadPayload(ReadPayloadError::Read(e)),
crate::http::ReadPayloadError::Parse(e) => {
Self::ReadPayload(ReadPayloadError::Parse(e))
}
crate::http::ReadPayloadError::LengthExceeded(x) => Self::RequestTooLarge(x as u64),
}
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpCancel {
#[error("query was cancelled")]
@@ -590,15 +603,14 @@ async fn handle_db_inner(
));
}
let fetch_and_process_request = Box::pin(
async {
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from),
);
let fetch_and_process_request = Box::pin(async {
let payload = parse_json_body_with_limit(
request.into_body(),
config.http_config.max_request_size_bytes as usize,
)
.await?;
Ok::<Payload, SqlOverHttpError>(payload) // Adjust error type accordingly
});
let authenticate_and_connect = Box::pin(
async {