mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 00:42:54 +00:00
do not roundtrip params via serde_json::Value
This commit is contained in:
@@ -1,29 +1,86 @@
|
||||
use postgres_client::Row;
|
||||
use postgres_client::types::{Kind, Type};
|
||||
use serde::de::{Deserializer, Visitor};
|
||||
use serde_json::value::RawValue;
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
struct PgTextVisitor<'de>(&'de RawValue);
|
||||
impl<'de> Visitor<'de> for PgTextVisitor<'de> {
|
||||
type Value = Option<String>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("any valid JSON value")
|
||||
}
|
||||
|
||||
// special care for nulls
|
||||
fn visit_none<E>(self) -> Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
fn visit_unit<E>(self) -> Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// convert to text with escaping
|
||||
fn visit_bool<E>(self, _: bool) -> Result<Self::Value, E> {
|
||||
Ok(Some(self.0.get().to_owned()))
|
||||
}
|
||||
fn visit_i64<E>(self, _: i64) -> Result<Self::Value, E> {
|
||||
Ok(Some(self.0.get().to_owned()))
|
||||
}
|
||||
fn visit_u64<E>(self, _: u64) -> Result<Self::Value, E> {
|
||||
Ok(Some(self.0.get().to_owned()))
|
||||
}
|
||||
fn visit_f64<E>(self, _: f64) -> Result<Self::Value, E> {
|
||||
Ok(Some(self.0.get().to_owned()))
|
||||
}
|
||||
fn visit_map<A: serde::de::MapAccess<'de>>(self, _: A) -> Result<Self::Value, A::Error> {
|
||||
Ok(Some(self.0.get().to_owned()))
|
||||
}
|
||||
|
||||
// avoid escaping here, as we pass this as a parameter
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> {
|
||||
Ok(Some(v.to_string()))
|
||||
}
|
||||
fn visit_string<E>(self, v: String) -> Result<Self::Value, E> {
|
||||
Ok(Some(v))
|
||||
}
|
||||
|
||||
fn visit_seq<A: serde::de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
|
||||
let mut output = String::new();
|
||||
output.push('{');
|
||||
let mut comma = false;
|
||||
while let Some(val) = seq.next_element::<Value>()? {
|
||||
if comma {
|
||||
output.push(',');
|
||||
}
|
||||
comma = true;
|
||||
|
||||
let val = match val {
|
||||
Value::Null => "NULL".to_string(),
|
||||
|
||||
// convert to text with escaping
|
||||
// here string needs to be escaped, as it is part of the array
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => v.to_string(),
|
||||
v @ Value::Object(_) => Value::String(v.to_string()).to_string(),
|
||||
|
||||
// recurse into array
|
||||
Value::Array(arr) => json_array_to_pg_array(&arr),
|
||||
};
|
||||
output.push_str(&val);
|
||||
}
|
||||
output.push('}');
|
||||
Ok(Some(output))
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Convert json non-string types to strings, so that they can be passed to Postgres
|
||||
// as parameters.
|
||||
//
|
||||
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
|
||||
json.iter().map(json_value_to_pg_text).collect()
|
||||
}
|
||||
|
||||
fn json_value_to_pg_text(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
|
||||
// convert to text with escaping
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
|
||||
|
||||
// avoid escaping here, as we pass this as a parameter
|
||||
Value::String(s) => Some(s.clone()),
|
||||
|
||||
// special care for arrays
|
||||
Value::Array(arr) => Some(json_array_to_pg_array(arr)),
|
||||
}
|
||||
pub(crate) fn json_to_pg_text(json: Vec<Box<RawValue>>) -> Vec<Option<String>> {
|
||||
json.into_iter()
|
||||
.map(|raw| raw.deserialize_any(PgTextVisitor(&raw)).unwrap_or(None))
|
||||
.collect()
|
||||
}
|
||||
|
||||
//
|
||||
@@ -385,6 +442,14 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn json_to_pg_text(json: Vec<serde_json::Value>) -> Vec<Option<String>> {
|
||||
let json = json
|
||||
.into_iter()
|
||||
.map(|value| serde_json::from_str(&value.to_string()).unwrap())
|
||||
.collect();
|
||||
super::json_to_pg_text(json)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_types_to_pg_params() {
|
||||
let json = vec![Value::Bool(true), Value::Bool(false)];
|
||||
|
||||
@@ -18,7 +18,6 @@ use postgres_client::{
|
||||
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
use tokio::time::{self, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -48,9 +47,8 @@ use crate::util::run_until_cancelled;
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
#[serde(deserialize_with = "bytes_to_pg_text")]
|
||||
#[serde(default)]
|
||||
params: Vec<Option<String>>,
|
||||
params: Vec<Box<RawValue>>,
|
||||
#[serde(default)]
|
||||
array_mode: Option<bool>,
|
||||
}
|
||||
@@ -67,15 +65,6 @@ enum Payload {
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
// TODO: consider avoiding the allocation here.
|
||||
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
pub(crate) async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestContext,
|
||||
@@ -498,7 +487,7 @@ async fn handle_db_inner(
|
||||
|
||||
debug!(length = body.len(), "request payload read");
|
||||
|
||||
// try batched, then try unbatched.
|
||||
// try unbatched, then try batched.
|
||||
let payload = if let Ok(batch) = serde_json::from_slice(&body) {
|
||||
Payload::Batch(batch)
|
||||
} else {
|
||||
@@ -892,7 +881,8 @@ async fn query_to_json<T: GenericClient>(
|
||||
) -> Result<(ReadyForQueryStatus, impl Serialize + use<T>), SqlOverHttpError> {
|
||||
let query_start = Instant::now();
|
||||
|
||||
let query_params = data.params;
|
||||
let query_params = json_to_pg_text(data.params);
|
||||
|
||||
let mut row_stream = client
|
||||
.query_raw_txt(&data.query, query_params)
|
||||
.await
|
||||
@@ -1045,7 +1035,7 @@ mod tests {
|
||||
} = serde_json::from_str(payload).unwrap();
|
||||
|
||||
assert_eq!(query, "SELECT * FROM users WHERE name = ?");
|
||||
assert_eq!(params, vec![Some(String::from("test"))]);
|
||||
assert_eq!(params[0].get(), "\"test\"");
|
||||
assert!(array_mode.unwrap());
|
||||
|
||||
let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}";
|
||||
@@ -1057,7 +1047,7 @@ mod tests {
|
||||
query.query,
|
||||
format!("SELECT * FROM users{i} WHERE name = ?")
|
||||
);
|
||||
assert_eq!(query.params, vec![Some(format!("test{i}"))]);
|
||||
assert_eq!(query.params[0].get(), &format!("\"test{i}\""));
|
||||
assert_eq!(query.array_mode.unwrap(), i > 0);
|
||||
}
|
||||
|
||||
@@ -1069,7 +1059,7 @@ mod tests {
|
||||
} = serde_json::from_str(payload).unwrap();
|
||||
|
||||
assert_eq!(query, "SELECT 1");
|
||||
assert_eq!(params, vec![]);
|
||||
assert!(params.is_empty());
|
||||
assert!(array_mode.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user