mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-24 13:50:37 +00:00
Compare commits
2 Commits
fix_aio_pr
...
neon_caps
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf9d801117 | ||
|
|
f86d98f44f |
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -4320,6 +4320,7 @@ dependencies = [
|
||||
"hyper-util",
|
||||
"ipnet",
|
||||
"itertools",
|
||||
"jsonwebtoken",
|
||||
"lasso",
|
||||
"md5",
|
||||
"metrics",
|
||||
|
||||
@@ -97,6 +97,7 @@ native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ mod password_hack;
|
||||
pub use password_hack::parse_endpoint_param;
|
||||
use password_hack::PasswordHackPayload;
|
||||
|
||||
pub mod caps;
|
||||
|
||||
mod flow;
|
||||
pub use flow::*;
|
||||
use tokio::time::error::Elapsed;
|
||||
@@ -71,6 +73,9 @@ pub enum AuthErrorImpl {
|
||||
#[error("Too many connections to this endpoint. Please try again later.")]
|
||||
TooManyConnections,
|
||||
|
||||
#[error("neon_caps token is invalid")]
|
||||
CapsInvalid,
|
||||
|
||||
#[error("Authentication timed out")]
|
||||
UserTimeout(Elapsed),
|
||||
}
|
||||
@@ -96,6 +101,10 @@ impl AuthError {
|
||||
AuthErrorImpl::TooManyConnections.into()
|
||||
}
|
||||
|
||||
pub fn caps_invalid() -> Self {
|
||||
AuthErrorImpl::CapsInvalid.into()
|
||||
}
|
||||
|
||||
pub fn is_auth_failed(&self) -> bool {
|
||||
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
|
||||
}
|
||||
@@ -126,6 +135,7 @@ impl UserFacingError for AuthError {
|
||||
IpAddressNotAllowed(_) => self.to_string(),
|
||||
TooManyConnections => self.to_string(),
|
||||
UserTimeout(_) => self.to_string(),
|
||||
CapsInvalid => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,6 +155,7 @@ impl ReportableError for AuthError {
|
||||
IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
|
||||
TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
UserTimeout(_) => crate::error::ErrorKind::User,
|
||||
CapsInvalid => crate::error::ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ use crate::{
|
||||
stream, url,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
@@ -251,11 +252,13 @@ async fn auth_quirks(
|
||||
Ok(info) => (info, None),
|
||||
};
|
||||
|
||||
let bypass_ipcheck = apply_caps(&config, &info, &ctx.peer_addr)?;
|
||||
|
||||
info!("fetching user's authentication info");
|
||||
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
|
||||
|
||||
// check allowed list
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
if !bypass_ipcheck && !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
@@ -537,6 +540,7 @@ mod tests {
|
||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||
rate_limiter_enabled: true,
|
||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||
caps: None,
|
||||
});
|
||||
|
||||
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
||||
@@ -695,3 +699,43 @@ mod tests {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// It checks that provided JWT capabilities are valid for the connection
|
||||
//
|
||||
// if it returns Ok(true), futher peer IP checks has to be disabled
|
||||
//
|
||||
// If proxy isn't configured for JWT capabilities or neon_caps option
|
||||
// isn't set, it skips any checks
|
||||
pub fn apply_caps(
|
||||
config: &AuthenticationConfig,
|
||||
info: &ComputeUserInfo,
|
||||
peer_addr: &IpAddr,
|
||||
) -> auth::Result<bool> {
|
||||
match (&config.caps, info.options.caps()) {
|
||||
(Some(caps_config), Some(caps)) => {
|
||||
let token = match caps_config.decode(&caps) {
|
||||
Err(_) => {
|
||||
return Err(auth::AuthError::caps_invalid());
|
||||
}
|
||||
Ok(token) => token,
|
||||
};
|
||||
|
||||
if token.claims.endpoint_id != *info.endpoint {
|
||||
return Err(auth::AuthError::caps_invalid());
|
||||
}
|
||||
|
||||
match token.claims.check_ip(peer_addr) {
|
||||
None => return Ok(false),
|
||||
Some(true) => {
|
||||
return Ok(true);
|
||||
}
|
||||
Some(false) => {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(*peer_addr));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
96
proxy/src/auth/caps.rs
Normal file
96
proxy/src/auth/caps.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use std::{borrow::Cow, fmt::Display, fs, net::IpAddr};
|
||||
|
||||
use anyhow::Result;
|
||||
use camino::Utf8Path;
|
||||
use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use super::{check_peer_addr_is_in_list, IpPattern};
|
||||
|
||||
const TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Scope {
|
||||
Connection,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone, PartialEq)]
|
||||
pub struct Claims {
|
||||
pub scope: Scope,
|
||||
pub allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub endpoint_id: String,
|
||||
}
|
||||
|
||||
impl Claims {
|
||||
pub fn check_ip(&self, ip: &IpAddr) -> Option<bool> {
|
||||
let allowed_ips = match &self.allowed_ips {
|
||||
None => return None,
|
||||
Some(allowed_ips) => allowed_ips,
|
||||
};
|
||||
if allowed_ips.is_empty() {
|
||||
return Some(true);
|
||||
}
|
||||
|
||||
return Some(check_peer_addr_is_in_list(ip, &allowed_ips));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CapsValidator {
|
||||
decoding_key: DecodingKey,
|
||||
validation: Validation,
|
||||
}
|
||||
|
||||
impl CapsValidator {
|
||||
pub fn new(decoding_key: DecodingKey) -> Self {
|
||||
let mut validation = Validation::default();
|
||||
validation.algorithms = vec![TOKEN_ALGORITHM];
|
||||
Self {
|
||||
decoding_key,
|
||||
validation,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_key_path(key_path: &Utf8Path) -> Result<Self> {
|
||||
let metadata = key_path.metadata()?;
|
||||
let decoding_key = if metadata.is_file() {
|
||||
let public_key = fs::read(key_path)?;
|
||||
DecodingKey::from_ed_pem(&public_key)?
|
||||
} else {
|
||||
anyhow::bail!("path isn't a file")
|
||||
};
|
||||
|
||||
Ok(Self::new(decoding_key))
|
||||
}
|
||||
|
||||
pub fn decode(&self, token: &str) -> std::result::Result<TokenData<Claims>, CapsError> {
|
||||
return match decode(token, &self.decoding_key, &self.validation) {
|
||||
Ok(res) => Ok(res),
|
||||
Err(e) => Err(CapsError(Cow::Owned(e.to_string()))),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CapsValidator {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("CapsValidator")
|
||||
.field("validation", &self.validation)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct CapsError(pub Cow<'static, str>);
|
||||
|
||||
impl Display for CapsError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CapsError> for ApiError {
|
||||
fn from(_value: CapsError) -> Self {
|
||||
ApiError::Forbidden("neon_caps validation error".to_string())
|
||||
}
|
||||
}
|
||||
@@ -5,9 +5,11 @@ use aws_config::meta::region::RegionProviderChain;
|
||||
use aws_config::profile::ProfileFileCredentialsProvider;
|
||||
use aws_config::provider_config::ProviderConfig;
|
||||
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
|
||||
use camino::Utf8Path;
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::auth::caps::CapsValidator;
|
||||
use proxy::cancellation::CancelMap;
|
||||
use proxy::cancellation::CancellationHandler;
|
||||
use proxy::config::remote_storage_from_toml;
|
||||
@@ -193,6 +195,9 @@ struct ProxyCliArgs {
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
|
||||
#[clap(long)]
|
||||
caps_key: Option<String>,
|
||||
|
||||
/// interval for backup metric collection
|
||||
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
|
||||
metric_backup_collection_interval: std::time::Duration,
|
||||
@@ -542,10 +547,19 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
|
||||
},
|
||||
};
|
||||
|
||||
let caps = if let Some(key) = &args.caps_key {
|
||||
let path = Utf8Path::new(key);
|
||||
Some(CapsValidator::from_key_path(path)?);
|
||||
} else {
|
||||
None;
|
||||
};
|
||||
|
||||
let authentication_config = AuthenticationConfig {
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
caps,
|
||||
};
|
||||
|
||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
auth,
|
||||
auth::{self, caps::CapsValidator},
|
||||
rate_limiter::{AuthRateLimiter, RateBucketInfo},
|
||||
serverless::GlobalConnPoolOptions,
|
||||
};
|
||||
@@ -58,6 +58,7 @@ pub struct AuthenticationConfig {
|
||||
pub scram_protocol_timeout: tokio::time::Duration,
|
||||
pub rate_limiter_enabled: bool,
|
||||
pub rate_limiter: AuthRateLimiter,
|
||||
pub caps: Option<CapsValidator>,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
|
||||
@@ -385,6 +385,10 @@ impl NeonOptions {
|
||||
!self.0.is_empty()
|
||||
}
|
||||
|
||||
pub fn caps(&self) -> Option<&str> {
|
||||
self.0.iter().find(|(k, _)| k == "caps").map(|(_, v)| &**v)
|
||||
}
|
||||
|
||||
fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
|
||||
let mut options = options
|
||||
.filter_map(neon_option)
|
||||
@@ -398,7 +402,13 @@ impl NeonOptions {
|
||||
// prefix + format!(" {k}:{v}")
|
||||
// kinda jank because SmolStr is immutable
|
||||
std::iter::once(prefix)
|
||||
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
|
||||
// exclude caps from cache key
|
||||
.chain(
|
||||
self.0
|
||||
.iter()
|
||||
.filter(|(k, _)| k != "caps")
|
||||
.flat_map(|(k, v)| [" ", &**k, ":", &**v]),
|
||||
)
|
||||
.collect::<SmolStr>()
|
||||
.into()
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ use async_trait::async_trait;
|
||||
use tracing::{field::display, info};
|
||||
|
||||
use crate::{
|
||||
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
|
||||
auth::{
|
||||
backend::{apply_caps, ComputeCredentials},
|
||||
check_peer_addr_is_in_list, AuthError,
|
||||
},
|
||||
compute,
|
||||
config::ProxyConfig,
|
||||
console::{
|
||||
@@ -31,8 +34,15 @@ impl PoolingBackend {
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = conn_info.user_info.clone();
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
|
||||
|
||||
let bypass_ipcheck = apply_caps(
|
||||
&&self.config.authentication_config,
|
||||
&user_info,
|
||||
&ctx.peer_addr,
|
||||
)?;
|
||||
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
if !bypass_ipcheck && !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
|
||||
Reference in New Issue
Block a user