diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 8d1b861a66..48de4e2353 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -5,7 +5,8 @@ pub use backend::BackendType; mod credentials; pub use credentials::{ - check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern, + check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, + ComputeUserInfoParseError, IpPattern, }; mod password_hack; @@ -14,8 +15,12 @@ use password_hack::PasswordHackPayload; mod flow; pub use flow::*; +use tokio::time::error::Elapsed; -use crate::{console, error::UserFacingError}; +use crate::{ + console, + error::{ReportableError, UserFacingError}, +}; use std::io; use thiserror::Error; @@ -67,6 +72,9 @@ pub enum AuthErrorImpl { #[error("Too many connections to this endpoint. Please try again later.")] TooManyConnections, + + #[error("Authentication timed out")] + UserTimeout(Elapsed), } #[derive(Debug, Error)] @@ -93,6 +101,10 @@ impl AuthError { pub fn is_auth_failed(&self) -> bool { matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_)) } + + pub fn user_timeout(elapsed: Elapsed) -> Self { + AuthErrorImpl::UserTimeout(elapsed).into() + } } impl> From for AuthError { @@ -116,6 +128,27 @@ impl UserFacingError for AuthError { Io(_) => "Internal error".to_string(), IpAddressNotAllowed => self.to_string(), TooManyConnections => self.to_string(), + UserTimeout(_) => self.to_string(), + } + } +} + +impl ReportableError for AuthError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + use AuthErrorImpl::*; + match self.0.as_ref() { + Link(e) => e.get_error_kind(), + GetAuthInfo(e) => e.get_error_kind(), + WakeCompute(e) => e.get_error_kind(), + Sasl(e) => e.get_error_kind(), + AuthFailed(_) => crate::error::ErrorKind::User, + BadAuthMethod(_) => crate::error::ErrorKind::User, + MalformedPassword(_) => crate::error::ErrorKind::User, + MissingEndpointName => crate::error::ErrorKind::User, + Io(_) => crate::error::ErrorKind::ClientDisconnect, + IpAddressNotAllowed => crate::error::ErrorKind::User, + TooManyConnections => crate::error::ErrorKind::RateLimit, + UserTimeout(_) => crate::error::ErrorKind::User, } } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 384063ceae..745dd75107 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -45,9 +45,9 @@ pub(super) async fn authenticate( } ) .await - .map_err(|error| { + .map_err(|e| { warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs()); - auth::io::Error::new(auth::io::ErrorKind::TimedOut, error) + auth::AuthError::user_timeout(e) })??; let client_key = match auth_outcome { diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index d8ae362c03..c71637dd1a 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -2,7 +2,7 @@ use crate::{ auth, compute, console::{self, provider::NodeInfo}, context::RequestMonitoring, - error::UserFacingError, + error::{ReportableError, UserFacingError}, stream::PqStream, waiters, }; @@ -14,10 +14,6 @@ use tracing::{info, info_span}; #[derive(Debug, Error)] pub enum LinkAuthError { - /// Authentication error reported by the console. - #[error("Authentication failed: {0}")] - AuthFailed(String), - #[error(transparent)] WaiterRegister(#[from] waiters::RegisterError), @@ -30,10 +26,16 @@ pub enum LinkAuthError { impl UserFacingError for LinkAuthError { fn to_string_client(&self) -> String { - use LinkAuthError::*; + "Internal error".to_string() + } +} + +impl ReportableError for LinkAuthError { + fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - AuthFailed(_) => self.to_string(), - _ => "Internal error".to_string(), + LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service, + LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service, + LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect, } } } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 875baaec47..d32609e44c 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -1,8 +1,12 @@ //! User credentials used in authentication. use crate::{ - auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError, - metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI, + auth::password_hack::parse_endpoint_param, + context::RequestMonitoring, + error::{ReportableError, UserFacingError}, + metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, + proxy::NeonOptions, + serverless::SERVERLESS_DRIVER_SNI, EndpointId, RoleName, }; use itertools::Itertools; @@ -39,6 +43,12 @@ pub enum ComputeUserInfoParseError { impl UserFacingError for ComputeUserInfoParseError {} +impl ReportableError for ComputeUserInfoParseError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + /// Various client credentials which we use for authentication. /// Note that we don't store any kind of client key or password here. #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 471be7af25..43b805e8a1 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -240,7 +240,9 @@ async fn ssl_handshake( ?unexpected, "unexpected startup packet, rejecting connection" ); - stream.throw_error_str(ERR_INSECURE_CONNECTION).await? + stream + .throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User) + .await? } } } @@ -272,5 +274,10 @@ async fn handle_client( let client = tokio::net::TcpStream::connect(destination).await?; let metrics_aux: MetricsAuxInfo = Default::default(); - proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await + + // doesn't yet matter as pg-sni-router doesn't report analytics logs + ctx.set_success(); + ctx.log(); + + proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index d4ee657144..fe614628d8 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,24 +1,45 @@ -use anyhow::Context; use dashmap::DashMap; use pq_proto::CancelKeyData; use std::{net::SocketAddr, sync::Arc}; +use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::{CancelToken, NoTls}; use tracing::info; +use crate::error::ReportableError; + /// Enables serving `CancelRequest`s. #[derive(Default)] pub struct CancelMap(DashMap>); +#[derive(Debug, Error)] +pub enum CancelError { + #[error("{0}")] + IO(#[from] std::io::Error), + #[error("{0}")] + Postgres(#[from] tokio_postgres::Error), +} + +impl ReportableError for CancelError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + CancelError::IO(_) => crate::error::ErrorKind::Compute, + CancelError::Postgres(e) if e.as_db_error().is_some() => { + crate::error::ErrorKind::Postgres + } + CancelError::Postgres(_) => crate::error::ErrorKind::Compute, + } + } +} + impl CancelMap { /// Cancel a running query for the corresponding connection. - pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> { + pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> { // NB: we should immediately release the lock after cloning the token. - let cancel_closure = self - .0 - .get(&key) - .and_then(|x| x.clone()) - .with_context(|| format!("query cancellation key not found: {key}"))?; + let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else { + tracing::warn!("query cancellation key not found: {key}"); + return Ok(()); + }; info!("cancelling query per user's request using key {key}"); cancel_closure.try_cancel_query().await @@ -81,7 +102,7 @@ impl CancelClosure { } /// Cancels the query running on user's compute node. - pub async fn try_cancel_query(self) -> anyhow::Result<()> { + async fn try_cancel_query(self) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; self.cancel_token.cancel_query_raw(socket, NoTls).await?; diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index aef1aab733..83940d80ec 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,6 +1,10 @@ use crate::{ - auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError, - context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE, + auth::parse_endpoint_param, + cancellation::CancelClosure, + console::errors::WakeComputeError, + context::RequestMonitoring, + error::{ReportableError, UserFacingError}, + metrics::NUM_DB_CONNECTIONS_GAUGE, proxy::neon_option, }; use futures::{FutureExt, TryFutureExt}; @@ -58,6 +62,20 @@ impl UserFacingError for ConnectionError { } } +impl ReportableError for ConnectionError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + ConnectionError::Postgres(e) if e.as_db_error().is_some() => { + crate::error::ErrorKind::Postgres + } + ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, + ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute, + ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, + ConnectionError::WakeComputeError(e) => e.get_error_kind(), + } + } +} + /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index c53d929470..e5cad42753 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -20,7 +20,7 @@ use tracing::info; pub mod errors { use crate::{ - error::{io_error, UserFacingError}, + error::{io_error, ReportableError, UserFacingError}, http, proxy::retry::ShouldRetry, }; @@ -81,6 +81,15 @@ pub mod errors { } } + impl ReportableError for ApiError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane, + ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane, + } + } + } + impl ShouldRetry for ApiError { fn could_retry(&self) -> bool { match self { @@ -150,6 +159,16 @@ pub mod errors { } } } + + impl ReportableError for GetAuthInfoError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane, + GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane, + } + } + } + #[derive(Debug, Error)] pub enum WakeComputeError { #[error("Console responded with a malformed compute address: {0}")] @@ -194,6 +213,16 @@ pub mod errors { } } } + + impl ReportableError for WakeComputeError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane, + WakeComputeError::ApiError(e) => e.get_error_kind(), + WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit, + } + } + } } /// Auth secret which is managed by the cloud. diff --git a/proxy/src/context.rs b/proxy/src/context.rs index fe204534b7..d2bf3f68d3 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -8,8 +8,10 @@ use tokio::sync::mpsc; use uuid::Uuid; use crate::{ - console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId, - EndpointId, ProjectId, RoleName, + console::messages::MetricsAuxInfo, + error::ErrorKind, + metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND}, + BranchId, EndpointId, ProjectId, RoleName, }; pub mod parquet; @@ -108,6 +110,18 @@ impl RequestMonitoring { self.user = Some(user); } + pub fn set_error_kind(&mut self, kind: ErrorKind) { + ERROR_BY_KIND + .with_label_values(&[kind.to_metric_label()]) + .inc(); + if let Some(ep) = &self.endpoint_id { + ENDPOINT_ERRORS_BY_KIND + .with_label_values(&[kind.to_metric_label()]) + .measure(ep); + } + self.error_kind = Some(kind); + } + pub fn set_success(&mut self) { self.success = true; } diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 8510c5c586..0fe46915bc 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -108,7 +108,7 @@ impl From for RequestData { branch: value.branch.as_deref().map(String::from), protocol: value.protocol, region: value.region, - error: value.error_kind.as_ref().map(|e| e.to_str()), + error: value.error_kind.as_ref().map(|e| e.to_metric_label()), success: value.success, duration_us: SystemTime::from(value.first_packet) .elapsed() diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 5b2dd7ecfd..eafe92bf48 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -17,7 +17,7 @@ pub fn log_error(e: E) -> E { /// NOTE: This trait should not be implemented for [`anyhow::Error`], since it /// is way too convenient and tends to proliferate all across the codebase, /// ultimately leading to accidental leaks of sensitive data. -pub trait UserFacingError: fmt::Display { +pub trait UserFacingError: ReportableError { /// Format the error for client, stripping all sensitive info. /// /// Although this might be a no-op for many types, it's highly @@ -29,13 +29,13 @@ pub trait UserFacingError: fmt::Display { } } -#[derive(Clone)] +#[derive(Copy, Clone, Debug)] pub enum ErrorKind { /// Wrong password, unknown endpoint, protocol violation, etc... User, /// Network error between user and proxy. Not necessarily user error - Disconnect, + ClientDisconnect, /// Proxy self-imposed rate limits RateLimit, @@ -46,6 +46,9 @@ pub enum ErrorKind { /// Error communicating with control plane ControlPlane, + /// Postgres error + Postgres, + /// Error communicating with compute Compute, } @@ -54,11 +57,36 @@ impl ErrorKind { pub fn to_str(&self) -> &'static str { match self { ErrorKind::User => "request failed due to user error", - ErrorKind::Disconnect => "client disconnected", + ErrorKind::ClientDisconnect => "client disconnected", ErrorKind::RateLimit => "request cancelled due to rate limit", ErrorKind::Service => "internal service error", ErrorKind::ControlPlane => "non-retryable control plane error", - ErrorKind::Compute => "non-retryable compute error (or exhausted retry capacity)", + ErrorKind::Postgres => "postgres error", + ErrorKind::Compute => { + "non-retryable compute connection error (or exhausted retry capacity)" + } + } + } + + pub fn to_metric_label(&self) -> &'static str { + match self { + ErrorKind::User => "user", + ErrorKind::ClientDisconnect => "clientdisconnect", + ErrorKind::RateLimit => "ratelimit", + ErrorKind::Service => "service", + ErrorKind::ControlPlane => "controlplane", + ErrorKind::Postgres => "postgres", + ErrorKind::Compute => "compute", } } } + +pub trait ReportableError: fmt::Display + Send + 'static { + fn get_error_kind(&self) -> ErrorKind; +} + +impl ReportableError for tokio::time::error::Elapsed { + fn get_error_kind(&self) -> ErrorKind { + ErrorKind::RateLimit + } +} diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index e2d96a9c27..ccf89f9b05 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -274,3 +274,22 @@ pub static CONNECTING_ENDPOINTS: Lazy> = Lazy::new(|| { ) .unwrap() }); + +pub static ERROR_BY_KIND: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_errors_total", + "Number of errors by a given classification", + &["type"], + ) + .unwrap() +}); + +pub static ENDPOINT_ERRORS_BY_KIND: Lazy> = Lazy::new(|| { + register_hll_vec!( + 32, + "proxy_endpoints_affected_by_errors", + "Number of endpoints affected by errors of a given classification", + &["type"], + ) + .unwrap() +}); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index b3b221d3e2..50e22ec72a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -13,9 +13,10 @@ use crate::{ compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, + error::ReportableError, metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE}, protocol2::WithClientIp, - proxy::{handshake::handshake, passthrough::proxy_pass}, + proxy::handshake::{handshake, HandshakeData}, rate_limiter::EndpointRateLimiter, stream::{PqStream, Stream}, EndpointCacheKey, @@ -28,14 +29,17 @@ use pq_proto::{BeMessage as Be, StartupMessageParams}; use regex::Regex; use smol_str::{format_smolstr, SmolStr}; use std::sync::Arc; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, Instrument}; -use self::connect_compute::{connect_to_compute, TcpMechanism}; +use self::{ + connect_compute::{connect_to_compute, TcpMechanism}, + passthrough::ProxyPassthrough, +}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; -const ERR_PROTO_VIOLATION: &str = "protocol violation"; pub async fn run_until_cancelled( f: F, @@ -98,14 +102,14 @@ pub async fn task_main( bail!("missing required client IP"); } - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region); - socket .inner .set_nodelay(true) .context("failed to set socket option")?; - handle_client( + let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region); + + let res = handle_client( config, &mut ctx, cancel_map, @@ -113,7 +117,26 @@ pub async fn task_main( ClientMode::Tcp, endpoint_rate_limiter, ) - .await + .await; + + match res { + Err(e) => { + // todo: log and push to ctx the error kind + ctx.set_error_kind(e.get_error_kind()); + ctx.log(); + Err(e.into()) + } + Ok(None) => { + ctx.set_success(); + ctx.log(); + Ok(()) + } + Ok(Some(p)) => { + ctx.set_success(); + ctx.log(); + p.proxy_pass().await + } + } } .unwrap_or_else(move |e| { // Acknowledge that the task has finished with an error. @@ -169,6 +192,37 @@ impl ClientMode { } } +#[derive(Debug, Error)] +// almost all errors should be reported to the user, but there's a few cases where we cannot +// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons +// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation, +// we cannot be sure the client even understands our error message +// 3. PrepareClient: The client disconnected, so we can't tell them anyway... +pub enum ClientRequestError { + #[error("{0}")] + Cancellation(#[from] cancellation::CancelError), + #[error("{0}")] + Handshake(#[from] handshake::HandshakeError), + #[error("{0}")] + HandshakeTimeout(#[from] tokio::time::error::Elapsed), + #[error("{0}")] + PrepareClient(#[from] std::io::Error), + #[error("{0}")] + ReportedError(#[from] crate::stream::ReportedError), +} + +impl ReportableError for ClientRequestError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + ClientRequestError::Cancellation(e) => e.get_error_kind(), + ClientRequestError::Handshake(e) => e.get_error_kind(), + ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit, + ClientRequestError::ReportedError(e) => e.get_error_kind(), + ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect, + } + } +} + pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, @@ -176,7 +230,7 @@ pub async fn handle_client( stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, -) -> anyhow::Result<()> { +) -> Result>, ClientRequestError> { info!( protocol = ctx.protocol, "handling interactive connection from client" @@ -193,11 +247,16 @@ pub async fn handle_client( let tls = config.tls_config.as_ref(); let pause = ctx.latency_timer.pause(); - let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map); + let do_handshake = handshake(stream, mode.handshake_tls(tls)); let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { - Some(x) => x, - None => return Ok(()), // it's a cancellation request + HandshakeData::Startup(stream, params) => (stream, params), + HandshakeData::Cancel(cancel_key_data) => { + return Ok(cancel_map + .cancel_session(cancel_key_data) + .await + .map(|()| None)?) + } }; drop(pause); @@ -222,7 +281,7 @@ pub async fn handle_client( if !endpoint_rate_limiter.check(ep) { return stream .throw_error(auth::AuthError::too_many_connections()) - .await; + .await?; } } @@ -242,7 +301,7 @@ pub async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return stream.throw_error(e).instrument(params_span).await; + return stream.throw_error(e).instrument(params_span).await?; } }; @@ -268,7 +327,13 @@ pub async fn handle_client( let (stream, read_buf) = stream.into_inner(); node.stream.write_all(&read_buf).await?; - proxy_pass(ctx, stream, node.stream, aux).await + Ok(Some(ProxyPassthrough { + client: stream, + compute: node, + aux, + req: _request_gauge, + conn: _client_gauge, + })) } /// Finish client connection initialization: confirm auth success, send params, etc. @@ -277,7 +342,7 @@ async fn prepare_client_connection( node: &compute::PostgresConnection, session: &cancellation::Session, stream: &mut PqStream, -) -> anyhow::Result<()> { +) -> Result<(), std::io::Error> { // Register compute's query cancellation token and produce a new, unique one. // The new token (cancel_key_data) will be sent to the client. let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone()); diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 1ad8da20d7..4665e07d23 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,15 +1,60 @@ -use anyhow::{bail, Context}; -use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; +use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams}; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; use crate::{ - cancellation::CancelMap, config::TlsConfig, - proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION}, - stream::{PqStream, Stream}, + error::ReportableError, + proxy::ERR_INSECURE_CONNECTION, + stream::{PqStream, Stream, StreamUpgradeError}, }; +#[derive(Error, Debug)] +pub enum HandshakeError { + #[error("data is sent before server replied with EncryptionResponse")] + EarlyData, + + #[error("protocol violation")] + ProtocolViolation, + + #[error("missing certificate")] + MissingCertificate, + + #[error("{0}")] + StreamUpgradeError(#[from] StreamUpgradeError), + + #[error("{0}")] + Io(#[from] std::io::Error), + + #[error("{0}")] + ReportedError(#[from] crate::stream::ReportedError), +} + +impl ReportableError for HandshakeError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + HandshakeError::EarlyData => crate::error::ErrorKind::User, + HandshakeError::ProtocolViolation => crate::error::ErrorKind::User, + // This error should not happen, but will if we have no default certificate and + // the client sends no SNI extension. + // If they provide SNI then we can be sure there is a certificate that matches. + HandshakeError::MissingCertificate => crate::error::ErrorKind::Service, + HandshakeError::StreamUpgradeError(upgrade) => match upgrade { + StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service, + StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect, + }, + HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect, + HandshakeError::ReportedError(e) => e.get_error_kind(), + } + } +} + +pub enum HandshakeData { + Startup(PqStream>, StartupMessageParams), + Cancel(CancelKeyData), +} + /// Establish a (most probably, secure) connection with the client. /// For better testing experience, `stream` can be any object satisfying the traits. /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; @@ -18,8 +63,7 @@ use crate::{ pub async fn handshake( stream: S, mut tls: Option<&TlsConfig>, - cancel_map: &CancelMap, -) -> anyhow::Result>, StartupMessageParams)>> { +) -> Result, HandshakeError> { // Client may try upgrading to each protocol only once let (mut tried_ssl, mut tried_gss) = (false, false); @@ -49,14 +93,14 @@ pub async fn handshake( // pipelining in our node js driver. We should probably // support that by chaining read_buf with the stream. if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); + return Err(HandshakeError::EarlyData); } let tls_stream = raw.upgrade(tls.to_server_config()).await?; let (_, tls_server_end_point) = tls .cert_resolver .resolve(tls_stream.get_ref().1.server_name()) - .context("missing certificate")?; + .ok_or(HandshakeError::MissingCertificate)?; stream = PqStream::new(Stream::Tls { tls: Box::new(tls_stream), @@ -64,7 +108,7 @@ pub async fn handshake( }); } } - _ => bail!(ERR_PROTO_VIOLATION), + _ => return Err(HandshakeError::ProtocolViolation), }, GssEncRequest => match stream.get_ref() { Stream::Raw { .. } if !tried_gss => { @@ -73,23 +117,23 @@ pub async fn handshake( // Currently, we don't support GSSAPI stream.write_message(&Be::EncryptionResponse(false)).await?; } - _ => bail!(ERR_PROTO_VIOLATION), + _ => return Err(HandshakeError::ProtocolViolation), }, StartupMessage { params, .. } => { // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { - stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; + return stream + .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User) + .await?; } info!(session_type = "normal", "successful handshake"); - break Ok(Some((stream, params))); + break Ok(HandshakeData::Startup(stream, params)); } CancelRequest(cancel_key_data) => { - cancel_map.cancel_session(cancel_key_data).await?; - info!(session_type = "cancellation", "successful handshake"); - break Ok(None); + break Ok(HandshakeData::Cancel(cancel_key_data)); } } } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 53e0c3c8f3..b7018c6fb5 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,9 +1,11 @@ use crate::{ + compute::PostgresConnection, console::messages::MetricsAuxInfo, - context::RequestMonitoring, metrics::NUM_BYTES_PROXIED_COUNTER, + stream::Stream, usage_metrics::{Ids, USAGE_METRICS}, }; +use metrics::IntCounterPairGuard; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; use utils::measured_stream::MeasuredStream; @@ -11,14 +13,10 @@ use utils::measured_stream::MeasuredStream; /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(skip_all)] pub async fn proxy_pass( - ctx: &mut RequestMonitoring, client: impl AsyncRead + AsyncWrite + Unpin, compute: impl AsyncRead + AsyncWrite + Unpin, aux: MetricsAuxInfo, ) -> anyhow::Result<()> { - ctx.set_success(); - ctx.log(); - let usage = USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id.clone(), branch_id: aux.branch_id.clone(), @@ -51,3 +49,18 @@ pub async fn proxy_pass( Ok(()) } + +pub struct ProxyPassthrough { + pub client: Stream, + pub compute: PostgresConnection, + pub aux: MetricsAuxInfo, + + pub req: IntCounterPairGuard, + pub conn: IntCounterPairGuard, +} + +impl ProxyPassthrough { + pub async fn proxy_pass(self) -> anyhow::Result<()> { + proxy_pass(self.client, self.compute.stream, self.aux).await + } +} diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 656cabac75..3e961afb41 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -163,11 +163,11 @@ async fn dummy_proxy( tls: Option, auth: impl TestAuth + Send, ) -> anyhow::Result<()> { - let cancel_map = CancelMap::default(); let client = WithClientIp::new(client); - let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map) - .await? - .context("handshake failed")?; + let mut stream = match handshake(client, tls.as_ref()).await? { + HandshakeData::Startup(stream, _) => stream, + HandshakeData::Cancel(_) => bail!("cancellation not supported"), + }; auth.authenticate(&mut stream).await?; diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index a0a84a1dc0..ed89e51754 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -35,12 +35,10 @@ async fn proxy_mitm( tokio::spawn(async move { // begin handshake with end_server let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await; - // process handshake with end_client - let (end_client, startup) = - handshake(client1, Some(&server_config1), &CancelMap::default()) - .await - .unwrap() - .unwrap(); + let (end_client, startup) = match handshake(client1, Some(&server_config1)).await.unwrap() { + HandshakeData::Startup(stream, params) => (stream, params), + HandshakeData::Cancel(_) => panic!("cancellation not supported"), + }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); let (end_client, buf) = end_client.framed.into_inner(); diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index da1cf21c6a..1cf8b53e11 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -10,7 +10,7 @@ mod channel_binding; mod messages; mod stream; -use crate::error::UserFacingError; +use crate::error::{ReportableError, UserFacingError}; use std::io; use thiserror::Error; @@ -48,6 +48,18 @@ impl UserFacingError for Error { } } +impl ReportableError for Error { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, + Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, + Error::BadClientMessage(_) => crate::error::ErrorKind::User, + Error::MissingBinding => crate::error::ErrorKind::Service, + Error::Io(_) => crate::error::ErrorKind::ClientDisconnect, + } + } +} + /// A convenient result type for SASL exchange. pub type Result = std::result::Result; diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 58aa925a6a..a20600b94a 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -109,10 +109,9 @@ pub async fn task_main( let make_svc = hyper::service::make_service_fn( |stream: &tokio_rustls::server::TlsStream>| { - let (io, tls) = stream.get_ref(); + let (io, _) = stream.get_ref(); let client_addr = io.client_addr(); let remote_addr = io.inner.remote_addr(); - let sni_name = tls.server_name().map(|s| s.to_string()); let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -125,7 +124,6 @@ pub async fn task_main( }; Ok(MetricService::new(hyper::service::service_fn( move |req: Request| { - let sni_name = sni_name.clone(); let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -141,7 +139,6 @@ pub async fn task_main( ws_connections, cancel_map, session_id, - sni_name, peer_addr.ip(), endpoint_rate_limiter, ) @@ -210,7 +207,6 @@ async fn request_handler( ws_connections: TaskTracker, cancel_map: Arc, session_id: uuid::Uuid, - sni_hostname: Option, peer_addr: IpAddr, endpoint_rate_limiter: Arc, ) -> Result, ApiError> { @@ -230,11 +226,11 @@ async fn request_handler( ws_connections.spawn( async move { - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region); + let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region); if let Err(e) = websocket::serve_websocket( config, - &mut ctx, + ctx, websocket, cancel_map, host, @@ -251,9 +247,9 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response) } else if request.uri().path() == "/sql" && request.method() == Method::POST { - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); + let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); - sql_over_http::handle(config, &mut ctx, request, sni_hostname, backend).await + sql_over_http::handle(config, ctx, request, backend).await } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { Response::builder() .header("Allow", "OPTIONS, POST") diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 466a74f0ea..03257e9161 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,6 +1,5 @@ use std::{sync::Arc, time::Duration}; -use anyhow::Context; use async_trait::async_trait; use tracing::info; @@ -8,7 +7,10 @@ use crate::{ auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, compute, config::ProxyConfig, - console::CachedNodeInfo, + console::{ + errors::{GetAuthInfoError, WakeComputeError}, + CachedNodeInfo, + }, context::RequestMonitoring, proxy::connect_compute::ConnectMechanism, }; @@ -66,7 +68,7 @@ impl PoolingBackend { conn_info: ConnInfo, keys: ComputeCredentialKeys, force_new: bool, - ) -> anyhow::Result> { + ) -> Result, HttpConnError> { let maybe_client = if !force_new { info!("pool: looking for an existing connection"); self.pool.get(ctx, &conn_info).await? @@ -90,7 +92,7 @@ impl PoolingBackend { let mut node_info = backend .wake_compute(ctx) .await? - .context("missing cache entry from wake_compute")?; + .ok_or(HttpConnError::NoComputeInfo)?; match keys { #[cfg(any(test, feature = "testing"))] @@ -114,6 +116,23 @@ impl PoolingBackend { } } +#[derive(Debug, thiserror::Error)] +pub enum HttpConnError { + #[error("pooled connection closed at inconsistent state")] + ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), + #[error("could not connection to compute")] + ConnectionError(#[from] tokio_postgres::Error), + + #[error("could not get auth info")] + GetAuthInfo(#[from] GetAuthInfoError), + #[error("user not authenticated")] + AuthError(#[from] AuthError), + #[error("wake_compute returned error")] + WakeCompute(#[from] WakeComputeError), + #[error("wake_compute returned nothing")] + NoComputeInfo, +} + struct TokioMechanism { pool: Arc>, conn_info: ConnInfo, @@ -124,7 +143,7 @@ struct TokioMechanism { impl ConnectMechanism for TokioMechanism { type Connection = Client; type ConnectError = tokio_postgres::Error; - type Error = anyhow::Error; + type Error = HttpConnError; async fn connect_once( &self, diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index a7b2c532d2..f92793096b 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -28,6 +28,8 @@ use crate::{ use tracing::{debug, error, warn, Span}; use tracing::{info, info_span, Instrument}; +use super::backend::HttpConnError; + pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http"); #[derive(Debug, Clone)] @@ -358,7 +360,7 @@ impl GlobalConnPool { self: &Arc, ctx: &mut RequestMonitoring, conn_info: &ConnInfo, - ) -> anyhow::Result>> { + ) -> Result>, HttpConnError> { let mut client: Option> = None; let endpoint_pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index a089d34040..c22c63e85b 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -60,6 +60,20 @@ fn json_array_to_pg_array(value: &Value) -> Option { } } +#[derive(Debug, thiserror::Error)] +pub enum JsonConversionError { + #[error("internal error compute returned invalid data: {0}")] + AsTextError(tokio_postgres::Error), + #[error("parse int error: {0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("parse float error: {0}")] + ParseFloatError(#[from] std::num::ParseFloatError), + #[error("parse json error: {0}")] + ParseJsonError(#[from] serde_json::Error), + #[error("unbalanced array")] + UnbalancedArray, +} + // // Convert postgres row with text-encoded values to JSON object // @@ -68,7 +82,7 @@ pub fn pg_text_row_to_json( columns: &[Type], raw_output: bool, array_mode: bool, -) -> Result { +) -> Result { let iter = row .columns() .iter() @@ -76,7 +90,7 @@ pub fn pg_text_row_to_json( .enumerate() .map(|(i, (column, typ))| { let name = column.name(); - let pg_value = row.as_text(i)?; + let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?; let json_value = if raw_output { match pg_value { Some(v) => Value::String(v.to_string()), @@ -92,10 +106,10 @@ pub fn pg_text_row_to_json( // drop keys and aggregate into array let arr = iter .map(|r| r.map(|(_key, val)| val)) - .collect::, anyhow::Error>>()?; + .collect::, JsonConversionError>>()?; Ok(Value::Array(arr)) } else { - let obj = iter.collect::, anyhow::Error>>()?; + let obj = iter.collect::, JsonConversionError>>()?; Ok(Value::Object(obj)) } } @@ -103,7 +117,7 @@ pub fn pg_text_row_to_json( // // Convert postgres text-encoded value to JSON value // -fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { +fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { if let Some(val) = pg_value { if let Kind::Array(elem_type) = pg_type.kind() { return pg_array_parse(val, elem_type); @@ -142,7 +156,7 @@ fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result Result { +fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result { _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v) } @@ -150,7 +164,7 @@ fn _pg_array_parse( pg_array: &str, elem_type: &Type, nested: bool, -) -> Result<(Value, usize), anyhow::Error> { +) -> Result<(Value, usize), JsonConversionError> { let mut pg_array_chr = pg_array.char_indices(); let mut level = 0; let mut quote = false; @@ -170,7 +184,7 @@ fn _pg_array_parse( entry: &mut String, entries: &mut Vec, elem_type: &Type, - ) -> Result<(), anyhow::Error> { + ) -> Result<(), JsonConversionError> { if !entry.is_empty() { // While in usual postgres response we get nulls as None and everything else // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while @@ -234,7 +248,7 @@ fn _pg_array_parse( } if level != 0 { - return Err(anyhow::anyhow!("unbalanced array")); + return Err(JsonConversionError::UnbalancedArray); } Ok((Value::Array(entries), 0)) diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 25e8813625..401022347e 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use anyhow::bail; -use anyhow::Context; use futures::pin_mut; use futures::StreamExt; use hyper::body::HttpBody; @@ -29,9 +28,11 @@ use utils::http::json::json_response; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; +use crate::auth::ComputeUserInfoParseError; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; +use crate::error::ReportableError; use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; @@ -41,7 +42,6 @@ use super::backend::PoolingBackend; use super::conn_pool::ConnInfo; use super::json::json_to_pg_text; use super::json::pg_text_row_to_json; -use super::SERVERLESS_DRIVER_SNI; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] @@ -86,67 +86,70 @@ where Ok(json_to_pg_text(json)) } +#[derive(Debug, thiserror::Error)] +pub enum ConnInfoError { + #[error("invalid header: {0}")] + InvalidHeader(&'static str), + #[error("invalid connection string: {0}")] + UrlParseError(#[from] url::ParseError), + #[error("incorrect scheme")] + IncorrectScheme, + #[error("missing database name")] + MissingDbName, + #[error("invalid database name")] + InvalidDbName, + #[error("missing username")] + MissingUsername, + #[error("missing password")] + MissingPassword, + #[error("missing hostname")] + MissingHostname, + #[error("invalid hostname: {0}")] + InvalidEndpoint(#[from] ComputeUserInfoParseError), + #[error("malformed endpoint")] + MalformedEndpoint, +} + fn get_conn_info( ctx: &mut RequestMonitoring, headers: &HeaderMap, - sni_hostname: Option, tls: &TlsConfig, -) -> Result { +) -> Result { let connection_string = headers .get("Neon-Connection-String") - .ok_or(anyhow::anyhow!("missing connection string"))? - .to_str()?; + .ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))? + .to_str() + .map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?; let connection_url = Url::parse(connection_string)?; let protocol = connection_url.scheme(); if protocol != "postgres" && protocol != "postgresql" { - return Err(anyhow::anyhow!( - "connection string must start with postgres: or postgresql:" - )); + return Err(ConnInfoError::IncorrectScheme); } let mut url_path = connection_url .path_segments() - .ok_or(anyhow::anyhow!("missing database name"))?; + .ok_or(ConnInfoError::MissingDbName)?; - let dbname = url_path - .next() - .ok_or(anyhow::anyhow!("invalid database name"))?; + let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?; let username = RoleName::from(connection_url.username()); if username.is_empty() { - return Err(anyhow::anyhow!("missing username")); + return Err(ConnInfoError::MissingUsername); } ctx.set_user(username.clone()); let password = connection_url .password() - .ok_or(anyhow::anyhow!("no password"))?; - - // TLS certificate selector now based on SNI hostname, so if we are running here - // we are sure that SNI hostname is set to one of the configured domain names. - let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?; + .ok_or(ConnInfoError::MissingPassword)?; let hostname = connection_url .host_str() - .ok_or(anyhow::anyhow!("no host"))?; + .ok_or(ConnInfoError::MissingHostname)?; - let host_header = headers - .get("host") - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.split(':').next()); - - // sni_hostname has to be either the same as hostname or the one used in serverless driver. - if !check_matches(&sni_hostname, hostname)? { - return Err(anyhow::anyhow!("mismatched SNI hostname and hostname")); - } else if let Some(h) = host_header { - if h != sni_hostname { - return Err(anyhow::anyhow!("mismatched host header and hostname")); - } - } - - let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?; + let endpoint = + endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?; ctx.set_endpoint_id(endpoint.clone()); let pairs = connection_url.query_pairs(); @@ -173,36 +176,27 @@ fn get_conn_info( }) } -fn check_matches(sni_hostname: &str, hostname: &str) -> Result { - if sni_hostname == hostname { - return Ok(true); - } - let (sni_hostname_first, sni_hostname_rest) = sni_hostname - .split_once('.') - .ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?; - let (_, hostname_rest) = hostname - .split_once('.') - .ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?; - Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI) -} - // TODO: return different http error codes pub async fn handle( config: &'static ProxyConfig, - ctx: &mut RequestMonitoring, + mut ctx: RequestMonitoring, request: Request, - sni_hostname: Option, backend: Arc, ) -> Result, ApiError> { let result = tokio::time::timeout( config.http_config.request_timeout, - handle_inner(config, ctx, request, sni_hostname, backend), + handle_inner(config, &mut ctx, request, backend), ) .await; let mut response = match result { Ok(r) => match r { - Ok(r) => r, + Ok(r) => { + ctx.set_success(); + r + } Err(e) => { + // TODO: ctx.set_error_kind(e.get_error_type()); + let mut message = format!("{:?}", e); let db_error = e .downcast_ref::() @@ -278,7 +272,9 @@ pub async fn handle( )? } }, - Err(_) => { + Err(e) => { + ctx.set_error_kind(e.get_error_kind()); + let message = format!( "HTTP-Connection timed out, execution time exeeded {} seconds", config.http_config.request_timeout.as_secs() @@ -290,6 +286,7 @@ pub async fn handle( )? } }; + response.headers_mut().insert( "Access-Control-Allow-Origin", hyper::http::HeaderValue::from_static("*"), @@ -302,7 +299,6 @@ async fn handle_inner( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, request: Request, - sni_hostname: Option, backend: Arc, ) -> anyhow::Result> { let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE @@ -318,12 +314,7 @@ async fn handle_inner( // let headers = request.headers(); // TLS config should be there. - let conn_info = get_conn_info( - ctx, - headers, - sni_hostname, - config.tls_config.as_ref().unwrap(), - )?; + let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?; info!( user = conn_info.user_info.user.as_str(), project = conn_info.user_info.endpoint.as_str(), @@ -487,8 +478,6 @@ async fn handle_inner( } }; - ctx.set_success(); - ctx.log(); let metrics = client.metrics(); // how could this possibly fail diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index f68b35010a..062dd440b2 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -2,7 +2,7 @@ use crate::{ cancellation::CancelMap, config::ProxyConfig, context::RequestMonitoring, - error::io_error, + error::{io_error, ReportableError}, proxy::{handle_client, ClientMode}, rate_limiter::EndpointRateLimiter, }; @@ -131,23 +131,41 @@ impl AsyncBufRead for WebSocketRw { pub async fn serve_websocket( config: &'static ProxyConfig, - ctx: &mut RequestMonitoring, + mut ctx: RequestMonitoring, websocket: HyperWebsocket, cancel_map: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { let websocket = websocket.await?; - handle_client( + let res = handle_client( config, - ctx, + &mut ctx, cancel_map, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, endpoint_rate_limiter, ) - .await?; - Ok(()) + .await; + + match res { + Err(e) => { + // todo: log and push to ctx the error kind + ctx.set_error_kind(e.get_error_kind()); + ctx.log(); + Err(e.into()) + } + Ok(None) => { + ctx.set_success(); + ctx.log(); + Ok(()) + } + Ok(Some(p)) => { + ctx.set_success(); + ctx.log(); + p.proxy_pass().await + } + } } #[cfg(test)] diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index f48b3fe39f..0d639d2c07 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,6 +1,5 @@ use crate::config::TlsServerEndPoint; -use crate::error::UserFacingError; -use anyhow::bail; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; use bytes::BytesMut; use pq_proto::framed::{ConnectionError, Framed}; @@ -73,6 +72,30 @@ impl PqStream { } } +#[derive(Debug)] +pub struct ReportedError { + source: anyhow::Error, + error_kind: ErrorKind, +} + +impl std::fmt::Display for ReportedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.source.fmt(f) + } +} + +impl std::error::Error for ReportedError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source.source() + } +} + +impl ReportableError for ReportedError { + fn get_error_kind(&self) -> ErrorKind { + self.error_kind + } +} + impl PqStream { /// Write the message into an internal buffer, but don't flush the underlying stream. pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { @@ -98,24 +121,52 @@ impl PqStream { /// Write the error message using [`Self::write_message`], then re-throw it. /// Allowing string literals is safe under the assumption they might not contain any runtime info. /// This method exists due to `&str` not implementing `Into`. - pub async fn throw_error_str(&mut self, error: &'static str) -> anyhow::Result { - tracing::info!("forwarding error to user: {error}"); - self.write_message(&BeMessage::ErrorResponse(error, None)) - .await?; - bail!(error) + pub async fn throw_error_str( + &mut self, + msg: &'static str, + error_kind: ErrorKind, + ) -> Result { + tracing::info!( + kind = error_kind.to_metric_label(), + msg, + "forwarding error to user" + ); + + // already error case, ignore client IO error + let _: Result<_, std::io::Error> = self + .write_message(&BeMessage::ErrorResponse(msg, None)) + .await; + + Err(ReportedError { + source: anyhow::anyhow!(msg), + error_kind, + }) } /// Write the error message using [`Self::write_message`], then re-throw it. /// Trait [`UserFacingError`] acts as an allowlist for error types. - pub async fn throw_error(&mut self, error: E) -> anyhow::Result + pub async fn throw_error(&mut self, error: E) -> Result where E: UserFacingError + Into, { + let error_kind = error.get_error_kind(); let msg = error.to_string_client(); - tracing::info!("forwarding error to user: {msg}"); - self.write_message(&BeMessage::ErrorResponse(&msg, None)) - .await?; - bail!(error) + tracing::info!( + kind=error_kind.to_metric_label(), + error=%error, + msg, + "forwarding error to user" + ); + + // already error case, ignore client IO error + let _: Result<_, std::io::Error> = self + .write_message(&BeMessage::ErrorResponse(&msg, None)) + .await; + + Err(ReportedError { + source: anyhow::anyhow!(error), + error_kind, + }) } }