diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 5dd264b35e..f6c58c7459 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -33,12 +33,51 @@ use crate::types::Host; pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; #[derive(Debug, Error)] -pub(crate) enum ConnectionError { +pub(crate) enum PostgresError { /// This error doesn't seem to reveal any secrets; for instance, /// `postgres_client::error::Kind` doesn't contain ip addresses and such. #[error("{COULD_NOT_CONNECT}: {0}")] Postgres(#[from] postgres_client::Error), +} +impl UserFacingError for PostgresError { + fn to_string_client(&self) -> String { + match self { + // This helps us drop irrelevant library-specific prefixes. + // TODO: propagate severity level and other parameters. + PostgresError::Postgres(err) => match err.as_db_error() { + Some(err) => { + let msg = err.message(); + + if msg.starts_with("unsupported startup parameter: ") + || msg.starts_with("unsupported startup parameter in options: ") + { + format!( + "{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter" + ) + } else { + msg.to_owned() + } + } + None => err.to_string(), + }, + } + } +} + +impl ReportableError for PostgresError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + PostgresError::Postgres(e) if e.as_db_error().is_some() => { + crate::error::ErrorKind::Postgres + } + PostgresError::Postgres(_) => crate::error::ErrorKind::Compute, + } + } +} + +#[derive(Debug, Error)] +pub(crate) enum ConnectionError { #[error("{COULD_NOT_CONNECT}: {0}")] TlsError(#[from] TlsError), @@ -52,22 +91,6 @@ pub(crate) enum ConnectionError { impl UserFacingError for ConnectionError { fn to_string_client(&self) -> String { match self { - // This helps us drop irrelevant library-specific prefixes. - // TODO: propagate severity level and other parameters. - ConnectionError::Postgres(err) => match err.as_db_error() { - Some(err) => { - let msg = err.message(); - - if msg.starts_with("unsupported startup parameter: ") - || msg.starts_with("unsupported startup parameter in options: ") - { - format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter") - } else { - msg.to_owned() - } - } - None => err.to_string(), - }, ConnectionError::WakeComputeError(err) => err.to_string_client(), ConnectionError::TooManyConnectionAttempts(_) => { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() @@ -80,10 +103,6 @@ impl UserFacingError for ConnectionError { impl ReportableError for ConnectionError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - ConnectionError::Postgres(e) if e.as_db_error().is_some() => { - crate::error::ErrorKind::Postgres - } - ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), @@ -206,6 +225,54 @@ impl AuthInfo { } } } + + pub async fn authenticate( + &self, + ctx: &RequestContext, + compute: &mut ComputeConnection, + user_info: ComputeUserInfo, + ) -> Result { + // client config with stubbed connect info. + // TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely, + // utilising pqproto.rs. + let mut tmp_config = postgres_client::Config::new(String::new(), 0); + // We have already established SSL if necessary. + tmp_config.ssl_mode(SslMode::Disable); + let tmp_config = self.enrich(tmp_config); + + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let connection = tmp_config.connect_raw(&mut compute.stream, NoTls).await?; + drop(pause); + + let RawConnection { + stream: _, + parameters, + delayed_notice, + process_id, + secret_key, + } = connection; + + tracing::Span::current().record("pid", tracing::field::display(process_id)); + + // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. + // Yet another reason to rework the connection establishing code. + let cancel_closure = CancelClosure::new( + compute.socket_addr, + RawCancelToken { + ssl_mode: compute.ssl_mode, + process_id, + secret_key, + }, + compute.hostname.to_string(), + user_info, + ); + + Ok(PostgresSettings { + params: parameters, + cancel_closure, + delayed_notice, + }) + } } impl ConnectInfo { @@ -268,51 +335,42 @@ impl ConnectInfo { pub type RustlsStream = >::Stream; pub type MaybeRustlsStream = MaybeTlsStream; -pub(crate) struct PostgresConnection { - /// Socket connected to a compute node. - pub(crate) stream: MaybeTlsStream, +// TODO(conrad): we don't need to parse these. +// These are just immediately forwarded back to the client. +// We could instead stream them out instead of reading them into memory. +pub struct PostgresSettings { /// PostgreSQL connection parameters. - pub(crate) params: std::collections::HashMap, + pub params: std::collections::HashMap, /// Query cancellation token. - pub(crate) cancel_closure: CancelClosure, - /// Labels for proxy's metrics. - pub(crate) aux: MetricsAuxInfo, + pub cancel_closure: CancelClosure, /// Notices received from compute after authenticating - pub(crate) delayed_notice: Vec, + pub delayed_notice: Vec, +} - pub(crate) guage: NumDbConnectionsGuard<'static>, +pub struct ComputeConnection { + /// Socket connected to a compute node. + pub stream: MaybeTlsStream, + /// Labels for proxy's metrics. + pub aux: MetricsAuxInfo, + pub hostname: Host, + pub ssl_mode: SslMode, + pub socket_addr: SocketAddr, + pub guage: NumDbConnectionsGuard<'static>, } impl ConnectInfo { /// Connect to a corresponding compute node. - pub(crate) async fn connect( + pub async fn connect( &self, ctx: &RequestContext, - aux: MetricsAuxInfo, - auth: &AuthInfo, + aux: &MetricsAuxInfo, config: &ComputeConfig, - user_info: ComputeUserInfo, - ) -> Result { - let mut tmp_config = auth.enrich(self.to_postgres_client_config()); - // we setup SSL early in `ConnectInfo::connect_raw`. - tmp_config.ssl_mode(SslMode::Disable); - + ) -> Result { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let (socket_addr, stream) = self.connect_raw(config).await?; - let connection = tmp_config.connect_raw(stream, NoTls).await?; drop(pause); - let RawConnection { - stream, - parameters, - delayed_notice, - process_id, - secret_key, - } = connection; - - tracing::Span::current().record("pid", tracing::field::display(process_id)); tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); - let MaybeTlsStream::Raw(stream) = stream.into_inner(); // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?) info!( @@ -324,25 +382,12 @@ impl ConnectInfo { ctx.get_testodrome_id().unwrap_or_default(), ); - // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. - // Yet another reason to rework the connection establishing code. - let cancel_closure = CancelClosure::new( - socket_addr, - RawCancelToken { - ssl_mode: self.ssl_mode, - process_id, - secret_key, - }, - self.host.to_string(), - user_info, - ); - - let connection = PostgresConnection { + let connection = ComputeConnection { stream, - params: parameters, - delayed_notice, - cancel_closure, - aux, + socket_addr, + hostname: self.host.clone(), + ssl_mode: self.ssl_mode, + aux: aux.clone(), guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()), }; diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 89adfc9049..113a11beab 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -218,11 +218,9 @@ pub(crate) async fn handle_client( }; auth_info.set_startup_params(¶ms, true); - let node = connect_to_compute( + let mut node = connect_to_compute( ctx, &TcpMechanism { - user_info, - auth: auth_info, locks: &config.connect_compute_locks, }, &node_info, @@ -232,9 +230,14 @@ pub(crate) async fn handle_client( .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; + let pg_settings = auth_info + .authenticate(ctx, &mut node, user_info) + .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) + .await?; + let session = cancellation_handler.get_key(); - prepare_client_connection(&node, *session.key(), &mut stream); + prepare_client_connection(&pg_settings, *session.key(), &mut stream); let stream = stream.flush_and_into_inner().await?; let session_id = ctx.session_id(); @@ -244,7 +247,7 @@ pub(crate) async fn handle_client( .maintain_cancel_key( session_id, cancel, - &node.cancel_closure, + &pg_settings.cancel_closure, &config.connect_to_compute, ) .await; diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index ed83e98bfe..a8c59dad0c 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -76,13 +76,9 @@ impl NodeInfo { pub(crate) async fn connect( &self, ctx: &RequestContext, - auth: &compute::AuthInfo, config: &ComputeConfig, - user_info: ComputeUserInfo, - ) -> Result { - self.conn_info - .connect(ctx, self.aux.clone(), auth, config, user_info) - .await + ) -> Result { + self.conn_info.connect(ctx, &self.aux, config).await } } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 92ed84f50f..aa675a439e 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -2,8 +2,7 @@ use async_trait::async_trait; use tokio::time; use tracing::{debug, info, warn}; -use crate::auth::backend::ComputeUserInfo; -use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection}; +use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection}; use crate::config::{ComputeConfig, RetryConfig}; use crate::context::RequestContext; use crate::control_plane::errors::WakeComputeError; @@ -50,15 +49,13 @@ pub(crate) trait ConnectMechanism { } pub(crate) struct TcpMechanism { - pub(crate) auth: AuthInfo, /// connect_to_compute concurrency lock pub(crate) locks: &'static ApiLocks, - pub(crate) user_info: ComputeUserInfo, } #[async_trait] impl ConnectMechanism for TcpMechanism { - type Connection = PostgresConnection; + type Connection = ComputeConnection; type ConnectError = compute::ConnectionError; type Error = compute::ConnectionError; @@ -71,13 +68,9 @@ impl ConnectMechanism for TcpMechanism { ctx: &RequestContext, node_info: &control_plane::CachedNodeInfo, config: &ComputeConfig, - ) -> Result { + ) -> Result { let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - permit.release_result( - node_info - .connect(ctx, &self.auth, config, self.user_info.clone()) - .await, - ) + permit.release_result(node_info.connect(ctx, config).await) } } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 7da1b8d8fa..6947e07488 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -357,24 +357,28 @@ pub(crate) async fn handle_client( let res = connect_to_compute( ctx, &TcpMechanism { - user_info: creds.info.clone(), - auth: auth_info, locks: &config.connect_compute_locks, }, - &auth::Backend::ControlPlane(cplane, creds.info), + &auth::Backend::ControlPlane(cplane, creds.info.clone()), config.wake_compute_retry_config, &config.connect_to_compute, ) .await; - let node = match res { + let mut node = match res { Ok(node) => node, Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; + let pg_settings = auth_info.authenticate(ctx, &mut node, creds.info).await; + let pg_settings = match pg_settings { + Ok(pg_settings) => pg_settings, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + }; + let session = cancellation_handler.get_key(); - prepare_client_connection(&node, *session.key(), &mut stream); + prepare_client_connection(&pg_settings, *session.key(), &mut stream); let stream = stream.flush_and_into_inner().await?; let session_id = ctx.session_id(); @@ -384,7 +388,7 @@ pub(crate) async fn handle_client( .maintain_cancel_key( session_id, cancel, - &node.cancel_closure, + &pg_settings.cancel_closure, &config.connect_to_compute, ) .await; @@ -413,19 +417,19 @@ pub(crate) async fn handle_client( /// Finish client connection initialization: confirm auth success, send params, etc. pub(crate) fn prepare_client_connection( - node: &compute::PostgresConnection, + settings: &compute::PostgresSettings, cancel_key_data: CancelKeyData, stream: &mut PqStream, ) { // Forward all deferred notices to the client. - for notice in &node.delayed_notice { + for notice in &settings.delayed_notice { stream.write_raw(notice.as_bytes().len(), b'N', |buf| { buf.extend_from_slice(notice.as_bytes()); }); } // Forward all postgres connection params to the client. - for (name, value) in &node.params { + for (name, value) in &settings.params { stream.write_message(BeMessage::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 0f19944afa..e9eca95724 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -99,7 +99,6 @@ impl ShouldRetryWakeCompute for postgres_client::Error { impl CouldRetry for compute::ConnectionError { fn could_retry(&self) -> bool { match self { - compute::ConnectionError::Postgres(err) => err.could_retry(), compute::ConnectionError::TlsError(err) => err.could_retry(), compute::ConnectionError::WakeComputeError(err) => err.could_retry(), compute::ConnectionError::TooManyConnectionAttempts(_) => false, @@ -109,7 +108,6 @@ impl CouldRetry for compute::ConnectionError { impl ShouldRetryWakeCompute for compute::ConnectionError { fn should_retry_wake_compute(&self) -> bool { match self { - compute::ConnectionError::Postgres(err) => err.should_retry_wake_compute(), // the cache entry was not checked for validity compute::ConnectionError::TooManyConnectionAttempts(_) => false, _ => true,