proxy: add jwks endpoint to control plane and mock providers (#9165)

This commit is contained in:
Conrad Ludgate
2024-09-27 16:08:43 +01:00
committed by GitHub
parent 42ef08db47
commit 43b2445d0b
6 changed files with 193 additions and 16 deletions

View File

@@ -80,6 +80,14 @@ pub(crate) trait TestBackend: Send + Sync + 'static {
fn get_allowed_ips_and_secret(
&self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
fn dyn_clone(&self) -> Box<dyn TestBackend>;
}
#[cfg(test)]
impl Clone for Box<dyn TestBackend> {
fn clone(&self) -> Self {
TestBackend::dyn_clone(&**self)
}
}
impl std::fmt::Display for Backend<'_, (), ()> {
@@ -585,6 +593,14 @@ mod tests {
))
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
_endpoint: crate::EndpointId,
) -> anyhow::Result<Vec<super::jwt::AuthRule>> {
unimplemented!()
}
async fn wake_compute(
&self,
_ctx: &RequestMonitoring,

View File

@@ -5,7 +5,10 @@ pub mod neon;
use super::messages::{ConsoleError, MetricsAuxInfo};
use crate::{
auth::{
backend::{ComputeCredentialKeys, ComputeUserInfo},
backend::{
jwt::{AuthRule, FetchAuthRules},
ComputeCredentialKeys, ComputeUserInfo,
},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
@@ -16,7 +19,7 @@ use crate::{
intern::ProjectIdInt,
metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey,
scram, EndpointCacheKey, EndpointId,
};
use dashmap::DashMap;
use std::{hash::Hash, sync::Arc, time::Duration};
@@ -334,6 +337,12 @@ pub(crate) trait Api {
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
@@ -343,6 +352,7 @@ pub(crate) trait Api {
}
#[non_exhaustive]
#[derive(Clone)]
pub enum ConsoleBackend {
/// Current Cloud API (V2).
Console(neon::Api),
@@ -386,6 +396,20 @@ impl Api for ConsoleBackend {
}
}
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
match self {
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(any(test, feature = "testing"))]
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(test)]
Self::Test(_api) => Ok(vec![]),
}
}
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
@@ -552,3 +576,13 @@ impl WakeComputePermit {
res
}
}
impl FetchAuthRules for ConsoleBackend {
async fn fetch_auth_rules(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.get_endpoint_jwks(ctx, endpoint).await
}
}

View File

@@ -4,7 +4,9 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::context::RequestMonitoring;
use crate::{
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use crate::{
@@ -118,6 +120,39 @@ impl Api {
})
}
async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
let connection = tokio::spawn(connection);
let res = client.query(
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
&[&endpoint.as_str()],
)
.await?;
let mut rows = vec![];
for row in res {
rows.push(AuthRule {
id: row.get("id"),
jwks_url: url::Url::parse(row.get("jwks_url"))?,
audience: row.get("audience"),
role_names: row
.get::<_, Vec<String>>("role_names")
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
});
}
drop(client);
connection.await??;
Ok(rows)
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
@@ -185,6 +220,14 @@ impl super::Api for Api {
))
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(endpoint).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -7,27 +7,33 @@ use super::{
NodeInfo,
};
use crate::{
auth::backend::ComputeUserInfo,
auth::backend::{jwt::AuthRule, ComputeUserInfo},
compute,
console::messages::{ColdStartInfo, Reason},
console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey,
scram, EndpointCacheKey, EndpointId,
};
use crate::{cache::Cached, context::RequestMonitoring};
use ::http::{header::AUTHORIZATION, HeaderName};
use anyhow::bail;
use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, error, info, info_span, warn, Instrument};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]
pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
jwt: String,
// put in a shared ref so we don't copy secrets all over in memory
jwt: Arc<str>,
}
impl Api {
@@ -38,7 +44,9 @@ impl Api {
locks: &'static ApiLocks<EndpointCacheKey>,
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
) -> Self {
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default();
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
.unwrap_or_default()
.into();
Self {
endpoint,
caches,
@@ -71,9 +79,9 @@ impl Api {
async {
let request = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.get_path("proxy_get_role_secret")
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
@@ -125,6 +133,61 @@ impl Api {
.await
}
async fn do_get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &endpoint.normalize())
.await
{
bail!("endpoint not found");
}
let request_id = ctx.session_id().to_string();
async {
let request = self
.endpoint
.get_with_url(|url| {
url.path_segments_mut()
.push("endpoints")
.push(endpoint.as_str())
.push("jwks");
})
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<EndpointJwksResponse>(response).await?;
let rules = body
.jwks
.into_iter()
.map(|jwks| AuthRule {
id: jwks.id,
jwks_url: jwks.jwks_url,
audience: jwks.jwt_audience,
role_names: jwks.role_names,
})
.collect();
Ok(rules)
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}
async fn do_wake_compute(
&self,
ctx: &RequestMonitoring,
@@ -135,7 +198,7 @@ impl Api {
async {
let mut request_builder = self
.endpoint
.get("proxy_wake_compute")
.get_path("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
@@ -262,6 +325,15 @@ impl super::Api for Api {
))
}
#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(ctx, endpoint).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -86,9 +86,17 @@ impl Endpoint {
/// Return a [builder](RequestBuilder) for a `GET` request,
/// appending a single `path` segment to the base endpoint URL.
pub(crate) fn get(&self, path: &str) -> RequestBuilder {
pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
self.get_with_url(|u| {
u.path_segments_mut().push(path);
})
}
/// Return a [builder](RequestBuilder) for a `GET` request,
/// accepting a closure to modify the url path segments for more complex paths queries.
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
let mut url = self.endpoint.clone();
url.path_segments_mut().push(path);
f(&mut url);
self.client.get(url.into_inner())
}
@@ -144,7 +152,7 @@ mod tests {
// Validate that this pattern makes sense.
let req = endpoint
.get("frobnicate")
.get_path("frobnicate")
.query(&[
("foo", Some("10")), // should be just `foo=10`
("bar", None), // shouldn't be passed at all
@@ -162,7 +170,7 @@ mod tests {
let endpoint = Endpoint::new(url, Client::new());
let req = endpoint
.get("frobnicate")
.get_path("frobnicate")
.query(&[("session_id", uuid::Uuid::nil())])
.build()?;

View File

@@ -525,6 +525,10 @@ impl TestBackend for TestConnectMechanism {
{
unimplemented!("not used in tests")
}
fn dyn_clone(&self) -> Box<dyn TestBackend> {
Box::new(self.clone())
}
}
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {