some start to using arenas

This commit is contained in:
Conrad Ludgate
2024-09-17 10:57:50 +01:00
parent 5a9138a764
commit b5ad693a87
2 changed files with 205 additions and 25 deletions

View File

@@ -1,7 +1,10 @@
use std::fmt;
use std::marker::PhantomData;
use std::ops::Range;
use itertools::Itertools;
use serde::de;
use serde::de::DeserializeSeed;
use serde::Deserialize;
use serde::Deserializer;
use serde_json::Map;
@@ -14,12 +17,115 @@ use super::sql_over_http::BatchQueryData;
use super::sql_over_http::Payload;
use super::sql_over_http::QueryData;
#[derive(Clone, Copy)]
pub struct Slice {
pub start: u32,
pub len: u32,
}
impl Slice {
pub fn into_range(self) -> Range<usize> {
let start = self.start as usize;
let end = start + self.len as usize;
start..end
}
}
#[derive(Default)]
pub struct Arena {
pub str_arena: String,
pub params_arena: Vec<Option<Slice>>,
}
impl Arena {
fn alloc_str(&mut self, s: &str) -> Slice {
let start = self.str_arena.len() as u32;
let len = s.len() as u32;
self.str_arena.push_str(s);
Slice { start, len }
}
}
pub struct SerdeArena<'a, T> {
pub arena: &'a mut Arena,
pub _t: PhantomData<T>,
}
impl<'a, T> SerdeArena<'a, T> {
fn alloc_str(&mut self, s: &str) -> Slice {
self.arena.alloc_str(s)
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Vec<QueryData>> {
type Value = Vec<QueryData>;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct VecVisitor<'a>(SerdeArena<'a, Vec<QueryData>>);
impl<'a, 'de> de::Visitor<'de> for VecVisitor<'a> {
type Value = Vec<QueryData>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let mut values = Vec::new();
while let Some(value) = seq.next_element_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<QueryData>,
})? {
values.push(value);
}
Ok(values)
}
}
d.deserialize_seq(VecVisitor(self))
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Slice> {
type Value = Slice;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'a>(SerdeArena<'a, Slice>);
impl<'a, 'de> de::Visitor<'de> for Visitor<'a> {
type Value = Slice;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string")
}
fn visit_str<E>(mut self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(self.0.alloc_str(v))
}
}
d.deserialize_seq(Visitor(self))
}
}
enum States {
Empty,
HasQueries(Vec<QueryData>),
HasPartialQueryData {
query: Option<String>,
params: Option<Vec<Option<String>>>,
query: Option<Slice>,
params: Option<Slice>,
#[allow(clippy::option_option)]
array_mode: Option<Option<bool>>,
},
@@ -73,13 +179,14 @@ impl<'de> Deserialize<'de> for Field {
}
}
impl<'de> Deserialize<'de> for QueryData {
fn deserialize<D>(d: D) -> Result<Self, D::Error>
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, QueryData> {
type Value = QueryData;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
struct Visitor<'a>(SerdeArena<'a, QueryData>);
impl<'a, 'de> de::Visitor<'de> for Visitor<'a> {
type Value = QueryData;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
@@ -109,7 +216,10 @@ impl<'de> Deserialize<'de> for QueryData {
} => (params, array_mode),
};
state = States::HasPartialQueryData {
query: Some(m.next_value()?),
query: Some(m.next_value_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<Slice>,
})?),
params,
array_mode,
};
@@ -129,9 +239,23 @@ impl<'de> Deserialize<'de> for QueryData {
array_mode,
} => (query, array_mode),
};
let params = m.next_value::<PgText>()?.value;
let start = self.0.arena.params_arena.len() as u32;
let len = params.len() as u32;
for param in params {
match param {
Some(s) => {
let s = self.0.arena.alloc_str(&s);
self.0.arena.params_arena.push(Some(s));
}
None => self.0.arena.params_arena.push(None),
}
}
state = States::HasPartialQueryData {
query,
params: Some(m.next_value::<PgText>()?.value),
params: Some(Slice { start, len }),
array_mode,
};
}
@@ -185,17 +309,23 @@ impl<'de> Deserialize<'de> for QueryData {
}
}
Deserializer::deserialize_struct(d, "QueryData", &["query", "params", "arrayMode"], Visitor)
Deserializer::deserialize_struct(
d,
"QueryData",
&["query", "params", "arrayMode"],
Visitor(self),
)
}
}
impl<'de> Deserialize<'de> for Payload {
fn deserialize<D>(d: D) -> Result<Self, D::Error>
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Payload> {
type Value = Payload;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
struct Visitor<'a>(SerdeArena<'a, Payload>);
impl<'a, 'de> de::Visitor<'de> for Visitor<'a> {
type Value = Payload;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
@@ -212,7 +342,12 @@ impl<'de> Deserialize<'de> for Payload {
while let Some(key) = m.next_key()? {
match key {
Field::Queries => match state {
States::Empty => state = States::HasQueries(m.next_value()?),
States::Empty => {
state = States::HasQueries(m.next_value_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<Vec<QueryData>>,
})?);
}
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::duplicate_field("queries"))
}
@@ -242,7 +377,10 @@ impl<'de> Deserialize<'de> for Payload {
} => (params, array_mode),
};
state = States::HasPartialQueryData {
query: Some(m.next_value()?),
query: Some(m.next_value_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<Slice>,
})?),
params,
array_mode,
};
@@ -267,9 +405,23 @@ impl<'de> Deserialize<'de> for Payload {
array_mode,
} => (query, array_mode),
};
let params = m.next_value::<PgText>()?.value;
let start = self.0.arena.params_arena.len() as u32;
let len = params.len() as u32;
for param in params {
match param {
Some(s) => {
let s = self.0.arena.alloc_str(&s);
self.0.arena.params_arena.push(Some(s));
}
None => self.0.arena.params_arena.push(None),
}
}
state = States::HasPartialQueryData {
query,
params: Some(m.next_value::<PgText>()?.value),
params: Some(Slice { start, len }),
array_mode,
};
}
@@ -332,7 +484,7 @@ impl<'de> Deserialize<'de> for Payload {
d,
"Payload",
&["queries", "query", "params", "arrayMode"],
Visitor,
Visitor(self),
)
}
}

View File

@@ -56,6 +56,8 @@ use crate::metrics::HttpDirection;
use crate::metrics::Metrics;
use crate::proxy::run_until_cancelled;
use crate::proxy::NeonOptions;
use crate::serverless::json::Arena;
use crate::serverless::json::SerdeArena;
use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName;
use crate::RoleName;
@@ -70,10 +72,11 @@ use super::conn_pool::ConnInfoWithAuth;
use super::http_util::json_response;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
use super::json::Slice;
pub(crate) struct QueryData {
pub(crate) query: String,
pub(crate) params: Vec<Option<String>>,
pub(crate) query: Slice,
pub(crate) params: Slice,
pub(crate) array_mode: Option<bool>,
}
@@ -604,9 +607,15 @@ async fn handle_db_inner(
));
}
let mut arena = Arena::default();
let fetch_and_process_request = Box::pin(async {
let seed = SerdeArena {
arena: &mut arena,
_t: PhantomData::<Payload>,
};
let payload = parse_json_body_with_limit(
PhantomData,
seed,
request.into_body(),
config.http_config.max_request_size_bytes as usize,
)
@@ -676,7 +685,7 @@ async fn handle_db_inner(
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => {
stmt.process(config, cancel, &mut client, parsed_headers)
stmt.process(config, &arena, cancel, &mut client, parsed_headers)
.await?
}
Payload::Batch(statements) => {
@@ -694,7 +703,7 @@ async fn handle_db_inner(
}
statements
.process(config, cancel, &mut client, parsed_headers)
.process(config, &arena, cancel, &mut client, parsed_headers)
.await?
}
};
@@ -786,6 +795,7 @@ impl QueryData {
async fn process(
self,
config: &'static ProxyConfig,
arena: &Arena,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -794,7 +804,14 @@ impl QueryData {
let cancel_token = inner.cancel_token();
let res = match select(
pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)),
pin!(query_to_json(
config,
arena,
&*inner,
self,
&mut 0,
parsed_headers
)),
pin!(cancel.cancelled()),
)
.await
@@ -860,6 +877,7 @@ impl BatchQueryData {
async fn process(
self,
config: &'static ProxyConfig,
arena: &Arena,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -886,6 +904,7 @@ impl BatchQueryData {
let json_output = match query_batch(
config,
arena,
cancel.child_token(),
&transaction,
self,
@@ -930,6 +949,7 @@ impl BatchQueryData {
async fn query_batch(
config: &'static ProxyConfig,
arena: &Arena,
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
@@ -940,6 +960,7 @@ async fn query_batch(
for stmt in queries.queries {
let query = pin!(query_to_json(
config,
arena,
transaction,
stmt,
&mut current_size,
@@ -969,14 +990,21 @@ async fn query_batch(
async fn query_to_json<T: GenericClient>(
config: &'static ProxyConfig,
arena: &Arena,
client: &T,
data: QueryData,
current_size: &mut usize,
parsed_headers: HttpHeaders,
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
info!("executing query");
let query_params = data.params;
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
let query_params = arena.params_arena[data.params.into_range()]
.iter()
.map(|p| p.map(|p| &arena.str_arena[p.into_range()]));
let query = &arena.str_arena[data.query.into_range()];
let mut row_stream = std::pin::pin!(client.query_raw_txt(query, query_params).await?);
info!("finished executing query");
// Manually drain the stream into a vector to leave row_stream hanging