diff --git a/Cargo.lock b/Cargo.lock index bdf2b08c5c..1f6e3a6083 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4320,6 +4320,7 @@ dependencies = [ "hyper-util", "ipnet", "itertools", + "jsonwebtoken", "lasso", "md5", "metrics", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 12bd67ea36..dc79e9e47c 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -97,6 +97,7 @@ native-tls.workspace = true postgres-native-tls.workspace = true postgres-protocol.workspace = true redis.workspace = true +jsonwebtoken.workspace = true workspace_hack.workspace = true diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 8c44823c98..464d9ca108 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -13,6 +13,8 @@ mod password_hack; pub use password_hack::parse_endpoint_param; use password_hack::PasswordHackPayload; +pub mod caps; + mod flow; pub use flow::*; use tokio::time::error::Elapsed; @@ -71,6 +73,9 @@ pub enum AuthErrorImpl { #[error("Too many connections to this endpoint. Please try again later.")] TooManyConnections, + #[error("neon_caps token is invalid")] + CapsInvalid, + #[error("Authentication timed out")] UserTimeout(Elapsed), } @@ -96,6 +101,10 @@ impl AuthError { AuthErrorImpl::TooManyConnections.into() } + pub fn caps_invalid() -> Self { + AuthErrorImpl::CapsInvalid.into() + } + pub fn is_auth_failed(&self) -> bool { matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_)) } @@ -126,6 +135,7 @@ impl UserFacingError for AuthError { IpAddressNotAllowed(_) => self.to_string(), TooManyConnections => self.to_string(), UserTimeout(_) => self.to_string(), + CapsInvalid => self.to_string(), } } } @@ -145,6 +155,7 @@ impl ReportableError for AuthError { IpAddressNotAllowed(_) => crate::error::ErrorKind::User, TooManyConnections => crate::error::ErrorKind::RateLimit, UserTimeout(_) => crate::error::ErrorKind::User, + CapsInvalid => crate::error::ErrorKind::User, } } } diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index e421798067..adcbf668d0 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -28,6 +28,7 @@ use crate::{ stream, url, }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; +use std::net::IpAddr; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; @@ -251,11 +252,13 @@ async fn auth_quirks( Ok(info) => (info, None), }; + let bypass_ipcheck = apply_caps(&config, &info, &ctx.peer_addr)?; + info!("fetching user's authentication info"); let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?; // check allowed list - if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { + if !bypass_ipcheck && !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr)); } let cached_secret = match maybe_secret { @@ -537,6 +540,7 @@ mod tests { scram_protocol_timeout: std::time::Duration::from_secs(5), rate_limiter_enabled: true, rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), + caps: None, }); async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage { @@ -695,3 +699,43 @@ mod tests { handle.await.unwrap(); } } + +// It checks that provided JWT capabilities are valid for the connection +// +// if it returns Ok(true), futher peer IP checks has to be disabled +// +// If proxy isn't configured for JWT capabilities or neon_caps option +// isn't set, it skips any checks +pub fn apply_caps( + config: &AuthenticationConfig, + info: &ComputeUserInfo, + peer_addr: &IpAddr, +) -> auth::Result { + match (&config.caps, info.options.caps()) { + (Some(caps_config), Some(caps)) => { + let token = match caps_config.decode(&caps) { + Err(_) => { + return Err(auth::AuthError::caps_invalid()); + } + Ok(token) => token, + }; + + if token.claims.endpoint_id != *info.endpoint { + return Err(auth::AuthError::caps_invalid()); + } + + match token.claims.check_ip(peer_addr) { + None => return Ok(false), + Some(true) => { + return Ok(true); + } + Some(false) => { + return Err(auth::AuthError::ip_address_not_allowed(*peer_addr)); + } + } + } + _ => { + return Ok(false); + } + } +} diff --git a/proxy/src/auth/caps.rs b/proxy/src/auth/caps.rs new file mode 100644 index 0000000000..1eff2d532f --- /dev/null +++ b/proxy/src/auth/caps.rs @@ -0,0 +1,96 @@ +use std::{borrow::Cow, fmt::Display, fs, net::IpAddr}; + +use anyhow::Result; +use camino::Utf8Path; +use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; +use serde::{Deserialize, Serialize}; +use utils::http::error::ApiError; + +use super::{check_peer_addr_is_in_list, IpPattern}; + +const TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA; + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Scope { + Connection, +} + +#[derive(Debug, Deserialize, Clone, PartialEq)] +pub struct Claims { + pub scope: Scope, + pub allowed_ips: Option>, + pub endpoint_id: String, +} + +impl Claims { + pub fn check_ip(&self, ip: &IpAddr) -> Option { + let allowed_ips = match &self.allowed_ips { + None => return None, + Some(allowed_ips) => allowed_ips, + }; + if allowed_ips.is_empty() { + return Some(true); + } + + return Some(check_peer_addr_is_in_list(ip, &allowed_ips)); + } +} + +pub struct CapsValidator { + decoding_key: DecodingKey, + validation: Validation, +} + +impl CapsValidator { + pub fn new(decoding_key: DecodingKey) -> Self { + let mut validation = Validation::default(); + validation.algorithms = vec![TOKEN_ALGORITHM]; + Self { + decoding_key, + validation, + } + } + + pub fn from_key_path(key_path: &Utf8Path) -> Result { + let metadata = key_path.metadata()?; + let decoding_key = if metadata.is_file() { + let public_key = fs::read(key_path)?; + DecodingKey::from_ed_pem(&public_key)? + } else { + anyhow::bail!("path isn't a file") + }; + + Ok(Self::new(decoding_key)) + } + + pub fn decode(&self, token: &str) -> std::result::Result, CapsError> { + return match decode(token, &self.decoding_key, &self.validation) { + Ok(res) => Ok(res), + Err(e) => Err(CapsError(Cow::Owned(e.to_string()))), + }; + } +} + +impl std::fmt::Debug for CapsValidator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CapsValidator") + .field("validation", &self.validation) + .finish() + } +} + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct CapsError(pub Cow<'static, str>); + +impl Display for CapsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for ApiError { + fn from(_value: CapsError) -> Self { + ApiError::Forbidden("neon_caps validation error".to_string()) + } +} diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 56a3ef79cd..6716c2bf4c 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -5,9 +5,11 @@ use aws_config::meta::region::RegionProviderChain; use aws_config::profile::ProfileFileCredentialsProvider; use aws_config::provider_config::ProviderConfig; use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; +use camino::Utf8Path; use futures::future::Either; use proxy::auth; use proxy::auth::backend::MaybeOwned; +use proxy::auth::caps::CapsValidator; use proxy::cancellation::CancelMap; use proxy::cancellation::CancellationHandler; use proxy::config::remote_storage_from_toml; @@ -193,6 +195,9 @@ struct ProxyCliArgs { #[clap(flatten)] parquet_upload: ParquetUploadArgs, + #[clap(long)] + caps_key: Option, + /// interval for backup metric collection #[clap(long, default_value = "10m", value_parser = humantime::parse_duration)] metric_backup_collection_interval: std::time::Duration, @@ -542,10 +547,20 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, }, }; + + let caps; + if let Some(key) = &args.caps_key { + let path = Utf8Path::new(key); + caps = Some(CapsValidator::from_key_path(path)?); + } else { + caps = None; + } + let authentication_config = AuthenticationConfig { scram_protocol_timeout: args.scram_protocol_timeout, rate_limiter_enabled: args.auth_rate_limit_enabled, rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), + caps, }; let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); diff --git a/proxy/src/config.rs b/proxy/src/config.rs index fc490c7348..9ec9241ffa 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,5 +1,5 @@ use crate::{ - auth, + auth::{self, caps::CapsValidator}, rate_limiter::{AuthRateLimiter, RateBucketInfo}, serverless::GlobalConnPoolOptions, }; @@ -58,6 +58,7 @@ pub struct AuthenticationConfig { pub scram_protocol_timeout: tokio::time::Duration, pub rate_limiter_enabled: bool, pub rate_limiter: AuthRateLimiter, + pub caps: Option, } impl TlsConfig { diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 6051c0a812..d59ad1b673 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -385,6 +385,10 @@ impl NeonOptions { !self.0.is_empty() } + pub fn caps(&self) -> Option<&str> { + self.0.iter().find(|(k, _)| k == "caps").map(|(_, v)| &**v) + } + fn parse_from_iter<'a>(options: impl Iterator) -> Self { let mut options = options .filter_map(neon_option) @@ -398,7 +402,13 @@ impl NeonOptions { // prefix + format!(" {k}:{v}") // kinda jank because SmolStr is immutable std::iter::once(prefix) - .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v])) + // exclude caps from cache key + .chain( + self.0 + .iter() + .filter(|(k, _)| k != "caps") + .flat_map(|(k, v)| [" ", &**k, ":", &**v]), + ) .collect::() .into() } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 8aa5ad4e8a..ada62bbcec 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -4,7 +4,10 @@ use async_trait::async_trait; use tracing::{field::display, info}; use crate::{ - auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, + auth::{ + backend::{apply_caps, ComputeCredentials}, + check_peer_addr_is_in_list, AuthError, + }, compute, config::ProxyConfig, console::{ @@ -31,8 +34,15 @@ impl PoolingBackend { ) -> Result { let user_info = conn_info.user_info.clone(); let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); + + let bypass_ipcheck = apply_caps( + &&self.config.authentication_config, + &user_info, + &ctx.peer_addr, + )?; + let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { + if !bypass_ipcheck && !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { return Err(AuthError::ip_address_not_allowed(ctx.peer_addr)); } let cached_secret = match maybe_secret {