From d1b60fa0b69dde210ec449062b0565cb4c1889a8 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 11 Mar 2025 10:48:50 +0000 Subject: [PATCH] fix(proxy): delete prepared statements when discarding (#11165) Fixes https://github.com/neondatabase/serverless/issues/144 When tables have enums, we need to perform type queries for that data. We cache these query statements for performance reasons. In Neon RLS, we run "discard all" for security reasons, which discards all the statements. When we need to type check again, the statements are no longer valid. This fixes it to discard the statements as well. I've also added some new logs and error types to monitor this. Currently we don't see the prepared statement errors in our logs. --- libs/proxy/tokio-postgres2/src/client.rs | 12 +++ proxy/src/serverless/local_conn_pool.rs | 14 +++- proxy/src/serverless/sql_over_http.rs | 95 ++++++++++++++++-------- 3 files changed, 87 insertions(+), 34 deletions(-) diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index c70cb598de..08a06163e1 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -284,6 +284,18 @@ impl Client { simple_query::batch_execute(self.inner(), query).await } + pub async fn discard_all(&self) -> Result { + // clear the prepared statements that are about to be nuked from the postgres session + { + let mut typeinfo = self.inner.cached_typeinfo.lock(); + typeinfo.typeinfo = None; + typeinfo.typeinfo_composite = None; + typeinfo.typeinfo_enum = None; + } + + self.batch_execute("discard all").await + } + /// Begins a new database transaction. /// /// The transaction will roll back by default - use the `commit` method to commit it. diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 8426a0810e..c958d077fc 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -35,6 +35,7 @@ use super::conn_pool_lib::{ Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, DbUserConn, EndpointConnPool, }; +use super::sql_over_http::SqlOverHttpError; use crate::context::RequestContext; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::Metrics; @@ -274,18 +275,23 @@ pub(crate) fn poll_client( } impl ClientInnerCommon { - pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> { + pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), SqlOverHttpError> { if let ClientDataEnum::Local(local_data) = &mut self.data { local_data.jti += 1; let token = resign_jwt(&local_data.key, payload, local_data.jti)?; - // discard all cannot run in a transaction. must be executed alone. - self.inner.batch_execute("discard all").await?; + self.inner + .discard_all() + .await + .map_err(SqlOverHttpError::InternalPostgres)?; // initiates the auth session // this is safe from query injections as the jwt format free of any escape characters. let query = format!("select auth.jwt_session_init('{token}')"); - self.inner.batch_execute(&query).await?; + self.inner + .batch_execute(&query) + .await + .map_err(SqlOverHttpError::InternalPostgres)?; let pid = self.inner.get_process_id(); info!(pid, jti = local_data.jti, "user session state init"); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 93dd531f70..612702231f 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -412,8 +412,12 @@ pub(crate) enum SqlOverHttpError { ResponseTooLarge(usize), #[error("invalid isolation level")] InvalidIsolationLevel, + /// for queries our customers choose to run #[error("{0}")] - Postgres(#[from] postgres_client::Error), + Postgres(#[source] postgres_client::Error), + /// for queries we choose to run + #[error("{0}")] + InternalPostgres(#[source] postgres_client::Error), #[error("{0}")] JsonConversion(#[from] JsonConversionError), #[error("{0}")] @@ -429,6 +433,13 @@ impl ReportableError for SqlOverHttpError { SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User, SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User, SqlOverHttpError::Postgres(p) => p.get_error_kind(), + SqlOverHttpError::InternalPostgres(p) => { + if p.as_db_error().is_some() { + ErrorKind::Service + } else { + ErrorKind::Compute + } + } SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres, SqlOverHttpError::Cancelled(c) => c.get_error_kind(), } @@ -444,6 +455,7 @@ impl UserFacingError for SqlOverHttpError { SqlOverHttpError::ResponseTooLarge(_) => self.to_string(), SqlOverHttpError::InvalidIsolationLevel => self.to_string(), SqlOverHttpError::Postgres(p) => p.to_string(), + SqlOverHttpError::InternalPostgres(p) => p.to_string(), SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(), SqlOverHttpError::Cancelled(_) => self.to_string(), } @@ -462,6 +474,7 @@ impl HttpCodeError for SqlOverHttpError { SqlOverHttpError::ResponseTooLarge(_) => StatusCode::INSUFFICIENT_STORAGE, SqlOverHttpError::InvalidIsolationLevel => StatusCode::BAD_REQUEST, SqlOverHttpError::Postgres(_) => StatusCode::BAD_REQUEST, + SqlOverHttpError::InternalPostgres(_) => StatusCode::INTERNAL_SERVER_ERROR, SqlOverHttpError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR, SqlOverHttpError::Cancelled(_) => StatusCode::INTERNAL_SERVER_ERROR, } @@ -671,16 +684,14 @@ async fn handle_db_inner( let authenticate_and_connect = Box::pin( async { let keys = match auth { - AuthData::Password(pw) => { - backend - .authenticate_with_password(ctx, &conn_info.user_info, &pw) - .await? - } - AuthData::Jwt(jwt) => { - backend - .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) - .await? - } + AuthData::Password(pw) => backend + .authenticate_with_password(ctx, &conn_info.user_info, &pw) + .await + .map_err(HttpConnError::AuthError)?, + AuthData::Jwt(jwt) => backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await + .map_err(HttpConnError::AuthError)?, }; let client = match keys.keys { @@ -703,7 +714,7 @@ async fn handle_db_inner( // not strictly necessary to mark success here, // but it's just insurance for if we forget it somewhere else ctx.success(); - Ok::<_, HttpConnError>(client) + Ok::<_, SqlOverHttpError>(client) } .map_err(SqlOverHttpError::from), ); @@ -933,11 +944,15 @@ impl BatchQueryData { builder = builder.deferrable(true); } - let transaction = builder.start().await.inspect_err(|_| { - // if we cannot start a transaction, we should return immediately - // and not return to the pool. connection is clearly broken - discard.discard(); - })?; + let transaction = builder + .start() + .await + .inspect_err(|_| { + // if we cannot start a transaction, we should return immediately + // and not return to the pool. connection is clearly broken + discard.discard(); + }) + .map_err(SqlOverHttpError::Postgres)?; let json_output = match query_batch( config, @@ -950,11 +965,15 @@ impl BatchQueryData { { Ok(json_output) => { info!("commit"); - let status = transaction.commit().await.inspect_err(|_| { - // if we cannot commit - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - })?; + let status = transaction + .commit() + .await + .inspect_err(|_| { + // if we cannot commit - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + }) + .map_err(SqlOverHttpError::Postgres)?; discard.check_idle(status); json_output } @@ -969,11 +988,15 @@ impl BatchQueryData { } Err(err) => { info!("rollback"); - let status = transaction.rollback().await.inspect_err(|_| { - // if we cannot rollback - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - })?; + let status = transaction + .rollback() + .await + .inspect_err(|_| { + // if we cannot rollback - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + }) + .map_err(SqlOverHttpError::Postgres)?; discard.check_idle(status); return Err(err); } @@ -1032,7 +1055,12 @@ async fn query_to_json( let query_start = Instant::now(); let query_params = data.params; - let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?); + let mut row_stream = std::pin::pin!( + client + .query_raw_txt(&data.query, query_params) + .await + .map_err(SqlOverHttpError::Postgres)? + ); let query_acknowledged = Instant::now(); // Manually drain the stream into a vector to leave row_stream hanging @@ -1040,7 +1068,7 @@ async fn query_to_json( // big. let mut rows: Vec = Vec::new(); while let Some(row) = row_stream.next().await { - let row = row?; + let row = row.map_err(SqlOverHttpError::Postgres)?; *current_size += row.body_len(); rows.push(row); // we don't have a streaming response support yet so this is to prevent OOM @@ -1091,7 +1119,14 @@ async fn query_to_json( "dataTypeModifier": c.type_modifier(), "format": "text", })); - columns.push(client.get_type(c.type_oid()).await?); + + match client.get_type(c.type_oid()).await { + Ok(t) => columns.push(t), + Err(err) => { + tracing::warn!(?err, "unable to query type information"); + return Err(SqlOverHttpError::InternalPostgres(err)); + } + } } let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);