mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-02 13:00:37 +00:00
proxy: refactor untagged enum parsing with manually implemented deserialize
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
use std::fmt;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -21,6 +22,9 @@ use hyper1::Response;
|
||||
use hyper1::StatusCode;
|
||||
use hyper1::{HeaderMap, Request};
|
||||
use pq_proto::StartupMessageParamsBuilder;
|
||||
use serde::de;
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tokio::time;
|
||||
@@ -71,23 +75,415 @@ use super::json::json_to_pg_text;
|
||||
use super::json::pg_text_row_to_json;
|
||||
use super::json::JsonConversionError;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
#[serde(deserialize_with = "bytes_to_pg_text")]
|
||||
params: Vec<Option<String>>,
|
||||
#[serde(default)]
|
||||
array_mode: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
impl<'de> Deserialize<'de> for QueryData {
|
||||
fn deserialize<D>(d: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
enum Field {
|
||||
Query,
|
||||
Params,
|
||||
ArrayMode,
|
||||
Ignore,
|
||||
}
|
||||
|
||||
enum States {
|
||||
Empty,
|
||||
HasPartialQueryData {
|
||||
query: Option<String>,
|
||||
params: Option<Vec<Option<String>>>,
|
||||
#[allow(clippy::option_option)]
|
||||
array_mode: Option<Option<bool>>,
|
||||
},
|
||||
}
|
||||
|
||||
struct FieldVisitor;
|
||||
|
||||
impl<'de> de::Visitor<'de> for FieldVisitor {
|
||||
type Value = Field;
|
||||
|
||||
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(r#"a JSON object string of either "query", "params", or "arrayMode"."#)
|
||||
}
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
self.visit_bytes(v.as_bytes())
|
||||
}
|
||||
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
match v {
|
||||
b"query" => Ok(Field::Query),
|
||||
b"params" => Ok(Field::Params),
|
||||
b"arrayMode" => Ok(Field::ArrayMode),
|
||||
_ => Ok(Field::Ignore),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'de> Deserialize<'de> for Field {
|
||||
#[inline]
|
||||
fn deserialize<D>(d: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
d.deserialize_identifier(FieldVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
struct Visitor;
|
||||
impl<'de> de::Visitor<'de> for Visitor {
|
||||
type Value = QueryData;
|
||||
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(
|
||||
"a json object containing either a query object, or a list of query objects",
|
||||
)
|
||||
}
|
||||
#[inline]
|
||||
fn visit_map<A>(self, mut m: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::MapAccess<'de>,
|
||||
{
|
||||
let mut state = States::Empty;
|
||||
|
||||
while let Some(key) = m.next_key()? {
|
||||
match key {
|
||||
Field::Query => {
|
||||
let (params, array_mode) = match state {
|
||||
States::HasPartialQueryData { query: Some(_), .. } => {
|
||||
return Err(<A::Error as de::Error>::duplicate_field("query"))
|
||||
}
|
||||
States::Empty => (None, None),
|
||||
States::HasPartialQueryData {
|
||||
query: None,
|
||||
params,
|
||||
array_mode,
|
||||
} => (params, array_mode),
|
||||
};
|
||||
state = States::HasPartialQueryData {
|
||||
query: Some(m.next_value()?),
|
||||
params,
|
||||
array_mode,
|
||||
};
|
||||
}
|
||||
Field::Params => {
|
||||
#[doc(hidden)]
|
||||
struct PgText {
|
||||
value: Vec<Option<String>>,
|
||||
}
|
||||
impl<'de> Deserialize<'de> for PgText {
|
||||
fn deserialize<D>(__deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Ok(PgText {
|
||||
value: bytes_to_pg_text(__deserializer)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let (query, array_mode) = match state {
|
||||
States::HasPartialQueryData {
|
||||
params: Some(_), ..
|
||||
} => {
|
||||
return Err(<A::Error as de::Error>::duplicate_field("params"))
|
||||
}
|
||||
States::Empty => (None, None),
|
||||
States::HasPartialQueryData {
|
||||
query,
|
||||
params: None,
|
||||
array_mode,
|
||||
} => (query, array_mode),
|
||||
};
|
||||
state = States::HasPartialQueryData {
|
||||
query,
|
||||
params: Some(m.next_value::<PgText>()?.value),
|
||||
array_mode,
|
||||
};
|
||||
}
|
||||
Field::ArrayMode => {
|
||||
let (query, params) = match state {
|
||||
States::HasPartialQueryData {
|
||||
array_mode: Some(_),
|
||||
..
|
||||
} => {
|
||||
return Err(<A::Error as de::Error>::duplicate_field(
|
||||
"arrayMode",
|
||||
))
|
||||
}
|
||||
States::Empty => (None, None),
|
||||
States::HasPartialQueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode: None,
|
||||
} => (query, params),
|
||||
};
|
||||
state = States::HasPartialQueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode: Some(m.next_value()?),
|
||||
};
|
||||
}
|
||||
Field::Ignore => {
|
||||
let _ = m.next_value::<de::IgnoredAny>()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
match state {
|
||||
States::HasPartialQueryData {
|
||||
query: Some(query),
|
||||
params: Some(params),
|
||||
array_mode,
|
||||
} => Ok(QueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode: array_mode.unwrap_or_default(),
|
||||
}),
|
||||
States::Empty | States::HasPartialQueryData { query: None, .. } => {
|
||||
Err(<A::Error as de::Error>::missing_field("query"))
|
||||
}
|
||||
States::HasPartialQueryData { params: None, .. } => {
|
||||
Err(<A::Error as de::Error>::missing_field("params"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Deserializer::deserialize_struct(d, "QueryData", &["query", "params", "arrayMode"], Visitor)
|
||||
}
|
||||
}
|
||||
|
||||
struct BatchQueryData {
|
||||
queries: Vec<QueryData>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
impl<'de> Deserialize<'de> for Payload {
|
||||
fn deserialize<D>(d: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
enum Field {
|
||||
Queries,
|
||||
Query,
|
||||
Params,
|
||||
ArrayMode,
|
||||
Ignore,
|
||||
}
|
||||
|
||||
enum States {
|
||||
Empty,
|
||||
HasQueries(Vec<QueryData>),
|
||||
HasPartialQueryData {
|
||||
query: Option<String>,
|
||||
params: Option<Vec<Option<String>>>,
|
||||
#[allow(clippy::option_option)]
|
||||
array_mode: Option<Option<bool>>,
|
||||
},
|
||||
}
|
||||
|
||||
struct FieldVisitor;
|
||||
|
||||
impl<'de> de::Visitor<'de> for FieldVisitor {
|
||||
type Value = Field;
|
||||
|
||||
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(r#"a JSON object string of either "query", "params", "arrayMode", or "queries"."#)
|
||||
}
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
self.visit_bytes(v.as_bytes())
|
||||
}
|
||||
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
match v {
|
||||
b"queries" => Ok(Field::Queries),
|
||||
b"query" => Ok(Field::Query),
|
||||
b"params" => Ok(Field::Params),
|
||||
b"arrayMode" => Ok(Field::ArrayMode),
|
||||
_ => Ok(Field::Ignore),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'de> Deserialize<'de> for Field {
|
||||
#[inline]
|
||||
fn deserialize<D>(d: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
d.deserialize_identifier(FieldVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
struct Visitor;
|
||||
impl<'de> de::Visitor<'de> for Visitor {
|
||||
type Value = Payload;
|
||||
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(
|
||||
"a json object containing either a query object, or a list of query objects",
|
||||
)
|
||||
}
|
||||
#[inline]
|
||||
fn visit_map<A>(self, mut m: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::MapAccess<'de>,
|
||||
{
|
||||
let mut state = States::Empty;
|
||||
|
||||
while let Some(key) = m.next_key()? {
|
||||
match key {
|
||||
Field::Queries => match state {
|
||||
States::Empty => state = States::HasQueries(m.next_value()?),
|
||||
States::HasQueries(_) => {
|
||||
return Err(<A::Error as de::Error>::duplicate_field("queries"))
|
||||
}
|
||||
States::HasPartialQueryData { .. } => {
|
||||
return Err(<A::Error as de::Error>::unknown_field(
|
||||
"queries",
|
||||
&["query", "params", "arrayMode"],
|
||||
))
|
||||
}
|
||||
},
|
||||
Field::Query => {
|
||||
let (params, array_mode) = match state {
|
||||
States::HasQueries(_) => {
|
||||
return Err(<A::Error as de::Error>::unknown_field(
|
||||
"query",
|
||||
&["queries"],
|
||||
))
|
||||
}
|
||||
States::HasPartialQueryData { query: Some(_), .. } => {
|
||||
return Err(<A::Error as de::Error>::duplicate_field("query"))
|
||||
}
|
||||
States::Empty => (None, None),
|
||||
States::HasPartialQueryData {
|
||||
query: None,
|
||||
params,
|
||||
array_mode,
|
||||
} => (params, array_mode),
|
||||
};
|
||||
state = States::HasPartialQueryData {
|
||||
query: Some(m.next_value()?),
|
||||
params,
|
||||
array_mode,
|
||||
};
|
||||
}
|
||||
Field::Params => {
|
||||
#[doc(hidden)]
|
||||
struct PgText {
|
||||
value: Vec<Option<String>>,
|
||||
}
|
||||
impl<'de> Deserialize<'de> for PgText {
|
||||
fn deserialize<D>(__deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Ok(PgText {
|
||||
value: bytes_to_pg_text(__deserializer)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let (query, array_mode) = match state {
|
||||
States::HasQueries(_) => {
|
||||
return Err(<A::Error as de::Error>::unknown_field(
|
||||
"params",
|
||||
&["queries"],
|
||||
))
|
||||
}
|
||||
States::HasPartialQueryData {
|
||||
params: Some(_), ..
|
||||
} => {
|
||||
return Err(<A::Error as de::Error>::duplicate_field("params"))
|
||||
}
|
||||
States::Empty => (None, None),
|
||||
States::HasPartialQueryData {
|
||||
query,
|
||||
params: None,
|
||||
array_mode,
|
||||
} => (query, array_mode),
|
||||
};
|
||||
state = States::HasPartialQueryData {
|
||||
query,
|
||||
params: Some(m.next_value::<PgText>()?.value),
|
||||
array_mode,
|
||||
};
|
||||
}
|
||||
Field::ArrayMode => {
|
||||
let (query, params) = match state {
|
||||
States::HasQueries(_) => {
|
||||
return Err(<A::Error as de::Error>::unknown_field(
|
||||
"arrayMode",
|
||||
&["queries"],
|
||||
))
|
||||
}
|
||||
States::HasPartialQueryData {
|
||||
array_mode: Some(_),
|
||||
..
|
||||
} => {
|
||||
return Err(<A::Error as de::Error>::duplicate_field(
|
||||
"arrayMode",
|
||||
))
|
||||
}
|
||||
States::Empty => (None, None),
|
||||
States::HasPartialQueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode: None,
|
||||
} => (query, params),
|
||||
};
|
||||
state = States::HasPartialQueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode: Some(m.next_value()?),
|
||||
};
|
||||
}
|
||||
Field::Ignore => {
|
||||
let _ = m.next_value::<de::IgnoredAny>()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
match state {
|
||||
States::HasQueries(queries) => Ok(Payload::Batch(BatchQueryData { queries })),
|
||||
States::HasPartialQueryData {
|
||||
query: Some(query),
|
||||
params: Some(params),
|
||||
array_mode,
|
||||
} => Ok(Payload::Single(QueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode: array_mode.unwrap_or_default(),
|
||||
})),
|
||||
States::Empty | States::HasPartialQueryData { query: None, .. } => {
|
||||
Err(<A::Error as de::Error>::missing_field("query"))
|
||||
}
|
||||
States::HasPartialQueryData { params: None, .. } => {
|
||||
Err(<A::Error as de::Error>::missing_field("params"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Deserializer::deserialize_struct(
|
||||
d,
|
||||
"Payload",
|
||||
&["queries", "query", "params", "arrayMode"],
|
||||
Visitor,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
enum Payload {
|
||||
Single(QueryData),
|
||||
Batch(BatchQueryData),
|
||||
@@ -105,10 +501,10 @@ static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
// TODO: consider avoiding the allocation here.
|
||||
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
|
||||
let json: Vec<Value> = Deserialize::deserialize(deserializer)?;
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user