From c5cd8577ff6d96e8153dd22af17373c4351e52e4 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Wed, 18 Sep 2024 13:58:51 +0200 Subject: [PATCH] proxy: make sql-over-http max request/response sizes configurable (#9029) --- proxy/src/bin/local_proxy.rs | 8 ++ proxy/src/bin/proxy.rs | 8 ++ proxy/src/config.rs | 2 + proxy/src/serverless/conn_pool.rs | 2 + proxy/src/serverless/sql_over_http.rs | 114 +++++++++++++++----------- 5 files changed, 85 insertions(+), 49 deletions(-) diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 6eba71df1b..94365ddf05 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -92,6 +92,12 @@ struct SqlOverHttpArgs { #[clap(long, default_value_t = 16)] sql_over_http_cancel_set_shards: usize, + + #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB + sql_over_http_max_request_size_bytes: u64, + + #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB + sql_over_http_max_response_size_bytes: usize, } #[tokio::main] @@ -208,6 +214,8 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig }, cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, + max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes, + max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, }; Ok(Box::leak(Box::new(ProxyConfig { diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index ca9aeb04d8..e5c5b47795 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -268,6 +268,12 @@ struct SqlOverHttpArgs { #[clap(long, default_value_t = 64)] sql_over_http_cancel_set_shards: usize, + + #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB + sql_over_http_max_request_size_bytes: u64, + + #[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB + sql_over_http_max_response_size_bytes: usize, } #[tokio::main] @@ -679,6 +685,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { }, cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, + max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes, + max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, }; let authentication_config = AuthenticationConfig { thread_pool, diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 1cda6d200c..373e4cf650 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -56,6 +56,8 @@ pub struct HttpConfig { pub pool_options: GlobalConnPoolOptions, pub cancel_set: CancelSet, pub client_conn_threshold: u64, + pub max_request_size_bytes: u64, + pub max_response_size_bytes: usize, } pub struct AuthenticationConfig { diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index bea599e9b9..6c32d5df0e 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -776,6 +776,8 @@ mod tests { }, cancel_set: CancelSet::new(0), client_conn_threshold: u64::MAX, + max_request_size_bytes: u64::MAX, + max_response_size_bytes: usize::MAX, })); let pool = GlobalConnPool::new(config); let conn_info = ConnInfo { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 2188edc8c5..06e540d149 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -87,9 +87,6 @@ enum Payload { Batch(BatchQueryData), } -const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB -const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB - static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string"); static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); @@ -366,10 +363,10 @@ pub(crate) enum SqlOverHttpError { ConnectCompute(#[from] HttpConnError), #[error("{0}")] ConnInfo(#[from] ConnInfoError), - #[error("request is too large (max is {MAX_REQUEST_SIZE} bytes)")] - RequestTooLarge, - #[error("response is too large (max is {MAX_RESPONSE_SIZE} bytes)")] - ResponseTooLarge, + #[error("request is too large (max is {0} bytes)")] + RequestTooLarge(u64), + #[error("response is too large (max is {0} bytes)")] + ResponseTooLarge(usize), #[error("invalid isolation level")] InvalidIsolationLevel, #[error("{0}")] @@ -386,8 +383,8 @@ impl ReportableError for SqlOverHttpError { SqlOverHttpError::ReadPayload(e) => e.get_error_kind(), SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(), SqlOverHttpError::ConnInfo(e) => e.get_error_kind(), - SqlOverHttpError::RequestTooLarge => ErrorKind::User, - SqlOverHttpError::ResponseTooLarge => ErrorKind::User, + SqlOverHttpError::RequestTooLarge(_) => ErrorKind::User, + SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User, SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User, SqlOverHttpError::Postgres(p) => p.get_error_kind(), SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres, @@ -402,8 +399,8 @@ impl UserFacingError for SqlOverHttpError { SqlOverHttpError::ReadPayload(p) => p.to_string(), SqlOverHttpError::ConnectCompute(c) => c.to_string_client(), SqlOverHttpError::ConnInfo(c) => c.to_string_client(), - SqlOverHttpError::RequestTooLarge => self.to_string(), - SqlOverHttpError::ResponseTooLarge => self.to_string(), + SqlOverHttpError::RequestTooLarge(_) => self.to_string(), + SqlOverHttpError::ResponseTooLarge(_) => self.to_string(), SqlOverHttpError::InvalidIsolationLevel => self.to_string(), SqlOverHttpError::Postgres(p) => p.to_string(), SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(), @@ -537,7 +534,7 @@ async fn handle_inner( let request_content_length = match request.body().size_hint().upper() { Some(v) => v, - None => MAX_REQUEST_SIZE + 1, + None => config.http_config.max_request_size_bytes + 1, }; info!(request_content_length, "request size in bytes"); Metrics::get() @@ -547,8 +544,10 @@ async fn handle_inner( // we don't have a streaming request support yet so this is to prevent OOM // from a malicious user sending an extremely large request body - if request_content_length > MAX_REQUEST_SIZE { - return Err(SqlOverHttpError::RequestTooLarge); + if request_content_length > config.http_config.max_request_size_bytes { + return Err(SqlOverHttpError::RequestTooLarge( + config.http_config.max_request_size_bytes, + )); } let fetch_and_process_request = Box::pin( @@ -612,7 +611,10 @@ async fn handle_inner( // Now execute the query and return the result. let json_output = match payload { - Payload::Single(stmt) => stmt.process(cancel, &mut client, parsed_headers).await?, + Payload::Single(stmt) => { + stmt.process(config, cancel, &mut client, parsed_headers) + .await? + } Payload::Batch(statements) => { if parsed_headers.txn_read_only { response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE); @@ -628,7 +630,7 @@ async fn handle_inner( } statements - .process(cancel, &mut client, parsed_headers) + .process(config, cancel, &mut client, parsed_headers) .await? } }; @@ -656,6 +658,7 @@ async fn handle_inner( impl QueryData { async fn process( self, + config: &'static ProxyConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -664,7 +667,7 @@ impl QueryData { let cancel_token = inner.cancel_token(); let res = match select( - pin!(query_to_json(&*inner, self, &mut 0, parsed_headers)), + pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)), pin!(cancel.cancelled()), ) .await @@ -727,6 +730,7 @@ impl QueryData { impl BatchQueryData { async fn process( self, + config: &'static ProxyConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -751,44 +755,52 @@ impl BatchQueryData { discard.discard(); })?; - let json_output = - match query_batch(cancel.child_token(), &transaction, self, parsed_headers).await { - Ok(json_output) => { - info!("commit"); - let status = transaction.commit().await.inspect_err(|_| { - // if we cannot commit - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - })?; - discard.check_idle(status); - json_output - } - Err(SqlOverHttpError::Cancelled(_)) => { - if let Err(err) = cancel_token.cancel_query(NoTls).await { - tracing::error!(?err, "could not cancel query"); - } - // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe. + let json_output = match query_batch( + config, + cancel.child_token(), + &transaction, + self, + parsed_headers, + ) + .await + { + Ok(json_output) => { + info!("commit"); + let status = transaction.commit().await.inspect_err(|_| { + // if we cannot commit - for now don't return connection to pool + // TODO: get a query status from the error discard.discard(); + })?; + discard.check_idle(status); + json_output + } + Err(SqlOverHttpError::Cancelled(_)) => { + if let Err(err) = cancel_token.cancel_query(NoTls).await { + tracing::error!(?err, "could not cancel query"); + } + // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe. + discard.discard(); - return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); - } - Err(err) => { - info!("rollback"); - let status = transaction.rollback().await.inspect_err(|_| { - // if we cannot rollback - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - })?; - discard.check_idle(status); - return Err(err); - } - }; + return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); + } + Err(err) => { + info!("rollback"); + let status = transaction.rollback().await.inspect_err(|_| { + // if we cannot rollback - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + })?; + discard.check_idle(status); + return Err(err); + } + }; Ok(json_output) } } async fn query_batch( + config: &'static ProxyConfig, cancel: CancellationToken, transaction: &Transaction<'_>, queries: BatchQueryData, @@ -798,6 +810,7 @@ async fn query_batch( let mut current_size = 0; for stmt in queries.queries { let query = pin!(query_to_json( + config, transaction, stmt, &mut current_size, @@ -826,6 +839,7 @@ async fn query_batch( } async fn query_to_json( + config: &'static ProxyConfig, client: &T, data: QueryData, current_size: &mut usize, @@ -846,8 +860,10 @@ async fn query_to_json( rows.push(row); // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) - if *current_size > MAX_RESPONSE_SIZE { - return Err(SqlOverHttpError::ResponseTooLarge); + if *current_size > config.http_config.max_response_size_bytes { + return Err(SqlOverHttpError::ResponseTooLarge( + config.http_config.max_response_size_bytes, + )); } }