From 43b2445d0bc7fd541f10a441c3935eebb6b48e78 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 27 Sep 2024 16:08:43 +0100 Subject: [PATCH] proxy: add jwks endpoint to control plane and mock providers (#9165) --- proxy/src/auth/backend.rs | 16 ++++++ proxy/src/console/provider.rs | 38 ++++++++++++- proxy/src/console/provider/mock.rs | 45 ++++++++++++++- proxy/src/console/provider/neon.rs | 90 +++++++++++++++++++++++++++--- proxy/src/http.rs | 16 ++++-- proxy/src/proxy/tests.rs | 4 ++ 6 files changed, 193 insertions(+), 16 deletions(-) diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 4e9f4591ad..5dbfa5cc09 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -80,6 +80,14 @@ pub(crate) trait TestBackend: Send + Sync + 'static { fn get_allowed_ips_and_secret( &self, ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError>; + fn dyn_clone(&self) -> Box; +} + +#[cfg(test)] +impl Clone for Box { + fn clone(&self) -> Self { + TestBackend::dyn_clone(&**self) + } } impl std::fmt::Display for Backend<'_, (), ()> { @@ -585,6 +593,14 @@ mod tests { )) } + async fn get_endpoint_jwks( + &self, + _ctx: &RequestMonitoring, + _endpoint: crate::EndpointId, + ) -> anyhow::Result> { + unimplemented!() + } + async fn wake_compute( &self, _ctx: &RequestMonitoring, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 16e8da605b..95097f2de9 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -5,7 +5,10 @@ pub mod neon; use super::messages::{ConsoleError, MetricsAuxInfo}; use crate::{ auth::{ - backend::{ComputeCredentialKeys, ComputeUserInfo}, + backend::{ + jwt::{AuthRule, FetchAuthRules}, + ComputeCredentialKeys, ComputeUserInfo, + }, IpPattern, }, cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru}, @@ -16,7 +19,7 @@ use crate::{ intern::ProjectIdInt, metrics::ApiLockMetrics, rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}, - scram, EndpointCacheKey, + scram, EndpointCacheKey, EndpointId, }; use dashmap::DashMap; use std::{hash::Hash, sync::Arc, time::Duration}; @@ -334,6 +337,12 @@ pub(crate) trait Api { user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), errors::GetAuthInfoError>; + async fn get_endpoint_jwks( + &self, + ctx: &RequestMonitoring, + endpoint: EndpointId, + ) -> anyhow::Result>; + /// Wake up the compute node and return the corresponding connection info. async fn wake_compute( &self, @@ -343,6 +352,7 @@ pub(crate) trait Api { } #[non_exhaustive] +#[derive(Clone)] pub enum ConsoleBackend { /// Current Cloud API (V2). Console(neon::Api), @@ -386,6 +396,20 @@ impl Api for ConsoleBackend { } } + async fn get_endpoint_jwks( + &self, + ctx: &RequestMonitoring, + endpoint: EndpointId, + ) -> anyhow::Result> { + match self { + Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await, + #[cfg(any(test, feature = "testing"))] + Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await, + #[cfg(test)] + Self::Test(_api) => Ok(vec![]), + } + } + async fn wake_compute( &self, ctx: &RequestMonitoring, @@ -552,3 +576,13 @@ impl WakeComputePermit { res } } + +impl FetchAuthRules for ConsoleBackend { + async fn fetch_auth_rules( + &self, + ctx: &RequestMonitoring, + endpoint: EndpointId, + ) -> anyhow::Result> { + self.get_endpoint_jwks(ctx, endpoint).await + } +} diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 1b77418de6..b548a0203a 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -4,7 +4,9 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo, }; -use crate::context::RequestMonitoring; +use crate::{ + auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName, +}; use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; use crate::{auth::IpPattern, cache::Cached}; use crate::{ @@ -118,6 +120,39 @@ impl Api { }) } + async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result> { + let (client, connection) = + tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; + + let connection = tokio::spawn(connection); + + let res = client.query( + "select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1", + &[&endpoint.as_str()], + ) + .await?; + + let mut rows = vec![]; + for row in res { + rows.push(AuthRule { + id: row.get("id"), + jwks_url: url::Url::parse(row.get("jwks_url"))?, + audience: row.get("audience"), + role_names: row + .get::<_, Vec>("role_names") + .into_iter() + .map(RoleName::from) + .map(|s| RoleNameInt::from(&s)) + .collect(), + }); + } + + drop(client); + connection.await??; + + Ok(rows) + } + async fn do_wake_compute(&self) -> Result { let mut config = compute::ConnCfg::new(); config @@ -185,6 +220,14 @@ impl super::Api for Api { )) } + async fn get_endpoint_jwks( + &self, + _ctx: &RequestMonitoring, + endpoint: EndpointId, + ) -> anyhow::Result> { + self.do_get_endpoint_jwks(endpoint).await + } + #[tracing::instrument(skip_all)] async fn wake_compute( &self, diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index b004bf4ecf..2d527f378c 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -7,27 +7,33 @@ use super::{ NodeInfo, }; use crate::{ - auth::backend::ComputeUserInfo, + auth::backend::{jwt::AuthRule, ComputeUserInfo}, compute, - console::messages::{ColdStartInfo, Reason}, + console::messages::{ColdStartInfo, EndpointJwksResponse, Reason}, http, metrics::{CacheOutcome, Metrics}, rate_limiter::WakeComputeRateLimiter, - scram, EndpointCacheKey, + scram, EndpointCacheKey, EndpointId, }; use crate::{cache::Cached, context::RequestMonitoring}; +use ::http::{header::AUTHORIZATION, HeaderName}; +use anyhow::bail; use futures::TryFutureExt; use std::{sync::Arc, time::Duration}; use tokio::time::Instant; use tokio_postgres::config::SslMode; use tracing::{debug, error, info, info_span, warn, Instrument}; +const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); + +#[derive(Clone)] pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, pub(crate) locks: &'static ApiLocks, pub(crate) wake_compute_endpoint_rate_limiter: Arc, - jwt: String, + // put in a shared ref so we don't copy secrets all over in memory + jwt: Arc, } impl Api { @@ -38,7 +44,9 @@ impl Api { locks: &'static ApiLocks, wake_compute_endpoint_rate_limiter: Arc, ) -> Self { - let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default(); + let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") + .unwrap_or_default() + .into(); Self { endpoint, caches, @@ -71,9 +79,9 @@ impl Api { async { let request = self .endpoint - .get("proxy_get_role_secret") - .header("X-Request-ID", &request_id) - .header("Authorization", format!("Bearer {}", &self.jwt)) + .get_path("proxy_get_role_secret") + .header(X_REQUEST_ID, &request_id) + .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) .query(&[("session_id", ctx.session_id())]) .query(&[ ("application_name", application_name.as_str()), @@ -125,6 +133,61 @@ impl Api { .await } + async fn do_get_endpoint_jwks( + &self, + ctx: &RequestMonitoring, + endpoint: EndpointId, + ) -> anyhow::Result> { + if !self + .caches + .endpoints_cache + .is_valid(ctx, &endpoint.normalize()) + .await + { + bail!("endpoint not found"); + } + let request_id = ctx.session_id().to_string(); + async { + let request = self + .endpoint + .get_with_url(|url| { + url.path_segments_mut() + .push("endpoints") + .push(endpoint.as_str()) + .push("jwks"); + }) + .header(X_REQUEST_ID, &request_id) + .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) + .query(&[("session_id", ctx.session_id())]) + .build()?; + + info!(url = request.url().as_str(), "sending http request"); + let start = Instant::now(); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); + let response = self.endpoint.execute(request).await?; + drop(pause); + info!(duration = ?start.elapsed(), "received http response"); + + let body = parse_body::(response).await?; + + let rules = body + .jwks + .into_iter() + .map(|jwks| AuthRule { + id: jwks.id, + jwks_url: jwks.jwks_url, + audience: jwks.jwt_audience, + role_names: jwks.role_names, + }) + .collect(); + + Ok(rules) + } + .map_err(crate::error::log_error) + .instrument(info_span!("http", id = request_id)) + .await + } + async fn do_wake_compute( &self, ctx: &RequestMonitoring, @@ -135,7 +198,7 @@ impl Api { async { let mut request_builder = self .endpoint - .get("proxy_wake_compute") + .get_path("proxy_wake_compute") .header("X-Request-ID", &request_id) .header("Authorization", format!("Bearer {}", &self.jwt)) .query(&[("session_id", ctx.session_id())]) @@ -262,6 +325,15 @@ impl super::Api for Api { )) } + #[tracing::instrument(skip_all)] + async fn get_endpoint_jwks( + &self, + ctx: &RequestMonitoring, + endpoint: EndpointId, + ) -> anyhow::Result> { + self.do_get_endpoint_jwks(ctx, endpoint).await + } + #[tracing::instrument(skip_all)] async fn wake_compute( &self, diff --git a/proxy/src/http.rs b/proxy/src/http.rs index c77d95f47d..14720b5c6b 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -86,9 +86,17 @@ impl Endpoint { /// Return a [builder](RequestBuilder) for a `GET` request, /// appending a single `path` segment to the base endpoint URL. - pub(crate) fn get(&self, path: &str) -> RequestBuilder { + pub(crate) fn get_path(&self, path: &str) -> RequestBuilder { + self.get_with_url(|u| { + u.path_segments_mut().push(path); + }) + } + + /// Return a [builder](RequestBuilder) for a `GET` request, + /// accepting a closure to modify the url path segments for more complex paths queries. + pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder { let mut url = self.endpoint.clone(); - url.path_segments_mut().push(path); + f(&mut url); self.client.get(url.into_inner()) } @@ -144,7 +152,7 @@ mod tests { // Validate that this pattern makes sense. let req = endpoint - .get("frobnicate") + .get_path("frobnicate") .query(&[ ("foo", Some("10")), // should be just `foo=10` ("bar", None), // shouldn't be passed at all @@ -162,7 +170,7 @@ mod tests { let endpoint = Endpoint::new(url, Client::new()); let req = endpoint - .get("frobnicate") + .get_path("frobnicate") .query(&[("session_id", uuid::Uuid::nil())]) .build()?; diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 752d982726..058ec06e02 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -525,6 +525,10 @@ impl TestBackend for TestConnectMechanism { { unimplemented!("not used in tests") } + + fn dyn_clone(&self) -> Box { + Box::new(self.clone()) + } } fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {