state machines for better error management, small futures, and clearer flows

This commit is contained in:
Conrad Ludgate
2024-01-07 09:03:41 +00:00
parent 1905f0bced
commit 7fed0ba44d
25 changed files with 1324 additions and 710 deletions

View File

@@ -82,6 +82,19 @@ impl<S> Framed<S> {
write_buf: self.write_buf,
})
}
/// Return new Framed with stream type transformed by f. For dynamic dispatch.
pub fn map_stream_sync<S2, F>(self, f: F) -> Framed<S2>
where
F: FnOnce(S) -> S2,
{
let stream = f(self.stream);
Framed {
stream,
read_buf: self.read_buf,
write_buf: self.write_buf,
}
}
}
impl<S: AsyncRead + Unpin> Framed<S> {

View File

@@ -13,7 +13,7 @@ use password_hack::PasswordHackPayload;
mod flow;
pub use flow::*;
use crate::{console, error::UserFacingError};
use crate::error::{ReportableError, UserFacingError};
use std::io;
use thiserror::Error;
@@ -23,15 +23,6 @@ pub type Result<T> = std::result::Result<T, AuthError>;
/// Common authentication error.
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
#[error(transparent)]
Link(#[from] backend::LinkAuthError),
#[error(transparent)]
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
#[error(transparent)]
WakeCompute(#[from] console::errors::WakeComputeError),
/// SASL protocol errors (includes [SCRAM](crate::scram)).
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
@@ -99,13 +90,25 @@ impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
}
}
impl ReportableError for AuthError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self.0.as_ref() {
AuthErrorImpl::Sasl(s) => s.get_error_type(),
AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User,
AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User,
AuthErrorImpl::Io(_) => crate::error::ErrorKind::Disconnect,
AuthErrorImpl::IpAddressNotAllowed => crate::error::ErrorKind::User,
AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit,
}
}
}
impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
use AuthErrorImpl::*;
match self.0.as_ref() {
Link(e) => e.to_string_client(),
GetAuthInfo(e) => e.to_string_client(),
WakeCompute(e) => e.to_string_client(),
Sasl(e) => e.to_string_client(),
AuthFailed(_) => self.to_string(),
BadAuthMethod(_) => self.to_string(),

View File

@@ -2,22 +2,27 @@ mod classic;
mod hacks;
mod link;
pub use link::LinkAuthError;
use pq_proto::StartupMessageParams;
use smol_str::SmolStr;
use tokio_postgres::config::AuthKeys;
use crate::auth::backend::link::NeedsLinkAuthentication;
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::validate_password_and_exchange;
use crate::cache::Cached;
use crate::cancellation::Session;
use crate::config::ProxyConfig;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::ConsoleBackend;
use crate::console::AuthSecret;
use crate::context::RequestMonitoring;
use crate::proxy::connect_compute::handle_try_wake;
use crate::proxy::retry::retry_after;
use crate::proxy::wake_compute::NeedsWakeCompute;
use crate::proxy::ClientMode;
use crate::proxy::NeonOptions;
use crate::rate_limiter::EndpointRateLimiter;
use crate::scram;
use crate::stream::Stream;
use crate::state_machine::{user_facing_error, DynStage, ResultExt, Stage, StageError};
use crate::stream::{PqStream, Stream};
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
config::AuthenticationConfig,
@@ -30,10 +35,11 @@ use crate::{
};
use futures::TryFutureExt;
use std::borrow::Cow;
use std::ops::ControlFlow;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use tracing::info;
use self::hacks::NeedsPasswordHack;
/// This type serves two purposes:
///
@@ -170,66 +176,94 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
}
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
/// All authentication flows will emit an AuthenticationOk message if successful.
async fn auth_quirks(
ctx: &mut RequestMonitoring,
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
struct NeedsAuthSecret<S> {
stream: PqStream<Stream<S>>,
api: Cow<'static, ConsoleBackend>,
params: StartupMessageParams,
allow_self_signed_compute: bool,
allow_cleartext: bool,
info: ComputeUserInfo,
unauthenticated_password: Option<Vec<u8>>,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info, unauthenticated_password) = match user_info.try_into() {
Err(info) => {
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
.await?;
ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
(res.info, Some(res.keys))
}
Ok(info) => (info, None),
};
info!("fetching user's authentication info");
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
// monitoring
ctx: RequestMonitoring,
cancel_session: Session,
}
// check allowed list
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed());
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsAuthSecret<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("get_auth_secret")
}
let cached_secret = api.get_role_secret(ctx, &info).await?;
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
api,
params,
allow_cleartext,
allow_self_signed_compute,
info,
unauthenticated_password,
config,
mut ctx,
cancel_session,
} = self;
let secret = cached_secret.value.clone().unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random()))
});
match authenticate_with_secret(
ctx,
secret,
info,
client,
unauthenticated_password,
allow_cleartext,
config,
)
.await
{
Ok(keys) => Ok(keys),
Err(e) => {
info!("fetching user's authentication info");
let (allowed_ips, stream) = api
.get_allowed_ips(&mut ctx, &info)
.await
.send_error_to_user(&mut ctx, stream)?;
// check allowed list
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(user_facing_error(
auth::AuthError::ip_address_not_allowed(),
&mut ctx,
stream,
));
}
let (cached_secret, mut stream) = api
.get_role_secret(&mut ctx, &info)
.await
.send_error_to_user(&mut ctx, stream)?;
let secret = cached_secret.value.clone().unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random()))
});
let (keys, stream) = authenticate_with_secret(
&mut ctx,
secret,
info,
&mut stream,
unauthenticated_password,
allow_cleartext,
config,
)
.await
.map_err(|e| {
if e.is_auth_failed() {
// The password could have been changed, so we invalidate the cache.
cached_secret.invalidate();
}
Err(e)
}
e
})
.send_error_to_user(&mut ctx, stream)?;
Ok(Box::new(NeedsWakeCompute {
stream,
api,
params,
allow_self_signed_compute,
creds: keys,
ctx,
cancel_session,
}))
}
}
@@ -270,49 +304,6 @@ async fn authenticate_with_secret(
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
}
/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache)
/// only if authentication was successfuly.
async fn auth_and_wake_compute(
ctx: &mut RequestMonitoring,
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
let compute_credentials =
auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?;
let mut num_retries = 0;
let mut node = loop {
let wake_res = api.wake_compute(ctx, &compute_credentials.info).await;
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
return Err(e.into());
}
Ok(ControlFlow::Continue(e)) => {
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
}
Ok(ControlFlow::Break(n)) => break n,
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
tokio::time::sleep(wait_duration).await;
};
ctx.set_project(node.aux.clone());
match compute_credentials.keys {
#[cfg(feature = "testing")]
ComputeCredentialKeys::Password(password) => node.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
};
Ok((node, compute_credentials.info))
}
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<SmolStr> {
@@ -337,50 +328,96 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
Test(_) => "test",
}
}
}
/// Authenticate the client via the requested backend, possibly using credentials.
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub async fn authenticate(
self,
ctx: &mut RequestMonitoring,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
use BackendType::*;
pub struct NeedsAuthentication<S> {
pub stream: PqStream<Stream<S>>,
pub creds: BackendType<'static, auth::ComputeUserInfoMaybeEndpoint>,
pub params: StartupMessageParams,
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
pub mode: ClientMode,
pub config: &'static ProxyConfig,
let res = match self {
Console(api, user_info) => {
info!(
user = &*user_info.user,
project = user_info.project(),
"performing authentication using the console"
);
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
let (cache_info, user_info) =
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
.await?;
(cache_info, BackendType::Console(api, user_info))
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsAuthentication<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("authenticate")
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
creds,
params,
endpoint_rate_limiter,
mode,
config,
mut ctx,
cancel_session,
} = self;
// check rate limit
if let Some(ep) = creds.get_endpoint() {
if !endpoint_rate_limiter.check(ep) {
return Err(user_facing_error(
auth::AuthError::too_many_connections(),
&mut ctx,
stream,
));
}
}
let allow_self_signed_compute = mode.allow_self_signed_compute(config);
let allow_cleartext = mode.allow_cleartext();
match creds {
BackendType::Console(api, creds) => {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
match creds.try_into() {
Err(info) => Ok(Box::new(NeedsPasswordHack {
stream,
api,
params,
allow_self_signed_compute,
info,
allow_cleartext,
config: &config.authentication_config,
ctx,
cancel_session,
})),
Ok(info) => Ok(Box::new(NeedsAuthSecret {
stream,
api,
params,
allow_self_signed_compute,
info,
unauthenticated_password: None,
allow_cleartext,
config: &config.authentication_config,
ctx,
cancel_session,
})),
}
}
// NOTE: this auth backend doesn't use client credentials.
Link(url) => {
info!("performing link authentication");
let node_info = link::authenticate(&url, client).await?;
(
CachedNodeInfo::new_uncached(node_info),
BackendType::Link(url),
)
}
BackendType::Link(link) => Ok(Box::new(NeedsLinkAuthentication {
stream,
link,
params,
allow_self_signed_compute,
ctx,
cancel_session,
})),
#[cfg(test)]
Test(_) => {
BackendType::Test(_) => {
unreachable!("this function should never be called in the test backend")
}
};
info!("user successfully authenticated");
Ok(res)
}
}
}

View File

@@ -1,13 +1,21 @@
use std::borrow::Cow;
use super::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
NeedsAuthSecret,
};
use crate::{
auth::{self, AuthFlow},
console::AuthSecret,
cancellation::Session,
config::AuthenticationConfig,
console::{provider::ConsoleBackend, AuthSecret},
context::RequestMonitoring,
metrics::LatencyTimer,
sasl,
stream::{self, Stream},
state_machine::{DynStage, ResultExt, Stage, StageError},
stream::{self, PqStream, Stream},
};
use pq_proto::StartupMessageParams;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -46,7 +54,7 @@ pub async fn authenticate_cleartext(
/// Workaround for clients which don't provide an endpoint (project) name.
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
pub async fn password_hack_no_authentication(
async fn password_hack_no_authentication(
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
@@ -74,3 +82,47 @@ pub async fn password_hack_no_authentication(
keys: payload.password,
})
}
pub struct NeedsPasswordHack<S> {
pub stream: PqStream<Stream<S>>,
pub api: Cow<'static, ConsoleBackend>,
pub params: StartupMessageParams,
pub allow_self_signed_compute: bool,
pub allow_cleartext: bool,
pub info: ComputeUserInfoNoEndpoint,
pub config: &'static AuthenticationConfig,
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsPasswordHack<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("password_hack")
}
async fn run(mut self) -> Result<DynStage, StageError> {
let (res, stream) = password_hack_no_authentication(
self.info,
&mut self.stream,
&mut self.ctx.latency_timer,
)
.await
.send_error_to_user(&mut self.ctx, self.stream)?;
self.ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
Ok(Box::new(NeedsAuthSecret {
stream,
info: res.info,
unauthenticated_password: Some(res.keys),
api: self.api,
params: self.params,
allow_self_signed_compute: self.allow_self_signed_compute,
allow_cleartext: self.allow_cleartext,
ctx: self.ctx,
cancel_session: self.cancel_session,
config: self.config,
}))
}
}

View File

@@ -1,41 +1,20 @@
use std::borrow::Cow;
use crate::{
auth, compute,
console::{self, provider::NodeInfo},
error::UserFacingError,
stream::PqStream,
waiters,
auth::BackendType,
cancellation::Session,
compute,
console::{self, mgmt::ComputeReady, provider::NodeInfo, CachedNodeInfo},
context::RequestMonitoring,
proxy::connect_compute::{NeedsComputeConnection, TcpMechanism},
state_machine::{DynStage, ResultExt, Stage, StageError},
stream::{PqStream, Stream},
waiters::Waiter,
};
use pq_proto::BeMessage as Be;
use thiserror::Error;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::SslMode;
use tracing::{info, info_span};
#[derive(Debug, Error)]
pub enum LinkAuthError {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),
#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),
#[error(transparent)]
WaiterWait(#[from] waiters::WaitError),
#[error(transparent)]
Io(#[from] std::io::Error),
}
impl UserFacingError for LinkAuthError {
fn to_string_client(&self) -> String {
use LinkAuthError::*;
match self {
AuthFailed(_) => self.to_string(),
_ => "Internal error".to_string(),
}
}
}
use tracing::info;
fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
format!(
@@ -53,64 +32,146 @@ pub fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
pub(super) async fn authenticate(
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
// registering waiter can fail if we get unlucky with rng.
// just try again.
let (psql_session_id, waiter) = loop {
let psql_session_id = new_psql_session_id();
pub struct NeedsLinkAuthentication<S> {
pub stream: PqStream<Stream<S>>,
pub link: Cow<'static, crate::url::ApiUrl>,
pub params: StartupMessageParams,
pub allow_self_signed_compute: bool,
match console::mgmt::get_waiter(&psql_session_id) {
Ok(waiter) => break (psql_session_id, waiter),
Err(_e) => continue,
}
};
let span = info_span!("link", psql_session_id = &psql_session_id);
let greeting = hello_message(link_uri, &psql_session_id);
// Give user a URL to spawn a new database.
info!(parent: &span, "sending the auth URL to the user");
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&Be::CLIENT_ENCODING)?
.write_message(&Be::NoticeResponse(&greeting))
.await?;
// Wait for web console response (see `mgmt`).
info!(parent: &span, "waiting for console's reply...");
let db_info = waiter.await.map_err(LinkAuthError::from)?;
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
Ok(NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute: false, // caller may override
})
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsLinkAuthentication<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("link", psql_session_id = tracing::field::Empty)
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
mut stream,
link,
params,
allow_self_signed_compute,
mut ctx,
cancel_session,
} = self;
// registering waiter can fail if we get unlucky with rng.
// just try again.
let (psql_session_id, waiter) = loop {
let psql_session_id = new_psql_session_id();
match console::mgmt::get_waiter(&psql_session_id) {
Ok(waiter) => break (psql_session_id, waiter),
Err(_e) => continue,
}
};
tracing::Span::current().record("psql_session_id", &psql_session_id);
let greeting = hello_message(&link, &psql_session_id);
info!("sending the auth URL to the user");
stream
.write_message_noflush(&Be::AuthenticationOk)
.and_then(|s| s.write_message_noflush(&Be::CLIENT_ENCODING))
.and_then(|s| s.write_message_noflush(&Be::NoticeResponse(&greeting)))
.no_user_error(&mut ctx, crate::error::ErrorKind::Service)?
.flush()
.await
.no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?;
Ok(Box::new(NeedsLinkAuthenticationResponse {
stream,
link,
params,
allow_self_signed_compute,
waiter,
psql_session_id,
ctx,
cancel_session,
}))
}
}
struct NeedsLinkAuthenticationResponse<S> {
stream: PqStream<Stream<S>>,
link: Cow<'static, crate::url::ApiUrl>,
params: StartupMessageParams,
allow_self_signed_compute: bool,
waiter: Waiter<'static, ComputeReady>,
psql_session_id: String,
// monitoring
ctx: RequestMonitoring,
cancel_session: Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage
for NeedsLinkAuthenticationResponse<S>
{
fn span(&self) -> tracing::Span {
tracing::info_span!("link_wait", psql_session_id = self.psql_session_id)
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
mut stream,
link,
params,
allow_self_signed_compute,
waiter,
psql_session_id: _,
mut ctx,
cancel_session,
} = self;
// Wait for web console response (see `mgmt`).
info!("waiting for console's reply...");
let db_info = waiter
.await
.no_user_error(&mut ctx, crate::error::ErrorKind::Service)?;
stream
.write_message_noflush(&Be::NoticeResponse("Connecting to database."))
.no_user_error(&mut ctx, crate::error::ErrorKind::Service)?;
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
let node_info = CachedNodeInfo::new_uncached(NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute,
});
let user_info = BackendType::Link(link);
Ok(Box::new(NeedsComputeConnection {
stream,
user_info,
mechanism: TcpMechanism { params },
node_info,
ctx,
cancel_session,
}))
}
}

View File

@@ -1,8 +1,11 @@
//! User credentials used in authentication.
use crate::{
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions,
auth::password_hack::parse_endpoint_param,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
proxy::NeonOptions,
};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
@@ -33,7 +36,24 @@ pub enum ComputeUserInfoParseError {
MalformedProjectName(SmolStr),
}
impl UserFacingError for ComputeUserInfoParseError {}
impl ReportableError for ComputeUserInfoParseError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
ComputeUserInfoParseError::MissingKey(_) => crate::error::ErrorKind::User,
ComputeUserInfoParseError::InconsistentProjectNames { .. } => {
crate::error::ErrorKind::User
}
ComputeUserInfoParseError::UnknownCommonName { .. } => crate::error::ErrorKind::User,
ComputeUserInfoParseError::MalformedProjectName(_) => crate::error::ErrorKind::User,
}
}
}
impl UserFacingError for ComputeUserInfoParseError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.

View File

@@ -164,6 +164,13 @@ async fn task_main(
let tls_config = Arc::clone(&tls_config);
let dest_suffix = Arc::clone(&dest_suffix);
let root_span = tracing::info_span!(
"handle_client",
?session_id,
endpoint = tracing::field::Empty
);
let root_span2 = root_span.clone();
connections.spawn(
async move {
socket
@@ -171,8 +178,13 @@ async fn task_main(
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let mut ctx =
RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
let mut ctx = RequestMonitoring::new(
session_id,
peer_addr.ip(),
"sni_router",
"sni",
root_span2,
);
handle_client(
&mut ctx,
dest_suffix,
@@ -186,7 +198,7 @@ async fn task_main(
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
.instrument(tracing::info_span!("handle_client", ?session_id)),
.instrument(root_span),
);
}
@@ -271,6 +283,7 @@ async fn handle_client(
let client = tokio::net::TcpStream::connect(destination).await?;
ctx.log();
let metrics_aux: MetricsAuxInfo = Default::default();
proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await
proxy::proxy::pass::proxy_pass(tls_stream, client, metrics_aux).await
}

View File

@@ -1,7 +1,7 @@
use anyhow::{bail, Context};
use anyhow::Context;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::net::SocketAddr;
use std::{net::SocketAddr, sync::Arc};
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
@@ -25,39 +25,33 @@ impl CancelMap {
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
pub fn get_session(self: Arc<Self>) -> Session {
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
let key = rand::random();
let key = loop {
let key = rand::random();
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.0.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => {
bail!("query cancellation key already exists: {key}")
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.0.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => {
continue;
}
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
}
}
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
}
}
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.remove(&key);
info!("dropped query cancellation key {key}");
}
break key;
};
info!("registered new query cancellation key {key}");
let session = Session::new(key, self);
f(session).await
Session {
key,
cancel_map: self,
}
}
#[cfg(test)]
@@ -98,23 +92,17 @@ impl CancelClosure {
}
/// Helper for registering query cancellation tokens.
pub struct Session<'a> {
pub struct Session {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancel_map: &'a CancelMap,
cancel_map: Arc<CancelMap>,
}
impl<'a> Session<'a> {
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
Self { key, cancel_map }
}
}
impl Session<'_> {
impl Session {
/// Store the cancel token for the given session.
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map.0.insert(self.key, Some(cancel_closure));
@@ -122,37 +110,26 @@ impl Session<'_> {
}
}
impl Drop for Session {
fn drop(&mut self) {
self.cancel_map.0.remove(&self.key);
info!("dropped query cancellation key {}", &self.key);
}
}
#[cfg(test)]
mod tests {
use super::*;
use once_cell::sync::Lazy;
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
let (tx, rx) = tokio::sync::oneshot::channel();
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
assert!(CANCEL_MAP.contains(&session));
tx.send(()).expect("failed to send");
futures::future::pending::<()>().await; // sleep forever
Ok(())
}));
// Wait until the task has been spawned.
rx.await.context("failed to hear from the task")?;
// Drop the session's entry by cancelling the task.
task.abort();
let error = task.await.expect_err("task should have failed");
if !error.is_cancelled() {
anyhow::bail!(error);
}
let cancel_map: Arc<CancelMap> = Default::default();
let session = cancel_map.clone().get_session();
assert!(cancel_map.contains(&session));
drop(session);
// Check that the session has been dropped.
assert!(CANCEL_MAP.is_empty());
assert!(cancel_map.is_empty());
Ok(())
}

View File

@@ -1,6 +1,10 @@
use crate::{
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::errors::WakeComputeError,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_DB_CONNECTIONS_GAUGE,
proxy::neon_option,
};
use futures::{FutureExt, TryFutureExt};
@@ -32,6 +36,17 @@ pub enum ConnectionError {
WakeComputeError(#[from] WakeComputeError),
}
impl ReportableError for ConnectionError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
impl UserFacingError for ConnectionError {
fn to_string_client(&self) -> String {
use ConnectionError::*;

View File

@@ -21,7 +21,7 @@ use tracing::info;
pub mod errors {
use crate::{
error::{io_error, UserFacingError},
error::{io_error, ReportableError, UserFacingError},
http,
proxy::retry::ShouldRetry,
};
@@ -56,6 +56,15 @@ pub mod errors {
}
}
impl ReportableError for ApiError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
impl UserFacingError for ApiError {
fn to_string_client(&self) -> String {
use ApiError::*;
@@ -140,6 +149,15 @@ pub mod errors {
}
}
impl ReportableError for GetAuthInfoError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
impl UserFacingError for GetAuthInfoError {
fn to_string_client(&self) -> String {
use GetAuthInfoError::*;
@@ -181,6 +199,16 @@ pub mod errors {
}
}
impl ReportableError for WakeComputeError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
WakeComputeError::ApiError(e) => e.get_error_type(),
WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
}
}
}
impl UserFacingError for WakeComputeError {
fn to_string_client(&self) -> String {
use WakeComputeError::*;

View File

@@ -38,6 +38,7 @@ pub struct RequestMonitoring {
// This sender is here to keep the request monitoring channel open while requests are taking place.
sender: Option<mpsc::UnboundedSender<RequestMonitoring>>,
pub latency_timer: LatencyTimer,
root_span: tracing::Span,
}
impl RequestMonitoring {
@@ -46,6 +47,7 @@ impl RequestMonitoring {
peer_addr: IpAddr,
protocol: &'static str,
region: &'static str,
root_span: tracing::Span,
) -> Self {
Self {
peer_addr,
@@ -64,12 +66,19 @@ impl RequestMonitoring {
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
latency_timer: LatencyTimer::new(protocol),
root_span,
}
}
#[cfg(test)]
pub fn test() -> Self {
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
RequestMonitoring::new(
Uuid::now_v7(),
[127, 0, 0, 1].into(),
"test",
"test",
tracing::Span::none(),
)
}
pub fn console_application_name(&self) -> String {
@@ -87,7 +96,10 @@ impl RequestMonitoring {
}
pub fn set_endpoint_id(&mut self, endpoint_id: Option<SmolStr>) {
self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone());
if let (None, Some(ep)) = (self.endpoint_id.as_ref(), endpoint_id) {
self.root_span.record("ep", &*ep);
self.endpoint_id = Some(ep)
}
}
pub fn set_application(&mut self, app: Option<SmolStr>) {
@@ -102,6 +114,10 @@ impl RequestMonitoring {
self.success = true;
}
pub fn error(&mut self, err: ErrorKind) {
self.error_kind = Some(err);
}
pub fn log(&mut self) {
if let Some(tx) = self.sender.take() {
let _: Result<(), _> = tx.send(self.clone());

View File

@@ -17,19 +17,16 @@ pub fn log_error<E: fmt::Display>(e: E) -> E {
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
/// is way too convenient and tends to proliferate all across the codebase,
/// ultimately leading to accidental leaks of sensitive data.
pub trait UserFacingError: fmt::Display {
pub trait UserFacingError: ReportableError {
/// Format the error for client, stripping all sensitive info.
///
/// Although this might be a no-op for many types, it's highly
/// recommended to override the default impl in case error type
/// contains anything sensitive: various IDs, IP addresses etc.
#[inline(always)]
fn to_string_client(&self) -> String {
self.to_string()
}
fn to_string_client(&self) -> String;
}
#[derive(Clone)]
#[derive(Clone, Copy)]
pub enum ErrorKind {
/// Wrong password, unknown endpoint, protocol violation, etc...
User,
@@ -62,3 +59,7 @@ impl ErrorKind {
}
}
}
pub trait ReportableError: fmt::Display + Send + 'static {
fn get_error_type(&self) -> ErrorKind;
}

View File

@@ -26,6 +26,7 @@ pub mod redis;
pub mod sasl;
pub mod scram;
pub mod serverless;
pub mod state_machine;
pub mod stream;
pub mod url;
pub mod usage_metrics;

View File

@@ -2,38 +2,32 @@
mod tests;
pub mod connect_compute;
pub mod handshake;
pub mod pass;
pub mod retry;
pub mod wake_compute;
use crate::{
auth,
cancellation::{self, CancelMap},
compute,
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
console::messages::MetricsAuxInfo,
cancellation::CancelMap,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
metrics::{
NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE,
},
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
protocol2::WithClientIp,
proxy::handshake::NeedsHandshake,
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
usage_metrics::{Ids, USAGE_METRICS},
state_machine::{DynStage, StageResult},
stream::Stream,
};
use anyhow::{bail, Context};
use futures::TryFutureExt;
use anyhow::Context;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use pq_proto::StartupMessageParams;
use regex::Regex;
use smol_str::SmolStr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use utils::measured_stream::MeasuredStream;
use self::connect_compute::{connect_to_compute, TcpMechanism};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
@@ -79,45 +73,64 @@ pub async fn task_main(
let cancel_map = Arc::clone(&cancel_map);
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let root_span = info_span!(
"handle_client",
?session_id,
peer_addr = tracing::field::Empty,
ep = tracing::field::Empty,
);
let root_span2 = root_span.clone();
connections.spawn(
async move {
info!("accepted postgres client connection");
let mut socket = WithClientIp::new(socket);
let mut peer_addr = peer_addr.ip();
if let Some(addr) = socket.wait_for_addr().await? {
peer_addr = addr.ip();
tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
} else if config.require_client_ip {
bail!("missing required client IP");
}
match socket.wait_for_addr().await {
Err(e) => {
error!("IO error: {e:#}");
return;
}
Ok(Some(addr)) => {
peer_addr = addr.ip();
root_span2.record("peer_addr", &tracing::field::display(addr));
}
Ok(None) if config.require_client_ip => {
error!("missing required client IP");
return;
}
Ok(None) => {}
};
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
"tcp",
&config.region,
root_span2,
);
socket
if let Err(e) = socket
.inner
.set_nodelay(true)
.context("failed to set socket option")?;
.context("failed to set socket option")
{
error!("could not set nodelay: {e:#}");
return;
}
handle_client(
config,
&mut ctx,
&cancel_map,
ctx,
cancel_map,
socket,
ClientMode::Tcp,
endpoint_rate_limiter,
)
.await
.await;
}
.instrument(info_span!(
"handle_client",
?session_id,
peer_addr = tracing::field::Empty
))
.unwrap_or_else(move |e| {
// Acknowledge that the task has finished with an error.
error!(?session_id, "per-client task finished with an error: {e:#}");
}),
.instrument(root_span),
);
}
@@ -137,14 +150,14 @@ pub enum ClientMode {
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
fn allow_cleartext(&self) -> bool {
pub fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
match self {
ClientMode::Tcp => config.allow_self_signed_compute,
ClientMode::Websockets { .. } => false,
@@ -167,14 +180,14 @@ impl ClientMode {
}
}
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + 'static + Send>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
cancel_map: &CancelMap,
ctx: RequestMonitoring,
cancel_map: Arc<CancelMap>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
) {
info!(
protocol = ctx.protocol,
"handling interactive connection from client"
@@ -188,308 +201,23 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.with_label_values(&[proto])
.guard();
let tls = config.tls_config.as_ref();
let pause = ctx.latency_timer.pause();
let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
drop(pause);
// Extract credentials which we're going to use for auth.
let user_info = {
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
let result = config
.auth_backend
.as_ref()
.map(|_| {
auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names)
})
.transpose();
match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e).await?,
}
};
ctx.set_endpoint_id(user_info.get_endpoint());
let client = Client::new(
let mut stage = Box::new(NeedsHandshake {
stream,
user_info,
&params,
mode.allow_self_signed_compute(config),
config,
cancel_map,
mode,
endpoint_rate_limiter,
);
cancel_map
.with_session(|session| {
client.connect_to_db(ctx, session, mode, &config.authentication_config)
})
.await
}
ctx,
}) as DynStage;
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
info!("received {msg:?}");
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empy.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
}
}
_ => bail!(ERR_PROTO_VIOLATION),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!(ERR_PROTO_VIOLATION),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
info!(session_type = "normal", "successful handshake");
break Ok(Some((stream, params)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
info!(session_type = "cancellation", "successful handshake");
break Ok(None);
}
}
}
}
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
async fn prepare_client_connection(
node: &compute::PostgresConnection,
session: cancellation::Session<'_>,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<()> {
// Register compute's query cancellation token and produce a new, unique one.
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
// Forward all postgres connection params to the client.
// Right now the implementation is very hacky and inefficent (ideally,
// we don't need an intermediate hashmap), but at least it should be correct.
for (name, value) in &node.params {
// TODO: Theoretically, this could result in a big pile of params...
stream.write_message_noflush(&Be::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
})?;
}
stream
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&Be::ReadyForQuery)
.await?;
Ok(())
}
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub async fn proxy_pass(
ctx: &mut RequestMonitoring,
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
ctx.set_success();
ctx.log();
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id.clone(),
branch_id: aux.branch_id.clone(),
});
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
m_sent2.inc_by(cnt as u64);
usage.record_egress(cnt as u64);
},
);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
m_recv2.inc_by(cnt as u64);
},
);
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
Ok(())
}
/// Thin connection context.
struct Client<'a, S> {
/// The underlying libpq protocol stream.
stream: PqStream<Stream<S>>,
/// Client credentials that we care about.
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
/// KV-dictionary with PostgreSQL connection params.
params: &'a StartupMessageParams,
/// Allow self-signed certificates (for testing).
allow_self_signed_compute: bool,
/// Rate limiter for endpoints
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
impl<'a, S> Client<'a, S> {
/// Construct a new connection context.
fn new(
stream: PqStream<Stream<S>>,
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
params: &'a StartupMessageParams,
allow_self_signed_compute: bool,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Self {
Self {
stream,
user_info,
params,
allow_self_signed_compute,
endpoint_rate_limiter,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
// Instrumentation logs endpoint name everywhere. Doesn't work for link
// auth; strictly speaking we don't know endpoint name in its case.
#[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)]
async fn connect_to_db(
self,
ctx: &mut RequestMonitoring,
session: cancellation::Session<'_>,
mode: ClientMode,
config: &'static AuthenticationConfig,
) -> anyhow::Result<()> {
let Self {
mut stream,
user_info,
params,
allow_self_signed_compute,
endpoint_rate_limiter,
} = self;
// check rate limit
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await;
}
}
let user = user_info.get_user().to_owned();
let auth_result = match user_info
.authenticate(ctx, &mut stream, mode.allow_cleartext(), config)
.await
{
Ok(auth_result) => auth_result,
while let StageResult::Run(handle) = stage.run() {
stage = match handle.await.expect("tasks should not panic") {
Ok(s) => s,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return stream.throw_error(e).instrument(params_span).await;
e.finish().await;
break;
}
};
let (mut node_info, user_info) = auth_result;
node_info.allow_self_signed_compute = allow_self_signed_compute;
let aux = node_info.aux.clone();
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info)
.or_else(|e| stream.throw_error(e))
.await?;
prepare_client_connection(&node, session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(ctx, stream, node.stream, aux).await
}
}
}

View File

@@ -1,20 +1,120 @@
use crate::{
auth,
cancellation::{self, Session},
compute::{self, PostgresConnection},
console::{self, errors::WakeComputeError, Api},
context::RequestMonitoring,
metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES},
proxy::retry::{retry_after, ShouldRetry},
state_machine::{DynStage, ResultExt, Stage, StageError},
stream::{PqStream, Stream},
};
use async_trait::async_trait;
use hyper::StatusCode;
use pq_proto::StartupMessageParams;
use std::ops::ControlFlow;
use tokio::time;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
time,
};
use tracing::{error, info, warn};
use pq_proto::BeMessage as Be;
use super::{
pass::ProxyPass,
retry::{retry_after, ShouldRetry},
};
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
pub struct NeedsComputeConnection<S> {
pub stream: PqStream<Stream<S>>,
pub user_info: auth::BackendType<'static, auth::backend::ComputeUserInfo>,
pub mechanism: TcpMechanism,
pub node_info: console::CachedNodeInfo,
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
impl<S> Stage for NeedsComputeConnection<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
fn span(&self) -> tracing::Span {
tracing::info_span!("connect_to_compute")
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
user_info,
mechanism,
node_info,
mut ctx,
cancel_session,
} = self;
let aux = node_info.aux.clone();
let (mut node, mut stream) =
connect_to_compute(&mut ctx, &mechanism, node_info, &user_info)
.await
.send_error_to_user(&mut ctx, stream)?;
prepare_client_connection(&node, &cancel_session, &mut stream)
.await
.no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream
.write_all(&read_buf)
.await
.no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?;
Ok(Box::new(ProxyPass {
client: stream,
compute: node.stream,
aux,
cancel_session,
}))
}
}
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
async fn prepare_client_connection(
node: &compute::PostgresConnection,
session: &cancellation::Session,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> std::io::Result<()> {
// Register compute's query cancellation token and produce a new, unique one.
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
// Forward all postgres connection params to the client.
// Right now the implementation is very hacky and inefficent (ideally,
// we don't need an intermediate hashmap), but at least it should be correct.
for (name, value) in &node.params {
// TODO: Theoretically, this could result in a big pile of params...
stream.write_message_noflush(&Be::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
})?;
}
stream
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&Be::ReadyForQuery)
.await?;
Ok(())
}
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
@@ -63,13 +163,13 @@ pub trait ConnectMechanism {
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
pub struct TcpMechanism<'a> {
pub struct TcpMechanism {
/// KV-dictionary with PostgreSQL connection params.
pub params: &'a StartupMessageParams,
pub params: StartupMessageParams,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism<'_> {
impl ConnectMechanism for TcpMechanism {
type Connection = PostgresConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
@@ -84,7 +184,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
config.set_startup_params(self.params);
config.set_startup_params(&self.params);
}
}

View File

@@ -0,0 +1,203 @@
use crate::{
auth::{self, backend::NeedsAuthentication},
cancellation::CancelMap,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
rate_limiter::EndpointRateLimiter,
state_machine::{DynStage, Finished, ResultExt, Stage, StageError},
stream::{PqStream, Stream, StreamUpgradeError},
};
use anyhow::{anyhow, Context};
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::{io, sync::Arc};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info};
use super::ClientMode;
pub struct NeedsHandshake<S> {
pub stream: S,
pub config: &'static ProxyConfig,
pub cancel_map: Arc<CancelMap>,
pub mode: ClientMode,
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
// monitoring
pub ctx: RequestMonitoring,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsHandshake<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("handshake")
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
config,
cancel_map,
mode,
endpoint_rate_limiter,
mut ctx,
} = self;
let tls = config.tls_config.as_ref();
let pause_timer = ctx.latency_timer.pause();
let handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map).await;
drop(pause_timer);
let (stream, params) = match handshake {
Err(err) => {
// TODO: proper handling
error!("could not complete handshake: {err:#}");
return Err(StageError::Done);
}
// cancellation
Ok(None) => return Ok(Box::new(Finished)),
Ok(Some(s)) => s,
};
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
let (creds, stream) = config
.auth_backend
.as_ref()
.map(|_| {
auth::ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &params, hostname, common_names)
})
.transpose()
.send_error_to_user(&mut ctx, stream)?;
ctx.set_endpoint_id(creds.get_endpoint());
Ok(Box::new(NeedsAuthentication {
stream,
creds,
params,
endpoint_rate_limiter,
mode,
config,
ctx,
cancel_session: cancel_map.get_session(),
}))
}
}
#[derive(Error, Debug)]
pub enum HandshakeError {
#[error("client disconnected: {0}")]
ClientIO(#[from] io::Error),
#[error("protocol violation: {0}")]
ProtocolError(#[from] anyhow::Error),
#[error("could not initiate tls connection: {0}")]
TLSError(#[from] StreamUpgradeError),
#[error("could not cancel connection: {0}")]
Cancel(anyhow::Error),
}
impl ReportableError for HandshakeError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
HandshakeError::ClientIO(_) => crate::error::ErrorKind::Disconnect,
HandshakeError::ProtocolError(_) => crate::error::ErrorKind::User,
HandshakeError::TLSError(_) => crate::error::ErrorKind::User,
HandshakeError::Cancel(_) => crate::error::ErrorKind::Compute,
}
}
}
type SuccessfulHandshake<S> = (PqStream<Stream<S>>, StartupMessageParams);
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> Result<Option<SuccessfulHandshake<S>>, HandshakeError> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
info!("received {msg:?}");
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empy.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
return Err(HandshakeError::ProtocolError(anyhow!(
"data is sent before server replied with EncryptionResponse"
)));
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
}
}
_ => return Err(HandshakeError::ProtocolError(anyhow!(ERR_PROTO_VIOLATION))),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => return Err(HandshakeError::ProtocolError(anyhow!(ERR_PROTO_VIOLATION))),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
info!(session_type = "normal", "successful handshake");
break Ok(Some((stream, params)));
}
CancelRequest(cancel_key_data) => {
cancel_map
.cancel_session(cancel_key_data)
.await
.map_err(HandshakeError::Cancel)?;
info!(session_type = "cancellation", "successful handshake");
break Ok(None);
}
}
}
}

82
proxy/src/proxy/pass.rs Normal file
View File

@@ -0,0 +1,82 @@
use crate::{
cancellation::Session,
console::messages::MetricsAuxInfo,
metrics::{NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER},
state_machine::{DynStage, Finished, Stage, StageError},
stream::Stream,
usage_metrics::{Ids, USAGE_METRICS},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info};
use utils::measured_stream::MeasuredStream;
pub struct ProxyPass<Client, Compute> {
pub client: Stream<Client>,
pub compute: Compute,
// monitoring
pub aux: MetricsAuxInfo,
pub cancel_session: Session,
}
impl<Client, Compute> Stage for ProxyPass<Client, Compute>
where
Client: AsyncRead + AsyncWrite + Unpin + Send + 'static,
Compute: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
fn span(&self) -> tracing::Span {
tracing::info_span!("proxy_pass")
}
async fn run(self) -> Result<DynStage, StageError> {
if let Err(e) = proxy_pass(self.client, self.compute, self.aux).await {
error!("{e:#}")
}
drop(self.cancel_session);
Ok(Box::new(Finished))
}
}
/// Forward bytes in both directions (client <-> compute).
pub async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id.clone(),
branch_id: aux.branch_id.clone(),
});
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
m_sent2.inc_by(cnt as u64);
usage.record_egress(cnt as u64);
},
);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
m_recv2.inc_by(cnt as u64);
},
);
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
Ok(())
}

View File

@@ -3,14 +3,19 @@
mod mitm;
use super::connect_compute::ConnectMechanism;
use super::handshake::handshake;
use super::retry::ShouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, TestBackend};
use crate::config::CertResolver;
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::proxy::connect_compute::connect_to_compute;
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
use crate::{auth, http, sasl, scram};
use crate::stream::PqStream;
use crate::{auth, compute, http, sasl, scram};
use anyhow::bail;
use async_trait::async_trait;
use pq_proto::BeMessage as Be;
use rstest::rstest;
use smol_str::SmolStr;
use tokio_postgres::config::SslMode;
@@ -202,7 +207,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
.err() // -> Option<E>
.context("server shouldn't accept client")?;
assert!(client_err.to_string().contains(&server_err.to_string()));
assert!(server_err.to_string().contains(ERR_INSECURE_CONNECTION));
Ok(())
}

View File

@@ -10,7 +10,7 @@ use super::*;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_protocol::message::frontend;
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::TlsConnect;
use tokio_util::codec::{Decoder, Encoder};

View File

@@ -0,0 +1,89 @@
use std::{borrow::Cow, ops::ControlFlow};
use pq_proto::StartupMessageParams;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, warn};
use crate::{
auth::{
backend::{ComputeCredentialKeys, ComputeCredentials},
BackendType,
},
cancellation::Session,
console::{provider::ConsoleBackend, Api},
context::RequestMonitoring,
state_machine::{user_facing_error, DynStage, Stage, StageError},
stream::{PqStream, Stream},
};
use super::{
connect_compute::{handle_try_wake, NeedsComputeConnection, TcpMechanism},
retry::retry_after,
};
pub struct NeedsWakeCompute<S> {
pub stream: PqStream<Stream<S>>,
pub api: Cow<'static, ConsoleBackend>,
pub params: StartupMessageParams,
pub allow_self_signed_compute: bool,
pub creds: ComputeCredentials<ComputeCredentialKeys>,
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsWakeCompute<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("wake_compute")
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
api,
params,
allow_self_signed_compute,
creds,
mut ctx,
cancel_session,
} = self;
let mut num_retries = 0;
let mut node_info = loop {
let wake_res = api.wake_compute(&mut ctx, &creds.info).await;
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
return Err(user_facing_error(e, &mut ctx, stream));
}
Ok(ControlFlow::Continue(e)) => {
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
}
Ok(ControlFlow::Break(n)) => break n,
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
tokio::time::sleep(wait_duration).await;
};
ctx.set_project(node_info.aux.clone());
node_info.allow_self_signed_compute = allow_self_signed_compute;
match creds.keys {
#[cfg(feature = "testing")]
ComputeCredentialKeys::Password(password) => node_info.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
};
Ok(Box::new(NeedsComputeConnection {
stream,
user_info: BackendType::Console(api, creds.info),
mechanism: TcpMechanism { params },
node_info,
ctx,
cancel_session,
}))
}
}

View File

@@ -10,7 +10,7 @@ mod channel_binding;
mod messages;
mod stream;
use crate::error::UserFacingError;
use crate::error::{ReportableError, UserFacingError};
use std::io;
use thiserror::Error;
@@ -37,6 +37,25 @@ pub enum Error {
Io(#[from] io::Error),
}
impl ReportableError for Error {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
Error::MissingBinding => crate::error::ErrorKind::Service,
Error::Io(io) => match io.kind() {
// tokio postgres uses these for various scram failures
io::ErrorKind::InvalidInput
| io::ErrorKind::UnexpectedEof
| io::ErrorKind::Other => crate::error::ErrorKind::User,
// all other IO errors are likely disconnects.
_ => crate::error::ErrorKind::Disconnect,
},
}
}
}
impl UserFacingError for Error {
fn to_string_client(&self) -> String {
use Error::*;

View File

@@ -124,6 +124,12 @@ pub async fn task_main(
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
let root_span = info_span!(
"serverless",
session = %session_id,
%peer_addr,
);
request_handler(
req,
config,
@@ -135,12 +141,9 @@ pub async fn task_main(
sni_name,
peer_addr.ip(),
endpoint_rate_limiter,
root_span.clone(),
)
.instrument(info_span!(
"serverless",
session = %session_id,
%peer_addr,
))
.instrument(root_span)
.await
}
},
@@ -205,6 +208,7 @@ async fn request_handler(
sni_hostname: Option<String>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
root_span: tracing::Span,
) -> Result<Response<Body>, ApiError> {
let host = request
.headers()
@@ -215,27 +219,33 @@ async fn request_handler(
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&request) {
info!(session_id = ?session_id, "performing websocket upgrade");
info!("performing websocket upgrade");
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
.map_err(|e| ApiError::BadRequest(e.into()))?;
ws_connections.spawn(
async move {
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
let ctx =
RequestMonitoring::new(session_id, peer_addr, "ws", &config.region, root_span);
if let Err(e) = websocket::serve_websocket(
let websocket = match websocket.await {
Err(e) => {
error!("error in websocket connection: {e:#}");
return;
}
Ok(ws) => ws,
};
websocket::serve_websocket(
config,
&mut ctx,
ctx,
websocket,
&cancel_map,
cancel_map,
host,
endpoint_rate_limiter,
)
.await
{
error!(session_id = ?session_id, "error in websocket connection: {e:#}");
}
}
.in_current_span(),
);
@@ -243,7 +253,8 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
let mut ctx =
RequestMonitoring::new(session_id, peer_addr, "http", &config.region, root_span);
sql_over_http::handle(
tls,

View File

@@ -9,7 +9,7 @@ use crate::{
use bytes::{Buf, Bytes};
use futures::{Sink, Stream};
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
use hyper_tungstenite::{tungstenite::Message, WebSocketStream};
use pin_project_lite::pin_project;
use std::{
@@ -131,13 +131,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub async fn serve_websocket(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
websocket: HyperWebsocket,
cancel_map: &CancelMap,
ctx: RequestMonitoring,
websocket: WebSocketStream<Upgraded>,
cancel_map: Arc<CancelMap>,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
) {
handle_client(
config,
ctx,
@@ -146,8 +145,7 @@ pub async fn serve_websocket(
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
)
.await?;
Ok(())
.await
}
#[cfg(test)]

149
proxy/src/state_machine.rs Normal file
View File

@@ -0,0 +1,149 @@
use futures::Future;
use pq_proto::{framed::Framed, BeMessage};
use tokio::{io::AsyncWrite, task::JoinHandle};
use tracing::{info, warn, Instrument};
pub trait Captures<T> {}
impl<T, U> Captures<T> for U {}
#[must_use]
pub enum StageError {
Flush(Framed<Box<dyn AsyncWrite + Unpin + Send + 'static>>),
Done,
}
impl StageError {
pub async fn finish(self) {
match self {
StageError::Flush(mut f) => {
// ignore result. we can't do anything about it.
// this is already the error case anyway...
if let Err(e) = f.flush().await {
warn!("could not send message to user: {e:?}")
}
}
StageError::Done => {}
}
info!("task finished");
}
}
pub type DynStage = Box<dyn StageSpawn>;
/// Stage represents a single stage in a state machine.
pub trait Stage: 'static + Send {
/// The span this stage should be run inside.
fn span(&self) -> tracing::Span;
/// Run the current stage, returning a new [`DynStage`], or an error
///
/// Can be implemented as `async fn run(self) -> Result<DynStage, StageError>`
fn run(self) -> impl 'static + Send + Future<Output = Result<DynStage, StageError>>;
}
pub enum StageResult {
Finished,
Run(JoinHandle<Result<DynStage, StageError>>),
}
pub trait StageSpawn: 'static + Send {
fn run(self: Box<Self>) -> StageResult;
}
/// Stage spawn is a helper trait for the state machine. It spawns the stages as a tokio task
impl<S: Stage> StageSpawn for S {
fn run(self: Box<Self>) -> StageResult {
let span = self.span();
StageResult::Run(tokio::spawn(S::run(*self).instrument(span)))
}
}
pub struct Finished;
impl StageSpawn for Finished {
fn run(self: Box<Self>) -> StageResult {
StageResult::Finished
}
}
use crate::{
context::RequestMonitoring,
error::{ErrorKind, UserFacingError},
stream::PqStream,
};
pub trait ResultExt<T, E> {
fn send_error_to_user<S>(
self,
ctx: &mut RequestMonitoring,
stream: PqStream<S>,
) -> Result<(T, PqStream<S>), StageError>
where
S: AsyncWrite + Unpin + Send + 'static,
E: UserFacingError;
fn no_user_error(self, ctx: &mut RequestMonitoring, kind: ErrorKind) -> Result<T, StageError>
where
E: std::fmt::Display;
}
impl<T, E> ResultExt<T, E> for Result<T, E> {
fn send_error_to_user<S>(
self,
ctx: &mut RequestMonitoring,
stream: PqStream<S>,
) -> Result<(T, PqStream<S>), StageError>
where
S: AsyncWrite + Unpin + Send + 'static,
E: UserFacingError,
{
match self {
Ok(t) => Ok((t, stream)),
Err(e) => Err(user_facing_error(e, ctx, stream)),
}
}
fn no_user_error(self, ctx: &mut RequestMonitoring, kind: ErrorKind) -> Result<T, StageError>
where
E: std::fmt::Display,
{
match self {
Ok(t) => Ok(t),
Err(e) => {
tracing::error!(
kind = kind.to_str(),
user_msg = "",
"task finished with error: {e}"
);
ctx.error(kind);
ctx.log();
Err(StageError::Done)
}
}
}
}
pub fn user_facing_error<S, E>(
err: E,
ctx: &mut RequestMonitoring,
mut stream: PqStream<S>,
) -> StageError
where
S: AsyncWrite + Unpin + Send + 'static,
E: UserFacingError,
{
let kind = err.get_error_type();
ctx.error(kind);
ctx.log();
let msg = err.to_string_client();
tracing::error!(
kind = kind.to_str(),
user_msg = msg,
"task finished with error: {err}"
);
if let Err(err) = stream.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)) {
warn!("could not process error message: {err:?}")
}
StageError::Flush(stream.framed.map_stream_sync(|f| Box::new(f) as Box<_>))
}

View File

@@ -1,5 +1,5 @@
use crate::config::TlsServerEndPoint;
use crate::error::UserFacingError;
use crate::error::ErrorKind;
use anyhow::bail;
use bytes::BytesMut;
@@ -99,24 +99,17 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
tracing::info!("forwarding error to user: {error}");
let kind = ErrorKind::User;
tracing::error!(
kind = kind.to_str(),
full_msg = error,
user_msg = error,
"task finished with error"
);
self.write_message(&BeMessage::ErrorResponse(error, None))
.await?;
bail!(error)
}
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
where
E: UserFacingError + Into<anyhow::Error>,
{
let msg = error.to_string_client();
tracing::info!("forwarding error to user: {msg}");
self.write_message(&BeMessage::ErrorResponse(&msg, None))
.await?;
bail!(error)
}
}
/// Wrapper for upgrading raw streams into secure streams.