fix(proxy): forward notifications from authentication (#9948)

Fixes https://github.com/neondatabase/cloud/issues/20973. 

This refactors `connect_raw` in order to return direct access to the
delayed notices.

I cannot find a way to test this with psycopg2 unfortunately, although
testing it with psql does return the expected results.
This commit is contained in:
Conrad Ludgate
2024-12-02 12:29:57 +00:00
committed by GitHub
parent bd09369198
commit cd1d2d1996
11 changed files with 117 additions and 58 deletions

View File

@@ -6,6 +6,7 @@ use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_protocol::message::backend::NoticeResponseBody;
use pq_proto::StartupMessageParams;
use rustls::client::danger::ServerCertVerifier;
use rustls::crypto::ring;
@@ -13,6 +14,7 @@ use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{CancelToken, RawConnection};
use tracing::{debug, error, info, warn};
use crate::auth::parse_endpoint_param;
@@ -277,6 +279,8 @@ pub(crate) struct PostgresConnection {
pub(crate) cancel_closure: CancelClosure,
/// Labels for proxy's metrics.
pub(crate) aux: MetricsAuxInfo,
/// Notices received from compute after authenticating
pub(crate) delayed_notice: Vec<NoticeResponseBody>,
_guage: NumDbConnectionsGuard<'static>,
}
@@ -322,10 +326,19 @@ impl ConnCfg {
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = self.0.connect_raw(stream, tls).await?;
let connection = self.0.connect_raw(stream, tls).await?;
drop(pause);
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
let stream = connection.stream.into_inner();
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connection;
tracing::Span::current().record("pid", tracing::field::display(process_id));
let stream = stream.into_inner();
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
@@ -334,18 +347,23 @@ impl ConnCfg {
self.0.get_ssl_mode()
);
// This is very ugly but as of now there's no better way to
// extract the connection parameters from tokio-postgres' connection.
// TODO: solve this problem in a more elegant manner (e.g. the new library).
let params = connection.parameters;
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token(), vec![]);
let cancel_closure = CancelClosure::new(
socket_addr,
CancelToken {
socket_config: None,
ssl_mode: self.0.get_ssl_mode(),
process_id,
secret_key,
},
vec![],
);
let connection = PostgresConnection {
stream,
params,
params: parameters,
delayed_notice,
cancel_closure,
aux,
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),

View File

@@ -384,11 +384,13 @@ pub(crate) async fn prepare_client_connection<P>(
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
// Forward all deferred notices to the client.
for notice in &node.delayed_notice {
stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?;
}
// Forward all postgres connection params to the client.
// Right now the implementation is very hacky and inefficent (ideally,
// we don't need an intermediate hashmap), but at least it should be correct.
for (name, value) in &node.params {
// TODO: Theoretically, this could result in a big pile of params...
stream.write_message_noflush(&Be::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),

View File

@@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
let _conn = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Require)
@@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> {
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
let _conn = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.options("project=generic-project-name")
@@ -296,7 +296,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
Scram::new(password).await?,
));
let (_client, _conn) = tokio_postgres::Config::new()
let _conn = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
.user("user")
.dbname("db")
@@ -320,7 +320,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
Scram::new("password").await?,
));
let (_client, _conn) = tokio_postgres::Config::new()
let _conn = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
.user("user")
.dbname("db")