mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-09 14:32:57 +00:00
implement local auth backend for proxy and remove control plane hacks
This commit is contained in:
@@ -6,7 +6,6 @@ use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
use anyhow::Context;
|
||||
use anyhow::{bail, ensure};
|
||||
use arc_swap::ArcSwapOption;
|
||||
@@ -17,12 +16,13 @@ use remote_storage::RemoteStorageConfig;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, error, info, warn};
|
||||
use tracing::{Instrument, error, info, warn, debug};
|
||||
use utils::sentry_init::init_sentry;
|
||||
use utils::{project_build_tag, project_git_version};
|
||||
|
||||
use crate::auth::backend::jwt::JwkCache;
|
||||
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
|
||||
use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend};
|
||||
use crate::batch::BatchQueue;
|
||||
use crate::cancellation::{CancellationHandler, CancellationProcessor};
|
||||
use crate::config::{
|
||||
@@ -41,9 +41,17 @@ use crate::serverless::GlobalConnPoolOptions;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
use crate::serverless::rest::DbSchemaCache;
|
||||
use crate::tls::client_config::compute_client_config_with_root_certs;
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{auth, control_plane, http, serverless, usage_metrics};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::Notify;
|
||||
use compute_api::spec::LocalProxySpec;
|
||||
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
|
||||
use crate::types::RoleName;
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::ext::TaskExt;
|
||||
use std::str::FromStr;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
project_build_tag!(BUILD_TAG);
|
||||
@@ -61,6 +69,9 @@ enum AuthBackendType {
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Postgres,
|
||||
|
||||
#[clap(alias("local"))]
|
||||
Local,
|
||||
}
|
||||
|
||||
/// Neon proxy/router
|
||||
@@ -75,6 +86,9 @@ struct ProxyCliArgs {
|
||||
proxy: SocketAddr,
|
||||
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
|
||||
auth_backend: AuthBackendType,
|
||||
/// Path of the local proxy config file (used for local-file auth backend)
|
||||
#[clap(long, default_value = "./local_proxy.json")]
|
||||
config_path: Utf8PathBuf,
|
||||
/// listen for management callback connection on ip:port
|
||||
#[clap(short, long, default_value = "127.0.0.1:7000")]
|
||||
mgmt: SocketAddr,
|
||||
@@ -436,6 +450,22 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
// if auth backend is local, we need to load the config file
|
||||
if let auth::Backend::Local(_) = &auth_backend {
|
||||
// trigger the first config load **after** setting up the signal hook
|
||||
// to avoid the race condition where:
|
||||
// 1. No config file registered when local_proxy starts up
|
||||
// 2. The config file is written but the signal hook is not yet received
|
||||
// 3. local_proxy completes startup but has no config loaded, despite there being a registerd config.
|
||||
let refresh_config_notify = Arc::new(Notify::new());
|
||||
refresh_config_notify.notify_one();
|
||||
tokio::spawn(refresh_config_loop(
|
||||
config,
|
||||
args.config_path,
|
||||
refresh_config_notify,
|
||||
));
|
||||
}
|
||||
}
|
||||
Either::Right(auth_backend) => {
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
@@ -876,6 +906,17 @@ fn build_auth_backend(
|
||||
let config = Box::leak(Box::new(backend));
|
||||
|
||||
Ok(Either::Right(config))
|
||||
},
|
||||
AuthBackendType::Local => {
|
||||
let postgres: SocketAddr = "127.0.0.1:7432".parse()?;
|
||||
let compute_ctl: ApiUrl = "http://127.0.0.1:3081/".parse()?;
|
||||
let auth_backend = crate::auth::Backend::Local(crate::auth::backend::MaybeOwned::Owned(
|
||||
LocalBackend::new(postgres, compute_ctl),
|
||||
));
|
||||
|
||||
let config = Box::leak(Box::new(auth_backend));
|
||||
|
||||
Ok(Either::Left(config))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -934,6 +975,136 @@ async fn configure_redis(
|
||||
Ok((regional_redis_client, redis_notifications_client))
|
||||
}
|
||||
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
enum RefreshConfigError {
|
||||
#[error(transparent)]
|
||||
Read(#[from] std::io::Error),
|
||||
#[error(transparent)]
|
||||
Parse(#[from] serde_json::Error),
|
||||
#[error(transparent)]
|
||||
Validate(anyhow::Error),
|
||||
#[error(transparent)]
|
||||
Tls(anyhow::Error),
|
||||
}
|
||||
|
||||
async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
|
||||
let mut init = true;
|
||||
loop {
|
||||
rx.notified().await;
|
||||
|
||||
match refresh_config_inner(config, &path).await {
|
||||
Ok(()) => {}
|
||||
// don't log for file not found errors if this is the first time we are checking
|
||||
// for computes that don't use local_proxy, this is not an error.
|
||||
Err(RefreshConfigError::Read(e))
|
||||
if init && e.kind() == std::io::ErrorKind::NotFound =>
|
||||
{
|
||||
debug!(error=?e, ?path, "could not read config file");
|
||||
}
|
||||
Err(RefreshConfigError::Tls(e)) => {
|
||||
error!(error=?e, ?path, "could not read TLS certificates");
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error=?e, ?path, "could not read config file");
|
||||
}
|
||||
}
|
||||
|
||||
init = false;
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_config_inner(
|
||||
config: &ProxyConfig,
|
||||
path: &Utf8Path,
|
||||
) -> Result<(), RefreshConfigError> {
|
||||
let bytes = tokio::fs::read(&path).await?;
|
||||
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
|
||||
|
||||
let mut jwks_set = vec![];
|
||||
|
||||
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
|
||||
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
|
||||
|
||||
ensure!(
|
||||
jwks_url.has_authority()
|
||||
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
|
||||
"Invalid JWKS url. Must be HTTP",
|
||||
);
|
||||
|
||||
ensure!(
|
||||
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
|
||||
"Invalid JWKS url. No domain listed",
|
||||
);
|
||||
|
||||
// clear username, password and ports
|
||||
jwks_url
|
||||
.set_username("")
|
||||
.expect("url can be a base and has a valid host and is not a file. should not error");
|
||||
jwks_url
|
||||
.set_password(None)
|
||||
.expect("url can be a base and has a valid host and is not a file. should not error");
|
||||
// local testing is hard if we need to have a specific restricted port
|
||||
if cfg!(not(feature = "testing")) {
|
||||
jwks_url.set_port(None).expect(
|
||||
"url can be a base and has a valid host and is not a file. should not error",
|
||||
);
|
||||
}
|
||||
|
||||
// clear query params
|
||||
jwks_url.set_fragment(None);
|
||||
jwks_url.query_pairs_mut().clear().finish();
|
||||
|
||||
if jwks_url.scheme() != "https" {
|
||||
// local testing is hard if we need to set up https support.
|
||||
if cfg!(not(feature = "testing")) {
|
||||
jwks_url
|
||||
.set_scheme("https")
|
||||
.expect("should not error to set the scheme to https if it was http");
|
||||
} else {
|
||||
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(JwksSettings {
|
||||
id: jwks.id,
|
||||
jwks_url,
|
||||
_provider_name: jwks.provider_name,
|
||||
jwt_audience: jwks.jwt_audience,
|
||||
role_names: jwks
|
||||
.role_names
|
||||
.into_iter()
|
||||
.map(RoleName::from)
|
||||
.map(|s| RoleNameInt::from(&s))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
for jwks in data.jwks.into_iter().flatten() {
|
||||
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
|
||||
}
|
||||
|
||||
info!("successfully loaded new config");
|
||||
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
|
||||
|
||||
if let Some(tls_config) = data.tls {
|
||||
let tls_config = tokio::task::spawn_blocking(move || {
|
||||
crate::tls::server_config::configure_tls(
|
||||
tls_config.key_path.as_ref(),
|
||||
tls_config.cert_path.as_ref(),
|
||||
None,
|
||||
false,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.propagate_task_panic()
|
||||
.map_err(RefreshConfigError::Tls)?;
|
||||
config.tls_config.store(Some(Arc::new(tls_config)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -389,21 +389,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
ctx: &RequestContext,
|
||||
endpoint: &EndpointId,
|
||||
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
|
||||
if true {
|
||||
return Ok(vec![AuthRule {
|
||||
id: "1".into(),
|
||||
jwks_url: "https://climbing-minnow-11.clerk.accounts.dev/.well-known/jwks.json"
|
||||
.parse()
|
||||
.expect("url is valid"),
|
||||
audience: None,
|
||||
role_names: vec![
|
||||
(&RoleName::from("authenticator")).into(),
|
||||
(&RoleName::from("authenticated")).into(),
|
||||
(&RoleName::from("anon")).into(),
|
||||
],
|
||||
}]);
|
||||
}
|
||||
|
||||
self.do_get_endpoint_jwks(ctx, endpoint).await
|
||||
}
|
||||
|
||||
@@ -413,24 +398,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
if true {
|
||||
return Ok(CachedNodeInfo::new_uncached(NodeInfo {
|
||||
conn_info: ConnectInfo {
|
||||
host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
||||
host: "localhost".into(),
|
||||
port: 7432,
|
||||
ssl_mode: SslMode::Disable,
|
||||
},
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: EndpointId::from("foo").into(),
|
||||
project_id: ProjectId::from("foo").into(),
|
||||
branch_id: BranchId::from("foo").into(),
|
||||
compute_id: "foo".into(),
|
||||
cold_start_info: ColdStartInfo::Warm,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
let key = user_info.endpoint_cache_key();
|
||||
|
||||
macro_rules! check_cache {
|
||||
|
||||
@@ -652,7 +652,7 @@ async fn handle_inner(
|
||||
// we always use the authenticator role to connect to the database
|
||||
let autheticator_role = "authenticator";
|
||||
let connection_string = format!(
|
||||
"postgresql://{}@{}.local.neon.build/database",
|
||||
"postgresql://{}@{}.local.neon.build/database", //FIXME: how do we get the database name knowing only the endpoint id?
|
||||
autheticator_role, endpoint_id
|
||||
);
|
||||
|
||||
@@ -684,7 +684,7 @@ async fn handle_inner(
|
||||
.await
|
||||
}
|
||||
_ => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(
|
||||
Credentials::Password,
|
||||
Credentials::BearerJwt,
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user