From 80c55768163fd24847d1deeb9fbac94af1c1bc84 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 3 Oct 2024 22:43:20 +0100 Subject: [PATCH] proxy: continue streamlining auth::Backend --- proxy/src/auth/backend/mod.rs | 216 ++++++++++-------------------- proxy/src/bin/local_proxy.rs | 10 +- proxy/src/bin/proxy.rs | 76 +++++------ proxy/src/proxy/mod.rs | 16 +-- proxy/src/proxy/tests/mod.rs | 29 ++-- proxy/src/serverless/backend.rs | 85 ++++++++---- proxy/src/serverless/mod.rs | 9 +- proxy/src/serverless/websocket.rs | 3 +- 8 files changed, 200 insertions(+), 244 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 17334b9cbb..079fc36501 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -21,10 +21,7 @@ use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserIn use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestMonitoring; -use crate::control_plane::errors::GetAuthInfoError; -use crate::control_plane::provider::{ - CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneBackend, -}; +use crate::control_plane::provider::{CachedNodeInfo, ControlPlaneBackend}; use crate::control_plane::{self, Api, AuthSecret}; use crate::intern::EndpointIdInt; use crate::metrics::Metrics; @@ -35,38 +32,16 @@ use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{scram, stream}; -/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality -pub enum MaybeOwned<'a, T> { - Owned(T), - Borrowed(&'a T), -} - -impl std::ops::Deref for MaybeOwned<'_, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - match self { - MaybeOwned::Owned(t) => t, - MaybeOwned::Borrowed(t) => t, - } - } -} - -/// This type serves two purposes: -/// -/// * When `T` is `()`, it's just a regular auth backend selector -/// which we use in [`crate::config::ProxyConfig`]. -/// -/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`], -/// this helps us provide the credentials only to those auth -/// backends which require them for the authentication process. -pub enum Backend<'a, T> { +pub enum Backend<'a> { /// Cloud API (V2). - ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T), + ControlPlane(&'a ControlPlaneBackend), /// Local proxy uses configured auth credentials and does not wake compute - Local(MaybeOwned<'a, LocalBackend>), + Local(&'a LocalBackend), } +#[cfg(test)] +use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret}; + #[cfg(test)] pub(crate) trait TestBackend: Send + Sync + 'static { fn wake_compute(&self) -> Result; @@ -83,56 +58,20 @@ impl Clone for Box { } } -impl std::fmt::Display for Backend<'_, ()> { +impl std::fmt::Display for ControlPlaneBackend { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::ControlPlane(api, ()) => match &**api { - ControlPlaneBackend::Management(endpoint) => fmt - .debug_tuple("ControlPlane::Management") - .field(&endpoint.url()) - .finish(), - #[cfg(any(test, feature = "testing"))] - ControlPlaneBackend::PostgresMock(endpoint) => fmt - .debug_tuple("ControlPlane::PostgresMock") - .field(&endpoint.url()) - .finish(), - #[cfg(test)] - ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(), - }, - Self::Local(_) => fmt.debug_tuple("Local").finish(), - } - } -} - -impl Backend<'_, T> { - /// Very similar to [`std::option::Option::as_ref`]. - /// This helps us pass structured config to async tasks. - pub(crate) fn as_ref(&self) -> Backend<'_, &T> { - match self { - Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x), - Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)), - } - } -} - -impl<'a, T> Backend<'a, T> { - /// Very similar to [`std::option::Option::map`]. - /// Maps [`Backend`] to [`Backend`] by applying - /// a function to a contained value. - pub(crate) fn map(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> { - match self { - Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)), - Self::Local(l) => Backend::Local(l), - } - } -} -impl<'a, T, E> Backend<'a, Result> { - /// Very similar to [`std::option::Option::transpose`]. - /// This is most useful for error handling. - pub(crate) fn transpose(self) -> Result, E> { - match self { - Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)), - Self::Local(l) => Ok(Backend::Local(l)), + ControlPlaneBackend::Management(endpoint) => fmt + .debug_tuple("ControlPlane::Management") + .field(&endpoint.url()) + .finish(), + #[cfg(any(test, feature = "testing"))] + ControlPlaneBackend::PostgresMock(endpoint) => fmt + .debug_tuple("ControlPlane::PostgresMock") + .field(&endpoint.url()) + .finish(), + #[cfg(test)] + ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(), } } } @@ -399,96 +338,79 @@ async fn authenticate_with_secret( classic::authenticate(ctx, info, client, config, secret).await } -impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { - /// Get username from the credentials. - pub(crate) fn get_user(&self) -> &str { - match self { - Self::ControlPlane(_, user_info) => &user_info.user, - Self::Local(_) => "local", - } - } - - /// Authenticate the client via the requested backend, possibly using credentials. +impl ControlPlaneBackend { #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)] pub(crate) async fn authenticate( - self, + &self, ctx: &RequestMonitoring, + user_info: ComputeUserInfoMaybeEndpoint, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, - ) -> auth::Result> { - let res = match self { - Self::ControlPlane(api, user_info) => { - info!( - user = &*user_info.user, - project = user_info.endpoint(), - "performing authentication using the console" - ); + ) -> auth::Result { + info!( + user = &*user_info.user, + project = user_info.endpoint(), + "performing authentication using the console" + ); - let credentials = auth_quirks( - ctx, - &*api, - user_info, - client, - allow_cleartext, - config, - endpoint_rate_limiter, - ) - .await?; - Backend::ControlPlane(api, credentials) - } - Self::Local(_) => { - return Err(auth::AuthError::bad_auth_method("invalid for local proxy")) - } - }; + let credentials = auth_quirks( + ctx, + self, + user_info, + client, + allow_cleartext, + config, + endpoint_rate_limiter, + ) + .await?; info!("user successfully authenticated"); - Ok(res) + Ok(ControlPlaneComputeBackend { + api: self, + creds: credentials, + }) + } + + pub(crate) fn attach_to_credentials( + &self, + creds: ComputeCredentials, + ) -> ControlPlaneComputeBackend { + ControlPlaneComputeBackend { api: self, creds } } } -impl Backend<'_, ComputeUserInfo> { - pub(crate) async fn get_role_secret( - &self, - ctx: &RequestMonitoring, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(None)), - } - } - - pub(crate) async fn get_allowed_ips_and_secret( - &self, - ctx: &RequestMonitoring, - ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { - match self { - Self::ControlPlane(api, user_info) => { - api.get_allowed_ips_and_secret(ctx, user_info).await - } - Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), - } - } +pub struct ControlPlaneComputeBackend<'a> { + api: &'a ControlPlaneBackend, + creds: ComputeCredentials, } #[async_trait::async_trait] -impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { +impl ComputeConnectBackend for ControlPlaneComputeBackend<'_> { async fn wake_compute( &self, ctx: &RequestMonitoring, ) -> Result { - match self { - Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await, - Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), - } + self.api.wake_compute(ctx, &self.creds.info).await } fn get_keys(&self) -> &ComputeCredentialKeys { - match self { - Self::ControlPlane(_, creds) => &creds.keys, - Self::Local(_) => &ComputeCredentialKeys::None, - } + &self.creds.keys + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for LocalBackend { + async fn wake_compute( + &self, + _ctx: &RequestMonitoring, + ) -> Result { + Ok(Cached::new_uncached(self.node_info.clone())) + } + + fn get_keys(&self) -> &ComputeCredentialKeys { + &ComputeCredentialKeys::None } } diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index df3628465f..5862fa6c3d 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -203,7 +203,7 @@ async fn main() -> anyhow::Result<()> { let task = serverless::task_main( config, - auth_backend, + auth::Backend::Local(auth_backend), http_listener, shutdown.clone(), Arc::new(CancellationHandlerMain::new( @@ -295,12 +295,8 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig } /// auth::Backend is created at proxy startup, and lives forever. -fn build_auth_backend( - args: &LocalProxyCliArgs, -) -> anyhow::Result<&'static auth::Backend<'static, ()>> { - let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( - LocalBackend::new(args.postgres, args.compute_ctl.clone()), - )); +fn build_auth_backend(args: &LocalProxyCliArgs) -> anyhow::Result<&'static LocalBackend> { + let auth_backend = LocalBackend::new(args.postgres, args.compute_ctl.clone()); Ok(Box::leak(Box::new(auth_backend))) } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 6e190029aa..b1d8367667 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -13,13 +13,14 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; use aws_config::Region; use futures::future::Either; use proxy::auth::backend::jwt::JwkCache; -use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; +use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend}; use proxy::cancellation::{CancelMap, CancellationHandler}; use proxy::config::{ self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig, ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, }; use proxy::context::parquet::ParquetUploadArgs; +use proxy::control_plane::provider::ControlPlaneBackend; use proxy::http::health_server::AppMetrics; use proxy::metrics::Metrics; use proxy::rate_limiter::{ @@ -467,7 +468,7 @@ async fn main() -> anyhow::Result<()> { if let Some(serverless_listener) = serverless_listener { client_tasks.spawn(serverless::task_main( config, - auth_backend, + auth::Backend::ControlPlane(auth_backend), serverless_listener, cancellation_token.clone(), cancellation_handler.clone(), @@ -515,40 +516,38 @@ async fn main() -> anyhow::Result<()> { )); } - if let Either::Left(auth::Backend::ControlPlane(api, _)) = &auth_backend { - if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api { - match (redis_notifications_client, regional_redis_client.clone()) { - (None, None) => {} - (client1, client2) => { - let cache = api.caches.project_info.clone(); - if let Some(client) = client1 { - maintenance_tasks.spawn(notifications::task_main( - client, - cache.clone(), - cancel_map.clone(), - args.region.clone(), - )); - } - if let Some(client) = client2 { - maintenance_tasks.spawn(notifications::task_main( - client, - cache.clone(), - cancel_map.clone(), - args.region.clone(), - )); - } - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + if let Either::Left(ControlPlaneBackend::Management(api)) = &auth_backend { + match (redis_notifications_client, regional_redis_client.clone()) { + (None, None) => {} + (client1, client2) => { + let cache = api.caches.project_info.clone(); + if let Some(client) = client1 { + maintenance_tasks.spawn(notifications::task_main( + client, + cache.clone(), + cancel_map.clone(), + args.region.clone(), + )); } + if let Some(client) = client2 { + maintenance_tasks.spawn(notifications::task_main( + client, + cache.clone(), + cancel_map.clone(), + args.region.clone(), + )); + } + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } - if let Some(regional_redis_client) = regional_redis_client { - let cache = api.caches.endpoints_cache.clone(); - let con = regional_redis_client; - let span = tracing::info_span!("endpoints_cache"); - maintenance_tasks.spawn( - async move { cache.do_read(con, cancellation_token.clone()).await } - .instrument(span), - ); - } + } + if let Some(regional_redis_client) = regional_redis_client { + let cache = api.caches.endpoints_cache.clone(); + let con = regional_redis_client; + let span = tracing::info_span!("endpoints_cache"); + maintenance_tasks.spawn( + async move { cache.do_read(con, cancellation_token.clone()).await } + .instrument(span), + ); } } @@ -694,7 +693,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { /// auth::Backend is created at proxy startup, and lives forever. fn build_auth_backend( args: &ProxyCliArgs, -) -> anyhow::Result, &'static ConsoleRedirectBackend>> { +) -> anyhow::Result> { match &args.auth_backend { AuthBackendType::Console => { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; @@ -744,8 +743,7 @@ fn build_auth_backend( locks, wake_compute_endpoint_rate_limiter, ); - let api = control_plane::provider::ControlPlaneBackend::Management(api); - let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ()); + let auth_backend = control_plane::provider::ControlPlaneBackend::Management(api); let config = Box::leak(Box::new(auth_backend)); @@ -756,9 +754,7 @@ fn build_auth_backend( AuthBackendType::Postgres => { let url = args.auth_endpoint.parse()?; let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy); - let api = control_plane::provider::ControlPlaneBackend::PostgresMock(api); - - let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ()); + let auth_backend = control_plane::provider::ControlPlaneBackend::PostgresMock(api); let config = Box::leak(Box::new(auth_backend)); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 2970d93393..f5ce5a1ff5 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -26,6 +26,7 @@ use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestMonitoring; +use crate::control_plane::provider::ControlPlaneBackend; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; use crate::protocol2::read_proxy_protocol; @@ -54,7 +55,7 @@ pub async fn run_until_cancelled( pub async fn task_main( config: &'static ProxyConfig, - auth_backend: &'static auth::Backend<'static, ()>, + auth_backend: &'static ControlPlaneBackend, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -241,7 +242,7 @@ impl ReportableError for ClientRequestError { #[allow(clippy::too_many_arguments)] pub(crate) async fn handle_client( config: &'static ProxyConfig, - auth_backend: &'static auth::Backend<'static, ()>, + auth_backend: &'static ControlPlaneBackend, ctx: &RequestMonitoring, cancellation_handler: Arc, stream: S, @@ -282,20 +283,17 @@ pub(crate) async fn handle_client( let common_names = tls.map(|tls| &tls.common_names); // Extract credentials which we're going to use for auth. - let result = auth_backend - .as_ref() - .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) - .transpose(); - + let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names); let user_info = match result { Ok(user_info) => user_info, Err(e) => stream.throw_error(e).await?, }; - let user = user_info.get_user().to_owned(); - let user_info = match user_info + let user = user_info.user.clone(); + let user_info = match auth_backend .authenticate( ctx, + user_info, &mut stream, mode.allow_cleartext(), &config.authentication_config, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index fe62fee204..5bc6cb8fa5 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -6,6 +6,7 @@ use std::time::Duration; use anyhow::{bail, Context}; use async_trait::async_trait; +use auth::backend::ControlPlaneComputeBackend; use http::StatusCode; use retry::{retry_after, ShouldRetryWakeCompute}; use rstest::rstest; @@ -19,7 +20,7 @@ use super::connect_compute::ConnectMechanism; use super::retry::CouldRetry; use super::*; use crate::auth::backend::{ - ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend, + ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, TestBackend, }; use crate::config::{CertResolver, RetryConfig}; use crate::control_plane::messages::{ControlPlaneError, Details, MetricsAuxInfo, Status}; @@ -566,19 +567,21 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> auth::Backend<'static, ComputeCredentials> { - let user_info = auth::Backend::ControlPlane( - MaybeOwned::Owned(ControlPlaneBackend::Test(Box::new(mechanism.clone()))), - ComputeCredentials { - info: ComputeUserInfo { - endpoint: "endpoint".into(), - user: "user".into(), - options: NeonOptions::parse_options_raw(""), - }, - keys: ComputeCredentialKeys::Password("password".into()), +) -> ControlPlaneComputeBackend<'static> { + let api = Box::leak(Box::new(ControlPlaneBackend::Test(Box::new( + mechanism.clone(), + )))); + + let creds = ComputeCredentials { + info: ComputeUserInfo { + endpoint: "endpoint".into(), + user: "user".into(), + options: NeonOptions::parse_options_raw(""), }, - ); - user_info + keys: ComputeCredentialKeys::Password("password".into()), + }; + + api.attach_to_credentials(creds) } #[tokio::test] diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 07e0e30148..eea9ec7341 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -27,7 +27,7 @@ use crate::context::RequestMonitoring; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::control_plane::provider::ApiLockError; -use crate::control_plane::CachedNodeInfo; +use crate::control_plane::{Api, CachedNodeInfo}; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; use crate::proxy::connect_compute::ConnectMechanism; @@ -41,7 +41,7 @@ pub(crate) struct PoolingBackend { pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, - pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>, + pub(crate) auth_backend: crate::auth::Backend<'static>, pub(crate) endpoint_rate_limiter: Arc, } @@ -52,9 +52,16 @@ impl PoolingBackend { user_info: &ComputeUserInfo, password: &[u8], ) -> Result { - let user_info = user_info.clone(); - let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); - let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; + let cplane = match self.auth_backend { + crate::auth::Backend::ControlPlane(cplane) => cplane, + crate::auth::Backend::Local(_local) => { + return Err(AuthError::bad_auth_method( + "password authentication not supported by local_proxy", + )) + } + }; + + let (allowed_ips, maybe_secret) = cplane.get_allowed_ips_and_secret(ctx, user_info).await?; if self.config.authentication_config.ip_allowlist_check_enabled && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { @@ -68,7 +75,7 @@ impl PoolingBackend { } let cached_secret = match maybe_secret { Some(secret) => secret, - None => backend.get_role_secret(ctx).await?, + None => cplane.get_role_secret(ctx, user_info).await?, }; let secret = match cached_secret.value.clone() { @@ -103,7 +110,7 @@ impl PoolingBackend { } }; res.map(|key| ComputeCredentials { - info: user_info, + info: user_info.clone(), keys: key, }) } @@ -115,7 +122,7 @@ impl PoolingBackend { jwt: String, ) -> Result { match &self.auth_backend { - crate::auth::Backend::ControlPlane(console, ()) => { + crate::auth::Backend::ControlPlane(console) => { self.config .authentication_config .jwks_cache @@ -182,21 +189,41 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self.auth_backend.as_ref().map(|()| keys); - crate::proxy::connect_compute::connect_to_compute( - ctx, - &TokioMechanism { - conn_id, - conn_info, - pool: self.pool.clone(), - locks: &self.config.connect_compute_locks, - }, - &backend, - false, // do not allow self signed compute for http flow - self.config.wake_compute_retry_config, - self.config.connect_to_compute_retry_config, - ) - .await + + match &self.auth_backend { + crate::auth::Backend::ControlPlane(cplane) => { + crate::proxy::connect_compute::connect_to_compute( + ctx, + &TokioMechanism { + conn_id, + conn_info, + pool: self.pool.clone(), + locks: &self.config.connect_compute_locks, + }, + &cplane.attach_to_credentials(keys), + false, // do not allow self signed compute for http flow + self.config.wake_compute_retry_config, + self.config.connect_to_compute_retry_config, + ) + .await + } + crate::auth::Backend::Local(local_proxy) => { + crate::proxy::connect_compute::connect_to_compute( + ctx, + &TokioMechanism { + conn_id, + conn_info, + pool: self.pool.clone(), + locks: &self.config.connect_compute_locks, + }, + &**local_proxy, + false, // do not allow self signed compute for http flow + self.config.wake_compute_retry_config, + self.config.connect_to_compute_retry_config, + ) + .await + } + } } // Wake up the destination if needed @@ -206,6 +233,13 @@ impl PoolingBackend { ctx: &RequestMonitoring, conn_info: ConnInfo, ) -> Result, HttpConnError> { + let cplane = match &self.auth_backend { + crate::auth::Backend::Local(_) => { + panic!("connect to local_proxy should not be called if we are already local_proxy") + } + crate::auth::Backend::ControlPlane(cplane) => cplane, + }; + info!("pool: looking for an existing connection"); if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) { return Ok(client); @@ -214,7 +248,8 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials { + + let backend = cplane.attach_to_credentials(ComputeCredentials { info: ComputeUserInfo { user: conn_info.user_info.user.clone(), endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)), @@ -256,7 +291,7 @@ impl PoolingBackend { } let local_backend = match &self.auth_backend { - auth::Backend::ControlPlane(_, ()) => { + auth::Backend::ControlPlane(_) => { unreachable!("only local_proxy can connect to local postgres") } auth::Backend::Local(local) => local, diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 29ff7b9d91..a302ce31c9 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -41,6 +41,7 @@ use tokio_util::task::TaskTracker; use tracing::{info, warn, Instrument}; use utils::http::error::ApiError; +use crate::auth::Backend; use crate::cancellation::CancellationHandlerMain; use crate::config::ProxyConfig; use crate::context::RequestMonitoring; @@ -55,7 +56,7 @@ pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub async fn task_main( config: &'static ProxyConfig, - auth_backend: &'static crate::auth::Backend<'static, ()>, + auth_backend: crate::auth::Backend<'static>, ws_listener: TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -382,6 +383,10 @@ async fn request_handler( if config.http_config.accept_websockets && framed_websockets::upgrade::is_upgrade_request(&request) { + let Backend::ControlPlane(auth_backend) = backend.auth_backend else { + return json_response(StatusCode::BAD_REQUEST, "query is not supported"); + }; + let ctx = RequestMonitoring::new( session_id, peer_addr, @@ -399,7 +404,7 @@ async fn request_handler( async move { if let Err(e) = websocket::serve_websocket( config, - backend.auth_backend, + auth_backend, ctx, websocket, cancellation_handler, diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index ba36116c2c..2b4f6a1552 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -15,6 +15,7 @@ use tracing::warn; use crate::cancellation::CancellationHandlerMain; use crate::config::ProxyConfig; use crate::context::RequestMonitoring; +use crate::control_plane::provider::ControlPlaneBackend; use crate::error::{io_error, ReportableError}; use crate::metrics::Metrics; use crate::proxy::{handle_client, ClientMode, ErrorSource}; @@ -125,7 +126,7 @@ impl AsyncBufRead for WebSocketRw { pub(crate) async fn serve_websocket( config: &'static ProxyConfig, - auth_backend: &'static crate::auth::Backend<'static, ()>, + auth_backend: &'static ControlPlaneBackend, ctx: RequestMonitoring, websocket: OnUpgrade, cancellation_handler: Arc,