Proxy error reworking (#6453)

## Problem

Taking my ideas from https://github.com/neondatabase/neon/pull/6283 and
doing a bit less radical changes. smaller commits.

We currently don't report error classifications in proxy as the current
error handling made it hard to do so.

## Summary of changes

1. Add a `ReportableError` trait that all errors will implement. This
provides the error classification functionality.
2. Handle Client requests a strongly typed error
    * this error is a `ReportableError` and is logged appropriately
3. The handle client error only has a few possible error types, to
account for the fact that at this point errors should be returned to the
user.
This commit is contained in:
Conrad Ludgate
2024-02-09 15:50:51 +00:00
committed by GitHub
parent 89a5c654bf
commit 96d89cde51
25 changed files with 588 additions and 186 deletions

View File

@@ -5,7 +5,8 @@ pub use backend::BackendType;
mod credentials;
pub use credentials::{
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern,
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
ComputeUserInfoParseError, IpPattern,
};
mod password_hack;
@@ -14,8 +15,12 @@ use password_hack::PasswordHackPayload;
mod flow;
pub use flow::*;
use tokio::time::error::Elapsed;
use crate::{console, error::UserFacingError};
use crate::{
console,
error::{ReportableError, UserFacingError},
};
use std::io;
use thiserror::Error;
@@ -67,6 +72,9 @@ pub enum AuthErrorImpl {
#[error("Too many connections to this endpoint. Please try again later.")]
TooManyConnections,
#[error("Authentication timed out")]
UserTimeout(Elapsed),
}
#[derive(Debug, Error)]
@@ -93,6 +101,10 @@ impl AuthError {
pub fn is_auth_failed(&self) -> bool {
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
}
pub fn user_timeout(elapsed: Elapsed) -> Self {
AuthErrorImpl::UserTimeout(elapsed).into()
}
}
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
@@ -116,6 +128,27 @@ impl UserFacingError for AuthError {
Io(_) => "Internal error".to_string(),
IpAddressNotAllowed => self.to_string(),
TooManyConnections => self.to_string(),
UserTimeout(_) => self.to_string(),
}
}
}
impl ReportableError for AuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
use AuthErrorImpl::*;
match self.0.as_ref() {
Link(e) => e.get_error_kind(),
GetAuthInfo(e) => e.get_error_kind(),
WakeCompute(e) => e.get_error_kind(),
Sasl(e) => e.get_error_kind(),
AuthFailed(_) => crate::error::ErrorKind::User,
BadAuthMethod(_) => crate::error::ErrorKind::User,
MalformedPassword(_) => crate::error::ErrorKind::User,
MissingEndpointName => crate::error::ErrorKind::User,
Io(_) => crate::error::ErrorKind::ClientDisconnect,
IpAddressNotAllowed => crate::error::ErrorKind::User,
TooManyConnections => crate::error::ErrorKind::RateLimit,
UserTimeout(_) => crate::error::ErrorKind::User,
}
}
}

View File

@@ -45,9 +45,9 @@ pub(super) async fn authenticate(
}
)
.await
.map_err(|error| {
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::io::Error::new(auth::io::ErrorKind::TimedOut, error)
auth::AuthError::user_timeout(e)
})??;
let client_key = match auth_outcome {

View File

@@ -2,7 +2,7 @@ use crate::{
auth, compute,
console::{self, provider::NodeInfo},
context::RequestMonitoring,
error::UserFacingError,
error::{ReportableError, UserFacingError},
stream::PqStream,
waiters,
};
@@ -14,10 +14,6 @@ use tracing::{info, info_span};
#[derive(Debug, Error)]
pub enum LinkAuthError {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),
#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),
@@ -30,10 +26,16 @@ pub enum LinkAuthError {
impl UserFacingError for LinkAuthError {
fn to_string_client(&self) -> String {
use LinkAuthError::*;
"Internal error".to_string()
}
}
impl ReportableError for LinkAuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
AuthFailed(_) => self.to_string(),
_ => "Internal error".to_string(),
LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}

View File

@@ -1,8 +1,12 @@
//! User credentials used in authentication.
use crate::{
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
auth::password_hack::parse_endpoint_param,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
proxy::NeonOptions,
serverless::SERVERLESS_DRIVER_SNI,
EndpointId, RoleName,
};
use itertools::Itertools;
@@ -39,6 +43,12 @@ pub enum ComputeUserInfoParseError {
impl UserFacingError for ComputeUserInfoParseError {}
impl ReportableError for ComputeUserInfoParseError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]

View File

@@ -240,7 +240,9 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
?unexpected,
"unexpected startup packet, rejecting connection"
);
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
stream
.throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User)
.await?
}
}
}
@@ -272,5 +274,10 @@ async fn handle_client(
let client = tokio::net::TcpStream::connect(destination).await?;
let metrics_aux: MetricsAuxInfo = Default::default();
proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await
// doesn't yet matter as pg-sni-router doesn't report analytics logs
ctx.set_success();
ctx.log();
proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await
}

View File

@@ -1,24 +1,45 @@
use anyhow::Context;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use crate::error::ReportableError;
/// Enables serving `CancelRequest`s.
#[derive(Default)]
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
#[derive(Debug, Error)]
pub enum CancelError {
#[error("{0}")]
IO(#[from] std::io::Error),
#[error("{0}")]
Postgres(#[from] tokio_postgres::Error),
}
impl ReportableError for CancelError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
CancelError::IO(_) => crate::error::ErrorKind::Compute,
CancelError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
}
}
}
impl CancelMap {
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
// NB: we should immediately release the lock after cloning the token.
let cancel_closure = self
.0
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("query cancellation key not found: {key}"))?;
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
return Ok(());
};
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
@@ -81,7 +102,7 @@ impl CancelClosure {
}
/// Cancels the query running on user's compute node.
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
async fn try_cancel_query(self) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
self.cancel_token.cancel_query_raw(socket, NoTls).await?;

View File

@@ -1,6 +1,10 @@
use crate::{
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::errors::WakeComputeError,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_DB_CONNECTIONS_GAUGE,
proxy::neon_option,
};
use futures::{FutureExt, TryFutureExt};
@@ -58,6 +62,20 @@ impl UserFacingError for ConnectionError {
}
}
impl ReportableError for ConnectionError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
}
}
}
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;

View File

@@ -20,7 +20,7 @@ use tracing::info;
pub mod errors {
use crate::{
error::{io_error, UserFacingError},
error::{io_error, ReportableError, UserFacingError},
http,
proxy::retry::ShouldRetry,
};
@@ -81,6 +81,15 @@ pub mod errors {
}
}
impl ReportableError for ApiError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
impl ShouldRetry for ApiError {
fn could_retry(&self) -> bool {
match self {
@@ -150,6 +159,16 @@ pub mod errors {
}
}
}
impl ReportableError for GetAuthInfoError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
#[derive(Debug, Error)]
pub enum WakeComputeError {
#[error("Console responded with a malformed compute address: {0}")]
@@ -194,6 +213,16 @@ pub mod errors {
}
}
}
impl ReportableError for WakeComputeError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
WakeComputeError::ApiError(e) => e.get_error_kind(),
WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
}
}
}
}
/// Auth secret which is managed by the cloud.

View File

@@ -8,8 +8,10 @@ use tokio::sync::mpsc;
use uuid::Uuid;
use crate::{
console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId,
EndpointId, ProjectId, RoleName,
console::messages::MetricsAuxInfo,
error::ErrorKind,
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
BranchId, EndpointId, ProjectId, RoleName,
};
pub mod parquet;
@@ -108,6 +110,18 @@ impl RequestMonitoring {
self.user = Some(user);
}
pub fn set_error_kind(&mut self, kind: ErrorKind) {
ERROR_BY_KIND
.with_label_values(&[kind.to_metric_label()])
.inc();
if let Some(ep) = &self.endpoint_id {
ENDPOINT_ERRORS_BY_KIND
.with_label_values(&[kind.to_metric_label()])
.measure(ep);
}
self.error_kind = Some(kind);
}
pub fn set_success(&mut self) {
self.success = true;
}

View File

@@ -108,7 +108,7 @@ impl From<RequestMonitoring> for RequestData {
branch: value.branch.as_deref().map(String::from),
protocol: value.protocol,
region: value.region,
error: value.error_kind.as_ref().map(|e| e.to_str()),
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success,
duration_us: SystemTime::from(value.first_packet)
.elapsed()

View File

@@ -17,7 +17,7 @@ pub fn log_error<E: fmt::Display>(e: E) -> E {
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
/// is way too convenient and tends to proliferate all across the codebase,
/// ultimately leading to accidental leaks of sensitive data.
pub trait UserFacingError: fmt::Display {
pub trait UserFacingError: ReportableError {
/// Format the error for client, stripping all sensitive info.
///
/// Although this might be a no-op for many types, it's highly
@@ -29,13 +29,13 @@ pub trait UserFacingError: fmt::Display {
}
}
#[derive(Clone)]
#[derive(Copy, Clone, Debug)]
pub enum ErrorKind {
/// Wrong password, unknown endpoint, protocol violation, etc...
User,
/// Network error between user and proxy. Not necessarily user error
Disconnect,
ClientDisconnect,
/// Proxy self-imposed rate limits
RateLimit,
@@ -46,6 +46,9 @@ pub enum ErrorKind {
/// Error communicating with control plane
ControlPlane,
/// Postgres error
Postgres,
/// Error communicating with compute
Compute,
}
@@ -54,11 +57,36 @@ impl ErrorKind {
pub fn to_str(&self) -> &'static str {
match self {
ErrorKind::User => "request failed due to user error",
ErrorKind::Disconnect => "client disconnected",
ErrorKind::ClientDisconnect => "client disconnected",
ErrorKind::RateLimit => "request cancelled due to rate limit",
ErrorKind::Service => "internal service error",
ErrorKind::ControlPlane => "non-retryable control plane error",
ErrorKind::Compute => "non-retryable compute error (or exhausted retry capacity)",
ErrorKind::Postgres => "postgres error",
ErrorKind::Compute => {
"non-retryable compute connection error (or exhausted retry capacity)"
}
}
}
pub fn to_metric_label(&self) -> &'static str {
match self {
ErrorKind::User => "user",
ErrorKind::ClientDisconnect => "clientdisconnect",
ErrorKind::RateLimit => "ratelimit",
ErrorKind::Service => "service",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "postgres",
ErrorKind::Compute => "compute",
}
}
}
pub trait ReportableError: fmt::Display + Send + 'static {
fn get_error_kind(&self) -> ErrorKind;
}
impl ReportableError for tokio::time::error::Elapsed {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::RateLimit
}
}

View File

@@ -274,3 +274,22 @@ pub static CONNECTING_ENDPOINTS: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
)
.unwrap()
});
pub static ERROR_BY_KIND: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_errors_total",
"Number of errors by a given classification",
&["type"],
)
.unwrap()
});
pub static ENDPOINT_ERRORS_BY_KIND: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
register_hll_vec!(
32,
"proxy_endpoints_affected_by_errors",
"Number of endpoints affected by errors of a given classification",
&["type"],
)
.unwrap()
});

View File

@@ -13,9 +13,10 @@ use crate::{
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
protocol2::WithClientIp,
proxy::{handshake::handshake, passthrough::proxy_pass},
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey,
@@ -28,14 +29,17 @@ use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use self::connect_compute::{connect_to_compute, TcpMechanism};
use self::{
connect_compute::{connect_to_compute, TcpMechanism},
passthrough::ProxyPassthrough,
};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
@@ -98,14 +102,14 @@ pub async fn task_main(
bail!("missing required client IP");
}
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
socket
.inner
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
let res = handle_client(
config,
&mut ctx,
cancel_map,
@@ -113,7 +117,26 @@ pub async fn task_main(
ClientMode::Tcp,
endpoint_rate_limiter,
)
.await
.await;
match res {
Err(e) => {
// todo: log and push to ctx the error kind
ctx.set_error_kind(e.get_error_kind());
ctx.log();
Err(e.into())
}
Ok(None) => {
ctx.set_success();
ctx.log();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log();
p.proxy_pass().await
}
}
}
.unwrap_or_else(move |e| {
// Acknowledge that the task has finished with an error.
@@ -169,6 +192,37 @@ impl ClientMode {
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] handshake::HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
@@ -176,7 +230,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
info!(
protocol = ctx.protocol,
"handling interactive connection from client"
@@ -193,11 +247,16 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let tls = config.tls_config.as_ref();
let pause = ctx.latency_timer.pause();
let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
let do_handshake = handshake(stream, mode.handshake_tls(tls));
let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancel_map
.cancel_session(cancel_key_data)
.await
.map(|()| None)?)
}
};
drop(pause);
@@ -222,7 +281,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
if !endpoint_rate_limiter.check(ep) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await;
.await?;
}
}
@@ -242,7 +301,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return stream.throw_error(e).instrument(params_span).await;
return stream.throw_error(e).instrument(params_span).await?;
}
};
@@ -268,7 +327,13 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(ctx, stream, node.stream, aux).await
Ok(Some(ProxyPassthrough {
client: stream,
compute: node,
aux,
req: _request_gauge,
conn: _client_gauge,
}))
}
/// Finish client connection initialization: confirm auth success, send params, etc.
@@ -277,7 +342,7 @@ async fn prepare_client_connection(
node: &compute::PostgresConnection,
session: &cancellation::Session,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<()> {
) -> Result<(), std::io::Error> {
// Register compute's query cancellation token and produce a new, unique one.
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());

View File

@@ -1,15 +1,60 @@
use anyhow::{bail, Context};
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use crate::{
cancellation::CancelMap,
config::TlsConfig,
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
stream::{PqStream, Stream},
error::ReportableError,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
};
#[derive(Error, Debug)]
pub enum HandshakeError {
#[error("data is sent before server replied with EncryptionResponse")]
EarlyData,
#[error("protocol violation")]
ProtocolViolation,
#[error("missing certificate")]
MissingCertificate,
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for HandshakeError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
// This error should not happen, but will if we have no default certificate and
// the client sends no SNI extension.
// If they provide SNI then we can be sure there is a certificate that matches.
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
},
HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
HandshakeError::ReportedError(e) => e.get_error_kind(),
}
}
}
pub enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData),
}
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
@@ -18,8 +63,7 @@ use crate::{
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
) -> Result<HandshakeData<S>, HandshakeError> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
@@ -49,14 +93,14 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
return Err(HandshakeError::EarlyData);
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
.ok_or(HandshakeError::MissingCertificate)?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
@@ -64,7 +108,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
});
}
}
_ => bail!(ERR_PROTO_VIOLATION),
_ => return Err(HandshakeError::ProtocolViolation),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
@@ -73,23 +117,23 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!(ERR_PROTO_VIOLATION),
_ => return Err(HandshakeError::ProtocolViolation),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
return stream
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
.await?;
}
info!(session_type = "normal", "successful handshake");
break Ok(Some((stream, params)));
break Ok(HandshakeData::Startup(stream, params));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
info!(session_type = "cancellation", "successful handshake");
break Ok(None);
break Ok(HandshakeData::Cancel(cancel_key_data));
}
}
}

View File

@@ -1,9 +1,11 @@
use crate::{
compute::PostgresConnection,
console::messages::MetricsAuxInfo,
context::RequestMonitoring,
metrics::NUM_BYTES_PROXIED_COUNTER,
stream::Stream,
usage_metrics::{Ids, USAGE_METRICS},
};
use metrics::IntCounterPairGuard;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;
@@ -11,14 +13,10 @@ use utils::measured_stream::MeasuredStream;
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub async fn proxy_pass(
ctx: &mut RequestMonitoring,
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
ctx.set_success();
ctx.log();
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id.clone(),
branch_id: aux.branch_id.clone(),
@@ -51,3 +49,18 @@ pub async fn proxy_pass(
Ok(())
}
pub struct ProxyPassthrough<S> {
pub client: Stream<S>,
pub compute: PostgresConnection,
pub aux: MetricsAuxInfo,
pub req: IntCounterPairGuard,
pub conn: IntCounterPairGuard,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub async fn proxy_pass(self) -> anyhow::Result<()> {
proxy_pass(self.client, self.compute.stream, self.aux).await
}
}

View File

@@ -163,11 +163,11 @@ async fn dummy_proxy(
tls: Option<TlsConfig>,
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
let client = WithClientIp::new(client);
let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map)
.await?
.context("handshake failed")?;
let mut stream = match handshake(client, tls.as_ref()).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
auth.authenticate(&mut stream).await?;

View File

@@ -35,12 +35,10 @@ async fn proxy_mitm(
tokio::spawn(async move {
// begin handshake with end_server
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
// process handshake with end_client
let (end_client, startup) =
handshake(client1, Some(&server_config1), &CancelMap::default())
.await
.unwrap()
.unwrap();
let (end_client, startup) = match handshake(client1, Some(&server_config1)).await.unwrap() {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
let (end_client, buf) = end_client.framed.into_inner();

View File

@@ -10,7 +10,7 @@ mod channel_binding;
mod messages;
mod stream;
use crate::error::UserFacingError;
use crate::error::{ReportableError, UserFacingError};
use std::io;
use thiserror::Error;
@@ -48,6 +48,18 @@ impl UserFacingError for Error {
}
}
impl ReportableError for Error {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
Error::MissingBinding => crate::error::ErrorKind::Service,
Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
/// A convenient result type for SASL exchange.
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -109,10 +109,9 @@ pub async fn task_main(
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
let (io, tls) = stream.get_ref();
let (io, _) = stream.get_ref();
let client_addr = io.client_addr();
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
@@ -125,7 +124,6 @@ pub async fn task_main(
};
Ok(MetricService::new(hyper::service::service_fn(
move |req: Request<Body>| {
let sni_name = sni_name.clone();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
@@ -141,7 +139,6 @@ pub async fn task_main(
ws_connections,
cancel_map,
session_id,
sni_name,
peer_addr.ip(),
endpoint_rate_limiter,
)
@@ -210,7 +207,6 @@ async fn request_handler(
ws_connections: TaskTracker,
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
sni_hostname: Option<String>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Body>, ApiError> {
@@ -230,11 +226,11 @@ async fn request_handler(
ws_connections.spawn(
async move {
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
if let Err(e) = websocket::serve_websocket(
config,
&mut ctx,
ctx,
websocket,
cancel_map,
host,
@@ -251,9 +247,9 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
sql_over_http::handle(config, &mut ctx, request, sni_hostname, backend).await
sql_over_http::handle(config, ctx, request, backend).await
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")

View File

@@ -1,6 +1,5 @@
use std::{sync::Arc, time::Duration};
use anyhow::Context;
use async_trait::async_trait;
use tracing::info;
@@ -8,7 +7,10 @@ use crate::{
auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
compute,
config::ProxyConfig,
console::CachedNodeInfo,
console::{
errors::{GetAuthInfoError, WakeComputeError},
CachedNodeInfo,
},
context::RequestMonitoring,
proxy::connect_compute::ConnectMechanism,
};
@@ -66,7 +68,7 @@ impl PoolingBackend {
conn_info: ConnInfo,
keys: ComputeCredentialKeys,
force_new: bool,
) -> anyhow::Result<Client<tokio_postgres::Client>> {
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
let maybe_client = if !force_new {
info!("pool: looking for an existing connection");
self.pool.get(ctx, &conn_info).await?
@@ -90,7 +92,7 @@ impl PoolingBackend {
let mut node_info = backend
.wake_compute(ctx)
.await?
.context("missing cache entry from wake_compute")?;
.ok_or(HttpConnError::NoComputeInfo)?;
match keys {
#[cfg(any(test, feature = "testing"))]
@@ -114,6 +116,23 @@ impl PoolingBackend {
}
}
#[derive(Debug, thiserror::Error)]
pub enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connection to compute")]
ConnectionError(#[from] tokio_postgres::Error),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
#[error("user not authenticated")]
AuthError(#[from] AuthError),
#[error("wake_compute returned error")]
WakeCompute(#[from] WakeComputeError),
#[error("wake_compute returned nothing")]
NoComputeInfo,
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
conn_info: ConnInfo,
@@ -124,7 +143,7 @@ struct TokioMechanism {
impl ConnectMechanism for TokioMechanism {
type Connection = Client<tokio_postgres::Client>;
type ConnectError = tokio_postgres::Error;
type Error = anyhow::Error;
type Error = HttpConnError;
async fn connect_once(
&self,

View File

@@ -28,6 +28,8 @@ use crate::{
use tracing::{debug, error, warn, Span};
use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
#[derive(Debug, Clone)]
@@ -358,7 +360,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
self: &Arc<Self>,
ctx: &mut RequestMonitoring,
conn_info: &ConnInfo,
) -> anyhow::Result<Option<Client<C>>> {
) -> Result<Option<Client<C>>, HttpConnError> {
let mut client: Option<ClientInner<C>> = None;
let endpoint_pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());

View File

@@ -60,6 +60,20 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
}
}
#[derive(Debug, thiserror::Error)]
pub enum JsonConversionError {
#[error("internal error compute returned invalid data: {0}")]
AsTextError(tokio_postgres::Error),
#[error("parse int error: {0}")]
ParseIntError(#[from] std::num::ParseIntError),
#[error("parse float error: {0}")]
ParseFloatError(#[from] std::num::ParseFloatError),
#[error("parse json error: {0}")]
ParseJsonError(#[from] serde_json::Error),
#[error("unbalanced array")]
UnbalancedArray,
}
//
// Convert postgres row with text-encoded values to JSON object
//
@@ -68,7 +82,7 @@ pub fn pg_text_row_to_json(
columns: &[Type],
raw_output: bool,
array_mode: bool,
) -> Result<Value, anyhow::Error> {
) -> Result<Value, JsonConversionError> {
let iter = row
.columns()
.iter()
@@ -76,7 +90,7 @@ pub fn pg_text_row_to_json(
.enumerate()
.map(|(i, (column, typ))| {
let name = column.name();
let pg_value = row.as_text(i)?;
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
let json_value = if raw_output {
match pg_value {
Some(v) => Value::String(v.to_string()),
@@ -92,10 +106,10 @@ pub fn pg_text_row_to_json(
// drop keys and aggregate into array
let arr = iter
.map(|r| r.map(|(_key, val)| val))
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
.collect::<Result<Vec<Value>, JsonConversionError>>()?;
Ok(Value::Array(arr))
} else {
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
let obj = iter.collect::<Result<Map<String, Value>, JsonConversionError>>()?;
Ok(Value::Object(obj))
}
}
@@ -103,7 +117,7 @@ pub fn pg_text_row_to_json(
//
// Convert postgres text-encoded value to JSON value
//
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, JsonConversionError> {
if let Some(val) = pg_value {
if let Kind::Array(elem_type) = pg_type.kind() {
return pg_array_parse(val, elem_type);
@@ -142,7 +156,7 @@ fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyh
// values. Unlike postgres we don't check that all nested arrays have the same
// dimensions, we just return them as is.
//
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, JsonConversionError> {
_pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
}
@@ -150,7 +164,7 @@ fn _pg_array_parse(
pg_array: &str,
elem_type: &Type,
nested: bool,
) -> Result<(Value, usize), anyhow::Error> {
) -> Result<(Value, usize), JsonConversionError> {
let mut pg_array_chr = pg_array.char_indices();
let mut level = 0;
let mut quote = false;
@@ -170,7 +184,7 @@ fn _pg_array_parse(
entry: &mut String,
entries: &mut Vec<Value>,
elem_type: &Type,
) -> Result<(), anyhow::Error> {
) -> Result<(), JsonConversionError> {
if !entry.is_empty() {
// While in usual postgres response we get nulls as None and everything else
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
@@ -234,7 +248,7 @@ fn _pg_array_parse(
}
if level != 0 {
return Err(anyhow::anyhow!("unbalanced array"));
return Err(JsonConversionError::UnbalancedArray);
}
Ok((Value::Array(entries), 0))

View File

@@ -1,7 +1,6 @@
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
use futures::pin_mut;
use futures::StreamExt;
use hyper::body::HttpBody;
@@ -29,9 +28,11 @@ use utils::http::json::json_response;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::HTTP_CONTENT_LENGTH;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
@@ -41,7 +42,6 @@ use super::backend::PoolingBackend;
use super::conn_pool::ConnInfo;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::SERVERLESS_DRIVER_SNI;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -86,67 +86,70 @@ where
Ok(json_to_pg_text(json))
}
#[derive(Debug, thiserror::Error)]
pub enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static str),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("missing password")]
MissingPassword,
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
#[error("malformed endpoint")]
MalformedEndpoint,
}
fn get_conn_info(
ctx: &mut RequestMonitoring,
headers: &HeaderMap,
sni_hostname: Option<String>,
tls: &TlsConfig,
) -> Result<ConnInfo, anyhow::Error> {
) -> Result<ConnInfo, ConnInfoError> {
let connection_string = headers
.get("Neon-Connection-String")
.ok_or(anyhow::anyhow!("missing connection string"))?
.to_str()?;
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(anyhow::anyhow!(
"connection string must start with postgres: or postgresql:"
));
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(anyhow::anyhow!("missing database name"))?;
.ok_or(ConnInfoError::MissingDbName)?;
let dbname = url_path
.next()
.ok_or(anyhow::anyhow!("invalid database name"))?;
let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?;
let username = RoleName::from(connection_url.username());
if username.is_empty() {
return Err(anyhow::anyhow!("missing username"));
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
let password = connection_url
.password()
.ok_or(anyhow::anyhow!("no password"))?;
// TLS certificate selector now based on SNI hostname, so if we are running here
// we are sure that SNI hostname is set to one of the configured domain names.
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
.ok_or(ConnInfoError::MissingPassword)?;
let hostname = connection_url
.host_str()
.ok_or(anyhow::anyhow!("no host"))?;
.ok_or(ConnInfoError::MissingHostname)?;
let host_header = headers
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next());
// sni_hostname has to be either the same as hostname or the one used in serverless driver.
if !check_matches(&sni_hostname, hostname)? {
return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
} else if let Some(h) = host_header {
if h != sni_hostname {
return Err(anyhow::anyhow!("mismatched host header and hostname"));
}
}
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
let endpoint =
endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
@@ -173,36 +176,27 @@ fn get_conn_info(
})
}
fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Error> {
if sni_hostname == hostname {
return Ok(true);
}
let (sni_hostname_first, sni_hostname_rest) = sni_hostname
.split_once('.')
.ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?;
let (_, hostname_rest) = hostname
.split_once('.')
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
}
// TODO: return different http error codes
pub async fn handle(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
mut ctx: RequestMonitoring,
request: Request<Body>,
sni_hostname: Option<String>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.http_config.request_timeout,
handle_inner(config, ctx, request, sni_hostname, backend),
handle_inner(config, &mut ctx, request, backend),
)
.await;
let mut response = match result {
Ok(r) => match r {
Ok(r) => r,
Ok(r) => {
ctx.set_success();
r
}
Err(e) => {
// TODO: ctx.set_error_kind(e.get_error_type());
let mut message = format!("{:?}", e);
let db_error = e
.downcast_ref::<tokio_postgres::Error>()
@@ -278,7 +272,9 @@ pub async fn handle(
)?
}
},
Err(_) => {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
let message = format!(
"HTTP-Connection timed out, execution time exeeded {} seconds",
config.http_config.request_timeout.as_secs()
@@ -290,6 +286,7 @@ pub async fn handle(
)?
}
};
response.headers_mut().insert(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
@@ -302,7 +299,6 @@ async fn handle_inner(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
sni_hostname: Option<String>,
backend: Arc<PoolingBackend>,
) -> anyhow::Result<Response<Body>> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
@@ -318,12 +314,7 @@ async fn handle_inner(
//
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(
ctx,
headers,
sni_hostname,
config.tls_config.as_ref().unwrap(),
)?;
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
info!(
user = conn_info.user_info.user.as_str(),
project = conn_info.user_info.endpoint.as_str(),
@@ -487,8 +478,6 @@ async fn handle_inner(
}
};
ctx.set_success();
ctx.log();
let metrics = client.metrics();
// how could this possibly fail

View File

@@ -2,7 +2,7 @@ use crate::{
cancellation::CancelMap,
config::ProxyConfig,
context::RequestMonitoring,
error::io_error,
error::{io_error, ReportableError},
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
@@ -131,23 +131,41 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub async fn serve_websocket(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
mut ctx: RequestMonitoring,
websocket: HyperWebsocket,
cancel_map: Arc<CancelMap>,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
handle_client(
let res = handle_client(
config,
ctx,
&mut ctx,
cancel_map,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
)
.await?;
Ok(())
.await;
match res {
Err(e) => {
// todo: log and push to ctx the error kind
ctx.set_error_kind(e.get_error_kind());
ctx.log();
Err(e.into())
}
Ok(None) => {
ctx.set_success();
ctx.log();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log();
p.proxy_pass().await
}
}
}
#[cfg(test)]

View File

@@ -1,6 +1,5 @@
use crate::config::TlsServerEndPoint;
use crate::error::UserFacingError;
use anyhow::bail;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
@@ -73,6 +72,30 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
}
}
#[derive(Debug)]
pub struct ReportedError {
source: anyhow::Error,
error_kind: ErrorKind,
}
impl std::fmt::Display for ReportedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.source.fmt(f)
}
}
impl std::error::Error for ReportedError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source.source()
}
}
impl ReportableError for ReportedError {
fn get_error_kind(&self) -> ErrorKind {
self.error_kind
}
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the message into an internal buffer, but don't flush the underlying stream.
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
@@ -98,24 +121,52 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
tracing::info!("forwarding error to user: {error}");
self.write_message(&BeMessage::ErrorResponse(error, None))
.await?;
bail!(error)
pub async fn throw_error_str<T>(
&mut self,
msg: &'static str,
error_kind: ErrorKind,
) -> Result<T, ReportedError> {
tracing::info!(
kind = error_kind.to_metric_label(),
msg,
"forwarding error to user"
);
// already error case, ignore client IO error
let _: Result<_, std::io::Error> = self
.write_message(&BeMessage::ErrorResponse(msg, None))
.await;
Err(ReportedError {
source: anyhow::anyhow!(msg),
error_kind,
})
}
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
pub async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
tracing::info!("forwarding error to user: {msg}");
self.write_message(&BeMessage::ErrorResponse(&msg, None))
.await?;
bail!(error)
tracing::info!(
kind=error_kind.to_metric_label(),
error=%error,
msg,
"forwarding error to user"
);
// already error case, ignore client IO error
let _: Result<_, std::io::Error> = self
.write_message(&BeMessage::ErrorResponse(&msg, None))
.await;
Err(ReportedError {
source: anyhow::anyhow!(error),
error_kind,
})
}
}