diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index c70cb598de..08a06163e1 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -284,6 +284,18 @@ impl Client { simple_query::batch_execute(self.inner(), query).await } + pub async fn discard_all(&self) -> Result { + // clear the prepared statements that are about to be nuked from the postgres session + { + let mut typeinfo = self.inner.cached_typeinfo.lock(); + typeinfo.typeinfo = None; + typeinfo.typeinfo_composite = None; + typeinfo.typeinfo_enum = None; + } + + self.batch_execute("discard all").await + } + /// Begins a new database transaction. /// /// The transaction will roll back by default - use the `commit` method to commit it. diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 8426a0810e..c958d077fc 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -35,6 +35,7 @@ use super::conn_pool_lib::{ Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, DbUserConn, EndpointConnPool, }; +use super::sql_over_http::SqlOverHttpError; use crate::context::RequestContext; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::Metrics; @@ -274,18 +275,23 @@ pub(crate) fn poll_client( } impl ClientInnerCommon { - pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> { + pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), SqlOverHttpError> { if let ClientDataEnum::Local(local_data) = &mut self.data { local_data.jti += 1; let token = resign_jwt(&local_data.key, payload, local_data.jti)?; - // discard all cannot run in a transaction. must be executed alone. - self.inner.batch_execute("discard all").await?; + self.inner + .discard_all() + .await + .map_err(SqlOverHttpError::InternalPostgres)?; // initiates the auth session // this is safe from query injections as the jwt format free of any escape characters. let query = format!("select auth.jwt_session_init('{token}')"); - self.inner.batch_execute(&query).await?; + self.inner + .batch_execute(&query) + .await + .map_err(SqlOverHttpError::InternalPostgres)?; let pid = self.inner.get_process_id(); info!(pid, jti = local_data.jti, "user session state init"); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 93dd531f70..612702231f 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -412,8 +412,12 @@ pub(crate) enum SqlOverHttpError { ResponseTooLarge(usize), #[error("invalid isolation level")] InvalidIsolationLevel, + /// for queries our customers choose to run #[error("{0}")] - Postgres(#[from] postgres_client::Error), + Postgres(#[source] postgres_client::Error), + /// for queries we choose to run + #[error("{0}")] + InternalPostgres(#[source] postgres_client::Error), #[error("{0}")] JsonConversion(#[from] JsonConversionError), #[error("{0}")] @@ -429,6 +433,13 @@ impl ReportableError for SqlOverHttpError { SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User, SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User, SqlOverHttpError::Postgres(p) => p.get_error_kind(), + SqlOverHttpError::InternalPostgres(p) => { + if p.as_db_error().is_some() { + ErrorKind::Service + } else { + ErrorKind::Compute + } + } SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres, SqlOverHttpError::Cancelled(c) => c.get_error_kind(), } @@ -444,6 +455,7 @@ impl UserFacingError for SqlOverHttpError { SqlOverHttpError::ResponseTooLarge(_) => self.to_string(), SqlOverHttpError::InvalidIsolationLevel => self.to_string(), SqlOverHttpError::Postgres(p) => p.to_string(), + SqlOverHttpError::InternalPostgres(p) => p.to_string(), SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(), SqlOverHttpError::Cancelled(_) => self.to_string(), } @@ -462,6 +474,7 @@ impl HttpCodeError for SqlOverHttpError { SqlOverHttpError::ResponseTooLarge(_) => StatusCode::INSUFFICIENT_STORAGE, SqlOverHttpError::InvalidIsolationLevel => StatusCode::BAD_REQUEST, SqlOverHttpError::Postgres(_) => StatusCode::BAD_REQUEST, + SqlOverHttpError::InternalPostgres(_) => StatusCode::INTERNAL_SERVER_ERROR, SqlOverHttpError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR, SqlOverHttpError::Cancelled(_) => StatusCode::INTERNAL_SERVER_ERROR, } @@ -671,16 +684,14 @@ async fn handle_db_inner( let authenticate_and_connect = Box::pin( async { let keys = match auth { - AuthData::Password(pw) => { - backend - .authenticate_with_password(ctx, &conn_info.user_info, &pw) - .await? - } - AuthData::Jwt(jwt) => { - backend - .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) - .await? - } + AuthData::Password(pw) => backend + .authenticate_with_password(ctx, &conn_info.user_info, &pw) + .await + .map_err(HttpConnError::AuthError)?, + AuthData::Jwt(jwt) => backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await + .map_err(HttpConnError::AuthError)?, }; let client = match keys.keys { @@ -703,7 +714,7 @@ async fn handle_db_inner( // not strictly necessary to mark success here, // but it's just insurance for if we forget it somewhere else ctx.success(); - Ok::<_, HttpConnError>(client) + Ok::<_, SqlOverHttpError>(client) } .map_err(SqlOverHttpError::from), ); @@ -933,11 +944,15 @@ impl BatchQueryData { builder = builder.deferrable(true); } - let transaction = builder.start().await.inspect_err(|_| { - // if we cannot start a transaction, we should return immediately - // and not return to the pool. connection is clearly broken - discard.discard(); - })?; + let transaction = builder + .start() + .await + .inspect_err(|_| { + // if we cannot start a transaction, we should return immediately + // and not return to the pool. connection is clearly broken + discard.discard(); + }) + .map_err(SqlOverHttpError::Postgres)?; let json_output = match query_batch( config, @@ -950,11 +965,15 @@ impl BatchQueryData { { Ok(json_output) => { info!("commit"); - let status = transaction.commit().await.inspect_err(|_| { - // if we cannot commit - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - })?; + let status = transaction + .commit() + .await + .inspect_err(|_| { + // if we cannot commit - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + }) + .map_err(SqlOverHttpError::Postgres)?; discard.check_idle(status); json_output } @@ -969,11 +988,15 @@ impl BatchQueryData { } Err(err) => { info!("rollback"); - let status = transaction.rollback().await.inspect_err(|_| { - // if we cannot rollback - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - })?; + let status = transaction + .rollback() + .await + .inspect_err(|_| { + // if we cannot rollback - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + }) + .map_err(SqlOverHttpError::Postgres)?; discard.check_idle(status); return Err(err); } @@ -1032,7 +1055,12 @@ async fn query_to_json( let query_start = Instant::now(); let query_params = data.params; - let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?); + let mut row_stream = std::pin::pin!( + client + .query_raw_txt(&data.query, query_params) + .await + .map_err(SqlOverHttpError::Postgres)? + ); let query_acknowledged = Instant::now(); // Manually drain the stream into a vector to leave row_stream hanging @@ -1040,7 +1068,7 @@ async fn query_to_json( // big. let mut rows: Vec = Vec::new(); while let Some(row) = row_stream.next().await { - let row = row?; + let row = row.map_err(SqlOverHttpError::Postgres)?; *current_size += row.body_len(); rows.push(row); // we don't have a streaming response support yet so this is to prevent OOM @@ -1091,7 +1119,14 @@ async fn query_to_json( "dataTypeModifier": c.type_modifier(), "format": "text", })); - columns.push(client.get_type(c.type_oid()).await?); + + match client.get_type(c.type_oid()).await { + Ok(t) => columns.push(t), + Err(err) => { + tracing::warn!(?err, "unable to query type information"); + return Err(SqlOverHttpError::InternalPostgres(err)); + } + } } let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);