From ab5bbb445bcd76410d884f3431a4dcba3ec8fb37 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 11 Oct 2024 21:14:52 +0200 Subject: [PATCH] proxy: refactor auth backends (#9271) preliminary for #9270 The auth::Backend didn't need to be in the mega ProxyConfig object, so I split it off and passed it manually in the few places it was necessary. I've also refined some of the uses of config I saw while doing this small refactor. I've also followed the trend and make the console redirect backend it's own struct, same as LocalBackend and ControlPlaneBackend. --- proxy/src/auth/backend/console_redirect.rs | 25 +++- proxy/src/auth/backend/mod.rs | 19 ++- proxy/src/bin/local_proxy.rs | 25 +++- proxy/src/bin/proxy.rs | 154 +++++++++++---------- proxy/src/config.rs | 6 +- proxy/src/proxy/mod.rs | 7 +- proxy/src/serverless/backend.rs | 60 ++++---- proxy/src/serverless/mod.rs | 3 + proxy/src/serverless/sql_over_http.rs | 44 ++---- proxy/src/serverless/websocket.rs | 2 + 10 files changed, 186 insertions(+), 159 deletions(-) diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index a7cc678187..127be545e1 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -25,6 +25,10 @@ pub(crate) enum WebAuthError { Io(#[from] std::io::Error), } +pub struct ConsoleRedirectBackend { + console_uri: reqwest::Url, +} + impl UserFacingError for WebAuthError { fn to_string_client(&self) -> String { "Internal error".to_string() @@ -57,7 +61,26 @@ pub(crate) fn new_psql_session_id() -> String { hex::encode(rand::random::<[u8; 8]>()) } -pub(super) async fn authenticate( +impl ConsoleRedirectBackend { + pub fn new(console_uri: reqwest::Url) -> Self { + Self { console_uri } + } + + pub(super) fn url(&self) -> &reqwest::Url { + &self.console_uri + } + + pub(crate) async fn authenticate( + &self, + ctx: &RequestMonitoring, + auth_config: &'static AuthenticationConfig, + client: &mut PqStream, + ) -> auth::Result { + authenticate(ctx, auth_config, &self.console_uri, client).await + } +} + +async fn authenticate( ctx: &RequestMonitoring, auth_config: &'static AuthenticationConfig, link_uri: &reqwest::Url, diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index c9aa5b7e61..27c9f1876e 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -8,6 +8,7 @@ use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; +pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::WebAuthError; use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; @@ -36,7 +37,7 @@ use crate::{ provider::{CachedAllowedIps, CachedNodeInfo}, Api, }, - stream, url, + stream, }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; @@ -69,7 +70,7 @@ pub enum Backend<'a, T, D> { /// Cloud API (V2). ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T), /// Authentication via a web browser. - ConsoleRedirect(MaybeOwned<'a, url::ApiUrl>, D), + ConsoleRedirect(MaybeOwned<'a, ConsoleRedirectBackend>, D), /// Local proxy uses configured auth credentials and does not wake compute Local(MaybeOwned<'a, LocalBackend>), } @@ -106,9 +107,9 @@ impl std::fmt::Display for Backend<'_, (), ()> { #[cfg(test)] ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(), }, - Self::ConsoleRedirect(url, ()) => fmt + Self::ConsoleRedirect(backend, ()) => fmt .debug_tuple("ConsoleRedirect") - .field(&url.as_str()) + .field(&backend.url().as_str()) .finish(), Self::Local(_) => fmt.debug_tuple("Local").finish(), } @@ -241,7 +242,6 @@ impl AuthenticationConfig { pub(crate) fn check_rate_limit( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, secret: AuthSecret, endpoint: &EndpointId, is_cleartext: bool, @@ -265,7 +265,7 @@ impl AuthenticationConfig { let limit_not_exceeded = self.rate_limiter.check( ( endpoint_int, - MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet), + MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), ), password_weight, ); @@ -339,7 +339,6 @@ async fn auth_quirks( let secret = if let Some(secret) = secret { config.check_rate_limit( ctx, - config, secret, &info.endpoint, unauthenticated_password.is_some() || allow_cleartext, @@ -456,12 +455,12 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> { Backend::ControlPlane(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Self::ConsoleRedirect(url, ()) => { + Self::ConsoleRedirect(backend, ()) => { info!("performing web authentication"); - let info = console_redirect::authenticate(ctx, config, &url, client).await?; + let info = backend.authenticate(ctx, config, client).await?; - Backend::ConsoleRedirect(url, info) + Backend::ConsoleRedirect(backend, info) } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")) diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index ae8a7f0841..c781af846a 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -6,9 +6,12 @@ use compute_api::spec::LocalProxySpec; use dashmap::DashMap; use futures::future::Either; use proxy::{ - auth::backend::{ - jwt::JwkCache, - local::{LocalBackend, JWKS_ROLE_MAP}, + auth::{ + self, + backend::{ + jwt::JwkCache, + local::{LocalBackend, JWKS_ROLE_MAP}, + }, }, cancellation::CancellationHandlerMain, config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig}, @@ -132,6 +135,7 @@ async fn main() -> anyhow::Result<()> { let args = LocalProxyCliArgs::parse(); let config = build_config(&args)?; + let auth_backend = build_auth_backend(&args)?; // before we bind to any ports, write the process ID to a file // so that compute-ctl can find our process later @@ -193,6 +197,7 @@ async fn main() -> anyhow::Result<()> { let task = serverless::task_main( config, + auth_backend, http_listener, shutdown.clone(), Arc::new(CancellationHandlerMain::new( @@ -257,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig Ok(Box::leak(Box::new(ProxyConfig { tls_config: None, - auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( - LocalBackend::new(args.compute), - )), metric_collection: None, allow_self_signed_compute: false, http_config, @@ -286,6 +288,17 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig }))) } +/// auth::Backend is created at proxy startup, and lives forever. +fn build_auth_backend( + args: &LocalProxyCliArgs, +) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> { + let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( + LocalBackend::new(args.compute), + )); + + Ok(Box::leak(Box::new(auth_backend))) +} + async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc) { loop { rx.notified().await; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 7488cce3c4..3f4c2df809 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -10,6 +10,7 @@ use futures::future::Either; use proxy::auth; use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::AuthRateLimiter; +use proxy::auth::backend::ConsoleRedirectBackend; use proxy::auth::backend::MaybeOwned; use proxy::cancellation::CancelMap; use proxy::cancellation::CancellationHandler; @@ -311,8 +312,9 @@ async fn main() -> anyhow::Result<()> { let args = ProxyCliArgs::parse(); let config = build_config(&args)?; + let auth_backend = build_auth_backend(&args)?; - info!("Authentication backend: {}", config.auth_backend); + info!("Authentication backend: {}", auth_backend); info!("Using region: {}", args.aws_region); let region_provider = @@ -462,6 +464,7 @@ async fn main() -> anyhow::Result<()> { if let Some(proxy_listener) = proxy_listener { client_tasks.spawn(proxy::proxy::task_main( config, + auth_backend, proxy_listener, cancellation_token.clone(), cancellation_handler.clone(), @@ -472,6 +475,7 @@ async fn main() -> anyhow::Result<()> { if let Some(serverless_listener) = serverless_listener { client_tasks.spawn(serverless::task_main( config, + auth_backend, serverless_listener, cancellation_token.clone(), cancellation_handler.clone(), @@ -506,7 +510,7 @@ async fn main() -> anyhow::Result<()> { )); } - if let auth::Backend::ControlPlane(api, _) = &config.auth_backend { + if let auth::Backend::ControlPlane(api, _) = auth_backend { if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api { match (redis_notifications_client, regional_redis_client.clone()) { (None, None) => {} @@ -610,6 +614,80 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { bail!("dynamic rate limiter should be disabled"); } + let config::ConcurrencyLockOptions { + shards, + limiter, + epoch, + timeout, + } = args.connect_compute_lock.parse()?; + info!( + ?limiter, + shards, + ?epoch, + "Using NodeLocks (connect_compute)" + ); + let connect_compute_locks = control_plane::locks::ApiLocks::new( + "connect_compute_lock", + limiter, + shards, + timeout, + epoch, + &Metrics::get().proxy.connect_compute_lock, + )?; + + let http_config = HttpConfig { + accept_websockets: !args.is_auth_broker, + pool_options: GlobalConnPoolOptions { + max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, + gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, + pool_shards: args.sql_over_http.sql_over_http_pool_shards, + idle_timeout: args.sql_over_http.sql_over_http_idle_timeout, + opt_in: args.sql_over_http.sql_over_http_pool_opt_in, + max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, + }, + cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), + client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, + max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes, + max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, + }; + let authentication_config = AuthenticationConfig { + jwks_cache: JwkCache::default(), + thread_pool, + scram_protocol_timeout: args.scram_protocol_timeout, + rate_limiter_enabled: args.auth_rate_limit_enabled, + rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), + rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, + ip_allowlist_check_enabled: !args.is_private_access_proxy, + is_auth_broker: args.is_auth_broker, + accept_jwts: args.is_auth_broker, + webauth_confirmation_timeout: args.webauth_confirmation_timeout, + }; + + let config = Box::leak(Box::new(ProxyConfig { + tls_config, + metric_collection, + allow_self_signed_compute: args.allow_self_signed_compute, + http_config, + authentication_config, + proxy_protocol_v2: args.proxy_protocol_v2, + handshake_timeout: args.handshake_timeout, + region: args.region.clone(), + wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, + connect_compute_locks, + connect_to_compute_retry_config: config::RetryConfig::parse( + &args.connect_to_compute_retry, + )?, + })); + + tokio::spawn(config.connect_compute_locks.garbage_collect_worker()); + + Ok(config) +} + +/// auth::Backend is created at proxy startup, and lives forever. +fn build_auth_backend( + args: &ProxyCliArgs, +) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> { let auth_backend = match &args.auth_backend { AuthBackendType::Console => { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; @@ -665,7 +743,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { AuthBackendType::Web => { let url = args.uri.parse()?; - auth::Backend::ConsoleRedirect(MaybeOwned::Owned(url), ()) + auth::Backend::ConsoleRedirect(MaybeOwned::Owned(ConsoleRedirectBackend::new(url)), ()) } #[cfg(feature = "testing")] @@ -677,75 +755,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } }; - let config::ConcurrencyLockOptions { - shards, - limiter, - epoch, - timeout, - } = args.connect_compute_lock.parse()?; - info!( - ?limiter, - shards, - ?epoch, - "Using NodeLocks (connect_compute)" - ); - let connect_compute_locks = control_plane::locks::ApiLocks::new( - "connect_compute_lock", - limiter, - shards, - timeout, - epoch, - &Metrics::get().proxy.connect_compute_lock, - )?; - - let http_config = HttpConfig { - accept_websockets: !args.is_auth_broker, - pool_options: GlobalConnPoolOptions { - max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, - gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, - pool_shards: args.sql_over_http.sql_over_http_pool_shards, - idle_timeout: args.sql_over_http.sql_over_http_idle_timeout, - opt_in: args.sql_over_http.sql_over_http_pool_opt_in, - max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, - }, - cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), - client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, - max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes, - max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, - }; - let authentication_config = AuthenticationConfig { - jwks_cache: JwkCache::default(), - thread_pool, - scram_protocol_timeout: args.scram_protocol_timeout, - rate_limiter_enabled: args.auth_rate_limit_enabled, - rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), - rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, - ip_allowlist_check_enabled: !args.is_private_access_proxy, - is_auth_broker: args.is_auth_broker, - accept_jwts: args.is_auth_broker, - webauth_confirmation_timeout: args.webauth_confirmation_timeout, - }; - - let config = Box::leak(Box::new(ProxyConfig { - tls_config, - auth_backend, - metric_collection, - allow_self_signed_compute: args.allow_self_signed_compute, - http_config, - authentication_config, - proxy_protocol_v2: args.proxy_protocol_v2, - handshake_timeout: args.handshake_timeout, - region: args.region.clone(), - wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, - connect_compute_locks, - connect_to_compute_retry_config: config::RetryConfig::parse( - &args.connect_to_compute_retry, - )?, - })); - - tokio::spawn(config.connect_compute_locks.garbage_collect_worker()); - - Ok(config) + Ok(Box::leak(Box::new(auth_backend))) } #[cfg(test)] diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 55d0b6374c..c068fc50fb 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,8 +1,5 @@ use crate::{ - auth::{ - self, - backend::{jwt::JwkCache, AuthRateLimiter}, - }, + auth::backend::{jwt::JwkCache, AuthRateLimiter}, control_plane::locks::ApiLocks, rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}, scram::threadpool::ThreadPool, @@ -29,7 +26,6 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::Backend<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 9e1af88f41..3a43ccb74a 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -61,6 +61,7 @@ pub async fn run_until_cancelled( pub async fn task_main( config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, (), ()>, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -129,6 +130,7 @@ pub async fn task_main( let startup = Box::pin( handle_client( config, + auth_backend, &ctx, cancellation_handler, socket, @@ -243,8 +245,10 @@ impl ReportableError for ClientRequestError { } } +#[allow(clippy::too_many_arguments)] pub(crate) async fn handle_client( config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, (), ()>, ctx: &RequestMonitoring, cancellation_handler: Arc, stream: S, @@ -285,8 +289,7 @@ pub(crate) async fn handle_client( let common_names = tls.map(|tls| &tls.common_names); // Extract credentials which we're going to use for auth. - let result = config - .auth_backend + let result = auth_backend .as_ref() .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) .transpose(); diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index f54476b51d..9e49478cf3 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -13,7 +13,7 @@ use crate::{ check_peer_addr_is_in_list, AuthError, }, compute, - config::{AuthenticationConfig, ProxyConfig}, + config::ProxyConfig, context::RequestMonitoring, control_plane::{ errors::{GetAuthInfoError, WakeComputeError}, @@ -42,6 +42,7 @@ pub(crate) struct PoolingBackend { pub(crate) local_pool: Arc>, pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, + pub(crate) auth_backend: &'static crate::auth::Backend<'static, (), ()>, pub(crate) endpoint_rate_limiter: Arc, } @@ -49,18 +50,13 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_password( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, user_info: &ComputeUserInfo, password: &[u8], ) -> Result { let user_info = user_info.clone(); - let backend = self - .config - .auth_backend - .as_ref() - .map(|()| user_info.clone()); + let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; - if config.ip_allowlist_check_enabled + if self.config.authentication_config.ip_allowlist_check_enabled && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); @@ -79,7 +75,6 @@ impl PoolingBackend { let secret = match cached_secret.value.clone() { Some(secret) => self.config.authentication_config.check_rate_limit( ctx, - config, secret, &user_info.endpoint, true, @@ -91,9 +86,13 @@ impl PoolingBackend { } }; let ep = EndpointIdInt::from(&user_info.endpoint); - let auth_outcome = - crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret) - .await?; + let auth_outcome = crate::auth::validate_password_and_exchange( + &self.config.authentication_config.thread_pool, + ep, + password, + secret, + ) + .await?; let res = match auth_outcome { crate::sasl::Outcome::Success(key) => { info!("user successfully authenticated"); @@ -113,13 +112,13 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_jwt( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, user_info: &ComputeUserInfo, jwt: String, ) -> Result { - match &self.config.auth_backend { + match &self.auth_backend { crate::auth::Backend::ControlPlane(console, ()) => { - config + self.config + .authentication_config .jwks_cache .check_jwt( ctx, @@ -140,7 +139,9 @@ impl PoolingBackend { "JWT login over web auth proxy is not supported", )), crate::auth::Backend::Local(_) => { - let keys = config + let keys = self + .config + .authentication_config .jwks_cache .check_jwt( ctx, @@ -185,7 +186,7 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self.config.auth_backend.as_ref().map(|()| keys); + let backend = self.auth_backend.as_ref().map(|()| keys); crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { @@ -217,21 +218,14 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self - .config - .auth_backend - .as_ref() - .map(|()| ComputeCredentials { - info: ComputeUserInfo { - user: conn_info.user_info.user.clone(), - endpoint: EndpointId::from(format!( - "{}-local-proxy", - conn_info.user_info.endpoint - )), - options: conn_info.user_info.options.clone(), - }, - keys: crate::auth::backend::ComputeCredentialKeys::None, - }); + let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials { + info: ComputeUserInfo { + user: conn_info.user_info.user.clone(), + endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)), + options: conn_info.user_info.options.clone(), + }, + keys: crate::auth::backend::ComputeCredentialKeys::None, + }); crate::proxy::connect_compute::connect_to_compute( ctx, &HyperMechanism { @@ -269,7 +263,7 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); - let mut node_info = match &self.config.auth_backend { + let mut node_info = match &self.auth_backend { auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => { unreachable!("only local_proxy can connect to local postgres") } diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index b5820b0535..95f64e972c 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -55,6 +55,7 @@ pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub async fn task_main( config: &'static ProxyConfig, + auth_backend: &'static crate::auth::Backend<'static, (), ()>, ws_listener: TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -110,6 +111,7 @@ pub async fn task_main( local_pool, pool: Arc::clone(&conn_pool), config, + auth_backend, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), }); let tls_acceptor: Arc = match config.tls_config.as_ref() { @@ -397,6 +399,7 @@ async fn request_handler( async move { if let Err(e) = websocket::serve_websocket( config, + backend.auth_backend, ctx, websocket, cancellation_handler, diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 646e7f8a52..cf3324926c 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -45,6 +45,7 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; use crate::auth::ComputeUserInfoParseError; use crate::config::AuthenticationConfig; +use crate::config::HttpConfig; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; @@ -554,7 +555,7 @@ async fn handle_inner( match conn_info.auth { AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => { - handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await + handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await } auth => { handle_db_inner( @@ -622,28 +623,17 @@ async fn handle_db_inner( let authenticate_and_connect = Box::pin( async { - let is_local_proxy = - matches!(backend.config.auth_backend, crate::auth::Backend::Local(_)); + let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_)); let keys = match auth { AuthData::Password(pw) => { backend - .authenticate_with_password( - ctx, - &config.authentication_config, - &conn_info.user_info, - &pw, - ) + .authenticate_with_password(ctx, &conn_info.user_info, &pw) .await? } AuthData::Jwt(jwt) => { backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.user_info, - jwt, - ) + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) .await? } }; @@ -691,7 +681,7 @@ async fn handle_db_inner( // Now execute the query and return the result. let json_output = match payload { Payload::Single(stmt) => { - stmt.process(config, cancel, &mut client, parsed_headers) + stmt.process(&config.http_config, cancel, &mut client, parsed_headers) .await? } Payload::Batch(statements) => { @@ -709,7 +699,7 @@ async fn handle_db_inner( } statements - .process(config, cancel, &mut client, parsed_headers) + .process(&config.http_config, cancel, &mut client, parsed_headers) .await? } }; @@ -749,7 +739,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[ ]; async fn handle_auth_broker_inner( - config: &'static ProxyConfig, ctx: &RequestMonitoring, request: Request, conn_info: ConnInfo, @@ -757,12 +746,7 @@ async fn handle_auth_broker_inner( backend: Arc, ) -> Result>, SqlOverHttpError> { backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.user_info, - jwt, - ) + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) .await .map_err(HttpConnError::from)?; @@ -800,7 +784,7 @@ async fn handle_auth_broker_inner( impl QueryData { async fn process( self, - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -874,7 +858,7 @@ impl QueryData { impl BatchQueryData { async fn process( self, - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -944,7 +928,7 @@ impl BatchQueryData { } async fn query_batch( - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, transaction: &Transaction<'_>, queries: BatchQueryData, @@ -983,7 +967,7 @@ async fn query_batch( } async fn query_to_json( - config: &'static ProxyConfig, + config: &'static HttpConfig, client: &T, data: QueryData, current_size: &mut usize, @@ -1004,9 +988,9 @@ async fn query_to_json( rows.push(row); // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) - if *current_size > config.http_config.max_response_size_bytes { + if *current_size > config.max_response_size_bytes { return Err(SqlOverHttpError::ResponseTooLarge( - config.http_config.max_response_size_bytes, + config.max_response_size_bytes, )); } } diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 08d5da9bef..fd0f0cac7f 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -129,6 +129,7 @@ impl AsyncBufRead for WebSocketRw { pub(crate) async fn serve_websocket( config: &'static ProxyConfig, + auth_backend: &'static crate::auth::Backend<'static, (), ()>, ctx: RequestMonitoring, websocket: OnUpgrade, cancellation_handler: Arc, @@ -145,6 +146,7 @@ pub(crate) async fn serve_websocket( let res = Box::pin(handle_client( config, + auth_backend, &ctx, cancellation_handler, WebSocketRw::new(websocket),