mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-18 19:02:56 +00:00
Compare commits
6 Commits
release
...
proxy-http
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe8b93ab9d | ||
|
|
7e3e7f1cca | ||
|
|
0b0ed662d9 | ||
|
|
50bd65769f | ||
|
|
90534b1745 | ||
|
|
99d52df475 |
@@ -1,30 +1,40 @@
|
||||
use itertools::Itertools;
|
||||
use serde_json::value::RawValue;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use tokio_postgres::types::Kind;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::Row;
|
||||
use typed_json::json;
|
||||
|
||||
use super::json_raw_value::LazyValue;
|
||||
|
||||
//
|
||||
// Convert json non-string types to strings, so that they can be passed to Postgres
|
||||
// as parameters.
|
||||
//
|
||||
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
|
||||
json.iter().map(json_value_to_pg_text).collect()
|
||||
pub(crate) fn json_to_pg_text(
|
||||
json: &[&RawValue],
|
||||
) -> Result<Vec<Option<String>>, serde_json::Error> {
|
||||
json.iter().copied().map(json_value_to_pg_text).try_collect()
|
||||
}
|
||||
|
||||
fn json_value_to_pg_text(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
fn json_value_to_pg_text(value: &RawValue) -> Result<Option<String>, serde_json::Error> {
|
||||
let lazy_value = serde_json::from_str(value.get())?;
|
||||
match lazy_value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
LazyValue::Null => Ok(None),
|
||||
|
||||
// convert to text with escaping
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
|
||||
LazyValue::Bool | LazyValue::Number | LazyValue::Object => {
|
||||
Ok(Some(value.get().to_string()))
|
||||
}
|
||||
|
||||
// avoid escaping here, as we pass this as a parameter
|
||||
Value::String(s) => Some(s.to_string()),
|
||||
LazyValue::String(s) => Ok(Some(s.into_owned())),
|
||||
|
||||
// special care for arrays
|
||||
Value::Array(_) => json_array_to_pg_array(value),
|
||||
LazyValue::Array(arr) => Ok(Some(json_array_to_pg_array(arr)?)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,27 +46,42 @@ fn json_value_to_pg_text(value: &Value) -> Option<String> {
|
||||
//
|
||||
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
|
||||
//
|
||||
fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
fn json_array_to_pg_array(arr: Vec<&RawValue>) -> Result<String, serde_json::Error> {
|
||||
let mut output = String::new();
|
||||
let mut first = true;
|
||||
|
||||
output.push('{');
|
||||
|
||||
for value in arr {
|
||||
if !first {
|
||||
output.push(',');
|
||||
}
|
||||
first = false;
|
||||
|
||||
let value = json_array_to_pg_array_inner(value)?;
|
||||
output.push_str(value.as_deref().unwrap_or("NULL"));
|
||||
}
|
||||
|
||||
output.push('}');
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn json_array_to_pg_array_inner(value: &RawValue) -> Result<Option<String>, serde_json::Error> {
|
||||
let lazy_value = serde_json::from_str(value.get())?;
|
||||
match lazy_value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
LazyValue::Null => Ok(None),
|
||||
|
||||
// convert to text with escaping
|
||||
// here string needs to be escaped, as it is part of the array
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
|
||||
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
|
||||
LazyValue::Bool | LazyValue::Number | LazyValue::String(_) => {
|
||||
Ok(Some(value.get().to_string()))
|
||||
}
|
||||
LazyValue::Object => Ok(Some(json!(value.get().to_string()).to_string())),
|
||||
|
||||
// recurse into array
|
||||
Value::Array(arr) => {
|
||||
let vals = arr
|
||||
.iter()
|
||||
.map(json_array_to_pg_array)
|
||||
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
Some(format!("{{{vals}}}"))
|
||||
}
|
||||
LazyValue::Array(arr) => Ok(Some(json_array_to_pg_array(arr)?)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,25 +284,31 @@ mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn json_to_pg_text_test(json: Vec<serde_json::Value>) -> Vec<Option<String>> {
|
||||
let json = serde_json::Value::Array(json).to_string();
|
||||
let json: Vec<&RawValue> = serde_json::from_str(&json).unwrap();
|
||||
json_to_pg_text(&json).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_types_to_pg_params() {
|
||||
let json = vec![Value::Bool(true), Value::Bool(false)];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
let pg_params = json_to_pg_text_test(json);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some("true".to_owned()), Some("false".to_owned())]
|
||||
);
|
||||
|
||||
let json = vec![Value::Number(serde_json::Number::from(42))];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
let pg_params = json_to_pg_text_test(json);
|
||||
assert_eq!(pg_params, vec![Some("42".to_owned())]);
|
||||
|
||||
let json = vec![Value::String("foo\"".to_string())];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
let pg_params = json_to_pg_text_test(json);
|
||||
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
|
||||
|
||||
let json = vec![Value::Null];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
let pg_params = json_to_pg_text_test(json);
|
||||
assert_eq!(pg_params, vec![None]);
|
||||
}
|
||||
|
||||
@@ -286,7 +317,7 @@ mod tests {
|
||||
// atoms and escaping
|
||||
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
let pg_params = json_to_pg_text_test(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(
|
||||
@@ -297,7 +328,7 @@ mod tests {
|
||||
// nested arrays
|
||||
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
let pg_params = json_to_pg_text_test(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(
|
||||
@@ -307,7 +338,7 @@ mod tests {
|
||||
// array of objects
|
||||
let json = r#"[{"foo": 1},{"bar": 2}]"#;
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
let pg_params = json_to_pg_text_test(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
|
||||
|
||||
193
proxy/src/serverless/json_raw_value.rs
Normal file
193
proxy/src/serverless/json_raw_value.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
//! [`serde_json::Value`] but uses RawValue internally
|
||||
//!
|
||||
//! This code forks from the serde_json code, but replaces internal Value with RawValue where possible.
|
||||
//!
|
||||
//! Taken from <https://github.com/serde-rs/json/blob/faab2e8d2fcf781a3f77f329df836ffb3aaacfba/src/value/de.rs>
|
||||
//! Licensed from serde-rs under MIT or APACHE-2.0, with modifications by Conrad Ludgate
|
||||
|
||||
use core::fmt;
|
||||
use std::borrow::Cow;
|
||||
|
||||
use serde::{
|
||||
de::{IgnoredAny, MapAccess, SeqAccess, Visitor},
|
||||
Deserialize,
|
||||
};
|
||||
use serde_json::value::RawValue;
|
||||
|
||||
pub enum LazyValue<'de> {
|
||||
Null,
|
||||
Bool,
|
||||
Number,
|
||||
String(Cow<'de, str>),
|
||||
Array(Vec<&'de RawValue>),
|
||||
Object,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for LazyValue<'de> {
|
||||
#[inline]
|
||||
fn deserialize<D>(deserializer: D) -> Result<LazyValue<'de>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct ValueVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for ValueVisitor {
|
||||
type Value = LazyValue<'de>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("any valid JSON value")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_bool<E>(self, _value: bool) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Bool)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_i64<E>(self, _value: i64) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Number)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_u64<E>(self, _value: u64) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Number)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_f64<E>(self, _value: f64) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Number)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_str<E>(self, value: &str) -> Result<LazyValue<'de>, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
self.visit_string(String::from(value))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_borrowed_str<E>(self, value: &'de str) -> Result<LazyValue<'de>, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(LazyValue::String(Cow::Borrowed(value)))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_string<E>(self, value: String) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::String(Cow::Owned(value)))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_none<E>(self) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Null)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_some<D>(self, deserializer: D) -> Result<LazyValue<'de>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
Deserialize::deserialize(deserializer)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_unit<E>(self) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Null)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_seq<V>(self, mut visitor: V) -> Result<LazyValue<'de>, V::Error>
|
||||
where
|
||||
V: SeqAccess<'de>,
|
||||
{
|
||||
let mut vec = Vec::new();
|
||||
|
||||
while let Some(elem) = visitor.next_element()? {
|
||||
vec.push(elem);
|
||||
}
|
||||
|
||||
Ok(LazyValue::Array(vec))
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut visitor: V) -> Result<LazyValue<'de>, V::Error>
|
||||
where
|
||||
V: MapAccess<'de>,
|
||||
{
|
||||
while visitor.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
|
||||
Ok(LazyValue::Object)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(ValueVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::borrow::Cow;
|
||||
|
||||
use typed_json::json;
|
||||
|
||||
use super::LazyValue;
|
||||
|
||||
#[test]
|
||||
fn object() {
|
||||
let json = json! {{
|
||||
"foo": {
|
||||
"bar": 1
|
||||
},
|
||||
"baz": [2, 3],
|
||||
}}
|
||||
.to_string();
|
||||
|
||||
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
|
||||
|
||||
let LazyValue::Object = lazy else {
|
||||
panic!("expected object")
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array() {
|
||||
let json = json! {[
|
||||
{
|
||||
"bar": 1
|
||||
},
|
||||
[2, 3],
|
||||
]}
|
||||
.to_string();
|
||||
|
||||
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
|
||||
|
||||
let LazyValue::Array(array) = lazy else {
|
||||
panic!("expected array")
|
||||
};
|
||||
assert_eq!(array.len(), 2);
|
||||
|
||||
assert_eq!(array[0].get(), r#"{"bar":1}"#);
|
||||
assert_eq!(array[1].get(), r#"[2,3]"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn string() {
|
||||
let json = json! { "hello world" }.to_string();
|
||||
|
||||
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
|
||||
|
||||
let LazyValue::String(Cow::Borrowed(string)) = lazy else {
|
||||
panic!("expected borrowed string")
|
||||
};
|
||||
assert_eq!(string, "hello world");
|
||||
|
||||
let json = json! { "hello \n world" }.to_string();
|
||||
|
||||
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
|
||||
|
||||
let LazyValue::String(Cow::Owned(string)) = lazy else {
|
||||
panic!("expected owned string")
|
||||
};
|
||||
assert_eq!(string, "hello \n world");
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ mod conn_pool;
|
||||
mod http_conn_pool;
|
||||
mod http_util;
|
||||
mod json;
|
||||
mod json_raw_value;
|
||||
mod local_conn_pool;
|
||||
mod sql_over_http;
|
||||
mod websocket;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::borrow::Cow;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -22,7 +23,7 @@ use hyper::StatusCode;
|
||||
use hyper::{HeaderMap, Request};
|
||||
use pq_proto::StartupMessageParamsBuilder;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
use tokio::time;
|
||||
use tokio_postgres::error::DbError;
|
||||
use tokio_postgres::error::ErrorPosition;
|
||||
@@ -76,24 +77,28 @@ use super::local_conn_pool;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
#[serde(deserialize_with = "bytes_to_pg_text")]
|
||||
params: Vec<Option<String>>,
|
||||
#[serde(bound = "'de: 'a")]
|
||||
struct QueryData<'a> {
|
||||
#[serde(borrow)]
|
||||
query: Cow<'a, str>,
|
||||
|
||||
#[serde(borrow)]
|
||||
params: Vec<&'a RawValue>,
|
||||
|
||||
#[serde(default)]
|
||||
array_mode: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct BatchQueryData {
|
||||
queries: Vec<QueryData>,
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[serde(bound = "'de: 'a")]
|
||||
struct BatchQueryData<'a> {
|
||||
queries: Vec<QueryData<'a>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Payload {
|
||||
Single(QueryData),
|
||||
Batch(BatchQueryData),
|
||||
enum Payload<'a> {
|
||||
Batch(BatchQueryData<'a>),
|
||||
Single(QueryData<'a>),
|
||||
}
|
||||
|
||||
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
|
||||
@@ -106,13 +111,18 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
// TODO: consider avoiding the allocation here.
|
||||
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
|
||||
Ok(json_to_pg_text(json))
|
||||
fn parse_pg_params(params: &[&RawValue]) -> Result<Vec<Option<String>>, ReadPayloadError> {
|
||||
json_to_pg_text(params).map_err(ReadPayloadError::Parse)
|
||||
}
|
||||
|
||||
fn parse_payload(body: &[u8]) -> Result<Payload<'_>, ReadPayloadError> {
|
||||
// RawValue doesn't work via untagged enums
|
||||
// so instead we try parse each individually
|
||||
if let Ok(batch) = serde_json::from_slice(body) {
|
||||
Ok(Payload::Batch(batch))
|
||||
} else {
|
||||
Ok(Payload::Single(serde_json::from_slice(body)?))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -615,8 +625,7 @@ async fn handle_db_inner(
|
||||
async {
|
||||
let body = request.into_body().collect().await?.to_bytes();
|
||||
info!(length = body.len(), "request payload read");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
|
||||
Ok::<Bytes, ReadPayloadError>(body)
|
||||
}
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
@@ -660,7 +669,7 @@ async fn handle_db_inner(
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
|
||||
let (payload, mut client) = match run_until_cancelled(
|
||||
let (body, mut client) = match run_until_cancelled(
|
||||
// Run both operations in parallel
|
||||
try_join(
|
||||
pin!(fetch_and_process_request),
|
||||
@@ -674,6 +683,8 @@ async fn handle_db_inner(
|
||||
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
|
||||
};
|
||||
|
||||
let payload = parse_payload(&body)?;
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/json");
|
||||
@@ -781,7 +792,7 @@ async fn handle_auth_broker_inner(
|
||||
.map(|b| b.boxed()))
|
||||
}
|
||||
|
||||
impl QueryData {
|
||||
impl QueryData<'_> {
|
||||
async fn process(
|
||||
self,
|
||||
config: &'static HttpConfig,
|
||||
@@ -855,7 +866,7 @@ impl QueryData {
|
||||
}
|
||||
}
|
||||
|
||||
impl BatchQueryData {
|
||||
impl BatchQueryData<'_> {
|
||||
async fn process(
|
||||
self,
|
||||
config: &'static HttpConfig,
|
||||
@@ -931,7 +942,7 @@ async fn query_batch(
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
transaction: &Transaction<'_>,
|
||||
queries: BatchQueryData,
|
||||
queries: BatchQueryData<'_>,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let mut results = Vec::with_capacity(queries.queries.len());
|
||||
@@ -969,12 +980,12 @@ async fn query_batch(
|
||||
async fn query_to_json<T: GenericClient>(
|
||||
config: &'static HttpConfig,
|
||||
client: &T,
|
||||
data: QueryData,
|
||||
data: QueryData<'_>,
|
||||
current_size: &mut usize,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
|
||||
info!("executing query");
|
||||
let query_params = data.params;
|
||||
let query_params = parse_pg_params(&data.params)?;
|
||||
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
|
||||
info!("finished executing query");
|
||||
|
||||
@@ -1100,3 +1111,41 @@ impl Discard<'_> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use typed_json::json;
|
||||
|
||||
use super::parse_payload;
|
||||
use super::Payload;
|
||||
|
||||
#[test]
|
||||
fn raw_single_payload() {
|
||||
let body = json! {
|
||||
{"query":"select $1","params":["1"]}
|
||||
}
|
||||
.to_string();
|
||||
|
||||
let Payload::Single(query) = parse_payload(body.as_bytes()).unwrap() else {
|
||||
panic!("expected single")
|
||||
};
|
||||
assert_eq!(&*query.query, "select $1");
|
||||
assert_eq!(query.params[0].get(), "\"1\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn raw_batch_payload() {
|
||||
let body = json! {{
|
||||
"queries": [
|
||||
{"query":"select $1","params":["1"]},
|
||||
{"query":"select $1","params":["2"]},
|
||||
]
|
||||
}}
|
||||
.to_string();
|
||||
|
||||
let Payload::Batch(query) = parse_payload(body.as_bytes()).unwrap() else {
|
||||
panic!("expected batch")
|
||||
};
|
||||
assert_eq!(query.queries.len(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user