fix Payload deser

This commit is contained in:
Conrad Ludgate
2024-10-14 14:02:46 +01:00
parent 7e3e7f1cca
commit fe8b93ab9d
2 changed files with 78 additions and 29 deletions

View File

@@ -14,9 +14,9 @@ use super::json_raw_value::LazyValue;
// as parameters.
//
pub(crate) fn json_to_pg_text(
json: Vec<&RawValue>,
json: &[&RawValue],
) -> Result<Vec<Option<String>>, serde_json::Error> {
json.into_iter().map(json_value_to_pg_text).try_collect()
json.iter().copied().map(json_value_to_pg_text).try_collect()
}
fn json_value_to_pg_text(value: &RawValue) -> Result<Option<String>, serde_json::Error> {
@@ -287,7 +287,7 @@ mod tests {
fn json_to_pg_text_test(json: Vec<serde_json::Value>) -> Vec<Option<String>> {
let json = serde_json::Value::Array(json).to_string();
let json: Vec<&RawValue> = serde_json::from_str(&json).unwrap();
json_to_pg_text(json).unwrap()
json_to_pg_text(&json).unwrap()
}
#[test]

View File

@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::pin::pin;
use std::sync::Arc;
@@ -76,24 +77,28 @@ use super::local_conn_pool;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
params: Vec<Option<String>>,
#[serde(bound = "'de: 'a")]
struct QueryData<'a> {
#[serde(borrow)]
query: Cow<'a, str>,
#[serde(borrow)]
params: Vec<&'a RawValue>,
#[serde(default)]
array_mode: Option<bool>,
}
#[derive(serde::Deserialize)]
struct BatchQueryData {
queries: Vec<QueryData>,
#[serde(rename_all = "camelCase")]
#[serde(bound = "'de: 'a")]
struct BatchQueryData<'a> {
queries: Vec<QueryData<'a>>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Payload {
Single(QueryData),
Batch(BatchQueryData),
enum Payload<'a> {
Batch(BatchQueryData<'a>),
Single(QueryData<'a>),
}
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
@@ -106,13 +111,18 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<&RawValue> = serde::de::Deserialize::deserialize(deserializer)?;
json_to_pg_text(json).map_err(serde::de::Error::custom)
fn parse_pg_params(params: &[&RawValue]) -> Result<Vec<Option<String>>, ReadPayloadError> {
json_to_pg_text(params).map_err(ReadPayloadError::Parse)
}
fn parse_payload(body: &[u8]) -> Result<Payload<'_>, ReadPayloadError> {
// RawValue doesn't work via untagged enums
// so instead we try parse each individually
if let Ok(batch) = serde_json::from_slice(body) {
Ok(Payload::Batch(batch))
} else {
Ok(Payload::Single(serde_json::from_slice(body)?))
}
}
#[derive(Debug, thiserror::Error)]
@@ -615,8 +625,7 @@ async fn handle_db_inner(
async {
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
Ok::<Bytes, ReadPayloadError>(body)
}
.map_err(SqlOverHttpError::from),
);
@@ -660,7 +669,7 @@ async fn handle_db_inner(
.map_err(SqlOverHttpError::from),
);
let (payload, mut client) = match run_until_cancelled(
let (body, mut client) = match run_until_cancelled(
// Run both operations in parallel
try_join(
pin!(fetch_and_process_request),
@@ -674,6 +683,8 @@ async fn handle_db_inner(
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
};
let payload = parse_payload(&body)?;
let mut response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json");
@@ -781,7 +792,7 @@ async fn handle_auth_broker_inner(
.map(|b| b.boxed()))
}
impl QueryData {
impl QueryData<'_> {
async fn process(
self,
config: &'static HttpConfig,
@@ -855,7 +866,7 @@ impl QueryData {
}
}
impl BatchQueryData {
impl BatchQueryData<'_> {
async fn process(
self,
config: &'static HttpConfig,
@@ -931,7 +942,7 @@ async fn query_batch(
config: &'static HttpConfig,
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
queries: BatchQueryData<'_>,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let mut results = Vec::with_capacity(queries.queries.len());
@@ -969,12 +980,12 @@ async fn query_batch(
async fn query_to_json<T: GenericClient>(
config: &'static HttpConfig,
client: &T,
data: QueryData,
data: QueryData<'_>,
current_size: &mut usize,
parsed_headers: HttpHeaders,
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
info!("executing query");
let query_params = data.params;
let query_params = parse_pg_params(&data.params)?;
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
info!("finished executing query");
@@ -1100,3 +1111,41 @@ impl Discard<'_> {
}
}
}
#[cfg(test)]
mod tests {
use typed_json::json;
use super::parse_payload;
use super::Payload;
#[test]
fn raw_single_payload() {
let body = json! {
{"query":"select $1","params":["1"]}
}
.to_string();
let Payload::Single(query) = parse_payload(body.as_bytes()).unwrap() else {
panic!("expected single")
};
assert_eq!(&*query.query, "select $1");
assert_eq!(query.params[0].get(), "\"1\"");
}
#[test]
fn raw_batch_payload() {
let body = json! {{
"queries": [
{"query":"select $1","params":["1"]},
{"query":"select $1","params":["2"]},
]
}}
.to_string();
let Payload::Batch(query) = parse_payload(body.as_bytes()).unwrap() else {
panic!("expected batch")
};
assert_eq!(query.queries.len(), 2);
}
}