From 33151e87fc7e4f483b2a5dbde6c9bdc184dd293d Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 8 Jul 2025 08:35:03 +0100 Subject: [PATCH] do not roundtrip params via serde_json::Value --- proxy/src/serverless/json.rs | 101 +++++++++++++++++++++----- proxy/src/serverless/sql_over_http.rs | 24 ++---- 2 files changed, 90 insertions(+), 35 deletions(-) diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 0e0a8c1a08..3a18ad0726 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -1,29 +1,86 @@ use postgres_client::Row; use postgres_client::types::{Kind, Type}; +use serde::de::{Deserializer, Visitor}; +use serde_json::value::RawValue; use serde_json::{Map, Value}; +struct PgTextVisitor<'de>(&'de RawValue); +impl<'de> Visitor<'de> for PgTextVisitor<'de> { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("any valid JSON value") + } + + // special care for nulls + fn visit_none(self) -> Result { + Ok(None) + } + fn visit_unit(self) -> Result { + Ok(None) + } + + // convert to text with escaping + fn visit_bool(self, _: bool) -> Result { + Ok(Some(self.0.get().to_owned())) + } + fn visit_i64(self, _: i64) -> Result { + Ok(Some(self.0.get().to_owned())) + } + fn visit_u64(self, _: u64) -> Result { + Ok(Some(self.0.get().to_owned())) + } + fn visit_f64(self, _: f64) -> Result { + Ok(Some(self.0.get().to_owned())) + } + fn visit_map>(self, _: A) -> Result { + Ok(Some(self.0.get().to_owned())) + } + + // avoid escaping here, as we pass this as a parameter + fn visit_str(self, v: &str) -> Result { + Ok(Some(v.to_string())) + } + fn visit_string(self, v: String) -> Result { + Ok(Some(v)) + } + + fn visit_seq>(self, mut seq: A) -> Result { + let mut output = String::new(); + output.push('{'); + let mut comma = false; + while let Some(val) = seq.next_element::()? { + if comma { + output.push(','); + } + comma = true; + + let val = match val { + Value::Null => "NULL".to_string(), + + // convert to text with escaping + // here string needs to be escaped, as it is part of the array + v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => v.to_string(), + v @ Value::Object(_) => Value::String(v.to_string()).to_string(), + + // recurse into array + Value::Array(arr) => json_array_to_pg_array(&arr), + }; + output.push_str(&val); + } + output.push('}'); + Ok(Some(output)) + } +} + // // Convert json non-string types to strings, so that they can be passed to Postgres // as parameters. // -pub(crate) fn json_to_pg_text(json: Vec) -> Vec> { - json.iter().map(json_value_to_pg_text).collect() -} - -fn json_value_to_pg_text(value: &Value) -> Option { - match value { - // special care for nulls - Value::Null => None, - - // convert to text with escaping - v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), - - // avoid escaping here, as we pass this as a parameter - Value::String(s) => Some(s.clone()), - - // special care for arrays - Value::Array(arr) => Some(json_array_to_pg_array(arr)), - } +pub(crate) fn json_to_pg_text(json: Vec>) -> Vec> { + json.into_iter() + .map(|raw| raw.deserialize_any(PgTextVisitor(&raw)).unwrap_or(None)) + .collect() } // @@ -385,6 +442,14 @@ mod tests { use super::*; + fn json_to_pg_text(json: Vec) -> Vec> { + let json = json + .into_iter() + .map(|value| serde_json::from_str(&value.to_string()).unwrap()) + .collect(); + super::json_to_pg_text(json) + } + #[test] fn test_atomic_types_to_pg_params() { let json = vec![Value::Bool(true), Value::Bool(false)]; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 41922172c8..c89794bc00 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -18,7 +18,6 @@ use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; use serde::Serialize; -use serde_json::Value; use serde_json::value::RawValue; use tokio::time::{self, Instant}; use tokio_util::sync::CancellationToken; @@ -48,9 +47,8 @@ use crate::util::run_until_cancelled; #[serde(rename_all = "camelCase")] struct QueryData { query: String, - #[serde(deserialize_with = "bytes_to_pg_text")] #[serde(default)] - params: Vec>, + params: Vec>, #[serde(default)] array_mode: Option, } @@ -67,15 +65,6 @@ enum Payload { static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); -fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> -where - D: serde::de::Deserializer<'de>, -{ - // TODO: consider avoiding the allocation here. - let json: Vec = serde::de::Deserialize::deserialize(deserializer)?; - Ok(json_to_pg_text(json)) -} - pub(crate) async fn handle( config: &'static ProxyConfig, ctx: RequestContext, @@ -498,7 +487,7 @@ async fn handle_db_inner( debug!(length = body.len(), "request payload read"); - // try batched, then try unbatched. + // try unbatched, then try batched. let payload = if let Ok(batch) = serde_json::from_slice(&body) { Payload::Batch(batch) } else { @@ -892,7 +881,8 @@ async fn query_to_json( ) -> Result<(ReadyForQueryStatus, impl Serialize + use), SqlOverHttpError> { let query_start = Instant::now(); - let query_params = data.params; + let query_params = json_to_pg_text(data.params); + let mut row_stream = client .query_raw_txt(&data.query, query_params) .await @@ -1045,7 +1035,7 @@ mod tests { } = serde_json::from_str(payload).unwrap(); assert_eq!(query, "SELECT * FROM users WHERE name = ?"); - assert_eq!(params, vec![Some(String::from("test"))]); + assert_eq!(params[0].get(), "\"test\""); assert!(array_mode.unwrap()); let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}"; @@ -1057,7 +1047,7 @@ mod tests { query.query, format!("SELECT * FROM users{i} WHERE name = ?") ); - assert_eq!(query.params, vec![Some(format!("test{i}"))]); + assert_eq!(query.params[0].get(), &format!("\"test{i}\"")); assert_eq!(query.array_mode.unwrap(), i > 0); } @@ -1069,7 +1059,7 @@ mod tests { } = serde_json::from_str(payload).unwrap(); assert_eq!(query, "SELECT 1"); - assert_eq!(params, vec![]); + assert!(params.is_empty()); assert!(array_mode.is_none()); } }