optimise array encoding

This commit is contained in:
Conrad Ludgate
2025-07-08 09:07:15 +01:00
parent 33151e87fc
commit eac1af4e1e

View File

@@ -1,119 +1,113 @@
use postgres_client::Row;
use postgres_client::types::{Kind, Type};
use serde::de::{Deserializer, Visitor};
use serde::Deserialize;
use serde::de::{Deserializer, IgnoredAny, Visitor};
use serde_json::value::RawValue;
use serde_json::{Map, Value};
struct PgTextVisitor<'de>(&'de RawValue);
impl<'de> Visitor<'de> for PgTextVisitor<'de> {
type Value = Option<String>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("any valid JSON value")
}
// special care for nulls
fn visit_none<E>(self) -> Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E>(self) -> Result<Self::Value, E> {
Ok(None)
}
// convert to text with escaping
fn visit_bool<E>(self, _: bool) -> Result<Self::Value, E> {
Ok(Some(self.0.get().to_owned()))
}
fn visit_i64<E>(self, _: i64) -> Result<Self::Value, E> {
Ok(Some(self.0.get().to_owned()))
}
fn visit_u64<E>(self, _: u64) -> Result<Self::Value, E> {
Ok(Some(self.0.get().to_owned()))
}
fn visit_f64<E>(self, _: f64) -> Result<Self::Value, E> {
Ok(Some(self.0.get().to_owned()))
}
fn visit_map<A: serde::de::MapAccess<'de>>(self, _: A) -> Result<Self::Value, A::Error> {
Ok(Some(self.0.get().to_owned()))
}
// avoid escaping here, as we pass this as a parameter
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> {
Ok(Some(v.to_string()))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E> {
Ok(Some(v))
}
fn visit_seq<A: serde::de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let mut output = String::new();
output.push('{');
let mut comma = false;
while let Some(val) = seq.next_element::<Value>()? {
if comma {
output.push(',');
}
comma = true;
let val = match val {
Value::Null => "NULL".to_string(),
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => v.to_string(),
v @ Value::Object(_) => Value::String(v.to_string()).to_string(),
// recurse into array
Value::Array(arr) => json_array_to_pg_array(&arr),
};
output.push_str(&val);
}
output.push('}');
Ok(Some(output))
}
}
//
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
pub(crate) fn json_to_pg_text(json: Vec<Box<RawValue>>) -> Vec<Option<String>> {
json.into_iter()
.map(|raw| raw.deserialize_any(PgTextVisitor(&raw)).unwrap_or(None))
.map(|raw| {
match raw.get().as_bytes() {
// special handling for null.
b"null" => None,
// remove the escape characters from the string.
[b'"', ..] => {
Some(String::deserialize(&*raw).expect("json should be a valid string"))
}
[b'[', ..] => {
let mut output = String::with_capacity(raw.get().len());
raw.deserialize_seq(PgArrayVisitor(&raw, &mut output))
.expect("json should be a valid");
Some(output)
}
// write all other values out directly
_ => Some(<Box<str>>::from(raw).into()),
}
})
.collect()
}
//
// Serialize a JSON array to a Postgres array. Contrary to the strings in the params
// in the array we need to escape the strings. Postgres is okay with arrays of form
// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
// it for Postgres to check.
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(arr: &[Value]) -> String {
let mut output = String::new();
output.push('{');
for val in arr {
if output.len() > 1 {
output.push(',');
}
struct PgArrayVisitor<'de, 'a>(&'de RawValue, &'a mut String);
let val = match val {
Value::Null => "NULL".to_string(),
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => v.to_string(),
v @ Value::Object(_) => Value::String(v.to_string()).to_string(),
// recurse into array
Value::Array(arr) => json_array_to_pg_array(arr),
};
output.push_str(&val);
impl PgArrayVisitor<'_, '_> {
#[inline]
#[allow(clippy::unnecessary_wraps)]
fn raw<E>(self) -> Result<(), E> {
self.1.push_str(self.0.get());
Ok(())
}
}
impl<'de> Visitor<'de> for PgArrayVisitor<'de, '_> {
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("any valid JSON value")
}
// special care for nulls
fn visit_none<E>(self) -> Result<Self::Value, E> {
self.1.push_str("NULL");
Ok(())
}
fn visit_unit<E>(self) -> Result<Self::Value, E> {
self.1.push_str("NULL");
Ok(())
}
// convert to text with escaping
fn visit_bool<E>(self, _: bool) -> Result<Self::Value, E> {
self.raw()
}
fn visit_i64<E>(self, _: i64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_u64<E>(self, _: u64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_i128<E>(self, _: i128) -> Result<Self::Value, E> {
self.raw()
}
fn visit_u128<E>(self, _: u128) -> Result<Self::Value, E> {
self.raw()
}
fn visit_f64<E>(self, _: f64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_str<E>(self, _: &str) -> Result<Self::Value, E> {
self.raw()
}
// an object needs re-escaping
fn visit_map<A: serde::de::MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
let s = serde_json::to_string(self.0.get()).expect("a string should be valid json");
self.1.push_str(&s);
Ok(())
}
// write an array
fn visit_seq<A: serde::de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
self.1.push('{');
let mut comma = false;
while let Some(val) = seq.next_element::<&'de RawValue>()? {
if comma {
self.1.push(',');
}
comma = true;
val.deserialize_any(PgArrayVisitor(val, self.1))
.expect("all json values are valid");
}
self.1.push('}');
Ok(())
}
output.push('}');
output
}
#[derive(Debug, thiserror::Error)]