fix(proxy): delete prepared statements when discarding (#11165)

Fixes https://github.com/neondatabase/serverless/issues/144

When tables have enums, we need to perform type queries for that data.
We cache these query statements for performance reasons. In Neon RLS, we
run "discard all" for security reasons, which discards all the
statements. When we need to type check again, the statements are no
longer valid.

This fixes it to discard the statements as well.

I've also added some new logs and error types to monitor this. Currently
we don't see the prepared statement errors in our logs.
This commit is contained in:
Conrad Ludgate
2025-03-11 10:48:50 +00:00
committed by GitHub
parent 7c462b3417
commit d1b60fa0b6
3 changed files with 87 additions and 34 deletions

View File

@@ -35,6 +35,7 @@ use super::conn_pool_lib::{
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, DbUserConn,
EndpointConnPool,
};
use super::sql_over_http::SqlOverHttpError;
use crate::context::RequestContext;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
@@ -274,18 +275,23 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
impl ClientInnerCommon<postgres_client::Client> {
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), SqlOverHttpError> {
if let ClientDataEnum::Local(local_data) = &mut self.data {
local_data.jti += 1;
let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
// discard all cannot run in a transaction. must be executed alone.
self.inner.batch_execute("discard all").await?;
self.inner
.discard_all()
.await
.map_err(SqlOverHttpError::InternalPostgres)?;
// initiates the auth session
// this is safe from query injections as the jwt format free of any escape characters.
let query = format!("select auth.jwt_session_init('{token}')");
self.inner.batch_execute(&query).await?;
self.inner
.batch_execute(&query)
.await
.map_err(SqlOverHttpError::InternalPostgres)?;
let pid = self.inner.get_process_id();
info!(pid, jti = local_data.jti, "user session state init");

View File

@@ -412,8 +412,12 @@ pub(crate) enum SqlOverHttpError {
ResponseTooLarge(usize),
#[error("invalid isolation level")]
InvalidIsolationLevel,
/// for queries our customers choose to run
#[error("{0}")]
Postgres(#[from] postgres_client::Error),
Postgres(#[source] postgres_client::Error),
/// for queries we choose to run
#[error("{0}")]
InternalPostgres(#[source] postgres_client::Error),
#[error("{0}")]
JsonConversion(#[from] JsonConversionError),
#[error("{0}")]
@@ -429,6 +433,13 @@ impl ReportableError for SqlOverHttpError {
SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User,
SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
SqlOverHttpError::Postgres(p) => p.get_error_kind(),
SqlOverHttpError::InternalPostgres(p) => {
if p.as_db_error().is_some() {
ErrorKind::Service
} else {
ErrorKind::Compute
}
}
SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
}
@@ -444,6 +455,7 @@ impl UserFacingError for SqlOverHttpError {
SqlOverHttpError::ResponseTooLarge(_) => self.to_string(),
SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
SqlOverHttpError::Postgres(p) => p.to_string(),
SqlOverHttpError::InternalPostgres(p) => p.to_string(),
SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
SqlOverHttpError::Cancelled(_) => self.to_string(),
}
@@ -462,6 +474,7 @@ impl HttpCodeError for SqlOverHttpError {
SqlOverHttpError::ResponseTooLarge(_) => StatusCode::INSUFFICIENT_STORAGE,
SqlOverHttpError::InvalidIsolationLevel => StatusCode::BAD_REQUEST,
SqlOverHttpError::Postgres(_) => StatusCode::BAD_REQUEST,
SqlOverHttpError::InternalPostgres(_) => StatusCode::INTERNAL_SERVER_ERROR,
SqlOverHttpError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR,
SqlOverHttpError::Cancelled(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
@@ -671,16 +684,14 @@ async fn handle_db_inner(
let authenticate_and_connect = Box::pin(
async {
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await?
}
AuthData::Password(pw) => backend
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
.await
.map_err(HttpConnError::AuthError)?,
AuthData::Jwt(jwt) => backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::AuthError)?,
};
let client = match keys.keys {
@@ -703,7 +714,7 @@ async fn handle_db_inner(
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.success();
Ok::<_, HttpConnError>(client)
Ok::<_, SqlOverHttpError>(client)
}
.map_err(SqlOverHttpError::from),
);
@@ -933,11 +944,15 @@ impl BatchQueryData {
builder = builder.deferrable(true);
}
let transaction = builder.start().await.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
})?;
let transaction = builder
.start()
.await
.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
let json_output = match query_batch(
config,
@@ -950,11 +965,15 @@ impl BatchQueryData {
{
Ok(json_output) => {
info!("commit");
let status = transaction.commit().await.inspect_err(|_| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})?;
let status = transaction
.commit()
.await
.inspect_err(|_| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
discard.check_idle(status);
json_output
}
@@ -969,11 +988,15 @@ impl BatchQueryData {
}
Err(err) => {
info!("rollback");
let status = transaction.rollback().await.inspect_err(|_| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})?;
let status = transaction
.rollback()
.await
.inspect_err(|_| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
discard.check_idle(status);
return Err(err);
}
@@ -1032,7 +1055,12 @@ async fn query_to_json<T: GenericClient>(
let query_start = Instant::now();
let query_params = data.params;
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
let mut row_stream = std::pin::pin!(
client
.query_raw_txt(&data.query, query_params)
.await
.map_err(SqlOverHttpError::Postgres)?
);
let query_acknowledged = Instant::now();
// Manually drain the stream into a vector to leave row_stream hanging
@@ -1040,7 +1068,7 @@ async fn query_to_json<T: GenericClient>(
// big.
let mut rows: Vec<postgres_client::Row> = Vec::new();
while let Some(row) = row_stream.next().await {
let row = row?;
let row = row.map_err(SqlOverHttpError::Postgres)?;
*current_size += row.body_len();
rows.push(row);
// we don't have a streaming response support yet so this is to prevent OOM
@@ -1091,7 +1119,14 @@ async fn query_to_json<T: GenericClient>(
"dataTypeModifier": c.type_modifier(),
"format": "text",
}));
columns.push(client.get_type(c.type_oid()).await?);
match client.get_type(c.type_oid()).await {
Ok(t) => columns.push(t),
Err(err) => {
tracing::warn!(?err, "unable to query type information");
return Err(SqlOverHttpError::InternalPostgres(err));
}
}
}
let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);