Compare commits

...

2 Commits

Author SHA1 Message Date
Andrey Rudenko
cf9d801117 expressions rule! 2024-04-11 14:42:32 +02:00
Andrey Rudenko
f86d98f44f proxy: add neon_caps support 2024-04-11 14:09:01 +02:00
9 changed files with 193 additions and 5 deletions

1
Cargo.lock generated
View File

@@ -4320,6 +4320,7 @@ dependencies = [
"hyper-util",
"ipnet",
"itertools",
"jsonwebtoken",
"lasso",
"md5",
"metrics",

View File

@@ -97,6 +97,7 @@ native-tls.workspace = true
postgres-native-tls.workspace = true
postgres-protocol.workspace = true
redis.workspace = true
jsonwebtoken.workspace = true
workspace_hack.workspace = true

View File

@@ -13,6 +13,8 @@ mod password_hack;
pub use password_hack::parse_endpoint_param;
use password_hack::PasswordHackPayload;
pub mod caps;
mod flow;
pub use flow::*;
use tokio::time::error::Elapsed;
@@ -71,6 +73,9 @@ pub enum AuthErrorImpl {
#[error("Too many connections to this endpoint. Please try again later.")]
TooManyConnections,
#[error("neon_caps token is invalid")]
CapsInvalid,
#[error("Authentication timed out")]
UserTimeout(Elapsed),
}
@@ -96,6 +101,10 @@ impl AuthError {
AuthErrorImpl::TooManyConnections.into()
}
pub fn caps_invalid() -> Self {
AuthErrorImpl::CapsInvalid.into()
}
pub fn is_auth_failed(&self) -> bool {
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
}
@@ -126,6 +135,7 @@ impl UserFacingError for AuthError {
IpAddressNotAllowed(_) => self.to_string(),
TooManyConnections => self.to_string(),
UserTimeout(_) => self.to_string(),
CapsInvalid => self.to_string(),
}
}
}
@@ -145,6 +155,7 @@ impl ReportableError for AuthError {
IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
TooManyConnections => crate::error::ErrorKind::RateLimit,
UserTimeout(_) => crate::error::ErrorKind::User,
CapsInvalid => crate::error::ErrorKind::User,
}
}
}

View File

@@ -28,6 +28,7 @@ use crate::{
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use std::net::IpAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -251,11 +252,13 @@ async fn auth_quirks(
Ok(info) => (info, None),
};
let bypass_ipcheck = apply_caps(&config, &info, &ctx.peer_addr)?;
info!("fetching user's authentication info");
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
// check allowed list
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
if !bypass_ipcheck && !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
}
let cached_secret = match maybe_secret {
@@ -537,6 +540,7 @@ mod tests {
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
caps: None,
});
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
@@ -695,3 +699,43 @@ mod tests {
handle.await.unwrap();
}
}
// It checks that provided JWT capabilities are valid for the connection
//
// if it returns Ok(true), futher peer IP checks has to be disabled
//
// If proxy isn't configured for JWT capabilities or neon_caps option
// isn't set, it skips any checks
pub fn apply_caps(
config: &AuthenticationConfig,
info: &ComputeUserInfo,
peer_addr: &IpAddr,
) -> auth::Result<bool> {
match (&config.caps, info.options.caps()) {
(Some(caps_config), Some(caps)) => {
let token = match caps_config.decode(&caps) {
Err(_) => {
return Err(auth::AuthError::caps_invalid());
}
Ok(token) => token,
};
if token.claims.endpoint_id != *info.endpoint {
return Err(auth::AuthError::caps_invalid());
}
match token.claims.check_ip(peer_addr) {
None => return Ok(false),
Some(true) => {
return Ok(true);
}
Some(false) => {
return Err(auth::AuthError::ip_address_not_allowed(*peer_addr));
}
}
}
_ => {
return Ok(false);
}
}
}

96
proxy/src/auth/caps.rs Normal file
View File

@@ -0,0 +1,96 @@
use std::{borrow::Cow, fmt::Display, fs, net::IpAddr};
use anyhow::Result;
use camino::Utf8Path;
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
use serde::{Deserialize, Serialize};
use utils::http::error::ApiError;
use super::{check_peer_addr_is_in_list, IpPattern};
const TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Scope {
Connection,
}
#[derive(Debug, Deserialize, Clone, PartialEq)]
pub struct Claims {
pub scope: Scope,
pub allowed_ips: Option<Vec<IpPattern>>,
pub endpoint_id: String,
}
impl Claims {
pub fn check_ip(&self, ip: &IpAddr) -> Option<bool> {
let allowed_ips = match &self.allowed_ips {
None => return None,
Some(allowed_ips) => allowed_ips,
};
if allowed_ips.is_empty() {
return Some(true);
}
return Some(check_peer_addr_is_in_list(ip, &allowed_ips));
}
}
pub struct CapsValidator {
decoding_key: DecodingKey,
validation: Validation,
}
impl CapsValidator {
pub fn new(decoding_key: DecodingKey) -> Self {
let mut validation = Validation::default();
validation.algorithms = vec![TOKEN_ALGORITHM];
Self {
decoding_key,
validation,
}
}
pub fn from_key_path(key_path: &Utf8Path) -> Result<Self> {
let metadata = key_path.metadata()?;
let decoding_key = if metadata.is_file() {
let public_key = fs::read(key_path)?;
DecodingKey::from_ed_pem(&public_key)?
} else {
anyhow::bail!("path isn't a file")
};
Ok(Self::new(decoding_key))
}
pub fn decode(&self, token: &str) -> std::result::Result<TokenData<Claims>, CapsError> {
return match decode(token, &self.decoding_key, &self.validation) {
Ok(res) => Ok(res),
Err(e) => Err(CapsError(Cow::Owned(e.to_string()))),
};
}
}
impl std::fmt::Debug for CapsValidator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CapsValidator")
.field("validation", &self.validation)
.finish()
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct CapsError(pub Cow<'static, str>);
impl Display for CapsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<CapsError> for ApiError {
fn from(_value: CapsError) -> Self {
ApiError::Forbidden("neon_caps validation error".to_string())
}
}

View File

@@ -5,9 +5,11 @@ use aws_config::meta::region::RegionProviderChain;
use aws_config::profile::ProfileFileCredentialsProvider;
use aws_config::provider_config::ProviderConfig;
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use camino::Utf8Path;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::MaybeOwned;
use proxy::auth::caps::CapsValidator;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
use proxy::config::remote_storage_from_toml;
@@ -193,6 +195,9 @@ struct ProxyCliArgs {
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
#[clap(long)]
caps_key: Option<String>,
/// interval for backup metric collection
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
metric_backup_collection_interval: std::time::Duration,
@@ -542,10 +547,19 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
},
};
let caps = if let Some(key) = &args.caps_key {
let path = Utf8Path::new(key);
Some(CapsValidator::from_key_path(path)?);
} else {
None;
};
let authentication_config = AuthenticationConfig {
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
caps,
};
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();

View File

@@ -1,5 +1,5 @@
use crate::{
auth,
auth::{self, caps::CapsValidator},
rate_limiter::{AuthRateLimiter, RateBucketInfo},
serverless::GlobalConnPoolOptions,
};
@@ -58,6 +58,7 @@ pub struct AuthenticationConfig {
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,
pub caps: Option<CapsValidator>,
}
impl TlsConfig {

View File

@@ -385,6 +385,10 @@ impl NeonOptions {
!self.0.is_empty()
}
pub fn caps(&self) -> Option<&str> {
self.0.iter().find(|(k, _)| k == "caps").map(|(_, v)| &**v)
}
fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
let mut options = options
.filter_map(neon_option)
@@ -398,7 +402,13 @@ impl NeonOptions {
// prefix + format!(" {k}:{v}")
// kinda jank because SmolStr is immutable
std::iter::once(prefix)
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
// exclude caps from cache key
.chain(
self.0
.iter()
.filter(|(k, _)| k != "caps")
.flat_map(|(k, v)| [" ", &**k, ":", &**v]),
)
.collect::<SmolStr>()
.into()
}

View File

@@ -4,7 +4,10 @@ use async_trait::async_trait;
use tracing::{field::display, info};
use crate::{
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
auth::{
backend::{apply_caps, ComputeCredentials},
check_peer_addr_is_in_list, AuthError,
},
compute,
config::ProxyConfig,
console::{
@@ -31,8 +34,15 @@ impl PoolingBackend {
) -> Result<ComputeCredentials, AuthError> {
let user_info = conn_info.user_info.clone();
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
let bypass_ipcheck = apply_caps(
&&self.config.authentication_config,
&user_info,
&ctx.peer_addr,
)?;
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
if !bypass_ipcheck && !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
}
let cached_secret = match maybe_secret {