From 05f7fc4a06a8fec92d87533ba2af2636cb4604ab Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 16 Sep 2024 15:25:34 +0100 Subject: [PATCH] split out --- proxy/src/serverless/json.rs | 382 ++++++++++++++++++++++++- proxy/src/serverless/sql_over_http.rs | 390 +------------------------- 2 files changed, 387 insertions(+), 385 deletions(-) diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index d2d6776f37..136f343290 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -1,3 +1,6 @@ +use std::fmt; + +use serde::de; use serde::Deserialize; use serde::Deserializer; use serde_json::Map; @@ -6,8 +9,381 @@ use tokio_postgres::types::Kind; use tokio_postgres::types::Type; use tokio_postgres::Row; -pub(crate) struct PgText { - pub(crate) value: Vec>, +use super::sql_over_http::BatchQueryData; +use super::sql_over_http::Payload; +use super::sql_over_http::QueryData; + +impl<'de> Deserialize<'de> for QueryData { + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { + enum Field { + Query, + Params, + ArrayMode, + Ignore, + } + + enum States { + Empty, + HasPartialQueryData { + query: Option, + params: Option>>, + #[allow(clippy::option_option)] + array_mode: Option>, + }, + } + + struct FieldVisitor; + + impl<'de> de::Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(r#"a JSON object string of either "query", "params", or "arrayMode"."#) + } + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + self.visit_bytes(v.as_bytes()) + } + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + match v { + b"query" => Ok(Field::Query), + b"params" => Ok(Field::Params), + b"arrayMode" => Ok(Field::ArrayMode), + _ => Ok(Field::Ignore), + } + } + } + impl<'de> Deserialize<'de> for Field { + #[inline] + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { + d.deserialize_identifier(FieldVisitor) + } + } + + struct Visitor; + impl<'de> de::Visitor<'de> for Visitor { + type Value = QueryData; + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str( + "a json object containing either a query object, or a list of query objects", + ) + } + #[inline] + fn visit_map(self, mut m: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut state = States::Empty; + + while let Some(key) = m.next_key()? { + match key { + Field::Query => { + let (params, array_mode) = match state { + States::HasPartialQueryData { query: Some(_), .. } => { + return Err(::duplicate_field("query")) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query: None, + params, + array_mode, + } => (params, array_mode), + }; + state = States::HasPartialQueryData { + query: Some(m.next_value()?), + params, + array_mode, + }; + } + Field::Params => { + let (query, array_mode) = match state { + States::HasPartialQueryData { + params: Some(_), .. + } => { + return Err(::duplicate_field("params")) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query, + params: None, + array_mode, + } => (query, array_mode), + }; + state = States::HasPartialQueryData { + query, + params: Some(m.next_value::()?.value), + array_mode, + }; + } + Field::ArrayMode => { + let (query, params) = match state { + States::HasPartialQueryData { + array_mode: Some(_), + .. + } => { + return Err(::duplicate_field( + "arrayMode", + )) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query, + params, + array_mode: None, + } => (query, params), + }; + state = States::HasPartialQueryData { + query, + params, + array_mode: Some(m.next_value()?), + }; + } + Field::Ignore => { + let _ = m.next_value::()?; + } + } + } + match state { + States::HasPartialQueryData { + query: Some(query), + params: Some(params), + array_mode, + } => Ok(QueryData { + query, + params, + array_mode: array_mode.unwrap_or_default(), + }), + States::Empty | States::HasPartialQueryData { query: None, .. } => { + Err(::missing_field("query")) + } + States::HasPartialQueryData { params: None, .. } => { + Err(::missing_field("params")) + } + } + } + } + + Deserializer::deserialize_struct(d, "QueryData", &["query", "params", "arrayMode"], Visitor) + } +} + +impl<'de> Deserialize<'de> for Payload { + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { + enum Field { + Queries, + Query, + Params, + ArrayMode, + Ignore, + } + + enum States { + Empty, + HasQueries(Vec), + HasPartialQueryData { + query: Option, + params: Option>>, + #[allow(clippy::option_option)] + array_mode: Option>, + }, + } + + struct FieldVisitor; + + impl<'de> de::Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(r#"a JSON object string of either "query", "params", "arrayMode", or "queries"."#) + } + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + self.visit_bytes(v.as_bytes()) + } + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + match v { + b"queries" => Ok(Field::Queries), + b"query" => Ok(Field::Query), + b"params" => Ok(Field::Params), + b"arrayMode" => Ok(Field::ArrayMode), + _ => Ok(Field::Ignore), + } + } + } + impl<'de> Deserialize<'de> for Field { + #[inline] + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { + d.deserialize_identifier(FieldVisitor) + } + } + + struct Visitor; + impl<'de> de::Visitor<'de> for Visitor { + type Value = Payload; + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str( + "a json object containing either a query object, or a list of query objects", + ) + } + #[inline] + fn visit_map(self, mut m: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut state = States::Empty; + + while let Some(key) = m.next_key()? { + match key { + Field::Queries => match state { + States::Empty => state = States::HasQueries(m.next_value()?), + States::HasQueries(_) => { + return Err(::duplicate_field("queries")) + } + States::HasPartialQueryData { .. } => { + return Err(::unknown_field( + "queries", + &["query", "params", "arrayMode"], + )) + } + }, + Field::Query => { + let (params, array_mode) = match state { + States::HasQueries(_) => { + return Err(::unknown_field( + "query", + &["queries"], + )) + } + States::HasPartialQueryData { query: Some(_), .. } => { + return Err(::duplicate_field("query")) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query: None, + params, + array_mode, + } => (params, array_mode), + }; + state = States::HasPartialQueryData { + query: Some(m.next_value()?), + params, + array_mode, + }; + } + Field::Params => { + let (query, array_mode) = match state { + States::HasQueries(_) => { + return Err(::unknown_field( + "params", + &["queries"], + )) + } + States::HasPartialQueryData { + params: Some(_), .. + } => { + return Err(::duplicate_field("params")) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query, + params: None, + array_mode, + } => (query, array_mode), + }; + state = States::HasPartialQueryData { + query, + params: Some(m.next_value::()?.value), + array_mode, + }; + } + Field::ArrayMode => { + let (query, params) = match state { + States::HasQueries(_) => { + return Err(::unknown_field( + "arrayMode", + &["queries"], + )) + } + States::HasPartialQueryData { + array_mode: Some(_), + .. + } => { + return Err(::duplicate_field( + "arrayMode", + )) + } + States::Empty => (None, None), + States::HasPartialQueryData { + query, + params, + array_mode: None, + } => (query, params), + }; + state = States::HasPartialQueryData { + query, + params, + array_mode: Some(m.next_value()?), + }; + } + Field::Ignore => { + let _ = m.next_value::()?; + } + } + } + match state { + States::HasQueries(queries) => Ok(Payload::Batch(BatchQueryData { queries })), + States::HasPartialQueryData { + query: Some(query), + params: Some(params), + array_mode, + } => Ok(Payload::Single(QueryData { + query, + params, + array_mode: array_mode.unwrap_or_default(), + })), + States::Empty | States::HasPartialQueryData { query: None, .. } => { + Err(::missing_field("query")) + } + States::HasPartialQueryData { params: None, .. } => { + Err(::missing_field("params")) + } + } + } + } + + Deserializer::deserialize_struct( + d, + "Payload", + &["queries", "query", "params", "arrayMode"], + Visitor, + ) + } +} + +struct PgText { + value: Vec>, } impl<'de> Deserialize<'de> for PgText { fn deserialize(__deserializer: D) -> Result @@ -26,7 +402,7 @@ impl<'de> Deserialize<'de> for PgText { // Convert json non-string types to strings, so that they can be passed to Postgres // as parameters. // -pub(crate) fn json_to_pg_text(json: Vec) -> Vec> { +fn json_to_pg_text(json: Vec) -> Vec> { json.iter().map(json_value_to_pg_text).collect() } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index b4e96d9ef8..15f4ee5639 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,4 +1,3 @@ -use std::fmt; use std::pin::pin; use std::sync::Arc; @@ -22,9 +21,6 @@ use hyper1::Response; use hyper1::StatusCode; use hyper1::{HeaderMap, Request}; use pq_proto::StartupMessageParamsBuilder; -use serde::de; -use serde::Deserialize; -use serde::Deserializer; use serde::Serialize; use tokio::time; use tokio_postgres::error::DbError; @@ -58,12 +54,11 @@ use crate::metrics::HttpDirection; use crate::metrics::Metrics; use crate::proxy::run_until_cancelled; use crate::proxy::NeonOptions; -use crate::serverless::backend::HttpConnError; -use crate::serverless::json::PgText; use crate::usage_metrics::MetricCounterRecorder; use crate::DbName; use crate::RoleName; +use super::backend::HttpConnError; use super::backend::LocalProxyConnError; use super::backend::PoolingBackend; use super::conn_pool::AuthData; @@ -74,386 +69,17 @@ use super::http_util::json_response; use super::json::pg_text_row_to_json; use super::json::JsonConversionError; -struct QueryData { - query: String, - params: Vec>, - array_mode: Option, +pub(crate) struct QueryData { + pub(crate) query: String, + pub(crate) params: Vec>, + pub(crate) array_mode: Option, } -impl<'de> Deserialize<'de> for QueryData { - fn deserialize(d: D) -> Result - where - D: Deserializer<'de>, - { - enum Field { - Query, - Params, - ArrayMode, - Ignore, - } - - enum States { - Empty, - HasPartialQueryData { - query: Option, - params: Option>>, - #[allow(clippy::option_option)] - array_mode: Option>, - }, - } - - struct FieldVisitor; - - impl<'de> de::Visitor<'de> for FieldVisitor { - type Value = Field; - - fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(r#"a JSON object string of either "query", "params", or "arrayMode"."#) - } - fn visit_str(self, v: &str) -> Result - where - E: de::Error, - { - self.visit_bytes(v.as_bytes()) - } - fn visit_bytes(self, v: &[u8]) -> Result - where - E: de::Error, - { - match v { - b"query" => Ok(Field::Query), - b"params" => Ok(Field::Params), - b"arrayMode" => Ok(Field::ArrayMode), - _ => Ok(Field::Ignore), - } - } - } - impl<'de> Deserialize<'de> for Field { - #[inline] - fn deserialize(d: D) -> Result - where - D: Deserializer<'de>, - { - d.deserialize_identifier(FieldVisitor) - } - } - - struct Visitor; - impl<'de> de::Visitor<'de> for Visitor { - type Value = QueryData; - fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str( - "a json object containing either a query object, or a list of query objects", - ) - } - #[inline] - fn visit_map(self, mut m: A) -> Result - where - A: de::MapAccess<'de>, - { - let mut state = States::Empty; - - while let Some(key) = m.next_key()? { - match key { - Field::Query => { - let (params, array_mode) = match state { - States::HasPartialQueryData { query: Some(_), .. } => { - return Err(::duplicate_field("query")) - } - States::Empty => (None, None), - States::HasPartialQueryData { - query: None, - params, - array_mode, - } => (params, array_mode), - }; - state = States::HasPartialQueryData { - query: Some(m.next_value()?), - params, - array_mode, - }; - } - Field::Params => { - let (query, array_mode) = match state { - States::HasPartialQueryData { - params: Some(_), .. - } => { - return Err(::duplicate_field("params")) - } - States::Empty => (None, None), - States::HasPartialQueryData { - query, - params: None, - array_mode, - } => (query, array_mode), - }; - state = States::HasPartialQueryData { - query, - params: Some(m.next_value::()?.value), - array_mode, - }; - } - Field::ArrayMode => { - let (query, params) = match state { - States::HasPartialQueryData { - array_mode: Some(_), - .. - } => { - return Err(::duplicate_field( - "arrayMode", - )) - } - States::Empty => (None, None), - States::HasPartialQueryData { - query, - params, - array_mode: None, - } => (query, params), - }; - state = States::HasPartialQueryData { - query, - params, - array_mode: Some(m.next_value()?), - }; - } - Field::Ignore => { - let _ = m.next_value::()?; - } - } - } - match state { - States::HasPartialQueryData { - query: Some(query), - params: Some(params), - array_mode, - } => Ok(QueryData { - query, - params, - array_mode: array_mode.unwrap_or_default(), - }), - States::Empty | States::HasPartialQueryData { query: None, .. } => { - Err(::missing_field("query")) - } - States::HasPartialQueryData { params: None, .. } => { - Err(::missing_field("params")) - } - } - } - } - - Deserializer::deserialize_struct(d, "QueryData", &["query", "params", "arrayMode"], Visitor) - } +pub(crate) struct BatchQueryData { + pub(crate) queries: Vec, } -struct BatchQueryData { - queries: Vec, -} - -impl<'de> Deserialize<'de> for Payload { - fn deserialize(d: D) -> Result - where - D: Deserializer<'de>, - { - enum Field { - Queries, - Query, - Params, - ArrayMode, - Ignore, - } - - enum States { - Empty, - HasQueries(Vec), - HasPartialQueryData { - query: Option, - params: Option>>, - #[allow(clippy::option_option)] - array_mode: Option>, - }, - } - - struct FieldVisitor; - - impl<'de> de::Visitor<'de> for FieldVisitor { - type Value = Field; - - fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(r#"a JSON object string of either "query", "params", "arrayMode", or "queries"."#) - } - fn visit_str(self, v: &str) -> Result - where - E: de::Error, - { - self.visit_bytes(v.as_bytes()) - } - fn visit_bytes(self, v: &[u8]) -> Result - where - E: de::Error, - { - match v { - b"queries" => Ok(Field::Queries), - b"query" => Ok(Field::Query), - b"params" => Ok(Field::Params), - b"arrayMode" => Ok(Field::ArrayMode), - _ => Ok(Field::Ignore), - } - } - } - impl<'de> Deserialize<'de> for Field { - #[inline] - fn deserialize(d: D) -> Result - where - D: Deserializer<'de>, - { - d.deserialize_identifier(FieldVisitor) - } - } - - struct Visitor; - impl<'de> de::Visitor<'de> for Visitor { - type Value = Payload; - fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str( - "a json object containing either a query object, or a list of query objects", - ) - } - #[inline] - fn visit_map(self, mut m: A) -> Result - where - A: de::MapAccess<'de>, - { - let mut state = States::Empty; - - while let Some(key) = m.next_key()? { - match key { - Field::Queries => match state { - States::Empty => state = States::HasQueries(m.next_value()?), - States::HasQueries(_) => { - return Err(::duplicate_field("queries")) - } - States::HasPartialQueryData { .. } => { - return Err(::unknown_field( - "queries", - &["query", "params", "arrayMode"], - )) - } - }, - Field::Query => { - let (params, array_mode) = match state { - States::HasQueries(_) => { - return Err(::unknown_field( - "query", - &["queries"], - )) - } - States::HasPartialQueryData { query: Some(_), .. } => { - return Err(::duplicate_field("query")) - } - States::Empty => (None, None), - States::HasPartialQueryData { - query: None, - params, - array_mode, - } => (params, array_mode), - }; - state = States::HasPartialQueryData { - query: Some(m.next_value()?), - params, - array_mode, - }; - } - Field::Params => { - let (query, array_mode) = match state { - States::HasQueries(_) => { - return Err(::unknown_field( - "params", - &["queries"], - )) - } - States::HasPartialQueryData { - params: Some(_), .. - } => { - return Err(::duplicate_field("params")) - } - States::Empty => (None, None), - States::HasPartialQueryData { - query, - params: None, - array_mode, - } => (query, array_mode), - }; - state = States::HasPartialQueryData { - query, - params: Some(m.next_value::()?.value), - array_mode, - }; - } - Field::ArrayMode => { - let (query, params) = match state { - States::HasQueries(_) => { - return Err(::unknown_field( - "arrayMode", - &["queries"], - )) - } - States::HasPartialQueryData { - array_mode: Some(_), - .. - } => { - return Err(::duplicate_field( - "arrayMode", - )) - } - States::Empty => (None, None), - States::HasPartialQueryData { - query, - params, - array_mode: None, - } => (query, params), - }; - state = States::HasPartialQueryData { - query, - params, - array_mode: Some(m.next_value()?), - }; - } - Field::Ignore => { - let _ = m.next_value::()?; - } - } - } - match state { - States::HasQueries(queries) => Ok(Payload::Batch(BatchQueryData { queries })), - States::HasPartialQueryData { - query: Some(query), - params: Some(params), - array_mode, - } => Ok(Payload::Single(QueryData { - query, - params, - array_mode: array_mode.unwrap_or_default(), - })), - States::Empty | States::HasPartialQueryData { query: None, .. } => { - Err(::missing_field("query")) - } - States::HasPartialQueryData { params: None, .. } => { - Err(::missing_field("params")) - } - } - } - } - - Deserializer::deserialize_struct( - d, - "Payload", - &["queries", "query", "params", "arrayMode"], - Visitor, - ) - } -} - -enum Payload { +pub(crate) enum Payload { Single(QueryData), Batch(BatchQueryData), }