diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 070c73cdcf..b877aaddef 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -26,9 +26,10 @@ use utils::sentry_init::init_sentry; use crate::context::RequestContext; use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::pglb::TlsRequired; use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; -use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute}; +use crate::proxy::{ErrorSource, copy_bidirectional_client_compute}; use crate::stream::{PqStream, Stream}; use crate::util::run_until_cancelled; diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 9ead05d492..2133f33a4d 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -392,7 +392,7 @@ pub async fn run() -> anyhow::Result<()> { match auth_backend { Either::Left(auth_backend) => { if let Some(proxy_listener) = proxy_listener { - client_tasks.spawn(crate::proxy::task_main( + client_tasks.spawn(crate::pglb::task_main( config, auth_backend, proxy_listener, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index d5903286a0..041a56e032 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -11,11 +11,12 @@ use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pglb::ClientRequestError; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection}; +use crate::proxy::{ErrorSource, finish_client_init}; use crate::util::run_until_cancelled; pub async fn task_main( @@ -232,7 +233,7 @@ pub(crate) async fn handle_client( let session = cancellation_handler.get_key(); - prepare_client_connection(&pg_settings, *session.key(), &mut stream); + finish_client_init(&pg_settings, *session.key(), &mut stream); let stream = stream.flush_and_into_inner().await?; let session_id = ctx.session_id(); diff --git a/proxy/src/pglb/handshake.rs b/proxy/src/pglb/handshake.rs index 6970ab8714..25a2d01b4a 100644 --- a/proxy/src/pglb/handshake.rs +++ b/proxy/src/pglb/handshake.rs @@ -8,10 +8,10 @@ use crate::config::TlsConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; +use crate::pglb::TlsRequired; use crate::pqproto::{ BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, }; -use crate::proxy::TlsRequired; use crate::stream::{PqStream, Stream, StreamUpgradeError}; use crate::tls::PG_ALPN_PROTOCOL; diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index cb82524cf6..c4cab155c5 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -2,3 +2,332 @@ pub mod copy_bidirectional; pub mod handshake; pub mod inprocess; pub mod passthrough; + +use std::sync::Arc; + +use futures::FutureExt; +use smol_str::ToSmolStr; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::sync::CancellationToken; +use tracing::{Instrument, debug, error, info, warn}; + +use crate::auth; +use crate::cancellation::{self, CancellationHandler}; +use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; +use crate::context::RequestContext; +use crate::error::{ReportableError, UserFacingError}; +use crate::metrics::{Metrics, NumClientConnectionsGuard}; +pub use crate::pglb::copy_bidirectional::ErrorSource; +use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; +use crate::pglb::passthrough::ProxyPassthrough; +use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; +use crate::proxy::handle_client; +use crate::rate_limiter::EndpointRateLimiter; +use crate::stream::Stream; +use crate::util::run_until_cancelled; + +pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; + +#[derive(Error, Debug)] +#[error("{ERR_INSECURE_CONNECTION}")] +pub struct TlsRequired; + +impl ReportableError for TlsRequired { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + +impl UserFacingError for TlsRequired {} + +pub async fn task_main( + config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, ()>, + listener: tokio::net::TcpListener, + cancellation_token: CancellationToken, + cancellation_handler: Arc, + endpoint_rate_limiter: Arc, +) -> anyhow::Result<()> { + scopeguard::defer! { + info!("proxy has shut down"); + } + + // When set for the server socket, the keepalive setting + // will be inherited by all accepted client sockets. + socket2::SockRef::from(&listener).set_keepalive(true)?; + + let connections = tokio_util::task::task_tracker::TaskTracker::new(); + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); + + while let Some(accept_result) = + run_until_cancelled(listener.accept(), &cancellation_token).await + { + let (socket, peer_addr) = accept_result?; + + let conn_gauge = Metrics::get() + .proxy + .client_connections + .guard(crate::metrics::Protocol::Tcp); + + let session_id = uuid::Uuid::new_v4(); + let cancellation_handler = Arc::clone(&cancellation_handler); + let cancellations = cancellations.clone(); + + debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); + let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); + + connections.spawn(async move { + let (socket, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(socket).await { + Err(e) => { + warn!("per-client task finished with an error: {e:#}"); + return; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_socket, ConnectHeader::Local)) => { + debug!("healthcheck received"); + return; + } + Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), + } + } + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( + socket, + ConnectionInfo { + addr: peer_addr, + extra: None, + }, + ), + }; + + match socket.set_nodelay(true) { + Ok(()) => {} + Err(e) => { + error!( + "per-client task finished with an error: failed to set socket option: {e:#}" + ); + return; + } + } + + let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp); + + let res = handle_connection( + config, + auth_backend, + &ctx, + cancellation_handler, + socket, + ClientMode::Tcp, + endpoint_rate_limiter2, + conn_gauge, + cancellations, + ) + .instrument(ctx.span()) + .boxed() + .await; + + match res { + Err(e) => { + ctx.set_error_kind(e.get_error_kind()); + warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); + } + Ok(None) => { + ctx.set_success(); + } + Ok(Some(p)) => { + ctx.set_success(); + let _disconnect = ctx.log_connect(); + match p.proxy_pass().await { + Ok(()) => {} + Err(ErrorSource::Client(e)) => { + warn!( + ?session_id, + "per-client task finished with an IO error from the client: {e:#}" + ); + } + Err(ErrorSource::Compute(e)) => { + error!( + ?session_id, + "per-client task finished with an IO error from the compute: {e:#}" + ); + } + } + } + } + }); + } + + connections.close(); + cancellations.close(); + drop(listener); + + // Drain connections + connections.wait().await; + cancellations.wait().await; + + Ok(()) +} + +pub(crate) enum ClientMode { + Tcp, + Websockets { hostname: Option }, +} + +/// Abstracts the logic of handling TCP vs WS clients +impl ClientMode { + pub fn allow_cleartext(&self) -> bool { + match self { + ClientMode::Tcp => false, + ClientMode::Websockets { .. } => true, + } + } + + pub fn hostname<'a, S>(&'a self, s: &'a Stream) -> Option<&'a str> { + match self { + ClientMode::Tcp => s.sni_hostname(), + ClientMode::Websockets { hostname } => hostname.as_deref(), + } + } + + pub fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> { + match self { + ClientMode::Tcp => tls, + // TLS is None here if using websockets, because the connection is already encrypted. + ClientMode::Websockets { .. } => None, + } + } +} + +#[derive(Debug, Error)] +// almost all errors should be reported to the user, but there's a few cases where we cannot +// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons +// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation, +// we cannot be sure the client even understands our error message +// 3. PrepareClient: The client disconnected, so we can't tell them anyway... +pub(crate) enum ClientRequestError { + #[error("{0}")] + Cancellation(#[from] cancellation::CancelError), + #[error("{0}")] + Handshake(#[from] HandshakeError), + #[error("{0}")] + HandshakeTimeout(#[from] tokio::time::error::Elapsed), + #[error("{0}")] + PrepareClient(#[from] std::io::Error), + #[error("{0}")] + ReportedError(#[from] crate::stream::ReportedError), +} + +impl ReportableError for ClientRequestError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + ClientRequestError::Cancellation(e) => e.get_error_kind(), + ClientRequestError::Handshake(e) => e.get_error_kind(), + ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit, + ClientRequestError::ReportedError(e) => e.get_error_kind(), + ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect, + } + } +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn handle_connection( + config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, ()>, + ctx: &RequestContext, + cancellation_handler: Arc, + client: S, + mode: ClientMode, + endpoint_rate_limiter: Arc, + conn_gauge: NumClientConnectionsGuard<'static>, + cancellations: tokio_util::task::task_tracker::TaskTracker, +) -> Result>, ClientRequestError> { + debug!( + protocol = %ctx.protocol(), + "handling interactive connection from client" + ); + + let metrics = &Metrics::get().proxy; + let proto = ctx.protocol(); + let request_gauge = metrics.connection_requests.guard(proto); + + let tls = config.tls_config.load(); + let tls = tls.as_deref(); + + let record_handshake_error = !ctx.has_private_peer_addr(); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error); + + let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) + .await?? + { + HandshakeData::Startup(client, params) => (client, params), + HandshakeData::Cancel(cancel_key_data) => { + // spawn a task to cancel the session, but don't wait for it + cancellations.spawn({ + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let ctx = ctx.clone(); + let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id()); + cancel_span.follows_from(tracing::Span::current()); + async move { + cancellation_handler_clone + .cancel_session( + cancel_key_data, + ctx, + config.authentication_config.ip_allowlist_check_enabled, + config.authentication_config.is_vpc_acccess_proxy, + auth_backend.get_api(), + ) + .await + .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok(); + }.instrument(cancel_span) + }); + + return Ok(None); + } + }; + drop(pause); + + ctx.set_db_options(params.clone()); + + let common_names = tls.map(|tls| &tls.common_names); + + let (node, cancel_on_shutdown) = handle_client( + config, + auth_backend, + ctx, + cancellation_handler, + &mut client, + &mode, + endpoint_rate_limiter, + common_names, + ¶ms, + ) + .await?; + + let client = client.flush_and_into_inner().await?; + + let private_link_id = match ctx.extra() { + Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), + Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), + None => None, + }; + + Ok(Some(ProxyPassthrough { + client, + compute: node.stream, + + aux: node.aux, + private_link_id, + + _cancel_on_shutdown: cancel_on_shutdown, + + _req: request_gauge, + _conn: conn_gauge, + _db_conn: node.guage, + })) +} diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index d9c0585efb..08c81afa04 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -5,326 +5,64 @@ pub(crate) mod connect_compute; pub(crate) mod retry; pub(crate) mod wake_compute; +use std::collections::HashSet; +use std::convert::Infallible; use std::sync::Arc; -use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; use regex::Regex; use serde::{Deserialize, Serialize}; -use smol_str::{SmolStr, ToSmolStr, format_smolstr}; -use thiserror::Error; +use smol_str::{SmolStr, format_smolstr}; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::sync::CancellationToken; -use tracing::{Instrument, debug, error, info, warn}; +use tokio::sync::oneshot; +use tracing::Instrument; use crate::cache::Cache; -use crate::cancellation::{self, CancellationHandler}; -use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; +use crate::cancellation::CancellationHandler; +use crate::compute::ComputeConnection; +use crate::config::ProxyConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; -use crate::error::{ReportableError, UserFacingError}; -use crate::metrics::{Metrics, NumClientConnectionsGuard}; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; -use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; -use crate::pglb::passthrough::ProxyPassthrough; +use crate::pglb::{ClientMode, ClientRequestError}; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; -use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; use crate::proxy::retry::ShouldRetryWakeCompute; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; -use crate::util::run_until_cancelled; use crate::{auth, compute}; -const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; - -#[derive(Error, Debug)] -#[error("{ERR_INSECURE_CONNECTION}")] -pub struct TlsRequired; - -impl ReportableError for TlsRequired { - fn get_error_kind(&self) -> crate::error::ErrorKind { - crate::error::ErrorKind::User - } -} - -impl UserFacingError for TlsRequired {} - -pub async fn task_main( - config: &'static ProxyConfig, - auth_backend: &'static auth::Backend<'static, ()>, - listener: tokio::net::TcpListener, - cancellation_token: CancellationToken, - cancellation_handler: Arc, - endpoint_rate_limiter: Arc, -) -> anyhow::Result<()> { - scopeguard::defer! { - info!("proxy has shut down"); - } - - // When set for the server socket, the keepalive setting - // will be inherited by all accepted client sockets. - socket2::SockRef::from(&listener).set_keepalive(true)?; - - let connections = tokio_util::task::task_tracker::TaskTracker::new(); - let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); - - while let Some(accept_result) = - run_until_cancelled(listener.accept(), &cancellation_token).await - { - let (socket, peer_addr) = accept_result?; - - let conn_gauge = Metrics::get() - .proxy - .client_connections - .guard(crate::metrics::Protocol::Tcp); - - let session_id = uuid::Uuid::new_v4(); - let cancellation_handler = Arc::clone(&cancellation_handler); - let cancellations = cancellations.clone(); - - debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); - let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); - - connections.spawn(async move { - let (socket, conn_info) = match config.proxy_protocol_v2 { - ProxyProtocolV2::Required => { - match read_proxy_protocol(socket).await { - Err(e) => { - warn!("per-client task finished with an error: {e:#}"); - return; - } - // our load balancers will not send any more data. let's just exit immediately - Ok((_socket, ConnectHeader::Local)) => { - debug!("healthcheck received"); - return; - } - Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), - } - } - // ignore the header - it cannot be confused for a postgres or http connection so will - // error later. - ProxyProtocolV2::Rejected => ( - socket, - ConnectionInfo { - addr: peer_addr, - extra: None, - }, - ), - }; - - match socket.set_nodelay(true) { - Ok(()) => {} - Err(e) => { - error!( - "per-client task finished with an error: failed to set socket option: {e:#}" - ); - return; - } - } - - let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp); - - let res = handle_client( - config, - auth_backend, - &ctx, - cancellation_handler, - socket, - ClientMode::Tcp, - endpoint_rate_limiter2, - conn_gauge, - cancellations, - ) - .instrument(ctx.span()) - .boxed() - .await; - - match res { - Err(e) => { - ctx.set_error_kind(e.get_error_kind()); - warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); - } - Ok(None) => { - ctx.set_success(); - } - Ok(Some(p)) => { - ctx.set_success(); - let _disconnect = ctx.log_connect(); - match p.proxy_pass().await { - Ok(()) => {} - Err(ErrorSource::Client(e)) => { - warn!( - ?session_id, - "per-client task finished with an IO error from the client: {e:#}" - ); - } - Err(ErrorSource::Compute(e)) => { - error!( - ?session_id, - "per-client task finished with an IO error from the compute: {e:#}" - ); - } - } - } - } - }); - } - - connections.close(); - cancellations.close(); - drop(listener); - - // Drain connections - connections.wait().await; - cancellations.wait().await; - - Ok(()) -} - -pub(crate) enum ClientMode { - Tcp, - Websockets { hostname: Option }, -} - -/// Abstracts the logic of handling TCP vs WS clients -impl ClientMode { - pub(crate) fn allow_cleartext(&self) -> bool { - match self { - ClientMode::Tcp => false, - ClientMode::Websockets { .. } => true, - } - } - - fn hostname<'a, S>(&'a self, s: &'a Stream) -> Option<&'a str> { - match self { - ClientMode::Tcp => s.sni_hostname(), - ClientMode::Websockets { hostname } => hostname.as_deref(), - } - } - - fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> { - match self { - ClientMode::Tcp => tls, - // TLS is None here if using websockets, because the connection is already encrypted. - ClientMode::Websockets { .. } => None, - } - } -} - -#[derive(Debug, Error)] -// almost all errors should be reported to the user, but there's a few cases where we cannot -// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons -// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation, -// we cannot be sure the client even understands our error message -// 3. PrepareClient: The client disconnected, so we can't tell them anyway... -pub(crate) enum ClientRequestError { - #[error("{0}")] - Cancellation(#[from] cancellation::CancelError), - #[error("{0}")] - Handshake(#[from] HandshakeError), - #[error("{0}")] - HandshakeTimeout(#[from] tokio::time::error::Elapsed), - #[error("{0}")] - PrepareClient(#[from] std::io::Error), - #[error("{0}")] - ReportedError(#[from] crate::stream::ReportedError), -} - -impl ReportableError for ClientRequestError { - fn get_error_kind(&self) -> crate::error::ErrorKind { - match self { - ClientRequestError::Cancellation(e) => e.get_error_kind(), - ClientRequestError::Handshake(e) => e.get_error_kind(), - ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit, - ClientRequestError::ReportedError(e) => e.get_error_kind(), - ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect, - } - } -} - #[allow(clippy::too_many_arguments)] pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, cancellation_handler: Arc, - stream: S, - mode: ClientMode, + client: &mut PqStream>, + mode: &ClientMode, endpoint_rate_limiter: Arc, - conn_gauge: NumClientConnectionsGuard<'static>, - cancellations: tokio_util::task::task_tracker::TaskTracker, -) -> Result>, ClientRequestError> { - debug!( - protocol = %ctx.protocol(), - "handling interactive connection from client" - ); - - let metrics = &Metrics::get().proxy; - let proto = ctx.protocol(); - let request_gauge = metrics.connection_requests.guard(proto); - - let tls = config.tls_config.load(); - let tls = tls.as_deref(); - - let record_handshake_error = !ctx.has_private_peer_addr(); - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error); - - let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) - .await?? - { - HandshakeData::Startup(stream, params) => (stream, params), - HandshakeData::Cancel(cancel_key_data) => { - // spawn a task to cancel the session, but don't wait for it - cancellations.spawn({ - let cancellation_handler_clone = Arc::clone(&cancellation_handler); - let ctx = ctx.clone(); - let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id()); - cancel_span.follows_from(tracing::Span::current()); - async move { - cancellation_handler_clone - .cancel_session( - cancel_key_data, - ctx, - config.authentication_config.ip_allowlist_check_enabled, - config.authentication_config.is_vpc_acccess_proxy, - auth_backend.get_api(), - ) - .await - .inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok(); - }.instrument(cancel_span) - }); - - return Ok(None); - } - }; - drop(pause); - - ctx.set_db_options(params.clone()); - - let hostname = mode.hostname(stream.get_ref()); - - let common_names = tls.map(|tls| &tls.common_names); - + common_names: Option<&HashSet>, + params: &StartupMessageParams, +) -> Result<(ComputeConnection, oneshot::Sender), ClientRequestError> { + let hostname = mode.hostname(client.get_ref()); // Extract credentials which we're going to use for auth. let result = auth_backend .as_ref() - .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) + .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, params, hostname, common_names)) .transpose(); let user_info = match result { Ok(user_info) => user_info, - Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, }; let user = user_info.get_user().to_owned(); let user_info = match user_info .authenticate( ctx, - &mut stream, + client, mode.allow_cleartext(), &config.authentication_config, endpoint_rate_limiter, @@ -337,7 +75,7 @@ pub(crate) async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return Err(stream + return Err(client .throw_error(e, Some(ctx)) .instrument(params_span) .await)?; @@ -350,7 +88,7 @@ pub(crate) async fn handle_client( }; let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some(); let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys); - auth_info.set_startup_params(¶ms, params_compat); + auth_info.set_startup_params(params, params_compat); let mut node; let mut attempt = 0; @@ -370,6 +108,7 @@ pub(crate) async fn handle_client( let pg_settings = loop { attempt += 1; + // TODO: callback to pglb let res = connect_to_compute( ctx, &connect, @@ -381,7 +120,7 @@ pub(crate) async fn handle_client( match res { Ok(n) => node = n, - Err(e) => return Err(stream.throw_error(e, Some(ctx)).await)?, + Err(e) => return Err(client.throw_error(e, Some(ctx)).await)?, } let auth::Backend::ControlPlane(cplane, user_info) = &backend else { @@ -400,17 +139,16 @@ pub(crate) async fn handle_client( cplane_proxy_v1.caches.node_info.invalidate(&key); } } - Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, } }; let session = cancellation_handler.get_key(); - prepare_client_connection(&pg_settings, *session.key(), &mut stream); - let stream = stream.flush_and_into_inner().await?; + finish_client_init(&pg_settings, *session.key(), client); let session_id = ctx.session_id(); - let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); + let (cancel_on_shutdown, cancel) = oneshot::channel(); tokio::spawn(async move { session .maintain_cancel_key( @@ -422,50 +160,32 @@ pub(crate) async fn handle_client( .await; }); - let private_link_id = match ctx.extra() { - Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), - Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), - None => None, - }; - - Ok(Some(ProxyPassthrough { - client: stream, - compute: node.stream, - - aux: node.aux, - private_link_id, - - _cancel_on_shutdown: cancel_on_shutdown, - - _req: request_gauge, - _conn: conn_gauge, - _db_conn: node.guage, - })) + Ok((node, cancel_on_shutdown)) } /// Finish client connection initialization: confirm auth success, send params, etc. -pub(crate) fn prepare_client_connection( +pub(crate) fn finish_client_init( settings: &compute::PostgresSettings, cancel_key_data: CancelKeyData, - stream: &mut PqStream, + client: &mut PqStream, ) { // Forward all deferred notices to the client. for notice in &settings.delayed_notice { - stream.write_raw(notice.as_bytes().len(), b'N', |buf| { + client.write_raw(notice.as_bytes().len(), b'N', |buf| { buf.extend_from_slice(notice.as_bytes()); }); } // Forward all postgres connection params to the client. for (name, value) in &settings.params { - stream.write_message(BeMessage::ParameterStatus { + client.write_message(BeMessage::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), }); } - stream.write_message(BeMessage::BackendKeyData(cancel_key_data)); - stream.write_message(BeMessage::ReadyForQuery); + client.write_message(BeMessage::BackendKeyData(cancel_key_data)); + client.write_message(BeMessage::ReadyForQuery); } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] @@ -475,7 +195,7 @@ impl NeonOptions { // proxy options: /// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute. - const PARAMS_COMPAT: &str = "proxy_params_compat"; + pub const PARAMS_COMPAT: &str = "proxy_params_compat"; // cplane options: diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 67dd0ab522..b09d8edc4c 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -14,6 +14,9 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; use tokio_util::codec::{Decoder, Encoder}; use super::*; +use crate::config::TlsConfig; +use crate::context::RequestContext; +use crate::pglb::handshake::{HandshakeData, handshake}; enum Intercept { None, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 4f27496019..dd89b05426 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -3,6 +3,7 @@ mod mitm; +use std::sync::Arc; use std::time::Duration; use anyhow::{Context, bail}; @@ -10,26 +11,31 @@ use async_trait::async_trait; use http::StatusCode; use postgres_client::config::SslMode; use postgres_client::tls::{MakeTlsConnect, NoTls}; -use retry::{ShouldRetryWakeCompute, retry_after}; use rstest::rstest; use rustls::crypto::ring; use rustls::pki_types; -use tokio::io::DuplexStream; +use tokio::io::{AsyncRead, AsyncWrite, DuplexStream}; use tracing_test::traced_test; use super::retry::CouldRetry; -use super::*; use crate::auth::backend::{ComputeUserInfo, MaybeOwned}; -use crate::config::{ComputeConfig, RetryConfig}; +use crate::config::{ComputeConfig, RetryConfig, TlsConfig}; +use crate::context::RequestContext; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; -use crate::error::ErrorKind; -use crate::proxy::connect_compute::ConnectMechanism; +use crate::error::{ErrorKind, ReportableError}; +use crate::pglb::ERR_INSECURE_CONNECTION; +use crate::pglb::handshake::{HandshakeData, handshake}; +use crate::pqproto::BeMessage; +use crate::proxy::NeonOptions; +use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute}; +use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after}; +use crate::stream::{PqStream, Stream}; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::server_config::CertResolver; use crate::types::{BranchId, EndpointId, ProjectId}; -use crate::{sasl, scram}; +use crate::{auth, compute, sasl, scram}; /// Generate a set of TLS certificates: CA + server. fn generate_certs( diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 0d374e6df2..1960709fba 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -17,7 +17,8 @@ use crate::config::ProxyConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; -use crate::proxy::{ClientMode, ErrorSource, handle_client}; +use crate::pglb::{ClientMode, handle_connection}; +use crate::proxy::ErrorSource; use crate::rate_limiter::EndpointRateLimiter; pin_project! { @@ -142,7 +143,7 @@ pub(crate) async fn serve_websocket( .client_connections .guard(crate::metrics::Protocol::Ws); - let res = Box::pin(handle_client( + let res = Box::pin(handle_connection( config, auth_backend, &ctx,