mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-12 07:00:36 +00:00
Compare commits
1 Commits
split-prox
...
actorsssss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fed0ba44d |
@@ -82,6 +82,19 @@ impl<S> Framed<S> {
|
|||||||
write_buf: self.write_buf,
|
write_buf: self.write_buf,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return new Framed with stream type transformed by f. For dynamic dispatch.
|
||||||
|
pub fn map_stream_sync<S2, F>(self, f: F) -> Framed<S2>
|
||||||
|
where
|
||||||
|
F: FnOnce(S) -> S2,
|
||||||
|
{
|
||||||
|
let stream = f(self.stream);
|
||||||
|
Framed {
|
||||||
|
stream,
|
||||||
|
read_buf: self.read_buf,
|
||||||
|
write_buf: self.write_buf,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncRead + Unpin> Framed<S> {
|
impl<S: AsyncRead + Unpin> Framed<S> {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use password_hack::PasswordHackPayload;
|
|||||||
mod flow;
|
mod flow;
|
||||||
pub use flow::*;
|
pub use flow::*;
|
||||||
|
|
||||||
use crate::{console, error::UserFacingError};
|
use crate::error::{ReportableError, UserFacingError};
|
||||||
use std::io;
|
use std::io;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
@@ -23,15 +23,6 @@ pub type Result<T> = std::result::Result<T, AuthError>;
|
|||||||
/// Common authentication error.
|
/// Common authentication error.
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum AuthErrorImpl {
|
pub enum AuthErrorImpl {
|
||||||
#[error(transparent)]
|
|
||||||
Link(#[from] backend::LinkAuthError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
WakeCompute(#[from] console::errors::WakeComputeError),
|
|
||||||
|
|
||||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Sasl(#[from] crate::sasl::Error),
|
Sasl(#[from] crate::sasl::Error),
|
||||||
@@ -99,13 +90,25 @@ impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ReportableError for AuthError {
|
||||||
|
fn get_error_type(&self) -> crate::error::ErrorKind {
|
||||||
|
match self.0.as_ref() {
|
||||||
|
AuthErrorImpl::Sasl(s) => s.get_error_type(),
|
||||||
|
AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||||
|
AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User,
|
||||||
|
AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User,
|
||||||
|
AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User,
|
||||||
|
AuthErrorImpl::Io(_) => crate::error::ErrorKind::Disconnect,
|
||||||
|
AuthErrorImpl::IpAddressNotAllowed => crate::error::ErrorKind::User,
|
||||||
|
AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UserFacingError for AuthError {
|
impl UserFacingError for AuthError {
|
||||||
fn to_string_client(&self) -> String {
|
fn to_string_client(&self) -> String {
|
||||||
use AuthErrorImpl::*;
|
use AuthErrorImpl::*;
|
||||||
match self.0.as_ref() {
|
match self.0.as_ref() {
|
||||||
Link(e) => e.to_string_client(),
|
|
||||||
GetAuthInfo(e) => e.to_string_client(),
|
|
||||||
WakeCompute(e) => e.to_string_client(),
|
|
||||||
Sasl(e) => e.to_string_client(),
|
Sasl(e) => e.to_string_client(),
|
||||||
AuthFailed(_) => self.to_string(),
|
AuthFailed(_) => self.to_string(),
|
||||||
BadAuthMethod(_) => self.to_string(),
|
BadAuthMethod(_) => self.to_string(),
|
||||||
|
|||||||
@@ -2,22 +2,27 @@ mod classic;
|
|||||||
mod hacks;
|
mod hacks;
|
||||||
mod link;
|
mod link;
|
||||||
|
|
||||||
pub use link::LinkAuthError;
|
use pq_proto::StartupMessageParams;
|
||||||
use smol_str::SmolStr;
|
use smol_str::SmolStr;
|
||||||
use tokio_postgres::config::AuthKeys;
|
use tokio_postgres::config::AuthKeys;
|
||||||
|
|
||||||
|
use crate::auth::backend::link::NeedsLinkAuthentication;
|
||||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||||
use crate::auth::validate_password_and_exchange;
|
use crate::auth::validate_password_and_exchange;
|
||||||
use crate::cache::Cached;
|
use crate::cache::Cached;
|
||||||
|
use crate::cancellation::Session;
|
||||||
|
use crate::config::ProxyConfig;
|
||||||
use crate::console::errors::GetAuthInfoError;
|
use crate::console::errors::GetAuthInfoError;
|
||||||
use crate::console::provider::ConsoleBackend;
|
use crate::console::provider::ConsoleBackend;
|
||||||
use crate::console::AuthSecret;
|
use crate::console::AuthSecret;
|
||||||
use crate::context::RequestMonitoring;
|
use crate::context::RequestMonitoring;
|
||||||
use crate::proxy::connect_compute::handle_try_wake;
|
use crate::proxy::wake_compute::NeedsWakeCompute;
|
||||||
use crate::proxy::retry::retry_after;
|
use crate::proxy::ClientMode;
|
||||||
use crate::proxy::NeonOptions;
|
use crate::proxy::NeonOptions;
|
||||||
|
use crate::rate_limiter::EndpointRateLimiter;
|
||||||
use crate::scram;
|
use crate::scram;
|
||||||
use crate::stream::Stream;
|
use crate::state_machine::{user_facing_error, DynStage, ResultExt, Stage, StageError};
|
||||||
|
use crate::stream::{PqStream, Stream};
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{self, ComputeUserInfoMaybeEndpoint},
|
auth::{self, ComputeUserInfoMaybeEndpoint},
|
||||||
config::AuthenticationConfig,
|
config::AuthenticationConfig,
|
||||||
@@ -30,10 +35,11 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use futures::TryFutureExt;
|
use futures::TryFutureExt;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::ops::ControlFlow;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tracing::{error, info, warn};
|
use tracing::info;
|
||||||
|
|
||||||
|
use self::hacks::NeedsPasswordHack;
|
||||||
|
|
||||||
/// This type serves two purposes:
|
/// This type serves two purposes:
|
||||||
///
|
///
|
||||||
@@ -170,66 +176,94 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// True to its name, this function encapsulates our current auth trade-offs.
|
struct NeedsAuthSecret<S> {
|
||||||
/// Here, we choose the appropriate auth flow based on circumstances.
|
stream: PqStream<Stream<S>>,
|
||||||
///
|
api: Cow<'static, ConsoleBackend>,
|
||||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
params: StartupMessageParams,
|
||||||
async fn auth_quirks(
|
allow_self_signed_compute: bool,
|
||||||
ctx: &mut RequestMonitoring,
|
|
||||||
api: &impl console::Api,
|
|
||||||
user_info: ComputeUserInfoMaybeEndpoint,
|
|
||||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
|
||||||
allow_cleartext: bool,
|
allow_cleartext: bool,
|
||||||
|
info: ComputeUserInfo,
|
||||||
|
unauthenticated_password: Option<Vec<u8>>,
|
||||||
config: &'static AuthenticationConfig,
|
config: &'static AuthenticationConfig,
|
||||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
|
||||||
// If there's no project so far, that entails that client doesn't
|
|
||||||
// support SNI or other means of passing the endpoint (project) name.
|
|
||||||
// We now expect to see a very specific payload in the place of password.
|
|
||||||
let (info, unauthenticated_password) = match user_info.try_into() {
|
|
||||||
Err(info) => {
|
|
||||||
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
|
|
||||||
.await?;
|
|
||||||
ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
|
|
||||||
(res.info, Some(res.keys))
|
|
||||||
}
|
|
||||||
Ok(info) => (info, None),
|
|
||||||
};
|
|
||||||
|
|
||||||
info!("fetching user's authentication info");
|
// monitoring
|
||||||
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
|
ctx: RequestMonitoring,
|
||||||
|
cancel_session: Session,
|
||||||
|
}
|
||||||
|
|
||||||
// check allowed list
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsAuthSecret<S> {
|
||||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
fn span(&self) -> tracing::Span {
|
||||||
return Err(auth::AuthError::ip_address_not_allowed());
|
tracing::info_span!("get_auth_secret")
|
||||||
}
|
}
|
||||||
let cached_secret = api.get_role_secret(ctx, &info).await?;
|
async fn run(self) -> Result<DynStage, StageError> {
|
||||||
|
let Self {
|
||||||
|
stream,
|
||||||
|
api,
|
||||||
|
params,
|
||||||
|
allow_cleartext,
|
||||||
|
allow_self_signed_compute,
|
||||||
|
info,
|
||||||
|
unauthenticated_password,
|
||||||
|
config,
|
||||||
|
mut ctx,
|
||||||
|
cancel_session,
|
||||||
|
} = self;
|
||||||
|
|
||||||
let secret = cached_secret.value.clone().unwrap_or_else(|| {
|
info!("fetching user's authentication info");
|
||||||
// If we don't have an authentication secret, we mock one to
|
let (allowed_ips, stream) = api
|
||||||
// prevent malicious probing (possible due to missing protocol steps).
|
.get_allowed_ips(&mut ctx, &info)
|
||||||
// This mocked secret will never lead to successful authentication.
|
.await
|
||||||
info!("authentication info not found, mocking it");
|
.send_error_to_user(&mut ctx, stream)?;
|
||||||
AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random()))
|
|
||||||
});
|
// check allowed list
|
||||||
match authenticate_with_secret(
|
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||||
ctx,
|
return Err(user_facing_error(
|
||||||
secret,
|
auth::AuthError::ip_address_not_allowed(),
|
||||||
info,
|
&mut ctx,
|
||||||
client,
|
stream,
|
||||||
unauthenticated_password,
|
));
|
||||||
allow_cleartext,
|
}
|
||||||
config,
|
let (cached_secret, mut stream) = api
|
||||||
)
|
.get_role_secret(&mut ctx, &info)
|
||||||
.await
|
.await
|
||||||
{
|
.send_error_to_user(&mut ctx, stream)?;
|
||||||
Ok(keys) => Ok(keys),
|
|
||||||
Err(e) => {
|
let secret = cached_secret.value.clone().unwrap_or_else(|| {
|
||||||
|
// If we don't have an authentication secret, we mock one to
|
||||||
|
// prevent malicious probing (possible due to missing protocol steps).
|
||||||
|
// This mocked secret will never lead to successful authentication.
|
||||||
|
info!("authentication info not found, mocking it");
|
||||||
|
AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random()))
|
||||||
|
});
|
||||||
|
|
||||||
|
let (keys, stream) = authenticate_with_secret(
|
||||||
|
&mut ctx,
|
||||||
|
secret,
|
||||||
|
info,
|
||||||
|
&mut stream,
|
||||||
|
unauthenticated_password,
|
||||||
|
allow_cleartext,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
if e.is_auth_failed() {
|
if e.is_auth_failed() {
|
||||||
// The password could have been changed, so we invalidate the cache.
|
// The password could have been changed, so we invalidate the cache.
|
||||||
cached_secret.invalidate();
|
cached_secret.invalidate();
|
||||||
}
|
}
|
||||||
Err(e)
|
e
|
||||||
}
|
})
|
||||||
|
.send_error_to_user(&mut ctx, stream)?;
|
||||||
|
|
||||||
|
Ok(Box::new(NeedsWakeCompute {
|
||||||
|
stream,
|
||||||
|
api,
|
||||||
|
params,
|
||||||
|
allow_self_signed_compute,
|
||||||
|
creds: keys,
|
||||||
|
ctx,
|
||||||
|
cancel_session,
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,49 +304,6 @@ async fn authenticate_with_secret(
|
|||||||
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
|
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache)
|
|
||||||
/// only if authentication was successfuly.
|
|
||||||
async fn auth_and_wake_compute(
|
|
||||||
ctx: &mut RequestMonitoring,
|
|
||||||
api: &impl console::Api,
|
|
||||||
user_info: ComputeUserInfoMaybeEndpoint,
|
|
||||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
|
||||||
allow_cleartext: bool,
|
|
||||||
config: &'static AuthenticationConfig,
|
|
||||||
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
|
|
||||||
let compute_credentials =
|
|
||||||
auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?;
|
|
||||||
|
|
||||||
let mut num_retries = 0;
|
|
||||||
let mut node = loop {
|
|
||||||
let wake_res = api.wake_compute(ctx, &compute_credentials.info).await;
|
|
||||||
match handle_try_wake(wake_res, num_retries) {
|
|
||||||
Err(e) => {
|
|
||||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
|
||||||
return Err(e.into());
|
|
||||||
}
|
|
||||||
Ok(ControlFlow::Continue(e)) => {
|
|
||||||
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
|
|
||||||
}
|
|
||||||
Ok(ControlFlow::Break(n)) => break n,
|
|
||||||
}
|
|
||||||
|
|
||||||
let wait_duration = retry_after(num_retries);
|
|
||||||
num_retries += 1;
|
|
||||||
tokio::time::sleep(wait_duration).await;
|
|
||||||
};
|
|
||||||
|
|
||||||
ctx.set_project(node.aux.clone());
|
|
||||||
|
|
||||||
match compute_credentials.keys {
|
|
||||||
#[cfg(feature = "testing")]
|
|
||||||
ComputeCredentialKeys::Password(password) => node.config.password(password),
|
|
||||||
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok((node, compute_credentials.info))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||||
/// Get compute endpoint name from the credentials.
|
/// Get compute endpoint name from the credentials.
|
||||||
pub fn get_endpoint(&self) -> Option<SmolStr> {
|
pub fn get_endpoint(&self) -> Option<SmolStr> {
|
||||||
@@ -337,50 +328,96 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
|||||||
Test(_) => "test",
|
Test(_) => "test",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
pub struct NeedsAuthentication<S> {
|
||||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
pub stream: PqStream<Stream<S>>,
|
||||||
pub async fn authenticate(
|
pub creds: BackendType<'static, auth::ComputeUserInfoMaybeEndpoint>,
|
||||||
self,
|
pub params: StartupMessageParams,
|
||||||
ctx: &mut RequestMonitoring,
|
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
pub mode: ClientMode,
|
||||||
allow_cleartext: bool,
|
pub config: &'static ProxyConfig,
|
||||||
config: &'static AuthenticationConfig,
|
|
||||||
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
|
|
||||||
use BackendType::*;
|
|
||||||
|
|
||||||
let res = match self {
|
// monitoring
|
||||||
Console(api, user_info) => {
|
pub ctx: RequestMonitoring,
|
||||||
info!(
|
pub cancel_session: Session,
|
||||||
user = &*user_info.user,
|
}
|
||||||
project = user_info.project(),
|
|
||||||
"performing authentication using the console"
|
|
||||||
);
|
|
||||||
|
|
||||||
let (cache_info, user_info) =
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsAuthentication<S> {
|
||||||
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
|
fn span(&self) -> tracing::Span {
|
||||||
.await?;
|
tracing::info_span!("authenticate")
|
||||||
(cache_info, BackendType::Console(api, user_info))
|
}
|
||||||
|
async fn run(self) -> Result<DynStage, StageError> {
|
||||||
|
let Self {
|
||||||
|
stream,
|
||||||
|
creds,
|
||||||
|
params,
|
||||||
|
endpoint_rate_limiter,
|
||||||
|
mode,
|
||||||
|
config,
|
||||||
|
mut ctx,
|
||||||
|
cancel_session,
|
||||||
|
} = self;
|
||||||
|
|
||||||
|
// check rate limit
|
||||||
|
if let Some(ep) = creds.get_endpoint() {
|
||||||
|
if !endpoint_rate_limiter.check(ep) {
|
||||||
|
return Err(user_facing_error(
|
||||||
|
auth::AuthError::too_many_connections(),
|
||||||
|
&mut ctx,
|
||||||
|
stream,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let allow_self_signed_compute = mode.allow_self_signed_compute(config);
|
||||||
|
let allow_cleartext = mode.allow_cleartext();
|
||||||
|
|
||||||
|
match creds {
|
||||||
|
BackendType::Console(api, creds) => {
|
||||||
|
// If there's no project so far, that entails that client doesn't
|
||||||
|
// support SNI or other means of passing the endpoint (project) name.
|
||||||
|
// We now expect to see a very specific payload in the place of password.
|
||||||
|
match creds.try_into() {
|
||||||
|
Err(info) => Ok(Box::new(NeedsPasswordHack {
|
||||||
|
stream,
|
||||||
|
api,
|
||||||
|
params,
|
||||||
|
allow_self_signed_compute,
|
||||||
|
info,
|
||||||
|
allow_cleartext,
|
||||||
|
config: &config.authentication_config,
|
||||||
|
ctx,
|
||||||
|
cancel_session,
|
||||||
|
})),
|
||||||
|
Ok(info) => Ok(Box::new(NeedsAuthSecret {
|
||||||
|
stream,
|
||||||
|
api,
|
||||||
|
params,
|
||||||
|
allow_self_signed_compute,
|
||||||
|
info,
|
||||||
|
unauthenticated_password: None,
|
||||||
|
allow_cleartext,
|
||||||
|
config: &config.authentication_config,
|
||||||
|
ctx,
|
||||||
|
cancel_session,
|
||||||
|
})),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// NOTE: this auth backend doesn't use client credentials.
|
// NOTE: this auth backend doesn't use client credentials.
|
||||||
Link(url) => {
|
BackendType::Link(link) => Ok(Box::new(NeedsLinkAuthentication {
|
||||||
info!("performing link authentication");
|
stream,
|
||||||
|
link,
|
||||||
let node_info = link::authenticate(&url, client).await?;
|
params,
|
||||||
|
allow_self_signed_compute,
|
||||||
(
|
ctx,
|
||||||
CachedNodeInfo::new_uncached(node_info),
|
cancel_session,
|
||||||
BackendType::Link(url),
|
})),
|
||||||
)
|
|
||||||
}
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
Test(_) => {
|
BackendType::Test(_) => {
|
||||||
unreachable!("this function should never be called in the test backend")
|
unreachable!("this function should never be called in the test backend")
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
info!("user successfully authenticated");
|
|
||||||
Ok(res)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,21 @@
|
|||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
|
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
|
||||||
|
NeedsAuthSecret,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{self, AuthFlow},
|
auth::{self, AuthFlow},
|
||||||
console::AuthSecret,
|
cancellation::Session,
|
||||||
|
config::AuthenticationConfig,
|
||||||
|
console::{provider::ConsoleBackend, AuthSecret},
|
||||||
|
context::RequestMonitoring,
|
||||||
metrics::LatencyTimer,
|
metrics::LatencyTimer,
|
||||||
sasl,
|
sasl,
|
||||||
stream::{self, Stream},
|
state_machine::{DynStage, ResultExt, Stage, StageError},
|
||||||
|
stream::{self, PqStream, Stream},
|
||||||
};
|
};
|
||||||
|
use pq_proto::StartupMessageParams;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
@@ -46,7 +54,7 @@ pub async fn authenticate_cleartext(
|
|||||||
/// Workaround for clients which don't provide an endpoint (project) name.
|
/// Workaround for clients which don't provide an endpoint (project) name.
|
||||||
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
||||||
/// and passwords are not yet validated (we don't know how to validate them!)
|
/// and passwords are not yet validated (we don't know how to validate them!)
|
||||||
pub async fn password_hack_no_authentication(
|
async fn password_hack_no_authentication(
|
||||||
info: ComputeUserInfoNoEndpoint,
|
info: ComputeUserInfoNoEndpoint,
|
||||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||||
latency_timer: &mut LatencyTimer,
|
latency_timer: &mut LatencyTimer,
|
||||||
@@ -74,3 +82,47 @@ pub async fn password_hack_no_authentication(
|
|||||||
keys: payload.password,
|
keys: payload.password,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct NeedsPasswordHack<S> {
|
||||||
|
pub stream: PqStream<Stream<S>>,
|
||||||
|
pub api: Cow<'static, ConsoleBackend>,
|
||||||
|
pub params: StartupMessageParams,
|
||||||
|
pub allow_self_signed_compute: bool,
|
||||||
|
pub allow_cleartext: bool,
|
||||||
|
pub info: ComputeUserInfoNoEndpoint,
|
||||||
|
pub config: &'static AuthenticationConfig,
|
||||||
|
|
||||||
|
// monitoring
|
||||||
|
pub ctx: RequestMonitoring,
|
||||||
|
pub cancel_session: Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsPasswordHack<S> {
|
||||||
|
fn span(&self) -> tracing::Span {
|
||||||
|
tracing::info_span!("password_hack")
|
||||||
|
}
|
||||||
|
async fn run(mut self) -> Result<DynStage, StageError> {
|
||||||
|
let (res, stream) = password_hack_no_authentication(
|
||||||
|
self.info,
|
||||||
|
&mut self.stream,
|
||||||
|
&mut self.ctx.latency_timer,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.send_error_to_user(&mut self.ctx, self.stream)?;
|
||||||
|
|
||||||
|
self.ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
|
||||||
|
Ok(Box::new(NeedsAuthSecret {
|
||||||
|
stream,
|
||||||
|
info: res.info,
|
||||||
|
unauthenticated_password: Some(res.keys),
|
||||||
|
|
||||||
|
api: self.api,
|
||||||
|
params: self.params,
|
||||||
|
allow_self_signed_compute: self.allow_self_signed_compute,
|
||||||
|
allow_cleartext: self.allow_cleartext,
|
||||||
|
ctx: self.ctx,
|
||||||
|
cancel_session: self.cancel_session,
|
||||||
|
config: self.config,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,41 +1,20 @@
|
|||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth, compute,
|
auth::BackendType,
|
||||||
console::{self, provider::NodeInfo},
|
cancellation::Session,
|
||||||
error::UserFacingError,
|
compute,
|
||||||
stream::PqStream,
|
console::{self, mgmt::ComputeReady, provider::NodeInfo, CachedNodeInfo},
|
||||||
waiters,
|
context::RequestMonitoring,
|
||||||
|
proxy::connect_compute::{NeedsComputeConnection, TcpMechanism},
|
||||||
|
state_machine::{DynStage, ResultExt, Stage, StageError},
|
||||||
|
stream::{PqStream, Stream},
|
||||||
|
waiters::Waiter,
|
||||||
};
|
};
|
||||||
use pq_proto::BeMessage as Be;
|
use pq_proto::{BeMessage as Be, StartupMessageParams};
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio_postgres::config::SslMode;
|
use tokio_postgres::config::SslMode;
|
||||||
use tracing::{info, info_span};
|
use tracing::info;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum LinkAuthError {
|
|
||||||
/// Authentication error reported by the console.
|
|
||||||
#[error("Authentication failed: {0}")]
|
|
||||||
AuthFailed(String),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
WaiterRegister(#[from] waiters::RegisterError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
WaiterWait(#[from] waiters::WaitError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Io(#[from] std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UserFacingError for LinkAuthError {
|
|
||||||
fn to_string_client(&self) -> String {
|
|
||||||
use LinkAuthError::*;
|
|
||||||
match self {
|
|
||||||
AuthFailed(_) => self.to_string(),
|
|
||||||
_ => "Internal error".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
|
fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
|
||||||
format!(
|
format!(
|
||||||
@@ -53,64 +32,146 @@ pub fn new_psql_session_id() -> String {
|
|||||||
hex::encode(rand::random::<[u8; 8]>())
|
hex::encode(rand::random::<[u8; 8]>())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) async fn authenticate(
|
pub struct NeedsLinkAuthentication<S> {
|
||||||
link_uri: &reqwest::Url,
|
pub stream: PqStream<Stream<S>>,
|
||||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
pub link: Cow<'static, crate::url::ApiUrl>,
|
||||||
) -> auth::Result<NodeInfo> {
|
pub params: StartupMessageParams,
|
||||||
// registering waiter can fail if we get unlucky with rng.
|
pub allow_self_signed_compute: bool,
|
||||||
// just try again.
|
|
||||||
let (psql_session_id, waiter) = loop {
|
|
||||||
let psql_session_id = new_psql_session_id();
|
|
||||||
|
|
||||||
match console::mgmt::get_waiter(&psql_session_id) {
|
// monitoring
|
||||||
Ok(waiter) => break (psql_session_id, waiter),
|
pub ctx: RequestMonitoring,
|
||||||
Err(_e) => continue,
|
pub cancel_session: Session,
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsLinkAuthentication<S> {
|
||||||
let span = info_span!("link", psql_session_id = &psql_session_id);
|
fn span(&self) -> tracing::Span {
|
||||||
let greeting = hello_message(link_uri, &psql_session_id);
|
tracing::info_span!("link", psql_session_id = tracing::field::Empty)
|
||||||
|
}
|
||||||
// Give user a URL to spawn a new database.
|
async fn run(self) -> Result<DynStage, StageError> {
|
||||||
info!(parent: &span, "sending the auth URL to the user");
|
let Self {
|
||||||
client
|
mut stream,
|
||||||
.write_message_noflush(&Be::AuthenticationOk)?
|
link,
|
||||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
params,
|
||||||
.write_message(&Be::NoticeResponse(&greeting))
|
allow_self_signed_compute,
|
||||||
.await?;
|
mut ctx,
|
||||||
|
cancel_session,
|
||||||
// Wait for web console response (see `mgmt`).
|
} = self;
|
||||||
info!(parent: &span, "waiting for console's reply...");
|
|
||||||
let db_info = waiter.await.map_err(LinkAuthError::from)?;
|
// registering waiter can fail if we get unlucky with rng.
|
||||||
|
// just try again.
|
||||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
let (psql_session_id, waiter) = loop {
|
||||||
|
let psql_session_id = new_psql_session_id();
|
||||||
// This config should be self-contained, because we won't
|
|
||||||
// take username or dbname from client's startup message.
|
match console::mgmt::get_waiter(&psql_session_id) {
|
||||||
let mut config = compute::ConnCfg::new();
|
Ok(waiter) => break (psql_session_id, waiter),
|
||||||
config
|
Err(_e) => continue,
|
||||||
.host(&db_info.host)
|
}
|
||||||
.port(db_info.port)
|
};
|
||||||
.dbname(&db_info.dbname)
|
tracing::Span::current().record("psql_session_id", &psql_session_id);
|
||||||
.user(&db_info.user);
|
let greeting = hello_message(&link, &psql_session_id);
|
||||||
|
|
||||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
info!("sending the auth URL to the user");
|
||||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
|
||||||
// everywhere, we can remove this.
|
stream
|
||||||
if db_info.host.contains("--") {
|
.write_message_noflush(&Be::AuthenticationOk)
|
||||||
// we need TLS connection with SNI info to properly route it
|
.and_then(|s| s.write_message_noflush(&Be::CLIENT_ENCODING))
|
||||||
config.ssl_mode(SslMode::Require);
|
.and_then(|s| s.write_message_noflush(&Be::NoticeResponse(&greeting)))
|
||||||
} else {
|
.no_user_error(&mut ctx, crate::error::ErrorKind::Service)?
|
||||||
config.ssl_mode(SslMode::Disable);
|
.flush()
|
||||||
}
|
.await
|
||||||
|
.no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?;
|
||||||
if let Some(password) = db_info.password {
|
|
||||||
config.password(password.as_ref());
|
Ok(Box::new(NeedsLinkAuthenticationResponse {
|
||||||
}
|
stream,
|
||||||
|
link,
|
||||||
Ok(NodeInfo {
|
params,
|
||||||
config,
|
allow_self_signed_compute,
|
||||||
aux: db_info.aux,
|
waiter,
|
||||||
allow_self_signed_compute: false, // caller may override
|
psql_session_id,
|
||||||
})
|
ctx,
|
||||||
|
cancel_session,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NeedsLinkAuthenticationResponse<S> {
|
||||||
|
stream: PqStream<Stream<S>>,
|
||||||
|
link: Cow<'static, crate::url::ApiUrl>,
|
||||||
|
params: StartupMessageParams,
|
||||||
|
allow_self_signed_compute: bool,
|
||||||
|
waiter: Waiter<'static, ComputeReady>,
|
||||||
|
psql_session_id: String,
|
||||||
|
|
||||||
|
// monitoring
|
||||||
|
ctx: RequestMonitoring,
|
||||||
|
cancel_session: Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage
|
||||||
|
for NeedsLinkAuthenticationResponse<S>
|
||||||
|
{
|
||||||
|
fn span(&self) -> tracing::Span {
|
||||||
|
tracing::info_span!("link_wait", psql_session_id = self.psql_session_id)
|
||||||
|
}
|
||||||
|
async fn run(self) -> Result<DynStage, StageError> {
|
||||||
|
let Self {
|
||||||
|
mut stream,
|
||||||
|
link,
|
||||||
|
params,
|
||||||
|
allow_self_signed_compute,
|
||||||
|
waiter,
|
||||||
|
psql_session_id: _,
|
||||||
|
mut ctx,
|
||||||
|
cancel_session,
|
||||||
|
} = self;
|
||||||
|
|
||||||
|
// Wait for web console response (see `mgmt`).
|
||||||
|
info!("waiting for console's reply...");
|
||||||
|
let db_info = waiter
|
||||||
|
.await
|
||||||
|
.no_user_error(&mut ctx, crate::error::ErrorKind::Service)?;
|
||||||
|
|
||||||
|
stream
|
||||||
|
.write_message_noflush(&Be::NoticeResponse("Connecting to database."))
|
||||||
|
.no_user_error(&mut ctx, crate::error::ErrorKind::Service)?;
|
||||||
|
|
||||||
|
// This config should be self-contained, because we won't
|
||||||
|
// take username or dbname from client's startup message.
|
||||||
|
let mut config = compute::ConnCfg::new();
|
||||||
|
config
|
||||||
|
.host(&db_info.host)
|
||||||
|
.port(db_info.port)
|
||||||
|
.dbname(&db_info.dbname)
|
||||||
|
.user(&db_info.user);
|
||||||
|
|
||||||
|
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||||
|
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||||
|
// everywhere, we can remove this.
|
||||||
|
if db_info.host.contains("--") {
|
||||||
|
// we need TLS connection with SNI info to properly route it
|
||||||
|
config.ssl_mode(SslMode::Require);
|
||||||
|
} else {
|
||||||
|
config.ssl_mode(SslMode::Disable);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(password) = db_info.password {
|
||||||
|
config.password(password.as_ref());
|
||||||
|
}
|
||||||
|
|
||||||
|
let node_info = CachedNodeInfo::new_uncached(NodeInfo {
|
||||||
|
config,
|
||||||
|
aux: db_info.aux,
|
||||||
|
allow_self_signed_compute,
|
||||||
|
});
|
||||||
|
let user_info = BackendType::Link(link);
|
||||||
|
|
||||||
|
Ok(Box::new(NeedsComputeConnection {
|
||||||
|
stream,
|
||||||
|
user_info,
|
||||||
|
mechanism: TcpMechanism { params },
|
||||||
|
node_info,
|
||||||
|
ctx,
|
||||||
|
cancel_session,
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
//! User credentials used in authentication.
|
//! User credentials used in authentication.
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
|
auth::password_hack::parse_endpoint_param,
|
||||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions,
|
context::RequestMonitoring,
|
||||||
|
error::{ReportableError, UserFacingError},
|
||||||
|
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
|
||||||
|
proxy::NeonOptions,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use pq_proto::StartupMessageParams;
|
use pq_proto::StartupMessageParams;
|
||||||
@@ -33,7 +36,24 @@ pub enum ComputeUserInfoParseError {
|
|||||||
MalformedProjectName(SmolStr),
|
MalformedProjectName(SmolStr),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UserFacingError for ComputeUserInfoParseError {}
|
impl ReportableError for ComputeUserInfoParseError {
|
||||||
|
fn get_error_type(&self) -> crate::error::ErrorKind {
|
||||||
|
match self {
|
||||||
|
ComputeUserInfoParseError::MissingKey(_) => crate::error::ErrorKind::User,
|
||||||
|
ComputeUserInfoParseError::InconsistentProjectNames { .. } => {
|
||||||
|
crate::error::ErrorKind::User
|
||||||
|
}
|
||||||
|
ComputeUserInfoParseError::UnknownCommonName { .. } => crate::error::ErrorKind::User,
|
||||||
|
ComputeUserInfoParseError::MalformedProjectName(_) => crate::error::ErrorKind::User,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserFacingError for ComputeUserInfoParseError {
|
||||||
|
fn to_string_client(&self) -> String {
|
||||||
|
self.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Various client credentials which we use for authentication.
|
/// Various client credentials which we use for authentication.
|
||||||
/// Note that we don't store any kind of client key or password here.
|
/// Note that we don't store any kind of client key or password here.
|
||||||
|
|||||||
@@ -164,6 +164,13 @@ async fn task_main(
|
|||||||
let tls_config = Arc::clone(&tls_config);
|
let tls_config = Arc::clone(&tls_config);
|
||||||
let dest_suffix = Arc::clone(&dest_suffix);
|
let dest_suffix = Arc::clone(&dest_suffix);
|
||||||
|
|
||||||
|
let root_span = tracing::info_span!(
|
||||||
|
"handle_client",
|
||||||
|
?session_id,
|
||||||
|
endpoint = tracing::field::Empty
|
||||||
|
);
|
||||||
|
let root_span2 = root_span.clone();
|
||||||
|
|
||||||
connections.spawn(
|
connections.spawn(
|
||||||
async move {
|
async move {
|
||||||
socket
|
socket
|
||||||
@@ -171,8 +178,13 @@ async fn task_main(
|
|||||||
.context("failed to set socket option")?;
|
.context("failed to set socket option")?;
|
||||||
|
|
||||||
info!(%peer_addr, "serving");
|
info!(%peer_addr, "serving");
|
||||||
let mut ctx =
|
let mut ctx = RequestMonitoring::new(
|
||||||
RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
|
session_id,
|
||||||
|
peer_addr.ip(),
|
||||||
|
"sni_router",
|
||||||
|
"sni",
|
||||||
|
root_span2,
|
||||||
|
);
|
||||||
handle_client(
|
handle_client(
|
||||||
&mut ctx,
|
&mut ctx,
|
||||||
dest_suffix,
|
dest_suffix,
|
||||||
@@ -186,7 +198,7 @@ async fn task_main(
|
|||||||
// Acknowledge that the task has finished with an error.
|
// Acknowledge that the task has finished with an error.
|
||||||
error!("per-client task finished with an error: {e:#}");
|
error!("per-client task finished with an error: {e:#}");
|
||||||
})
|
})
|
||||||
.instrument(tracing::info_span!("handle_client", ?session_id)),
|
.instrument(root_span),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,6 +283,7 @@ async fn handle_client(
|
|||||||
|
|
||||||
let client = tokio::net::TcpStream::connect(destination).await?;
|
let client = tokio::net::TcpStream::connect(destination).await?;
|
||||||
|
|
||||||
|
ctx.log();
|
||||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||||
proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await
|
proxy::proxy::pass::proxy_pass(tls_stream, client, metrics_aux).await
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use anyhow::{bail, Context};
|
use anyhow::Context;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use pq_proto::CancelKeyData;
|
use pq_proto::CancelKeyData;
|
||||||
use std::net::SocketAddr;
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_postgres::{CancelToken, NoTls};
|
use tokio_postgres::{CancelToken, NoTls};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
@@ -25,39 +25,33 @@ impl CancelMap {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||||
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
|
pub fn get_session(self: Arc<Self>) -> Session {
|
||||||
where
|
|
||||||
F: FnOnce(Session<'a>) -> R,
|
|
||||||
R: std::future::Future<Output = anyhow::Result<V>>,
|
|
||||||
{
|
|
||||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||||
// expose it and we don't want to do another roundtrip to query
|
// expose it and we don't want to do another roundtrip to query
|
||||||
// for it. The client will be able to notice that this is not the
|
// for it. The client will be able to notice that this is not the
|
||||||
// actual backend_pid, but backend_pid is not used for anything
|
// actual backend_pid, but backend_pid is not used for anything
|
||||||
// so it doesn't matter.
|
// so it doesn't matter.
|
||||||
let key = rand::random();
|
let key = loop {
|
||||||
|
let key = rand::random();
|
||||||
|
|
||||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||||
// which is why we have to take care not to rewrite an existing key.
|
// which is why we have to take care not to rewrite an existing key.
|
||||||
match self.0.entry(key) {
|
match self.0.entry(key) {
|
||||||
dashmap::mapref::entry::Entry::Occupied(_) => {
|
dashmap::mapref::entry::Entry::Occupied(_) => {
|
||||||
bail!("query cancellation key already exists: {key}")
|
continue;
|
||||||
|
}
|
||||||
|
dashmap::mapref::entry::Entry::Vacant(e) => {
|
||||||
|
e.insert(None);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
dashmap::mapref::entry::Entry::Vacant(e) => {
|
break key;
|
||||||
e.insert(None);
|
};
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This will guarantee that the session gets dropped
|
|
||||||
// as soon as the future is finished.
|
|
||||||
scopeguard::defer! {
|
|
||||||
self.0.remove(&key);
|
|
||||||
info!("dropped query cancellation key {key}");
|
|
||||||
}
|
|
||||||
|
|
||||||
info!("registered new query cancellation key {key}");
|
info!("registered new query cancellation key {key}");
|
||||||
let session = Session::new(key, self);
|
Session {
|
||||||
f(session).await
|
key,
|
||||||
|
cancel_map: self,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -98,23 +92,17 @@ impl CancelClosure {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Helper for registering query cancellation tokens.
|
/// Helper for registering query cancellation tokens.
|
||||||
pub struct Session<'a> {
|
pub struct Session {
|
||||||
/// The user-facing key identifying this session.
|
/// The user-facing key identifying this session.
|
||||||
key: CancelKeyData,
|
key: CancelKeyData,
|
||||||
/// The [`CancelMap`] this session belongs to.
|
/// The [`CancelMap`] this session belongs to.
|
||||||
cancel_map: &'a CancelMap,
|
cancel_map: Arc<CancelMap>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Session<'a> {
|
impl Session {
|
||||||
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
|
|
||||||
Self { key, cancel_map }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Session<'_> {
|
|
||||||
/// Store the cancel token for the given session.
|
/// Store the cancel token for the given session.
|
||||||
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
||||||
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||||
info!("enabling query cancellation for this session");
|
info!("enabling query cancellation for this session");
|
||||||
self.cancel_map.0.insert(self.key, Some(cancel_closure));
|
self.cancel_map.0.insert(self.key, Some(cancel_closure));
|
||||||
|
|
||||||
@@ -122,37 +110,26 @@ impl Session<'_> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for Session {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.cancel_map.0.remove(&self.key);
|
||||||
|
info!("dropped query cancellation key {}", &self.key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use once_cell::sync::Lazy;
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn check_session_drop() -> anyhow::Result<()> {
|
async fn check_session_drop() -> anyhow::Result<()> {
|
||||||
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
|
let cancel_map: Arc<CancelMap> = Default::default();
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
|
||||||
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
|
|
||||||
assert!(CANCEL_MAP.contains(&session));
|
|
||||||
|
|
||||||
tx.send(()).expect("failed to send");
|
|
||||||
futures::future::pending::<()>().await; // sleep forever
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Wait until the task has been spawned.
|
|
||||||
rx.await.context("failed to hear from the task")?;
|
|
||||||
|
|
||||||
// Drop the session's entry by cancelling the task.
|
|
||||||
task.abort();
|
|
||||||
let error = task.await.expect_err("task should have failed");
|
|
||||||
if !error.is_cancelled() {
|
|
||||||
anyhow::bail!(error);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
let session = cancel_map.clone().get_session();
|
||||||
|
assert!(cancel_map.contains(&session));
|
||||||
|
drop(session);
|
||||||
// Check that the session has been dropped.
|
// Check that the session has been dropped.
|
||||||
assert!(CANCEL_MAP.is_empty());
|
assert!(cancel_map.is_empty());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
|
auth::parse_endpoint_param,
|
||||||
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
|
cancellation::CancelClosure,
|
||||||
|
console::errors::WakeComputeError,
|
||||||
|
context::RequestMonitoring,
|
||||||
|
error::{ReportableError, UserFacingError},
|
||||||
|
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||||
proxy::neon_option,
|
proxy::neon_option,
|
||||||
};
|
};
|
||||||
use futures::{FutureExt, TryFutureExt};
|
use futures::{FutureExt, TryFutureExt};
|
||||||
@@ -32,6 +36,17 @@ pub enum ConnectionError {
|
|||||||
WakeComputeError(#[from] WakeComputeError),
|
WakeComputeError(#[from] WakeComputeError),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ReportableError for ConnectionError {
|
||||||
|
fn get_error_type(&self) -> crate::error::ErrorKind {
|
||||||
|
match self {
|
||||||
|
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||||
|
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
|
||||||
|
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||||
|
ConnectionError::WakeComputeError(_) => crate::error::ErrorKind::ControlPlane,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UserFacingError for ConnectionError {
|
impl UserFacingError for ConnectionError {
|
||||||
fn to_string_client(&self) -> String {
|
fn to_string_client(&self) -> String {
|
||||||
use ConnectionError::*;
|
use ConnectionError::*;
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ use tracing::info;
|
|||||||
|
|
||||||
pub mod errors {
|
pub mod errors {
|
||||||
use crate::{
|
use crate::{
|
||||||
error::{io_error, UserFacingError},
|
error::{io_error, ReportableError, UserFacingError},
|
||||||
http,
|
http,
|
||||||
proxy::retry::ShouldRetry,
|
proxy::retry::ShouldRetry,
|
||||||
};
|
};
|
||||||
@@ -56,6 +56,15 @@ pub mod errors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ReportableError for ApiError {
|
||||||
|
fn get_error_type(&self) -> crate::error::ErrorKind {
|
||||||
|
match self {
|
||||||
|
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
|
||||||
|
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UserFacingError for ApiError {
|
impl UserFacingError for ApiError {
|
||||||
fn to_string_client(&self) -> String {
|
fn to_string_client(&self) -> String {
|
||||||
use ApiError::*;
|
use ApiError::*;
|
||||||
@@ -140,6 +149,15 @@ pub mod errors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ReportableError for GetAuthInfoError {
|
||||||
|
fn get_error_type(&self) -> crate::error::ErrorKind {
|
||||||
|
match self {
|
||||||
|
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||||
|
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UserFacingError for GetAuthInfoError {
|
impl UserFacingError for GetAuthInfoError {
|
||||||
fn to_string_client(&self) -> String {
|
fn to_string_client(&self) -> String {
|
||||||
use GetAuthInfoError::*;
|
use GetAuthInfoError::*;
|
||||||
@@ -181,6 +199,16 @@ pub mod errors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ReportableError for WakeComputeError {
|
||||||
|
fn get_error_type(&self) -> crate::error::ErrorKind {
|
||||||
|
match self {
|
||||||
|
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
|
||||||
|
WakeComputeError::ApiError(e) => e.get_error_type(),
|
||||||
|
WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UserFacingError for WakeComputeError {
|
impl UserFacingError for WakeComputeError {
|
||||||
fn to_string_client(&self) -> String {
|
fn to_string_client(&self) -> String {
|
||||||
use WakeComputeError::*;
|
use WakeComputeError::*;
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ pub struct RequestMonitoring {
|
|||||||
// This sender is here to keep the request monitoring channel open while requests are taking place.
|
// This sender is here to keep the request monitoring channel open while requests are taking place.
|
||||||
sender: Option<mpsc::UnboundedSender<RequestMonitoring>>,
|
sender: Option<mpsc::UnboundedSender<RequestMonitoring>>,
|
||||||
pub latency_timer: LatencyTimer,
|
pub latency_timer: LatencyTimer,
|
||||||
|
root_span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RequestMonitoring {
|
impl RequestMonitoring {
|
||||||
@@ -46,6 +47,7 @@ impl RequestMonitoring {
|
|||||||
peer_addr: IpAddr,
|
peer_addr: IpAddr,
|
||||||
protocol: &'static str,
|
protocol: &'static str,
|
||||||
region: &'static str,
|
region: &'static str,
|
||||||
|
root_span: tracing::Span,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
peer_addr,
|
peer_addr,
|
||||||
@@ -64,12 +66,19 @@ impl RequestMonitoring {
|
|||||||
|
|
||||||
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
||||||
latency_timer: LatencyTimer::new(protocol),
|
latency_timer: LatencyTimer::new(protocol),
|
||||||
|
root_span,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub fn test() -> Self {
|
pub fn test() -> Self {
|
||||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
|
RequestMonitoring::new(
|
||||||
|
Uuid::now_v7(),
|
||||||
|
[127, 0, 0, 1].into(),
|
||||||
|
"test",
|
||||||
|
"test",
|
||||||
|
tracing::Span::none(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn console_application_name(&self) -> String {
|
pub fn console_application_name(&self) -> String {
|
||||||
@@ -87,7 +96,10 @@ impl RequestMonitoring {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_endpoint_id(&mut self, endpoint_id: Option<SmolStr>) {
|
pub fn set_endpoint_id(&mut self, endpoint_id: Option<SmolStr>) {
|
||||||
self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone());
|
if let (None, Some(ep)) = (self.endpoint_id.as_ref(), endpoint_id) {
|
||||||
|
self.root_span.record("ep", &*ep);
|
||||||
|
self.endpoint_id = Some(ep)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_application(&mut self, app: Option<SmolStr>) {
|
pub fn set_application(&mut self, app: Option<SmolStr>) {
|
||||||
@@ -102,6 +114,10 @@ impl RequestMonitoring {
|
|||||||
self.success = true;
|
self.success = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn error(&mut self, err: ErrorKind) {
|
||||||
|
self.error_kind = Some(err);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn log(&mut self) {
|
pub fn log(&mut self) {
|
||||||
if let Some(tx) = self.sender.take() {
|
if let Some(tx) = self.sender.take() {
|
||||||
let _: Result<(), _> = tx.send(self.clone());
|
let _: Result<(), _> = tx.send(self.clone());
|
||||||
|
|||||||
@@ -17,19 +17,16 @@ pub fn log_error<E: fmt::Display>(e: E) -> E {
|
|||||||
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
|
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
|
||||||
/// is way too convenient and tends to proliferate all across the codebase,
|
/// is way too convenient and tends to proliferate all across the codebase,
|
||||||
/// ultimately leading to accidental leaks of sensitive data.
|
/// ultimately leading to accidental leaks of sensitive data.
|
||||||
pub trait UserFacingError: fmt::Display {
|
pub trait UserFacingError: ReportableError {
|
||||||
/// Format the error for client, stripping all sensitive info.
|
/// Format the error for client, stripping all sensitive info.
|
||||||
///
|
///
|
||||||
/// Although this might be a no-op for many types, it's highly
|
/// Although this might be a no-op for many types, it's highly
|
||||||
/// recommended to override the default impl in case error type
|
/// recommended to override the default impl in case error type
|
||||||
/// contains anything sensitive: various IDs, IP addresses etc.
|
/// contains anything sensitive: various IDs, IP addresses etc.
|
||||||
#[inline(always)]
|
fn to_string_client(&self) -> String;
|
||||||
fn to_string_client(&self) -> String {
|
|
||||||
self.to_string()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Copy)]
|
||||||
pub enum ErrorKind {
|
pub enum ErrorKind {
|
||||||
/// Wrong password, unknown endpoint, protocol violation, etc...
|
/// Wrong password, unknown endpoint, protocol violation, etc...
|
||||||
User,
|
User,
|
||||||
@@ -62,3 +59,7 @@ impl ErrorKind {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait ReportableError: fmt::Display + Send + 'static {
|
||||||
|
fn get_error_type(&self) -> ErrorKind;
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ pub mod redis;
|
|||||||
pub mod sasl;
|
pub mod sasl;
|
||||||
pub mod scram;
|
pub mod scram;
|
||||||
pub mod serverless;
|
pub mod serverless;
|
||||||
|
pub mod state_machine;
|
||||||
pub mod stream;
|
pub mod stream;
|
||||||
pub mod url;
|
pub mod url;
|
||||||
pub mod usage_metrics;
|
pub mod usage_metrics;
|
||||||
|
|||||||
@@ -2,38 +2,32 @@
|
|||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
pub mod connect_compute;
|
pub mod connect_compute;
|
||||||
|
pub mod handshake;
|
||||||
|
pub mod pass;
|
||||||
pub mod retry;
|
pub mod retry;
|
||||||
|
pub mod wake_compute;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth,
|
cancellation::CancelMap,
|
||||||
cancellation::{self, CancelMap},
|
config::{ProxyConfig, TlsConfig},
|
||||||
compute,
|
|
||||||
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
|
|
||||||
console::messages::MetricsAuxInfo,
|
|
||||||
context::RequestMonitoring,
|
context::RequestMonitoring,
|
||||||
metrics::{
|
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
|
||||||
NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
|
|
||||||
NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE,
|
|
||||||
},
|
|
||||||
protocol2::WithClientIp,
|
protocol2::WithClientIp,
|
||||||
|
proxy::handshake::NeedsHandshake,
|
||||||
rate_limiter::EndpointRateLimiter,
|
rate_limiter::EndpointRateLimiter,
|
||||||
stream::{PqStream, Stream},
|
state_machine::{DynStage, StageResult},
|
||||||
usage_metrics::{Ids, USAGE_METRICS},
|
stream::Stream,
|
||||||
};
|
};
|
||||||
use anyhow::{bail, Context};
|
use anyhow::Context;
|
||||||
use futures::TryFutureExt;
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use once_cell::sync::OnceCell;
|
use once_cell::sync::OnceCell;
|
||||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
use pq_proto::StartupMessageParams;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use smol_str::SmolStr;
|
use smol_str::SmolStr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{error, info, info_span, Instrument};
|
use tracing::{error, info, info_span, Instrument};
|
||||||
use utils::measured_stream::MeasuredStream;
|
|
||||||
|
|
||||||
use self::connect_compute::{connect_to_compute, TcpMechanism};
|
|
||||||
|
|
||||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||||
const ERR_PROTO_VIOLATION: &str = "protocol violation";
|
const ERR_PROTO_VIOLATION: &str = "protocol violation";
|
||||||
@@ -79,45 +73,64 @@ pub async fn task_main(
|
|||||||
let cancel_map = Arc::clone(&cancel_map);
|
let cancel_map = Arc::clone(&cancel_map);
|
||||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||||
|
|
||||||
|
let root_span = info_span!(
|
||||||
|
"handle_client",
|
||||||
|
?session_id,
|
||||||
|
peer_addr = tracing::field::Empty,
|
||||||
|
ep = tracing::field::Empty,
|
||||||
|
);
|
||||||
|
let root_span2 = root_span.clone();
|
||||||
|
|
||||||
connections.spawn(
|
connections.spawn(
|
||||||
async move {
|
async move {
|
||||||
info!("accepted postgres client connection");
|
info!("accepted postgres client connection");
|
||||||
|
|
||||||
let mut socket = WithClientIp::new(socket);
|
let mut socket = WithClientIp::new(socket);
|
||||||
let mut peer_addr = peer_addr.ip();
|
let mut peer_addr = peer_addr.ip();
|
||||||
if let Some(addr) = socket.wait_for_addr().await? {
|
match socket.wait_for_addr().await {
|
||||||
peer_addr = addr.ip();
|
Err(e) => {
|
||||||
tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
|
error!("IO error: {e:#}");
|
||||||
} else if config.require_client_ip {
|
return;
|
||||||
bail!("missing required client IP");
|
}
|
||||||
}
|
Ok(Some(addr)) => {
|
||||||
|
peer_addr = addr.ip();
|
||||||
|
root_span2.record("peer_addr", &tracing::field::display(addr));
|
||||||
|
}
|
||||||
|
Ok(None) if config.require_client_ip => {
|
||||||
|
error!("missing required client IP");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Ok(None) => {}
|
||||||
|
};
|
||||||
|
|
||||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
|
let ctx = RequestMonitoring::new(
|
||||||
|
session_id,
|
||||||
|
peer_addr,
|
||||||
|
"tcp",
|
||||||
|
&config.region,
|
||||||
|
root_span2,
|
||||||
|
);
|
||||||
|
|
||||||
socket
|
if let Err(e) = socket
|
||||||
.inner
|
.inner
|
||||||
.set_nodelay(true)
|
.set_nodelay(true)
|
||||||
.context("failed to set socket option")?;
|
.context("failed to set socket option")
|
||||||
|
{
|
||||||
|
error!("could not set nodelay: {e:#}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
handle_client(
|
handle_client(
|
||||||
config,
|
config,
|
||||||
&mut ctx,
|
ctx,
|
||||||
&cancel_map,
|
cancel_map,
|
||||||
socket,
|
socket,
|
||||||
ClientMode::Tcp,
|
ClientMode::Tcp,
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
)
|
)
|
||||||
.await
|
.await;
|
||||||
}
|
}
|
||||||
.instrument(info_span!(
|
.instrument(root_span),
|
||||||
"handle_client",
|
|
||||||
?session_id,
|
|
||||||
peer_addr = tracing::field::Empty
|
|
||||||
))
|
|
||||||
.unwrap_or_else(move |e| {
|
|
||||||
// Acknowledge that the task has finished with an error.
|
|
||||||
error!(?session_id, "per-client task finished with an error: {e:#}");
|
|
||||||
}),
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,14 +150,14 @@ pub enum ClientMode {
|
|||||||
|
|
||||||
/// Abstracts the logic of handling TCP vs WS clients
|
/// Abstracts the logic of handling TCP vs WS clients
|
||||||
impl ClientMode {
|
impl ClientMode {
|
||||||
fn allow_cleartext(&self) -> bool {
|
pub fn allow_cleartext(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
ClientMode::Tcp => false,
|
ClientMode::Tcp => false,
|
||||||
ClientMode::Websockets { .. } => true,
|
ClientMode::Websockets { .. } => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||||
match self {
|
match self {
|
||||||
ClientMode::Tcp => config.allow_self_signed_compute,
|
ClientMode::Tcp => config.allow_self_signed_compute,
|
||||||
ClientMode::Websockets { .. } => false,
|
ClientMode::Websockets { .. } => false,
|
||||||
@@ -167,14 +180,14 @@ impl ClientMode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + 'static + Send>(
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: RequestMonitoring,
|
||||||
cancel_map: &CancelMap,
|
cancel_map: Arc<CancelMap>,
|
||||||
stream: S,
|
stream: S,
|
||||||
mode: ClientMode,
|
mode: ClientMode,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
) -> anyhow::Result<()> {
|
) {
|
||||||
info!(
|
info!(
|
||||||
protocol = ctx.protocol,
|
protocol = ctx.protocol,
|
||||||
"handling interactive connection from client"
|
"handling interactive connection from client"
|
||||||
@@ -188,308 +201,23 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
|||||||
.with_label_values(&[proto])
|
.with_label_values(&[proto])
|
||||||
.guard();
|
.guard();
|
||||||
|
|
||||||
let tls = config.tls_config.as_ref();
|
let mut stage = Box::new(NeedsHandshake {
|
||||||
|
|
||||||
let pause = ctx.latency_timer.pause();
|
|
||||||
let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
|
|
||||||
let (mut stream, params) = match do_handshake.await? {
|
|
||||||
Some(x) => x,
|
|
||||||
None => return Ok(()), // it's a cancellation request
|
|
||||||
};
|
|
||||||
drop(pause);
|
|
||||||
|
|
||||||
// Extract credentials which we're going to use for auth.
|
|
||||||
let user_info = {
|
|
||||||
let hostname = mode.hostname(stream.get_ref());
|
|
||||||
|
|
||||||
let common_names = tls.map(|tls| &tls.common_names);
|
|
||||||
let result = config
|
|
||||||
.auth_backend
|
|
||||||
.as_ref()
|
|
||||||
.map(|_| {
|
|
||||||
auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)
|
|
||||||
})
|
|
||||||
.transpose();
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(user_info) => user_info,
|
|
||||||
Err(e) => stream.throw_error(e).await?,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
ctx.set_endpoint_id(user_info.get_endpoint());
|
|
||||||
|
|
||||||
let client = Client::new(
|
|
||||||
stream,
|
stream,
|
||||||
user_info,
|
config,
|
||||||
¶ms,
|
cancel_map,
|
||||||
mode.allow_self_signed_compute(config),
|
mode,
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
);
|
ctx,
|
||||||
cancel_map
|
}) as DynStage;
|
||||||
.with_session(|session| {
|
|
||||||
client.connect_to_db(ctx, session, mode, &config.authentication_config)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Establish a (most probably, secure) connection with the client.
|
while let StageResult::Run(handle) = stage.run() {
|
||||||
/// For better testing experience, `stream` can be any object satisfying the traits.
|
stage = match handle.await.expect("tasks should not panic") {
|
||||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
Ok(s) => s,
|
||||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
|
||||||
#[tracing::instrument(skip_all)]
|
|
||||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
|
||||||
stream: S,
|
|
||||||
mut tls: Option<&TlsConfig>,
|
|
||||||
cancel_map: &CancelMap,
|
|
||||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
|
|
||||||
// Client may try upgrading to each protocol only once
|
|
||||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
|
||||||
|
|
||||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
|
||||||
loop {
|
|
||||||
let msg = stream.read_startup_packet().await?;
|
|
||||||
info!("received {msg:?}");
|
|
||||||
|
|
||||||
use FeStartupPacket::*;
|
|
||||||
match msg {
|
|
||||||
SslRequest => match stream.get_ref() {
|
|
||||||
Stream::Raw { .. } if !tried_ssl => {
|
|
||||||
tried_ssl = true;
|
|
||||||
|
|
||||||
// We can't perform TLS handshake without a config
|
|
||||||
let enc = tls.is_some();
|
|
||||||
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
|
||||||
if let Some(tls) = tls.take() {
|
|
||||||
// Upgrade raw stream into a secure TLS-backed stream.
|
|
||||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
|
||||||
|
|
||||||
let (raw, read_buf) = stream.into_inner();
|
|
||||||
// TODO: Normally, client doesn't send any data before
|
|
||||||
// server says TLS handshake is ok and read_buf is empy.
|
|
||||||
// However, you could imagine pipelining of postgres
|
|
||||||
// SSLRequest + TLS ClientHello in one hunk similar to
|
|
||||||
// pipelining in our node js driver. We should probably
|
|
||||||
// support that by chaining read_buf with the stream.
|
|
||||||
if !read_buf.is_empty() {
|
|
||||||
bail!("data is sent before server replied with EncryptionResponse");
|
|
||||||
}
|
|
||||||
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
|
||||||
|
|
||||||
let (_, tls_server_end_point) = tls
|
|
||||||
.cert_resolver
|
|
||||||
.resolve(tls_stream.get_ref().1.server_name())
|
|
||||||
.context("missing certificate")?;
|
|
||||||
|
|
||||||
stream = PqStream::new(Stream::Tls {
|
|
||||||
tls: Box::new(tls_stream),
|
|
||||||
tls_server_end_point,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => bail!(ERR_PROTO_VIOLATION),
|
|
||||||
},
|
|
||||||
GssEncRequest => match stream.get_ref() {
|
|
||||||
Stream::Raw { .. } if !tried_gss => {
|
|
||||||
tried_gss = true;
|
|
||||||
|
|
||||||
// Currently, we don't support GSSAPI
|
|
||||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
|
||||||
}
|
|
||||||
_ => bail!(ERR_PROTO_VIOLATION),
|
|
||||||
},
|
|
||||||
StartupMessage { params, .. } => {
|
|
||||||
// Check that the config has been consumed during upgrade
|
|
||||||
// OR we didn't provide it at all (for dev purposes).
|
|
||||||
if tls.is_some() {
|
|
||||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
info!(session_type = "normal", "successful handshake");
|
|
||||||
break Ok(Some((stream, params)));
|
|
||||||
}
|
|
||||||
CancelRequest(cancel_key_data) => {
|
|
||||||
cancel_map.cancel_session(cancel_key_data).await?;
|
|
||||||
|
|
||||||
info!(session_type = "cancellation", "successful handshake");
|
|
||||||
break Ok(None);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
|
||||||
#[tracing::instrument(skip_all)]
|
|
||||||
async fn prepare_client_connection(
|
|
||||||
node: &compute::PostgresConnection,
|
|
||||||
session: cancellation::Session<'_>,
|
|
||||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
// Register compute's query cancellation token and produce a new, unique one.
|
|
||||||
// The new token (cancel_key_data) will be sent to the client.
|
|
||||||
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
|
|
||||||
|
|
||||||
// Forward all postgres connection params to the client.
|
|
||||||
// Right now the implementation is very hacky and inefficent (ideally,
|
|
||||||
// we don't need an intermediate hashmap), but at least it should be correct.
|
|
||||||
for (name, value) in &node.params {
|
|
||||||
// TODO: Theoretically, this could result in a big pile of params...
|
|
||||||
stream.write_message_noflush(&Be::ParameterStatus {
|
|
||||||
name: name.as_bytes(),
|
|
||||||
value: value.as_bytes(),
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
|
|
||||||
stream
|
|
||||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
|
||||||
.write_message(&Be::ReadyForQuery)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Forward bytes in both directions (client <-> compute).
|
|
||||||
#[tracing::instrument(skip_all)]
|
|
||||||
pub async fn proxy_pass(
|
|
||||||
ctx: &mut RequestMonitoring,
|
|
||||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
|
||||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
|
||||||
aux: MetricsAuxInfo,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
ctx.set_success();
|
|
||||||
ctx.log();
|
|
||||||
|
|
||||||
let usage = USAGE_METRICS.register(Ids {
|
|
||||||
endpoint_id: aux.endpoint_id.clone(),
|
|
||||||
branch_id: aux.branch_id.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
|
|
||||||
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
|
|
||||||
let mut client = MeasuredStream::new(
|
|
||||||
client,
|
|
||||||
|_| {},
|
|
||||||
|cnt| {
|
|
||||||
// Number of bytes we sent to the client (outbound).
|
|
||||||
m_sent.inc_by(cnt as u64);
|
|
||||||
m_sent2.inc_by(cnt as u64);
|
|
||||||
usage.record_egress(cnt as u64);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
|
|
||||||
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
|
|
||||||
let mut compute = MeasuredStream::new(
|
|
||||||
compute,
|
|
||||||
|_| {},
|
|
||||||
|cnt| {
|
|
||||||
// Number of bytes the client sent to the compute node (inbound).
|
|
||||||
m_recv.inc_by(cnt as u64);
|
|
||||||
m_recv2.inc_by(cnt as u64);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Starting from here we only proxy the client's traffic.
|
|
||||||
info!("performing the proxy pass...");
|
|
||||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Thin connection context.
|
|
||||||
struct Client<'a, S> {
|
|
||||||
/// The underlying libpq protocol stream.
|
|
||||||
stream: PqStream<Stream<S>>,
|
|
||||||
/// Client credentials that we care about.
|
|
||||||
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
|
|
||||||
/// KV-dictionary with PostgreSQL connection params.
|
|
||||||
params: &'a StartupMessageParams,
|
|
||||||
/// Allow self-signed certificates (for testing).
|
|
||||||
allow_self_signed_compute: bool,
|
|
||||||
/// Rate limiter for endpoints
|
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, S> Client<'a, S> {
|
|
||||||
/// Construct a new connection context.
|
|
||||||
fn new(
|
|
||||||
stream: PqStream<Stream<S>>,
|
|
||||||
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
|
|
||||||
params: &'a StartupMessageParams,
|
|
||||||
allow_self_signed_compute: bool,
|
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
stream,
|
|
||||||
user_info,
|
|
||||||
params,
|
|
||||||
allow_self_signed_compute,
|
|
||||||
endpoint_rate_limiter,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
|
||||||
/// Let the client authenticate and connect to the designated compute node.
|
|
||||||
// Instrumentation logs endpoint name everywhere. Doesn't work for link
|
|
||||||
// auth; strictly speaking we don't know endpoint name in its case.
|
|
||||||
#[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)]
|
|
||||||
async fn connect_to_db(
|
|
||||||
self,
|
|
||||||
ctx: &mut RequestMonitoring,
|
|
||||||
session: cancellation::Session<'_>,
|
|
||||||
mode: ClientMode,
|
|
||||||
config: &'static AuthenticationConfig,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let Self {
|
|
||||||
mut stream,
|
|
||||||
user_info,
|
|
||||||
params,
|
|
||||||
allow_self_signed_compute,
|
|
||||||
endpoint_rate_limiter,
|
|
||||||
} = self;
|
|
||||||
|
|
||||||
// check rate limit
|
|
||||||
if let Some(ep) = user_info.get_endpoint() {
|
|
||||||
if !endpoint_rate_limiter.check(ep) {
|
|
||||||
return stream
|
|
||||||
.throw_error(auth::AuthError::too_many_connections())
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let user = user_info.get_user().to_owned();
|
|
||||||
let auth_result = match user_info
|
|
||||||
.authenticate(ctx, &mut stream, mode.allow_cleartext(), config)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(auth_result) => auth_result,
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let db = params.get("database");
|
e.finish().await;
|
||||||
let app = params.get("application_name");
|
break;
|
||||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
|
||||||
|
|
||||||
return stream.throw_error(e).instrument(params_span).await;
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
let (mut node_info, user_info) = auth_result;
|
|
||||||
|
|
||||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
|
||||||
|
|
||||||
let aux = node_info.aux.clone();
|
|
||||||
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info)
|
|
||||||
.or_else(|e| stream.throw_error(e))
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
prepare_client_connection(&node, session, &mut stream).await?;
|
|
||||||
// Before proxy passing, forward to compute whatever data is left in the
|
|
||||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
|
||||||
// driver in pipeline mode sends startup, password and first query
|
|
||||||
// immediately after opening the connection.
|
|
||||||
let (stream, read_buf) = stream.into_inner();
|
|
||||||
node.stream.write_all(&read_buf).await?;
|
|
||||||
proxy_pass(ctx, stream, node.stream, aux).await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,120 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
auth,
|
auth,
|
||||||
|
cancellation::{self, Session},
|
||||||
compute::{self, PostgresConnection},
|
compute::{self, PostgresConnection},
|
||||||
console::{self, errors::WakeComputeError, Api},
|
console::{self, errors::WakeComputeError, Api},
|
||||||
context::RequestMonitoring,
|
context::RequestMonitoring,
|
||||||
metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES},
|
metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES},
|
||||||
proxy::retry::{retry_after, ShouldRetry},
|
state_machine::{DynStage, ResultExt, Stage, StageError},
|
||||||
|
stream::{PqStream, Stream},
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use pq_proto::StartupMessageParams;
|
use pq_proto::StartupMessageParams;
|
||||||
use std::ops::ControlFlow;
|
use std::ops::ControlFlow;
|
||||||
use tokio::time;
|
use tokio::{
|
||||||
|
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
|
||||||
|
time,
|
||||||
|
};
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
|
use pq_proto::BeMessage as Be;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
pass::ProxyPass,
|
||||||
|
retry::{retry_after, ShouldRetry},
|
||||||
|
};
|
||||||
|
|
||||||
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||||
|
|
||||||
|
pub struct NeedsComputeConnection<S> {
|
||||||
|
pub stream: PqStream<Stream<S>>,
|
||||||
|
pub user_info: auth::BackendType<'static, auth::backend::ComputeUserInfo>,
|
||||||
|
pub mechanism: TcpMechanism,
|
||||||
|
pub node_info: console::CachedNodeInfo,
|
||||||
|
|
||||||
|
// monitoring
|
||||||
|
pub ctx: RequestMonitoring,
|
||||||
|
pub cancel_session: Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Stage for NeedsComputeConnection<S>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
fn span(&self) -> tracing::Span {
|
||||||
|
tracing::info_span!("connect_to_compute")
|
||||||
|
}
|
||||||
|
async fn run(self) -> Result<DynStage, StageError> {
|
||||||
|
let Self {
|
||||||
|
stream,
|
||||||
|
user_info,
|
||||||
|
mechanism,
|
||||||
|
node_info,
|
||||||
|
mut ctx,
|
||||||
|
cancel_session,
|
||||||
|
} = self;
|
||||||
|
|
||||||
|
let aux = node_info.aux.clone();
|
||||||
|
let (mut node, mut stream) =
|
||||||
|
connect_to_compute(&mut ctx, &mechanism, node_info, &user_info)
|
||||||
|
.await
|
||||||
|
.send_error_to_user(&mut ctx, stream)?;
|
||||||
|
|
||||||
|
prepare_client_connection(&node, &cancel_session, &mut stream)
|
||||||
|
.await
|
||||||
|
.no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?;
|
||||||
|
|
||||||
|
// Before proxy passing, forward to compute whatever data is left in the
|
||||||
|
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||||
|
// driver in pipeline mode sends startup, password and first query
|
||||||
|
// immediately after opening the connection.
|
||||||
|
let (stream, read_buf) = stream.into_inner();
|
||||||
|
|
||||||
|
node.stream
|
||||||
|
.write_all(&read_buf)
|
||||||
|
.await
|
||||||
|
.no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?;
|
||||||
|
|
||||||
|
Ok(Box::new(ProxyPass {
|
||||||
|
client: stream,
|
||||||
|
compute: node.stream,
|
||||||
|
aux,
|
||||||
|
cancel_session,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||||
|
#[tracing::instrument(skip_all)]
|
||||||
|
async fn prepare_client_connection(
|
||||||
|
node: &compute::PostgresConnection,
|
||||||
|
session: &cancellation::Session,
|
||||||
|
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||||
|
) -> std::io::Result<()> {
|
||||||
|
// Register compute's query cancellation token and produce a new, unique one.
|
||||||
|
// The new token (cancel_key_data) will be sent to the client.
|
||||||
|
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
|
||||||
|
|
||||||
|
// Forward all postgres connection params to the client.
|
||||||
|
// Right now the implementation is very hacky and inefficent (ideally,
|
||||||
|
// we don't need an intermediate hashmap), but at least it should be correct.
|
||||||
|
for (name, value) in &node.params {
|
||||||
|
// TODO: Theoretically, this could result in a big pile of params...
|
||||||
|
stream.write_message_noflush(&Be::ParameterStatus {
|
||||||
|
name: name.as_bytes(),
|
||||||
|
value: value.as_bytes(),
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
stream
|
||||||
|
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||||
|
.write_message(&Be::ReadyForQuery)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// If we couldn't connect, a cached connection info might be to blame
|
/// If we couldn't connect, a cached connection info might be to blame
|
||||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||||
@@ -63,13 +163,13 @@ pub trait ConnectMechanism {
|
|||||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TcpMechanism<'a> {
|
pub struct TcpMechanism {
|
||||||
/// KV-dictionary with PostgreSQL connection params.
|
/// KV-dictionary with PostgreSQL connection params.
|
||||||
pub params: &'a StartupMessageParams,
|
pub params: StartupMessageParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ConnectMechanism for TcpMechanism<'_> {
|
impl ConnectMechanism for TcpMechanism {
|
||||||
type Connection = PostgresConnection;
|
type Connection = PostgresConnection;
|
||||||
type ConnectError = compute::ConnectionError;
|
type ConnectError = compute::ConnectionError;
|
||||||
type Error = compute::ConnectionError;
|
type Error = compute::ConnectionError;
|
||||||
@@ -84,7 +184,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||||
config.set_startup_params(self.params);
|
config.set_startup_params(&self.params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
203
proxy/src/proxy/handshake.rs
Normal file
203
proxy/src/proxy/handshake.rs
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
use crate::{
|
||||||
|
auth::{self, backend::NeedsAuthentication},
|
||||||
|
cancellation::CancelMap,
|
||||||
|
config::{ProxyConfig, TlsConfig},
|
||||||
|
context::RequestMonitoring,
|
||||||
|
error::ReportableError,
|
||||||
|
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
|
||||||
|
rate_limiter::EndpointRateLimiter,
|
||||||
|
state_machine::{DynStage, Finished, ResultExt, Stage, StageError},
|
||||||
|
stream::{PqStream, Stream, StreamUpgradeError},
|
||||||
|
};
|
||||||
|
use anyhow::{anyhow, Context};
|
||||||
|
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||||
|
use std::{io, sync::Arc};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tracing::{error, info};
|
||||||
|
|
||||||
|
use super::ClientMode;
|
||||||
|
|
||||||
|
pub struct NeedsHandshake<S> {
|
||||||
|
pub stream: S,
|
||||||
|
pub config: &'static ProxyConfig,
|
||||||
|
pub cancel_map: Arc<CancelMap>,
|
||||||
|
pub mode: ClientMode,
|
||||||
|
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
|
|
||||||
|
// monitoring
|
||||||
|
pub ctx: RequestMonitoring,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsHandshake<S> {
|
||||||
|
fn span(&self) -> tracing::Span {
|
||||||
|
tracing::info_span!("handshake")
|
||||||
|
}
|
||||||
|
async fn run(self) -> Result<DynStage, StageError> {
|
||||||
|
let Self {
|
||||||
|
stream,
|
||||||
|
config,
|
||||||
|
cancel_map,
|
||||||
|
mode,
|
||||||
|
endpoint_rate_limiter,
|
||||||
|
mut ctx,
|
||||||
|
} = self;
|
||||||
|
|
||||||
|
let tls = config.tls_config.as_ref();
|
||||||
|
|
||||||
|
let pause_timer = ctx.latency_timer.pause();
|
||||||
|
let handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map).await;
|
||||||
|
drop(pause_timer);
|
||||||
|
|
||||||
|
let (stream, params) = match handshake {
|
||||||
|
Err(err) => {
|
||||||
|
// TODO: proper handling
|
||||||
|
error!("could not complete handshake: {err:#}");
|
||||||
|
return Err(StageError::Done);
|
||||||
|
}
|
||||||
|
// cancellation
|
||||||
|
Ok(None) => return Ok(Box::new(Finished)),
|
||||||
|
Ok(Some(s)) => s,
|
||||||
|
};
|
||||||
|
|
||||||
|
let hostname = mode.hostname(stream.get_ref());
|
||||||
|
|
||||||
|
let common_names = tls.map(|tls| &tls.common_names);
|
||||||
|
let (creds, stream) = config
|
||||||
|
.auth_backend
|
||||||
|
.as_ref()
|
||||||
|
.map(|_| {
|
||||||
|
auth::ComputeUserInfoMaybeEndpoint::parse(&mut ctx, ¶ms, hostname, common_names)
|
||||||
|
})
|
||||||
|
.transpose()
|
||||||
|
.send_error_to_user(&mut ctx, stream)?;
|
||||||
|
|
||||||
|
ctx.set_endpoint_id(creds.get_endpoint());
|
||||||
|
|
||||||
|
Ok(Box::new(NeedsAuthentication {
|
||||||
|
stream,
|
||||||
|
creds,
|
||||||
|
params,
|
||||||
|
endpoint_rate_limiter,
|
||||||
|
mode,
|
||||||
|
config,
|
||||||
|
|
||||||
|
ctx,
|
||||||
|
cancel_session: cancel_map.get_session(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum HandshakeError {
|
||||||
|
#[error("client disconnected: {0}")]
|
||||||
|
ClientIO(#[from] io::Error),
|
||||||
|
#[error("protocol violation: {0}")]
|
||||||
|
ProtocolError(#[from] anyhow::Error),
|
||||||
|
#[error("could not initiate tls connection: {0}")]
|
||||||
|
TLSError(#[from] StreamUpgradeError),
|
||||||
|
#[error("could not cancel connection: {0}")]
|
||||||
|
Cancel(anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReportableError for HandshakeError {
|
||||||
|
fn get_error_type(&self) -> crate::error::ErrorKind {
|
||||||
|
match self {
|
||||||
|
HandshakeError::ClientIO(_) => crate::error::ErrorKind::Disconnect,
|
||||||
|
HandshakeError::ProtocolError(_) => crate::error::ErrorKind::User,
|
||||||
|
HandshakeError::TLSError(_) => crate::error::ErrorKind::User,
|
||||||
|
HandshakeError::Cancel(_) => crate::error::ErrorKind::Compute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type SuccessfulHandshake<S> = (PqStream<Stream<S>>, StartupMessageParams);
|
||||||
|
|
||||||
|
/// Establish a (most probably, secure) connection with the client.
|
||||||
|
/// For better testing experience, `stream` can be any object satisfying the traits.
|
||||||
|
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||||
|
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||||
|
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||||
|
stream: S,
|
||||||
|
mut tls: Option<&TlsConfig>,
|
||||||
|
cancel_map: &CancelMap,
|
||||||
|
) -> Result<Option<SuccessfulHandshake<S>>, HandshakeError> {
|
||||||
|
// Client may try upgrading to each protocol only once
|
||||||
|
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||||
|
|
||||||
|
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||||
|
loop {
|
||||||
|
let msg = stream.read_startup_packet().await?;
|
||||||
|
info!("received {msg:?}");
|
||||||
|
|
||||||
|
use FeStartupPacket::*;
|
||||||
|
match msg {
|
||||||
|
SslRequest => match stream.get_ref() {
|
||||||
|
Stream::Raw { .. } if !tried_ssl => {
|
||||||
|
tried_ssl = true;
|
||||||
|
|
||||||
|
// We can't perform TLS handshake without a config
|
||||||
|
let enc = tls.is_some();
|
||||||
|
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
||||||
|
if let Some(tls) = tls.take() {
|
||||||
|
// Upgrade raw stream into a secure TLS-backed stream.
|
||||||
|
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||||
|
|
||||||
|
let (raw, read_buf) = stream.into_inner();
|
||||||
|
// TODO: Normally, client doesn't send any data before
|
||||||
|
// server says TLS handshake is ok and read_buf is empy.
|
||||||
|
// However, you could imagine pipelining of postgres
|
||||||
|
// SSLRequest + TLS ClientHello in one hunk similar to
|
||||||
|
// pipelining in our node js driver. We should probably
|
||||||
|
// support that by chaining read_buf with the stream.
|
||||||
|
if !read_buf.is_empty() {
|
||||||
|
return Err(HandshakeError::ProtocolError(anyhow!(
|
||||||
|
"data is sent before server replied with EncryptionResponse"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
|
||||||
|
|
||||||
|
let (_, tls_server_end_point) = tls
|
||||||
|
.cert_resolver
|
||||||
|
.resolve(tls_stream.get_ref().1.server_name())
|
||||||
|
.context("missing certificate")?;
|
||||||
|
|
||||||
|
stream = PqStream::new(Stream::Tls {
|
||||||
|
tls: Box::new(tls_stream),
|
||||||
|
tls_server_end_point,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(HandshakeError::ProtocolError(anyhow!(ERR_PROTO_VIOLATION))),
|
||||||
|
},
|
||||||
|
GssEncRequest => match stream.get_ref() {
|
||||||
|
Stream::Raw { .. } if !tried_gss => {
|
||||||
|
tried_gss = true;
|
||||||
|
|
||||||
|
// Currently, we don't support GSSAPI
|
||||||
|
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||||
|
}
|
||||||
|
_ => return Err(HandshakeError::ProtocolError(anyhow!(ERR_PROTO_VIOLATION))),
|
||||||
|
},
|
||||||
|
StartupMessage { params, .. } => {
|
||||||
|
// Check that the config has been consumed during upgrade
|
||||||
|
// OR we didn't provide it at all (for dev purposes).
|
||||||
|
if tls.is_some() {
|
||||||
|
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(session_type = "normal", "successful handshake");
|
||||||
|
break Ok(Some((stream, params)));
|
||||||
|
}
|
||||||
|
CancelRequest(cancel_key_data) => {
|
||||||
|
cancel_map
|
||||||
|
.cancel_session(cancel_key_data)
|
||||||
|
.await
|
||||||
|
.map_err(HandshakeError::Cancel)?;
|
||||||
|
|
||||||
|
info!(session_type = "cancellation", "successful handshake");
|
||||||
|
break Ok(None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
82
proxy/src/proxy/pass.rs
Normal file
82
proxy/src/proxy/pass.rs
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
use crate::{
|
||||||
|
cancellation::Session,
|
||||||
|
console::messages::MetricsAuxInfo,
|
||||||
|
metrics::{NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER},
|
||||||
|
state_machine::{DynStage, Finished, Stage, StageError},
|
||||||
|
stream::Stream,
|
||||||
|
usage_metrics::{Ids, USAGE_METRICS},
|
||||||
|
};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tracing::{error, info};
|
||||||
|
use utils::measured_stream::MeasuredStream;
|
||||||
|
|
||||||
|
pub struct ProxyPass<Client, Compute> {
|
||||||
|
pub client: Stream<Client>,
|
||||||
|
pub compute: Compute,
|
||||||
|
|
||||||
|
// monitoring
|
||||||
|
pub aux: MetricsAuxInfo,
|
||||||
|
pub cancel_session: Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Client, Compute> Stage for ProxyPass<Client, Compute>
|
||||||
|
where
|
||||||
|
Client: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||||
|
Compute: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
fn span(&self) -> tracing::Span {
|
||||||
|
tracing::info_span!("proxy_pass")
|
||||||
|
}
|
||||||
|
async fn run(self) -> Result<DynStage, StageError> {
|
||||||
|
if let Err(e) = proxy_pass(self.client, self.compute, self.aux).await {
|
||||||
|
error!("{e:#}")
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(self.cancel_session);
|
||||||
|
|
||||||
|
Ok(Box::new(Finished))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forward bytes in both directions (client <-> compute).
|
||||||
|
pub async fn proxy_pass(
|
||||||
|
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||||
|
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||||
|
aux: MetricsAuxInfo,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let usage = USAGE_METRICS.register(Ids {
|
||||||
|
endpoint_id: aux.endpoint_id.clone(),
|
||||||
|
branch_id: aux.branch_id.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
|
||||||
|
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
|
||||||
|
let mut client = MeasuredStream::new(
|
||||||
|
client,
|
||||||
|
|_| {},
|
||||||
|
|cnt| {
|
||||||
|
// Number of bytes we sent to the client (outbound).
|
||||||
|
m_sent.inc_by(cnt as u64);
|
||||||
|
m_sent2.inc_by(cnt as u64);
|
||||||
|
usage.record_egress(cnt as u64);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
|
||||||
|
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
|
||||||
|
let mut compute = MeasuredStream::new(
|
||||||
|
compute,
|
||||||
|
|_| {},
|
||||||
|
|cnt| {
|
||||||
|
// Number of bytes the client sent to the compute node (inbound).
|
||||||
|
m_recv.inc_by(cnt as u64);
|
||||||
|
m_recv2.inc_by(cnt as u64);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Starting from here we only proxy the client's traffic.
|
||||||
|
info!("performing the proxy pass...");
|
||||||
|
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -3,14 +3,19 @@
|
|||||||
mod mitm;
|
mod mitm;
|
||||||
|
|
||||||
use super::connect_compute::ConnectMechanism;
|
use super::connect_compute::ConnectMechanism;
|
||||||
|
use super::handshake::handshake;
|
||||||
use super::retry::ShouldRetry;
|
use super::retry::ShouldRetry;
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::auth::backend::{ComputeUserInfo, TestBackend};
|
use crate::auth::backend::{ComputeUserInfo, TestBackend};
|
||||||
use crate::config::CertResolver;
|
use crate::config::CertResolver;
|
||||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||||
|
use crate::proxy::connect_compute::connect_to_compute;
|
||||||
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||||
use crate::{auth, http, sasl, scram};
|
use crate::stream::PqStream;
|
||||||
|
use crate::{auth, compute, http, sasl, scram};
|
||||||
|
use anyhow::bail;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use pq_proto::BeMessage as Be;
|
||||||
use rstest::rstest;
|
use rstest::rstest;
|
||||||
use smol_str::SmolStr;
|
use smol_str::SmolStr;
|
||||||
use tokio_postgres::config::SslMode;
|
use tokio_postgres::config::SslMode;
|
||||||
@@ -202,7 +207,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
|||||||
.err() // -> Option<E>
|
.err() // -> Option<E>
|
||||||
.context("server shouldn't accept client")?;
|
.context("server shouldn't accept client")?;
|
||||||
|
|
||||||
assert!(client_err.to_string().contains(&server_err.to_string()));
|
assert!(server_err.to_string().contains(ERR_INSECURE_CONNECTION));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ use super::*;
|
|||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::{SinkExt, StreamExt};
|
||||||
use postgres_protocol::message::frontend;
|
use postgres_protocol::message::frontend;
|
||||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||||
use tokio_postgres::config::SslMode;
|
use tokio_postgres::config::SslMode;
|
||||||
use tokio_postgres::tls::TlsConnect;
|
use tokio_postgres::tls::TlsConnect;
|
||||||
use tokio_util::codec::{Decoder, Encoder};
|
use tokio_util::codec::{Decoder, Encoder};
|
||||||
|
|||||||
89
proxy/src/proxy/wake_compute.rs
Normal file
89
proxy/src/proxy/wake_compute.rs
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
use std::{borrow::Cow, ops::ControlFlow};
|
||||||
|
|
||||||
|
use pq_proto::StartupMessageParams;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tracing::{error, warn};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::{
|
||||||
|
backend::{ComputeCredentialKeys, ComputeCredentials},
|
||||||
|
BackendType,
|
||||||
|
},
|
||||||
|
cancellation::Session,
|
||||||
|
console::{provider::ConsoleBackend, Api},
|
||||||
|
context::RequestMonitoring,
|
||||||
|
state_machine::{user_facing_error, DynStage, Stage, StageError},
|
||||||
|
stream::{PqStream, Stream},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
connect_compute::{handle_try_wake, NeedsComputeConnection, TcpMechanism},
|
||||||
|
retry::retry_after,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct NeedsWakeCompute<S> {
|
||||||
|
pub stream: PqStream<Stream<S>>,
|
||||||
|
pub api: Cow<'static, ConsoleBackend>,
|
||||||
|
pub params: StartupMessageParams,
|
||||||
|
pub allow_self_signed_compute: bool,
|
||||||
|
pub creds: ComputeCredentials<ComputeCredentialKeys>,
|
||||||
|
|
||||||
|
// monitoring
|
||||||
|
pub ctx: RequestMonitoring,
|
||||||
|
pub cancel_session: Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsWakeCompute<S> {
|
||||||
|
fn span(&self) -> tracing::Span {
|
||||||
|
tracing::info_span!("wake_compute")
|
||||||
|
}
|
||||||
|
async fn run(self) -> Result<DynStage, StageError> {
|
||||||
|
let Self {
|
||||||
|
stream,
|
||||||
|
api,
|
||||||
|
params,
|
||||||
|
allow_self_signed_compute,
|
||||||
|
creds,
|
||||||
|
mut ctx,
|
||||||
|
cancel_session,
|
||||||
|
} = self;
|
||||||
|
|
||||||
|
let mut num_retries = 0;
|
||||||
|
let mut node_info = loop {
|
||||||
|
let wake_res = api.wake_compute(&mut ctx, &creds.info).await;
|
||||||
|
match handle_try_wake(wake_res, num_retries) {
|
||||||
|
Err(e) => {
|
||||||
|
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||||
|
return Err(user_facing_error(e, &mut ctx, stream));
|
||||||
|
}
|
||||||
|
Ok(ControlFlow::Continue(e)) => {
|
||||||
|
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
|
||||||
|
}
|
||||||
|
Ok(ControlFlow::Break(n)) => break n,
|
||||||
|
}
|
||||||
|
|
||||||
|
let wait_duration = retry_after(num_retries);
|
||||||
|
num_retries += 1;
|
||||||
|
tokio::time::sleep(wait_duration).await;
|
||||||
|
};
|
||||||
|
|
||||||
|
ctx.set_project(node_info.aux.clone());
|
||||||
|
|
||||||
|
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||||
|
|
||||||
|
match creds.keys {
|
||||||
|
#[cfg(feature = "testing")]
|
||||||
|
ComputeCredentialKeys::Password(password) => node_info.config.password(password),
|
||||||
|
ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Box::new(NeedsComputeConnection {
|
||||||
|
stream,
|
||||||
|
user_info: BackendType::Console(api, creds.info),
|
||||||
|
mechanism: TcpMechanism { params },
|
||||||
|
node_info,
|
||||||
|
ctx,
|
||||||
|
cancel_session,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ mod channel_binding;
|
|||||||
mod messages;
|
mod messages;
|
||||||
mod stream;
|
mod stream;
|
||||||
|
|
||||||
use crate::error::UserFacingError;
|
use crate::error::{ReportableError, UserFacingError};
|
||||||
use std::io;
|
use std::io;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
@@ -37,6 +37,25 @@ pub enum Error {
|
|||||||
Io(#[from] io::Error),
|
Io(#[from] io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ReportableError for Error {
|
||||||
|
fn get_error_type(&self) -> crate::error::ErrorKind {
|
||||||
|
match self {
|
||||||
|
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
|
||||||
|
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
|
||||||
|
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
|
||||||
|
Error::MissingBinding => crate::error::ErrorKind::Service,
|
||||||
|
Error::Io(io) => match io.kind() {
|
||||||
|
// tokio postgres uses these for various scram failures
|
||||||
|
io::ErrorKind::InvalidInput
|
||||||
|
| io::ErrorKind::UnexpectedEof
|
||||||
|
| io::ErrorKind::Other => crate::error::ErrorKind::User,
|
||||||
|
// all other IO errors are likely disconnects.
|
||||||
|
_ => crate::error::ErrorKind::Disconnect,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UserFacingError for Error {
|
impl UserFacingError for Error {
|
||||||
fn to_string_client(&self) -> String {
|
fn to_string_client(&self) -> String {
|
||||||
use Error::*;
|
use Error::*;
|
||||||
|
|||||||
@@ -124,6 +124,12 @@ pub async fn task_main(
|
|||||||
let cancel_map = Arc::new(CancelMap::default());
|
let cancel_map = Arc::new(CancelMap::default());
|
||||||
let session_id = uuid::Uuid::new_v4();
|
let session_id = uuid::Uuid::new_v4();
|
||||||
|
|
||||||
|
let root_span = info_span!(
|
||||||
|
"serverless",
|
||||||
|
session = %session_id,
|
||||||
|
%peer_addr,
|
||||||
|
);
|
||||||
|
|
||||||
request_handler(
|
request_handler(
|
||||||
req,
|
req,
|
||||||
config,
|
config,
|
||||||
@@ -135,12 +141,9 @@ pub async fn task_main(
|
|||||||
sni_name,
|
sni_name,
|
||||||
peer_addr.ip(),
|
peer_addr.ip(),
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
|
root_span.clone(),
|
||||||
)
|
)
|
||||||
.instrument(info_span!(
|
.instrument(root_span)
|
||||||
"serverless",
|
|
||||||
session = %session_id,
|
|
||||||
%peer_addr,
|
|
||||||
))
|
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -205,6 +208,7 @@ async fn request_handler(
|
|||||||
sni_hostname: Option<String>,
|
sni_hostname: Option<String>,
|
||||||
peer_addr: IpAddr,
|
peer_addr: IpAddr,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
|
root_span: tracing::Span,
|
||||||
) -> Result<Response<Body>, ApiError> {
|
) -> Result<Response<Body>, ApiError> {
|
||||||
let host = request
|
let host = request
|
||||||
.headers()
|
.headers()
|
||||||
@@ -215,27 +219,33 @@ async fn request_handler(
|
|||||||
|
|
||||||
// Check if the request is a websocket upgrade request.
|
// Check if the request is a websocket upgrade request.
|
||||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||||
info!(session_id = ?session_id, "performing websocket upgrade");
|
info!("performing websocket upgrade");
|
||||||
|
|
||||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||||
|
|
||||||
ws_connections.spawn(
|
ws_connections.spawn(
|
||||||
async move {
|
async move {
|
||||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
|
let ctx =
|
||||||
|
RequestMonitoring::new(session_id, peer_addr, "ws", &config.region, root_span);
|
||||||
|
|
||||||
if let Err(e) = websocket::serve_websocket(
|
let websocket = match websocket.await {
|
||||||
|
Err(e) => {
|
||||||
|
error!("error in websocket connection: {e:#}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Ok(ws) => ws,
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket::serve_websocket(
|
||||||
config,
|
config,
|
||||||
&mut ctx,
|
ctx,
|
||||||
websocket,
|
websocket,
|
||||||
&cancel_map,
|
cancel_map,
|
||||||
host,
|
host,
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
|
||||||
error!(session_id = ?session_id, "error in websocket connection: {e:#}");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
.in_current_span(),
|
.in_current_span(),
|
||||||
);
|
);
|
||||||
@@ -243,7 +253,8 @@ async fn request_handler(
|
|||||||
// Return the response so the spawned future can continue.
|
// Return the response so the spawned future can continue.
|
||||||
Ok(response)
|
Ok(response)
|
||||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
let mut ctx =
|
||||||
|
RequestMonitoring::new(session_id, peer_addr, "http", &config.region, root_span);
|
||||||
|
|
||||||
sql_over_http::handle(
|
sql_over_http::handle(
|
||||||
tls,
|
tls,
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use crate::{
|
|||||||
use bytes::{Buf, Bytes};
|
use bytes::{Buf, Bytes};
|
||||||
use futures::{Sink, Stream};
|
use futures::{Sink, Stream};
|
||||||
use hyper::upgrade::Upgraded;
|
use hyper::upgrade::Upgraded;
|
||||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
use hyper_tungstenite::{tungstenite::Message, WebSocketStream};
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
@@ -131,13 +131,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
|||||||
|
|
||||||
pub async fn serve_websocket(
|
pub async fn serve_websocket(
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
ctx: &mut RequestMonitoring,
|
ctx: RequestMonitoring,
|
||||||
websocket: HyperWebsocket,
|
websocket: WebSocketStream<Upgraded>,
|
||||||
cancel_map: &CancelMap,
|
cancel_map: Arc<CancelMap>,
|
||||||
hostname: Option<String>,
|
hostname: Option<String>,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
) -> anyhow::Result<()> {
|
) {
|
||||||
let websocket = websocket.await?;
|
|
||||||
handle_client(
|
handle_client(
|
||||||
config,
|
config,
|
||||||
ctx,
|
ctx,
|
||||||
@@ -146,8 +145,7 @@ pub async fn serve_websocket(
|
|||||||
ClientMode::Websockets { hostname },
|
ClientMode::Websockets { hostname },
|
||||||
endpoint_rate_limiter,
|
endpoint_rate_limiter,
|
||||||
)
|
)
|
||||||
.await?;
|
.await
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
149
proxy/src/state_machine.rs
Normal file
149
proxy/src/state_machine.rs
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
use futures::Future;
|
||||||
|
use pq_proto::{framed::Framed, BeMessage};
|
||||||
|
use tokio::{io::AsyncWrite, task::JoinHandle};
|
||||||
|
use tracing::{info, warn, Instrument};
|
||||||
|
|
||||||
|
pub trait Captures<T> {}
|
||||||
|
impl<T, U> Captures<T> for U {}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub enum StageError {
|
||||||
|
Flush(Framed<Box<dyn AsyncWrite + Unpin + Send + 'static>>),
|
||||||
|
Done,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StageError {
|
||||||
|
pub async fn finish(self) {
|
||||||
|
match self {
|
||||||
|
StageError::Flush(mut f) => {
|
||||||
|
// ignore result. we can't do anything about it.
|
||||||
|
// this is already the error case anyway...
|
||||||
|
if let Err(e) = f.flush().await {
|
||||||
|
warn!("could not send message to user: {e:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StageError::Done => {}
|
||||||
|
}
|
||||||
|
info!("task finished");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type DynStage = Box<dyn StageSpawn>;
|
||||||
|
|
||||||
|
/// Stage represents a single stage in a state machine.
|
||||||
|
pub trait Stage: 'static + Send {
|
||||||
|
/// The span this stage should be run inside.
|
||||||
|
fn span(&self) -> tracing::Span;
|
||||||
|
/// Run the current stage, returning a new [`DynStage`], or an error
|
||||||
|
///
|
||||||
|
/// Can be implemented as `async fn run(self) -> Result<DynStage, StageError>`
|
||||||
|
fn run(self) -> impl 'static + Send + Future<Output = Result<DynStage, StageError>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum StageResult {
|
||||||
|
Finished,
|
||||||
|
Run(JoinHandle<Result<DynStage, StageError>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait StageSpawn: 'static + Send {
|
||||||
|
fn run(self: Box<Self>) -> StageResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stage spawn is a helper trait for the state machine. It spawns the stages as a tokio task
|
||||||
|
impl<S: Stage> StageSpawn for S {
|
||||||
|
fn run(self: Box<Self>) -> StageResult {
|
||||||
|
let span = self.span();
|
||||||
|
StageResult::Run(tokio::spawn(S::run(*self).instrument(span)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Finished;
|
||||||
|
|
||||||
|
impl StageSpawn for Finished {
|
||||||
|
fn run(self: Box<Self>) -> StageResult {
|
||||||
|
StageResult::Finished
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
context::RequestMonitoring,
|
||||||
|
error::{ErrorKind, UserFacingError},
|
||||||
|
stream::PqStream,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub trait ResultExt<T, E> {
|
||||||
|
fn send_error_to_user<S>(
|
||||||
|
self,
|
||||||
|
ctx: &mut RequestMonitoring,
|
||||||
|
stream: PqStream<S>,
|
||||||
|
) -> Result<(T, PqStream<S>), StageError>
|
||||||
|
where
|
||||||
|
S: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
E: UserFacingError;
|
||||||
|
|
||||||
|
fn no_user_error(self, ctx: &mut RequestMonitoring, kind: ErrorKind) -> Result<T, StageError>
|
||||||
|
where
|
||||||
|
E: std::fmt::Display;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, E> ResultExt<T, E> for Result<T, E> {
|
||||||
|
fn send_error_to_user<S>(
|
||||||
|
self,
|
||||||
|
ctx: &mut RequestMonitoring,
|
||||||
|
stream: PqStream<S>,
|
||||||
|
) -> Result<(T, PqStream<S>), StageError>
|
||||||
|
where
|
||||||
|
S: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
E: UserFacingError,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Ok(t) => Ok((t, stream)),
|
||||||
|
Err(e) => Err(user_facing_error(e, ctx, stream)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn no_user_error(self, ctx: &mut RequestMonitoring, kind: ErrorKind) -> Result<T, StageError>
|
||||||
|
where
|
||||||
|
E: std::fmt::Display,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Ok(t) => Ok(t),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(
|
||||||
|
kind = kind.to_str(),
|
||||||
|
user_msg = "",
|
||||||
|
"task finished with error: {e}"
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.error(kind);
|
||||||
|
ctx.log();
|
||||||
|
Err(StageError::Done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn user_facing_error<S, E>(
|
||||||
|
err: E,
|
||||||
|
ctx: &mut RequestMonitoring,
|
||||||
|
mut stream: PqStream<S>,
|
||||||
|
) -> StageError
|
||||||
|
where
|
||||||
|
S: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
E: UserFacingError,
|
||||||
|
{
|
||||||
|
let kind = err.get_error_type();
|
||||||
|
ctx.error(kind);
|
||||||
|
ctx.log();
|
||||||
|
|
||||||
|
let msg = err.to_string_client();
|
||||||
|
tracing::error!(
|
||||||
|
kind = kind.to_str(),
|
||||||
|
user_msg = msg,
|
||||||
|
"task finished with error: {err}"
|
||||||
|
);
|
||||||
|
if let Err(err) = stream.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)) {
|
||||||
|
warn!("could not process error message: {err:?}")
|
||||||
|
}
|
||||||
|
StageError::Flush(stream.framed.map_stream_sync(|f| Box::new(f) as Box<_>))
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::config::TlsServerEndPoint;
|
use crate::config::TlsServerEndPoint;
|
||||||
use crate::error::UserFacingError;
|
use crate::error::ErrorKind;
|
||||||
use anyhow::bail;
|
use anyhow::bail;
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
|
|
||||||
@@ -99,24 +99,17 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
|||||||
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
||||||
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
||||||
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
|
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
|
||||||
tracing::info!("forwarding error to user: {error}");
|
let kind = ErrorKind::User;
|
||||||
|
tracing::error!(
|
||||||
|
kind = kind.to_str(),
|
||||||
|
full_msg = error,
|
||||||
|
user_msg = error,
|
||||||
|
"task finished with error"
|
||||||
|
);
|
||||||
self.write_message(&BeMessage::ErrorResponse(error, None))
|
self.write_message(&BeMessage::ErrorResponse(error, None))
|
||||||
.await?;
|
.await?;
|
||||||
bail!(error)
|
bail!(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
|
||||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
|
||||||
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
|
|
||||||
where
|
|
||||||
E: UserFacingError + Into<anyhow::Error>,
|
|
||||||
{
|
|
||||||
let msg = error.to_string_client();
|
|
||||||
tracing::info!("forwarding error to user: {msg}");
|
|
||||||
self.write_message(&BeMessage::ErrorResponse(&msg, None))
|
|
||||||
.await?;
|
|
||||||
bail!(error)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wrapper for upgrading raw streams into secure streams.
|
/// Wrapper for upgrading raw streams into secure streams.
|
||||||
|
|||||||
Reference in New Issue
Block a user