diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 2e63ad6c99..d0f155165d 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -12,6 +12,7 @@ use crate::{ CachedNodeInfo, }, context::RequestMonitoring, + error::{ErrorKind, ReportableError, UserFacingError}, proxy::connect_compute::ConnectMechanism, }; @@ -117,6 +118,30 @@ pub enum HttpConnError { WakeCompute(#[from] WakeComputeError), } +impl ReportableError for HttpConnError { + fn get_error_kind(&self) -> ErrorKind { + match self { + HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, + HttpConnError::ConnectionError(p) => p.get_error_kind(), + HttpConnError::GetAuthInfo(a) => a.get_error_kind(), + HttpConnError::AuthError(a) => a.get_error_kind(), + HttpConnError::WakeCompute(w) => w.get_error_kind(), + } + } +} + +impl UserFacingError for HttpConnError { + fn to_string_client(&self) -> String { + match self { + HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), + HttpConnError::ConnectionError(p) => p.to_string(), + HttpConnError::GetAuthInfo(c) => c.to_string_client(), + HttpConnError::AuthError(c) => c.to_string_client(), + HttpConnError::WakeCompute(c) => c.to_string_client(), + } + } +} + struct TokioMechanism { pool: Arc>, conn_info: ConnInfo, diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 901e30224b..c7e8eaef76 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -119,16 +119,12 @@ impl EndpointConnPool { } } - fn put( - pool: &RwLock, - conn_info: &ConnInfo, - client: ClientInner, - ) -> anyhow::Result<()> { + fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) { let conn_id = client.conn_id; if client.is_closed() { info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed"); - return Ok(()); + return; } let global_max_conn = pool.read().global_pool_size_max_conns; if pool @@ -138,7 +134,7 @@ impl EndpointConnPool { >= global_max_conn { info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full"); - return Ok(()); + return; } // return connection to the pool @@ -172,8 +168,6 @@ impl EndpointConnPool { } else { info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); } - - Ok(()) } } @@ -653,7 +647,7 @@ impl Client { // return connection to the pool return Some(move || { let _span = current_span.enter(); - let _ = EndpointConnPool::put(&conn_pool, &conn_info, client); + EndpointConnPool::put(&conn_pool, &conn_info, client); }); } None diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 20d9795b47..86c278030f 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,11 +1,11 @@ use std::pin::pin; use std::sync::Arc; -use anyhow::bail; use futures::future::select; use futures::future::try_join; use futures::future::Either; use futures::StreamExt; +use futures::TryFutureExt; use hyper::body::HttpBody; use hyper::header; use hyper::http::HeaderName; @@ -37,9 +37,13 @@ use crate::auth::ComputeUserInfoParseError; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; +use crate::error::ErrorKind; +use crate::error::ReportableError; +use crate::error::UserFacingError; use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; +use crate::serverless::backend::HttpConnError; use crate::DbName; use crate::RoleName; @@ -47,6 +51,7 @@ use super::backend::PoolingBackend; use super::conn_pool::ConnInfo; use super::json::json_to_pg_text; use super::json::pg_text_row_to_json; +use super::json::JsonConversionError; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] @@ -117,6 +122,18 @@ pub enum ConnInfoError { MalformedEndpoint, } +impl ReportableError for ConnInfoError { + fn get_error_kind(&self) -> ErrorKind { + ErrorKind::User + } +} + +impl UserFacingError for ConnInfoError { + fn to_string_client(&self) -> String { + self.to_string() + } +} + fn get_conn_info( ctx: &mut RequestMonitoring, headers: &HeaderMap, @@ -212,17 +229,41 @@ pub async fn handle( handle.abort(); let mut response = match result { - Ok(Ok(r)) => { + Ok(r) => { ctx.set_success(); r } - Err(e) => { - // TODO: ctx.set_error_kind(e.get_error_type()); + Err(e @ SqlOverHttpError::Cancelled(_)) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); - let mut message = format!("{:?}", e); - let db_error = e - .downcast_ref::() - .and_then(|e| e.as_db_error()); + let message = format!( + "Query cancelled, runtime exceeded. SQL queries over HTTP must not exceed {} seconds of runtime. Please consider using our websocket based connections", + config.http_config.request_timeout.as_secs_f64() + ); + + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg=message, + "forwarding error to user" + ); + + json_response( + StatusCode::BAD_REQUEST, + json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }), + )? + } + Err(e) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); + + let mut message = e.to_string_client(); + let db_error = match &e { + SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e)) + | SqlOverHttpError::Postgres(e) => e.as_db_error(), + _ => None, + }; fn get<'a, T: serde::Serialize>( db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T, @@ -265,10 +306,13 @@ pub async fn handle( let line = get(db_error, |db| db.line().map(|l| l.to_string())); let routine = get(db_error, |db| db.routine()); - error!( - ?code, - "sql-over-http per-client task finished with an error: {e:#}" + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg=message, + "forwarding error to user" ); + // TODO: this shouldn't always be bad request. json_response( StatusCode::BAD_REQUEST, @@ -293,21 +337,6 @@ pub async fn handle( }), )? } - Ok(Err(Cancelled())) => { - // TODO: when http error classification is done, distinguish between - // timeout on sql vs timeout in proxy/cplane - // ctx.set_error_kind(crate::error::ErrorKind::RateLimit); - - let message = format!( - "Query cancelled, runtime exceeded. SQL queries over HTTP must not exceed {} seconds of runtime. Please consider using our websocket based connections", - config.http_config.request_timeout.as_secs_f64() - ); - error!(message); - json_response( - StatusCode::BAD_REQUEST, - json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }), - )? - } }; response.headers_mut().insert( @@ -317,7 +346,93 @@ pub async fn handle( Ok(response) } -struct Cancelled(); +#[derive(Debug, thiserror::Error)] +pub enum SqlOverHttpError { + #[error("{0}")] + ReadPayload(#[from] ReadPayloadError), + #[error("{0}")] + ConnectCompute(#[from] HttpConnError), + #[error("{0}")] + ConnInfo(#[from] ConnInfoError), + #[error("request is too large (max is {MAX_REQUEST_SIZE} bytes)")] + RequestTooLarge, + #[error("response is too large (max is {MAX_RESPONSE_SIZE} bytes)")] + ResponseTooLarge, + #[error("invalid isolation level")] + InvalidIsolationLevel, + #[error("{0}")] + Postgres(#[from] tokio_postgres::Error), + #[error("{0}")] + JsonConversion(#[from] JsonConversionError), + #[error("{0}")] + Cancelled(SqlOverHttpCancel), +} + +impl ReportableError for SqlOverHttpError { + fn get_error_kind(&self) -> ErrorKind { + match self { + SqlOverHttpError::ReadPayload(e) => e.get_error_kind(), + SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(), + SqlOverHttpError::ConnInfo(e) => e.get_error_kind(), + SqlOverHttpError::RequestTooLarge => ErrorKind::User, + SqlOverHttpError::ResponseTooLarge => ErrorKind::User, + SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User, + SqlOverHttpError::Postgres(p) => p.get_error_kind(), + SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres, + SqlOverHttpError::Cancelled(c) => c.get_error_kind(), + } + } +} + +impl UserFacingError for SqlOverHttpError { + fn to_string_client(&self) -> String { + match self { + SqlOverHttpError::ReadPayload(p) => p.to_string(), + SqlOverHttpError::ConnectCompute(c) => c.to_string_client(), + SqlOverHttpError::ConnInfo(c) => c.to_string_client(), + SqlOverHttpError::RequestTooLarge => self.to_string(), + SqlOverHttpError::ResponseTooLarge => self.to_string(), + SqlOverHttpError::InvalidIsolationLevel => self.to_string(), + SqlOverHttpError::Postgres(p) => p.to_string(), + SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(), + SqlOverHttpError::Cancelled(_) => self.to_string(), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ReadPayloadError { + #[error("could not read the HTTP request body: {0}")] + Read(#[from] hyper::Error), + #[error("could not parse the HTTP request body: {0}")] + Parse(#[from] serde_json::Error), +} + +impl ReportableError for ReadPayloadError { + fn get_error_kind(&self) -> ErrorKind { + match self { + ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect, + ReadPayloadError::Parse(_) => ErrorKind::User, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum SqlOverHttpCancel { + #[error("query was cancelled")] + Postgres, + #[error("query was cancelled while stuck trying to connect to the database")] + Connect, +} + +impl ReportableError for SqlOverHttpCancel { + fn get_error_kind(&self) -> ErrorKind { + match self { + SqlOverHttpCancel::Postgres => ErrorKind::RateLimit, + SqlOverHttpCancel::Connect => ErrorKind::ServiceRateLimit, + } + } +} async fn handle_inner( cancel: CancellationToken, @@ -325,7 +440,7 @@ async fn handle_inner( ctx: &mut RequestMonitoring, request: Request, backend: Arc, -) -> Result, Cancelled>, anyhow::Error> { +) -> Result, SqlOverHttpError> { let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE .with_label_values(&[ctx.protocol]) .guard(); @@ -358,7 +473,7 @@ async fn handle_inner( b"ReadUncommitted" => IsolationLevel::ReadUncommitted, b"ReadCommitted" => IsolationLevel::ReadCommitted, b"RepeatableRead" => IsolationLevel::RepeatableRead, - _ => bail!("invalid isolation level"), + _ => return Err(SqlOverHttpError::InvalidIsolationLevel), }), None => None, }; @@ -376,19 +491,16 @@ async fn handle_inner( // we don't have a streaming request support yet so this is to prevent OOM // from a malicious user sending an extremely large request body if request_content_length > MAX_REQUEST_SIZE { - return Err(anyhow::anyhow!( - "request is too large (max is {MAX_REQUEST_SIZE} bytes)" - )); + return Err(SqlOverHttpError::RequestTooLarge); } let fetch_and_process_request = async { - let body = hyper::body::to_bytes(request.into_body()) - .await - .map_err(anyhow::Error::from)?; + let body = hyper::body::to_bytes(request.into_body()).await?; info!(length = body.len(), "request payload read"); let payload: Payload = serde_json::from_slice(&body)?; - Ok::(payload) // Adjust error type accordingly - }; + Ok::(payload) // Adjust error type accordingly + } + .map_err(SqlOverHttpError::from); let authenticate_and_connect = async { let keys = backend.authenticate(ctx, &conn_info).await?; @@ -398,8 +510,9 @@ async fn handle_inner( // not strictly necessary to mark success here, // but it's just insurance for if we forget it somewhere else ctx.latency_timer.success(); - Ok::<_, anyhow::Error>(client) - }; + Ok::<_, HttpConnError>(client) + } + .map_err(SqlOverHttpError::from); // Run both operations in parallel let (payload, mut client) = match select( @@ -412,7 +525,9 @@ async fn handle_inner( .await { Either::Left((result, _cancelled)) => result?, - Either::Right((_cancelled, _)) => return Ok(Err(Cancelled())), + Either::Right((_cancelled, _)) => { + return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)) + } }; let mut response = Response::builder() @@ -456,20 +571,24 @@ async fn handle_inner( results } Ok(Err(error)) => { - let db_error = error - .downcast_ref::() - .and_then(|e| e.as_db_error()); + let db_error = match &error { + SqlOverHttpError::ConnectCompute( + HttpConnError::ConnectionError(e), + ) + | SqlOverHttpError::Postgres(e) => e.as_db_error(), + _ => None, + }; // if errored for some other reason, it might not be safe to return if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) { discard.discard(); } - return Ok(Err(Cancelled())); + return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); } Err(_timeout) => { discard.discard(); - return Ok(Err(Cancelled())); + return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); } } } @@ -507,7 +626,7 @@ async fn handle_inner( ) .await { - Ok(Ok(results)) => { + Ok(results) => { info!("commit"); let status = transaction.commit().await.map_err(|e| { // if we cannot commit - for now don't return connection to pool @@ -518,14 +637,14 @@ async fn handle_inner( discard.check_idle(status); results } - Ok(Err(Cancelled())) => { + Err(SqlOverHttpError::Cancelled(_)) => { if let Err(err) = cancel_token.cancel_query(NoTls).await { tracing::error!(?err, "could not cancel query"); } // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe. discard.discard(); - return Ok(Err(Cancelled())); + return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); } Err(err) => { info!("rollback"); @@ -541,16 +660,10 @@ async fn handle_inner( }; if txn_read_only { - response = response.header( - TXN_READ_ONLY.clone(), - HeaderValue::try_from(txn_read_only.to_string())?, - ); + response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE); } if txn_deferrable { - response = response.header( - TXN_DEFERRABLE.clone(), - HeaderValue::try_from(txn_deferrable.to_string())?, - ); + response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE); } if let Some(txn_isolation_level) = txn_isolation_level_raw { response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); @@ -574,7 +687,7 @@ async fn handle_inner( // moving this later in the stack is going to be a lot of effort and ehhhh metrics.record_egress(len as u64); - Ok(Ok(response)) + Ok(response) } async fn query_batch( @@ -584,7 +697,7 @@ async fn query_batch( total_size: &mut usize, raw_output: bool, array_mode: bool, -) -> anyhow::Result, Cancelled>> { +) -> Result, SqlOverHttpError> { let mut results = Vec::with_capacity(queries.queries.len()); let mut current_size = 0; for stmt in queries.queries { @@ -606,12 +719,12 @@ async fn query_batch( return Err(e); } Either::Right((_cancelled, _)) => { - return Ok(Err(Cancelled())); + return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); } } } *total_size += current_size; - Ok(Ok(results)) + Ok(results) } async fn query_to_json( @@ -620,7 +733,7 @@ async fn query_to_json( current_size: &mut usize, raw_output: bool, default_array_mode: bool, -) -> anyhow::Result<(ReadyForQueryStatus, Value)> { +) -> Result<(ReadyForQueryStatus, Value), SqlOverHttpError> { info!("executing query"); let query_params = data.params; let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?); @@ -637,9 +750,7 @@ async fn query_to_json( // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) if *current_size > MAX_RESPONSE_SIZE { - return Err(anyhow::anyhow!( - "response is too large (max is {MAX_RESPONSE_SIZE} bytes)" - )); + return Err(SqlOverHttpError::ResponseTooLarge); } }