From b41070ba536288578db861c05d581bf726582893 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 16 Sep 2024 15:15:34 +0100 Subject: [PATCH] proxy: refactor untagged enum parsing with manually implemented deserialize --- proxy/src/serverless/sql_over_http.rs | 414 +++++++++++++++++++++++++- 1 file changed, 405 insertions(+), 9 deletions(-) diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index f3a7ed9329..9beb6349e2 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::pin::pin; use std::sync::Arc; @@ -21,6 +22,9 @@ use hyper1::Response; use hyper1::StatusCode; use hyper1::{HeaderMap, Request}; use pq_proto::StartupMessageParamsBuilder; +use serde::de; +use serde::Deserialize; +use serde::Deserializer; use serde::Serialize; use serde_json::Value; use tokio::time; @@ -71,23 +75,415 @@ use super::json::json_to_pg_text; use super::json::pg_text_row_to_json; use super::json::JsonConversionError; -#[derive(serde::Deserialize)] -#[serde(rename_all = "camelCase")] struct QueryData { query: String, - #[serde(deserialize_with = "bytes_to_pg_text")] params: Vec>, - #[serde(default)] array_mode: Option, } -#[derive(serde::Deserialize)] +impl<'de> Deserialize<'de> for QueryData { + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { + enum Field { + Query, + Params, + ArrayMode, + Ignore, + } + + enum States { + Empty, + HasPartialQueryData { + query: Option, + params: Option>>, + #[allow(clippy::option_option)] + array_mode: Option>, + }, + } + + struct FieldVisitor; + + impl<'de> de::Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(r#"a JSON object string of either "query", "params", or "arrayMode"."#) + } + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + self.visit_bytes(v.as_bytes()) + } + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + match v { + b"query" => Ok(Field::Query), + b"params" => Ok(Field::Params), + b"arrayMode" => Ok(Field::ArrayMode), + _ => Ok(Field::Ignore), + } + } + } + impl<'de> Deserialize<'de> for Field { + #[inline] + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { + d.deserialize_identifier(FieldVisitor) + } + } + + struct Visitor; + impl<'de> de::Visitor<'de> for Visitor { + type Value = QueryData; + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str( + "a json object containing either a query object, or a list of query objects", + ) + } + #[inline] + fn visit_map(self, mut m: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut state = States::Empty; + + while let Some(key) = m.next_key()? { + match key { + Field::Query => { + let (params, array_mode) = match state { + States::HasPartialQueryData { query: Some(_), .. } => { + return Err(::duplicate_field("query")) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query: None, + params, + array_mode, + } => (params, array_mode), + }; + state = States::HasPartialQueryData { + query: Some(m.next_value()?), + params, + array_mode, + }; + } + Field::Params => { + #[doc(hidden)] + struct PgText { + value: Vec>, + } + impl<'de> Deserialize<'de> for PgText { + fn deserialize(__deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(PgText { + value: bytes_to_pg_text(__deserializer)?, + }) + } + } + + let (query, array_mode) = match state { + States::HasPartialQueryData { + params: Some(_), .. + } => { + return Err(::duplicate_field("params")) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query, + params: None, + array_mode, + } => (query, array_mode), + }; + state = States::HasPartialQueryData { + query, + params: Some(m.next_value::()?.value), + array_mode, + }; + } + Field::ArrayMode => { + let (query, params) = match state { + States::HasPartialQueryData { + array_mode: Some(_), + .. + } => { + return Err(::duplicate_field( + "arrayMode", + )) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query, + params, + array_mode: None, + } => (query, params), + }; + state = States::HasPartialQueryData { + query, + params, + array_mode: Some(m.next_value()?), + }; + } + Field::Ignore => { + let _ = m.next_value::()?; + } + } + } + match state { + States::HasPartialQueryData { + query: Some(query), + params: Some(params), + array_mode, + } => Ok(QueryData { + query, + params, + array_mode: array_mode.unwrap_or_default(), + }), + States::Empty | States::HasPartialQueryData { query: None, .. } => { + Err(::missing_field("query")) + } + States::HasPartialQueryData { params: None, .. } => { + Err(::missing_field("params")) + } + } + } + } + + Deserializer::deserialize_struct(d, "QueryData", &["query", "params", "arrayMode"], Visitor) + } +} + struct BatchQueryData { queries: Vec, } -#[derive(serde::Deserialize)] -#[serde(untagged)] +impl<'de> Deserialize<'de> for Payload { + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { + enum Field { + Queries, + Query, + Params, + ArrayMode, + Ignore, + } + + enum States { + Empty, + HasQueries(Vec), + HasPartialQueryData { + query: Option, + params: Option>>, + #[allow(clippy::option_option)] + array_mode: Option>, + }, + } + + struct FieldVisitor; + + impl<'de> de::Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(r#"a JSON object string of either "query", "params", "arrayMode", or "queries"."#) + } + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + self.visit_bytes(v.as_bytes()) + } + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + match v { + b"queries" => Ok(Field::Queries), + b"query" => Ok(Field::Query), + b"params" => Ok(Field::Params), + b"arrayMode" => Ok(Field::ArrayMode), + _ => Ok(Field::Ignore), + } + } + } + impl<'de> Deserialize<'de> for Field { + #[inline] + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { + d.deserialize_identifier(FieldVisitor) + } + } + + struct Visitor; + impl<'de> de::Visitor<'de> for Visitor { + type Value = Payload; + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str( + "a json object containing either a query object, or a list of query objects", + ) + } + #[inline] + fn visit_map(self, mut m: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut state = States::Empty; + + while let Some(key) = m.next_key()? { + match key { + Field::Queries => match state { + States::Empty => state = States::HasQueries(m.next_value()?), + States::HasQueries(_) => { + return Err(::duplicate_field("queries")) + } + States::HasPartialQueryData { .. } => { + return Err(::unknown_field( + "queries", + &["query", "params", "arrayMode"], + )) + } + }, + Field::Query => { + let (params, array_mode) = match state { + States::HasQueries(_) => { + return Err(::unknown_field( + "query", + &["queries"], + )) + } + States::HasPartialQueryData { query: Some(_), .. } => { + return Err(::duplicate_field("query")) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query: None, + params, + array_mode, + } => (params, array_mode), + }; + state = States::HasPartialQueryData { + query: Some(m.next_value()?), + params, + array_mode, + }; + } + Field::Params => { + #[doc(hidden)] + struct PgText { + value: Vec>, + } + impl<'de> Deserialize<'de> for PgText { + fn deserialize(__deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(PgText { + value: bytes_to_pg_text(__deserializer)?, + }) + } + } + + let (query, array_mode) = match state { + States::HasQueries(_) => { + return Err(::unknown_field( + "params", + &["queries"], + )) + } + States::HasPartialQueryData { + params: Some(_), .. + } => { + return Err(::duplicate_field("params")) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query, + params: None, + array_mode, + } => (query, array_mode), + }; + state = States::HasPartialQueryData { + query, + params: Some(m.next_value::()?.value), + array_mode, + }; + } + Field::ArrayMode => { + let (query, params) = match state { + States::HasQueries(_) => { + return Err(::unknown_field( + "arrayMode", + &["queries"], + )) + } + States::HasPartialQueryData { + array_mode: Some(_), + .. + } => { + return Err(::duplicate_field( + "arrayMode", + )) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query, + params, + array_mode: None, + } => (query, params), + }; + state = States::HasPartialQueryData { + query, + params, + array_mode: Some(m.next_value()?), + }; + } + Field::Ignore => { + let _ = m.next_value::()?; + } + } + } + match state { + States::HasQueries(queries) => Ok(Payload::Batch(BatchQueryData { queries })), + States::HasPartialQueryData { + query: Some(query), + params: Some(params), + array_mode, + } => Ok(Payload::Single(QueryData { + query, + params, + array_mode: array_mode.unwrap_or_default(), + })), + States::Empty | States::HasPartialQueryData { query: None, .. } => { + Err(::missing_field("query")) + } + States::HasPartialQueryData { params: None, .. } => { + Err(::missing_field("params")) + } + } + } + } + + Deserializer::deserialize_struct( + d, + "Payload", + &["queries", "query", "params", "arrayMode"], + Visitor, + ) + } +} + enum Payload { Single(QueryData), Batch(BatchQueryData), @@ -105,10 +501,10 @@ static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> where - D: serde::de::Deserializer<'de>, + D: Deserializer<'de>, { // TODO: consider avoiding the allocation here. - let json: Vec = serde::de::Deserialize::deserialize(deserializer)?; + let json: Vec = Deserialize::deserialize(deserializer)?; Ok(json_to_pg_text(json)) }