Compare commits

...

5 Commits

Author SHA1 Message Date
Conrad Ludgate
b66e545e26 a little more type-safety, a little more verbose... 2024-10-24 12:33:10 +01:00
Conrad Ludgate
c8108a4b84 make ComputeConnectBackend dyn 2024-10-24 11:55:31 +01:00
Conrad Ludgate
2d34fec39b minor changes 2024-10-24 11:48:43 +01:00
Conrad Ludgate
3da4705775 rename to serverless backend 2024-10-24 11:44:15 +01:00
Conrad Ludgate
80c5576816 proxy: continue streamlining auth::Backend 2024-10-24 11:43:46 +01:00
12 changed files with 270 additions and 300 deletions

View File

@@ -21,10 +21,7 @@ use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserIn
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::provider::{
CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneBackend,
};
use crate::control_plane::provider::{CachedNodeInfo, ControlPlaneBackend};
use crate::control_plane::{self, Api, AuthSecret};
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
@@ -35,38 +32,19 @@ use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{scram, stream};
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
Owned(T),
Borrowed(&'a T),
}
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
MaybeOwned::Owned(t) => t,
MaybeOwned::Borrowed(t) => t,
}
}
}
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
/// which we use in [`crate::config::ProxyConfig`].
///
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum Backend<'a, T> {
/// The [crate::serverless] module can authenticate either using control-plane
/// to get authentication state, or by using JWKs stored in the filesystem.
#[derive(Clone, Copy)]
pub enum ServerlessBackend<'a> {
/// Cloud API (V2).
ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
ControlPlane(&'a ControlPlaneBackend),
/// Local proxy uses configured auth credentials and does not wake compute
Local(MaybeOwned<'a, LocalBackend>),
Local(&'a LocalBackend),
}
#[cfg(test)]
use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
#[cfg(test)]
pub(crate) trait TestBackend: Send + Sync + 'static {
fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
@@ -83,56 +61,20 @@ impl Clone for Box<dyn TestBackend> {
}
}
impl std::fmt::Display for Backend<'_, ()> {
impl std::fmt::Display for ControlPlaneBackend {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ControlPlane(api, ()) => match &**api {
ControlPlaneBackend::Management(endpoint) => fmt
.debug_tuple("ControlPlane::Management")
.field(&endpoint.url())
.finish(),
#[cfg(any(test, feature = "testing"))]
ControlPlaneBackend::PostgresMock(endpoint) => fmt
.debug_tuple("ControlPlane::PostgresMock")
.field(&endpoint.url())
.finish(),
#[cfg(test)]
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
},
Self::Local(_) => fmt.debug_tuple("Local").finish(),
}
}
}
impl<T> Backend<'_, T> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
match self {
Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
}
}
}
impl<'a, T> Backend<'a, T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
/// a function to a contained value.
pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
match self {
Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
Self::Local(l) => Backend::Local(l),
}
}
}
impl<'a, T, E> Backend<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
match self {
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
Self::Local(l) => Ok(Backend::Local(l)),
ControlPlaneBackend::Management(endpoint) => fmt
.debug_tuple("ControlPlane::Management")
.field(&endpoint.url())
.finish(),
#[cfg(any(test, feature = "testing"))]
ControlPlaneBackend::PostgresMock(endpoint) => fmt
.debug_tuple("ControlPlane::PostgresMock")
.field(&endpoint.url())
.finish(),
#[cfg(test)]
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
}
}
}
@@ -399,96 +341,79 @@ async fn authenticate_with_secret(
classic::authenticate(ctx, info, client, config, secret).await
}
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
match self {
Self::ControlPlane(_, user_info) => &user_info.user,
Self::Local(_) => "local",
}
}
/// Authenticate the client via the requested backend, possibly using credentials.
impl ControlPlaneBackend {
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub(crate) async fn authenticate(
self,
&self,
ctx: &RequestMonitoring,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<Backend<'a, ComputeCredentials>> {
let res = match self {
Self::ControlPlane(api, user_info) => {
info!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
) -> auth::Result<ControlPlaneComputeBackend> {
info!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let credentials = auth_quirks(
ctx,
&*api,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
Backend::ControlPlane(api, credentials)
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
}
};
let credentials = auth_quirks(
ctx,
self,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
info!("user successfully authenticated");
Ok(res)
Ok(ControlPlaneComputeBackend {
api: self,
creds: credentials,
})
}
pub(crate) fn attach_to_credentials(
&self,
creds: ComputeCredentials,
) -> ControlPlaneComputeBackend {
ControlPlaneComputeBackend { api: self, creds }
}
}
impl Backend<'_, ComputeUserInfo> {
pub(crate) async fn get_role_secret(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
Self::Local(_) => Ok(Cached::new_uncached(None)),
}
}
pub(crate) async fn get_allowed_ips_and_secret(
&self,
ctx: &RequestMonitoring,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_allowed_ips_and_secret(ctx, user_info).await
}
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
pub struct ControlPlaneComputeBackend<'a> {
api: &'a ControlPlaneBackend,
creds: ComputeCredentials,
}
#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
impl ComputeConnectBackend for ControlPlaneComputeBackend<'static> {
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
match self {
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
self.api.wake_compute(ctx, &self.creds.info).await
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::ControlPlane(_, creds) => &creds.keys,
Self::Local(_) => &ComputeCredentialKeys::None,
}
&self.creds.keys
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for LocalBackend {
async fn wake_compute(
&self,
_ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
Ok(Cached::new_uncached(self.node_info.clone()))
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&ComputeCredentialKeys::None
}
}

View File

@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::Backend;
pub use backend::ServerlessBackend;
mod credentials;
pub(crate) use credentials::{

View File

@@ -203,7 +203,7 @@ async fn main() -> anyhow::Result<()> {
let task = serverless::task_main(
config,
auth_backend,
auth::ServerlessBackend::Local(auth_backend),
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
@@ -295,12 +295,8 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &LocalProxyCliArgs,
) -> anyhow::Result<&'static auth::Backend<'static, ()>> {
let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.postgres, args.compute_ctl.clone()),
));
fn build_auth_backend(args: &LocalProxyCliArgs) -> anyhow::Result<&'static LocalBackend> {
let auth_backend = LocalBackend::new(args.postgres, args.compute_ctl.clone());
Ok(Box::leak(Box::new(auth_backend)))
}

View File

@@ -13,13 +13,14 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_config::Region;
use futures::future::Either;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend};
use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::context::parquet::ParquetUploadArgs;
use proxy::control_plane::provider::ControlPlaneBackend;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
use proxy::rate_limiter::{
@@ -467,7 +468,7 @@ async fn main() -> anyhow::Result<()> {
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
auth::ServerlessBackend::ControlPlane(auth_backend),
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
@@ -515,40 +516,38 @@ async fn main() -> anyhow::Result<()> {
));
}
if let Either::Left(auth::Backend::ControlPlane(api, _)) = &auth_backend {
if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
if let Either::Left(ControlPlaneBackend::Management(api)) = &auth_backend {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
.instrument(span),
);
}
}
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
.instrument(span),
);
}
}
@@ -694,7 +693,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,
) -> anyhow::Result<Either<&'static auth::Backend<'static, ()>, &'static ConsoleRedirectBackend>> {
) -> anyhow::Result<Either<&'static ControlPlaneBackend, &'static ConsoleRedirectBackend>> {
match &args.auth_backend {
AuthBackendType::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
@@ -744,8 +743,7 @@ fn build_auth_backend(
locks,
wake_compute_endpoint_rate_limiter,
);
let api = control_plane::provider::ControlPlaneBackend::Management(api);
let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ());
let auth_backend = control_plane::provider::ControlPlaneBackend::Management(api);
let config = Box::leak(Box::new(auth_backend));
@@ -756,9 +754,7 @@ fn build_auth_backend(
AuthBackendType::Postgres => {
let url = args.auth_endpoint.parse()?;
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
let api = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ());
let auth_backend = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
let config = Box::leak(Box::new(auth_backend));

View File

@@ -56,7 +56,7 @@ pub(crate) trait ConnectMechanism {
}
#[async_trait]
pub(crate) trait ComputeConnectBackend {
pub(crate) trait ComputeConnectBackend: Send + Sync + 'static {
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
@@ -98,10 +98,10 @@ impl ConnectMechanism for TcpMechanism<'_> {
/// Try to connect to the compute node, retrying if necessary.
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
pub(crate) async fn connect_to_compute<M: ConnectMechanism>(
ctx: &RequestMonitoring,
mechanism: &M,
user_info: &B,
user_info: &dyn ComputeConnectBackend,
allow_self_signed_compute: bool,
wake_compute_retry_config: RetryConfig,
connect_to_compute_retry_config: RetryConfig,

View File

@@ -26,6 +26,7 @@ use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestMonitoring;
use crate::control_plane::provider::ControlPlaneBackend;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::read_proxy_protocol;
@@ -54,7 +55,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
auth_backend: &'static ControlPlaneBackend,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -241,7 +242,7 @@ impl ReportableError for ClientRequestError {
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
auth_backend: &'static ControlPlaneBackend,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
@@ -282,20 +283,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names);
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e).await?,
};
let user = user_info.get_user().to_owned();
let user_info = match user_info
let user = user_info.user.clone();
let user_info = match auth_backend
.authenticate(
ctx,
user_info,
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,

View File

@@ -6,6 +6,7 @@ use std::time::Duration;
use anyhow::{bail, Context};
use async_trait::async_trait;
use auth::backend::ControlPlaneComputeBackend;
use http::StatusCode;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
@@ -19,7 +20,7 @@ use super::connect_compute::ConnectMechanism;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, TestBackend,
};
use crate::config::{CertResolver, RetryConfig};
use crate::control_plane::messages::{ControlPlaneError, Details, MetricsAuxInfo, Status};
@@ -566,19 +567,21 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> auth::Backend<'static, ComputeCredentials> {
let user_info = auth::Backend::ControlPlane(
MaybeOwned::Owned(ControlPlaneBackend::Test(Box::new(mechanism.clone()))),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::Password("password".into()),
) -> ControlPlaneComputeBackend<'static> {
let api = Box::leak(Box::new(ControlPlaneBackend::Test(Box::new(
mechanism.clone(),
))));
let creds = ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
);
user_info
keys: ComputeCredentialKeys::Password("password".into()),
};
api.attach_to_credentials(creds)
}
#[tokio::test]

View File

@@ -11,10 +11,10 @@ use crate::metrics::{
};
use crate::proxy::retry::{retry_after, should_retry};
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
pub(crate) async fn wake_compute(
num_retries: &mut u32,
ctx: &RequestMonitoring,
api: &B,
api: &dyn ComputeConnectBackend,
config: RetryConfig,
) -> Result<CachedNodeInfo, WakeComputeError> {
let retry_type = RetryType::WakeCompute;

View File

@@ -15,9 +15,9 @@ use super::conn_pool::poll_client;
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
use super::http_conn_pool::{self, poll_http2_client, Send};
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::local::{LocalBackend, StaticAuthRules};
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, check_peer_addr_is_in_list, AuthError};
use crate::auth::{check_peer_addr_is_in_list, AuthError, ServerlessBackend};
use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
@@ -26,11 +26,11 @@ use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::provider::ApiLockError;
use crate::control_plane::CachedNodeInfo;
use crate::control_plane::provider::{ApiLockError, ControlPlaneBackend};
use crate::control_plane::{Api, CachedNodeInfo};
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::connect_compute::{ComputeConnectBackend, ConnectMechanism};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host};
@@ -41,7 +41,6 @@ pub(crate) struct PoolingBackend {
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
@@ -49,12 +48,13 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_password(
&self,
ctx: &RequestMonitoring,
auth_backend: &ControlPlaneBackend,
user_info: &ComputeUserInfo,
password: &[u8],
) -> Result<ComputeCredentials, AuthError> {
let user_info = user_info.clone();
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
let (allowed_ips, maybe_secret) = auth_backend
.get_allowed_ips_and_secret(ctx, user_info)
.await?;
if self.config.authentication_config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
@@ -68,7 +68,7 @@ impl PoolingBackend {
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => backend.get_role_secret(ctx).await?,
None => auth_backend.get_role_secret(ctx, user_info).await?,
};
let secret = match cached_secret.value.clone() {
@@ -103,7 +103,7 @@ impl PoolingBackend {
}
};
res.map(|key| ComputeCredentials {
info: user_info,
info: user_info.clone(),
keys: key,
})
}
@@ -111,11 +111,12 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
auth_backend: ServerlessBackend<'static>,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<ComputeCredentials, AuthError> {
match &self.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
match auth_backend {
ServerlessBackend::ControlPlane(console) => {
self.config
.authentication_config
.jwks_cache
@@ -123,7 +124,7 @@ impl PoolingBackend {
ctx,
user_info.endpoint.clone(),
&user_info.user,
&**console,
console,
&jwt,
)
.await
@@ -134,7 +135,7 @@ impl PoolingBackend {
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
}
crate::auth::Backend::Local(_) => {
ServerlessBackend::Local(_) => {
let keys = self
.config
.authentication_config
@@ -164,6 +165,7 @@ impl PoolingBackend {
pub(crate) async fn connect_to_compute(
&self,
ctx: &RequestMonitoring,
auth_backend: ServerlessBackend<'static>,
conn_info: ConnInfo,
keys: ComputeCredentials,
force_new: bool,
@@ -182,7 +184,14 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys);
let api = match auth_backend {
ServerlessBackend::ControlPlane(cplane) => {
&cplane.attach_to_credentials(keys) as &dyn ComputeConnectBackend
}
ServerlessBackend::Local(local_proxy) => local_proxy as &dyn ComputeConnectBackend,
};
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
@@ -191,7 +200,7 @@ impl PoolingBackend {
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
api,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
@@ -204,6 +213,7 @@ impl PoolingBackend {
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestMonitoring,
auth_backend: &'static ControlPlaneBackend,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
info!("pool: looking for an existing connection");
@@ -214,7 +224,8 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
let backend = auth_backend.attach_to_credentials(ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)),
@@ -249,26 +260,20 @@ impl PoolingBackend {
pub(crate) async fn connect_to_local_postgres(
&self,
ctx: &RequestMonitoring,
auth_backend: &LocalBackend,
conn_info: ConnInfo,
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
return Ok(client);
}
let local_backend = match &self.auth_backend {
auth::Backend::ControlPlane(_, ()) => {
unreachable!("only local_proxy can connect to local postgres")
}
auth::Backend::Local(local) => local,
};
if !self.local_pool.initialized(&conn_info) {
// only install and grant usage one at a time.
let _permit = local_backend.initialize.acquire().await.unwrap();
let _permit = auth_backend.initialize.acquire().await.unwrap();
// check again for race
if !self.local_pool.initialized(&conn_info) {
local_backend
auth_backend
.compute_ctl
.install_extension(&ExtensionInstallRequest {
extension: EXT_NAME,
@@ -277,7 +282,7 @@ impl PoolingBackend {
})
.await?;
local_backend
auth_backend
.compute_ctl
.grant_role(&SetRoleGrantsRequest {
schema: EXT_SCHEMA,
@@ -295,7 +300,7 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = local_backend.node_info.clone();
let mut node_info = auth_backend.node_info.clone();
let (key, jwk) = create_random_jwk();

View File

@@ -41,6 +41,7 @@ use tokio_util::task::TaskTracker;
use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
use crate::auth::ServerlessBackend;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
@@ -55,7 +56,7 @@ pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
auth_backend: ServerlessBackend<'static>,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -111,7 +112,6 @@ pub async fn task_main(
local_pool,
pool: Arc::clone(&conn_pool),
config,
auth_backend,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
@@ -184,6 +184,7 @@ pub async fn task_main(
Box::pin(connection_handler(
config,
auth_backend,
backend,
connections2,
cancellation_handler,
@@ -289,6 +290,7 @@ async fn connection_startup(
#[allow(clippy::too_many_arguments)]
async fn connection_handler(
config: &'static ProxyConfig,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -323,6 +325,7 @@ async fn connection_handler(
request_handler(
req,
config,
auth_backend,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
@@ -362,6 +365,7 @@ async fn connection_handler(
async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -382,6 +386,10 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ServerlessBackend::ControlPlane(auth_backend) = auth_backend else {
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
};
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
@@ -399,7 +407,7 @@ async fn request_handler(
async move {
if let Err(e) = websocket::serve_websocket(
config,
backend.auth_backend,
auth_backend,
ctx,
websocket,
cancellation_handler,
@@ -425,9 +433,16 @@ async fn request_handler(
);
let span = ctx.span();
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
sql_over_http::handle(
config,
ctx,
request,
auth_backend,
backend,
http_cancellation_token,
)
.instrument(span)
.await
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")

View File

@@ -30,10 +30,11 @@ use super::conn_pool_lib::{self, ConnInfo};
use super::http_util::json_response;
use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
use super::local_conn_pool;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::{endpoint_sni, ComputeUserInfoParseError, ServerlessBackend};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
use crate::context::RequestMonitoring;
use crate::control_plane::provider::ControlPlaneBackend;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::{HttpDirection, Metrics};
use crate::proxy::{run_until_cancelled, NeonOptions};
@@ -240,10 +241,11 @@ pub(crate) async fn handle(
config: &'static ProxyConfig,
ctx: RequestMonitoring,
request: Request<Incoming>,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let result = handle_inner(cancel, config, &ctx, request, auth_backend, backend).await;
let mut response = match result {
Ok(r) => {
@@ -498,6 +500,7 @@ async fn handle_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get()
@@ -522,7 +525,11 @@ async fn handle_inner(
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
let ServerlessBackend::ControlPlane(cplane) = auth_backend else {
panic!("auth_broker must be configured with a control-plane auth backend.")
};
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, cplane, backend).await
}
auth => {
handle_db_inner(
@@ -532,6 +539,7 @@ async fn handle_inner(
request,
conn_info.conn_info,
auth,
auth_backend,
backend,
)
.await
@@ -539,6 +547,7 @@ async fn handle_inner(
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_db_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
@@ -546,6 +555,7 @@ async fn handle_db_inner(
request: Request<Incoming>,
conn_info: ConnInfo,
auth: AuthData,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
//
@@ -588,45 +598,58 @@ async fn handle_db_inner(
.map_err(SqlOverHttpError::from),
);
let authenticate_and_connect = Box::pin(
async {
let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_));
let authenticate_and_connect = Box::pin(async {
let creds = match auth {
AuthData::Password(pw) => {
let ServerlessBackend::ControlPlane(cplane) = auth_backend else {
return Err(SqlOverHttpError::ConnInfo(
ConnInfoError::MissingCredentials(Credentials::BearerJwt),
));
};
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await?
}
};
backend
.authenticate_with_password(ctx, cplane, &conn_info.user_info, &pw)
.await
.map_err(HttpConnError::from)?
}
AuthData::Jwt(jwt) => backend
.authenticate_with_jwt(ctx, auth_backend, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::from)?,
};
let client = match keys.keys {
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
Client::Local(client)
}
_ => {
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
Client::Remote(client)
}
};
let client = match (creds.keys, auth_backend) {
(ComputeCredentialKeys::JwtPayload(payload), ServerlessBackend::Local(local)) => {
let mut client = backend
.connect_to_local_postgres(ctx, local, conn_info)
.await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
Client::Local(client)
}
(keys, auth_backend) => {
let client = backend
.connect_to_compute(
ctx,
auth_backend,
conn_info,
ComputeCredentials {
keys,
info: creds.info,
},
!allow_pool,
)
.await
.map_err(HttpConnError::from)?;
Client::Remote(client)
}
};
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.success();
Ok::<_, HttpConnError>(client)
}
.map_err(SqlOverHttpError::from),
);
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.success();
Ok::<_, SqlOverHttpError>(client)
});
let (payload, mut client) = match run_until_cancelled(
// Run both operations in parallel
@@ -711,14 +734,22 @@ async fn handle_auth_broker_inner(
request: Request<Incoming>,
conn_info: ConnInfo,
jwt: String,
auth_backend: &'static ControlPlaneBackend,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.authenticate_with_jwt(
ctx,
ServerlessBackend::ControlPlane(auth_backend),
&conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let mut client = backend
.connect_to_local_proxy(ctx, auth_backend, conn_info)
.await?;
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");

View File

@@ -15,6 +15,7 @@ use tracing::warn;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::provider::ControlPlaneBackend;
use crate::error::{io_error, ReportableError};
use crate::metrics::Metrics;
use crate::proxy::{handle_client, ClientMode, ErrorSource};
@@ -125,7 +126,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
auth_backend: &'static ControlPlaneBackend,
ctx: RequestMonitoring,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandlerMain>,