diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index fb24d3d3aa..77ee2496c0 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -1,30 +1,40 @@ +use itertools::Itertools; +use serde_json::value::RawValue; use serde_json::Map; use serde_json::Value; use tokio_postgres::types::Kind; use tokio_postgres::types::Type; use tokio_postgres::Row; +use typed_json::json; + +use super::json_raw_value::LazyValue; // // Convert json non-string types to strings, so that they can be passed to Postgres // as parameters. // -pub(crate) fn json_to_pg_text(json: Vec) -> Vec> { - json.into_iter().map(json_value_to_pg_text).collect() +pub(crate) fn json_to_pg_text( + json: Vec<&RawValue>, +) -> Result>, serde_json::Error> { + json.into_iter().map(json_value_to_pg_text).try_collect() } -fn json_value_to_pg_text(value: Value) -> Option { +fn json_value_to_pg_text(value: &RawValue) -> Result, serde_json::Error> { + let value = serde_json::from_str(value.get())?; match value { // special care for nulls - Value::Null => None, + LazyValue::Null => Ok(None), // convert to text with escaping - v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), + v @ (LazyValue::Bool(_) | LazyValue::Number(_) | LazyValue::Object(_)) => { + Ok(Some(v.to_string())) + } // avoid escaping here, as we pass this as a parameter - Value::String(s) => Some(s), + LazyValue::String(s) => Ok(Some(s.into_owned())), // special care for arrays - Value::Array(arr) => Some(json_array_to_pg_array(arr)), + LazyValue::Array(arr) => Ok(Some(json_array_to_pg_array(arr)?)), } } @@ -36,7 +46,7 @@ fn json_value_to_pg_text(value: Value) -> Option { // // Example of the same escaping in node-postgres: packages/pg/lib/utils.js // -fn json_array_to_pg_array(arr: Vec) -> String { +fn json_array_to_pg_array(arr: Vec<&RawValue>) -> Result { let mut output = String::new(); let mut first = true; @@ -48,27 +58,30 @@ fn json_array_to_pg_array(arr: Vec) -> String { } first = false; - let value = json_array_to_pg_array_inner(value); + let value = json_array_to_pg_array_inner(value)?; output.push_str(value.as_deref().unwrap_or("NULL")); } output.push('}'); - output + Ok(output) } -fn json_array_to_pg_array_inner(value: Value) -> Option { +fn json_array_to_pg_array_inner(value: &RawValue) -> Result, serde_json::Error> { + let value = serde_json::from_str(value.get())?; match value { // special care for nulls - Value::Null => None, + LazyValue::Null => Ok(None), // convert to text with escaping // here string needs to be escaped, as it is part of the array - v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()), - v @ Value::Object(_) => json_array_to_pg_array_inner(Value::String(v.to_string())), + v @ (LazyValue::Bool(_) | LazyValue::Number(_) | LazyValue::String(_)) => { + Ok(Some(v.to_string())) + } + v @ LazyValue::Object(_) => Ok(Some(json!(v.to_string()).to_string())), // recurse into array - Value::Array(arr) => Some(json_array_to_pg_array(arr)), + LazyValue::Array(arr) => Ok(Some(json_array_to_pg_array(arr)?)), } } @@ -271,8 +284,10 @@ mod tests { use super::*; use serde_json::json; - fn json_to_pg_text_test(json: Vec) -> Vec> { - json_to_pg_text(json) + fn json_to_pg_text_test(json: Vec) -> Vec> { + let json = serde_json::Value::Array(json).to_string(); + let json: Vec<&RawValue> = serde_json::from_str(&json).unwrap(); + json_to_pg_text(json).unwrap() } #[test] diff --git a/proxy/src/serverless/json_raw_value.rs b/proxy/src/serverless/json_raw_value.rs new file mode 100644 index 0000000000..79caa706bb --- /dev/null +++ b/proxy/src/serverless/json_raw_value.rs @@ -0,0 +1,234 @@ +//! [`serde_json::Value`] but uses RawValue internally +//! +//! This code forks from the serde_json code, but replaces internal Value with RawValue where possible. +//! +//! Taken from +//! Licensed from serde-rs under MIT or APACHE-2.0, with modifications by Conrad Ludgate + +use core::fmt; +use std::borrow::Cow; + +use indexmap::IndexMap; +use serde::{ + de::{MapAccess, SeqAccess, Visitor}, + Deserialize, Serialize, +}; +use serde_json::{value::RawValue, Number}; + +pub enum LazyValue<'de> { + Null, + Bool(bool), + Number(Number), + String(Cow<'de, str>), + Array(Vec<&'de RawValue>), + Object(IndexMap), +} + +impl<'de> Deserialize<'de> for LazyValue<'de> { + #[inline] + fn deserialize(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + struct ValueVisitor; + + impl<'de> Visitor<'de> for ValueVisitor { + type Value = LazyValue<'de>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("any valid JSON value") + } + + #[inline] + fn visit_bool(self, value: bool) -> Result, E> { + Ok(LazyValue::Bool(value)) + } + + #[inline] + fn visit_i64(self, value: i64) -> Result, E> { + Ok(LazyValue::Number(value.into())) + } + + #[inline] + fn visit_u64(self, value: u64) -> Result, E> { + Ok(LazyValue::Number(value.into())) + } + + #[inline] + fn visit_f64(self, value: f64) -> Result, E> { + Ok(Number::from_f64(value).map_or(LazyValue::Null, LazyValue::Number)) + } + + #[inline] + fn visit_str(self, value: &str) -> Result, E> + where + E: serde::de::Error, + { + self.visit_string(String::from(value)) + } + + #[inline] + fn visit_borrowed_str(self, value: &'de str) -> Result, E> + where + E: serde::de::Error, + { + Ok(LazyValue::String(Cow::Borrowed(value))) + } + + #[inline] + fn visit_string(self, value: String) -> Result, E> { + Ok(LazyValue::String(Cow::Owned(value))) + } + + #[inline] + fn visit_none(self) -> Result, E> { + Ok(LazyValue::Null) + } + + #[inline] + fn visit_some(self, deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + Deserialize::deserialize(deserializer) + } + + #[inline] + fn visit_unit(self) -> Result, E> { + Ok(LazyValue::Null) + } + + #[inline] + fn visit_seq(self, mut visitor: V) -> Result, V::Error> + where + V: SeqAccess<'de>, + { + let mut vec = Vec::new(); + + while let Some(elem) = visitor.next_element()? { + vec.push(elem); + } + + Ok(LazyValue::Array(vec)) + } + + fn visit_map(self, mut visitor: V) -> Result, V::Error> + where + V: MapAccess<'de>, + { + let mut values = IndexMap::new(); + + while let Some((key, value)) = visitor.next_entry()? { + values.insert(key, value); + } + + Ok(LazyValue::Object(values)) + } + } + + deserializer.deserialize_any(ValueVisitor) + } +} + +impl Serialize for LazyValue<'_> { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + match self { + LazyValue::Null => serializer.serialize_unit(), + LazyValue::Bool(b) => serializer.serialize_bool(*b), + LazyValue::Number(n) => n.serialize(serializer), + LazyValue::String(s) => serializer.serialize_str(s), + LazyValue::Array(v) => v.serialize(serializer), + LazyValue::Object(m) => { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(m.len()))?; + for (k, v) in m { + map.serialize_entry(k, v)?; + } + map.end() + } + } + } +} + +#[allow(clippy::to_string_trait_impl)] +impl ToString for LazyValue<'_> { + fn to_string(&self) -> String { + serde_json::to_string(self).expect("json encoding a LazyValue should never error") + } +} + +#[cfg(test)] +mod tests { + use std::borrow::Cow; + + use typed_json::json; + + use super::LazyValue; + + #[test] + fn object() { + let json = json! {{ + "foo": { + "bar": 1 + }, + "baz": [2, 3], + }} + .to_string(); + + let lazy: LazyValue = serde_json::from_str(&json).unwrap(); + + let LazyValue::Object(object) = lazy else { + panic!("expected object") + }; + assert_eq!(object.len(), 2); + + assert_eq!(object["foo"].get(), r#"{"bar":1}"#); + assert_eq!(object["baz"].get(), r#"[2,3]"#); + } + + #[test] + fn array() { + let json = json! {[ + { + "bar": 1 + }, + [2, 3], + ]} + .to_string(); + + let lazy: LazyValue = serde_json::from_str(&json).unwrap(); + + let LazyValue::Array(array) = lazy else { + panic!("expected array") + }; + assert_eq!(array.len(), 2); + + assert_eq!(array[0].get(), r#"{"bar":1}"#); + assert_eq!(array[1].get(), r#"[2,3]"#); + } + + #[test] + fn string() { + let json = json! { "hello world" }.to_string(); + + let lazy: LazyValue = serde_json::from_str(&json).unwrap(); + + let LazyValue::String(Cow::Borrowed(string)) = lazy else { + panic!("expected borrowed string") + }; + assert_eq!(string, "hello world"); + + let json = json! { "hello \n world" }.to_string(); + + let lazy: LazyValue = serde_json::from_str(&json).unwrap(); + + let LazyValue::String(Cow::Owned(string)) = lazy else { + panic!("expected owned string") + }; + assert_eq!(string, "hello \n world"); + } +} diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 95f64e972c..e7aa8decfa 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -8,6 +8,7 @@ mod conn_pool; mod http_conn_pool; mod http_util; mod json; +mod json_raw_value; mod local_conn_pool; mod sql_over_http; mod websocket; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index cf3324926c..f02f70fd9a 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -22,7 +22,7 @@ use hyper::StatusCode; use hyper::{HeaderMap, Request}; use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; -use serde_json::Value; +use serde_json::value::RawValue; use tokio::time; use tokio_postgres::error::DbError; use tokio_postgres::error::ErrorPosition; @@ -111,8 +111,8 @@ where D: serde::de::Deserializer<'de>, { // TODO: consider avoiding the allocation here. - let json: Vec = serde::de::Deserialize::deserialize(deserializer)?; - Ok(json_to_pg_text(json)) + let json: Vec<&RawValue> = serde::de::Deserialize::deserialize(deserializer)?; + json_to_pg_text(json).map_err(serde::de::Error::custom) } #[derive(Debug, thiserror::Error)]