Compare commits

...

4 Commits

Author SHA1 Message Date
Conrad Ludgate
eac1af4e1e optimise array encoding 2025-07-08 09:07:15 +01:00
Conrad Ludgate
33151e87fc do not roundtrip params via serde_json::Value 2025-07-08 08:35:03 +01:00
Conrad Ludgate
7e1979db0d do not use serde untagged 2025-07-07 17:56:16 +01:00
Conrad Ludgate
539150ff64 invert json_array_to_pg_array 2025-07-07 17:53:13 +01:00
2 changed files with 135 additions and 97 deletions

View File

@@ -1,60 +1,112 @@
use postgres_client::Row;
use postgres_client::types::{Kind, Type};
use serde::Deserialize;
use serde::de::{Deserializer, IgnoredAny, Visitor};
use serde_json::value::RawValue;
use serde_json::{Map, Value};
//
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
json.iter().map(json_value_to_pg_text).collect()
pub(crate) fn json_to_pg_text(json: Vec<Box<RawValue>>) -> Vec<Option<String>> {
json.into_iter()
.map(|raw| {
match raw.get().as_bytes() {
// special handling for null.
b"null" => None,
// remove the escape characters from the string.
[b'"', ..] => {
Some(String::deserialize(&*raw).expect("json should be a valid string"))
}
[b'[', ..] => {
let mut output = String::with_capacity(raw.get().len());
raw.deserialize_seq(PgArrayVisitor(&raw, &mut output))
.expect("json should be a valid");
Some(output)
}
// write all other values out directly
_ => Some(<Box<str>>::from(raw).into()),
}
})
.collect()
}
fn json_value_to_pg_text(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
struct PgArrayVisitor<'de, 'a>(&'de RawValue, &'a mut String);
// convert to text with escaping
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.clone()),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
impl PgArrayVisitor<'_, '_> {
#[inline]
#[allow(clippy::unnecessary_wraps)]
fn raw<E>(self) -> Result<(), E> {
self.1.push_str(self.0.get());
Ok(())
}
}
//
// Serialize a JSON array to a Postgres array. Contrary to the strings in the params
// in the array we need to escape the strings. Postgres is okay with arrays of form
// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
// it for Postgres to check.
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
impl<'de> Visitor<'de> for PgArrayVisitor<'de, '_> {
type Value = ();
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("any valid JSON value")
}
// recurse into array
Value::Array(arr) => {
let vals = arr
.iter()
.map(json_array_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.collect::<Vec<_>>()
.join(",");
// special care for nulls
fn visit_none<E>(self) -> Result<Self::Value, E> {
self.1.push_str("NULL");
Ok(())
}
fn visit_unit<E>(self) -> Result<Self::Value, E> {
self.1.push_str("NULL");
Ok(())
}
Some(format!("{{{vals}}}"))
// convert to text with escaping
fn visit_bool<E>(self, _: bool) -> Result<Self::Value, E> {
self.raw()
}
fn visit_i64<E>(self, _: i64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_u64<E>(self, _: u64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_i128<E>(self, _: i128) -> Result<Self::Value, E> {
self.raw()
}
fn visit_u128<E>(self, _: u128) -> Result<Self::Value, E> {
self.raw()
}
fn visit_f64<E>(self, _: f64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_str<E>(self, _: &str) -> Result<Self::Value, E> {
self.raw()
}
// an object needs re-escaping
fn visit_map<A: serde::de::MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
let s = serde_json::to_string(self.0.get()).expect("a string should be valid json");
self.1.push_str(&s);
Ok(())
}
// write an array
fn visit_seq<A: serde::de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
self.1.push('{');
let mut comma = false;
while let Some(val) = seq.next_element::<&'de RawValue>()? {
if comma {
self.1.push(',');
}
comma = true;
val.deserialize_any(PgArrayVisitor(val, self.1))
.expect("all json values are valid");
}
self.1.push('}');
Ok(())
}
}
@@ -384,6 +436,14 @@ mod tests {
use super::*;
fn json_to_pg_text(json: Vec<serde_json::Value>) -> Vec<Option<String>> {
let json = json
.into_iter()
.map(|value| serde_json::from_str(&value.to_string()).unwrap())
.collect();
super::json_to_pg_text(json)
}
#[test]
fn test_atomic_types_to_pg_params() {
let json = vec![Value::Bool(true), Value::Bool(false)];

View File

@@ -18,7 +18,6 @@ use postgres_client::{
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
};
use serde::Serialize;
use serde_json::Value;
use serde_json::value::RawValue;
use tokio::time::{self, Instant};
use tokio_util::sync::CancellationToken;
@@ -48,9 +47,8 @@ use crate::util::run_until_cancelled;
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
#[serde(default)]
params: Vec<Option<String>>,
params: Vec<Box<RawValue>>,
#[serde(default)]
array_mode: Option<bool>,
}
@@ -60,8 +58,6 @@ struct BatchQueryData {
queries: Vec<QueryData>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Payload {
Single(QueryData),
Batch(BatchQueryData),
@@ -69,15 +65,6 @@ enum Payload {
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
Ok(json_to_pg_text(json))
}
pub(crate) async fn handle(
config: &'static ProxyConfig,
ctx: RequestContext,
@@ -499,7 +486,14 @@ async fn handle_db_inner(
.observe(HttpDirection::Request, body.len() as f64);
debug!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
// try unbatched, then try batched.
let payload = if let Ok(batch) = serde_json::from_slice(&body) {
Payload::Batch(batch)
} else {
Payload::Single(serde_json::from_slice(&body)?)
};
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from),
@@ -887,7 +881,8 @@ async fn query_to_json<T: GenericClient>(
) -> Result<(ReadyForQueryStatus, impl Serialize + use<T>), SqlOverHttpError> {
let query_start = Instant::now();
let query_params = data.params;
let query_params = json_to_pg_text(data.params);
let mut row_stream = client
.query_raw_txt(&data.query, query_params)
.await
@@ -1033,55 +1028,38 @@ mod tests {
#[test]
fn test_payload() {
let payload = "{\"query\":\"SELECT * FROM users WHERE name = ?\",\"params\":[\"test\"],\"arrayMode\":true}";
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
let QueryData {
query,
params,
array_mode,
} = serde_json::from_str(payload).unwrap();
match deserialized_payload {
Payload::Single(QueryData {
query,
params,
array_mode,
}) => {
assert_eq!(query, "SELECT * FROM users WHERE name = ?");
assert_eq!(params, vec![Some(String::from("test"))]);
assert!(array_mode.unwrap());
}
Payload::Batch(_) => {
panic!("deserialization failed: case with single query, one param, and array mode")
}
}
assert_eq!(query, "SELECT * FROM users WHERE name = ?");
assert_eq!(params[0].get(), "\"test\"");
assert!(array_mode.unwrap());
let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}";
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
let BatchQueryData { queries } = serde_json::from_str(payload).unwrap();
match deserialized_payload {
Payload::Batch(BatchQueryData { queries }) => {
assert_eq!(queries.len(), 2);
for (i, query) in queries.into_iter().enumerate() {
assert_eq!(
query.query,
format!("SELECT * FROM users{i} WHERE name = ?")
);
assert_eq!(query.params, vec![Some(format!("test{i}"))]);
assert_eq!(query.array_mode.unwrap(), i > 0);
}
}
Payload::Single(_) => panic!("deserialization failed: case with multiple queries"),
assert_eq!(queries.len(), 2);
for (i, query) in queries.into_iter().enumerate() {
assert_eq!(
query.query,
format!("SELECT * FROM users{i} WHERE name = ?")
);
assert_eq!(query.params[0].get(), &format!("\"test{i}\""));
assert_eq!(query.array_mode.unwrap(), i > 0);
}
let payload = "{\"query\":\"SELECT 1\"}";
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
let QueryData {
query,
params,
array_mode,
} = serde_json::from_str(payload).unwrap();
match deserialized_payload {
Payload::Single(QueryData {
query,
params,
array_mode,
}) => {
assert_eq!(query, "SELECT 1");
assert_eq!(params, vec![]);
assert!(array_mode.is_none());
}
Payload::Batch(_) => panic!("deserialization failed: case with only one query"),
}
assert_eq!(query, "SELECT 1");
assert!(params.is_empty());
assert!(array_mode.is_none());
}
}