diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 81f6eaa6dd..31828a7240 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -34,6 +34,7 @@ use crate::{scram, stream}; /// The [crate::serverless] module can authenticate either using control-plane /// to get authentication state, or by using JWKs stored in the filesystem. +#[derive(Clone, Copy)] pub enum ServerlessBackend<'a> { /// Cloud API (V2). ControlPlane(&'a ControlPlaneBackend), diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 1508de95e7..249eacb098 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -15,9 +15,9 @@ use super::conn_pool::poll_client; use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool}; use super::http_conn_pool::{self, poll_http2_client, Send}; use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION}; -use crate::auth::backend::local::StaticAuthRules; +use crate::auth::backend::local::{LocalBackend, StaticAuthRules}; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; -use crate::auth::{self, check_peer_addr_is_in_list, AuthError, ServerlessBackend}; +use crate::auth::{check_peer_addr_is_in_list, AuthError, ServerlessBackend}; use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, @@ -26,7 +26,7 @@ use crate::config::ProxyConfig; use crate::context::RequestMonitoring; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; -use crate::control_plane::provider::ApiLockError; +use crate::control_plane::provider::{ApiLockError, ControlPlaneBackend}; use crate::control_plane::{Api, CachedNodeInfo}; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; @@ -41,7 +41,6 @@ pub(crate) struct PoolingBackend { pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, - pub(crate) auth_backend: ServerlessBackend<'static>, pub(crate) endpoint_rate_limiter: Arc, } @@ -49,19 +48,13 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_password( &self, ctx: &RequestMonitoring, + auth_backend: &ControlPlaneBackend, user_info: &ComputeUserInfo, password: &[u8], ) -> Result { - let cplane = match self.auth_backend { - ServerlessBackend::ControlPlane(cplane) => cplane, - ServerlessBackend::Local(_local) => { - return Err(AuthError::bad_auth_method( - "password authentication not supported by local_proxy", - )) - } - }; - - let (allowed_ips, maybe_secret) = cplane.get_allowed_ips_and_secret(ctx, user_info).await?; + let (allowed_ips, maybe_secret) = auth_backend + .get_allowed_ips_and_secret(ctx, user_info) + .await?; if self.config.authentication_config.ip_allowlist_check_enabled && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { @@ -75,7 +68,7 @@ impl PoolingBackend { } let cached_secret = match maybe_secret { Some(secret) => secret, - None => cplane.get_role_secret(ctx, user_info).await?, + None => auth_backend.get_role_secret(ctx, user_info).await?, }; let secret = match cached_secret.value.clone() { @@ -118,10 +111,11 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_jwt( &self, ctx: &RequestMonitoring, + auth_backend: ServerlessBackend<'static>, user_info: &ComputeUserInfo, jwt: String, ) -> Result { - match &self.auth_backend { + match auth_backend { ServerlessBackend::ControlPlane(console) => { self.config .authentication_config @@ -130,7 +124,7 @@ impl PoolingBackend { ctx, user_info.endpoint.clone(), &user_info.user, - &**console, + console, &jwt, ) .await @@ -171,6 +165,7 @@ impl PoolingBackend { pub(crate) async fn connect_to_compute( &self, ctx: &RequestMonitoring, + auth_backend: ServerlessBackend<'static>, conn_info: ConnInfo, keys: ComputeCredentials, force_new: bool, @@ -190,11 +185,11 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let api = match &self.auth_backend { + let api = match auth_backend { ServerlessBackend::ControlPlane(cplane) => { &cplane.attach_to_credentials(keys) as &dyn ComputeConnectBackend } - ServerlessBackend::Local(local_proxy) => &**local_proxy as &dyn ComputeConnectBackend, + ServerlessBackend::Local(local_proxy) => local_proxy as &dyn ComputeConnectBackend, }; crate::proxy::connect_compute::connect_to_compute( @@ -218,15 +213,9 @@ impl PoolingBackend { pub(crate) async fn connect_to_local_proxy( &self, ctx: &RequestMonitoring, + auth_backend: &'static ControlPlaneBackend, conn_info: ConnInfo, ) -> Result, HttpConnError> { - let cplane = match &self.auth_backend { - ServerlessBackend::Local(_) => { - panic!("connect to local_proxy should not be called if we are already local_proxy") - } - ServerlessBackend::ControlPlane(cplane) => cplane, - }; - info!("pool: looking for an existing connection"); if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) { return Ok(client); @@ -236,7 +225,7 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = cplane.attach_to_credentials(ComputeCredentials { + let backend = auth_backend.attach_to_credentials(ComputeCredentials { info: ComputeUserInfo { user: conn_info.user_info.user.clone(), endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)), @@ -271,26 +260,20 @@ impl PoolingBackend { pub(crate) async fn connect_to_local_postgres( &self, ctx: &RequestMonitoring, + auth_backend: &LocalBackend, conn_info: ConnInfo, ) -> Result, HttpConnError> { if let Some(client) = self.local_pool.get(ctx, &conn_info)? { return Ok(client); } - let local_backend = match &self.auth_backend { - auth::ServerlessBackend::ControlPlane(_) => { - unreachable!("only local_proxy can connect to local postgres") - } - auth::ServerlessBackend::Local(local) => local, - }; - if !self.local_pool.initialized(&conn_info) { // only install and grant usage one at a time. - let _permit = local_backend.initialize.acquire().await.unwrap(); + let _permit = auth_backend.initialize.acquire().await.unwrap(); // check again for race if !self.local_pool.initialized(&conn_info) { - local_backend + auth_backend .compute_ctl .install_extension(&ExtensionInstallRequest { extension: EXT_NAME, @@ -299,7 +282,7 @@ impl PoolingBackend { }) .await?; - local_backend + auth_backend .compute_ctl .grant_role(&SetRoleGrantsRequest { schema: EXT_SCHEMA, @@ -317,7 +300,7 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); - let mut node_info = local_backend.node_info.clone(); + let mut node_info = auth_backend.node_info.clone(); let (key, jwk) = create_random_jwk(); diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 6c9616db99..5e369fdd75 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -112,7 +112,6 @@ pub async fn task_main( local_pool, pool: Arc::clone(&conn_pool), config, - auth_backend, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), }); let tls_acceptor: Arc = match config.tls_config.as_ref() { @@ -185,6 +184,7 @@ pub async fn task_main( Box::pin(connection_handler( config, + auth_backend, backend, connections2, cancellation_handler, @@ -290,6 +290,7 @@ async fn connection_startup( #[allow(clippy::too_many_arguments)] async fn connection_handler( config: &'static ProxyConfig, + auth_backend: ServerlessBackend<'static>, backend: Arc, connections: TaskTracker, cancellation_handler: Arc, @@ -324,6 +325,7 @@ async fn connection_handler( request_handler( req, config, + auth_backend, backend.clone(), connections.clone(), cancellation_handler.clone(), @@ -363,6 +365,7 @@ async fn connection_handler( async fn request_handler( mut request: hyper::Request, config: &'static ProxyConfig, + auth_backend: ServerlessBackend<'static>, backend: Arc, ws_connections: TaskTracker, cancellation_handler: Arc, @@ -383,7 +386,7 @@ async fn request_handler( if config.http_config.accept_websockets && framed_websockets::upgrade::is_upgrade_request(&request) { - let ServerlessBackend::ControlPlane(auth_backend) = backend.auth_backend else { + let ServerlessBackend::ControlPlane(auth_backend) = auth_backend else { return json_response(StatusCode::BAD_REQUEST, "query is not supported"); }; @@ -430,9 +433,16 @@ async fn request_handler( ); let span = ctx.span(); - sql_over_http::handle(config, ctx, request, backend, http_cancellation_token) - .instrument(span) - .await + sql_over_http::handle( + config, + ctx, + request, + auth_backend, + backend, + http_cancellation_token, + ) + .instrument(span) + .await } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS { Response::builder() .header("Allow", "OPTIONS, POST") diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 1398a171e5..881cfd2e3f 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -30,10 +30,11 @@ use super::conn_pool_lib::{self, ConnInfo}; use super::http_util::json_response; use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError}; use super::local_conn_pool; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::auth::{endpoint_sni, ComputeUserInfoParseError}; +use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo}; +use crate::auth::{endpoint_sni, ComputeUserInfoParseError, ServerlessBackend}; use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig}; use crate::context::RequestMonitoring; +use crate::control_plane::provider::ControlPlaneBackend; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::{HttpDirection, Metrics}; use crate::proxy::{run_until_cancelled, NeonOptions}; @@ -240,10 +241,11 @@ pub(crate) async fn handle( config: &'static ProxyConfig, ctx: RequestMonitoring, request: Request, + auth_backend: ServerlessBackend<'static>, backend: Arc, cancel: CancellationToken, ) -> Result>, ApiError> { - let result = handle_inner(cancel, config, &ctx, request, backend).await; + let result = handle_inner(cancel, config, &ctx, request, auth_backend, backend).await; let mut response = match result { Ok(r) => { @@ -498,6 +500,7 @@ async fn handle_inner( config: &'static ProxyConfig, ctx: &RequestMonitoring, request: Request, + auth_backend: ServerlessBackend<'static>, backend: Arc, ) -> Result>, SqlOverHttpError> { let _requeset_gauge = Metrics::get() @@ -522,7 +525,11 @@ async fn handle_inner( match conn_info.auth { AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => { - handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await + let ServerlessBackend::ControlPlane(cplane) = auth_backend else { + panic!("auth_broker must be configured with a control-plane auth backend.") + }; + + handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, cplane, backend).await } auth => { handle_db_inner( @@ -532,6 +539,7 @@ async fn handle_inner( request, conn_info.conn_info, auth, + auth_backend, backend, ) .await @@ -539,6 +547,7 @@ async fn handle_inner( } } +#[allow(clippy::too_many_arguments)] async fn handle_db_inner( cancel: CancellationToken, config: &'static ProxyConfig, @@ -546,6 +555,7 @@ async fn handle_db_inner( request: Request, conn_info: ConnInfo, auth: AuthData, + auth_backend: ServerlessBackend<'static>, backend: Arc, ) -> Result>, SqlOverHttpError> { // @@ -588,48 +598,58 @@ async fn handle_db_inner( .map_err(SqlOverHttpError::from), ); - let authenticate_and_connect = Box::pin( - async { - let is_local_proxy = matches!( - backend.auth_backend, - crate::auth::ServerlessBackend::Local(_) - ); + let authenticate_and_connect = Box::pin(async { + let creds = match auth { + AuthData::Password(pw) => { + let ServerlessBackend::ControlPlane(cplane) = auth_backend else { + return Err(SqlOverHttpError::ConnInfo( + ConnInfoError::MissingCredentials(Credentials::BearerJwt), + )); + }; - let keys = match auth { - AuthData::Password(pw) => { - backend - .authenticate_with_password(ctx, &conn_info.user_info, &pw) - .await? - } - AuthData::Jwt(jwt) => { - backend - .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) - .await? - } - }; + backend + .authenticate_with_password(ctx, cplane, &conn_info.user_info, &pw) + .await + .map_err(HttpConnError::from)? + } + AuthData::Jwt(jwt) => backend + .authenticate_with_jwt(ctx, auth_backend, &conn_info.user_info, jwt) + .await + .map_err(HttpConnError::from)?, + }; - let client = match keys.keys { - ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => { - let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?; - let (cli_inner, _dsc) = client.client_inner(); - cli_inner.set_jwt_session(&payload).await?; - Client::Local(client) - } - _ => { - let client = backend - .connect_to_compute(ctx, conn_info, keys, !allow_pool) - .await?; - Client::Remote(client) - } - }; + let client = match (creds.keys, auth_backend) { + (ComputeCredentialKeys::JwtPayload(payload), ServerlessBackend::Local(local)) => { + let mut client = backend + .connect_to_local_postgres(ctx, local, conn_info) + .await?; + let (cli_inner, _dsc) = client.client_inner(); + cli_inner.set_jwt_session(&payload).await?; + Client::Local(client) + } + (keys, auth_backend) => { + let client = backend + .connect_to_compute( + ctx, + auth_backend, + conn_info, + ComputeCredentials { + keys, + info: creds.info, + }, + !allow_pool, + ) + .await + .map_err(HttpConnError::from)?; + Client::Remote(client) + } + }; - // not strictly necessary to mark success here, - // but it's just insurance for if we forget it somewhere else - ctx.success(); - Ok::<_, HttpConnError>(client) - } - .map_err(SqlOverHttpError::from), - ); + // not strictly necessary to mark success here, + // but it's just insurance for if we forget it somewhere else + ctx.success(); + Ok::<_, SqlOverHttpError>(client) + }); let (payload, mut client) = match run_until_cancelled( // Run both operations in parallel @@ -714,14 +734,22 @@ async fn handle_auth_broker_inner( request: Request, conn_info: ConnInfo, jwt: String, + auth_backend: &'static ControlPlaneBackend, backend: Arc, ) -> Result>, SqlOverHttpError> { backend - .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .authenticate_with_jwt( + ctx, + ServerlessBackend::ControlPlane(auth_backend), + &conn_info.user_info, + jwt, + ) .await .map_err(HttpConnError::from)?; - let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; + let mut client = backend + .connect_to_local_proxy(ctx, auth_backend, conn_info) + .await?; let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");