diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 7592d076ec..ae72bc6de3 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -2,6 +2,7 @@ mod classic; mod hacks; pub mod jwt; mod link; +pub mod local; use std::net::IpAddr; use std::sync::Arc; @@ -9,6 +10,7 @@ use std::time::Duration; use ipnet::{Ipv4Net, Ipv6Net}; pub use link::LinkAuthError; +use local::LocalBackend; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_postgres::config::AuthKeys; use tracing::{info, warn}; @@ -68,6 +70,8 @@ pub enum BackendType<'a, T, D> { Console(MaybeOwned<'a, ConsoleBackend>, T), /// Authentication via a web browser. Link(MaybeOwned<'a, url::ApiUrl>, D), + /// Local proxy uses configured auth credentials and does not wake compute + Local(MaybeOwned<'a, LocalBackend>), } pub trait TestBackend: Send + Sync + 'static { @@ -93,6 +97,7 @@ impl std::fmt::Display for BackendType<'_, (), ()> { ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), }, Self::Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), + Self::Local(_) => fmt.debug_tuple("Local").finish(), } } } @@ -104,6 +109,7 @@ impl BackendType<'_, T, D> { match self { Self::Console(c, x) => BackendType::Console(MaybeOwned::Borrowed(c), x), Self::Link(c, x) => BackendType::Link(MaybeOwned::Borrowed(c), x), + Self::Local(l) => BackendType::Local(MaybeOwned::Borrowed(l)), } } } @@ -116,6 +122,7 @@ impl<'a, T, D> BackendType<'a, T, D> { match self { Self::Console(c, x) => BackendType::Console(c, f(x)), Self::Link(c, x) => BackendType::Link(c, x), + Self::Local(l) => BackendType::Local(l), } } } @@ -126,6 +133,7 @@ impl<'a, T, D, E> BackendType<'a, Result, D> { match self { Self::Console(c, x) => x.map(|x| BackendType::Console(c, x)), Self::Link(c, x) => Ok(BackendType::Link(c, x)), + Self::Local(l) => Ok(BackendType::Local(l)), } } } @@ -157,6 +165,7 @@ impl ComputeUserInfo { pub enum ComputeCredentialKeys { Password(Vec), AuthKeys(AuthKeys), + None, } impl TryFrom for ComputeUserInfo { @@ -289,7 +298,7 @@ async fn auth_quirks( ctx.set_endpoint_id(res.info.endpoint.clone()); let password = match res.keys { ComputeCredentialKeys::Password(p) => p, - ComputeCredentialKeys::AuthKeys(_) => { + ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => { unreachable!("password hack should return a password") } }; @@ -401,6 +410,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { match self { Self::Console(_, user_info) => user_info.endpoint_id.clone(), Self::Link(_, _) => Some("link".into()), + Self::Local(_) => Some("local".into()), } } @@ -409,6 +419,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { match self { Self::Console(_, user_info) => &user_info.user, Self::Link(_, _) => "link", + Self::Local(_) => "local", } } @@ -450,6 +461,9 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { BackendType::Link(url, info) } + Self::Local(_) => { + return Err(auth::AuthError::bad_auth_method("invalid for local proxy")) + } }; info!("user successfully authenticated"); @@ -465,6 +479,7 @@ impl BackendType<'_, ComputeUserInfo, &()> { match self { Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await, Self::Link(_, _) => Ok(Cached::new_uncached(None)), + Self::Local(_) => Ok(Cached::new_uncached(None)), } } @@ -475,6 +490,7 @@ impl BackendType<'_, ComputeUserInfo, &()> { match self { Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, Self::Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), } } } @@ -488,13 +504,15 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { match self { Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await, Self::Link(_, info) => Ok(Cached::new_uncached(info.clone())), + Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } - fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + fn get_keys(&self) -> &ComputeCredentialKeys { match self { - Self::Console(_, creds) => Some(&creds.keys), - Self::Link(_, _) => None, + Self::Console(_, creds) => &creds.keys, + Self::Link(_, _) => &ComputeCredentialKeys::None, + Self::Local(_) => &ComputeCredentialKeys::None, } } } @@ -508,13 +526,15 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { match self { Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await, Self::Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"), + Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } - fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + fn get_keys(&self) -> &ComputeCredentialKeys { match self { - Self::Console(_, creds) => Some(&creds.keys), - Self::Link(_, _) => None, + Self::Console(_, creds) => &creds.keys, + Self::Link(_, _) => &ComputeCredentialKeys::None, + Self::Local(_) => &ComputeCredentialKeys::None, } } } diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs new file mode 100644 index 0000000000..6d18564dd6 --- /dev/null +++ b/proxy/src/auth/backend/local.rs @@ -0,0 +1,79 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use anyhow::Context; +use arc_swap::ArcSwapOption; + +use crate::{ + compute::ConnCfg, + console::{ + messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo}, + NodeInfo, + }, + intern::{BranchIdInt, BranchIdTag, EndpointIdTag, InternId, ProjectIdInt, ProjectIdTag}, + RoleName, +}; + +use super::jwt::{AuthRule, FetchAuthRules, JwkCache}; + +pub struct LocalBackend { + pub jwks_cache: JwkCache, + pub postgres_addr: SocketAddr, + pub node_info: NodeInfo, +} + +impl LocalBackend { + pub fn new(postgres_addr: SocketAddr) -> Self { + LocalBackend { + jwks_cache: JwkCache::default(), + postgres_addr, + node_info: NodeInfo { + config: { + let mut cfg = ConnCfg::new(); + cfg.host(&postgres_addr.ip().to_string()); + cfg.port(postgres_addr.port()); + cfg + }, + // TODO(conrad): make this better reflect compute info rather than endpoint info. + aux: MetricsAuxInfo { + endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"), + project_id: ProjectIdTag::get_interner().get_or_intern("local"), + branch_id: BranchIdTag::get_interner().get_or_intern("local"), + cold_start_info: ColdStartInfo::WarmCached, + }, + allow_self_signed_compute: false, + }, + } + } +} + +#[derive(Clone, Copy)] +pub struct StaticAuthRules; + +pub static JWKS_ROLE_MAP: ArcSwapOption = ArcSwapOption::const_empty(); + +#[derive(Debug, Clone)] +pub struct JwksRoleSettings { + pub roles: HashMap, + pub project_id: ProjectIdInt, + pub branch_id: BranchIdInt, +} + +impl FetchAuthRules for StaticAuthRules { + async fn fetch_auth_rules(&self, role_name: RoleName) -> anyhow::Result> { + let mappings = JWKS_ROLE_MAP.load(); + let role_mappings = mappings + .as_deref() + .and_then(|m| m.roles.get(&role_name)) + .context("JWKs settings for this role were not configured")?; + let mut rules = vec![]; + for setting in &role_mappings.jwks { + rules.push(AuthRule { + id: setting.id.clone(), + jwks_url: setting.jwks_url.clone(), + audience: setting.jwt_audience.clone(), + }); + } + + Ok(rules) + } +} diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index ac66e116d0..a7ccf076b0 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -1,11 +1,13 @@ use measured::FixedCardinalityLabel; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::fmt::{self, Display}; use crate::auth::IpPattern; use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt}; use crate::proxy::retry::CouldRetry; +use crate::RoleName; /// Generic error response with human-readable description. /// Note that we can't always present it to user as is. @@ -341,6 +343,26 @@ impl ColdStartInfo { } } +#[derive(Debug, Deserialize, Clone)] +pub struct JwksRoleMapping { + pub roles: HashMap, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct EndpointJwksResponse { + pub jwks: Vec, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct JwksSettings { + pub id: String, + pub project_id: ProjectIdInt, + pub branch_id: BranchIdInt, + pub jwks_url: url::Url, + pub provider_name: String, + pub jwt_audience: Option, +} + #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index cc2ee10062..4794527410 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -305,6 +305,7 @@ impl NodeInfo { match keys { ComputeCredentialKeys::Password(password) => self.config.password(password), ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), + ComputeCredentialKeys::None => &mut self.config, }; } } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index f38e43ba5a..e1a54a9c98 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -61,7 +61,7 @@ pub trait ComputeConnectBackend { ctx: &RequestMonitoring, ) -> Result; - fn get_keys(&self) -> Option<&ComputeCredentialKeys>; + fn get_keys(&self) -> &ComputeCredentialKeys; } pub struct TcpMechanism<'a> { @@ -112,9 +112,8 @@ where let mut num_retries = 0; let mut node_info = wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; - if let Some(keys) = user_info.get_keys() { - node_info.set_keys(keys); - } + + node_info.set_keys(user_info.get_keys()); node_info.allow_self_signed_compute = allow_self_signed_compute; // let mut node_info = credentials.get_node_info(ctx, user_info).await?; mechanism.update_connect_config(&mut node_info.config); diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index b2bf93dc6d..ea65867293 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -407,7 +407,7 @@ async fn request_handler( .header("Access-Control-Allow-Origin", "*") .header( "Access-Control-Allow-Headers", - "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level", + "Authorization, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level", ) .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 295ea1a1c7..b44ecb76e3 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -4,7 +4,10 @@ use async_trait::async_trait; use tracing::{field::display, info}; use crate::{ - auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, + auth::{ + backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo}, + check_peer_addr_is_in_list, AuthError, + }, compute, config::{AuthenticationConfig, ProxyConfig}, console::{ @@ -24,7 +27,7 @@ use crate::{ Host, }; -use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}; +use super::conn_pool::{poll_client, AuthData, Client, ConnInfo, GlobalConnPool}; pub struct PoolingBackend { pub pool: Arc>, @@ -33,13 +36,14 @@ pub struct PoolingBackend { } impl PoolingBackend { - pub async fn authenticate( + pub async fn authenticate_with_password( &self, ctx: &RequestMonitoring, config: &AuthenticationConfig, - conn_info: &ConnInfo, + user_info: &ComputeUserInfo, + password: &[u8], ) -> Result { - let user_info = conn_info.user_info.clone(); + let user_info = user_info.clone(); let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { @@ -47,7 +51,7 @@ impl PoolingBackend { } if !self .endpoint_rate_limiter - .check(conn_info.user_info.endpoint.clone().into(), 1) + .check(user_info.endpoint.clone().into(), 1) { return Err(AuthError::too_many_connections()); } @@ -70,14 +74,10 @@ impl PoolingBackend { return Err(AuthError::auth_failed(&*user_info.user)); } }; - let ep = EndpointIdInt::from(&conn_info.user_info.endpoint); - let auth_outcome = crate::auth::validate_password_and_exchange( - &config.thread_pool, - ep, - &conn_info.password, - secret, - ) - .await?; + let ep = EndpointIdInt::from(&user_info.endpoint); + let auth_outcome = + crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret) + .await?; let res = match auth_outcome { crate::sasl::Outcome::Success(key) => { info!("user successfully authenticated"); @@ -85,7 +85,7 @@ impl PoolingBackend { } crate::sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); - Err(AuthError::auth_failed(&*conn_info.user_info.user)) + Err(AuthError::auth_failed(&*user_info.user)) } }; res.map(|key| ComputeCredentials { @@ -94,6 +94,39 @@ impl PoolingBackend { }) } + pub async fn authenticate_with_jwt( + &self, + ctx: &RequestMonitoring, + user_info: &ComputeUserInfo, + jwt: &str, + ) -> Result { + match &self.config.auth_backend { + crate::auth::BackendType::Console(_, _) => { + Err(AuthError::auth_failed("JWT login is not yet supported")) + } + crate::auth::BackendType::Link(_, _) => Err(AuthError::auth_failed( + "JWT login over link proxy is not supported", + )), + crate::auth::BackendType::Local(cache) => { + cache + .jwks_cache + .check_jwt( + ctx, + user_info.endpoint.clone(), + user_info.user.clone(), + &StaticAuthRules, + jwt, + ) + .await + .map_err(|e| AuthError::auth_failed(e.to_string()))?; + Ok(ComputeCredentials { + info: user_info.clone(), + keys: crate::auth::backend::ComputeCredentialKeys::None, + }) + } + } + } + // Wake up the destination if needed. Code here is a bit involved because // we reuse the code from the usual proxy and we need to prepare few structures // that this code expects. @@ -232,10 +265,16 @@ impl ConnectMechanism for TokioMechanism { let mut config = (*node_info.config).clone(); let config = config .user(&self.conn_info.user_info.user) - .password(&*self.conn_info.password) .dbname(&self.conn_info.dbname) .connect_timeout(timeout); + match &self.conn_info.auth { + AuthData::Jwt(_) => {} + AuthData::Password(pw) => { + config.password(pw); + } + } + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let res = config.connect(tokio_postgres::NoTls).await; drop(pause); diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 3478787995..6ed694af58 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -33,7 +33,13 @@ use super::backend::HttpConnError; pub struct ConnInfo { pub user_info: ComputeUserInfo, pub dbname: DbName, - pub password: SmallVec<[u8; 16]>, + pub auth: AuthData, +} + +#[derive(Debug, Clone)] +pub enum AuthData { + Password(SmallVec<[u8; 16]>), + Jwt(String), } impl ConnInfo { @@ -778,7 +784,7 @@ mod tests { options: Default::default(), }, dbname: "dbname".into(), - password: "password".as_bytes().into(), + auth: AuthData::Password("password".as_bytes().into()), }; let ep_pool = Arc::downgrade( &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()), @@ -836,7 +842,7 @@ mod tests { options: Default::default(), }, dbname: "dbname".into(), - password: "password".as_bytes().into(), + auth: AuthData::Password("password".as_bytes().into()), }; let ep_pool = Arc::downgrade( &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()), diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index bbfed90f39..79baef45f6 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -7,6 +7,7 @@ use futures::future::try_join; use futures::future::Either; use futures::StreamExt; use futures::TryFutureExt; +use http::header::AUTHORIZATION; use http_body_util::BodyExt; use http_body_util::Full; use hyper1::body::Body; @@ -56,6 +57,7 @@ use crate::DbName; use crate::RoleName; use super::backend::PoolingBackend; +use super::conn_pool::AuthData; use super::conn_pool::Client; use super::conn_pool::ConnInfo; use super::http_util::json_response; @@ -88,6 +90,7 @@ enum Payload { const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB +static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string"); static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); @@ -109,7 +112,7 @@ where #[derive(Debug, thiserror::Error)] pub enum ConnInfoError { #[error("invalid header: {0}")] - InvalidHeader(&'static str), + InvalidHeader(&'static HeaderName), #[error("invalid connection string: {0}")] UrlParseError(#[from] url::ParseError), #[error("incorrect scheme")] @@ -153,10 +156,10 @@ fn get_conn_info( ctx.set_auth_method(crate::context::AuthMethod::Cleartext); let connection_string = headers - .get("Neon-Connection-String") - .ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))? + .get(&CONN_STRING) + .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))? .to_str() - .map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?; + .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?; let connection_url = Url::parse(connection_string)?; @@ -179,10 +182,23 @@ fn get_conn_info( } ctx.set_user(username.clone()); - let password = connection_url - .password() - .ok_or(ConnInfoError::MissingPassword)?; - let password = urlencoding::decode_binary(password.as_bytes()); + let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { + let auth = auth + .to_str() + .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; + AuthData::Jwt( + auth.strip_prefix("Bearer ") + .ok_or(ConnInfoError::MissingPassword)? + .into(), + ) + } else if let Some(pass) = connection_url.password() { + AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { + std::borrow::Cow::Borrowed(b) => b.into(), + std::borrow::Cow::Owned(b) => b.into(), + }) + } else { + return Err(ConnInfoError::MissingPassword); + }; let endpoint = match connection_url.host() { Some(url::Host::Domain(hostname)) => { @@ -225,10 +241,7 @@ fn get_conn_info( Ok(ConnInfo { user_info, dbname, - password: match password { - std::borrow::Cow::Borrowed(b) => b.into(), - std::borrow::Cow::Owned(b) => b.into(), - }, + auth, }) } @@ -550,9 +563,24 @@ async fn handle_inner( let authenticate_and_connect = Box::pin( async { - let keys = backend - .authenticate(ctx, &config.authentication_config, &conn_info) - .await?; + let keys = match &conn_info.auth { + AuthData::Password(pw) => { + backend + .authenticate_with_password( + ctx, + &config.authentication_config, + &conn_info.user_info, + pw, + ) + .await? + } + AuthData::Jwt(jwt) => { + backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await? + } + }; + let client = backend .connect_to_compute(ctx, conn_info, keys, !allow_pool) .await?;