proxy http error classification (#7098)

## Problem

Missing error classification for SQL-over-HTTP queries.
Not respecting `UserFacingError` for SQL-over-HTTP queries.

## Summary of changes

Adds error classification.
Adds user facing errors.
This commit is contained in:
Conrad Ludgate
2024-03-13 06:35:49 +00:00
committed by GitHub
parent 1b41db8bdd
commit 83855a907c
3 changed files with 204 additions and 74 deletions

View File

@@ -12,6 +12,7 @@ use crate::{
CachedNodeInfo,
},
context::RequestMonitoring,
error::{ErrorKind, ReportableError, UserFacingError},
proxy::connect_compute::ConnectMechanism,
};
@@ -117,6 +118,30 @@ pub enum HttpConnError {
WakeCompute(#[from] WakeComputeError),
}
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::ConnectionError(p) => p.get_error_kind(),
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
HttpConnError::WakeCompute(w) => w.get_error_kind(),
}
}
}
impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::ConnectionError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
HttpConnError::WakeCompute(c) => c.to_string_client(),
}
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
conn_info: ConnInfo,

View File

@@ -119,16 +119,12 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
}
}
fn put(
pool: &RwLock<Self>,
conn_info: &ConnInfo,
client: ClientInner<C>,
) -> anyhow::Result<()> {
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
let conn_id = client.conn_id;
if client.is_closed() {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return Ok(());
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool
@@ -138,7 +134,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
>= global_max_conn
{
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
return Ok(());
return;
}
// return connection to the pool
@@ -172,8 +168,6 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
} else {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
Ok(())
}
}
@@ -653,7 +647,7 @@ impl<C: ClientInnerExt> Client<C> {
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
let _ = EndpointConnPool::put(&conn_pool, &conn_info, client);
EndpointConnPool::put(&conn_pool, &conn_info, client);
});
}
None

View File

@@ -1,11 +1,11 @@
use std::pin::pin;
use std::sync::Arc;
use anyhow::bail;
use futures::future::select;
use futures::future::try_join;
use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use hyper::body::HttpBody;
use hyper::header;
use hyper::http::HeaderName;
@@ -37,9 +37,13 @@ use crate::auth::ComputeUserInfoParseError;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::error::ErrorKind;
use crate::error::ReportableError;
use crate::error::UserFacingError;
use crate::metrics::HTTP_CONTENT_LENGTH;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::DbName;
use crate::RoleName;
@@ -47,6 +51,7 @@ use super::backend::PoolingBackend;
use super::conn_pool::ConnInfo;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -117,6 +122,18 @@ pub enum ConnInfoError {
MalformedEndpoint,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
fn get_conn_info(
ctx: &mut RequestMonitoring,
headers: &HeaderMap,
@@ -212,17 +229,41 @@ pub async fn handle(
handle.abort();
let mut response = match result {
Ok(Ok(r)) => {
Ok(r) => {
ctx.set_success();
r
}
Err(e) => {
// TODO: ctx.set_error_kind(e.get_error_type());
Err(e @ SqlOverHttpError::Cancelled(_)) => {
let error_kind = e.get_error_kind();
ctx.set_error_kind(error_kind);
let mut message = format!("{:?}", e);
let db_error = e
.downcast_ref::<tokio_postgres::Error>()
.and_then(|e| e.as_db_error());
let message = format!(
"Query cancelled, runtime exceeded. SQL queries over HTTP must not exceed {} seconds of runtime. Please consider using our websocket based connections",
config.http_config.request_timeout.as_secs_f64()
);
tracing::info!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
json_response(
StatusCode::BAD_REQUEST,
json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
)?
}
Err(e) => {
let error_kind = e.get_error_kind();
ctx.set_error_kind(error_kind);
let mut message = e.to_string_client();
let db_error = match &e {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
fn get<'a, T: serde::Serialize>(
db: Option<&'a DbError>,
x: impl FnOnce(&'a DbError) -> T,
@@ -265,10 +306,13 @@ pub async fn handle(
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
let routine = get(db_error, |db| db.routine());
error!(
?code,
"sql-over-http per-client task finished with an error: {e:#}"
tracing::info!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
// TODO: this shouldn't always be bad request.
json_response(
StatusCode::BAD_REQUEST,
@@ -293,21 +337,6 @@ pub async fn handle(
}),
)?
}
Ok(Err(Cancelled())) => {
// TODO: when http error classification is done, distinguish between
// timeout on sql vs timeout in proxy/cplane
// ctx.set_error_kind(crate::error::ErrorKind::RateLimit);
let message = format!(
"Query cancelled, runtime exceeded. SQL queries over HTTP must not exceed {} seconds of runtime. Please consider using our websocket based connections",
config.http_config.request_timeout.as_secs_f64()
);
error!(message);
json_response(
StatusCode::BAD_REQUEST,
json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
)?
}
};
response.headers_mut().insert(
@@ -317,7 +346,93 @@ pub async fn handle(
Ok(response)
}
struct Cancelled();
#[derive(Debug, thiserror::Error)]
pub enum SqlOverHttpError {
#[error("{0}")]
ReadPayload(#[from] ReadPayloadError),
#[error("{0}")]
ConnectCompute(#[from] HttpConnError),
#[error("{0}")]
ConnInfo(#[from] ConnInfoError),
#[error("request is too large (max is {MAX_REQUEST_SIZE} bytes)")]
RequestTooLarge,
#[error("response is too large (max is {MAX_RESPONSE_SIZE} bytes)")]
ResponseTooLarge,
#[error("invalid isolation level")]
InvalidIsolationLevel,
#[error("{0}")]
Postgres(#[from] tokio_postgres::Error),
#[error("{0}")]
JsonConversion(#[from] JsonConversionError),
#[error("{0}")]
Cancelled(SqlOverHttpCancel),
}
impl ReportableError for SqlOverHttpError {
fn get_error_kind(&self) -> ErrorKind {
match self {
SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
SqlOverHttpError::RequestTooLarge => ErrorKind::User,
SqlOverHttpError::ResponseTooLarge => ErrorKind::User,
SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
SqlOverHttpError::Postgres(p) => p.get_error_kind(),
SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
}
}
}
impl UserFacingError for SqlOverHttpError {
fn to_string_client(&self) -> String {
match self {
SqlOverHttpError::ReadPayload(p) => p.to_string(),
SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
SqlOverHttpError::RequestTooLarge => self.to_string(),
SqlOverHttpError::ResponseTooLarge => self.to_string(),
SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
SqlOverHttpError::Postgres(p) => p.to_string(),
SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
SqlOverHttpError::Cancelled(_) => self.to_string(),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum SqlOverHttpCancel {
#[error("query was cancelled")]
Postgres,
#[error("query was cancelled while stuck trying to connect to the database")]
Connect,
}
impl ReportableError for SqlOverHttpCancel {
fn get_error_kind(&self) -> ErrorKind {
match self {
SqlOverHttpCancel::Postgres => ErrorKind::RateLimit,
SqlOverHttpCancel::Connect => ErrorKind::ServiceRateLimit,
}
}
}
async fn handle_inner(
cancel: CancellationToken,
@@ -325,7 +440,7 @@ async fn handle_inner(
ctx: &mut RequestMonitoring,
request: Request<Body>,
backend: Arc<PoolingBackend>,
) -> Result<Result<Response<Body>, Cancelled>, anyhow::Error> {
) -> Result<Response<Body>, SqlOverHttpError> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&[ctx.protocol])
.guard();
@@ -358,7 +473,7 @@ async fn handle_inner(
b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
b"ReadCommitted" => IsolationLevel::ReadCommitted,
b"RepeatableRead" => IsolationLevel::RepeatableRead,
_ => bail!("invalid isolation level"),
_ => return Err(SqlOverHttpError::InvalidIsolationLevel),
}),
None => None,
};
@@ -376,19 +491,16 @@ async fn handle_inner(
// we don't have a streaming request support yet so this is to prevent OOM
// from a malicious user sending an extremely large request body
if request_content_length > MAX_REQUEST_SIZE {
return Err(anyhow::anyhow!(
"request is too large (max is {MAX_REQUEST_SIZE} bytes)"
));
return Err(SqlOverHttpError::RequestTooLarge);
}
let fetch_and_process_request = async {
let body = hyper::body::to_bytes(request.into_body())
.await
.map_err(anyhow::Error::from)?;
let body = hyper::body::to_bytes(request.into_body()).await?;
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
};
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from);
let authenticate_and_connect = async {
let keys = backend.authenticate(ctx, &conn_info).await?;
@@ -398,8 +510,9 @@ async fn handle_inner(
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.latency_timer.success();
Ok::<_, anyhow::Error>(client)
};
Ok::<_, HttpConnError>(client)
}
.map_err(SqlOverHttpError::from);
// Run both operations in parallel
let (payload, mut client) = match select(
@@ -412,7 +525,9 @@ async fn handle_inner(
.await
{
Either::Left((result, _cancelled)) => result?,
Either::Right((_cancelled, _)) => return Ok(Err(Cancelled())),
Either::Right((_cancelled, _)) => {
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect))
}
};
let mut response = Response::builder()
@@ -456,20 +571,24 @@ async fn handle_inner(
results
}
Ok(Err(error)) => {
let db_error = error
.downcast_ref::<tokio_postgres::Error>()
.and_then(|e| e.as_db_error());
let db_error = match &error {
SqlOverHttpError::ConnectCompute(
HttpConnError::ConnectionError(e),
)
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
// if errored for some other reason, it might not be safe to return
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
discard.discard();
}
return Ok(Err(Cancelled()));
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
Err(_timeout) => {
discard.discard();
return Ok(Err(Cancelled()));
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
}
}
@@ -507,7 +626,7 @@ async fn handle_inner(
)
.await
{
Ok(Ok(results)) => {
Ok(results) => {
info!("commit");
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
@@ -518,14 +637,14 @@ async fn handle_inner(
discard.check_idle(status);
results
}
Ok(Err(Cancelled())) => {
Err(SqlOverHttpError::Cancelled(_)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
}
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
discard.discard();
return Ok(Err(Cancelled()));
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
Err(err) => {
info!("rollback");
@@ -541,16 +660,10 @@ async fn handle_inner(
};
if txn_read_only {
response = response.header(
TXN_READ_ONLY.clone(),
HeaderValue::try_from(txn_read_only.to_string())?,
);
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
}
if txn_deferrable {
response = response.header(
TXN_DEFERRABLE.clone(),
HeaderValue::try_from(txn_deferrable.to_string())?,
);
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
}
if let Some(txn_isolation_level) = txn_isolation_level_raw {
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
@@ -574,7 +687,7 @@ async fn handle_inner(
// moving this later in the stack is going to be a lot of effort and ehhhh
metrics.record_egress(len as u64);
Ok(Ok(response))
Ok(response)
}
async fn query_batch(
@@ -584,7 +697,7 @@ async fn query_batch(
total_size: &mut usize,
raw_output: bool,
array_mode: bool,
) -> anyhow::Result<Result<Vec<Value>, Cancelled>> {
) -> Result<Vec<Value>, SqlOverHttpError> {
let mut results = Vec::with_capacity(queries.queries.len());
let mut current_size = 0;
for stmt in queries.queries {
@@ -606,12 +719,12 @@ async fn query_batch(
return Err(e);
}
Either::Right((_cancelled, _)) => {
return Ok(Err(Cancelled()));
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
}
}
*total_size += current_size;
Ok(Ok(results))
Ok(results)
}
async fn query_to_json<T: GenericClient>(
@@ -620,7 +733,7 @@ async fn query_to_json<T: GenericClient>(
current_size: &mut usize,
raw_output: bool,
default_array_mode: bool,
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
) -> Result<(ReadyForQueryStatus, Value), SqlOverHttpError> {
info!("executing query");
let query_params = data.params;
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
@@ -637,9 +750,7 @@ async fn query_to_json<T: GenericClient>(
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
if *current_size > MAX_RESPONSE_SIZE {
return Err(anyhow::anyhow!(
"response is too large (max is {MAX_RESPONSE_SIZE} bytes)"
));
return Err(SqlOverHttpError::ResponseTooLarge);
}
}