From fe8b93ab9dac674012468d3d1483837a3f203c63 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 14 Oct 2024 14:02:46 +0100 Subject: [PATCH] fix Payload deser --- proxy/src/serverless/json.rs | 6 +- proxy/src/serverless/sql_over_http.rs | 101 +++++++++++++++++++------- 2 files changed, 78 insertions(+), 29 deletions(-) diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 961645dc36..88d479aa06 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -14,9 +14,9 @@ use super::json_raw_value::LazyValue; // as parameters. // pub(crate) fn json_to_pg_text( - json: Vec<&RawValue>, + json: &[&RawValue], ) -> Result>, serde_json::Error> { - json.into_iter().map(json_value_to_pg_text).try_collect() + json.iter().copied().map(json_value_to_pg_text).try_collect() } fn json_value_to_pg_text(value: &RawValue) -> Result, serde_json::Error> { @@ -287,7 +287,7 @@ mod tests { fn json_to_pg_text_test(json: Vec) -> Vec> { let json = serde_json::Value::Array(json).to_string(); let json: Vec<&RawValue> = serde_json::from_str(&json).unwrap(); - json_to_pg_text(json).unwrap() + json_to_pg_text(&json).unwrap() } #[test] diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index f02f70fd9a..fc7a5a3c95 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::pin::pin; use std::sync::Arc; @@ -76,24 +77,28 @@ use super::local_conn_pool; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] -struct QueryData { - query: String, - #[serde(deserialize_with = "bytes_to_pg_text")] - params: Vec>, +#[serde(bound = "'de: 'a")] +struct QueryData<'a> { + #[serde(borrow)] + query: Cow<'a, str>, + + #[serde(borrow)] + params: Vec<&'a RawValue>, + #[serde(default)] array_mode: Option, } #[derive(serde::Deserialize)] -struct BatchQueryData { - queries: Vec, +#[serde(rename_all = "camelCase")] +#[serde(bound = "'de: 'a")] +struct BatchQueryData<'a> { + queries: Vec>, } -#[derive(serde::Deserialize)] -#[serde(untagged)] -enum Payload { - Single(QueryData), - Batch(BatchQueryData), +enum Payload<'a> { + Batch(BatchQueryData<'a>), + Single(QueryData<'a>), } static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string"); @@ -106,13 +111,18 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); -fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> -where - D: serde::de::Deserializer<'de>, -{ - // TODO: consider avoiding the allocation here. - let json: Vec<&RawValue> = serde::de::Deserialize::deserialize(deserializer)?; - json_to_pg_text(json).map_err(serde::de::Error::custom) +fn parse_pg_params(params: &[&RawValue]) -> Result>, ReadPayloadError> { + json_to_pg_text(params).map_err(ReadPayloadError::Parse) +} + +fn parse_payload(body: &[u8]) -> Result, ReadPayloadError> { + // RawValue doesn't work via untagged enums + // so instead we try parse each individually + if let Ok(batch) = serde_json::from_slice(body) { + Ok(Payload::Batch(batch)) + } else { + Ok(Payload::Single(serde_json::from_slice(body)?)) + } } #[derive(Debug, thiserror::Error)] @@ -615,8 +625,7 @@ async fn handle_db_inner( async { let body = request.into_body().collect().await?.to_bytes(); info!(length = body.len(), "request payload read"); - let payload: Payload = serde_json::from_slice(&body)?; - Ok::(payload) // Adjust error type accordingly + Ok::(body) } .map_err(SqlOverHttpError::from), ); @@ -660,7 +669,7 @@ async fn handle_db_inner( .map_err(SqlOverHttpError::from), ); - let (payload, mut client) = match run_until_cancelled( + let (body, mut client) = match run_until_cancelled( // Run both operations in parallel try_join( pin!(fetch_and_process_request), @@ -674,6 +683,8 @@ async fn handle_db_inner( None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)), }; + let payload = parse_payload(&body)?; + let mut response = Response::builder() .status(StatusCode::OK) .header(header::CONTENT_TYPE, "application/json"); @@ -781,7 +792,7 @@ async fn handle_auth_broker_inner( .map(|b| b.boxed())) } -impl QueryData { +impl QueryData<'_> { async fn process( self, config: &'static HttpConfig, @@ -855,7 +866,7 @@ impl QueryData { } } -impl BatchQueryData { +impl BatchQueryData<'_> { async fn process( self, config: &'static HttpConfig, @@ -931,7 +942,7 @@ async fn query_batch( config: &'static HttpConfig, cancel: CancellationToken, transaction: &Transaction<'_>, - queries: BatchQueryData, + queries: BatchQueryData<'_>, parsed_headers: HttpHeaders, ) -> Result { let mut results = Vec::with_capacity(queries.queries.len()); @@ -969,12 +980,12 @@ async fn query_batch( async fn query_to_json( config: &'static HttpConfig, client: &T, - data: QueryData, + data: QueryData<'_>, current_size: &mut usize, parsed_headers: HttpHeaders, ) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> { info!("executing query"); - let query_params = data.params; + let query_params = parse_pg_params(&data.params)?; let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?); info!("finished executing query"); @@ -1100,3 +1111,41 @@ impl Discard<'_> { } } } + +#[cfg(test)] +mod tests { + use typed_json::json; + + use super::parse_payload; + use super::Payload; + + #[test] + fn raw_single_payload() { + let body = json! { + {"query":"select $1","params":["1"]} + } + .to_string(); + + let Payload::Single(query) = parse_payload(body.as_bytes()).unwrap() else { + panic!("expected single") + }; + assert_eq!(&*query.query, "select $1"); + assert_eq!(query.params[0].get(), "\"1\""); + } + + #[test] + fn raw_batch_payload() { + let body = json! {{ + "queries": [ + {"query":"select $1","params":["1"]}, + {"query":"select $1","params":["2"]}, + ] + }} + .to_string(); + + let Payload::Batch(query) = parse_payload(body.as_bytes()).unwrap() else { + panic!("expected batch") + }; + assert_eq!(query.queries.len(), 2); + } +}