mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-05 20:42:54 +00:00
proxy: add jwks endpoint to control plane and mock providers (#9165)
This commit is contained in:
@@ -80,6 +80,14 @@ pub(crate) trait TestBackend: Send + Sync + 'static {
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
|
||||
fn dyn_clone(&self) -> Box<dyn TestBackend>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Clone for Box<dyn TestBackend> {
|
||||
fn clone(&self) -> Self {
|
||||
TestBackend::dyn_clone(&**self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Backend<'_, (), ()> {
|
||||
@@ -585,6 +593,14 @@ mod tests {
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_endpoint: crate::EndpointId,
|
||||
) -> anyhow::Result<Vec<super::jwt::AuthRule>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
|
||||
@@ -5,7 +5,10 @@ pub mod neon;
|
||||
use super::messages::{ConsoleError, MetricsAuxInfo};
|
||||
use crate::{
|
||||
auth::{
|
||||
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
||||
backend::{
|
||||
jwt::{AuthRule, FetchAuthRules},
|
||||
ComputeCredentialKeys, ComputeUserInfo,
|
||||
},
|
||||
IpPattern,
|
||||
},
|
||||
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
@@ -16,7 +19,7 @@ use crate::{
|
||||
intern::ProjectIdInt,
|
||||
metrics::ApiLockMetrics,
|
||||
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
|
||||
scram, EndpointCacheKey,
|
||||
scram, EndpointCacheKey, EndpointId,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::{hash::Hash, sync::Arc, time::Duration};
|
||||
@@ -334,6 +337,12 @@ pub(crate) trait Api {
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
@@ -343,6 +352,7 @@ pub(crate) trait Api {
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone)]
|
||||
pub enum ConsoleBackend {
|
||||
/// Current Cloud API (V2).
|
||||
Console(neon::Api),
|
||||
@@ -386,6 +396,20 @@ impl Api for ConsoleBackend {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
match self {
|
||||
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_api) => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
@@ -552,3 +576,13 @@ impl WakeComputePermit {
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl FetchAuthRules for ConsoleBackend {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
self.get_endpoint_jwks(ctx, endpoint).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||
};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::{
|
||||
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
|
||||
};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::IpPattern, cache::Cached};
|
||||
use crate::{
|
||||
@@ -118,6 +120,39 @@ impl Api {
|
||||
})
|
||||
}
|
||||
|
||||
async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
|
||||
let connection = tokio::spawn(connection);
|
||||
|
||||
let res = client.query(
|
||||
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
|
||||
&[&endpoint.as_str()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut rows = vec![];
|
||||
for row in res {
|
||||
rows.push(AuthRule {
|
||||
id: row.get("id"),
|
||||
jwks_url: url::Url::parse(row.get("jwks_url"))?,
|
||||
audience: row.get("audience"),
|
||||
role_names: row
|
||||
.get::<_, Vec<String>>("role_names")
|
||||
.into_iter()
|
||||
.map(RoleName::from)
|
||||
.map(|s| RoleNameInt::from(&s))
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
|
||||
drop(client);
|
||||
connection.await??;
|
||||
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
@@ -185,6 +220,14 @@ impl super::Api for Api {
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
self.do_get_endpoint_jwks(endpoint).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
|
||||
@@ -7,27 +7,33 @@ use super::{
|
||||
NodeInfo,
|
||||
};
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
auth::backend::{jwt::AuthRule, ComputeUserInfo},
|
||||
compute,
|
||||
console::messages::{ColdStartInfo, Reason},
|
||||
console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
|
||||
http,
|
||||
metrics::{CacheOutcome, Metrics},
|
||||
rate_limiter::WakeComputeRateLimiter,
|
||||
scram, EndpointCacheKey,
|
||||
scram, EndpointCacheKey, EndpointId,
|
||||
};
|
||||
use crate::{cache::Cached, context::RequestMonitoring};
|
||||
use ::http::{header::AUTHORIZATION, HeaderName};
|
||||
use anyhow::bail;
|
||||
use futures::TryFutureExt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
|
||||
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
jwt: String,
|
||||
// put in a shared ref so we don't copy secrets all over in memory
|
||||
jwt: Arc<str>,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
@@ -38,7 +44,9 @@ impl Api {
|
||||
locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
) -> Self {
|
||||
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default();
|
||||
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
|
||||
.unwrap_or_default()
|
||||
.into();
|
||||
Self {
|
||||
endpoint,
|
||||
caches,
|
||||
@@ -71,9 +79,9 @@ impl Api {
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.get_path("proxy_get_role_secret")
|
||||
.header(X_REQUEST_ID, &request_id)
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
@@ -125,6 +133,61 @@ impl Api {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
if !self
|
||||
.caches
|
||||
.endpoints_cache
|
||||
.is_valid(ctx, &endpoint.normalize())
|
||||
.await
|
||||
{
|
||||
bail!("endpoint not found");
|
||||
}
|
||||
let request_id = ctx.session_id().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get_with_url(|url| {
|
||||
url.path_segments_mut()
|
||||
.push("endpoints")
|
||||
.push(endpoint.as_str())
|
||||
.push("jwks");
|
||||
})
|
||||
.header(X_REQUEST_ID, &request_id)
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
|
||||
let body = parse_body::<EndpointJwksResponse>(response).await?;
|
||||
|
||||
let rules = body
|
||||
.jwks
|
||||
.into_iter()
|
||||
.map(|jwks| AuthRule {
|
||||
id: jwks.id,
|
||||
jwks_url: jwks.jwks_url,
|
||||
audience: jwks.jwt_audience,
|
||||
role_names: jwks.role_names,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(rules)
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
@@ -135,7 +198,7 @@ impl Api {
|
||||
async {
|
||||
let mut request_builder = self
|
||||
.endpoint
|
||||
.get("proxy_wake_compute")
|
||||
.get_path("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
@@ -262,6 +325,15 @@ impl super::Api for Api {
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
self.do_get_endpoint_jwks(ctx, endpoint).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
|
||||
@@ -86,9 +86,17 @@ impl Endpoint {
|
||||
|
||||
/// Return a [builder](RequestBuilder) for a `GET` request,
|
||||
/// appending a single `path` segment to the base endpoint URL.
|
||||
pub(crate) fn get(&self, path: &str) -> RequestBuilder {
|
||||
pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
|
||||
self.get_with_url(|u| {
|
||||
u.path_segments_mut().push(path);
|
||||
})
|
||||
}
|
||||
|
||||
/// Return a [builder](RequestBuilder) for a `GET` request,
|
||||
/// accepting a closure to modify the url path segments for more complex paths queries.
|
||||
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push(path);
|
||||
f(&mut url);
|
||||
self.client.get(url.into_inner())
|
||||
}
|
||||
|
||||
@@ -144,7 +152,7 @@ mod tests {
|
||||
|
||||
// Validate that this pattern makes sense.
|
||||
let req = endpoint
|
||||
.get("frobnicate")
|
||||
.get_path("frobnicate")
|
||||
.query(&[
|
||||
("foo", Some("10")), // should be just `foo=10`
|
||||
("bar", None), // shouldn't be passed at all
|
||||
@@ -162,7 +170,7 @@ mod tests {
|
||||
let endpoint = Endpoint::new(url, Client::new());
|
||||
|
||||
let req = endpoint
|
||||
.get("frobnicate")
|
||||
.get_path("frobnicate")
|
||||
.query(&[("session_id", uuid::Uuid::nil())])
|
||||
.build()?;
|
||||
|
||||
|
||||
@@ -525,6 +525,10 @@ impl TestBackend for TestConnectMechanism {
|
||||
{
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
|
||||
fn dyn_clone(&self) -> Box<dyn TestBackend> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
|
||||
Reference in New Issue
Block a user