proxy/http: switch to typed_json (#8377)

## Summary of changes

This switches JSON rendering logic to `typed_json` in order to
reduce the number of allocations in the HTTP responder path.

Followup from
https://github.com/neondatabase/neon/pull/8319#issuecomment-2216991760.

---------

Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
This commit is contained in:
Luca Bruno
2024-07-15 13:38:52 +02:00
committed by GitHub
parent b329b1c610
commit 8da3b547f8
4 changed files with 59 additions and 51 deletions

11
Cargo.lock generated
View File

@@ -4404,6 +4404,7 @@ dependencies = [
"tracing-opentelemetry",
"tracing-subscriber",
"tracing-utils",
"typed-json",
"url",
"urlencoding",
"utils",
@@ -6665,6 +6666,16 @@ dependencies = [
"static_assertions",
]
[[package]]
name = "typed-json"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6024a8d0025400b3f6b189366e9aa92012cf9c4fe1cd2620848dd61425c49eed"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "typenum"
version = "1.16.0"

View File

@@ -184,6 +184,7 @@ tracing-error = "0.2.0"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json", "ansi"] }
twox-hash = { version = "1.6.3", default-features = false }
typed-json = "0.1"
url = "2.2"
urlencoding = "2.1"
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }

View File

@@ -92,6 +92,7 @@ tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true
tracing.workspace = true
typed-json.workspace = true
url.workspace = true
urlencoding.workspace = true
utils.workspace = true

View File

@@ -18,7 +18,7 @@ use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use pq_proto::StartupMessageParamsBuilder;
use serde_json::json;
use serde::Serialize;
use serde_json::Value;
use tokio::time;
use tokio_postgres::error::DbError;
@@ -32,6 +32,7 @@ use tokio_postgres::Transaction;
use tokio_util::sync::CancellationToken;
use tracing::error;
use tracing::info;
use typed_json::json;
use url::Url;
use utils::http::error::ApiError;
@@ -263,13 +264,8 @@ pub async fn handle(
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
fn get<'a, T: serde::Serialize>(
db: Option<&'a DbError>,
x: impl FnOnce(&'a DbError) -> T,
) -> Value {
db.map(x)
.and_then(|t| serde_json::to_value(t).ok())
.unwrap_or_default()
fn get<'a, T: Default>(db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T) -> T {
db.map(x).unwrap_or_default()
}
if let Some(db_error) = db_error {
@@ -278,17 +274,11 @@ pub async fn handle(
let position = db_error.and_then(|db| db.position());
let (position, internal_position, internal_query) = match position {
Some(ErrorPosition::Original(position)) => (
Value::String(position.to_string()),
Value::Null,
Value::Null,
),
Some(ErrorPosition::Internal { position, query }) => (
Value::Null,
Value::String(position.to_string()),
Value::String(query.clone()),
),
None => (Value::Null, Value::Null, Value::Null),
Some(ErrorPosition::Original(position)) => (Some(position.to_string()), None, None),
Some(ErrorPosition::Internal { position, query }) => {
(None, Some(position.to_string()), Some(query.clone()))
}
None => (None, None, None),
};
let code = get(db_error, |db| db.code().code());
@@ -578,10 +568,8 @@ async fn handle_inner(
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json");
//
// Now execute the query and return the result
//
let result = match payload {
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => stmt.process(cancel, &mut client, parsed_headers).await?,
Payload::Batch(statements) => {
if parsed_headers.txn_read_only {
@@ -605,11 +593,9 @@ async fn handle_inner(
let metrics = client.metrics();
// how could this possibly fail
let body = serde_json::to_string(&result).expect("json serialization should not fail");
let len = body.len();
let len = json_output.len();
let response = response
.body(Full::new(Bytes::from(body)))
.body(Full::new(Bytes::from(json_output)))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
@@ -631,7 +617,7 @@ impl QueryData {
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
) -> Result<Value, SqlOverHttpError> {
) -> Result<String, SqlOverHttpError> {
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
@@ -644,7 +630,10 @@ impl QueryData {
// The query successfully completed.
Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
discard.check_idle(status);
Ok(results)
let json_output =
serde_json::to_string(&results).expect("json serialization should not fail");
Ok(json_output)
}
// The query failed with an error
Either::Left((Err(e), __not_yet_cancelled)) => {
@@ -662,7 +651,10 @@ impl QueryData {
// query successed before it was cancelled.
Ok(Ok((status, results))) => {
discard.check_idle(status);
Ok(results)
let json_output = serde_json::to_string(&results)
.expect("json serialization should not fail");
Ok(json_output)
}
// query failed or was cancelled.
Ok(Err(error)) => {
@@ -696,7 +688,7 @@ impl BatchQueryData {
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
) -> Result<Value, SqlOverHttpError> {
) -> Result<String, SqlOverHttpError> {
info!("starting transaction");
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
@@ -718,9 +710,9 @@ impl BatchQueryData {
e
})?;
let results =
let json_output =
match query_batch(cancel.child_token(), &transaction, self, parsed_headers).await {
Ok(results) => {
Ok(json_output) => {
info!("commit");
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
@@ -729,7 +721,7 @@ impl BatchQueryData {
e
})?;
discard.check_idle(status);
results
json_output
}
Err(SqlOverHttpError::Cancelled(_)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
@@ -753,7 +745,7 @@ impl BatchQueryData {
}
};
Ok(json!({ "results": results }))
Ok(json_output)
}
}
@@ -762,7 +754,7 @@ async fn query_batch(
transaction: &Transaction<'_>,
queries: BatchQueryData,
parsed_headers: HttpHeaders,
) -> Result<Vec<Value>, SqlOverHttpError> {
) -> Result<String, SqlOverHttpError> {
let mut results = Vec::with_capacity(queries.queries.len());
let mut current_size = 0;
for stmt in queries.queries {
@@ -787,7 +779,11 @@ async fn query_batch(
}
}
}
Ok(results)
let results = json!({ "results": results });
let json_output = serde_json::to_string(&results).expect("json serialization should not fail");
Ok(json_output)
}
async fn query_to_json<T: GenericClient>(
@@ -795,7 +791,7 @@ async fn query_to_json<T: GenericClient>(
data: QueryData,
current_size: &mut usize,
parsed_headers: HttpHeaders,
) -> Result<(ReadyForQueryStatus, Value), SqlOverHttpError> {
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
info!("executing query");
let query_params = data.params;
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
@@ -844,8 +840,8 @@ async fn query_to_json<T: GenericClient>(
for c in row_stream.columns() {
fields.push(json!({
"name": Value::String(c.name().to_owned()),
"dataTypeID": Value::Number(c.type_().oid().into()),
"name": c.name().to_owned(),
"dataTypeID": c.type_().oid(),
"tableID": c.table_oid(),
"columnID": c.column_id(),
"dataTypeSize": c.type_size(),
@@ -863,15 +859,14 @@ async fn query_to_json<T: GenericClient>(
.map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode))
.collect::<Result<Vec<_>, _>>()?;
// resulting JSON format is based on the format of node-postgres result
Ok((
ready,
json!({
"command": command_tag_name,
"rowCount": command_tag_count,
"rows": rows,
"fields": fields,
"rowAsArray": array_mode,
}),
))
// Resulting JSON format is based on the format of node-postgres result.
let results = json!({
"command": command_tag_name.to_string(),
"rowCount": command_tag_count,
"rows": rows,
"fields": fields,
"rowAsArray": array_mode,
});
Ok((ready, results))
}