proxy: refactor untagged enum parsing with manually implemented deserialize

This commit is contained in:
Conrad Ludgate
2024-09-16 15:15:34 +01:00
parent 4391b25d01
commit b41070ba53

View File

@@ -1,3 +1,4 @@
use std::fmt;
use std::pin::pin;
use std::sync::Arc;
@@ -21,6 +22,9 @@ use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use pq_proto::StartupMessageParamsBuilder;
use serde::de;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde_json::Value;
use tokio::time;
@@ -71,23 +75,415 @@ use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
params: Vec<Option<String>>,
#[serde(default)]
array_mode: Option<bool>,
}
#[derive(serde::Deserialize)]
impl<'de> Deserialize<'de> for QueryData {
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
enum Field {
Query,
Params,
ArrayMode,
Ignore,
}
enum States {
Empty,
HasPartialQueryData {
query: Option<String>,
params: Option<Vec<Option<String>>>,
#[allow(clippy::option_option)]
array_mode: Option<Option<bool>>,
},
}
struct FieldVisitor;
impl<'de> de::Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(r#"a JSON object string of either "query", "params", or "arrayMode"."#)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
self.visit_bytes(v.as_bytes())
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
match v {
b"query" => Ok(Field::Query),
b"params" => Ok(Field::Params),
b"arrayMode" => Ok(Field::ArrayMode),
_ => Ok(Field::Ignore),
}
}
}
impl<'de> Deserialize<'de> for Field {
#[inline]
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
d.deserialize_identifier(FieldVisitor)
}
}
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = QueryData;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
"a json object containing either a query object, or a list of query objects",
)
}
#[inline]
fn visit_map<A>(self, mut m: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut state = States::Empty;
while let Some(key) = m.next_key()? {
match key {
Field::Query => {
let (params, array_mode) = match state {
States::HasPartialQueryData { query: Some(_), .. } => {
return Err(<A::Error as de::Error>::duplicate_field("query"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query: None,
params,
array_mode,
} => (params, array_mode),
};
state = States::HasPartialQueryData {
query: Some(m.next_value()?),
params,
array_mode,
};
}
Field::Params => {
#[doc(hidden)]
struct PgText {
value: Vec<Option<String>>,
}
impl<'de> Deserialize<'de> for PgText {
fn deserialize<D>(__deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(PgText {
value: bytes_to_pg_text(__deserializer)?,
})
}
}
let (query, array_mode) = match state {
States::HasPartialQueryData {
params: Some(_), ..
} => {
return Err(<A::Error as de::Error>::duplicate_field("params"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params: None,
array_mode,
} => (query, array_mode),
};
state = States::HasPartialQueryData {
query,
params: Some(m.next_value::<PgText>()?.value),
array_mode,
};
}
Field::ArrayMode => {
let (query, params) = match state {
States::HasPartialQueryData {
array_mode: Some(_),
..
} => {
return Err(<A::Error as de::Error>::duplicate_field(
"arrayMode",
))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params,
array_mode: None,
} => (query, params),
};
state = States::HasPartialQueryData {
query,
params,
array_mode: Some(m.next_value()?),
};
}
Field::Ignore => {
let _ = m.next_value::<de::IgnoredAny>()?;
}
}
}
match state {
States::HasPartialQueryData {
query: Some(query),
params: Some(params),
array_mode,
} => Ok(QueryData {
query,
params,
array_mode: array_mode.unwrap_or_default(),
}),
States::Empty | States::HasPartialQueryData { query: None, .. } => {
Err(<A::Error as de::Error>::missing_field("query"))
}
States::HasPartialQueryData { params: None, .. } => {
Err(<A::Error as de::Error>::missing_field("params"))
}
}
}
}
Deserializer::deserialize_struct(d, "QueryData", &["query", "params", "arrayMode"], Visitor)
}
}
struct BatchQueryData {
queries: Vec<QueryData>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
impl<'de> Deserialize<'de> for Payload {
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
enum Field {
Queries,
Query,
Params,
ArrayMode,
Ignore,
}
enum States {
Empty,
HasQueries(Vec<QueryData>),
HasPartialQueryData {
query: Option<String>,
params: Option<Vec<Option<String>>>,
#[allow(clippy::option_option)]
array_mode: Option<Option<bool>>,
},
}
struct FieldVisitor;
impl<'de> de::Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(r#"a JSON object string of either "query", "params", "arrayMode", or "queries"."#)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
self.visit_bytes(v.as_bytes())
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
match v {
b"queries" => Ok(Field::Queries),
b"query" => Ok(Field::Query),
b"params" => Ok(Field::Params),
b"arrayMode" => Ok(Field::ArrayMode),
_ => Ok(Field::Ignore),
}
}
}
impl<'de> Deserialize<'de> for Field {
#[inline]
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
d.deserialize_identifier(FieldVisitor)
}
}
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = Payload;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
"a json object containing either a query object, or a list of query objects",
)
}
#[inline]
fn visit_map<A>(self, mut m: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut state = States::Empty;
while let Some(key) = m.next_key()? {
match key {
Field::Queries => match state {
States::Empty => state = States::HasQueries(m.next_value()?),
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::duplicate_field("queries"))
}
States::HasPartialQueryData { .. } => {
return Err(<A::Error as de::Error>::unknown_field(
"queries",
&["query", "params", "arrayMode"],
))
}
},
Field::Query => {
let (params, array_mode) = match state {
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::unknown_field(
"query",
&["queries"],
))
}
States::HasPartialQueryData { query: Some(_), .. } => {
return Err(<A::Error as de::Error>::duplicate_field("query"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query: None,
params,
array_mode,
} => (params, array_mode),
};
state = States::HasPartialQueryData {
query: Some(m.next_value()?),
params,
array_mode,
};
}
Field::Params => {
#[doc(hidden)]
struct PgText {
value: Vec<Option<String>>,
}
impl<'de> Deserialize<'de> for PgText {
fn deserialize<D>(__deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(PgText {
value: bytes_to_pg_text(__deserializer)?,
})
}
}
let (query, array_mode) = match state {
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::unknown_field(
"params",
&["queries"],
))
}
States::HasPartialQueryData {
params: Some(_), ..
} => {
return Err(<A::Error as de::Error>::duplicate_field("params"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params: None,
array_mode,
} => (query, array_mode),
};
state = States::HasPartialQueryData {
query,
params: Some(m.next_value::<PgText>()?.value),
array_mode,
};
}
Field::ArrayMode => {
let (query, params) = match state {
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::unknown_field(
"arrayMode",
&["queries"],
))
}
States::HasPartialQueryData {
array_mode: Some(_),
..
} => {
return Err(<A::Error as de::Error>::duplicate_field(
"arrayMode",
))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params,
array_mode: None,
} => (query, params),
};
state = States::HasPartialQueryData {
query,
params,
array_mode: Some(m.next_value()?),
};
}
Field::Ignore => {
let _ = m.next_value::<de::IgnoredAny>()?;
}
}
}
match state {
States::HasQueries(queries) => Ok(Payload::Batch(BatchQueryData { queries })),
States::HasPartialQueryData {
query: Some(query),
params: Some(params),
array_mode,
} => Ok(Payload::Single(QueryData {
query,
params,
array_mode: array_mode.unwrap_or_default(),
})),
States::Empty | States::HasPartialQueryData { query: None, .. } => {
Err(<A::Error as de::Error>::missing_field("query"))
}
States::HasPartialQueryData { params: None, .. } => {
Err(<A::Error as de::Error>::missing_field("params"))
}
}
}
}
Deserializer::deserialize_struct(
d,
"Payload",
&["queries", "query", "params", "arrayMode"],
Visitor,
)
}
}
enum Payload {
Single(QueryData),
Batch(BatchQueryData),
@@ -105,10 +501,10 @@ static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
D: Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
let json: Vec<Value> = Deserialize::deserialize(deserializer)?;
Ok(json_to_pg_text(json))
}