proxy: make sql-over-http max request/response sizes configurable (#9029)

This commit is contained in:
Folke Behrens
2024-09-18 13:58:51 +02:00
committed by GitHub
parent 3454ef7507
commit c5cd8577ff
5 changed files with 85 additions and 49 deletions

View File

@@ -92,6 +92,12 @@ struct SqlOverHttpArgs {
#[clap(long, default_value_t = 16)]
sql_over_http_cancel_set_shards: usize,
#[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
sql_over_http_max_request_size_bytes: u64,
#[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
sql_over_http_max_response_size_bytes: usize,
}
#[tokio::main]
@@ -208,6 +214,8 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
},
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
Ok(Box::leak(Box::new(ProxyConfig {

View File

@@ -268,6 +268,12 @@ struct SqlOverHttpArgs {
#[clap(long, default_value_t = 64)]
sql_over_http_cancel_set_shards: usize,
#[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
sql_over_http_max_request_size_bytes: u64,
#[clap(long, default_value_t = 10 * 1024 * 1024)] // 10 MiB
sql_over_http_max_response_size_bytes: usize,
}
#[tokio::main]
@@ -679,6 +685,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
},
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let authentication_config = AuthenticationConfig {
thread_pool,

View File

@@ -56,6 +56,8 @@ pub struct HttpConfig {
pub pool_options: GlobalConnPoolOptions,
pub cancel_set: CancelSet,
pub client_conn_threshold: u64,
pub max_request_size_bytes: u64,
pub max_response_size_bytes: usize,
}
pub struct AuthenticationConfig {

View File

@@ -776,6 +776,8 @@ mod tests {
},
cancel_set: CancelSet::new(0),
client_conn_threshold: u64::MAX,
max_request_size_bytes: u64::MAX,
max_response_size_bytes: usize::MAX,
}));
let pool = GlobalConnPool::new(config);
let conn_info = ConnInfo {

View File

@@ -87,9 +87,6 @@ enum Payload {
Batch(BatchQueryData),
}
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
@@ -366,10 +363,10 @@ pub(crate) enum SqlOverHttpError {
ConnectCompute(#[from] HttpConnError),
#[error("{0}")]
ConnInfo(#[from] ConnInfoError),
#[error("request is too large (max is {MAX_REQUEST_SIZE} bytes)")]
RequestTooLarge,
#[error("response is too large (max is {MAX_RESPONSE_SIZE} bytes)")]
ResponseTooLarge,
#[error("request is too large (max is {0} bytes)")]
RequestTooLarge(u64),
#[error("response is too large (max is {0} bytes)")]
ResponseTooLarge(usize),
#[error("invalid isolation level")]
InvalidIsolationLevel,
#[error("{0}")]
@@ -386,8 +383,8 @@ impl ReportableError for SqlOverHttpError {
SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
SqlOverHttpError::RequestTooLarge => ErrorKind::User,
SqlOverHttpError::ResponseTooLarge => ErrorKind::User,
SqlOverHttpError::RequestTooLarge(_) => ErrorKind::User,
SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User,
SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
SqlOverHttpError::Postgres(p) => p.get_error_kind(),
SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
@@ -402,8 +399,8 @@ impl UserFacingError for SqlOverHttpError {
SqlOverHttpError::ReadPayload(p) => p.to_string(),
SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
SqlOverHttpError::RequestTooLarge => self.to_string(),
SqlOverHttpError::ResponseTooLarge => self.to_string(),
SqlOverHttpError::RequestTooLarge(_) => self.to_string(),
SqlOverHttpError::ResponseTooLarge(_) => self.to_string(),
SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
SqlOverHttpError::Postgres(p) => p.to_string(),
SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
@@ -537,7 +534,7 @@ async fn handle_inner(
let request_content_length = match request.body().size_hint().upper() {
Some(v) => v,
None => MAX_REQUEST_SIZE + 1,
None => config.http_config.max_request_size_bytes + 1,
};
info!(request_content_length, "request size in bytes");
Metrics::get()
@@ -547,8 +544,10 @@ async fn handle_inner(
// we don't have a streaming request support yet so this is to prevent OOM
// from a malicious user sending an extremely large request body
if request_content_length > MAX_REQUEST_SIZE {
return Err(SqlOverHttpError::RequestTooLarge);
if request_content_length > config.http_config.max_request_size_bytes {
return Err(SqlOverHttpError::RequestTooLarge(
config.http_config.max_request_size_bytes,
));
}
let fetch_and_process_request = Box::pin(
@@ -612,7 +611,10 @@ async fn handle_inner(
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => stmt.process(cancel, &mut client, parsed_headers).await?,
Payload::Single(stmt) => {
stmt.process(config, cancel, &mut client, parsed_headers)
.await?
}
Payload::Batch(statements) => {
if parsed_headers.txn_read_only {
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
@@ -628,7 +630,7 @@ async fn handle_inner(
}
statements
.process(cancel, &mut client, parsed_headers)
.process(config, cancel, &mut client, parsed_headers)
.await?
}
};
@@ -656,6 +658,7 @@ async fn handle_inner(
impl QueryData {
async fn process(
self,
config: &'static ProxyConfig,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -664,7 +667,7 @@ impl QueryData {
let cancel_token = inner.cancel_token();
let res = match select(
pin!(query_to_json(&*inner, self, &mut 0, parsed_headers)),
pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)),
pin!(cancel.cancelled()),
)
.await
@@ -727,6 +730,7 @@ impl QueryData {
impl BatchQueryData {
async fn process(
self,
config: &'static ProxyConfig,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -751,44 +755,52 @@ impl BatchQueryData {
discard.discard();
})?;
let json_output =
match query_batch(cancel.child_token(), &transaction, self, parsed_headers).await {
Ok(json_output) => {
info!("commit");
let status = transaction.commit().await.inspect_err(|_| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})?;
discard.check_idle(status);
json_output
}
Err(SqlOverHttpError::Cancelled(_)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
}
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
let json_output = match query_batch(
config,
cancel.child_token(),
&transaction,
self,
parsed_headers,
)
.await
{
Ok(json_output) => {
info!("commit");
let status = transaction.commit().await.inspect_err(|_| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})?;
discard.check_idle(status);
json_output
}
Err(SqlOverHttpError::Cancelled(_)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
}
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
discard.discard();
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
Err(err) => {
info!("rollback");
let status = transaction.rollback().await.inspect_err(|_| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})?;
discard.check_idle(status);
return Err(err);
}
};
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
Err(err) => {
info!("rollback");
let status = transaction.rollback().await.inspect_err(|_| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})?;
discard.check_idle(status);
return Err(err);
}
};
Ok(json_output)
}
}
async fn query_batch(
config: &'static ProxyConfig,
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
@@ -798,6 +810,7 @@ async fn query_batch(
let mut current_size = 0;
for stmt in queries.queries {
let query = pin!(query_to_json(
config,
transaction,
stmt,
&mut current_size,
@@ -826,6 +839,7 @@ async fn query_batch(
}
async fn query_to_json<T: GenericClient>(
config: &'static ProxyConfig,
client: &T,
data: QueryData,
current_size: &mut usize,
@@ -846,8 +860,10 @@ async fn query_to_json<T: GenericClient>(
rows.push(row);
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
if *current_size > MAX_RESPONSE_SIZE {
return Err(SqlOverHttpError::ResponseTooLarge);
if *current_size > config.http_config.max_response_size_bytes {
return Err(SqlOverHttpError::ResponseTooLarge(
config.http_config.max_response_size_bytes,
));
}
}