diff --git a/libs/pq_proto/src/framed.rs b/libs/pq_proto/src/framed.rs index 6e97b8c2a0..d3d3981922 100644 --- a/libs/pq_proto/src/framed.rs +++ b/libs/pq_proto/src/framed.rs @@ -82,6 +82,19 @@ impl Framed { write_buf: self.write_buf, }) } + + /// Return new Framed with stream type transformed by f. For dynamic dispatch. + pub fn map_stream_sync(self, f: F) -> Framed + where + F: FnOnce(S) -> S2, + { + let stream = f(self.stream); + Framed { + stream, + read_buf: self.read_buf, + write_buf: self.write_buf, + } + } } impl Framed { diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 0707c1331f..d63228291b 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -13,7 +13,7 @@ use password_hack::PasswordHackPayload; mod flow; pub use flow::*; -use crate::{console, error::UserFacingError}; +use crate::error::{ReportableError, UserFacingError}; use std::io; use thiserror::Error; @@ -23,15 +23,6 @@ pub type Result = std::result::Result; /// Common authentication error. #[derive(Debug, Error)] pub enum AuthErrorImpl { - #[error(transparent)] - Link(#[from] backend::LinkAuthError), - - #[error(transparent)] - GetAuthInfo(#[from] console::errors::GetAuthInfoError), - - #[error(transparent)] - WakeCompute(#[from] console::errors::WakeComputeError), - /// SASL protocol errors (includes [SCRAM](crate::scram)). #[error(transparent)] Sasl(#[from] crate::sasl::Error), @@ -99,13 +90,25 @@ impl> From for AuthError { } } +impl ReportableError for AuthError { + fn get_error_type(&self) -> crate::error::ErrorKind { + match self.0.as_ref() { + AuthErrorImpl::Sasl(s) => s.get_error_type(), + AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User, + AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User, + AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User, + AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User, + AuthErrorImpl::Io(_) => crate::error::ErrorKind::Disconnect, + AuthErrorImpl::IpAddressNotAllowed => crate::error::ErrorKind::User, + AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit, + } + } +} + impl UserFacingError for AuthError { fn to_string_client(&self) -> String { use AuthErrorImpl::*; match self.0.as_ref() { - Link(e) => e.to_string_client(), - GetAuthInfo(e) => e.to_string_client(), - WakeCompute(e) => e.to_string_client(), Sasl(e) => e.to_string_client(), AuthFailed(_) => self.to_string(), BadAuthMethod(_) => self.to_string(), diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index a6164f7bfb..a6b49c3873 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -2,22 +2,27 @@ mod classic; mod hacks; mod link; -pub use link::LinkAuthError; +use pq_proto::StartupMessageParams; use smol_str::SmolStr; use tokio_postgres::config::AuthKeys; +use crate::auth::backend::link::NeedsLinkAuthentication; use crate::auth::credentials::check_peer_addr_is_in_list; use crate::auth::validate_password_and_exchange; use crate::cache::Cached; +use crate::cancellation::Session; +use crate::config::ProxyConfig; use crate::console::errors::GetAuthInfoError; use crate::console::provider::ConsoleBackend; use crate::console::AuthSecret; use crate::context::RequestMonitoring; -use crate::proxy::connect_compute::handle_try_wake; -use crate::proxy::retry::retry_after; +use crate::proxy::wake_compute::NeedsWakeCompute; +use crate::proxy::ClientMode; use crate::proxy::NeonOptions; +use crate::rate_limiter::EndpointRateLimiter; use crate::scram; -use crate::stream::Stream; +use crate::state_machine::{user_facing_error, DynStage, ResultExt, Stage, StageError}; +use crate::stream::{PqStream, Stream}; use crate::{ auth::{self, ComputeUserInfoMaybeEndpoint}, config::AuthenticationConfig, @@ -30,10 +35,11 @@ use crate::{ }; use futures::TryFutureExt; use std::borrow::Cow; -use std::ops::ControlFlow; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{error, info, warn}; +use tracing::info; + +use self::hacks::NeedsPasswordHack; /// This type serves two purposes: /// @@ -170,66 +176,94 @@ impl TryFrom for ComputeUserInfo { } } -/// True to its name, this function encapsulates our current auth trade-offs. -/// Here, we choose the appropriate auth flow based on circumstances. -/// -/// All authentication flows will emit an AuthenticationOk message if successful. -async fn auth_quirks( - ctx: &mut RequestMonitoring, - api: &impl console::Api, - user_info: ComputeUserInfoMaybeEndpoint, - client: &mut stream::PqStream>, +struct NeedsAuthSecret { + stream: PqStream>, + api: Cow<'static, ConsoleBackend>, + params: StartupMessageParams, + allow_self_signed_compute: bool, allow_cleartext: bool, + info: ComputeUserInfo, + unauthenticated_password: Option>, config: &'static AuthenticationConfig, -) -> auth::Result> { - // If there's no project so far, that entails that client doesn't - // support SNI or other means of passing the endpoint (project) name. - // We now expect to see a very specific payload in the place of password. - let (info, unauthenticated_password) = match user_info.try_into() { - Err(info) => { - let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer) - .await?; - ctx.set_endpoint_id(Some(res.info.endpoint.clone())); - (res.info, Some(res.keys)) - } - Ok(info) => (info, None), - }; - info!("fetching user's authentication info"); - let allowed_ips = api.get_allowed_ips(ctx, &info).await?; + // monitoring + ctx: RequestMonitoring, + cancel_session: Session, +} - // check allowed list - if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed()); +impl Stage for NeedsAuthSecret { + fn span(&self) -> tracing::Span { + tracing::info_span!("get_auth_secret") } - let cached_secret = api.get_role_secret(ctx, &info).await?; + async fn run(self) -> Result { + let Self { + stream, + api, + params, + allow_cleartext, + allow_self_signed_compute, + info, + unauthenticated_password, + config, + mut ctx, + cancel_session, + } = self; - let secret = cached_secret.value.clone().unwrap_or_else(|| { - // If we don't have an authentication secret, we mock one to - // prevent malicious probing (possible due to missing protocol steps). - // This mocked secret will never lead to successful authentication. - info!("authentication info not found, mocking it"); - AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random())) - }); - match authenticate_with_secret( - ctx, - secret, - info, - client, - unauthenticated_password, - allow_cleartext, - config, - ) - .await - { - Ok(keys) => Ok(keys), - Err(e) => { + info!("fetching user's authentication info"); + let (allowed_ips, stream) = api + .get_allowed_ips(&mut ctx, &info) + .await + .send_error_to_user(&mut ctx, stream)?; + + // check allowed list + if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { + return Err(user_facing_error( + auth::AuthError::ip_address_not_allowed(), + &mut ctx, + stream, + )); + } + let (cached_secret, mut stream) = api + .get_role_secret(&mut ctx, &info) + .await + .send_error_to_user(&mut ctx, stream)?; + + let secret = cached_secret.value.clone().unwrap_or_else(|| { + // If we don't have an authentication secret, we mock one to + // prevent malicious probing (possible due to missing protocol steps). + // This mocked secret will never lead to successful authentication. + info!("authentication info not found, mocking it"); + AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random())) + }); + + let (keys, stream) = authenticate_with_secret( + &mut ctx, + secret, + info, + &mut stream, + unauthenticated_password, + allow_cleartext, + config, + ) + .await + .map_err(|e| { if e.is_auth_failed() { // The password could have been changed, so we invalidate the cache. cached_secret.invalidate(); } - Err(e) - } + e + }) + .send_error_to_user(&mut ctx, stream)?; + + Ok(Box::new(NeedsWakeCompute { + stream, + api, + params, + allow_self_signed_compute, + creds: keys, + ctx, + cancel_session, + })) } } @@ -270,49 +304,6 @@ async fn authenticate_with_secret( classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await } -/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache) -/// only if authentication was successfuly. -async fn auth_and_wake_compute( - ctx: &mut RequestMonitoring, - api: &impl console::Api, - user_info: ComputeUserInfoMaybeEndpoint, - client: &mut stream::PqStream>, - allow_cleartext: bool, - config: &'static AuthenticationConfig, -) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> { - let compute_credentials = - auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?; - - let mut num_retries = 0; - let mut node = loop { - let wake_res = api.wake_compute(ctx, &compute_credentials.info).await; - match handle_try_wake(wake_res, num_retries) { - Err(e) => { - error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); - return Err(e.into()); - } - Ok(ControlFlow::Continue(e)) => { - warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); - } - Ok(ControlFlow::Break(n)) => break n, - } - - let wait_duration = retry_after(num_retries); - num_retries += 1; - tokio::time::sleep(wait_duration).await; - }; - - ctx.set_project(node.aux.clone()); - - match compute_credentials.keys { - #[cfg(feature = "testing")] - ComputeCredentialKeys::Password(password) => node.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), - }; - - Ok((node, compute_credentials.info)) -} - impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { /// Get compute endpoint name from the credentials. pub fn get_endpoint(&self) -> Option { @@ -337,50 +328,96 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { Test(_) => "test", } } +} - /// Authenticate the client via the requested backend, possibly using credentials. - #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)] - pub async fn authenticate( - self, - ctx: &mut RequestMonitoring, - client: &mut stream::PqStream>, - allow_cleartext: bool, - config: &'static AuthenticationConfig, - ) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> { - use BackendType::*; +pub struct NeedsAuthentication { + pub stream: PqStream>, + pub creds: BackendType<'static, auth::ComputeUserInfoMaybeEndpoint>, + pub params: StartupMessageParams, + pub endpoint_rate_limiter: Arc, + pub mode: ClientMode, + pub config: &'static ProxyConfig, - let res = match self { - Console(api, user_info) => { - info!( - user = &*user_info.user, - project = user_info.project(), - "performing authentication using the console" - ); + // monitoring + pub ctx: RequestMonitoring, + pub cancel_session: Session, +} - let (cache_info, user_info) = - auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config) - .await?; - (cache_info, BackendType::Console(api, user_info)) +impl Stage for NeedsAuthentication { + fn span(&self) -> tracing::Span { + tracing::info_span!("authenticate") + } + async fn run(self) -> Result { + let Self { + stream, + creds, + params, + endpoint_rate_limiter, + mode, + config, + mut ctx, + cancel_session, + } = self; + + // check rate limit + if let Some(ep) = creds.get_endpoint() { + if !endpoint_rate_limiter.check(ep) { + return Err(user_facing_error( + auth::AuthError::too_many_connections(), + &mut ctx, + stream, + )); + } + } + + let allow_self_signed_compute = mode.allow_self_signed_compute(config); + let allow_cleartext = mode.allow_cleartext(); + + match creds { + BackendType::Console(api, creds) => { + // If there's no project so far, that entails that client doesn't + // support SNI or other means of passing the endpoint (project) name. + // We now expect to see a very specific payload in the place of password. + match creds.try_into() { + Err(info) => Ok(Box::new(NeedsPasswordHack { + stream, + api, + params, + allow_self_signed_compute, + info, + allow_cleartext, + config: &config.authentication_config, + ctx, + cancel_session, + })), + Ok(info) => Ok(Box::new(NeedsAuthSecret { + stream, + api, + params, + allow_self_signed_compute, + info, + unauthenticated_password: None, + allow_cleartext, + config: &config.authentication_config, + ctx, + cancel_session, + })), + } } // NOTE: this auth backend doesn't use client credentials. - Link(url) => { - info!("performing link authentication"); - - let node_info = link::authenticate(&url, client).await?; - - ( - CachedNodeInfo::new_uncached(node_info), - BackendType::Link(url), - ) - } + BackendType::Link(link) => Ok(Box::new(NeedsLinkAuthentication { + stream, + link, + params, + allow_self_signed_compute, + ctx, + cancel_session, + })), #[cfg(test)] - Test(_) => { + BackendType::Test(_) => { unreachable!("this function should never be called in the test backend") } - }; - - info!("user successfully authenticated"); - Ok(res) + } } } diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index b6c1a92d3c..387ab2238b 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -1,13 +1,21 @@ +use std::borrow::Cow; + use super::{ ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint, + NeedsAuthSecret, }; use crate::{ auth::{self, AuthFlow}, - console::AuthSecret, + cancellation::Session, + config::AuthenticationConfig, + console::{provider::ConsoleBackend, AuthSecret}, + context::RequestMonitoring, metrics::LatencyTimer, sasl, - stream::{self, Stream}, + state_machine::{DynStage, ResultExt, Stage, StageError}, + stream::{self, PqStream, Stream}, }; +use pq_proto::StartupMessageParams; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; @@ -46,7 +54,7 @@ pub async fn authenticate_cleartext( /// Workaround for clients which don't provide an endpoint (project) name. /// Similar to [`authenticate_cleartext`], but there's a specific password format, /// and passwords are not yet validated (we don't know how to validate them!) -pub async fn password_hack_no_authentication( +async fn password_hack_no_authentication( info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, latency_timer: &mut LatencyTimer, @@ -74,3 +82,47 @@ pub async fn password_hack_no_authentication( keys: payload.password, }) } + +pub struct NeedsPasswordHack { + pub stream: PqStream>, + pub api: Cow<'static, ConsoleBackend>, + pub params: StartupMessageParams, + pub allow_self_signed_compute: bool, + pub allow_cleartext: bool, + pub info: ComputeUserInfoNoEndpoint, + pub config: &'static AuthenticationConfig, + + // monitoring + pub ctx: RequestMonitoring, + pub cancel_session: Session, +} + +impl Stage for NeedsPasswordHack { + fn span(&self) -> tracing::Span { + tracing::info_span!("password_hack") + } + async fn run(mut self) -> Result { + let (res, stream) = password_hack_no_authentication( + self.info, + &mut self.stream, + &mut self.ctx.latency_timer, + ) + .await + .send_error_to_user(&mut self.ctx, self.stream)?; + + self.ctx.set_endpoint_id(Some(res.info.endpoint.clone())); + Ok(Box::new(NeedsAuthSecret { + stream, + info: res.info, + unauthenticated_password: Some(res.keys), + + api: self.api, + params: self.params, + allow_self_signed_compute: self.allow_self_signed_compute, + allow_cleartext: self.allow_cleartext, + ctx: self.ctx, + cancel_session: self.cancel_session, + config: self.config, + })) + } +} diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index a7ddd257b3..480f9f18ef 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -1,41 +1,20 @@ +use std::borrow::Cow; + use crate::{ - auth, compute, - console::{self, provider::NodeInfo}, - error::UserFacingError, - stream::PqStream, - waiters, + auth::BackendType, + cancellation::Session, + compute, + console::{self, mgmt::ComputeReady, provider::NodeInfo, CachedNodeInfo}, + context::RequestMonitoring, + proxy::connect_compute::{NeedsComputeConnection, TcpMechanism}, + state_machine::{DynStage, ResultExt, Stage, StageError}, + stream::{PqStream, Stream}, + waiters::Waiter, }; -use pq_proto::BeMessage as Be; -use thiserror::Error; +use pq_proto::{BeMessage as Be, StartupMessageParams}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_postgres::config::SslMode; -use tracing::{info, info_span}; - -#[derive(Debug, Error)] -pub enum LinkAuthError { - /// Authentication error reported by the console. - #[error("Authentication failed: {0}")] - AuthFailed(String), - - #[error(transparent)] - WaiterRegister(#[from] waiters::RegisterError), - - #[error(transparent)] - WaiterWait(#[from] waiters::WaitError), - - #[error(transparent)] - Io(#[from] std::io::Error), -} - -impl UserFacingError for LinkAuthError { - fn to_string_client(&self) -> String { - use LinkAuthError::*; - match self { - AuthFailed(_) => self.to_string(), - _ => "Internal error".to_string(), - } - } -} +use tracing::info; fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String { format!( @@ -53,64 +32,146 @@ pub fn new_psql_session_id() -> String { hex::encode(rand::random::<[u8; 8]>()) } -pub(super) async fn authenticate( - link_uri: &reqwest::Url, - client: &mut PqStream, -) -> auth::Result { - // registering waiter can fail if we get unlucky with rng. - // just try again. - let (psql_session_id, waiter) = loop { - let psql_session_id = new_psql_session_id(); +pub struct NeedsLinkAuthentication { + pub stream: PqStream>, + pub link: Cow<'static, crate::url::ApiUrl>, + pub params: StartupMessageParams, + pub allow_self_signed_compute: bool, - match console::mgmt::get_waiter(&psql_session_id) { - Ok(waiter) => break (psql_session_id, waiter), - Err(_e) => continue, - } - }; - - let span = info_span!("link", psql_session_id = &psql_session_id); - let greeting = hello_message(link_uri, &psql_session_id); - - // Give user a URL to spawn a new database. - info!(parent: &span, "sending the auth URL to the user"); - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::NoticeResponse(&greeting)) - .await?; - - // Wait for web console response (see `mgmt`). - info!(parent: &span, "waiting for console's reply..."); - let db_info = waiter.await.map_err(LinkAuthError::from)?; - - client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; - - // This config should be self-contained, because we won't - // take username or dbname from client's startup message. - let mut config = compute::ConnCfg::new(); - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); - - // Backwards compatibility. pg_sni_proxy uses "--" in domain names - // while direct connections do not. Once we migrate to pg_sni_proxy - // everywhere, we can remove this. - if db_info.host.contains("--") { - // we need TLS connection with SNI info to properly route it - config.ssl_mode(SslMode::Require); - } else { - config.ssl_mode(SslMode::Disable); - } - - if let Some(password) = db_info.password { - config.password(password.as_ref()); - } - - Ok(NodeInfo { - config, - aux: db_info.aux, - allow_self_signed_compute: false, // caller may override - }) + // monitoring + pub ctx: RequestMonitoring, + pub cancel_session: Session, +} + +impl Stage for NeedsLinkAuthentication { + fn span(&self) -> tracing::Span { + tracing::info_span!("link", psql_session_id = tracing::field::Empty) + } + async fn run(self) -> Result { + let Self { + mut stream, + link, + params, + allow_self_signed_compute, + mut ctx, + cancel_session, + } = self; + + // registering waiter can fail if we get unlucky with rng. + // just try again. + let (psql_session_id, waiter) = loop { + let psql_session_id = new_psql_session_id(); + + match console::mgmt::get_waiter(&psql_session_id) { + Ok(waiter) => break (psql_session_id, waiter), + Err(_e) => continue, + } + }; + tracing::Span::current().record("psql_session_id", &psql_session_id); + let greeting = hello_message(&link, &psql_session_id); + + info!("sending the auth URL to the user"); + + stream + .write_message_noflush(&Be::AuthenticationOk) + .and_then(|s| s.write_message_noflush(&Be::CLIENT_ENCODING)) + .and_then(|s| s.write_message_noflush(&Be::NoticeResponse(&greeting))) + .no_user_error(&mut ctx, crate::error::ErrorKind::Service)? + .flush() + .await + .no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?; + + Ok(Box::new(NeedsLinkAuthenticationResponse { + stream, + link, + params, + allow_self_signed_compute, + waiter, + psql_session_id, + ctx, + cancel_session, + })) + } +} + +struct NeedsLinkAuthenticationResponse { + stream: PqStream>, + link: Cow<'static, crate::url::ApiUrl>, + params: StartupMessageParams, + allow_self_signed_compute: bool, + waiter: Waiter<'static, ComputeReady>, + psql_session_id: String, + + // monitoring + ctx: RequestMonitoring, + cancel_session: Session, +} + +impl Stage + for NeedsLinkAuthenticationResponse +{ + fn span(&self) -> tracing::Span { + tracing::info_span!("link_wait", psql_session_id = self.psql_session_id) + } + async fn run(self) -> Result { + let Self { + mut stream, + link, + params, + allow_self_signed_compute, + waiter, + psql_session_id: _, + mut ctx, + cancel_session, + } = self; + + // Wait for web console response (see `mgmt`). + info!("waiting for console's reply..."); + let db_info = waiter + .await + .no_user_error(&mut ctx, crate::error::ErrorKind::Service)?; + + stream + .write_message_noflush(&Be::NoticeResponse("Connecting to database.")) + .no_user_error(&mut ctx, crate::error::ErrorKind::Service)?; + + // This config should be self-contained, because we won't + // take username or dbname from client's startup message. + let mut config = compute::ConnCfg::new(); + config + .host(&db_info.host) + .port(db_info.port) + .dbname(&db_info.dbname) + .user(&db_info.user); + + // Backwards compatibility. pg_sni_proxy uses "--" in domain names + // while direct connections do not. Once we migrate to pg_sni_proxy + // everywhere, we can remove this. + if db_info.host.contains("--") { + // we need TLS connection with SNI info to properly route it + config.ssl_mode(SslMode::Require); + } else { + config.ssl_mode(SslMode::Disable); + } + + if let Some(password) = db_info.password { + config.password(password.as_ref()); + } + + let node_info = CachedNodeInfo::new_uncached(NodeInfo { + config, + aux: db_info.aux, + allow_self_signed_compute, + }); + let user_info = BackendType::Link(link); + + Ok(Box::new(NeedsComputeConnection { + stream, + user_info, + mechanism: TcpMechanism { params }, + node_info, + ctx, + cancel_session, + })) + } } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index ada7f3614c..ac8c6f26fc 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -1,8 +1,11 @@ //! User credentials used in authentication. use crate::{ - auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError, - metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, + auth::password_hack::parse_endpoint_param, + context::RequestMonitoring, + error::{ReportableError, UserFacingError}, + metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, + proxy::NeonOptions, }; use itertools::Itertools; use pq_proto::StartupMessageParams; @@ -33,7 +36,24 @@ pub enum ComputeUserInfoParseError { MalformedProjectName(SmolStr), } -impl UserFacingError for ComputeUserInfoParseError {} +impl ReportableError for ComputeUserInfoParseError { + fn get_error_type(&self) -> crate::error::ErrorKind { + match self { + ComputeUserInfoParseError::MissingKey(_) => crate::error::ErrorKind::User, + ComputeUserInfoParseError::InconsistentProjectNames { .. } => { + crate::error::ErrorKind::User + } + ComputeUserInfoParseError::UnknownCommonName { .. } => crate::error::ErrorKind::User, + ComputeUserInfoParseError::MalformedProjectName(_) => crate::error::ErrorKind::User, + } + } +} + +impl UserFacingError for ComputeUserInfoParseError { + fn to_string_client(&self) -> String { + self.to_string() + } +} /// Various client credentials which we use for authentication. /// Note that we don't store any kind of client key or password here. diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 1edbc1e7e7..eaba6ebe93 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -164,6 +164,13 @@ async fn task_main( let tls_config = Arc::clone(&tls_config); let dest_suffix = Arc::clone(&dest_suffix); + let root_span = tracing::info_span!( + "handle_client", + ?session_id, + endpoint = tracing::field::Empty + ); + let root_span2 = root_span.clone(); + connections.spawn( async move { socket @@ -171,8 +178,13 @@ async fn task_main( .context("failed to set socket option")?; info!(%peer_addr, "serving"); - let mut ctx = - RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni"); + let mut ctx = RequestMonitoring::new( + session_id, + peer_addr.ip(), + "sni_router", + "sni", + root_span2, + ); handle_client( &mut ctx, dest_suffix, @@ -186,7 +198,7 @@ async fn task_main( // Acknowledge that the task has finished with an error. error!("per-client task finished with an error: {e:#}"); }) - .instrument(tracing::info_span!("handle_client", ?session_id)), + .instrument(root_span), ); } @@ -271,6 +283,7 @@ async fn handle_client( let client = tokio::net::TcpStream::connect(destination).await?; + ctx.log(); let metrics_aux: MetricsAuxInfo = Default::default(); - proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await + proxy::proxy::pass::proxy_pass(tls_stream, client, metrics_aux).await } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a5eb3544b4..9f30e864e5 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,7 +1,7 @@ -use anyhow::{bail, Context}; +use anyhow::Context; use dashmap::DashMap; use pq_proto::CancelKeyData; -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpStream; use tokio_postgres::{CancelToken, NoTls}; use tracing::info; @@ -25,39 +25,33 @@ impl CancelMap { } /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result - where - F: FnOnce(Session<'a>) -> R, - R: std::future::Future>, - { + pub fn get_session(self: Arc) -> Session { // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the // actual backend_pid, but backend_pid is not used for anything // so it doesn't matter. - let key = rand::random(); + let key = loop { + let key = rand::random(); - // Random key collisions are unlikely to happen here, but they're still possible, - // which is why we have to take care not to rewrite an existing key. - match self.0.entry(key) { - dashmap::mapref::entry::Entry::Occupied(_) => { - bail!("query cancellation key already exists: {key}") + // Random key collisions are unlikely to happen here, but they're still possible, + // which is why we have to take care not to rewrite an existing key. + match self.0.entry(key) { + dashmap::mapref::entry::Entry::Occupied(_) => { + continue; + } + dashmap::mapref::entry::Entry::Vacant(e) => { + e.insert(None); + } } - dashmap::mapref::entry::Entry::Vacant(e) => { - e.insert(None); - } - } - - // This will guarantee that the session gets dropped - // as soon as the future is finished. - scopeguard::defer! { - self.0.remove(&key); - info!("dropped query cancellation key {key}"); - } + break key; + }; info!("registered new query cancellation key {key}"); - let session = Session::new(key, self); - f(session).await + Session { + key, + cancel_map: self, + } } #[cfg(test)] @@ -98,23 +92,17 @@ impl CancelClosure { } /// Helper for registering query cancellation tokens. -pub struct Session<'a> { +pub struct Session { /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. - cancel_map: &'a CancelMap, + cancel_map: Arc, } -impl<'a> Session<'a> { - fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self { - Self { key, cancel_map } - } -} - -impl Session<'_> { +impl Session { /// Store the cancel token for the given session. /// This enables query cancellation in `crate::proxy::prepare_client_connection`. - pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData { + pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); self.cancel_map.0.insert(self.key, Some(cancel_closure)); @@ -122,37 +110,26 @@ impl Session<'_> { } } +impl Drop for Session { + fn drop(&mut self) { + self.cancel_map.0.remove(&self.key); + info!("dropped query cancellation key {}", &self.key); + } +} + #[cfg(test)] mod tests { use super::*; - use once_cell::sync::Lazy; #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { - static CANCEL_MAP: Lazy = Lazy::new(Default::default); - - let (tx, rx) = tokio::sync::oneshot::channel(); - let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move { - assert!(CANCEL_MAP.contains(&session)); - - tx.send(()).expect("failed to send"); - futures::future::pending::<()>().await; // sleep forever - - Ok(()) - })); - - // Wait until the task has been spawned. - rx.await.context("failed to hear from the task")?; - - // Drop the session's entry by cancelling the task. - task.abort(); - let error = task.await.expect_err("task should have failed"); - if !error.is_cancelled() { - anyhow::bail!(error); - } + let cancel_map: Arc = Default::default(); + let session = cancel_map.clone().get_session(); + assert!(cancel_map.contains(&session)); + drop(session); // Check that the session has been dropped. - assert!(CANCEL_MAP.is_empty()); + assert!(cancel_map.is_empty()); Ok(()) } diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index aef1aab733..9f662f67ff 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,6 +1,10 @@ use crate::{ - auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError, - context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE, + auth::parse_endpoint_param, + cancellation::CancelClosure, + console::errors::WakeComputeError, + context::RequestMonitoring, + error::{ReportableError, UserFacingError}, + metrics::NUM_DB_CONNECTIONS_GAUGE, proxy::neon_option, }; use futures::{FutureExt, TryFutureExt}; @@ -32,6 +36,17 @@ pub enum ConnectionError { WakeComputeError(#[from] WakeComputeError), } +impl ReportableError for ConnectionError { + fn get_error_type(&self) -> crate::error::ErrorKind { + match self { + ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, + ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute, + ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, + ConnectionError::WakeComputeError(_) => crate::error::ErrorKind::ControlPlane, + } + } +} + impl UserFacingError for ConnectionError { fn to_string_client(&self) -> String { use ConnectionError::*; diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index bbcddae86c..b308fddade 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -21,7 +21,7 @@ use tracing::info; pub mod errors { use crate::{ - error::{io_error, UserFacingError}, + error::{io_error, ReportableError, UserFacingError}, http, proxy::retry::ShouldRetry, }; @@ -56,6 +56,15 @@ pub mod errors { } } + impl ReportableError for ApiError { + fn get_error_type(&self) -> crate::error::ErrorKind { + match self { + ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane, + ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane, + } + } + } + impl UserFacingError for ApiError { fn to_string_client(&self) -> String { use ApiError::*; @@ -140,6 +149,15 @@ pub mod errors { } } + impl ReportableError for GetAuthInfoError { + fn get_error_type(&self) -> crate::error::ErrorKind { + match self { + GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane, + GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane, + } + } + } + impl UserFacingError for GetAuthInfoError { fn to_string_client(&self) -> String { use GetAuthInfoError::*; @@ -181,6 +199,16 @@ pub mod errors { } } + impl ReportableError for WakeComputeError { + fn get_error_type(&self) -> crate::error::ErrorKind { + match self { + WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane, + WakeComputeError::ApiError(e) => e.get_error_type(), + WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit, + } + } + } + impl UserFacingError for WakeComputeError { fn to_string_client(&self) -> String { use WakeComputeError::*; diff --git a/proxy/src/context.rs b/proxy/src/context.rs index 8a1aa4aec9..f30731ac6e 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -38,6 +38,7 @@ pub struct RequestMonitoring { // This sender is here to keep the request monitoring channel open while requests are taking place. sender: Option>, pub latency_timer: LatencyTimer, + root_span: tracing::Span, } impl RequestMonitoring { @@ -46,6 +47,7 @@ impl RequestMonitoring { peer_addr: IpAddr, protocol: &'static str, region: &'static str, + root_span: tracing::Span, ) -> Self { Self { peer_addr, @@ -64,12 +66,19 @@ impl RequestMonitoring { sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()), latency_timer: LatencyTimer::new(protocol), + root_span, } } #[cfg(test)] pub fn test() -> Self { - RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test") + RequestMonitoring::new( + Uuid::now_v7(), + [127, 0, 0, 1].into(), + "test", + "test", + tracing::Span::none(), + ) } pub fn console_application_name(&self) -> String { @@ -87,7 +96,10 @@ impl RequestMonitoring { } pub fn set_endpoint_id(&mut self, endpoint_id: Option) { - self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone()); + if let (None, Some(ep)) = (self.endpoint_id.as_ref(), endpoint_id) { + self.root_span.record("ep", &*ep); + self.endpoint_id = Some(ep) + } } pub fn set_application(&mut self, app: Option) { @@ -102,6 +114,10 @@ impl RequestMonitoring { self.success = true; } + pub fn error(&mut self, err: ErrorKind) { + self.error_kind = Some(err); + } + pub fn log(&mut self) { if let Some(tx) = self.sender.take() { let _: Result<(), _> = tx.send(self.clone()); diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 5b2dd7ecfd..cbe8656c6f 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -17,19 +17,16 @@ pub fn log_error(e: E) -> E { /// NOTE: This trait should not be implemented for [`anyhow::Error`], since it /// is way too convenient and tends to proliferate all across the codebase, /// ultimately leading to accidental leaks of sensitive data. -pub trait UserFacingError: fmt::Display { +pub trait UserFacingError: ReportableError { /// Format the error for client, stripping all sensitive info. /// /// Although this might be a no-op for many types, it's highly /// recommended to override the default impl in case error type /// contains anything sensitive: various IDs, IP addresses etc. - #[inline(always)] - fn to_string_client(&self) -> String { - self.to_string() - } + fn to_string_client(&self) -> String; } -#[derive(Clone)] +#[derive(Clone, Copy)] pub enum ErrorKind { /// Wrong password, unknown endpoint, protocol violation, etc... User, @@ -62,3 +59,7 @@ impl ErrorKind { } } } + +pub trait ReportableError: fmt::Display + Send + 'static { + fn get_error_type(&self) -> ErrorKind; +} diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index a22b2459b8..f8f78947be 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -26,6 +26,7 @@ pub mod redis; pub mod sasl; pub mod scram; pub mod serverless; +pub mod state_machine; pub mod stream; pub mod url; pub mod usage_metrics; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 635d157383..45b2e4d2c7 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -2,38 +2,32 @@ mod tests; pub mod connect_compute; +pub mod handshake; +pub mod pass; pub mod retry; +pub mod wake_compute; use crate::{ - auth, - cancellation::{self, CancelMap}, - compute, - config::{AuthenticationConfig, ProxyConfig, TlsConfig}, - console::messages::MetricsAuxInfo, + cancellation::CancelMap, + config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, - metrics::{ - NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER, - NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE, - }, + metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE}, protocol2::WithClientIp, + proxy::handshake::NeedsHandshake, rate_limiter::EndpointRateLimiter, - stream::{PqStream, Stream}, - usage_metrics::{Ids, USAGE_METRICS}, + state_machine::{DynStage, StageResult}, + stream::Stream, }; -use anyhow::{bail, Context}; -use futures::TryFutureExt; +use anyhow::Context; use itertools::Itertools; use once_cell::sync::OnceCell; -use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; +use pq_proto::StartupMessageParams; use regex::Regex; use smol_str::SmolStr; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, Instrument}; -use utils::measured_stream::MeasuredStream; - -use self::connect_compute::{connect_to_compute, TcpMechanism}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; const ERR_PROTO_VIOLATION: &str = "protocol violation"; @@ -79,45 +73,64 @@ pub async fn task_main( let cancel_map = Arc::clone(&cancel_map); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let root_span = info_span!( + "handle_client", + ?session_id, + peer_addr = tracing::field::Empty, + ep = tracing::field::Empty, + ); + let root_span2 = root_span.clone(); + connections.spawn( async move { info!("accepted postgres client connection"); let mut socket = WithClientIp::new(socket); let mut peer_addr = peer_addr.ip(); - if let Some(addr) = socket.wait_for_addr().await? { - peer_addr = addr.ip(); - tracing::Span::current().record("peer_addr", &tracing::field::display(addr)); - } else if config.require_client_ip { - bail!("missing required client IP"); - } + match socket.wait_for_addr().await { + Err(e) => { + error!("IO error: {e:#}"); + return; + } + Ok(Some(addr)) => { + peer_addr = addr.ip(); + root_span2.record("peer_addr", &tracing::field::display(addr)); + } + Ok(None) if config.require_client_ip => { + error!("missing required client IP"); + return; + } + Ok(None) => {} + }; - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region); + let ctx = RequestMonitoring::new( + session_id, + peer_addr, + "tcp", + &config.region, + root_span2, + ); - socket + if let Err(e) = socket .inner .set_nodelay(true) - .context("failed to set socket option")?; + .context("failed to set socket option") + { + error!("could not set nodelay: {e:#}"); + return; + } handle_client( config, - &mut ctx, - &cancel_map, + ctx, + cancel_map, socket, ClientMode::Tcp, endpoint_rate_limiter, ) - .await + .await; } - .instrument(info_span!( - "handle_client", - ?session_id, - peer_addr = tracing::field::Empty - )) - .unwrap_or_else(move |e| { - // Acknowledge that the task has finished with an error. - error!(?session_id, "per-client task finished with an error: {e:#}"); - }), + .instrument(root_span), ); } @@ -137,14 +150,14 @@ pub enum ClientMode { /// Abstracts the logic of handling TCP vs WS clients impl ClientMode { - fn allow_cleartext(&self) -> bool { + pub fn allow_cleartext(&self) -> bool { match self { ClientMode::Tcp => false, ClientMode::Websockets { .. } => true, } } - fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { + pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { match self { ClientMode::Tcp => config.allow_self_signed_compute, ClientMode::Websockets { .. } => false, @@ -167,14 +180,14 @@ impl ClientMode { } } -pub async fn handle_client( +pub async fn handle_client( config: &'static ProxyConfig, - ctx: &mut RequestMonitoring, - cancel_map: &CancelMap, + ctx: RequestMonitoring, + cancel_map: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, -) -> anyhow::Result<()> { +) { info!( protocol = ctx.protocol, "handling interactive connection from client" @@ -188,308 +201,23 @@ pub async fn handle_client( .with_label_values(&[proto]) .guard(); - let tls = config.tls_config.as_ref(); - - let pause = ctx.latency_timer.pause(); - let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map); - let (mut stream, params) = match do_handshake.await? { - Some(x) => x, - None => return Ok(()), // it's a cancellation request - }; - drop(pause); - - // Extract credentials which we're going to use for auth. - let user_info = { - let hostname = mode.hostname(stream.get_ref()); - - let common_names = tls.map(|tls| &tls.common_names); - let result = config - .auth_backend - .as_ref() - .map(|_| { - auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names) - }) - .transpose(); - - match result { - Ok(user_info) => user_info, - Err(e) => stream.throw_error(e).await?, - } - }; - - ctx.set_endpoint_id(user_info.get_endpoint()); - - let client = Client::new( + let mut stage = Box::new(NeedsHandshake { stream, - user_info, - ¶ms, - mode.allow_self_signed_compute(config), + config, + cancel_map, + mode, endpoint_rate_limiter, - ); - cancel_map - .with_session(|session| { - client.connect_to_db(ctx, session, mode, &config.authentication_config) - }) - .await -} + ctx, + }) as DynStage; -/// Establish a (most probably, secure) connection with the client. -/// For better testing experience, `stream` can be any object satisfying the traits. -/// It's easier to work with owned `stream` here as we need to upgrade it to TLS; -/// we also take an extra care of propagating only the select handshake errors to client. -#[tracing::instrument(skip_all)] -async fn handshake( - stream: S, - mut tls: Option<&TlsConfig>, - cancel_map: &CancelMap, -) -> anyhow::Result>, StartupMessageParams)>> { - // Client may try upgrading to each protocol only once - let (mut tried_ssl, mut tried_gss) = (false, false); - - let mut stream = PqStream::new(Stream::from_raw(stream)); - loop { - let msg = stream.read_startup_packet().await?; - info!("received {msg:?}"); - - use FeStartupPacket::*; - match msg { - SslRequest => match stream.get_ref() { - Stream::Raw { .. } if !tried_ssl => { - tried_ssl = true; - - // We can't perform TLS handshake without a config - let enc = tls.is_some(); - stream.write_message(&Be::EncryptionResponse(enc)).await?; - if let Some(tls) = tls.take() { - // Upgrade raw stream into a secure TLS-backed stream. - // NOTE: We've consumed `tls`; this fact will be used later. - - let (raw, read_buf) = stream.into_inner(); - // TODO: Normally, client doesn't send any data before - // server says TLS handshake is ok and read_buf is empy. - // However, you could imagine pipelining of postgres - // SSLRequest + TLS ClientHello in one hunk similar to - // pipelining in our node js driver. We should probably - // support that by chaining read_buf with the stream. - if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); - } - let tls_stream = raw.upgrade(tls.to_server_config()).await?; - - let (_, tls_server_end_point) = tls - .cert_resolver - .resolve(tls_stream.get_ref().1.server_name()) - .context("missing certificate")?; - - stream = PqStream::new(Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }); - } - } - _ => bail!(ERR_PROTO_VIOLATION), - }, - GssEncRequest => match stream.get_ref() { - Stream::Raw { .. } if !tried_gss => { - tried_gss = true; - - // Currently, we don't support GSSAPI - stream.write_message(&Be::EncryptionResponse(false)).await?; - } - _ => bail!(ERR_PROTO_VIOLATION), - }, - StartupMessage { params, .. } => { - // Check that the config has been consumed during upgrade - // OR we didn't provide it at all (for dev purposes). - if tls.is_some() { - stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; - } - - info!(session_type = "normal", "successful handshake"); - break Ok(Some((stream, params))); - } - CancelRequest(cancel_key_data) => { - cancel_map.cancel_session(cancel_key_data).await?; - - info!(session_type = "cancellation", "successful handshake"); - break Ok(None); - } - } - } -} - -/// Finish client connection initialization: confirm auth success, send params, etc. -#[tracing::instrument(skip_all)] -async fn prepare_client_connection( - node: &compute::PostgresConnection, - session: cancellation::Session<'_>, - stream: &mut PqStream, -) -> anyhow::Result<()> { - // Register compute's query cancellation token and produce a new, unique one. - // The new token (cancel_key_data) will be sent to the client. - let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone()); - - // Forward all postgres connection params to the client. - // Right now the implementation is very hacky and inefficent (ideally, - // we don't need an intermediate hashmap), but at least it should be correct. - for (name, value) in &node.params { - // TODO: Theoretically, this could result in a big pile of params... - stream.write_message_noflush(&Be::ParameterStatus { - name: name.as_bytes(), - value: value.as_bytes(), - })?; - } - - stream - .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&Be::ReadyForQuery) - .await?; - - Ok(()) -} - -/// Forward bytes in both directions (client <-> compute). -#[tracing::instrument(skip_all)] -pub async fn proxy_pass( - ctx: &mut RequestMonitoring, - client: impl AsyncRead + AsyncWrite + Unpin, - compute: impl AsyncRead + AsyncWrite + Unpin, - aux: MetricsAuxInfo, -) -> anyhow::Result<()> { - ctx.set_success(); - ctx.log(); - - let usage = USAGE_METRICS.register(Ids { - endpoint_id: aux.endpoint_id.clone(), - branch_id: aux.branch_id.clone(), - }); - - let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]); - let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx")); - let mut client = MeasuredStream::new( - client, - |_| {}, - |cnt| { - // Number of bytes we sent to the client (outbound). - m_sent.inc_by(cnt as u64); - m_sent2.inc_by(cnt as u64); - usage.record_egress(cnt as u64); - }, - ); - - let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]); - let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx")); - let mut compute = MeasuredStream::new( - compute, - |_| {}, - |cnt| { - // Number of bytes the client sent to the compute node (inbound). - m_recv.inc_by(cnt as u64); - m_recv2.inc_by(cnt as u64); - }, - ); - - // Starting from here we only proxy the client's traffic. - info!("performing the proxy pass..."); - let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?; - - Ok(()) -} - -/// Thin connection context. -struct Client<'a, S> { - /// The underlying libpq protocol stream. - stream: PqStream>, - /// Client credentials that we care about. - user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>, - /// KV-dictionary with PostgreSQL connection params. - params: &'a StartupMessageParams, - /// Allow self-signed certificates (for testing). - allow_self_signed_compute: bool, - /// Rate limiter for endpoints - endpoint_rate_limiter: Arc, -} - -impl<'a, S> Client<'a, S> { - /// Construct a new connection context. - fn new( - stream: PqStream>, - user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>, - params: &'a StartupMessageParams, - allow_self_signed_compute: bool, - endpoint_rate_limiter: Arc, - ) -> Self { - Self { - stream, - user_info, - params, - allow_self_signed_compute, - endpoint_rate_limiter, - } - } -} - -impl Client<'_, S> { - /// Let the client authenticate and connect to the designated compute node. - // Instrumentation logs endpoint name everywhere. Doesn't work for link - // auth; strictly speaking we don't know endpoint name in its case. - #[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)] - async fn connect_to_db( - self, - ctx: &mut RequestMonitoring, - session: cancellation::Session<'_>, - mode: ClientMode, - config: &'static AuthenticationConfig, - ) -> anyhow::Result<()> { - let Self { - mut stream, - user_info, - params, - allow_self_signed_compute, - endpoint_rate_limiter, - } = self; - - // check rate limit - if let Some(ep) = user_info.get_endpoint() { - if !endpoint_rate_limiter.check(ep) { - return stream - .throw_error(auth::AuthError::too_many_connections()) - .await; - } - } - - let user = user_info.get_user().to_owned(); - let auth_result = match user_info - .authenticate(ctx, &mut stream, mode.allow_cleartext(), config) - .await - { - Ok(auth_result) => auth_result, + while let StageResult::Run(handle) = stage.run() { + stage = match handle.await.expect("tasks should not panic") { + Ok(s) => s, Err(e) => { - let db = params.get("database"); - let app = params.get("application_name"); - let params_span = tracing::info_span!("", ?user, ?db, ?app); - - return stream.throw_error(e).instrument(params_span).await; + e.finish().await; + break; } - }; - - let (mut node_info, user_info) = auth_result; - - node_info.allow_self_signed_compute = allow_self_signed_compute; - - let aux = node_info.aux.clone(); - let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info) - .or_else(|e| stream.throw_error(e)) - .await?; - - prepare_client_connection(&node, session, &mut stream).await?; - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; - proxy_pass(ctx, stream, node.stream, aux).await + } } } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 8bbe88aa51..12564f65fa 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,20 +1,120 @@ use crate::{ auth, + cancellation::{self, Session}, compute::{self, PostgresConnection}, console::{self, errors::WakeComputeError, Api}, context::RequestMonitoring, metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES}, - proxy::retry::{retry_after, ShouldRetry}, + state_machine::{DynStage, ResultExt, Stage, StageError}, + stream::{PqStream, Stream}, }; use async_trait::async_trait; use hyper::StatusCode; use pq_proto::StartupMessageParams; use std::ops::ControlFlow; -use tokio::time; +use tokio::{ + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, + time, +}; use tracing::{error, info, warn}; +use pq_proto::BeMessage as Be; + +use super::{ + pass::ProxyPass, + retry::{retry_after, ShouldRetry}, +}; + const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); +pub struct NeedsComputeConnection { + pub stream: PqStream>, + pub user_info: auth::BackendType<'static, auth::backend::ComputeUserInfo>, + pub mechanism: TcpMechanism, + pub node_info: console::CachedNodeInfo, + + // monitoring + pub ctx: RequestMonitoring, + pub cancel_session: Session, +} + +impl Stage for NeedsComputeConnection +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + fn span(&self) -> tracing::Span { + tracing::info_span!("connect_to_compute") + } + async fn run(self) -> Result { + let Self { + stream, + user_info, + mechanism, + node_info, + mut ctx, + cancel_session, + } = self; + + let aux = node_info.aux.clone(); + let (mut node, mut stream) = + connect_to_compute(&mut ctx, &mechanism, node_info, &user_info) + .await + .send_error_to_user(&mut ctx, stream)?; + + prepare_client_connection(&node, &cancel_session, &mut stream) + .await + .no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?; + + // Before proxy passing, forward to compute whatever data is left in the + // PqStream input buffer. Normally there is none, but our serverless npm + // driver in pipeline mode sends startup, password and first query + // immediately after opening the connection. + let (stream, read_buf) = stream.into_inner(); + + node.stream + .write_all(&read_buf) + .await + .no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?; + + Ok(Box::new(ProxyPass { + client: stream, + compute: node.stream, + aux, + cancel_session, + })) + } +} + +/// Finish client connection initialization: confirm auth success, send params, etc. +#[tracing::instrument(skip_all)] +async fn prepare_client_connection( + node: &compute::PostgresConnection, + session: &cancellation::Session, + stream: &mut PqStream, +) -> std::io::Result<()> { + // Register compute's query cancellation token and produce a new, unique one. + // The new token (cancel_key_data) will be sent to the client. + let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone()); + + // Forward all postgres connection params to the client. + // Right now the implementation is very hacky and inefficent (ideally, + // we don't need an intermediate hashmap), but at least it should be correct. + for (name, value) in &node.params { + // TODO: Theoretically, this could result in a big pile of params... + stream.write_message_noflush(&Be::ParameterStatus { + name: name.as_bytes(), + value: value.as_bytes(), + })?; + } + + stream + .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? + .write_message(&Be::ReadyForQuery) + .await?; + + Ok(()) +} + /// If we couldn't connect, a cached connection info might be to blame /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. @@ -63,13 +163,13 @@ pub trait ConnectMechanism { fn update_connect_config(&self, conf: &mut compute::ConnCfg); } -pub struct TcpMechanism<'a> { +pub struct TcpMechanism { /// KV-dictionary with PostgreSQL connection params. - pub params: &'a StartupMessageParams, + pub params: StartupMessageParams, } #[async_trait] -impl ConnectMechanism for TcpMechanism<'_> { +impl ConnectMechanism for TcpMechanism { type Connection = PostgresConnection; type ConnectError = compute::ConnectionError; type Error = compute::ConnectionError; @@ -84,7 +184,7 @@ impl ConnectMechanism for TcpMechanism<'_> { } fn update_connect_config(&self, config: &mut compute::ConnCfg) { - config.set_startup_params(self.params); + config.set_startup_params(&self.params); } } diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs new file mode 100644 index 0000000000..0015cffa05 --- /dev/null +++ b/proxy/src/proxy/handshake.rs @@ -0,0 +1,203 @@ +use crate::{ + auth::{self, backend::NeedsAuthentication}, + cancellation::CancelMap, + config::{ProxyConfig, TlsConfig}, + context::RequestMonitoring, + error::ReportableError, + proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION}, + rate_limiter::EndpointRateLimiter, + state_machine::{DynStage, Finished, ResultExt, Stage, StageError}, + stream::{PqStream, Stream, StreamUpgradeError}, +}; +use anyhow::{anyhow, Context}; +use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; +use std::{io, sync::Arc}; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{error, info}; + +use super::ClientMode; + +pub struct NeedsHandshake { + pub stream: S, + pub config: &'static ProxyConfig, + pub cancel_map: Arc, + pub mode: ClientMode, + pub endpoint_rate_limiter: Arc, + + // monitoring + pub ctx: RequestMonitoring, +} + +impl Stage for NeedsHandshake { + fn span(&self) -> tracing::Span { + tracing::info_span!("handshake") + } + async fn run(self) -> Result { + let Self { + stream, + config, + cancel_map, + mode, + endpoint_rate_limiter, + mut ctx, + } = self; + + let tls = config.tls_config.as_ref(); + + let pause_timer = ctx.latency_timer.pause(); + let handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map).await; + drop(pause_timer); + + let (stream, params) = match handshake { + Err(err) => { + // TODO: proper handling + error!("could not complete handshake: {err:#}"); + return Err(StageError::Done); + } + // cancellation + Ok(None) => return Ok(Box::new(Finished)), + Ok(Some(s)) => s, + }; + + let hostname = mode.hostname(stream.get_ref()); + + let common_names = tls.map(|tls| &tls.common_names); + let (creds, stream) = config + .auth_backend + .as_ref() + .map(|_| { + auth::ComputeUserInfoMaybeEndpoint::parse(&mut ctx, ¶ms, hostname, common_names) + }) + .transpose() + .send_error_to_user(&mut ctx, stream)?; + + ctx.set_endpoint_id(creds.get_endpoint()); + + Ok(Box::new(NeedsAuthentication { + stream, + creds, + params, + endpoint_rate_limiter, + mode, + config, + + ctx, + cancel_session: cancel_map.get_session(), + })) + } +} + +#[derive(Error, Debug)] +pub enum HandshakeError { + #[error("client disconnected: {0}")] + ClientIO(#[from] io::Error), + #[error("protocol violation: {0}")] + ProtocolError(#[from] anyhow::Error), + #[error("could not initiate tls connection: {0}")] + TLSError(#[from] StreamUpgradeError), + #[error("could not cancel connection: {0}")] + Cancel(anyhow::Error), +} + +impl ReportableError for HandshakeError { + fn get_error_type(&self) -> crate::error::ErrorKind { + match self { + HandshakeError::ClientIO(_) => crate::error::ErrorKind::Disconnect, + HandshakeError::ProtocolError(_) => crate::error::ErrorKind::User, + HandshakeError::TLSError(_) => crate::error::ErrorKind::User, + HandshakeError::Cancel(_) => crate::error::ErrorKind::Compute, + } + } +} + +type SuccessfulHandshake = (PqStream>, StartupMessageParams); + +/// Establish a (most probably, secure) connection with the client. +/// For better testing experience, `stream` can be any object satisfying the traits. +/// It's easier to work with owned `stream` here as we need to upgrade it to TLS; +/// we also take an extra care of propagating only the select handshake errors to client. +pub async fn handshake( + stream: S, + mut tls: Option<&TlsConfig>, + cancel_map: &CancelMap, +) -> Result>, HandshakeError> { + // Client may try upgrading to each protocol only once + let (mut tried_ssl, mut tried_gss) = (false, false); + + let mut stream = PqStream::new(Stream::from_raw(stream)); + loop { + let msg = stream.read_startup_packet().await?; + info!("received {msg:?}"); + + use FeStartupPacket::*; + match msg { + SslRequest => match stream.get_ref() { + Stream::Raw { .. } if !tried_ssl => { + tried_ssl = true; + + // We can't perform TLS handshake without a config + let enc = tls.is_some(); + stream.write_message(&Be::EncryptionResponse(enc)).await?; + if let Some(tls) = tls.take() { + // Upgrade raw stream into a secure TLS-backed stream. + // NOTE: We've consumed `tls`; this fact will be used later. + + let (raw, read_buf) = stream.into_inner(); + // TODO: Normally, client doesn't send any data before + // server says TLS handshake is ok and read_buf is empy. + // However, you could imagine pipelining of postgres + // SSLRequest + TLS ClientHello in one hunk similar to + // pipelining in our node js driver. We should probably + // support that by chaining read_buf with the stream. + if !read_buf.is_empty() { + return Err(HandshakeError::ProtocolError(anyhow!( + "data is sent before server replied with EncryptionResponse" + ))); + } + let tls_stream = raw.upgrade(tls.to_server_config()).await?; + + let (_, tls_server_end_point) = tls + .cert_resolver + .resolve(tls_stream.get_ref().1.server_name()) + .context("missing certificate")?; + + stream = PqStream::new(Stream::Tls { + tls: Box::new(tls_stream), + tls_server_end_point, + }); + } + } + _ => return Err(HandshakeError::ProtocolError(anyhow!(ERR_PROTO_VIOLATION))), + }, + GssEncRequest => match stream.get_ref() { + Stream::Raw { .. } if !tried_gss => { + tried_gss = true; + + // Currently, we don't support GSSAPI + stream.write_message(&Be::EncryptionResponse(false)).await?; + } + _ => return Err(HandshakeError::ProtocolError(anyhow!(ERR_PROTO_VIOLATION))), + }, + StartupMessage { params, .. } => { + // Check that the config has been consumed during upgrade + // OR we didn't provide it at all (for dev purposes). + if tls.is_some() { + stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; + } + + info!(session_type = "normal", "successful handshake"); + break Ok(Some((stream, params))); + } + CancelRequest(cancel_key_data) => { + cancel_map + .cancel_session(cancel_key_data) + .await + .map_err(HandshakeError::Cancel)?; + + info!(session_type = "cancellation", "successful handshake"); + break Ok(None); + } + } + } +} diff --git a/proxy/src/proxy/pass.rs b/proxy/src/proxy/pass.rs new file mode 100644 index 0000000000..891d672a12 --- /dev/null +++ b/proxy/src/proxy/pass.rs @@ -0,0 +1,82 @@ +use crate::{ + cancellation::Session, + console::messages::MetricsAuxInfo, + metrics::{NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER}, + state_machine::{DynStage, Finished, Stage, StageError}, + stream::Stream, + usage_metrics::{Ids, USAGE_METRICS}, +}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{error, info}; +use utils::measured_stream::MeasuredStream; + +pub struct ProxyPass { + pub client: Stream, + pub compute: Compute, + + // monitoring + pub aux: MetricsAuxInfo, + pub cancel_session: Session, +} + +impl Stage for ProxyPass +where + Client: AsyncRead + AsyncWrite + Unpin + Send + 'static, + Compute: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + fn span(&self) -> tracing::Span { + tracing::info_span!("proxy_pass") + } + async fn run(self) -> Result { + if let Err(e) = proxy_pass(self.client, self.compute, self.aux).await { + error!("{e:#}") + } + + drop(self.cancel_session); + + Ok(Box::new(Finished)) + } +} + +/// Forward bytes in both directions (client <-> compute). +pub async fn proxy_pass( + client: impl AsyncRead + AsyncWrite + Unpin, + compute: impl AsyncRead + AsyncWrite + Unpin, + aux: MetricsAuxInfo, +) -> anyhow::Result<()> { + let usage = USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id.clone(), + branch_id: aux.branch_id.clone(), + }); + + let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]); + let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx")); + let mut client = MeasuredStream::new( + client, + |_| {}, + |cnt| { + // Number of bytes we sent to the client (outbound). + m_sent.inc_by(cnt as u64); + m_sent2.inc_by(cnt as u64); + usage.record_egress(cnt as u64); + }, + ); + + let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]); + let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx")); + let mut compute = MeasuredStream::new( + compute, + |_| {}, + |cnt| { + // Number of bytes the client sent to the compute node (inbound). + m_recv.inc_by(cnt as u64); + m_recv2.inc_by(cnt as u64); + }, + ); + + // Starting from here we only proxy the client's traffic. + info!("performing the proxy pass..."); + let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?; + + Ok(()) +} diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 73fde2d7d0..11f6e27659 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -3,14 +3,19 @@ mod mitm; use super::connect_compute::ConnectMechanism; +use super::handshake::handshake; use super::retry::ShouldRetry; use super::*; use crate::auth::backend::{ComputeUserInfo, TestBackend}; use crate::config::CertResolver; use crate::console::{self, CachedNodeInfo, NodeInfo}; +use crate::proxy::connect_compute::connect_to_compute; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; -use crate::{auth, http, sasl, scram}; +use crate::stream::PqStream; +use crate::{auth, compute, http, sasl, scram}; +use anyhow::bail; use async_trait::async_trait; +use pq_proto::BeMessage as Be; use rstest::rstest; use smol_str::SmolStr; use tokio_postgres::config::SslMode; @@ -202,7 +207,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { .err() // -> Option .context("server shouldn't accept client")?; - assert!(client_err.to_string().contains(&server_err.to_string())); + assert!(server_err.to_string().contains(ERR_INSECURE_CONNECTION)); Ok(()) } diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index a0a84a1dc0..c04d8a8f2a 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -10,7 +10,7 @@ use super::*; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use postgres_protocol::message::frontend; -use tokio::io::{AsyncReadExt, DuplexStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; use tokio_postgres::config::SslMode; use tokio_postgres::tls::TlsConnect; use tokio_util::codec::{Decoder, Encoder}; diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs new file mode 100644 index 0000000000..05e3f62bfa --- /dev/null +++ b/proxy/src/proxy/wake_compute.rs @@ -0,0 +1,89 @@ +use std::{borrow::Cow, ops::ControlFlow}; + +use pq_proto::StartupMessageParams; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{error, warn}; + +use crate::{ + auth::{ + backend::{ComputeCredentialKeys, ComputeCredentials}, + BackendType, + }, + cancellation::Session, + console::{provider::ConsoleBackend, Api}, + context::RequestMonitoring, + state_machine::{user_facing_error, DynStage, Stage, StageError}, + stream::{PqStream, Stream}, +}; + +use super::{ + connect_compute::{handle_try_wake, NeedsComputeConnection, TcpMechanism}, + retry::retry_after, +}; + +pub struct NeedsWakeCompute { + pub stream: PqStream>, + pub api: Cow<'static, ConsoleBackend>, + pub params: StartupMessageParams, + pub allow_self_signed_compute: bool, + pub creds: ComputeCredentials, + + // monitoring + pub ctx: RequestMonitoring, + pub cancel_session: Session, +} + +impl Stage for NeedsWakeCompute { + fn span(&self) -> tracing::Span { + tracing::info_span!("wake_compute") + } + async fn run(self) -> Result { + let Self { + stream, + api, + params, + allow_self_signed_compute, + creds, + mut ctx, + cancel_session, + } = self; + + let mut num_retries = 0; + let mut node_info = loop { + let wake_res = api.wake_compute(&mut ctx, &creds.info).await; + match handle_try_wake(wake_res, num_retries) { + Err(e) => { + error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); + return Err(user_facing_error(e, &mut ctx, stream)); + } + Ok(ControlFlow::Continue(e)) => { + warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); + } + Ok(ControlFlow::Break(n)) => break n, + } + + let wait_duration = retry_after(num_retries); + num_retries += 1; + tokio::time::sleep(wait_duration).await; + }; + + ctx.set_project(node_info.aux.clone()); + + node_info.allow_self_signed_compute = allow_self_signed_compute; + + match creds.keys { + #[cfg(feature = "testing")] + ComputeCredentialKeys::Password(password) => node_info.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys), + }; + + Ok(Box::new(NeedsComputeConnection { + stream, + user_info: BackendType::Console(api, creds.info), + mechanism: TcpMechanism { params }, + node_info, + ctx, + cancel_session, + })) + } +} diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index da1cf21c6a..14a36dca23 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -10,7 +10,7 @@ mod channel_binding; mod messages; mod stream; -use crate::error::UserFacingError; +use crate::error::{ReportableError, UserFacingError}; use std::io; use thiserror::Error; @@ -37,6 +37,25 @@ pub enum Error { Io(#[from] io::Error), } +impl ReportableError for Error { + fn get_error_type(&self) -> crate::error::ErrorKind { + match self { + Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, + Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, + Error::BadClientMessage(_) => crate::error::ErrorKind::User, + Error::MissingBinding => crate::error::ErrorKind::Service, + Error::Io(io) => match io.kind() { + // tokio postgres uses these for various scram failures + io::ErrorKind::InvalidInput + | io::ErrorKind::UnexpectedEof + | io::ErrorKind::Other => crate::error::ErrorKind::User, + // all other IO errors are likely disconnects. + _ => crate::error::ErrorKind::Disconnect, + }, + } + } +} + impl UserFacingError for Error { fn to_string_client(&self) -> String { use Error::*; diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 8af008394a..1ba387cd2e 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -124,6 +124,12 @@ pub async fn task_main( let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); + let root_span = info_span!( + "serverless", + session = %session_id, + %peer_addr, + ); + request_handler( req, config, @@ -135,12 +141,9 @@ pub async fn task_main( sni_name, peer_addr.ip(), endpoint_rate_limiter, + root_span.clone(), ) - .instrument(info_span!( - "serverless", - session = %session_id, - %peer_addr, - )) + .instrument(root_span) .await } }, @@ -205,6 +208,7 @@ async fn request_handler( sni_hostname: Option, peer_addr: IpAddr, endpoint_rate_limiter: Arc, + root_span: tracing::Span, ) -> Result, ApiError> { let host = request .headers() @@ -215,27 +219,33 @@ async fn request_handler( // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&request) { - info!(session_id = ?session_id, "performing websocket upgrade"); + info!("performing websocket upgrade"); let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) .map_err(|e| ApiError::BadRequest(e.into()))?; ws_connections.spawn( async move { - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region); + let ctx = + RequestMonitoring::new(session_id, peer_addr, "ws", &config.region, root_span); - if let Err(e) = websocket::serve_websocket( + let websocket = match websocket.await { + Err(e) => { + error!("error in websocket connection: {e:#}"); + return; + } + Ok(ws) => ws, + }; + + websocket::serve_websocket( config, - &mut ctx, + ctx, websocket, - &cancel_map, + cancel_map, host, endpoint_rate_limiter, ) .await - { - error!(session_id = ?session_id, "error in websocket connection: {e:#}"); - } } .in_current_span(), ); @@ -243,7 +253,8 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response) } else if request.uri().path() == "/sql" && request.method() == Method::POST { - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); + let mut ctx = + RequestMonitoring::new(session_id, peer_addr, "http", &config.region, root_span); sql_over_http::handle( tls, diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index a6529c920a..cf172e90c1 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -9,7 +9,7 @@ use crate::{ use bytes::{Buf, Bytes}; use futures::{Sink, Stream}; use hyper::upgrade::Upgraded; -use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; +use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use pin_project_lite::pin_project; use std::{ @@ -131,13 +131,12 @@ impl AsyncBufRead for WebSocketRw { pub async fn serve_websocket( config: &'static ProxyConfig, - ctx: &mut RequestMonitoring, - websocket: HyperWebsocket, - cancel_map: &CancelMap, + ctx: RequestMonitoring, + websocket: WebSocketStream, + cancel_map: Arc, hostname: Option, endpoint_rate_limiter: Arc, -) -> anyhow::Result<()> { - let websocket = websocket.await?; +) { handle_client( config, ctx, @@ -146,8 +145,7 @@ pub async fn serve_websocket( ClientMode::Websockets { hostname }, endpoint_rate_limiter, ) - .await?; - Ok(()) + .await } #[cfg(test)] diff --git a/proxy/src/state_machine.rs b/proxy/src/state_machine.rs new file mode 100644 index 0000000000..ffbd896b65 --- /dev/null +++ b/proxy/src/state_machine.rs @@ -0,0 +1,149 @@ +use futures::Future; +use pq_proto::{framed::Framed, BeMessage}; +use tokio::{io::AsyncWrite, task::JoinHandle}; +use tracing::{info, warn, Instrument}; + +pub trait Captures {} +impl Captures for U {} + +#[must_use] +pub enum StageError { + Flush(Framed>), + Done, +} + +impl StageError { + pub async fn finish(self) { + match self { + StageError::Flush(mut f) => { + // ignore result. we can't do anything about it. + // this is already the error case anyway... + if let Err(e) = f.flush().await { + warn!("could not send message to user: {e:?}") + } + } + StageError::Done => {} + } + info!("task finished"); + } +} + +pub type DynStage = Box; + +/// Stage represents a single stage in a state machine. +pub trait Stage: 'static + Send { + /// The span this stage should be run inside. + fn span(&self) -> tracing::Span; + /// Run the current stage, returning a new [`DynStage`], or an error + /// + /// Can be implemented as `async fn run(self) -> Result` + fn run(self) -> impl 'static + Send + Future>; +} + +pub enum StageResult { + Finished, + Run(JoinHandle>), +} + +pub trait StageSpawn: 'static + Send { + fn run(self: Box) -> StageResult; +} + +/// Stage spawn is a helper trait for the state machine. It spawns the stages as a tokio task +impl StageSpawn for S { + fn run(self: Box) -> StageResult { + let span = self.span(); + StageResult::Run(tokio::spawn(S::run(*self).instrument(span))) + } +} + +pub struct Finished; + +impl StageSpawn for Finished { + fn run(self: Box) -> StageResult { + StageResult::Finished + } +} + +use crate::{ + context::RequestMonitoring, + error::{ErrorKind, UserFacingError}, + stream::PqStream, +}; + +pub trait ResultExt { + fn send_error_to_user( + self, + ctx: &mut RequestMonitoring, + stream: PqStream, + ) -> Result<(T, PqStream), StageError> + where + S: AsyncWrite + Unpin + Send + 'static, + E: UserFacingError; + + fn no_user_error(self, ctx: &mut RequestMonitoring, kind: ErrorKind) -> Result + where + E: std::fmt::Display; +} + +impl ResultExt for Result { + fn send_error_to_user( + self, + ctx: &mut RequestMonitoring, + stream: PqStream, + ) -> Result<(T, PqStream), StageError> + where + S: AsyncWrite + Unpin + Send + 'static, + E: UserFacingError, + { + match self { + Ok(t) => Ok((t, stream)), + Err(e) => Err(user_facing_error(e, ctx, stream)), + } + } + + fn no_user_error(self, ctx: &mut RequestMonitoring, kind: ErrorKind) -> Result + where + E: std::fmt::Display, + { + match self { + Ok(t) => Ok(t), + Err(e) => { + tracing::error!( + kind = kind.to_str(), + user_msg = "", + "task finished with error: {e}" + ); + + ctx.error(kind); + ctx.log(); + Err(StageError::Done) + } + } + } +} + +pub fn user_facing_error( + err: E, + ctx: &mut RequestMonitoring, + mut stream: PqStream, +) -> StageError +where + S: AsyncWrite + Unpin + Send + 'static, + E: UserFacingError, +{ + let kind = err.get_error_type(); + ctx.error(kind); + ctx.log(); + + let msg = err.to_string_client(); + tracing::error!( + kind = kind.to_str(), + user_msg = msg, + "task finished with error: {err}" + ); + if let Err(err) = stream.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)) { + warn!("could not process error message: {err:?}") + } + StageError::Flush(stream.framed.map_stream_sync(|f| Box::new(f) as Box<_>)) +} diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index f48b3fe39f..27e0e016ba 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,5 +1,5 @@ use crate::config::TlsServerEndPoint; -use crate::error::UserFacingError; +use crate::error::ErrorKind; use anyhow::bail; use bytes::BytesMut; @@ -99,24 +99,17 @@ impl PqStream { /// Allowing string literals is safe under the assumption they might not contain any runtime info. /// This method exists due to `&str` not implementing `Into`. pub async fn throw_error_str(&mut self, error: &'static str) -> anyhow::Result { - tracing::info!("forwarding error to user: {error}"); + let kind = ErrorKind::User; + tracing::error!( + kind = kind.to_str(), + full_msg = error, + user_msg = error, + "task finished with error" + ); self.write_message(&BeMessage::ErrorResponse(error, None)) .await?; bail!(error) } - - /// Write the error message using [`Self::write_message`], then re-throw it. - /// Trait [`UserFacingError`] acts as an allowlist for error types. - pub async fn throw_error(&mut self, error: E) -> anyhow::Result - where - E: UserFacingError + Into, - { - let msg = error.to_string_client(); - tracing::info!("forwarding error to user: {msg}"); - self.write_message(&BeMessage::ErrorResponse(&msg, None)) - .await?; - bail!(error) - } } /// Wrapper for upgrading raw streams into secure streams.