implement local auth backend for proxy and remove control plane hacks

This commit is contained in:
Ruslan Talpa
2025-06-30 16:00:43 +03:00
parent 9480d17de7
commit 7e3f64b309
3 changed files with 176 additions and 38 deletions

View File

@@ -6,7 +6,6 @@ use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
#[cfg(any(test, feature = "testing"))]
use anyhow::Context;
use anyhow::{bail, ensure};
use arc_swap::ArcSwapOption;
@@ -17,12 +16,13 @@ use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, warn};
use tracing::{Instrument, error, info, warn, debug};
use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend};
use crate::batch::BatchQueue;
use crate::cancellation::{CancellationHandler, CancellationProcessor};
use crate::config::{
@@ -41,9 +41,17 @@ use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::rest::DbSchemaCache;
use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
use crate::{auth, control_plane, http, serverless, usage_metrics};
use camino::{Utf8Path, Utf8PathBuf};
use thiserror::Error;
use tokio::sync::Notify;
use compute_api::spec::LocalProxySpec;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::types::RoleName;
use crate::intern::RoleNameInt;
use crate::ext::TaskExt;
use std::str::FromStr;
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
@@ -61,6 +69,9 @@ enum AuthBackendType {
#[cfg(any(test, feature = "testing"))]
Postgres,
#[clap(alias("local"))]
Local,
}
/// Neon proxy/router
@@ -75,6 +86,9 @@ struct ProxyCliArgs {
proxy: SocketAddr,
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
auth_backend: AuthBackendType,
/// Path of the local proxy config file (used for local-file auth backend)
#[clap(long, default_value = "./local_proxy.json")]
config_path: Utf8PathBuf,
/// listen for management callback connection on ip:port
#[clap(short, long, default_value = "127.0.0.1:7000")]
mgmt: SocketAddr,
@@ -436,6 +450,22 @@ pub async fn run() -> anyhow::Result<()> {
endpoint_rate_limiter.clone(),
));
}
// if auth backend is local, we need to load the config file
if let auth::Backend::Local(_) = &auth_backend {
// trigger the first config load **after** setting up the signal hook
// to avoid the race condition where:
// 1. No config file registered when local_proxy starts up
// 2. The config file is written but the signal hook is not yet received
// 3. local_proxy completes startup but has no config loaded, despite there being a registerd config.
let refresh_config_notify = Arc::new(Notify::new());
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(
config,
args.config_path,
refresh_config_notify,
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
@@ -876,6 +906,17 @@ fn build_auth_backend(
let config = Box::leak(Box::new(backend));
Ok(Either::Right(config))
},
AuthBackendType::Local => {
let postgres: SocketAddr = "127.0.0.1:7432".parse()?;
let compute_ctl: ApiUrl = "http://127.0.0.1:3081/".parse()?;
let auth_backend = crate::auth::Backend::Local(crate::auth::backend::MaybeOwned::Owned(
LocalBackend::new(postgres, compute_ctl),
));
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
}
}
@@ -934,6 +975,136 @@ async fn configure_redis(
Ok((regional_redis_client, redis_notifications_client))
}
#[derive(Error, Debug)]
enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::time::Duration;

View File

@@ -389,21 +389,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
ctx: &RequestContext,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
if true {
return Ok(vec![AuthRule {
id: "1".into(),
jwks_url: "https://climbing-minnow-11.clerk.accounts.dev/.well-known/jwks.json"
.parse()
.expect("url is valid"),
audience: None,
role_names: vec![
(&RoleName::from("authenticator")).into(),
(&RoleName::from("authenticated")).into(),
(&RoleName::from("anon")).into(),
],
}]);
}
self.do_get_endpoint_jwks(ctx, endpoint).await
}
@@ -413,24 +398,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
if true {
return Ok(CachedNodeInfo::new_uncached(NodeInfo {
conn_info: ConnectInfo {
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
host: "localhost".into(),
port: 7432,
ssl_mode: SslMode::Disable,
},
aux: MetricsAuxInfo {
endpoint_id: EndpointId::from("foo").into(),
project_id: ProjectId::from("foo").into(),
branch_id: BranchId::from("foo").into(),
compute_id: "foo".into(),
cold_start_info: ColdStartInfo::Warm,
},
}));
}
let key = user_info.endpoint_cache_key();
macro_rules! check_cache {

View File

@@ -652,7 +652,7 @@ async fn handle_inner(
// we always use the authenticator role to connect to the database
let autheticator_role = "authenticator";
let connection_string = format!(
"postgresql://{}@{}.local.neon.build/database",
"postgresql://{}@{}.local.neon.build/database", //FIXME: how do we get the database name knowing only the endpoint id?
autheticator_role, endpoint_id
);
@@ -684,7 +684,7 @@ async fn handle_inner(
.await
}
_ => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(
Credentials::Password,
Credentials::BearerJwt,
))),
}
}