From b5ad693a87fea183230889b4a965efe3be5fdad1 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 17 Sep 2024 10:57:50 +0100 Subject: [PATCH] some start to using arenas --- proxy/src/serverless/json.rs | 186 +++++++++++++++++++++++--- proxy/src/serverless/sql_over_http.rs | 44 ++++-- 2 files changed, 205 insertions(+), 25 deletions(-) diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 6362cc4e39..49fa6bd346 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -1,7 +1,10 @@ use std::fmt; +use std::marker::PhantomData; +use std::ops::Range; use itertools::Itertools; use serde::de; +use serde::de::DeserializeSeed; use serde::Deserialize; use serde::Deserializer; use serde_json::Map; @@ -14,12 +17,115 @@ use super::sql_over_http::BatchQueryData; use super::sql_over_http::Payload; use super::sql_over_http::QueryData; +#[derive(Clone, Copy)] +pub struct Slice { + pub start: u32, + pub len: u32, +} + +impl Slice { + pub fn into_range(self) -> Range { + let start = self.start as usize; + let end = start + self.len as usize; + start..end + } +} + +#[derive(Default)] +pub struct Arena { + pub str_arena: String, + pub params_arena: Vec>, +} + +impl Arena { + fn alloc_str(&mut self, s: &str) -> Slice { + let start = self.str_arena.len() as u32; + let len = s.len() as u32; + self.str_arena.push_str(s); + Slice { start, len } + } +} + +pub struct SerdeArena<'a, T> { + pub arena: &'a mut Arena, + pub _t: PhantomData, +} + +impl<'a, T> SerdeArena<'a, T> { + fn alloc_str(&mut self, s: &str) -> Slice { + self.arena.alloc_str(s) + } +} + +impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Vec> { + type Value = Vec; + fn deserialize(self, d: D) -> Result + where + D: Deserializer<'de>, + { + struct VecVisitor<'a>(SerdeArena<'a, Vec>); + + impl<'a, 'de> de::Visitor<'de> for VecVisitor<'a> { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut values = Vec::new(); + + while let Some(value) = seq.next_element_seed(SerdeArena { + arena: &mut *self.0.arena, + _t: PhantomData::, + })? { + values.push(value); + } + + Ok(values) + } + } + + d.deserialize_seq(VecVisitor(self)) + } +} + +impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Slice> { + type Value = Slice; + fn deserialize(self, d: D) -> Result + where + D: Deserializer<'de>, + { + struct Visitor<'a>(SerdeArena<'a, Slice>); + + impl<'a, 'de> de::Visitor<'de> for Visitor<'a> { + type Value = Slice; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string") + } + + fn visit_str(mut self, v: &str) -> Result + where + E: de::Error, + { + Ok(self.0.alloc_str(v)) + } + } + + d.deserialize_seq(Visitor(self)) + } +} + enum States { Empty, HasQueries(Vec), HasPartialQueryData { - query: Option, - params: Option>>, + query: Option, + params: Option, #[allow(clippy::option_option)] array_mode: Option>, }, @@ -73,13 +179,14 @@ impl<'de> Deserialize<'de> for Field { } } -impl<'de> Deserialize<'de> for QueryData { - fn deserialize(d: D) -> Result +impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, QueryData> { + type Value = QueryData; + fn deserialize(self, d: D) -> Result where D: Deserializer<'de>, { - struct Visitor; - impl<'de> de::Visitor<'de> for Visitor { + struct Visitor<'a>(SerdeArena<'a, QueryData>); + impl<'a, 'de> de::Visitor<'de> for Visitor<'a> { type Value = QueryData; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str( @@ -109,7 +216,10 @@ impl<'de> Deserialize<'de> for QueryData { } => (params, array_mode), }; state = States::HasPartialQueryData { - query: Some(m.next_value()?), + query: Some(m.next_value_seed(SerdeArena { + arena: &mut *self.0.arena, + _t: PhantomData::, + })?), params, array_mode, }; @@ -129,9 +239,23 @@ impl<'de> Deserialize<'de> for QueryData { array_mode, } => (query, array_mode), }; + + let params = m.next_value::()?.value; + let start = self.0.arena.params_arena.len() as u32; + let len = params.len() as u32; + for param in params { + match param { + Some(s) => { + let s = self.0.arena.alloc_str(&s); + self.0.arena.params_arena.push(Some(s)); + } + None => self.0.arena.params_arena.push(None), + } + } + state = States::HasPartialQueryData { query, - params: Some(m.next_value::()?.value), + params: Some(Slice { start, len }), array_mode, }; } @@ -185,17 +309,23 @@ impl<'de> Deserialize<'de> for QueryData { } } - Deserializer::deserialize_struct(d, "QueryData", &["query", "params", "arrayMode"], Visitor) + Deserializer::deserialize_struct( + d, + "QueryData", + &["query", "params", "arrayMode"], + Visitor(self), + ) } } -impl<'de> Deserialize<'de> for Payload { - fn deserialize(d: D) -> Result +impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Payload> { + type Value = Payload; + fn deserialize(self, d: D) -> Result where D: Deserializer<'de>, { - struct Visitor; - impl<'de> de::Visitor<'de> for Visitor { + struct Visitor<'a>(SerdeArena<'a, Payload>); + impl<'a, 'de> de::Visitor<'de> for Visitor<'a> { type Value = Payload; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str( @@ -212,7 +342,12 @@ impl<'de> Deserialize<'de> for Payload { while let Some(key) = m.next_key()? { match key { Field::Queries => match state { - States::Empty => state = States::HasQueries(m.next_value()?), + States::Empty => { + state = States::HasQueries(m.next_value_seed(SerdeArena { + arena: &mut *self.0.arena, + _t: PhantomData::>, + })?); + } States::HasQueries(_) => { return Err(::duplicate_field("queries")) } @@ -242,7 +377,10 @@ impl<'de> Deserialize<'de> for Payload { } => (params, array_mode), }; state = States::HasPartialQueryData { - query: Some(m.next_value()?), + query: Some(m.next_value_seed(SerdeArena { + arena: &mut *self.0.arena, + _t: PhantomData::, + })?), params, array_mode, }; @@ -267,9 +405,23 @@ impl<'de> Deserialize<'de> for Payload { array_mode, } => (query, array_mode), }; + + let params = m.next_value::()?.value; + let start = self.0.arena.params_arena.len() as u32; + let len = params.len() as u32; + for param in params { + match param { + Some(s) => { + let s = self.0.arena.alloc_str(&s); + self.0.arena.params_arena.push(Some(s)); + } + None => self.0.arena.params_arena.push(None), + } + } + state = States::HasPartialQueryData { query, - params: Some(m.next_value::()?.value), + params: Some(Slice { start, len }), array_mode, }; } @@ -332,7 +484,7 @@ impl<'de> Deserialize<'de> for Payload { d, "Payload", &["queries", "query", "params", "arrayMode"], - Visitor, + Visitor(self), ) } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 4cfc342a11..55bf75cf76 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -56,6 +56,8 @@ use crate::metrics::HttpDirection; use crate::metrics::Metrics; use crate::proxy::run_until_cancelled; use crate::proxy::NeonOptions; +use crate::serverless::json::Arena; +use crate::serverless::json::SerdeArena; use crate::usage_metrics::MetricCounterRecorder; use crate::DbName; use crate::RoleName; @@ -70,10 +72,11 @@ use super::conn_pool::ConnInfoWithAuth; use super::http_util::json_response; use super::json::pg_text_row_to_json; use super::json::JsonConversionError; +use super::json::Slice; pub(crate) struct QueryData { - pub(crate) query: String, - pub(crate) params: Vec>, + pub(crate) query: Slice, + pub(crate) params: Slice, pub(crate) array_mode: Option, } @@ -604,9 +607,15 @@ async fn handle_db_inner( )); } + let mut arena = Arena::default(); + let fetch_and_process_request = Box::pin(async { + let seed = SerdeArena { + arena: &mut arena, + _t: PhantomData::, + }; let payload = parse_json_body_with_limit( - PhantomData, + seed, request.into_body(), config.http_config.max_request_size_bytes as usize, ) @@ -676,7 +685,7 @@ async fn handle_db_inner( // Now execute the query and return the result. let json_output = match payload { Payload::Single(stmt) => { - stmt.process(config, cancel, &mut client, parsed_headers) + stmt.process(config, &arena, cancel, &mut client, parsed_headers) .await? } Payload::Batch(statements) => { @@ -694,7 +703,7 @@ async fn handle_db_inner( } statements - .process(config, cancel, &mut client, parsed_headers) + .process(config, &arena, cancel, &mut client, parsed_headers) .await? } }; @@ -786,6 +795,7 @@ impl QueryData { async fn process( self, config: &'static ProxyConfig, + arena: &Arena, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -794,7 +804,14 @@ impl QueryData { let cancel_token = inner.cancel_token(); let res = match select( - pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)), + pin!(query_to_json( + config, + arena, + &*inner, + self, + &mut 0, + parsed_headers + )), pin!(cancel.cancelled()), ) .await @@ -860,6 +877,7 @@ impl BatchQueryData { async fn process( self, config: &'static ProxyConfig, + arena: &Arena, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -886,6 +904,7 @@ impl BatchQueryData { let json_output = match query_batch( config, + arena, cancel.child_token(), &transaction, self, @@ -930,6 +949,7 @@ impl BatchQueryData { async fn query_batch( config: &'static ProxyConfig, + arena: &Arena, cancel: CancellationToken, transaction: &Transaction<'_>, queries: BatchQueryData, @@ -940,6 +960,7 @@ async fn query_batch( for stmt in queries.queries { let query = pin!(query_to_json( config, + arena, transaction, stmt, &mut current_size, @@ -969,14 +990,21 @@ async fn query_batch( async fn query_to_json( config: &'static ProxyConfig, + arena: &Arena, client: &T, data: QueryData, current_size: &mut usize, parsed_headers: HttpHeaders, ) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> { info!("executing query"); - let query_params = data.params; - let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?); + + let query_params = arena.params_arena[data.params.into_range()] + .iter() + .map(|p| p.map(|p| &arena.str_arena[p.into_range()])); + + let query = &arena.str_arena[data.query.into_range()]; + + let mut row_stream = std::pin::pin!(client.query_raw_txt(query, query_params).await?); info!("finished executing query"); // Manually drain the stream into a vector to leave row_stream hanging