Compare commits

...

5 Commits

Author SHA1 Message Date
Conrad Ludgate
c4d42d1b2e fix tests 2025-01-15 09:33:36 +00:00
Conrad Ludgate
eab0be6fa8 feat: allow multiple IP patterns in JWT 2025-01-14 16:53:36 +00:00
Conrad Ludgate
f64a240888 fmt 2025-01-14 12:32:50 +00:00
Conrad Ludgate
97d0147ed9 validate jwt during auth_quirks 2025-01-14 11:59:45 +00:00
Conrad Ludgate
d4eedb4069 feat: introduce jwks settings in cplane response 2025-01-14 11:26:15 +00:00
16 changed files with 204 additions and 87 deletions

4
Cargo.lock generated
View File

@@ -4478,6 +4478,8 @@ dependencies = [
"bytes",
"fallible-iterator",
"postgres-protocol",
"serde",
"serde_json",
]
[[package]]
@@ -8022,6 +8024,7 @@ dependencies = [
"num-traits",
"once_cell",
"parquet",
"postgres-types",
"prettyplease",
"proc-macro2",
"prost",
@@ -8048,6 +8051,7 @@ dependencies = [
"time",
"time-macros",
"tokio",
"tokio-postgres",
"tokio-rustls 0.26.0",
"tokio-stream",
"tokio-util",

View File

@@ -121,4 +121,4 @@ rcgen.workspace = true
rstest.workspace = true
walkdir.workspace = true
rand_distr = "0.4"
tokio-postgres.workspace = true
tokio-postgres = { workspace = true, features = ["with-serde_json-1"] }

View File

@@ -94,7 +94,7 @@ impl BackendIpAllowlist for ConsoleRedirectBackend {
self.api
.get_allowed_ips_and_secret(ctx, user_info)
.await
.map(|(ips, _)| ips.as_ref().clone())
.map(|(ips, _)| ips.0.clone())
.map_err(|e| e.into())
}
}

View File

@@ -44,6 +44,18 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
}
#[derive(Clone)]
pub(crate) struct StaticAuthRules(pub Vec<AuthRule>);
impl FetchAuthRules for StaticAuthRules {
async fn fetch_auth_rules(
&self,
_ctx: &RequestContext,
_endpoint: EndpointId,
) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
Ok(self.0.clone())
}
}
#[derive(Error, Debug)]
pub(crate) enum FetchAuthRulesError {
#[error(transparent)]
@@ -53,7 +65,7 @@ pub(crate) enum FetchAuthRulesError {
RoleJwksNotConfigured,
}
#[derive(Clone)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct AuthRule {
pub(crate) id: String,
pub(crate) jwks_url: url::Url,

View File

@@ -10,6 +10,7 @@ use std::sync::Arc;
pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::ConsoleRedirectError;
use ipnet::{Ipv4Net, Ipv6Net};
use jwt::{JwkCache, StaticAuthRules};
use local::LocalBackend;
use postgres_client::config::AuthKeys;
use tokio::io::{AsyncRead, AsyncWrite};
@@ -259,6 +260,7 @@ pub(crate) trait BackendIpAllowlist {
/// Here, we choose the appropriate auth flow based on circumstances.
///
/// All authentication flows will emit an AuthenticationOk message if successful.
#[allow(clippy::too_many_arguments)]
async fn auth_quirks(
ctx: &RequestContext,
api: &impl control_plane::ControlPlaneApi,
@@ -267,6 +269,7 @@ async fn auth_quirks(
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
jwks_cache: Arc<JwkCache>,
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
@@ -282,11 +285,54 @@ async fn auth_quirks(
};
debug!("fetching user's authentication info");
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
let (x, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
let (allowed_ips, auth_rules) = &**x;
// we expect a jwt in the options field
if !auth_rules.is_empty() {
match info.options.get("jwt") {
Some(jwt) => {
let creds = jwks_cache
.check_jwt(
ctx,
info.endpoint.clone(),
&info.user,
&StaticAuthRules(auth_rules.clone()),
&jwt,
)
.await?;
let token = match creds {
ComputeCredentialKeys::JwtPayload(payload) => {
serde_json::from_slice::<serde_json::Value>(&payload)
.expect("jwt payload is valid json")
}
_ => unreachable!(),
};
// the token has a required IP claim.
if let Some(expected_ip) = token.get("ip") {
// todo: don't panic here, obviously.
let allowed_ips: Vec<IpPattern> = expected_ip
.as_str()
.expect("jwt should not have an invalid IP claim")
.split(',')
.map(|s| s.parse().expect("jwt should not have an invalid IP claim"))
.collect();
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
}
None => {
return Err(AuthError::bad_auth_method("needs jwt"));
}
}
}
// check allowed list
if config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
{
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
@@ -326,7 +372,7 @@ async fn auth_quirks(
)
.await
{
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
Ok(keys) => Ok((keys, Some(allowed_ips.clone()))),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
@@ -396,6 +442,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
jwks_cache: Arc<JwkCache>,
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
let res = match self {
Self::ControlPlane(api, user_info) => {
@@ -413,6 +460,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
allow_cleartext,
config,
endpoint_rate_limiter,
jwks_cache,
)
.await?;
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
@@ -447,7 +495,7 @@ impl Backend<'_, ComputeUserInfo> {
Self::ControlPlane(api, user_info) => {
api.get_allowed_ips_and_secret(ctx, user_info).await
}
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Self::Local(_) => Ok((Cached::new_uncached(Arc::new((vec![], vec![]))), None)),
}
}
}
@@ -461,11 +509,11 @@ impl BackendIpAllowlist for Backend<'_, ()> {
) -> auth::Result<Vec<auth::IpPattern>> {
let auth_data = match self {
Self::ControlPlane(api, ()) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Self::Local(_) => Ok((Cached::new_uncached(Arc::new((vec![], vec![]))), None)),
};
auth_data
.map(|(ips, _)| ips.as_ref().clone())
.map(|(ips, _)| ips.0.clone())
.map_err(|e| e.into())
}
}
@@ -543,7 +591,7 @@ mod tests {
control_plane::errors::GetAuthInfoError,
> {
Ok((
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
CachedAllowedIps::new_uncached(Arc::new((self.ips.clone(), vec![]))),
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
))
}
@@ -703,6 +751,7 @@ mod tests {
false,
&CONFIG,
endpoint_rate_limiter,
Arc::new(JwkCache::default()),
)
.await
.unwrap();
@@ -758,6 +807,7 @@ mod tests {
true,
&CONFIG,
endpoint_rate_limiter,
Arc::new(JwkCache::default()),
)
.await
.unwrap();
@@ -811,6 +861,7 @@ mod tests {
true,
&CONFIG,
endpoint_rate_limiter,
Arc::new(JwkCache::default()),
)
.await
.unwrap();

View File

@@ -1,3 +1,5 @@
#![allow(clippy::type_complexity)]
use std::collections::HashSet;
use std::convert::Infallible;
use std::sync::atomic::AtomicU64;
@@ -13,6 +15,7 @@ use tokio::time::Instant;
use tracing::{debug, info};
use super::{Cache, Cached};
use crate::auth::backend::jwt::AuthRule;
use crate::auth::IpPattern;
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::AuthSecret;
@@ -50,7 +53,7 @@ impl<T> From<T> for Entry<T> {
#[derive(Default)]
struct EndpointInfo {
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
allowed_ips: Option<Entry<Arc<(Vec<IpPattern>, Vec<AuthRule>)>>>,
}
impl EndpointInfo {
@@ -81,7 +84,7 @@ impl EndpointInfo {
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Arc<Vec<IpPattern>>, bool)> {
) -> Option<(Arc<(Vec<IpPattern>, Vec<AuthRule>)>, bool)> {
if let Some(allowed_ips) = &self.allowed_ips {
if valid_since < allowed_ips.created_at {
return Some((
@@ -211,7 +214,7 @@ impl ProjectInfoCacheImpl {
pub(crate) fn get_allowed_ips(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
) -> Option<Cached<&Self, Arc<(Vec<IpPattern>, Vec<AuthRule>)>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
@@ -247,7 +250,7 @@ impl ProjectInfoCacheImpl {
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
allowed_ips: Arc<Vec<IpPattern>>,
allowed_ips: Arc<(Vec<IpPattern>, Vec<AuthRule>)>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
@@ -386,10 +389,10 @@ mod tests {
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = None;
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
let allowed_ips = Arc::new((
vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
vec![],
));
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
@@ -457,10 +460,10 @@ mod tests {
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
let allowed_ips = Arc::new((
vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
vec![],
));
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
@@ -520,10 +523,10 @@ mod tests {
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
let allowed_ips = Arc::new((
vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
vec![],
));
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),

View File

@@ -153,9 +153,22 @@ impl NeonControlPlaneClient {
.proxy
.allowed_ips_number
.observe(allowed_ips.len() as f64);
let auth_rules = body
.jwks
.into_iter()
.map(|jwks| AuthRule {
id: jwks.id,
jwks_url: jwks.jwks_url,
audience: jwks.jwt_audience,
role_names: jwks.role_names,
})
.collect();
Ok(AuthInfo {
secret,
allowed_ips,
auth_rules,
project_id: body.project_id,
})
}
@@ -310,7 +323,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
Arc::new(auth_info.allowed_ips),
Arc::new((auth_info.allowed_ips, auth_info.auth_rules)),
);
ctx.set_project_id(project_id);
}
@@ -336,7 +349,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
.allowed_ips_cache_misses
.inc(CacheOutcome::Miss);
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let allowed_ips = Arc::new((auth_info.allowed_ips, auth_info.auth_rules));
let user = &user_info.user;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();

View File

@@ -5,8 +5,9 @@ use std::sync::Arc;
use futures::TryFutureExt;
use thiserror::Error;
use tokio_postgres::types::Json;
use tokio_postgres::Client;
use tracing::{error, info, info_span, warn, Instrument};
use tracing::{error, info, warn};
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::ComputeUserInfo;
@@ -17,7 +18,7 @@ use crate::control_plane::client::{CachedAllowedIps, CachedRoleSecret};
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
};
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::messages::{JwksSettings, MetricsAuxInfo};
use crate::control_plane::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
use crate::error::io_error;
use crate::intern::RoleNameInt;
@@ -65,61 +66,70 @@ impl MockControlPlane {
&self,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
let (secret, allowed_ips) = async {
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
tokio::spawn(connection);
tokio::spawn(connection);
let secret = if let Some(entry) = get_execute_postgres_query(
&client,
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
&[&&*user_info.user],
"rolpassword",
)
.await?
{
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
} else {
warn!("user '{}' does not exist", user_info.user);
None
};
let secret = if let Some(entry) = get_execute_postgres_query(
&client,
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
&[&&*user_info.user],
"rolpassword",
)
.await?
{
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
} else {
warn!("user '{}' does not exist", user_info.user);
None
};
let allowed_ips = if self.ip_allowlist_check_enabled {
match get_execute_postgres_query(
&client,
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
&[&user_info.endpoint.as_str()],
"allowed_ips",
)
.await?
{
Some(s) => {
info!("got allowed_ips: {s}");
s.split(',')
.map(|s| {
IpPattern::from_str(s).expect("mocked ip pattern should be correct")
})
.collect()
}
None => vec![],
let (allowed_ips, auth_rules) = if self.ip_allowlist_check_enabled {
let row = client.query_opt("select allowed_ips, jwks from neon_control_plane.endpoints where endpoint_id = $1", &[&user_info.endpoint.as_str()]).await?;
match row {
Some(row) => {
let allowed_ips: String = row
.try_get("allowed_ips")
.map_err(MockApiError::PasswordNotSet)?;
let jwks: Json<Vec<JwksSettings>> =
row.try_get("jwks").map_err(MockApiError::PasswordNotSet)?;
info!("got allowed_ips: {allowed_ips}");
let allowed_ips = allowed_ips
.split(',')
.map(|s| {
IpPattern::from_str(s).expect("mocked ip pattern should be correct")
})
.collect();
let auth_rules = jwks
.0
.into_iter()
.map(|jwks| AuthRule {
id: jwks.id,
jwks_url: jwks.jwks_url,
audience: jwks.jwt_audience,
role_names: jwks.role_names,
})
.collect();
(allowed_ips, auth_rules)
}
} else {
vec![]
};
None => (vec![], vec![]),
}
} else {
(vec![], vec![])
};
Ok((secret, allowed_ips))
}
.inspect_err(|e: &GetAuthInfoError| tracing::error!("{e}"))
.instrument(info_span!("postgres", url = self.endpoint.as_str()))
.await?;
Ok(AuthInfo {
secret,
auth_rules,
allowed_ips,
project_id: None,
})
@@ -203,7 +213,7 @@ async fn get_execute_postgres_query(
}
impl super::ControlPlaneApi for MockControlPlane {
#[tracing::instrument(skip_all)]
#[tracing::instrument(skip_all, fields(url = self.endpoint.as_str()))]
async fn get_role_secret(
&self,
_ctx: &RequestContext,
@@ -214,19 +224,20 @@ impl super::ControlPlaneApi for MockControlPlane {
))
}
#[tracing::instrument(skip_all, fields(url = self.endpoint.as_str()))]
async fn get_allowed_ips_and_secret(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
let res = self.do_get_auth_info(user_info).await?;
Ok((
Cached::new_uncached(Arc::new(
self.do_get_auth_info(user_info).await?.allowed_ips,
)),
Cached::new_uncached(Arc::new((res.allowed_ips, res.auth_rules))),
None,
))
}
#[tracing::instrument(skip_all, fields(url = self.endpoint.as_str()))]
async fn get_endpoint_jwks(
&self,
_ctx: &RequestContext,
@@ -235,7 +246,7 @@ impl super::ControlPlaneApi for MockControlPlane {
self.do_get_endpoint_jwks(endpoint).await
}
#[tracing::instrument(skip_all)]
#[tracing::instrument(skip_all, fields(url = self.endpoint.as_str()))]
async fn wake_compute(
&self,
_ctx: &RequestContext,

View File

@@ -229,6 +229,8 @@ pub(crate) struct GetEndpointAccessControl {
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<EndpointIdInt>>,
#[serde(default)]
pub(crate) jwks: Vec<JwksSettings>,
}
/// Response which holds compute node's `host:port` pair.

View File

@@ -54,6 +54,7 @@ pub(crate) struct AuthInfo {
pub(crate) allowed_ips: Vec<IpPattern>,
/// Project ID. This is used for cache invalidation.
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) auth_rules: Vec<AuthRule>,
}
/// Info for establishing a connection to a compute node.
@@ -99,7 +100,8 @@ pub(crate) type NodeInfoCache =
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
pub(crate) type CachedAllowedIps =
Cached<&'static ProjectInfoCacheImpl, Arc<(Vec<IpPattern>, Vec<AuthRule>)>>;
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.

View File

@@ -23,6 +23,7 @@ use tracing::{debug, error, info, warn, Instrument};
use self::connect_compute::{connect_to_compute, TcpMechanism};
use self::passthrough::ProxyPassthrough;
use crate::auth::backend::jwt::JwkCache;
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
@@ -71,6 +72,8 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
let jwks_cache = Arc::new(JwkCache::default());
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
@@ -84,6 +87,7 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
let jwks_cache = jwks_cache.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
@@ -136,6 +140,7 @@ pub async fn task_main(
endpoint_rate_limiter2,
conn_gauge,
cancellations,
jwks_cache,
)
.instrument(ctx.span())
.boxed()
@@ -249,6 +254,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
jwks_cache: Arc<JwkCache>,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
@@ -319,6 +325,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
jwks_cache,
)
.await
{

View File

@@ -57,9 +57,10 @@ impl PoolingBackend {
let user_info = user_info.clone();
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
let (x, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
let (allowed_ips, _) = &**x;
if self.config.authentication_config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
{
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
}

View File

@@ -43,6 +43,7 @@ use tokio_util::task::TaskTracker;
use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
use crate::auth::backend::jwt::JwkCache;
use crate::cancellation::CancellationHandlerMain;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
@@ -331,6 +332,8 @@ async fn connection_handler(
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
let jwks_cache = Arc::new(JwkCache::default());
let conn_info2 = conn_info.clone();
let server = Builder::new(TokioExecutor::new());
let conn = server.serve_connection_with_upgrades(
@@ -371,6 +374,7 @@ async fn connection_handler(
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
jwks_cache.clone(),
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -419,6 +423,7 @@ async fn request_handler(
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
jwks_cache: Arc<JwkCache>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -456,6 +461,7 @@ async fn request_handler(
endpoint_rate_limiter,
host,
cancellations,
jwks_cache,
)
.await
{

View File

@@ -12,6 +12,7 @@ use pin_project_lite::pin_project;
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tracing::warn;
use crate::auth::backend::jwt::JwkCache;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestContext;
@@ -133,6 +134,7 @@ pub(crate) async fn serve_websocket(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
jwks_cache: Arc<JwkCache>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
@@ -152,6 +154,7 @@ pub(crate) async fn serve_websocket(
endpoint_rate_limiter,
conn_gauge,
cancellations,
jwks_cache,
))
.await;

View File

@@ -3645,7 +3645,7 @@ def static_proxy(
vanilla_pg.safe_psql("create user proxy with login superuser password 'password'")
vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS neon_control_plane")
vanilla_pg.safe_psql(
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255), jwks jsonb default '[]'::jsonb)"
)
proxy_port = port_distributor.get_port()

View File

@@ -65,6 +65,7 @@ num-rational = { version = "0.4", default-features = false, features = ["num-big
num-traits = { version = "0.2", features = ["i128", "libm"] }
once_cell = { version = "1" }
parquet = { version = "53", default-features = false, features = ["zstd"] }
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", default-features = false, features = ["with-serde_json-1"] }
prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] }
rand = { version = "0.8", features = ["small_rng"] }
regex = { version = "1" }
@@ -86,6 +87,7 @@ tikv-jemalloc-ctl = { version = "0.6", features = ["stats", "use_std"] }
tikv-jemalloc-sys = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] }
time = { version = "0.3", features = ["macros", "serde-well-known"] }
tokio = { version = "1", features = ["full", "test-util"] }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", features = ["with-serde_json-1"] }
tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] }
tokio-stream = { version = "0.1", features = ["net"] }
tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }