Compare commits

...

12 Commits

Author SHA1 Message Date
Conrad Ludgate
77e431d80a abstract out listen-accept loops 2024-10-07 16:52:37 +01:00
Conrad Ludgate
784571eac7 remove redundant generic 2024-10-07 16:52:37 +01:00
Conrad Ludgate
ffd0d875cf minor changes 2024-10-07 16:52:37 +01:00
Conrad Ludgate
dbf34a17ce rename to serverless backend 2024-10-07 16:52:37 +01:00
Conrad Ludgate
4001e24745 proxy: continue streamlining auth::Backend 2024-10-07 16:52:37 +01:00
Conrad Ludgate
d9a59c5a3f deduplicate some mode 2024-10-07 16:52:37 +01:00
Conrad Ludgate
c7afbe55c9 remove some duplication 2024-10-07 16:52:37 +01:00
Conrad Ludgate
d0930c9d1d remove console-redirect from mega-backend object 2024-10-07 16:52:37 +01:00
Conrad Ludgate
addfff61b5 create console_redirect_proxy solo-path 2024-10-07 16:52:37 +01:00
Conrad Ludgate
cb28721eee make well-defined console request backend 2024-10-07 16:52:37 +01:00
Conrad Ludgate
07076e88a9 shrink some uses of config 2024-10-07 16:52:37 +01:00
Conrad Ludgate
2feba8a3da remove auth backend from proxy config 2024-10-07 16:52:37 +01:00
16 changed files with 849 additions and 614 deletions

View File

@@ -1,18 +1,24 @@
use crate::{
auth, compute,
auth,
cache::Cached,
compute,
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::{self, provider::NodeInfo},
control_plane::{self, provider::NodeInfo, CachedNodeInfo},
error::{ReportableError, UserFacingError},
proxy::connect_compute::ComputeConnectBackend,
stream::PqStream,
waiters,
};
use async_trait::async_trait;
use pq_proto::BeMessage as Be;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::SslMode;
use tracing::{info, info_span};
use super::ComputeCredentialKeys;
#[derive(Debug, Error)]
pub(crate) enum WebAuthError {
#[error(transparent)]
@@ -25,6 +31,11 @@ pub(crate) enum WebAuthError {
Io(#[from] std::io::Error),
}
#[derive(Debug)]
pub struct ConsoleRedirectBackend {
console_uri: reqwest::Url,
}
impl UserFacingError for WebAuthError {
fn to_string_client(&self) -> String {
"Internal error".to_string()
@@ -57,7 +68,40 @@ pub(crate) fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
pub(super) async fn authenticate(
impl ConsoleRedirectBackend {
pub fn new(console_uri: reqwest::Url) -> Self {
Self { console_uri }
}
pub(crate) async fn authenticate(
&self,
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<ConsoleRedirectNodeInfo> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
.map(ConsoleRedirectNodeInfo)
}
}
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
#[async_trait]
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
async fn wake_compute(
&self,
_ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
Ok(Cached::new_uncached(self.0.clone()))
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&ComputeCredentialKeys::None
}
}
async fn authenticate(
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,

View File

@@ -8,6 +8,7 @@ use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::WebAuthError;
use ipnet::{Ipv4Net, Ipv6Net};
use local::LocalBackend;
@@ -19,9 +20,8 @@ use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{validate_password_and_exchange, AuthError};
use crate::cache::Cached;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::provider::{CachedRoleSecret, ControlPlaneBackend};
use crate::control_plane::{AuthSecret, NodeInfo};
use crate::control_plane::provider::ControlPlaneBackend;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
@@ -31,49 +31,23 @@ use crate::stream::Stream;
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
config::AuthenticationConfig,
control_plane::{
self,
provider::{CachedAllowedIps, CachedNodeInfo},
Api,
},
stream, url,
control_plane::{self, provider::CachedNodeInfo, Api},
stream,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
Owned(T),
Borrowed(&'a T),
}
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
MaybeOwned::Owned(t) => t,
MaybeOwned::Borrowed(t) => t,
}
}
}
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
/// which we use in [`crate::config::ProxyConfig`].
///
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum Backend<'a, T, D> {
/// The [crate::serverless] module can authenticate either using control-plane
/// to get authentication state, or by using JWKs stored in the filesystem.
pub enum ServerlessBackend<'a> {
/// Cloud API (V2).
ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
/// Authentication via a web browser.
ConsoleRedirect(MaybeOwned<'a, url::ApiUrl>, D),
ControlPlane(&'a ControlPlaneBackend),
/// Local proxy uses configured auth credentials and does not wake compute
Local(MaybeOwned<'a, LocalBackend>),
Local(&'a LocalBackend),
}
#[cfg(test)]
use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
#[cfg(test)]
pub(crate) trait TestBackend: Send + Sync + 'static {
fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
@@ -90,63 +64,20 @@ impl Clone for Box<dyn TestBackend> {
}
}
impl std::fmt::Display for Backend<'_, (), ()> {
impl std::fmt::Display for ControlPlaneBackend {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ControlPlane(api, ()) => match &**api {
ControlPlaneBackend::Management(endpoint) => fmt
.debug_tuple("ControlPlane::Management")
.field(&endpoint.url())
.finish(),
#[cfg(any(test, feature = "testing"))]
ControlPlaneBackend::PostgresMock(endpoint) => fmt
.debug_tuple("ControlPlane::PostgresMock")
.field(&endpoint.url())
.finish(),
#[cfg(test)]
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
},
Self::ConsoleRedirect(url, ()) => fmt
.debug_tuple("ConsoleRedirect")
.field(&url.as_str())
ControlPlaneBackend::Management(endpoint) => fmt
.debug_tuple("ControlPlane::Management")
.field(&endpoint.url())
.finish(),
Self::Local(_) => fmt.debug_tuple("Local").finish(),
}
}
}
impl<T, D> Backend<'_, T, D> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub(crate) fn as_ref(&self) -> Backend<'_, &T, &D> {
match self {
Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
Self::ConsoleRedirect(c, x) => Backend::ConsoleRedirect(MaybeOwned::Borrowed(c), x),
Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
}
}
}
impl<'a, T, D> Backend<'a, T, D> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
/// a function to a contained value.
pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R, D> {
match self {
Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
Self::ConsoleRedirect(c, x) => Backend::ConsoleRedirect(c, x),
Self::Local(l) => Backend::Local(l),
}
}
}
impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub(crate) fn transpose(self) -> Result<Backend<'a, T, D>, E> {
match self {
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
Self::ConsoleRedirect(c, x) => Ok(Backend::ConsoleRedirect(c, x)),
Self::Local(l) => Ok(Backend::Local(l)),
#[cfg(any(test, feature = "testing"))]
ControlPlaneBackend::PostgresMock(endpoint) => fmt
.debug_tuple("ControlPlane::PostgresMock")
.field(&endpoint.url())
.finish(),
#[cfg(test)]
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
}
}
}
@@ -239,7 +170,6 @@ impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
@@ -263,7 +193,7 @@ impl AuthenticationConfig {
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
),
password_weight,
);
@@ -337,7 +267,6 @@ async fn auth_quirks(
let secret = if let Some(secret) = secret {
config.check_rate_limit(
ctx,
config,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
@@ -413,133 +342,79 @@ async fn authenticate_with_secret(
classic::authenticate(ctx, info, client, config, secret).await
}
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
match self {
Self::ControlPlane(_, user_info) => &user_info.user,
Self::ConsoleRedirect(_, ()) => "web",
Self::Local(_) => "local",
}
}
/// Authenticate the client via the requested backend, possibly using credentials.
impl ControlPlaneBackend {
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub(crate) async fn authenticate(
self,
&self,
ctx: &RequestMonitoring,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<Backend<'a, ComputeCredentials, NodeInfo>> {
let res = match self {
Self::ControlPlane(api, user_info) => {
info!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
) -> auth::Result<ControlPlaneComputeBackend> {
info!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let credentials = auth_quirks(
ctx,
&*api,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
Backend::ControlPlane(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
Self::ConsoleRedirect(url, ()) => {
info!("performing web authentication");
let info = console_redirect::authenticate(ctx, config, &url, client).await?;
Backend::ConsoleRedirect(url, info)
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
}
};
let credentials = auth_quirks(
ctx,
self,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
info!("user successfully authenticated");
Ok(res)
Ok(ControlPlaneComputeBackend {
api: self,
creds: credentials,
})
}
pub(crate) fn attach_to_credentials(
&self,
creds: ComputeCredentials,
) -> ControlPlaneComputeBackend {
ControlPlaneComputeBackend { api: self, creds }
}
}
impl Backend<'_, ComputeUserInfo, &()> {
pub(crate) async fn get_role_secret(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
Self::ConsoleRedirect(_, ()) => Ok(Cached::new_uncached(None)),
Self::Local(_) => Ok(Cached::new_uncached(None)),
}
}
pub(crate) async fn get_allowed_ips_and_secret(
&self,
ctx: &RequestMonitoring,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_allowed_ips_and_secret(ctx, user_info).await
}
Self::ConsoleRedirect(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
pub struct ControlPlaneComputeBackend<'a> {
api: &'a ControlPlaneBackend,
creds: ComputeCredentials,
}
#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials, NodeInfo> {
impl ComputeConnectBackend for ControlPlaneComputeBackend<'_> {
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
match self {
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::ConsoleRedirect(_, info) => Ok(Cached::new_uncached(info.clone())),
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
self.api.wake_compute(ctx, &self.creds.info).await
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::ControlPlane(_, creds) => &creds.keys,
Self::ConsoleRedirect(_, _) => &ComputeCredentialKeys::None,
Self::Local(_) => &ComputeCredentialKeys::None,
}
&self.creds.keys
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials, &()> {
impl ComputeConnectBackend for LocalBackend {
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
_ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
match self {
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::ConsoleRedirect(_, ()) => {
unreachable!("web auth flow doesn't support waking the compute")
}
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
Ok(Cached::new_uncached(self.node_info.clone()))
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::ControlPlane(_, creds) => &creds.keys,
Self::ConsoleRedirect(_, ()) => &ComputeCredentialKeys::None,
Self::Local(_) => &ComputeCredentialKeys::None,
}
&ComputeCredentialKeys::None
}
}

View File

@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::Backend;
pub use backend::ServerlessBackend;
mod credentials;
pub(crate) use credentials::{

View File

@@ -6,9 +6,12 @@ use compute_api::spec::LocalProxySpec;
use dashmap::DashMap;
use futures::future::Either;
use proxy::{
auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
auth::{
self,
backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
@@ -132,6 +135,7 @@ async fn main() -> anyhow::Result<()> {
let args = LocalProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
// before we bind to any ports, write the process ID to a file
// so that compute-ctl can find our process later
@@ -193,6 +197,7 @@ async fn main() -> anyhow::Result<()> {
let task = serverless::task_main(
config,
auth::ServerlessBackend::Local(auth_backend),
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
@@ -257,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.compute),
)),
metric_collection: None,
allow_self_signed_compute: false,
http_config,
@@ -286,6 +288,13 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
})))
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(args: &LocalProxyCliArgs) -> anyhow::Result<&'static LocalBackend> {
let auth_backend = LocalBackend::new(args.compute);
Ok(Box::leak(Box::new(auth_backend)))
}
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
loop {
rx.notified().await;

View File

@@ -10,7 +10,7 @@ use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::MaybeOwned;
use proxy::auth::backend::ConsoleRedirectBackend;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
use proxy::config::remote_storage_from_toml;
@@ -21,6 +21,7 @@ use proxy::config::ProjectInfoCacheOptions;
use proxy::config::ProxyProtocolV2;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::control_plane;
use proxy::control_plane::provider::ControlPlaneBackend;
use proxy::http;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
@@ -311,8 +312,12 @@ async fn main() -> anyhow::Result<()> {
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
info!("Authentication backend: {}", config.auth_backend);
match auth_backend {
Either::Left(auth_backend) => info!("Authentication backend: {auth_backend}"),
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
};
info!("Using region: {}", args.aws_region);
let region_provider =
@@ -459,24 +464,41 @@ async fn main() -> anyhow::Result<()> {
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
match auth_backend {
Either::Left(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
auth::ServerlessBackend::ControlPlane(auth_backend),
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::console_redirect_proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
));
}
}
}
client_tasks.spawn(proxy::context::parquet::worker(
@@ -506,40 +528,38 @@ async fn main() -> anyhow::Result<()> {
));
}
if let auth::Backend::ControlPlane(api, _) = &config.auth_backend {
if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
if let Either::Left(ControlPlaneBackend::Management(api)) = &auth_backend {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
.instrument(span),
);
}
}
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
.instrument(span),
);
}
}
@@ -610,73 +630,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
bail!("dynamic rate limiter should be disabled");
}
let auth_backend = match &args.auth_backend {
AuthBackendType::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = control_plane::provider::neon::Api::new(
endpoint,
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let api = control_plane::provider::ControlPlaneBackend::Management(api);
auth::Backend::ControlPlane(MaybeOwned::Owned(api), ())
}
AuthBackendType::Web => {
let url = args.uri.parse()?;
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(url), ())
}
#[cfg(feature = "testing")]
AuthBackendType::Postgres => {
let url = args.auth_endpoint.parse()?;
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
let api = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
auth::Backend::ControlPlane(MaybeOwned::Owned(api), ())
}
};
let config::ConcurrencyLockOptions {
shards,
limiter,
@@ -726,9 +679,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
};
let config = Box::leak(Box::new(ProxyConfig {
let config = ProxyConfig {
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
@@ -741,13 +693,97 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
}));
};
let config = Box::leak(Box::new(config));
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
Ok(config)
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,
) -> anyhow::Result<Either<&'static ControlPlaneBackend, &'static ConsoleRedirectBackend>> {
match &args.auth_backend {
AuthBackendType::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = control_plane::provider::neon::Api::new(
endpoint,
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let auth_backend = control_plane::provider::ControlPlaneBackend::Management(api);
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
#[cfg(feature = "testing")]
AuthBackendType::Postgres => {
let url = args.auth_endpoint.parse()?;
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
let auth_backend = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
AuthBackendType::Web => {
let url = args.uri.parse()?;
let backend = ConsoleRedirectBackend::new(url);
let config = Box::leak(Box::new(backend));
Ok(Either::Right(config))
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;

View File

@@ -1,8 +1,5 @@
use crate::{
auth::{
self,
backend::{jwt::JwkCache, AuthRateLimiter},
},
auth::backend::{jwt::JwkCache, AuthRateLimiter},
control_plane::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
@@ -29,7 +26,6 @@ use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::Backend<'static, (), ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,

View File

@@ -0,0 +1,161 @@
use crate::auth::backend::ConsoleRedirectBackend;
use crate::config::ProxyConfig;
use crate::metrics::Protocol;
use crate::proxy::{prepare_client_connection, transition_connection, ClientRequestError};
use crate::{
cancellation::CancellationHandlerMain,
context::RequestMonitoring,
metrics::{Metrics, NumClientConnectionsGuard},
proxy::handshake::{handshake, HandshakeData},
};
use futures::TryFutureExt;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{info, Instrument};
use crate::proxy::{
connect_compute::{connect_to_compute, TcpMechanism},
passthrough::ProxyPassthrough,
};
pub async fn task_main(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
super::connection_loop(
config,
listener,
cancellation_token,
Protocol::Tcp,
C {
config,
backend,
cancellation_handler,
},
)
.await
}
#[derive(Clone)]
struct C {
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
cancellation_handler: Arc<CancellationHandlerMain>,
}
impl super::ConnHandler for C {
async fn handle(
self,
session_id: uuid::Uuid,
peer_addr: IpAddr,
socket: crate::protocol2::ChainRW<tokio::net::TcpStream>,
conn_gauge: crate::metrics::NumClientConnectionsGuard<'static>,
) {
let ctx = RequestMonitoring::new(session_id, peer_addr, Protocol::Tcp, &self.config.region);
let span = ctx.span();
let startup = Box::pin(
handle_client(
self.config,
self.backend,
&ctx,
self.cancellation_handler,
socket,
conn_gauge,
)
.instrument(span.clone()),
);
let res = startup.await;
transition_connection(ctx, res).await;
}
}
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
info!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.as_ref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id())
.await
.map(|()| None)?)
}
};
drop(pause);
ctx.set_db_options(params.clone());
let user_info = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await?;
}
};
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
params: &params,
locks: &config.connect_compute_locks,
},
&user_info,
config.allow_self_signed_compute,
config.wake_compute_retry_config,
config.connect_to_compute_retry_config,
)
.or_else(|e| stream.throw_error(e))
.await?;
let session = cancellation_handler.get_session();
prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
compute: node,
_req: request_gauge,
_conn: conn_gauge,
_cancel: session,
}))
}

View File

@@ -82,13 +82,16 @@
impl_trait_overcaptures,
)]
use std::convert::Infallible;
use std::{convert::Infallible, future::Future, net::IpAddr};
use anyhow::{bail, Context};
use intern::{EndpointIdInt, EndpointIdTag, InternId};
use tokio::task::JoinError;
use protocol2::{get_client_conn_info, ChainRW};
use proxy::run_until_cancelled;
use tokio::{net::TcpStream, task::JoinError};
use tokio_util::sync::CancellationToken;
use tracing::warn;
use tracing::{error, warn};
use uuid::Uuid;
extern crate hyper0 as hyper;
@@ -97,6 +100,7 @@ pub mod cache;
pub mod cancellation;
pub mod compute;
pub mod config;
pub mod console_redirect_proxy;
pub mod context;
pub mod control_plane;
pub mod error;
@@ -275,3 +279,81 @@ impl EndpointId {
ProjectId(self.0.clone())
}
}
pub(crate) trait ConnHandler: Clone + Send + 'static {
fn handle(
self,
session_id: Uuid,
peer_addr: IpAddr,
stream: ChainRW<TcpStream>,
conn_gauge: metrics::NumClientConnectionsGuard<'static>,
) -> impl Future<Output = ()> + Send;
}
/// Accept connections, parse the proxy-protocol v2 header and spawn a tracked connection task.
pub(crate) async fn connection_loop<C>(
config: &'static config::ProxyConfig,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
protocol: metrics::Protocol,
conn_handler: C,
) -> anyhow::Result<()>
where
C: ConnHandler,
{
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = metrics::Metrics::get()
.proxy
.client_connections
.guard(protocol);
let session_id = uuid::Uuid::new_v4();
let conn_handler = conn_handler.clone();
tracing::info!(protocol = protocol.as_str(), %session_id, "accepted new TCP connection");
connections.spawn(async move {
let (socket, peer_addr) = match get_client_conn_info(socket, config.proxy_protocol_v2).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
return;
}
Ok((socket, Some(addr))) => (socket, addr),
Ok((socket, None)) => (socket, peer_addr.ip()),
};
match socket.inner.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!("per-client task finished with an error: failed to set socket option: {e:#}");
return;
}
};
conn_handler.handle(
session_id,
peer_addr,
socket,
conn_gauge,
).await;
});
}
connections.close();
drop(listener);
// Drain connections
connections.wait().await;
Ok(())
}

View File

@@ -2,15 +2,18 @@
use std::{
io,
net::SocketAddr,
net::{IpAddr, SocketAddr},
pin::Pin,
task::{Context, Poll},
};
use anyhow::bail;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use crate::config::ProxyProtocolV2;
pin_project! {
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
pub(crate) struct ChainRW<T> {
@@ -60,7 +63,23 @@ const HEADER: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
pub(crate) async fn get_client_conn_info<T: AsyncRead + Unpin>(
socket: T,
proxy_protocol_v2: ProxyProtocolV2,
) -> anyhow::Result<(ChainRW<T>, Option<IpAddr>)> {
match read_proxy_protocol(socket).await? {
(_socket, None) if proxy_protocol_v2 == ProxyProtocolV2::Required => {
bail!("missing required proxy protocol header");
}
(_socket, Some(_)) if proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
bail!("proxy protocol header not supported");
}
(socket, Some(addr)) => Ok((socket, Some(addr.ip()))),
(socket, None) => Ok((socket, None)),
}
}
async fn read_proxy_protocol<T: AsyncRead + Unpin>(
mut read: T,
) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
let mut buf = BytesMut::with_capacity(128);

View File

@@ -10,16 +10,16 @@ pub(crate) mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use crate::config::ProxyProtocolV2;
use crate::control_plane::provider::ControlPlaneBackend;
use crate::metrics::Protocol;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
cancellation::{self, CancellationHandlerMain},
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
metrics::{Metrics, NumClientConnectionsGuard},
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
@@ -31,6 +31,7 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::net::IpAddr;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -61,6 +62,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static ControlPlaneBackend,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -70,109 +72,91 @@ pub async fn task_main(
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
super::connection_loop(
config,
listener,
cancellation_token,
Protocol::Tcp,
C {
config,
auth_backend,
cancellation_handler,
endpoint_rate_limiter,
},
)
.await
}
let connections = tokio_util::task::task_tracker::TaskTracker::new();
#[derive(Clone)]
struct C {
config: &'static ProxyConfig,
auth_backend: &'static ControlPlaneBackend,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
impl super::ConnHandler for C {
async fn handle(
self,
session_id: uuid::Uuid,
peer_addr: IpAddr,
socket: crate::protocol2::ChainRW<tokio::net::TcpStream>,
conn_gauge: crate::metrics::NumClientConnectionsGuard<'static>,
) {
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&self.config.region,
);
let span = ctx.span();
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let startup = Box::pin(
handle_client(
self.config,
self.auth_backend,
&ctx,
self.cancellation_handler,
socket,
ClientMode::Tcp,
self.endpoint_rate_limiter,
conn_gauge,
)
.instrument(span.clone()),
);
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let res = startup.await;
transition_connection(ctx, res).await;
}
}
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
return;
}
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
error!("missing required proxy protocol header");
return;
}
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
error!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
Ok((socket, None)) => (socket, peer_addr.ip()),
};
match socket.inner.set_nodelay(true) {
pub(crate) async fn transition_connection<S: AsyncRead + AsyncWrite + Unpin>(
ctx: RequestMonitoring,
res: Result<Option<ProxyPassthrough<S>>, ClientRequestError>,
) {
let span = ctx.span();
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
error!(parent: &span, "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass().instrument(span.clone()).await {
Ok(()) => {}
Err(e) => {
error!("per-client task finished with an error: failed to set socket option: {e:#}");
return;
Err(ErrorSource::Client(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
}
};
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span();
let startup = Box::pin(
handle_client(
config,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
)
.instrument(span.clone()),
);
let res = startup.await;
match res {
Err(e) => {
// todo: log and push to ctx the error kind
ctx.set_error_kind(e.get_error_kind());
error!(parent: &span, "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass().instrument(span.clone()).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
}
Err(ErrorSource::Compute(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
}
}
Err(ErrorSource::Compute(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
}
}
});
}
}
connections.close();
drop(listener);
// Drain connections
connections.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
@@ -243,15 +227,17 @@ impl ReportableError for ClientRequestError {
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
auth_backend: &'static ControlPlaneBackend,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
info!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
@@ -285,21 +271,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = config
.auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names);
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e).await?,
};
let user = user_info.get_user().to_owned();
let user_info = match user_info
let user = user_info.user.clone();
let user_info = match auth_backend
.authenticate(
ctx,
user_info,
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
@@ -353,7 +335,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
async fn prepare_client_connection<P>(
pub(crate) async fn prepare_client_connection<P>(
node: &compute::PostgresConnection,
session: &cancellation::Session<P>,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,

View File

@@ -1,5 +1,5 @@
use crate::{
cancellation,
cancellation::{self, CancellationHandlerMainInternal},
compute::PostgresConnection,
control_plane::messages::MetricsAuxInfo,
metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard},
@@ -57,17 +57,17 @@ pub(crate) async fn proxy_pass(
Ok(())
}
pub(crate) struct ProxyPassthrough<P, S> {
pub(crate) struct ProxyPassthrough<S> {
pub(crate) client: Stream<S>,
pub(crate) compute: PostgresConnection,
pub(crate) aux: MetricsAuxInfo,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
pub(crate) _cancel: cancellation::Session<P>,
pub(crate) _cancel: cancellation::Session<CancellationHandlerMainInternal>,
}
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {

View File

@@ -8,18 +8,20 @@ use super::connect_compute::ConnectMechanism;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, TestBackend,
};
use crate::config::{CertResolver, RetryConfig};
use crate::config::{CertResolver, ProxyProtocolV2, RetryConfig};
use crate::control_plane::messages::{ControlPlaneError, Details, MetricsAuxInfo, Status};
use crate::control_plane::provider::{
CachedAllowedIps, CachedRoleSecret, ControlPlaneBackend, NodeInfoCache,
};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::protocol2::get_client_conn_info;
use crate::{sasl, scram, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use auth::backend::ControlPlaneComputeBackend;
use http::StatusCode;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
@@ -176,7 +178,7 @@ async fn dummy_proxy(
tls: Option<TlsConfig>,
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = read_proxy_protocol(client).await?;
let (client, _) = get_client_conn_info(client, ProxyProtocolV2::Supported).await?;
let mut stream =
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
@@ -552,19 +554,19 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> auth::Backend<'static, ComputeCredentials, &()> {
let user_info = auth::Backend::ControlPlane(
MaybeOwned::Owned(ControlPlaneBackend::Test(Box::new(mechanism.clone()))),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::Password("password".into()),
) -> ControlPlaneComputeBackend<'static> {
let api = Box::leak(Box::new(ControlPlaneBackend::Test(Box::new(
mechanism.clone(),
))));
api.attach_to_credentials(ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
);
user_info
keys: ComputeCredentialKeys::Password("password".into()),
})
}
#[tokio::test]

View File

@@ -8,16 +8,16 @@ use tracing::{field::display, info};
use crate::{
auth::{
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
check_peer_addr_is_in_list, AuthError,
check_peer_addr_is_in_list, AuthError, ServerlessBackend,
},
compute,
config::{AuthenticationConfig, ProxyConfig},
config::ProxyConfig,
context::RequestMonitoring,
control_plane::{
errors::{GetAuthInfoError, WakeComputeError},
locks::ApiLocks,
provider::ApiLockError,
CachedNodeInfo,
Api, CachedNodeInfo,
},
error::{ErrorKind, ReportableError, UserFacingError},
intern::EndpointIdInt,
@@ -38,6 +38,7 @@ pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: ServerlessBackend<'static>,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
@@ -45,18 +46,20 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_password(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
password: &[u8],
) -> Result<ComputeCredentials, AuthError> {
let user_info = user_info.clone();
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if config.ip_allowlist_check_enabled
let cplane = match &self.auth_backend {
ServerlessBackend::ControlPlane(cplane) => cplane,
ServerlessBackend::Local(_local) => {
return Err(AuthError::bad_auth_method(
"password authentication not supported by local_proxy",
))
}
};
let (allowed_ips, maybe_secret) = cplane.get_allowed_ips_and_secret(ctx, user_info).await?;
if self.config.authentication_config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
@@ -69,13 +72,12 @@ impl PoolingBackend {
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => backend.get_role_secret(ctx).await?,
None => cplane.get_role_secret(ctx, user_info).await?,
};
let secret = match cached_secret.value.clone() {
Some(secret) => self.config.authentication_config.check_rate_limit(
ctx,
config,
secret,
&user_info.endpoint,
true,
@@ -87,9 +89,13 @@ impl PoolingBackend {
}
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome =
crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
.await?;
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
ep,
password,
secret,
)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");
@@ -101,7 +107,7 @@ impl PoolingBackend {
}
};
res.map(|key| ComputeCredentials {
info: user_info,
info: user_info.clone(),
keys: key,
})
}
@@ -109,13 +115,13 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<(), AuthError> {
match &self.config.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
config
match &self.auth_backend {
ServerlessBackend::ControlPlane(console) => {
self.config
.authentication_config
.jwks_cache
.check_jwt(
ctx,
@@ -129,11 +135,9 @@ impl PoolingBackend {
Ok(())
}
crate::auth::Backend::ConsoleRedirect(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(_) => {
config
ServerlessBackend::Local(_) => {
self.config
.authentication_config
.jwks_cache
.check_jwt(
ctx,
@@ -176,21 +180,41 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.config.auth_backend.as_ref().map(|()| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
match &self.auth_backend {
ServerlessBackend::ControlPlane(cplane) => {
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
&cplane.attach_to_credentials(keys),
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
ServerlessBackend::Local(local_proxy) => {
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
&**local_proxy,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
}
}
// Wake up the destination if needed
@@ -200,6 +224,13 @@ impl PoolingBackend {
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client, HttpConnError> {
let cplane = match &self.auth_backend {
ServerlessBackend::Local(_) => {
panic!("connect to local_proxy should not be called if we are already local_proxy")
}
ServerlessBackend::ControlPlane(cplane) => cplane,
};
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
@@ -208,14 +239,11 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
let backend = cplane.attach_to_credentials(ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {

View File

@@ -16,7 +16,6 @@ use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
@@ -32,28 +31,29 @@ use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::task::TaskTracker;
use crate::auth::ServerlessBackend;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::metrics::Metrics;
use crate::protocol2::{read_proxy_protocol, ChainRW};
use crate::proxy::run_until_cancelled;
use crate::metrics::{Metrics, Protocol};
use crate::protocol2::ChainRW;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use std::net::{IpAddr, SocketAddr};
use std::net::IpAddr;
use std::pin::{pin, Pin};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};
use tracing::{error, info, instrument, warn, Instrument};
use utils::http::error::ApiError;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: ServerlessBackend<'static>,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -107,6 +107,7 @@ pub async fn task_main(
http_conn_pool: Arc::clone(&http_conn_pool),
pool: Arc::clone(&conn_pool),
config,
auth_backend,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
@@ -122,81 +123,100 @@ pub async fn task_main(
}
};
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let requests = TaskTracker::new();
requests.close(); // allows `requests.wait to complete`
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
tracing::error!("could not set nodelay: {e}");
continue;
}
let conn_id = uuid::Uuid::new_v4();
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
crate::connection_loop(
config,
ws_listener,
cancellation_token.clone(),
Protocol::Http,
C {
config,
backend,
cancellation_handler,
endpoint_rate_limiter,
tls_acceptor,
requests: requests.clone(),
cancellation_token,
},
)
.await?;
requests.wait().await;
Ok(())
}
#[derive(Clone)]
struct C {
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
requests: TaskTracker,
cancellation_token: CancellationToken,
}
impl super::ConnHandler for C {
#[instrument(name = "http_conn", skip_all, fields(conn_id))]
async fn handle(
self,
conn_id: uuid::Uuid,
peer_addr: IpAddr,
stream: ChainRW<TcpStream>,
conn_gauge: crate::metrics::NumClientConnectionsGuard<'static>,
) {
// try and close an old HTTP connection.
// picked at random
let n_connections = Metrics::get()
.proxy
.client_connections
.sample(crate::metrics::Protocol::Http);
tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check");
if n_connections > config.http_config.client_conn_threshold {
tracing::trace!(?n_connections, threshold = ?self.config.http_config.client_conn_threshold, "check");
if n_connections > self.config.http_config.client_conn_threshold {
tracing::trace!("attempting to cancel a random connection");
if let Some(token) = config.http_config.cancel_set.take() {
if let Some(token) = self.config.http_config.cancel_set.take() {
tracing::debug!("cancelling a random connection");
token.cancel();
}
}
let conn_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
let conn_token = self.cancellation_token.child_token();
let _cancel_guard = self
.config
.http_config
.cancel_set
.insert(conn_id, conn_token.clone());
let session_id = uuid::Uuid::new_v4();
let startup_result = Box::pin(connection_startup(
self.config,
self.tls_acceptor,
conn_id,
stream,
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
return;
};
let _gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Http);
Box::pin(connection_handler(
self.config,
self.backend,
self.requests,
self.cancellation_handler,
self.endpoint_rate_limiter,
conn_token,
conn,
peer_addr,
conn_id,
))
.await;
let startup_result = Box::pin(connection_startup(
config,
tls_acceptor,
session_id,
conn,
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
return;
};
Box::pin(connection_handler(
config,
backend,
connections2,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
conn,
peer_addr,
session_id,
))
.await;
}
.instrument(http_conn_span),
);
drop(conn_gauge);
}
connections.wait().await;
Ok(())
}
pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
@@ -224,26 +244,14 @@ impl MaybeTlsAcceptor for NoTls {
}
}
/// Handles the TCP startup lifecycle.
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
/// Handles the TLS startup handshake.
async fn connection_startup(
config: &ProxyConfig,
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
session_id: uuid::Uuid,
conn: TcpStream,
peer_addr: SocketAddr,
conn: ChainRW<TcpStream>,
peer_addr: IpAddr,
) -> Option<(AsyncRW, IpAddr)> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
Err(e) => {
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return None;
}
};
let peer_addr = peer.unwrap_or(peer_addr).ip();
let has_private_peer_addr = match peer_addr {
IpAddr::V4(ip) => ip.is_private(),
IpAddr::V6(_) => false,
@@ -377,6 +385,10 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ServerlessBackend::ControlPlane(auth_backend) = backend.auth_backend else {
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
};
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
@@ -394,6 +406,7 @@ async fn request_handler(
async move {
if let Err(e) = websocket::serve_websocket(
config,
auth_backend,
ctx,
websocket,
cancellation_handler,

View File

@@ -45,6 +45,7 @@ use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::HttpConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
@@ -552,7 +553,7 @@ async fn handle_inner(
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
@@ -623,22 +624,12 @@ async fn handle_db_inner(
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
&pw,
)
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await?;
ComputeCredentials {
@@ -680,7 +671,7 @@ async fn handle_db_inner(
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => {
stmt.process(config, cancel, &mut client, parsed_headers)
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
}
Payload::Batch(statements) => {
@@ -698,7 +689,7 @@ async fn handle_db_inner(
}
statements
.process(config, cancel, &mut client, parsed_headers)
.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
}
};
@@ -738,7 +729,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
];
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
@@ -746,12 +736,7 @@ async fn handle_auth_broker_inner(
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::from)?;
@@ -789,7 +774,7 @@ async fn handle_auth_broker_inner(
impl QueryData {
async fn process(
self,
config: &'static ProxyConfig,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -863,7 +848,7 @@ impl QueryData {
impl BatchQueryData {
async fn process(
self,
config: &'static ProxyConfig,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -933,7 +918,7 @@ impl BatchQueryData {
}
async fn query_batch(
config: &'static ProxyConfig,
config: &'static HttpConfig,
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
@@ -972,7 +957,7 @@ async fn query_batch(
}
async fn query_to_json<T: GenericClient>(
config: &'static ProxyConfig,
config: &'static HttpConfig,
client: &T,
data: QueryData,
current_size: &mut usize,
@@ -993,9 +978,9 @@ async fn query_to_json<T: GenericClient>(
rows.push(row);
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
if *current_size > config.http_config.max_response_size_bytes {
if *current_size > config.max_response_size_bytes {
return Err(SqlOverHttpError::ResponseTooLarge(
config.http_config.max_response_size_bytes,
config.max_response_size_bytes,
));
}
}

View File

@@ -1,3 +1,4 @@
use crate::control_plane::provider::ControlPlaneBackend;
use crate::proxy::ErrorSource;
use crate::{
cancellation::CancellationHandlerMain,
@@ -129,6 +130,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static ControlPlaneBackend,
ctx: RequestMonitoring,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -145,6 +147,7 @@ pub(crate) async fn serve_websocket(
let res = Box::pin(handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
WebSocketRw::new(websocket),