mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-04 14:00:38 +00:00
Proxy error reworking (#6453)
## Problem Taking my ideas from https://github.com/neondatabase/neon/pull/6283 and doing a bit less radical changes. smaller commits. We currently don't report error classifications in proxy as the current error handling made it hard to do so. ## Summary of changes 1. Add a `ReportableError` trait that all errors will implement. This provides the error classification functionality. 2. Handle Client requests a strongly typed error * this error is a `ReportableError` and is logged appropriately 3. The handle client error only has a few possible error types, to account for the fact that at this point errors should be returned to the user.
This commit is contained in:
@@ -5,7 +5,8 @@ pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::{
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern,
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
|
||||
ComputeUserInfoParseError, IpPattern,
|
||||
};
|
||||
|
||||
mod password_hack;
|
||||
@@ -14,8 +15,12 @@ use password_hack::PasswordHackPayload;
|
||||
|
||||
mod flow;
|
||||
pub use flow::*;
|
||||
use tokio::time::error::Elapsed;
|
||||
|
||||
use crate::{console, error::UserFacingError};
|
||||
use crate::{
|
||||
console,
|
||||
error::{ReportableError, UserFacingError},
|
||||
};
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -67,6 +72,9 @@ pub enum AuthErrorImpl {
|
||||
|
||||
#[error("Too many connections to this endpoint. Please try again later.")]
|
||||
TooManyConnections,
|
||||
|
||||
#[error("Authentication timed out")]
|
||||
UserTimeout(Elapsed),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -93,6 +101,10 @@ impl AuthError {
|
||||
pub fn is_auth_failed(&self) -> bool {
|
||||
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
|
||||
}
|
||||
|
||||
pub fn user_timeout(elapsed: Elapsed) -> Self {
|
||||
AuthErrorImpl::UserTimeout(elapsed).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
|
||||
@@ -116,6 +128,27 @@ impl UserFacingError for AuthError {
|
||||
Io(_) => "Internal error".to_string(),
|
||||
IpAddressNotAllowed => self.to_string(),
|
||||
TooManyConnections => self.to_string(),
|
||||
UserTimeout(_) => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for AuthError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.get_error_kind(),
|
||||
GetAuthInfo(e) => e.get_error_kind(),
|
||||
WakeCompute(e) => e.get_error_kind(),
|
||||
Sasl(e) => e.get_error_kind(),
|
||||
AuthFailed(_) => crate::error::ErrorKind::User,
|
||||
BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
MalformedPassword(_) => crate::error::ErrorKind::User,
|
||||
MissingEndpointName => crate::error::ErrorKind::User,
|
||||
Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
IpAddressNotAllowed => crate::error::ErrorKind::User,
|
||||
TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
UserTimeout(_) => crate::error::ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,9 +45,9 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
)
|
||||
.await
|
||||
.map_err(|error| {
|
||||
.map_err(|e| {
|
||||
warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs());
|
||||
auth::io::Error::new(auth::io::ErrorKind::TimedOut, error)
|
||||
auth::AuthError::user_timeout(e)
|
||||
})??;
|
||||
|
||||
let client_key = match auth_outcome {
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::{
|
||||
auth, compute,
|
||||
console::{self, provider::NodeInfo},
|
||||
context::RequestMonitoring,
|
||||
error::UserFacingError,
|
||||
error::{ReportableError, UserFacingError},
|
||||
stream::PqStream,
|
||||
waiters,
|
||||
};
|
||||
@@ -14,10 +14,6 @@ use tracing::{info, info_span};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LinkAuthError {
|
||||
/// Authentication error reported by the console.
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthFailed(String),
|
||||
|
||||
#[error(transparent)]
|
||||
WaiterRegister(#[from] waiters::RegisterError),
|
||||
|
||||
@@ -30,10 +26,16 @@ pub enum LinkAuthError {
|
||||
|
||||
impl UserFacingError for LinkAuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use LinkAuthError::*;
|
||||
"Internal error".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for LinkAuthError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
AuthFailed(_) => self.to_string(),
|
||||
_ => "Internal error".to_string(),
|
||||
LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
|
||||
LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
|
||||
LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use crate::{
|
||||
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
|
||||
auth::password_hack::parse_endpoint_param,
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
|
||||
proxy::NeonOptions,
|
||||
serverless::SERVERLESS_DRIVER_SNI,
|
||||
EndpointId, RoleName,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
@@ -39,6 +43,12 @@ pub enum ComputeUserInfoParseError {
|
||||
|
||||
impl UserFacingError for ComputeUserInfoParseError {}
|
||||
|
||||
impl ReportableError for ComputeUserInfoParseError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
|
||||
@@ -240,7 +240,9 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
?unexpected,
|
||||
"unexpected startup packet, rejecting connection"
|
||||
);
|
||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
|
||||
stream
|
||||
.throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User)
|
||||
.await?
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -272,5 +274,10 @@ async fn handle_client(
|
||||
let client = tokio::net::TcpStream::connect(destination).await?;
|
||||
|
||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||
proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await
|
||||
|
||||
// doesn't yet matter as pg-sni-router doesn't report analytics logs
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
|
||||
proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await
|
||||
}
|
||||
|
||||
@@ -1,24 +1,45 @@
|
||||
use anyhow::Context;
|
||||
use dashmap::DashMap;
|
||||
use pq_proto::CancelKeyData;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use tracing::info;
|
||||
|
||||
use crate::error::ReportableError;
|
||||
|
||||
/// Enables serving `CancelRequest`s.
|
||||
#[derive(Default)]
|
||||
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CancelError {
|
||||
#[error("{0}")]
|
||||
IO(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
}
|
||||
|
||||
impl ReportableError for CancelError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
CancelError::IO(_) => crate::error::ErrorKind::Compute,
|
||||
CancelError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CancelMap {
|
||||
/// Cancel a running query for the corresponding connection.
|
||||
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
|
||||
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
|
||||
// NB: we should immediately release the lock after cloning the token.
|
||||
let cancel_closure = self
|
||||
.0
|
||||
.get(&key)
|
||||
.and_then(|x| x.clone())
|
||||
.with_context(|| format!("query cancellation key not found: {key}"))?;
|
||||
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
info!("cancelling query per user's request using key {key}");
|
||||
cancel_closure.try_cancel_query().await
|
||||
@@ -81,7 +102,7 @@ impl CancelClosure {
|
||||
}
|
||||
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
|
||||
async fn try_cancel_query(self) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use crate::{
|
||||
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
|
||||
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
auth::parse_endpoint_param,
|
||||
cancellation::CancelClosure,
|
||||
console::errors::WakeComputeError,
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
proxy::neon_option,
|
||||
};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
@@ -58,6 +62,20 @@ impl UserFacingError for ConnectionError {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ConnectionError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ use tracing::info;
|
||||
|
||||
pub mod errors {
|
||||
use crate::{
|
||||
error::{io_error, UserFacingError},
|
||||
error::{io_error, ReportableError, UserFacingError},
|
||||
http,
|
||||
proxy::retry::ShouldRetry,
|
||||
};
|
||||
@@ -81,6 +81,15 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ApiError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
|
||||
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ShouldRetry for ApiError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
@@ -150,6 +159,16 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for GetAuthInfoError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
@@ -194,6 +213,16 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for WakeComputeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
|
||||
WakeComputeError::ApiError(e) => e.get_error_kind(),
|
||||
WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
|
||||
@@ -8,8 +8,10 @@ use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId,
|
||||
EndpointId, ProjectId, RoleName,
|
||||
console::messages::MetricsAuxInfo,
|
||||
error::ErrorKind,
|
||||
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
|
||||
BranchId, EndpointId, ProjectId, RoleName,
|
||||
};
|
||||
|
||||
pub mod parquet;
|
||||
@@ -108,6 +110,18 @@ impl RequestMonitoring {
|
||||
self.user = Some(user);
|
||||
}
|
||||
|
||||
pub fn set_error_kind(&mut self, kind: ErrorKind) {
|
||||
ERROR_BY_KIND
|
||||
.with_label_values(&[kind.to_metric_label()])
|
||||
.inc();
|
||||
if let Some(ep) = &self.endpoint_id {
|
||||
ENDPOINT_ERRORS_BY_KIND
|
||||
.with_label_values(&[kind.to_metric_label()])
|
||||
.measure(ep);
|
||||
}
|
||||
self.error_kind = Some(kind);
|
||||
}
|
||||
|
||||
pub fn set_success(&mut self) {
|
||||
self.success = true;
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ impl From<RequestMonitoring> for RequestData {
|
||||
branch: value.branch.as_deref().map(String::from),
|
||||
protocol: value.protocol,
|
||||
region: value.region,
|
||||
error: value.error_kind.as_ref().map(|e| e.to_str()),
|
||||
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
|
||||
success: value.success,
|
||||
duration_us: SystemTime::from(value.first_packet)
|
||||
.elapsed()
|
||||
|
||||
@@ -17,7 +17,7 @@ pub fn log_error<E: fmt::Display>(e: E) -> E {
|
||||
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
|
||||
/// is way too convenient and tends to proliferate all across the codebase,
|
||||
/// ultimately leading to accidental leaks of sensitive data.
|
||||
pub trait UserFacingError: fmt::Display {
|
||||
pub trait UserFacingError: ReportableError {
|
||||
/// Format the error for client, stripping all sensitive info.
|
||||
///
|
||||
/// Although this might be a no-op for many types, it's highly
|
||||
@@ -29,13 +29,13 @@ pub trait UserFacingError: fmt::Display {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum ErrorKind {
|
||||
/// Wrong password, unknown endpoint, protocol violation, etc...
|
||||
User,
|
||||
|
||||
/// Network error between user and proxy. Not necessarily user error
|
||||
Disconnect,
|
||||
ClientDisconnect,
|
||||
|
||||
/// Proxy self-imposed rate limits
|
||||
RateLimit,
|
||||
@@ -46,6 +46,9 @@ pub enum ErrorKind {
|
||||
/// Error communicating with control plane
|
||||
ControlPlane,
|
||||
|
||||
/// Postgres error
|
||||
Postgres,
|
||||
|
||||
/// Error communicating with compute
|
||||
Compute,
|
||||
}
|
||||
@@ -54,11 +57,36 @@ impl ErrorKind {
|
||||
pub fn to_str(&self) -> &'static str {
|
||||
match self {
|
||||
ErrorKind::User => "request failed due to user error",
|
||||
ErrorKind::Disconnect => "client disconnected",
|
||||
ErrorKind::ClientDisconnect => "client disconnected",
|
||||
ErrorKind::RateLimit => "request cancelled due to rate limit",
|
||||
ErrorKind::Service => "internal service error",
|
||||
ErrorKind::ControlPlane => "non-retryable control plane error",
|
||||
ErrorKind::Compute => "non-retryable compute error (or exhausted retry capacity)",
|
||||
ErrorKind::Postgres => "postgres error",
|
||||
ErrorKind::Compute => {
|
||||
"non-retryable compute connection error (or exhausted retry capacity)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_metric_label(&self) -> &'static str {
|
||||
match self {
|
||||
ErrorKind::User => "user",
|
||||
ErrorKind::ClientDisconnect => "clientdisconnect",
|
||||
ErrorKind::RateLimit => "ratelimit",
|
||||
ErrorKind::Service => "service",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "postgres",
|
||||
ErrorKind::Compute => "compute",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ReportableError: fmt::Display + Send + 'static {
|
||||
fn get_error_kind(&self) -> ErrorKind;
|
||||
}
|
||||
|
||||
impl ReportableError for tokio::time::error::Elapsed {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::RateLimit
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,3 +274,22 @@ pub static CONNECTING_ENDPOINTS: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static ERROR_BY_KIND: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"proxy_errors_total",
|
||||
"Number of errors by a given classification",
|
||||
&["type"],
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static ENDPOINT_ERRORS_BY_KIND: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
|
||||
register_hll_vec!(
|
||||
32,
|
||||
"proxy_endpoints_affected_by_errors",
|
||||
"Number of endpoints affected by errors of a given classification",
|
||||
&["type"],
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
@@ -13,9 +13,10 @@ use crate::{
|
||||
compute,
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
|
||||
protocol2::WithClientIp,
|
||||
proxy::{handshake::handshake, passthrough::proxy_pass},
|
||||
proxy::handshake::{handshake, HandshakeData},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
EndpointCacheKey,
|
||||
@@ -28,14 +29,17 @@ use pq_proto::{BeMessage as Be, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use smol_str::{format_smolstr, SmolStr};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
|
||||
use self::connect_compute::{connect_to_compute, TcpMechanism};
|
||||
use self::{
|
||||
connect_compute::{connect_to_compute, TcpMechanism},
|
||||
passthrough::ProxyPassthrough,
|
||||
};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
const ERR_PROTO_VIOLATION: &str = "protocol violation";
|
||||
|
||||
pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
f: F,
|
||||
@@ -98,14 +102,14 @@ pub async fn task_main(
|
||||
bail!("missing required client IP");
|
||||
}
|
||||
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
|
||||
|
||||
socket
|
||||
.inner
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
handle_client(
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
cancel_map,
|
||||
@@ -113,7 +117,26 @@ pub async fn task_main(
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
// todo: log and push to ctx the error kind
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
ctx.log();
|
||||
Err(e.into())
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
Ok(())
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
p.proxy_pass().await
|
||||
}
|
||||
}
|
||||
}
|
||||
.unwrap_or_else(move |e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -169,6 +192,37 @@ impl ClientMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
Handshake(#[from] handshake::HandshakeError),
|
||||
#[error("{0}")]
|
||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("{0}")]
|
||||
PrepareClient(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for ClientRequestError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
@@ -176,7 +230,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
info!(
|
||||
protocol = ctx.protocol,
|
||||
"handling interactive connection from client"
|
||||
@@ -193,11 +247,16 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let tls = config.tls_config.as_ref();
|
||||
|
||||
let pause = ctx.latency_timer.pause();
|
||||
let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
|
||||
let do_handshake = handshake(stream, mode.handshake_tls(tls));
|
||||
let (mut stream, params) =
|
||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
return Ok(cancel_map
|
||||
.cancel_session(cancel_key_data)
|
||||
.await
|
||||
.map(|()| None)?)
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
@@ -222,7 +281,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if !endpoint_rate_limiter.check(ep) {
|
||||
return stream
|
||||
.throw_error(auth::AuthError::too_many_connections())
|
||||
.await;
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,7 +301,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return stream.throw_error(e).instrument(params_span).await;
|
||||
return stream.throw_error(e).instrument(params_span).await?;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -268,7 +327,13 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
proxy_pass(ctx, stream, node.stream, aux).await
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
compute: node,
|
||||
aux,
|
||||
req: _request_gauge,
|
||||
conn: _client_gauge,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
@@ -277,7 +342,7 @@ async fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
session: &cancellation::Session,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<(), std::io::Error> {
|
||||
// Register compute's query cancellation token and produce a new, unique one.
|
||||
// The new token (cancel_key_data) will be sent to the client.
|
||||
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
|
||||
|
||||
@@ -1,15 +1,60 @@
|
||||
use anyhow::{bail, Context};
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
cancellation::CancelMap,
|
||||
config::TlsConfig,
|
||||
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
|
||||
stream::{PqStream, Stream},
|
||||
error::ReportableError,
|
||||
proxy::ERR_INSECURE_CONNECTION,
|
||||
stream::{PqStream, Stream, StreamUpgradeError},
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum HandshakeError {
|
||||
#[error("data is sent before server replied with EncryptionResponse")]
|
||||
EarlyData,
|
||||
|
||||
#[error("protocol violation")]
|
||||
ProtocolViolation,
|
||||
|
||||
#[error("missing certificate")]
|
||||
MissingCertificate,
|
||||
|
||||
#[error("{0}")]
|
||||
StreamUpgradeError(#[from] StreamUpgradeError),
|
||||
|
||||
#[error("{0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for HandshakeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
HandshakeError::EarlyData => crate::error::ErrorKind::User,
|
||||
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
|
||||
// This error should not happen, but will if we have no default certificate and
|
||||
// the client sends no SNI extension.
|
||||
// If they provide SNI then we can be sure there is a certificate that matches.
|
||||
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
|
||||
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
|
||||
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
|
||||
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
},
|
||||
HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
HandshakeError::ReportedError(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
/// For better testing experience, `stream` can be any object satisfying the traits.
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
@@ -18,8 +63,7 @@ use crate::{
|
||||
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
|
||||
) -> Result<HandshakeData<S>, HandshakeError> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
|
||||
@@ -49,14 +93,14 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
return Err(HandshakeError::EarlyData);
|
||||
}
|
||||
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(tls_stream.get_ref().1.server_name())
|
||||
.context("missing certificate")?;
|
||||
.ok_or(HandshakeError::MissingCertificate)?;
|
||||
|
||||
stream = PqStream::new(Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
@@ -64,7 +108,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
},
|
||||
GssEncRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_gss => {
|
||||
@@ -73,23 +117,23 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// Currently, we don't support GSSAPI
|
||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
},
|
||||
StartupMessage { params, .. } => {
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
||||
return stream
|
||||
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!(session_type = "normal", "successful handshake");
|
||||
break Ok(Some((stream, params)));
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
cancel_map.cancel_session(cancel_key_data).await?;
|
||||
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(None);
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::{
|
||||
compute::PostgresConnection,
|
||||
console::messages::MetricsAuxInfo,
|
||||
context::RequestMonitoring,
|
||||
metrics::NUM_BYTES_PROXIED_COUNTER,
|
||||
stream::Stream,
|
||||
usage_metrics::{Ids, USAGE_METRICS},
|
||||
};
|
||||
use metrics::IntCounterPairGuard;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
@@ -11,14 +13,10 @@ use utils::measured_stream::MeasuredStream;
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn proxy_pass(
|
||||
ctx: &mut RequestMonitoring,
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
|
||||
let usage = USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id.clone(),
|
||||
branch_id: aux.branch_id.clone(),
|
||||
@@ -51,3 +49,18 @@ pub async fn proxy_pass(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct ProxyPassthrough<S> {
|
||||
pub client: Stream<S>,
|
||||
pub compute: PostgresConnection,
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
pub req: IntCounterPairGuard,
|
||||
pub conn: IntCounterPairGuard,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub async fn proxy_pass(self) -> anyhow::Result<()> {
|
||||
proxy_pass(self.client, self.compute.stream, self.aux).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,11 +163,11 @@ async fn dummy_proxy(
|
||||
tls: Option<TlsConfig>,
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
let client = WithClientIp::new(client);
|
||||
let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map)
|
||||
.await?
|
||||
.context("handshake failed")?;
|
||||
let mut stream = match handshake(client, tls.as_ref()).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
};
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
|
||||
@@ -35,12 +35,10 @@ async fn proxy_mitm(
|
||||
tokio::spawn(async move {
|
||||
// begin handshake with end_server
|
||||
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
|
||||
// process handshake with end_client
|
||||
let (end_client, startup) =
|
||||
handshake(client1, Some(&server_config1), &CancelMap::default())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let (end_client, startup) = match handshake(client1, Some(&server_config1)).await.unwrap() {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
|
||||
};
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
let (end_client, buf) = end_client.framed.into_inner();
|
||||
|
||||
@@ -10,7 +10,7 @@ mod channel_binding;
|
||||
mod messages;
|
||||
mod stream;
|
||||
|
||||
use crate::error::UserFacingError;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -48,6 +48,18 @@ impl UserFacingError for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for Error {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
|
||||
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
|
||||
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
|
||||
Error::MissingBinding => crate::error::ErrorKind::Service,
|
||||
Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A convenient result type for SASL exchange.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
|
||||
@@ -109,10 +109,9 @@ pub async fn task_main(
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
|
||||
let (io, tls) = stream.get_ref();
|
||||
let (io, _) = stream.get_ref();
|
||||
let client_addr = io.client_addr();
|
||||
let remote_addr = io.inner.remote_addr();
|
||||
let sni_name = tls.server_name().map(|s| s.to_string());
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
@@ -125,7 +124,6 @@ pub async fn task_main(
|
||||
};
|
||||
Ok(MetricService::new(hyper::service::service_fn(
|
||||
move |req: Request<Body>| {
|
||||
let sni_name = sni_name.clone();
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
@@ -141,7 +139,6 @@ pub async fn task_main(
|
||||
ws_connections,
|
||||
cancel_map,
|
||||
session_id,
|
||||
sni_name,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
@@ -210,7 +207,6 @@ async fn request_handler(
|
||||
ws_connections: TaskTracker,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
session_id: uuid::Uuid,
|
||||
sni_hostname: Option<String>,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
@@ -230,11 +226,11 @@ async fn request_handler(
|
||||
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
|
||||
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
&mut ctx,
|
||||
ctx,
|
||||
websocket,
|
||||
cancel_map,
|
||||
host,
|
||||
@@ -251,9 +247,9 @@ async fn request_handler(
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
|
||||
sql_over_http::handle(config, &mut ctx, request, sni_hostname, backend).await
|
||||
sql_over_http::handle(config, ctx, request, backend).await
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use tracing::info;
|
||||
|
||||
@@ -8,7 +7,10 @@ use crate::{
|
||||
auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
|
||||
compute,
|
||||
config::ProxyConfig,
|
||||
console::CachedNodeInfo,
|
||||
console::{
|
||||
errors::{GetAuthInfoError, WakeComputeError},
|
||||
CachedNodeInfo,
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
proxy::connect_compute::ConnectMechanism,
|
||||
};
|
||||
@@ -66,7 +68,7 @@ impl PoolingBackend {
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentialKeys,
|
||||
force_new: bool,
|
||||
) -> anyhow::Result<Client<tokio_postgres::Client>> {
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
let maybe_client = if !force_new {
|
||||
info!("pool: looking for an existing connection");
|
||||
self.pool.get(ctx, &conn_info).await?
|
||||
@@ -90,7 +92,7 @@ impl PoolingBackend {
|
||||
let mut node_info = backend
|
||||
.wake_compute(ctx)
|
||||
.await?
|
||||
.context("missing cache entry from wake_compute")?;
|
||||
.ok_or(HttpConnError::NoComputeInfo)?;
|
||||
|
||||
match keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
@@ -114,6 +116,23 @@ impl PoolingBackend {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
#[error("could not connection to compute")]
|
||||
ConnectionError(#[from] tokio_postgres::Error),
|
||||
|
||||
#[error("could not get auth info")]
|
||||
GetAuthInfo(#[from] GetAuthInfoError),
|
||||
#[error("user not authenticated")]
|
||||
AuthError(#[from] AuthError),
|
||||
#[error("wake_compute returned error")]
|
||||
WakeCompute(#[from] WakeComputeError),
|
||||
#[error("wake_compute returned nothing")]
|
||||
NoComputeInfo,
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
conn_info: ConnInfo,
|
||||
@@ -124,7 +143,7 @@ struct TokioMechanism {
|
||||
impl ConnectMechanism for TokioMechanism {
|
||||
type Connection = Client<tokio_postgres::Client>;
|
||||
type ConnectError = tokio_postgres::Error;
|
||||
type Error = anyhow::Error;
|
||||
type Error = HttpConnError;
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
|
||||
@@ -28,6 +28,8 @@ use crate::{
|
||||
use tracing::{debug, error, warn, Span};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
|
||||
pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -358,7 +360,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
self: &Arc<Self>,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> anyhow::Result<Option<Client<C>>> {
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInner<C>> = None;
|
||||
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
|
||||
|
||||
@@ -60,6 +60,20 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum JsonConversionError {
|
||||
#[error("internal error compute returned invalid data: {0}")]
|
||||
AsTextError(tokio_postgres::Error),
|
||||
#[error("parse int error: {0}")]
|
||||
ParseIntError(#[from] std::num::ParseIntError),
|
||||
#[error("parse float error: {0}")]
|
||||
ParseFloatError(#[from] std::num::ParseFloatError),
|
||||
#[error("parse json error: {0}")]
|
||||
ParseJsonError(#[from] serde_json::Error),
|
||||
#[error("unbalanced array")]
|
||||
UnbalancedArray,
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
@@ -68,7 +82,7 @@ pub fn pg_text_row_to_json(
|
||||
columns: &[Type],
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> Result<Value, anyhow::Error> {
|
||||
) -> Result<Value, JsonConversionError> {
|
||||
let iter = row
|
||||
.columns()
|
||||
.iter()
|
||||
@@ -76,7 +90,7 @@ pub fn pg_text_row_to_json(
|
||||
.enumerate()
|
||||
.map(|(i, (column, typ))| {
|
||||
let name = column.name();
|
||||
let pg_value = row.as_text(i)?;
|
||||
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
|
||||
let json_value = if raw_output {
|
||||
match pg_value {
|
||||
Some(v) => Value::String(v.to_string()),
|
||||
@@ -92,10 +106,10 @@ pub fn pg_text_row_to_json(
|
||||
// drop keys and aggregate into array
|
||||
let arr = iter
|
||||
.map(|r| r.map(|(_key, val)| val))
|
||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||
.collect::<Result<Vec<Value>, JsonConversionError>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
let obj = iter.collect::<Result<Map<String, Value>, JsonConversionError>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
@@ -103,7 +117,7 @@ pub fn pg_text_row_to_json(
|
||||
//
|
||||
// Convert postgres text-encoded value to JSON value
|
||||
//
|
||||
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, JsonConversionError> {
|
||||
if let Some(val) = pg_value {
|
||||
if let Kind::Array(elem_type) = pg_type.kind() {
|
||||
return pg_array_parse(val, elem_type);
|
||||
@@ -142,7 +156,7 @@ fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyh
|
||||
// values. Unlike postgres we don't check that all nested arrays have the same
|
||||
// dimensions, we just return them as is.
|
||||
//
|
||||
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, JsonConversionError> {
|
||||
_pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
|
||||
}
|
||||
|
||||
@@ -150,7 +164,7 @@ fn _pg_array_parse(
|
||||
pg_array: &str,
|
||||
elem_type: &Type,
|
||||
nested: bool,
|
||||
) -> Result<(Value, usize), anyhow::Error> {
|
||||
) -> Result<(Value, usize), JsonConversionError> {
|
||||
let mut pg_array_chr = pg_array.char_indices();
|
||||
let mut level = 0;
|
||||
let mut quote = false;
|
||||
@@ -170,7 +184,7 @@ fn _pg_array_parse(
|
||||
entry: &mut String,
|
||||
entries: &mut Vec<Value>,
|
||||
elem_type: &Type,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(), JsonConversionError> {
|
||||
if !entry.is_empty() {
|
||||
// While in usual postgres response we get nulls as None and everything else
|
||||
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
|
||||
@@ -234,7 +248,7 @@ fn _pg_array_parse(
|
||||
}
|
||||
|
||||
if level != 0 {
|
||||
return Err(anyhow::anyhow!("unbalanced array"));
|
||||
return Err(JsonConversionError::UnbalancedArray);
|
||||
}
|
||||
|
||||
Ok((Value::Array(entries), 0))
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::Context;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
@@ -29,9 +28,11 @@ use utils::http::json::json_response;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::HTTP_CONTENT_LENGTH;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
use crate::proxy::NeonOptions;
|
||||
@@ -41,7 +42,6 @@ use super::backend::PoolingBackend;
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::json::json_to_pg_text;
|
||||
use super::json::pg_text_row_to_json;
|
||||
use super::SERVERLESS_DRIVER_SNI;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -86,67 +86,70 @@ where
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static str),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
IncorrectScheme,
|
||||
#[error("missing database name")]
|
||||
MissingDbName,
|
||||
#[error("invalid database name")]
|
||||
InvalidDbName,
|
||||
#[error("missing username")]
|
||||
MissingUsername,
|
||||
#[error("missing password")]
|
||||
MissingPassword,
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
InvalidEndpoint(#[from] ComputeUserInfoParseError),
|
||||
#[error("malformed endpoint")]
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
ctx: &mut RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
sni_hostname: Option<String>,
|
||||
tls: &TlsConfig,
|
||||
) -> Result<ConnInfo, anyhow::Error> {
|
||||
) -> Result<ConnInfo, ConnInfoError> {
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
.ok_or(anyhow::anyhow!("missing connection string"))?
|
||||
.to_str()?;
|
||||
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(anyhow::anyhow!(
|
||||
"connection string must start with postgres: or postgresql:"
|
||||
));
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(anyhow::anyhow!("missing database name"))?;
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
|
||||
let dbname = url_path
|
||||
.next()
|
||||
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||
let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?;
|
||||
|
||||
let username = RoleName::from(connection_url.username());
|
||||
if username.is_empty() {
|
||||
return Err(anyhow::anyhow!("missing username"));
|
||||
return Err(ConnInfoError::MissingUsername);
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
.ok_or(anyhow::anyhow!("no password"))?;
|
||||
|
||||
// TLS certificate selector now based on SNI hostname, so if we are running here
|
||||
// we are sure that SNI hostname is set to one of the configured domain names.
|
||||
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
|
||||
.ok_or(ConnInfoError::MissingPassword)?;
|
||||
|
||||
let hostname = connection_url
|
||||
.host_str()
|
||||
.ok_or(anyhow::anyhow!("no host"))?;
|
||||
.ok_or(ConnInfoError::MissingHostname)?;
|
||||
|
||||
let host_header = headers
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next());
|
||||
|
||||
// sni_hostname has to be either the same as hostname or the one used in serverless driver.
|
||||
if !check_matches(&sni_hostname, hostname)? {
|
||||
return Err(anyhow::anyhow!("mismatched SNI hostname and hostname"));
|
||||
} else if let Some(h) = host_header {
|
||||
if h != sni_hostname {
|
||||
return Err(anyhow::anyhow!("mismatched host header and hostname"));
|
||||
}
|
||||
}
|
||||
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
|
||||
let endpoint =
|
||||
endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
@@ -173,36 +176,27 @@ fn get_conn_info(
|
||||
})
|
||||
}
|
||||
|
||||
fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Error> {
|
||||
if sni_hostname == hostname {
|
||||
return Ok(true);
|
||||
}
|
||||
let (sni_hostname_first, sni_hostname_rest) = sni_hostname
|
||||
.split_once('.')
|
||||
.ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?;
|
||||
let (_, hostname_rest) = hostname
|
||||
.split_once('.')
|
||||
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
|
||||
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
|
||||
}
|
||||
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
mut ctx: RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.http_config.request_timeout,
|
||||
handle_inner(config, ctx, request, sni_hostname, backend),
|
||||
handle_inner(config, &mut ctx, request, backend),
|
||||
)
|
||||
.await;
|
||||
let mut response = match result {
|
||||
Ok(r) => match r {
|
||||
Ok(r) => r,
|
||||
Ok(r) => {
|
||||
ctx.set_success();
|
||||
r
|
||||
}
|
||||
Err(e) => {
|
||||
// TODO: ctx.set_error_kind(e.get_error_type());
|
||||
|
||||
let mut message = format!("{:?}", e);
|
||||
let db_error = e
|
||||
.downcast_ref::<tokio_postgres::Error>()
|
||||
@@ -278,7 +272,9 @@ pub async fn handle(
|
||||
)?
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
|
||||
let message = format!(
|
||||
"HTTP-Connection timed out, execution time exeeded {} seconds",
|
||||
config.http_config.request_timeout.as_secs()
|
||||
@@ -290,6 +286,7 @@ pub async fn handle(
|
||||
)?
|
||||
}
|
||||
};
|
||||
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
@@ -302,7 +299,6 @@ async fn handle_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
@@ -318,12 +314,7 @@ async fn handle_inner(
|
||||
//
|
||||
let headers = request.headers();
|
||||
// TLS config should be there.
|
||||
let conn_info = get_conn_info(
|
||||
ctx,
|
||||
headers,
|
||||
sni_hostname,
|
||||
config.tls_config.as_ref().unwrap(),
|
||||
)?;
|
||||
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
|
||||
info!(
|
||||
user = conn_info.user_info.user.as_str(),
|
||||
project = conn_info.user_info.endpoint.as_str(),
|
||||
@@ -487,8 +478,6 @@ async fn handle_inner(
|
||||
}
|
||||
};
|
||||
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
let metrics = client.metrics();
|
||||
|
||||
// how could this possibly fail
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::{
|
||||
cancellation::CancelMap,
|
||||
config::ProxyConfig,
|
||||
context::RequestMonitoring,
|
||||
error::io_error,
|
||||
error::{io_error, ReportableError},
|
||||
proxy::{handle_client, ClientMode},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
};
|
||||
@@ -131,23 +131,41 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
|
||||
pub async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
mut ctx: RequestMonitoring,
|
||||
websocket: HyperWebsocket,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
hostname: Option<String>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
handle_client(
|
||||
let res = handle_client(
|
||||
config,
|
||||
ctx,
|
||||
&mut ctx,
|
||||
cancel_map,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
// todo: log and push to ctx the error kind
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
ctx.log();
|
||||
Err(e.into())
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
Ok(())
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log();
|
||||
p.proxy_pass().await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::UserFacingError;
|
||||
use anyhow::bail;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use bytes::BytesMut;
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
@@ -73,6 +72,30 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReportedError {
|
||||
source: anyhow::Error,
|
||||
error_kind: ErrorKind,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ReportedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.source.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ReportedError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
self.source.source()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReportedError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
self.error_kind
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
@@ -98,24 +121,52 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
||||
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
||||
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
|
||||
tracing::info!("forwarding error to user: {error}");
|
||||
self.write_message(&BeMessage::ErrorResponse(error, None))
|
||||
.await?;
|
||||
bail!(error)
|
||||
pub async fn throw_error_str<T>(
|
||||
&mut self,
|
||||
msg: &'static str,
|
||||
error_kind: ErrorKind,
|
||||
) -> Result<T, ReportedError> {
|
||||
tracing::info!(
|
||||
kind = error_kind.to_metric_label(),
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
// already error case, ignore client IO error
|
||||
let _: Result<_, std::io::Error> = self
|
||||
.write_message(&BeMessage::ErrorResponse(msg, None))
|
||||
.await;
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(msg),
|
||||
error_kind,
|
||||
})
|
||||
}
|
||||
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
|
||||
pub async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
tracing::info!("forwarding error to user: {msg}");
|
||||
self.write_message(&BeMessage::ErrorResponse(&msg, None))
|
||||
.await?;
|
||||
bail!(error)
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%error,
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
// already error case, ignore client IO error
|
||||
let _: Result<_, std::io::Error> = self
|
||||
.write_message(&BeMessage::ErrorResponse(&msg, None))
|
||||
.await;
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(error),
|
||||
error_kind,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user