mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-05 15:20:39 +00:00
Compare commits
5 Commits
diko/baseb
...
conrad/ano
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4d42d1b2e | ||
|
|
eab0be6fa8 | ||
|
|
f64a240888 | ||
|
|
97d0147ed9 | ||
|
|
d4eedb4069 |
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -4478,6 +4478,8 @@ dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8022,6 +8024,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"parquet",
|
||||
"postgres-types",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"prost",
|
||||
@@ -8048,6 +8051,7 @@ dependencies = [
|
||||
"time",
|
||||
"time-macros",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
|
||||
@@ -121,4 +121,4 @@ rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
walkdir.workspace = true
|
||||
rand_distr = "0.4"
|
||||
tokio-postgres.workspace = true
|
||||
tokio-postgres = { workspace = true, features = ["with-serde_json-1"] }
|
||||
|
||||
@@ -94,7 +94,7 @@ impl BackendIpAllowlist for ConsoleRedirectBackend {
|
||||
self.api
|
||||
.get_allowed_ips_and_secret(ctx, user_info)
|
||||
.await
|
||||
.map(|(ips, _)| ips.as_ref().clone())
|
||||
.map(|(ips, _)| ips.0.clone())
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,18 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
|
||||
) -> impl Future<Output = Result<Vec<AuthRule>, FetchAuthRulesError>> + Send;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct StaticAuthRules(pub Vec<AuthRule>);
|
||||
impl FetchAuthRules for StaticAuthRules {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
_endpoint: EndpointId,
|
||||
) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
|
||||
Ok(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub(crate) enum FetchAuthRulesError {
|
||||
#[error(transparent)]
|
||||
@@ -53,7 +65,7 @@ pub(crate) enum FetchAuthRulesError {
|
||||
RoleJwksNotConfigured,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct AuthRule {
|
||||
pub(crate) id: String,
|
||||
pub(crate) jwks_url: url::Url,
|
||||
|
||||
@@ -10,6 +10,7 @@ use std::sync::Arc;
|
||||
pub use console_redirect::ConsoleRedirectBackend;
|
||||
pub(crate) use console_redirect::ConsoleRedirectError;
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
use jwt::{JwkCache, StaticAuthRules};
|
||||
use local::LocalBackend;
|
||||
use postgres_client::config::AuthKeys;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -259,6 +260,7 @@ pub(crate) trait BackendIpAllowlist {
|
||||
/// Here, we choose the appropriate auth flow based on circumstances.
|
||||
///
|
||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn auth_quirks(
|
||||
ctx: &RequestContext,
|
||||
api: &impl control_plane::ControlPlaneApi,
|
||||
@@ -267,6 +269,7 @@ async fn auth_quirks(
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
jwks_cache: Arc<JwkCache>,
|
||||
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
@@ -282,11 +285,54 @@ async fn auth_quirks(
|
||||
};
|
||||
|
||||
debug!("fetching user's authentication info");
|
||||
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
|
||||
let (x, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
|
||||
let (allowed_ips, auth_rules) = &**x;
|
||||
|
||||
// we expect a jwt in the options field
|
||||
if !auth_rules.is_empty() {
|
||||
match info.options.get("jwt") {
|
||||
Some(jwt) => {
|
||||
let creds = jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
info.endpoint.clone(),
|
||||
&info.user,
|
||||
&StaticAuthRules(auth_rules.clone()),
|
||||
&jwt,
|
||||
)
|
||||
.await?;
|
||||
let token = match creds {
|
||||
ComputeCredentialKeys::JwtPayload(payload) => {
|
||||
serde_json::from_slice::<serde_json::Value>(&payload)
|
||||
.expect("jwt payload is valid json")
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// the token has a required IP claim.
|
||||
if let Some(expected_ip) = token.get("ip") {
|
||||
// todo: don't panic here, obviously.
|
||||
let allowed_ips: Vec<IpPattern> = expected_ip
|
||||
.as_str()
|
||||
.expect("jwt should not have an invalid IP claim")
|
||||
.split(',')
|
||||
.map(|s| s.parse().expect("jwt should not have an invalid IP claim"))
|
||||
.collect();
|
||||
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(AuthError::bad_auth_method("needs jwt"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check allowed list
|
||||
if config.ip_allowlist_check_enabled
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
|
||||
{
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
@@ -326,7 +372,7 @@ async fn auth_quirks(
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
|
||||
Ok(keys) => Ok((keys, Some(allowed_ips.clone()))),
|
||||
Err(e) => {
|
||||
if e.is_password_failed() {
|
||||
// The password could have been changed, so we invalidate the cache.
|
||||
@@ -396,6 +442,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
jwks_cache: Arc<JwkCache>,
|
||||
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
|
||||
let res = match self {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
@@ -413,6 +460,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
jwks_cache,
|
||||
)
|
||||
.await?;
|
||||
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
|
||||
@@ -447,7 +495,7 @@ impl Backend<'_, ComputeUserInfo> {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
api.get_allowed_ips_and_secret(ctx, user_info).await
|
||||
}
|
||||
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
Self::Local(_) => Ok((Cached::new_uncached(Arc::new((vec![], vec![]))), None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -461,11 +509,11 @@ impl BackendIpAllowlist for Backend<'_, ()> {
|
||||
) -> auth::Result<Vec<auth::IpPattern>> {
|
||||
let auth_data = match self {
|
||||
Self::ControlPlane(api, ()) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
Self::Local(_) => Ok((Cached::new_uncached(Arc::new((vec![], vec![]))), None)),
|
||||
};
|
||||
|
||||
auth_data
|
||||
.map(|(ips, _)| ips.as_ref().clone())
|
||||
.map(|(ips, _)| ips.0.clone())
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
@@ -543,7 +591,7 @@ mod tests {
|
||||
control_plane::errors::GetAuthInfoError,
|
||||
> {
|
||||
Ok((
|
||||
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
|
||||
CachedAllowedIps::new_uncached(Arc::new((self.ips.clone(), vec![]))),
|
||||
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
|
||||
))
|
||||
}
|
||||
@@ -703,6 +751,7 @@ mod tests {
|
||||
false,
|
||||
&CONFIG,
|
||||
endpoint_rate_limiter,
|
||||
Arc::new(JwkCache::default()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -758,6 +807,7 @@ mod tests {
|
||||
true,
|
||||
&CONFIG,
|
||||
endpoint_rate_limiter,
|
||||
Arc::new(JwkCache::default()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -811,6 +861,7 @@ mod tests {
|
||||
true,
|
||||
&CONFIG,
|
||||
endpoint_rate_limiter,
|
||||
Arc::new(JwkCache::default()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
35
proxy/src/cache/project_info.rs
vendored
35
proxy/src/cache/project_info.rs
vendored
@@ -1,3 +1,5 @@
|
||||
#![allow(clippy::type_complexity)]
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
@@ -13,6 +15,7 @@ use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::{Cache, Cached};
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::IpPattern;
|
||||
use crate::config::ProjectInfoCacheOptions;
|
||||
use crate::control_plane::AuthSecret;
|
||||
@@ -50,7 +53,7 @@ impl<T> From<T> for Entry<T> {
|
||||
#[derive(Default)]
|
||||
struct EndpointInfo {
|
||||
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
|
||||
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
|
||||
allowed_ips: Option<Entry<Arc<(Vec<IpPattern>, Vec<AuthRule>)>>>,
|
||||
}
|
||||
|
||||
impl EndpointInfo {
|
||||
@@ -81,7 +84,7 @@ impl EndpointInfo {
|
||||
&self,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(Arc<Vec<IpPattern>>, bool)> {
|
||||
) -> Option<(Arc<(Vec<IpPattern>, Vec<AuthRule>)>, bool)> {
|
||||
if let Some(allowed_ips) = &self.allowed_ips {
|
||||
if valid_since < allowed_ips.created_at {
|
||||
return Some((
|
||||
@@ -211,7 +214,7 @@ impl ProjectInfoCacheImpl {
|
||||
pub(crate) fn get_allowed_ips(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
|
||||
) -> Option<Cached<&Self, Arc<(Vec<IpPattern>, Vec<AuthRule>)>>> {
|
||||
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(&endpoint_id)?;
|
||||
@@ -247,7 +250,7 @@ impl ProjectInfoCacheImpl {
|
||||
&self,
|
||||
project_id: ProjectIdInt,
|
||||
endpoint_id: EndpointIdInt,
|
||||
allowed_ips: Arc<Vec<IpPattern>>,
|
||||
allowed_ips: Arc<(Vec<IpPattern>, Vec<AuthRule>)>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
@@ -386,10 +389,10 @@ mod tests {
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
let secret2 = None;
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
let allowed_ips = Arc::new((
|
||||
vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
|
||||
vec![],
|
||||
));
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
@@ -457,10 +460,10 @@ mod tests {
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
let allowed_ips = Arc::new((
|
||||
vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
|
||||
vec![],
|
||||
));
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
@@ -520,10 +523,10 @@ mod tests {
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
let allowed_ips = Arc::new((
|
||||
vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
|
||||
vec![],
|
||||
));
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
|
||||
@@ -153,9 +153,22 @@ impl NeonControlPlaneClient {
|
||||
.proxy
|
||||
.allowed_ips_number
|
||||
.observe(allowed_ips.len() as f64);
|
||||
|
||||
let auth_rules = body
|
||||
.jwks
|
||||
.into_iter()
|
||||
.map(|jwks| AuthRule {
|
||||
id: jwks.id,
|
||||
jwks_url: jwks.jwks_url,
|
||||
audience: jwks.jwt_audience,
|
||||
role_names: jwks.role_names,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
auth_rules,
|
||||
project_id: body.project_id,
|
||||
})
|
||||
}
|
||||
@@ -310,7 +323,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
Arc::new(auth_info.allowed_ips),
|
||||
Arc::new((auth_info.allowed_ips, auth_info.auth_rules)),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
@@ -336,7 +349,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
.allowed_ips_cache_misses
|
||||
.inc(CacheOutcome::Miss);
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let allowed_ips = Arc::new((auth_info.allowed_ips, auth_info.auth_rules));
|
||||
let user = &user_info.user;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
|
||||
@@ -5,8 +5,9 @@ use std::sync::Arc;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::types::Json;
|
||||
use tokio_postgres::Client;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
@@ -17,7 +18,7 @@ use crate::control_plane::client::{CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::control_plane::errors::{
|
||||
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
|
||||
};
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::control_plane::messages::{JwksSettings, MetricsAuxInfo};
|
||||
use crate::control_plane::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::io_error;
|
||||
use crate::intern::RoleNameInt;
|
||||
@@ -65,61 +66,70 @@ impl MockControlPlane {
|
||||
&self,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let (secret, allowed_ips) = async {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
tokio::spawn(connection);
|
||||
|
||||
let secret = if let Some(entry) = get_execute_postgres_query(
|
||||
&client,
|
||||
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
|
||||
&[&&*user_info.user],
|
||||
"rolpassword",
|
||||
)
|
||||
.await?
|
||||
{
|
||||
info!("got a secret: {entry}"); // safe since it's not a prod scenario
|
||||
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
} else {
|
||||
warn!("user '{}' does not exist", user_info.user);
|
||||
None
|
||||
};
|
||||
let secret = if let Some(entry) = get_execute_postgres_query(
|
||||
&client,
|
||||
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
|
||||
&[&&*user_info.user],
|
||||
"rolpassword",
|
||||
)
|
||||
.await?
|
||||
{
|
||||
info!("got a secret: {entry}"); // safe since it's not a prod scenario
|
||||
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
} else {
|
||||
warn!("user '{}' does not exist", user_info.user);
|
||||
None
|
||||
};
|
||||
|
||||
let allowed_ips = if self.ip_allowlist_check_enabled {
|
||||
match get_execute_postgres_query(
|
||||
&client,
|
||||
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
|
||||
&[&user_info.endpoint.as_str()],
|
||||
"allowed_ips",
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(s) => {
|
||||
info!("got allowed_ips: {s}");
|
||||
s.split(',')
|
||||
.map(|s| {
|
||||
IpPattern::from_str(s).expect("mocked ip pattern should be correct")
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => vec![],
|
||||
let (allowed_ips, auth_rules) = if self.ip_allowlist_check_enabled {
|
||||
let row = client.query_opt("select allowed_ips, jwks from neon_control_plane.endpoints where endpoint_id = $1", &[&user_info.endpoint.as_str()]).await?;
|
||||
match row {
|
||||
Some(row) => {
|
||||
let allowed_ips: String = row
|
||||
.try_get("allowed_ips")
|
||||
.map_err(MockApiError::PasswordNotSet)?;
|
||||
let jwks: Json<Vec<JwksSettings>> =
|
||||
row.try_get("jwks").map_err(MockApiError::PasswordNotSet)?;
|
||||
|
||||
info!("got allowed_ips: {allowed_ips}");
|
||||
let allowed_ips = allowed_ips
|
||||
.split(',')
|
||||
.map(|s| {
|
||||
IpPattern::from_str(s).expect("mocked ip pattern should be correct")
|
||||
})
|
||||
.collect();
|
||||
|
||||
let auth_rules = jwks
|
||||
.0
|
||||
.into_iter()
|
||||
.map(|jwks| AuthRule {
|
||||
id: jwks.id,
|
||||
jwks_url: jwks.jwks_url,
|
||||
audience: jwks.jwt_audience,
|
||||
role_names: jwks.role_names,
|
||||
})
|
||||
.collect();
|
||||
|
||||
(allowed_ips, auth_rules)
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
None => (vec![], vec![]),
|
||||
}
|
||||
} else {
|
||||
(vec![], vec![])
|
||||
};
|
||||
|
||||
Ok((secret, allowed_ips))
|
||||
}
|
||||
.inspect_err(|e: &GetAuthInfoError| tracing::error!("{e}"))
|
||||
.instrument(info_span!("postgres", url = self.endpoint.as_str()))
|
||||
.await?;
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
auth_rules,
|
||||
allowed_ips,
|
||||
project_id: None,
|
||||
})
|
||||
@@ -203,7 +213,7 @@ async fn get_execute_postgres_query(
|
||||
}
|
||||
|
||||
impl super::ControlPlaneApi for MockControlPlane {
|
||||
#[tracing::instrument(skip_all)]
|
||||
#[tracing::instrument(skip_all, fields(url = self.endpoint.as_str()))]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
@@ -214,19 +224,20 @@ impl super::ControlPlaneApi for MockControlPlane {
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(url = self.endpoint.as_str()))]
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
let res = self.do_get_auth_info(user_info).await?;
|
||||
Ok((
|
||||
Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(user_info).await?.allowed_ips,
|
||||
)),
|
||||
Cached::new_uncached(Arc::new((res.allowed_ips, res.auth_rules))),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(url = self.endpoint.as_str()))]
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
@@ -235,7 +246,7 @@ impl super::ControlPlaneApi for MockControlPlane {
|
||||
self.do_get_endpoint_jwks(endpoint).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
#[tracing::instrument(skip_all, fields(url = self.endpoint.as_str()))]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
|
||||
@@ -229,6 +229,8 @@ pub(crate) struct GetEndpointAccessControl {
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<EndpointIdInt>>,
|
||||
#[serde(default)]
|
||||
pub(crate) jwks: Vec<JwksSettings>,
|
||||
}
|
||||
|
||||
/// Response which holds compute node's `host:port` pair.
|
||||
|
||||
@@ -54,6 +54,7 @@ pub(crate) struct AuthInfo {
|
||||
pub(crate) allowed_ips: Vec<IpPattern>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
pub(crate) auth_rules: Vec<AuthRule>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
@@ -99,7 +100,8 @@ pub(crate) type NodeInfoCache =
|
||||
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
|
||||
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
pub(crate) type CachedAllowedIps =
|
||||
Cached<&'static ProjectInfoCacheImpl, Arc<(Vec<IpPattern>, Vec<AuthRule>)>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
|
||||
@@ -23,6 +23,7 @@ use tracing::{debug, error, info, warn, Instrument};
|
||||
|
||||
use self::connect_compute::{connect_to_compute, TcpMechanism};
|
||||
use self::passthrough::ProxyPassthrough;
|
||||
use crate::auth::backend::jwt::JwkCache;
|
||||
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
@@ -71,6 +72,8 @@ pub async fn task_main(
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
let jwks_cache = Arc::new(JwkCache::default());
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
@@ -84,6 +87,7 @@ pub async fn task_main(
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
let jwks_cache = jwks_cache.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
@@ -136,6 +140,7 @@ pub async fn task_main(
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
jwks_cache,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
@@ -249,6 +254,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
jwks_cache: Arc<JwkCache>,
|
||||
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
@@ -319,6 +325,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
jwks_cache,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -57,9 +57,10 @@ impl PoolingBackend {
|
||||
|
||||
let user_info = user_info.clone();
|
||||
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
let (x, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
let (allowed_ips, _) = &**x;
|
||||
if self.config.authentication_config.ip_allowlist_check_enabled
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
|
||||
{
|
||||
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ use tokio_util::task::TaskTracker;
|
||||
use tracing::{info, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::auth::backend::jwt::JwkCache;
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
@@ -331,6 +332,8 @@ async fn connection_handler(
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let jwks_cache = Arc::new(JwkCache::default());
|
||||
|
||||
let conn_info2 = conn_info.clone();
|
||||
let server = Builder::new(TokioExecutor::new());
|
||||
let conn = server.serve_connection_with_upgrades(
|
||||
@@ -371,6 +374,7 @@ async fn connection_handler(
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellations,
|
||||
jwks_cache.clone(),
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
@@ -419,6 +423,7 @@ async fn request_handler(
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellations: TaskTracker,
|
||||
jwks_cache: Arc<JwkCache>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -456,6 +461,7 @@ async fn request_handler(
|
||||
endpoint_rate_limiter,
|
||||
host,
|
||||
cancellations,
|
||||
jwks_cache,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -12,6 +12,7 @@ use pin_project_lite::pin_project;
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::auth::backend::jwt::JwkCache;
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
@@ -133,6 +134,7 @@ pub(crate) async fn serve_websocket(
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
jwks_cache: Arc<JwkCache>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
@@ -152,6 +154,7 @@ pub(crate) async fn serve_websocket(
|
||||
endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
jwks_cache,
|
||||
))
|
||||
.await;
|
||||
|
||||
|
||||
@@ -3645,7 +3645,7 @@ def static_proxy(
|
||||
vanilla_pg.safe_psql("create user proxy with login superuser password 'password'")
|
||||
vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS neon_control_plane")
|
||||
vanilla_pg.safe_psql(
|
||||
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
|
||||
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255), jwks jsonb default '[]'::jsonb)"
|
||||
)
|
||||
|
||||
proxy_port = port_distributor.get_port()
|
||||
|
||||
@@ -65,6 +65,7 @@ num-rational = { version = "0.4", default-features = false, features = ["num-big
|
||||
num-traits = { version = "0.2", features = ["i128", "libm"] }
|
||||
once_cell = { version = "1" }
|
||||
parquet = { version = "53", default-features = false, features = ["zstd"] }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", default-features = false, features = ["with-serde_json-1"] }
|
||||
prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] }
|
||||
rand = { version = "0.8", features = ["small_rng"] }
|
||||
regex = { version = "1" }
|
||||
@@ -86,6 +87,7 @@ tikv-jemalloc-ctl = { version = "0.6", features = ["stats", "use_std"] }
|
||||
tikv-jemalloc-sys = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] }
|
||||
time = { version = "0.3", features = ["macros", "serde-well-known"] }
|
||||
tokio = { version = "1", features = ["full", "test-util"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", features = ["with-serde_json-1"] }
|
||||
tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] }
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }
|
||||
|
||||
Reference in New Issue
Block a user