diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 3a18ad0726..d2d692947b 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -1,119 +1,113 @@ use postgres_client::Row; use postgres_client::types::{Kind, Type}; -use serde::de::{Deserializer, Visitor}; +use serde::Deserialize; +use serde::de::{Deserializer, IgnoredAny, Visitor}; use serde_json::value::RawValue; use serde_json::{Map, Value}; -struct PgTextVisitor<'de>(&'de RawValue); -impl<'de> Visitor<'de> for PgTextVisitor<'de> { - type Value = Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("any valid JSON value") - } - - // special care for nulls - fn visit_none(self) -> Result { - Ok(None) - } - fn visit_unit(self) -> Result { - Ok(None) - } - - // convert to text with escaping - fn visit_bool(self, _: bool) -> Result { - Ok(Some(self.0.get().to_owned())) - } - fn visit_i64(self, _: i64) -> Result { - Ok(Some(self.0.get().to_owned())) - } - fn visit_u64(self, _: u64) -> Result { - Ok(Some(self.0.get().to_owned())) - } - fn visit_f64(self, _: f64) -> Result { - Ok(Some(self.0.get().to_owned())) - } - fn visit_map>(self, _: A) -> Result { - Ok(Some(self.0.get().to_owned())) - } - - // avoid escaping here, as we pass this as a parameter - fn visit_str(self, v: &str) -> Result { - Ok(Some(v.to_string())) - } - fn visit_string(self, v: String) -> Result { - Ok(Some(v)) - } - - fn visit_seq>(self, mut seq: A) -> Result { - let mut output = String::new(); - output.push('{'); - let mut comma = false; - while let Some(val) = seq.next_element::()? { - if comma { - output.push(','); - } - comma = true; - - let val = match val { - Value::Null => "NULL".to_string(), - - // convert to text with escaping - // here string needs to be escaped, as it is part of the array - v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => v.to_string(), - v @ Value::Object(_) => Value::String(v.to_string()).to_string(), - - // recurse into array - Value::Array(arr) => json_array_to_pg_array(&arr), - }; - output.push_str(&val); - } - output.push('}'); - Ok(Some(output)) - } -} - // // Convert json non-string types to strings, so that they can be passed to Postgres // as parameters. // pub(crate) fn json_to_pg_text(json: Vec>) -> Vec> { json.into_iter() - .map(|raw| raw.deserialize_any(PgTextVisitor(&raw)).unwrap_or(None)) + .map(|raw| { + match raw.get().as_bytes() { + // special handling for null. + b"null" => None, + // remove the escape characters from the string. + [b'"', ..] => { + Some(String::deserialize(&*raw).expect("json should be a valid string")) + } + [b'[', ..] => { + let mut output = String::with_capacity(raw.get().len()); + raw.deserialize_seq(PgArrayVisitor(&raw, &mut output)) + .expect("json should be a valid"); + Some(output) + } + // write all other values out directly + _ => Some(>::from(raw).into()), + } + }) .collect() } -// -// Serialize a JSON array to a Postgres array. Contrary to the strings in the params -// in the array we need to escape the strings. Postgres is okay with arrays of form -// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving -// it for Postgres to check. -// -// Example of the same escaping in node-postgres: packages/pg/lib/utils.js -// -fn json_array_to_pg_array(arr: &[Value]) -> String { - let mut output = String::new(); - output.push('{'); - for val in arr { - if output.len() > 1 { - output.push(','); - } +struct PgArrayVisitor<'de, 'a>(&'de RawValue, &'a mut String); - let val = match val { - Value::Null => "NULL".to_string(), - - // convert to text with escaping - // here string needs to be escaped, as it is part of the array - v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => v.to_string(), - v @ Value::Object(_) => Value::String(v.to_string()).to_string(), - - // recurse into array - Value::Array(arr) => json_array_to_pg_array(arr), - }; - output.push_str(&val); +impl PgArrayVisitor<'_, '_> { + #[inline] + #[allow(clippy::unnecessary_wraps)] + fn raw(self) -> Result<(), E> { + self.1.push_str(self.0.get()); + Ok(()) + } +} + +impl<'de> Visitor<'de> for PgArrayVisitor<'de, '_> { + type Value = (); + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("any valid JSON value") + } + + // special care for nulls + fn visit_none(self) -> Result { + self.1.push_str("NULL"); + Ok(()) + } + fn visit_unit(self) -> Result { + self.1.push_str("NULL"); + Ok(()) + } + + // convert to text with escaping + fn visit_bool(self, _: bool) -> Result { + self.raw() + } + fn visit_i64(self, _: i64) -> Result { + self.raw() + } + fn visit_u64(self, _: u64) -> Result { + self.raw() + } + fn visit_i128(self, _: i128) -> Result { + self.raw() + } + fn visit_u128(self, _: u128) -> Result { + self.raw() + } + fn visit_f64(self, _: f64) -> Result { + self.raw() + } + fn visit_str(self, _: &str) -> Result { + self.raw() + } + + // an object needs re-escaping + fn visit_map>(self, mut map: A) -> Result { + while map.next_entry::()?.is_some() {} + + let s = serde_json::to_string(self.0.get()).expect("a string should be valid json"); + self.1.push_str(&s); + Ok(()) + } + + // write an array + fn visit_seq>(self, mut seq: A) -> Result { + self.1.push('{'); + let mut comma = false; + while let Some(val) = seq.next_element::<&'de RawValue>()? { + if comma { + self.1.push(','); + } + comma = true; + + val.deserialize_any(PgArrayVisitor(val, self.1)) + .expect("all json values are valid"); + } + self.1.push('}'); + Ok(()) } - output.push('}'); - output } #[derive(Debug, thiserror::Error)]