mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-04 22:10:39 +00:00
feat(proxy): Implement access control with VPC endpoint checks and block for public internet / VPC (#10143)
- Wired up filtering on VPC endpoints - Wired up block access from public internet / VPC depending on per project flag - Added cache invalidation for VPC endpoints (partially based on PR from Raphael) - Removed BackendIpAllowlist trait --------- Co-authored-by: Ivan Efremov <ivan@neon.tech>
This commit is contained in:
@@ -7,8 +7,8 @@ use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use super::{ComputeCredentialKeys, ControlPlaneApi};
|
||||
use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
|
||||
use super::ComputeCredentialKeys;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::IpPattern;
|
||||
use crate::cache::Cached;
|
||||
use crate::config::AuthenticationConfig;
|
||||
@@ -84,26 +84,15 @@ pub(crate) fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackendIpAllowlist for ConsoleRedirectBackend {
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> auth::Result<Vec<auth::IpPattern>> {
|
||||
self.api
|
||||
.get_allowed_ips_and_secret(ctx, user_info)
|
||||
.await
|
||||
.map(|(ips, _)| ips.as_ref().clone())
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ConsoleRedirectBackend {
|
||||
pub fn new(console_uri: reqwest::Url, api: cplane_proxy_v1::NeonControlPlaneClient) -> Self {
|
||||
Self { console_uri, api }
|
||||
}
|
||||
|
||||
pub(crate) fn get_api(&self) -> &cplane_proxy_v1::NeonControlPlaneClient {
|
||||
&self.api
|
||||
}
|
||||
|
||||
pub(crate) async fn authenticate(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
@@ -191,6 +180,15 @@ async fn authenticate(
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the access over the public internet is allowed, otherwise block. Note that
|
||||
// the console redirect is not behind the VPC service endpoint, so we don't need to check
|
||||
// the VPC endpoint ID.
|
||||
if let Some(public_access_allowed) = db_info.public_access_allowed {
|
||||
if !public_access_allowed {
|
||||
return Err(auth::AuthError::NetworkNotAllowed);
|
||||
}
|
||||
}
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
|
||||
// This config should be self-contained, because we won't
|
||||
|
||||
@@ -26,10 +26,12 @@ use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ControlPlaneClient;
|
||||
use crate::control_plane::errors::GetAuthInfoError;
|
||||
use crate::control_plane::{
|
||||
self, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
|
||||
self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
|
||||
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
|
||||
};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
|
||||
@@ -99,6 +101,13 @@ impl<T> Backend<'_, T> {
|
||||
Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_api(&self) -> &ControlPlaneClient {
|
||||
match self {
|
||||
Self::ControlPlane(api, _) => api,
|
||||
Self::Local(_) => panic!("Local backend has no API"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Backend<'a, T> {
|
||||
@@ -247,15 +256,6 @@ impl AuthenticationConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub(crate) trait BackendIpAllowlist {
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> auth::Result<Vec<auth::IpPattern>>;
|
||||
}
|
||||
|
||||
/// True to its name, this function encapsulates our current auth trade-offs.
|
||||
/// Here, we choose the appropriate auth flow based on circumstances.
|
||||
///
|
||||
@@ -282,23 +282,51 @@ async fn auth_quirks(
|
||||
Ok(info) => (info, None),
|
||||
};
|
||||
|
||||
debug!("fetching user's authentication info");
|
||||
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
|
||||
debug!("fetching authentication info and allowlists");
|
||||
|
||||
// check allowed list
|
||||
if config.ip_allowlist_check_enabled
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
|
||||
{
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
let allowed_ips = if config.ip_allowlist_check_enabled {
|
||||
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
allowed_ips
|
||||
} else {
|
||||
Cached::new_uncached(Arc::new(vec![]))
|
||||
};
|
||||
|
||||
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
|
||||
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
|
||||
if config.is_vpc_acccess_proxy {
|
||||
if access_blocks.vpc_access_blocked {
|
||||
return Err(AuthError::NetworkNotAllowed);
|
||||
}
|
||||
|
||||
let incoming_vpc_endpoint_id = match ctx.extra() {
|
||||
None => return Err(AuthError::MissingEndpointName),
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
|
||||
// Convert the vcpe_id to a string
|
||||
String::from_utf8(vpce_id.to_vec()).unwrap_or_default()
|
||||
}
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
|
||||
};
|
||||
let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?;
|
||||
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
|
||||
if !allowed_vpc_endpoint_ids.is_empty()
|
||||
&& !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
|
||||
{
|
||||
return Err(AuthError::vpc_endpoint_id_not_allowed(
|
||||
incoming_vpc_endpoint_id,
|
||||
));
|
||||
}
|
||||
} else if access_blocks.public_access_blocked {
|
||||
return Err(AuthError::NetworkNotAllowed);
|
||||
}
|
||||
|
||||
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => api.get_role_secret(ctx, &info).await?,
|
||||
};
|
||||
let cached_secret = api.get_role_secret(ctx, &info).await?;
|
||||
let (cached_entry, secret) = cached_secret.take_value();
|
||||
|
||||
let secret = if let Some(secret) = secret {
|
||||
@@ -440,34 +468,38 @@ impl Backend<'_, ComputeUserInfo> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get_allowed_ips_and_secret(
|
||||
pub(crate) async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
api.get_allowed_ips_and_secret(ctx, user_info).await
|
||||
}
|
||||
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
|
||||
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl BackendIpAllowlist for Backend<'_, ()> {
|
||||
async fn get_allowed_ips(
|
||||
pub(crate) async fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> auth::Result<Vec<auth::IpPattern>> {
|
||||
let auth_data = match self {
|
||||
Self::ControlPlane(api, ()) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
};
|
||||
) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
api.get_allowed_vpc_endpoint_ids(ctx, user_info).await
|
||||
}
|
||||
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
|
||||
}
|
||||
}
|
||||
|
||||
auth_data
|
||||
.map(|(ips, _)| ips.as_ref().clone())
|
||||
.map_err(|e| e.into())
|
||||
pub(crate) async fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
api.get_block_public_or_vpc_access(ctx, user_info).await
|
||||
}
|
||||
Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -514,7 +546,10 @@ mod tests {
|
||||
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::{self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret};
|
||||
use crate::control_plane::{
|
||||
self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps,
|
||||
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret,
|
||||
};
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
@@ -523,6 +558,8 @@ mod tests {
|
||||
|
||||
struct Auth {
|
||||
ips: Vec<IpPattern>,
|
||||
vpc_endpoint_ids: Vec<String>,
|
||||
access_blocker_flags: AccessBlockerFlags,
|
||||
secret: AuthSecret,
|
||||
}
|
||||
|
||||
@@ -535,17 +572,31 @@ mod tests {
|
||||
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<
|
||||
(CachedAllowedIps, Option<CachedRoleSecret>),
|
||||
control_plane::errors::GetAuthInfoError,
|
||||
> {
|
||||
Ok((
|
||||
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
|
||||
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
|
||||
) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
|
||||
Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())))
|
||||
}
|
||||
|
||||
async fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
|
||||
Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new(
|
||||
self.vpc_endpoint_ids.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
async fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError> {
|
||||
Ok(CachedAccessBlockerFlags::new_uncached(
|
||||
self.access_blocker_flags.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -575,6 +626,7 @@ mod tests {
|
||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_vpc_acccess_proxy: false,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: false,
|
||||
console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
|
||||
@@ -642,6 +694,8 @@ mod tests {
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
vpc_endpoint_ids: vec![],
|
||||
access_blocker_flags: AccessBlockerFlags::default(),
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
};
|
||||
|
||||
@@ -722,6 +776,8 @@ mod tests {
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
vpc_endpoint_ids: vec![],
|
||||
access_blocker_flags: AccessBlockerFlags::default(),
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
};
|
||||
|
||||
@@ -774,6 +830,8 @@ mod tests {
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
vpc_endpoint_ids: vec![],
|
||||
access_blocker_flags: AccessBlockerFlags::default(),
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
};
|
||||
|
||||
|
||||
@@ -55,6 +55,12 @@ pub(crate) enum AuthError {
|
||||
)]
|
||||
MissingEndpointName,
|
||||
|
||||
#[error(
|
||||
"VPC endpoint ID is not specified. \
|
||||
This endpoint requires a VPC endpoint ID to connect."
|
||||
)]
|
||||
MissingVPCEndpointId,
|
||||
|
||||
#[error("password authentication failed for user '{0}'")]
|
||||
PasswordFailed(Box<str>),
|
||||
|
||||
@@ -69,6 +75,15 @@ pub(crate) enum AuthError {
|
||||
)]
|
||||
IpAddressNotAllowed(IpAddr),
|
||||
|
||||
#[error("This connection is trying to access this endpoint from a blocked network.")]
|
||||
NetworkNotAllowed,
|
||||
|
||||
#[error(
|
||||
"This VPC endpoint id {0} is not allowed to connect to this endpoint. \
|
||||
Please add it to the allowed list in the Neon console."
|
||||
)]
|
||||
VpcEndpointIdNotAllowed(String),
|
||||
|
||||
#[error("Too many connections to this endpoint. Please try again later.")]
|
||||
TooManyConnections,
|
||||
|
||||
@@ -95,6 +110,10 @@ impl AuthError {
|
||||
AuthError::IpAddressNotAllowed(ip)
|
||||
}
|
||||
|
||||
pub(crate) fn vpc_endpoint_id_not_allowed(id: String) -> Self {
|
||||
AuthError::VpcEndpointIdNotAllowed(id)
|
||||
}
|
||||
|
||||
pub(crate) fn too_many_connections() -> Self {
|
||||
AuthError::TooManyConnections
|
||||
}
|
||||
@@ -122,8 +141,11 @@ impl UserFacingError for AuthError {
|
||||
Self::BadAuthMethod(_) => self.to_string(),
|
||||
Self::MalformedPassword(_) => self.to_string(),
|
||||
Self::MissingEndpointName => self.to_string(),
|
||||
Self::MissingVPCEndpointId => self.to_string(),
|
||||
Self::Io(_) => "Internal error".to_string(),
|
||||
Self::IpAddressNotAllowed(_) => self.to_string(),
|
||||
Self::NetworkNotAllowed => self.to_string(),
|
||||
Self::VpcEndpointIdNotAllowed(_) => self.to_string(),
|
||||
Self::TooManyConnections => self.to_string(),
|
||||
Self::UserTimeout(_) => self.to_string(),
|
||||
Self::ConfirmationTimeout(_) => self.to_string(),
|
||||
@@ -142,8 +164,11 @@ impl ReportableError for AuthError {
|
||||
Self::BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
Self::MalformedPassword(_) => crate::error::ErrorKind::User,
|
||||
Self::MissingEndpointName => crate::error::ErrorKind::User,
|
||||
Self::MissingVPCEndpointId => crate::error::ErrorKind::User,
|
||||
Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
Self::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
|
||||
Self::NetworkNotAllowed => crate::error::ErrorKind::User,
|
||||
Self::VpcEndpointIdNotAllowed(_) => crate::error::ErrorKind::User,
|
||||
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
Self::UserTimeout(_) => crate::error::ErrorKind::User,
|
||||
Self::ConfirmationTimeout(_) => crate::error::ErrorKind::User,
|
||||
|
||||
@@ -284,6 +284,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
rate_limiter: BucketRateLimiter::new(vec![]),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_vpc_acccess_proxy: false,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: true,
|
||||
console_redirect_confirmation_timeout: Duration::ZERO,
|
||||
|
||||
@@ -630,6 +630,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_vpc_acccess_proxy: args.is_private_access_proxy,
|
||||
is_auth_broker: args.is_auth_broker,
|
||||
accept_jwts: args.is_auth_broker,
|
||||
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
|
||||
|
||||
224
proxy/src/cache/project_info.rs
vendored
224
proxy/src/cache/project_info.rs
vendored
@@ -15,13 +15,16 @@ use tracing::{debug, info};
|
||||
use super::{Cache, Cached};
|
||||
use crate::auth::IpPattern;
|
||||
use crate::config::ProjectInfoCacheOptions;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::{EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
|
||||
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ProjectInfoCache {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
|
||||
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>);
|
||||
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt);
|
||||
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
|
||||
async fn decrement_active_listeners(&self);
|
||||
async fn increment_active_listeners(&self);
|
||||
@@ -51,6 +54,8 @@ impl<T> From<T> for Entry<T> {
|
||||
struct EndpointInfo {
|
||||
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
|
||||
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
|
||||
block_public_or_vpc_access: Option<Entry<AccessBlockerFlags>>,
|
||||
allowed_vpc_endpoint_ids: Option<Entry<Arc<Vec<String>>>>,
|
||||
}
|
||||
|
||||
impl EndpointInfo {
|
||||
@@ -92,9 +97,52 @@ impl EndpointInfo {
|
||||
}
|
||||
None
|
||||
}
|
||||
pub(crate) fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(Arc<Vec<String>>, bool)> {
|
||||
if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids {
|
||||
if valid_since < allowed_vpc_endpoint_ids.created_at {
|
||||
return Some((
|
||||
allowed_vpc_endpoint_ids.value.clone(),
|
||||
Self::check_ignore_cache(
|
||||
ignore_cache_since,
|
||||
allowed_vpc_endpoint_ids.created_at,
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
pub(crate) fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(AccessBlockerFlags, bool)> {
|
||||
if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access {
|
||||
if valid_since < block_public_or_vpc_access.created_at {
|
||||
return Some((
|
||||
block_public_or_vpc_access.value.clone(),
|
||||
Self::check_ignore_cache(
|
||||
ignore_cache_since,
|
||||
block_public_or_vpc_access.created_at,
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn invalidate_allowed_ips(&mut self) {
|
||||
self.allowed_ips = None;
|
||||
}
|
||||
pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) {
|
||||
self.allowed_vpc_endpoint_ids = None;
|
||||
}
|
||||
pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) {
|
||||
self.block_public_or_vpc_access = None;
|
||||
}
|
||||
pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
|
||||
self.secret.remove(&role_name);
|
||||
}
|
||||
@@ -111,6 +159,8 @@ pub struct ProjectInfoCacheImpl {
|
||||
cache: ClashMap<EndpointIdInt, EndpointInfo>,
|
||||
|
||||
project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
|
||||
// FIXME(stefan): we need a way to GC the account2ep map.
|
||||
account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
|
||||
config: ProjectInfoCacheOptions,
|
||||
|
||||
start_time: Instant,
|
||||
@@ -120,6 +170,63 @@ pub struct ProjectInfoCacheImpl {
|
||||
|
||||
#[async_trait]
|
||||
impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>) {
|
||||
info!(
|
||||
"invalidating allowed vpc endpoint ids for projects `{}`",
|
||||
project_ids
|
||||
.iter()
|
||||
.map(|id| id.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
);
|
||||
for project_id in project_ids {
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
.get(&project_id)
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) {
|
||||
info!(
|
||||
"invalidating allowed vpc endpoint ids for org `{}`",
|
||||
account_id
|
||||
);
|
||||
let endpoints = self
|
||||
.account2ep
|
||||
.get(&account_id)
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) {
|
||||
info!(
|
||||
"invalidating block public or vpc access for project `{}`",
|
||||
project_id
|
||||
);
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
.get(&project_id)
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_block_public_or_vpc_access();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
|
||||
info!("invalidating allowed ips for project `{}`", project_id);
|
||||
let endpoints = self
|
||||
@@ -178,6 +285,7 @@ impl ProjectInfoCacheImpl {
|
||||
Self {
|
||||
cache: ClashMap::new(),
|
||||
project2ep: ClashMap::new(),
|
||||
account2ep: ClashMap::new(),
|
||||
config,
|
||||
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
|
||||
start_time: Instant::now(),
|
||||
@@ -226,6 +334,49 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub(crate) fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<Cached<&Self, Arc<Vec<String>>>> {
|
||||
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(&endpoint_id)?;
|
||||
let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since);
|
||||
let (value, ignore_cache) = value?;
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((
|
||||
self,
|
||||
CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id),
|
||||
)),
|
||||
value,
|
||||
};
|
||||
return Some(cached);
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub(crate) fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<Cached<&Self, AccessBlockerFlags>> {
|
||||
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(&endpoint_id)?;
|
||||
let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since);
|
||||
let (value, ignore_cache) = value?;
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((
|
||||
self,
|
||||
CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id),
|
||||
)),
|
||||
value,
|
||||
};
|
||||
return Some(cached);
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
|
||||
pub(crate) fn insert_role_secret(
|
||||
&self,
|
||||
project_id: ProjectIdInt,
|
||||
@@ -256,6 +407,43 @@ impl ProjectInfoCacheImpl {
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
|
||||
}
|
||||
pub(crate) fn insert_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
account_id: Option<AccountIdInt>,
|
||||
project_id: ProjectIdInt,
|
||||
endpoint_id: EndpointIdInt,
|
||||
allowed_vpc_endpoint_ids: Arc<Vec<String>>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
if let Some(account_id) = account_id {
|
||||
self.insert_account2endpoint(account_id, endpoint_id);
|
||||
}
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
self.cache
|
||||
.entry(endpoint_id)
|
||||
.or_default()
|
||||
.allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into());
|
||||
}
|
||||
pub(crate) fn insert_block_public_or_vpc_access(
|
||||
&self,
|
||||
project_id: ProjectIdInt,
|
||||
endpoint_id: EndpointIdInt,
|
||||
access_blockers: AccessBlockerFlags,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
self.cache
|
||||
.entry(endpoint_id)
|
||||
.or_default()
|
||||
.block_public_or_vpc_access = Some(access_blockers.into());
|
||||
}
|
||||
|
||||
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
|
||||
if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
|
||||
endpoints.insert(endpoint_id);
|
||||
@@ -264,6 +452,14 @@ impl ProjectInfoCacheImpl {
|
||||
.insert(project_id, HashSet::from([endpoint_id]));
|
||||
}
|
||||
}
|
||||
fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
|
||||
if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
|
||||
endpoints.insert(endpoint_id);
|
||||
} else {
|
||||
self.account2ep
|
||||
.insert(account_id, HashSet::from([endpoint_id]));
|
||||
}
|
||||
}
|
||||
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
|
||||
let mut valid_since = Instant::now() - self.config.ttl;
|
||||
// Only ignore cache if ttl is disabled.
|
||||
@@ -334,11 +530,25 @@ impl CachedLookupInfo {
|
||||
lookup_type: LookupType::AllowedIps,
|
||||
}
|
||||
}
|
||||
pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::AllowedVpcEndpointIds,
|
||||
}
|
||||
}
|
||||
pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::BlockPublicOrVpcAccess,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum LookupType {
|
||||
RoleSecret(RoleNameInt),
|
||||
AllowedIps,
|
||||
AllowedVpcEndpointIds,
|
||||
BlockPublicOrVpcAccess,
|
||||
}
|
||||
|
||||
impl Cache for ProjectInfoCacheImpl {
|
||||
@@ -360,6 +570,16 @@ impl Cache for ProjectInfoCacheImpl {
|
||||
endpoint_info.invalidate_allowed_ips();
|
||||
}
|
||||
}
|
||||
LookupType::AllowedVpcEndpointIds => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
|
||||
}
|
||||
}
|
||||
LookupType::BlockPublicOrVpcAccess => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_block_public_or_vpc_access();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,13 +12,15 @@ use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::{check_peer_addr_is_in_list, AuthError};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::ControlPlaneApi;
|
||||
use crate::error::ReportableError;
|
||||
use crate::ext::LockExt;
|
||||
use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
use crate::redis::keys::KeyPrefix;
|
||||
use crate::redis::kv_ops::RedisKVClient;
|
||||
@@ -133,6 +135,9 @@ pub(crate) enum CancelError {
|
||||
#[error("IP is not allowed")]
|
||||
IpNotAllowed,
|
||||
|
||||
#[error("VPC endpoint id is not allowed to connect")]
|
||||
VpcEndpointIdNotAllowed,
|
||||
|
||||
#[error("Authentication backend error")]
|
||||
AuthError(#[from] AuthError),
|
||||
|
||||
@@ -152,8 +157,9 @@ impl ReportableError for CancelError {
|
||||
}
|
||||
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
|
||||
CancelError::IpNotAllowed => crate::error::ErrorKind::User,
|
||||
CancelError::NotFound => crate::error::ErrorKind::User,
|
||||
CancelError::IpNotAllowed
|
||||
| CancelError::VpcEndpointIdNotAllowed
|
||||
| CancelError::NotFound => crate::error::ErrorKind::User,
|
||||
CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
CancelError::InternalError => crate::error::ErrorKind::Service,
|
||||
}
|
||||
@@ -265,11 +271,12 @@ impl CancellationHandler {
|
||||
/// Will fetch IP allowlist internally.
|
||||
///
|
||||
/// return Result primarily for tests
|
||||
pub(crate) async fn cancel_session<T: BackendIpAllowlist>(
|
||||
pub(crate) async fn cancel_session<T: ControlPlaneApi>(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
ctx: RequestContext,
|
||||
check_allowed: bool,
|
||||
check_ip_allowed: bool,
|
||||
check_vpc_allowed: bool,
|
||||
auth_backend: &T,
|
||||
) -> Result<(), CancelError> {
|
||||
let subnet_key = match ctx.peer_addr() {
|
||||
@@ -304,11 +311,11 @@ impl CancellationHandler {
|
||||
return Err(CancelError::NotFound);
|
||||
};
|
||||
|
||||
if check_allowed {
|
||||
if check_ip_allowed {
|
||||
let ip_allowlist = auth_backend
|
||||
.get_allowed_ips(&ctx, &cancel_closure.user_info)
|
||||
.await
|
||||
.map_err(CancelError::AuthError)?;
|
||||
.map_err(|e| CancelError::AuthError(e.into()))?;
|
||||
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
|
||||
// log it here since cancel_session could be spawned in a task
|
||||
@@ -320,6 +327,40 @@ impl CancellationHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
|
||||
let access_blocks = auth_backend
|
||||
.get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info)
|
||||
.await
|
||||
.map_err(|e| CancelError::AuthError(e.into()))?;
|
||||
|
||||
if check_vpc_allowed {
|
||||
if access_blocks.vpc_access_blocked {
|
||||
return Err(CancelError::AuthError(AuthError::NetworkNotAllowed));
|
||||
}
|
||||
|
||||
let incoming_vpc_endpoint_id = match ctx.extra() {
|
||||
None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
|
||||
// Convert the vcpe_id to a string
|
||||
String::from_utf8(vpce_id.to_vec()).unwrap_or_default()
|
||||
}
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
|
||||
};
|
||||
|
||||
let allowed_vpc_endpoint_ids = auth_backend
|
||||
.get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info)
|
||||
.await
|
||||
.map_err(|e| CancelError::AuthError(e.into()))?;
|
||||
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
|
||||
if !allowed_vpc_endpoint_ids.is_empty()
|
||||
&& !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
|
||||
{
|
||||
return Err(CancelError::VpcEndpointIdNotAllowed);
|
||||
}
|
||||
} else if access_blocks.public_access_blocked {
|
||||
return Err(CancelError::VpcEndpointIdNotAllowed);
|
||||
}
|
||||
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
|
||||
@@ -68,6 +68,7 @@ pub struct AuthenticationConfig {
|
||||
pub rate_limiter: AuthRateLimiter,
|
||||
pub rate_limit_ip_subnet: u8,
|
||||
pub ip_allowlist_check_enabled: bool,
|
||||
pub is_vpc_acccess_proxy: bool,
|
||||
pub jwks_cache: JwkCache,
|
||||
pub is_auth_broker: bool,
|
||||
pub accept_jwts: bool,
|
||||
|
||||
@@ -182,7 +182,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
backend,
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
|
||||
@@ -19,7 +19,7 @@ use crate::intern::{BranchIdInt, ProjectIdInt};
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting,
|
||||
};
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
|
||||
use crate::types::{DbName, EndpointId, RoleName};
|
||||
|
||||
pub mod parquet;
|
||||
@@ -312,6 +312,15 @@ impl RequestContext {
|
||||
.ip()
|
||||
}
|
||||
|
||||
pub(crate) fn extra(&self) -> Option<ConnectionInfoExtra> {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.conn_info
|
||||
.extra
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn cold_start_info(&self) -> ColdStartInfo {
|
||||
self.0
|
||||
.try_lock()
|
||||
|
||||
@@ -22,7 +22,8 @@ use crate::control_plane::errors::{
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
|
||||
use crate::control_plane::{
|
||||
AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo,
|
||||
AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
|
||||
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo,
|
||||
};
|
||||
use crate::metrics::{CacheOutcome, Metrics};
|
||||
use crate::rate_limiter::WakeComputeRateLimiter;
|
||||
@@ -137,9 +138,6 @@ impl NeonControlPlaneClient {
|
||||
}
|
||||
};
|
||||
|
||||
// Ivan: don't know where it will be used, so I leave it here
|
||||
let _endpoint_vpc_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default();
|
||||
|
||||
let secret = if body.role_secret.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -153,10 +151,23 @@ impl NeonControlPlaneClient {
|
||||
.proxy
|
||||
.allowed_ips_number
|
||||
.observe(allowed_ips.len() as f64);
|
||||
let allowed_vpc_endpoint_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default();
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_vpc_endpoint_ids
|
||||
.observe(allowed_vpc_endpoint_ids.len() as f64);
|
||||
let block_public_connections = body.block_public_connections.unwrap_or_default();
|
||||
let block_vpc_connections = body.block_vpc_connections.unwrap_or_default();
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
allowed_vpc_endpoint_ids,
|
||||
project_id: body.project_id,
|
||||
account_id: body.account_id,
|
||||
access_blocker_flags: AccessBlockerFlags {
|
||||
public_access_blocked: block_public_connections,
|
||||
vpc_access_blocked: block_vpc_connections,
|
||||
},
|
||||
})
|
||||
}
|
||||
.inspect_err(|e| tracing::debug!(error = ?e))
|
||||
@@ -299,6 +310,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
return Ok(role_secret);
|
||||
}
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let account_id = auth_info.account_id;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
@@ -312,24 +324,35 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
normalized_ep_int,
|
||||
Arc::new(auth_info.allowed_ips),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
|
||||
account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
);
|
||||
self.caches.project_info.insert_block_public_or_vpc_access(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
auth_info.access_blocker_flags,
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
// When we just got a secret, we don't need to invalidate it.
|
||||
Ok(Cached::new_uncached(auth_info.secret))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_ips_cache_misses
|
||||
.allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats?
|
||||
.inc(CacheOutcome::Hit);
|
||||
return Ok((allowed_ips, None));
|
||||
return Ok(allowed_ips);
|
||||
}
|
||||
Metrics::get()
|
||||
.proxy
|
||||
@@ -337,7 +360,10 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
.inc(CacheOutcome::Miss);
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
|
||||
let access_blocker_flags = auth_info.access_blocker_flags;
|
||||
let user = &user_info.user;
|
||||
let account_id = auth_info.account_id;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
@@ -351,12 +377,136 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
normalized_ep_int,
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
|
||||
account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_vpc_endpoint_ids.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_block_public_or_vpc_access(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
access_blocker_flags,
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok((
|
||||
Cached::new_uncached(allowed_ips),
|
||||
Some(Cached::new_uncached(auth_info.secret)),
|
||||
))
|
||||
Ok(Cached::new_uncached(allowed_ips))
|
||||
}
|
||||
|
||||
async fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
if let Some(allowed_vpc_endpoint_ids) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_allowed_vpc_endpoint_ids(normalized_ep)
|
||||
{
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.vpc_endpoint_id_cache_stats
|
||||
.inc(CacheOutcome::Hit);
|
||||
return Ok(allowed_vpc_endpoint_ids);
|
||||
}
|
||||
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.vpc_endpoint_id_cache_stats
|
||||
.inc(CacheOutcome::Miss);
|
||||
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
|
||||
let access_blocker_flags = auth_info.access_blocker_flags;
|
||||
let user = &user_info.user;
|
||||
let account_id = auth_info.account_id;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
|
||||
account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_vpc_endpoint_ids.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_block_public_or_vpc_access(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
access_blocker_flags,
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok(Cached::new_uncached(allowed_vpc_endpoint_ids))
|
||||
}
|
||||
|
||||
async fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
if let Some(access_blocker_flags) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_block_public_or_vpc_access(normalized_ep)
|
||||
{
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.access_blocker_flags_cache_stats
|
||||
.inc(CacheOutcome::Hit);
|
||||
return Ok(access_blocker_flags);
|
||||
}
|
||||
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.access_blocker_flags_cache_stats
|
||||
.inc(CacheOutcome::Miss);
|
||||
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
|
||||
let access_blocker_flags = auth_info.access_blocker_flags;
|
||||
let user = &user_info.user;
|
||||
let account_id = auth_info.account_id;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
|
||||
account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_vpc_endpoint_ids.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_block_public_or_vpc_access(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
access_blocker_flags.clone(),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok(Cached::new_uncached(access_blocker_flags))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
|
||||
@@ -13,12 +13,14 @@ use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::IpPattern;
|
||||
use crate::cache::Cached;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::{CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::control_plane::client::{
|
||||
CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret,
|
||||
};
|
||||
use crate::control_plane::errors::{
|
||||
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
|
||||
};
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::control_plane::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::io_error;
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
@@ -121,7 +123,10 @@ impl MockControlPlane {
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
allowed_vpc_endpoint_ids: vec![],
|
||||
project_id: None,
|
||||
account_id: None,
|
||||
access_blocker_flags: AccessBlockerFlags::default(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -214,16 +219,35 @@ impl super::ControlPlaneApi for MockControlPlane {
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
Ok((
|
||||
Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(user_info).await?.allowed_ips,
|
||||
)),
|
||||
None,
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
Ok(Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(user_info).await?.allowed_ips,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedVpcEndpointIds, super::errors::GetAuthInfoError> {
|
||||
Ok(Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(user_info)
|
||||
.await?
|
||||
.allowed_vpc_endpoint_ids,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<super::CachedAccessBlockerFlags, super::errors::GetAuthInfoError> {
|
||||
Ok(Cached::new_uncached(
|
||||
self.do_get_auth_info(user_info).await?.access_blocker_flags,
|
||||
))
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ use crate::cache::project_info::ProjectInfoCacheImpl;
|
||||
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::{
|
||||
errors, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache,
|
||||
errors, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds,
|
||||
CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, NodeInfoCache,
|
||||
};
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::ApiLockMetrics;
|
||||
@@ -55,17 +56,45 @@ impl ControlPlaneApi for ControlPlaneClient {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
|
||||
) -> Result<CachedAllowedIps, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_allowed_ips_and_secret(),
|
||||
Self::Test(api) => api.get_allowed_ips(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_allowed_vpc_endpoint_ids(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_block_public_or_vpc_access(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,9 +131,15 @@ impl ControlPlaneApi for ControlPlaneClient {
|
||||
pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
|
||||
fn get_allowed_ips_and_secret(
|
||||
fn get_allowed_ips(&self) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
|
||||
|
||||
fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError>;
|
||||
|
||||
fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError>;
|
||||
|
||||
fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use measured::FixedCardinalityLabel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::proxy::retry::CouldRetry;
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
@@ -227,8 +227,11 @@ pub(crate) struct UserFacingMessage {
|
||||
pub(crate) struct GetEndpointAccessControl {
|
||||
pub(crate) role_secret: Box<str>,
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<EndpointIdInt>>,
|
||||
pub(crate) account_id: Option<AccountIdInt>,
|
||||
pub(crate) block_public_connections: Option<bool>,
|
||||
pub(crate) block_vpc_connections: Option<bool>,
|
||||
}
|
||||
|
||||
/// Response which holds compute node's `host:port` pair.
|
||||
@@ -282,6 +285,10 @@ pub(crate) struct DatabaseInfo {
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
#[serde(default)]
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
#[serde(default)]
|
||||
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
pub(crate) public_access_allowed: Option<bool>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
@@ -293,6 +300,7 @@ impl fmt::Debug for DatabaseInfo {
|
||||
.field("dbname", &self.dbname)
|
||||
.field("user", &self.user)
|
||||
.field("allowed_ips", &self.allowed_ips)
|
||||
.field("allowed_vpc_endpoint_ids", &self.allowed_vpc_endpoint_ids)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
@@ -457,7 +465,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_get_role_secret() -> anyhow::Result<()> {
|
||||
// Empty `allowed_ips` field.
|
||||
// Empty `allowed_ips` and `allowed_vpc_endpoint_ids` field.
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
});
|
||||
@@ -467,9 +475,21 @@ mod tests {
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
});
|
||||
serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"],
|
||||
});
|
||||
serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
"allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"],
|
||||
});
|
||||
serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
"allowed_vpc_endpoint_ids": ["vpce-0abcd1234567890ef"],
|
||||
"project_id": "project",
|
||||
});
|
||||
serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
|
||||
|
||||
@@ -19,6 +19,7 @@ use crate::cache::{Cached, TimedLru};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
|
||||
use crate::intern::AccountIdInt;
|
||||
use crate::intern::ProjectIdInt;
|
||||
use crate::types::{EndpointCacheKey, EndpointId};
|
||||
use crate::{compute, scram};
|
||||
@@ -52,8 +53,14 @@ pub(crate) struct AuthInfo {
|
||||
pub(crate) secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub(crate) allowed_ips: Vec<IpPattern>,
|
||||
/// List of VPC endpoints allowed for the autorization.
|
||||
pub(crate) allowed_vpc_endpoint_ids: Vec<String>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
/// Account ID. This is used for cache invalidation.
|
||||
pub(crate) account_id: Option<AccountIdInt>,
|
||||
/// Are public connections or VPC connections blocked?
|
||||
pub(crate) access_blocker_flags: AccessBlockerFlags,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
@@ -95,11 +102,21 @@ impl NodeInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Eq, PartialEq, Debug)]
|
||||
pub(crate) struct AccessBlockerFlags {
|
||||
pub public_access_blocked: bool,
|
||||
pub vpc_access_blocked: bool,
|
||||
}
|
||||
|
||||
pub(crate) type NodeInfoCache =
|
||||
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
|
||||
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
pub(crate) type CachedAllowedVpcEndpointIds =
|
||||
Cached<&'static ProjectInfoCacheImpl, Arc<Vec<String>>>;
|
||||
pub(crate) type CachedAccessBlockerFlags =
|
||||
Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
@@ -113,11 +130,23 @@ pub(crate) trait ControlPlaneApi {
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::sync::OnceLock;
|
||||
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
|
||||
use rustc_hash::FxHasher;
|
||||
|
||||
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
use crate::types::{AccountId, BranchId, EndpointId, ProjectId, RoleName};
|
||||
|
||||
pub trait InternId: Sized + 'static {
|
||||
fn get_interner() -> &'static StringInterner<Self>;
|
||||
@@ -206,6 +206,26 @@ impl From<ProjectId> for ProjectIdInt {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct AccountIdTag;
|
||||
impl InternId for AccountIdTag {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
static ROLE_NAMES: OnceLock<StringInterner<AccountIdTag>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
pub type AccountIdInt = InternedString<AccountIdTag>;
|
||||
impl From<&AccountId> for AccountIdInt {
|
||||
fn from(value: &AccountId) -> Self {
|
||||
AccountIdTag::get_interner().get_or_intern(value)
|
||||
}
|
||||
}
|
||||
impl From<AccountId> for AccountIdInt {
|
||||
fn from(value: AccountId) -> Self {
|
||||
AccountIdTag::get_interner().get_or_intern(&value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
|
||||
@@ -96,6 +96,16 @@ pub struct ProxyMetrics {
|
||||
#[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))]
|
||||
pub allowed_ips_number: Histogram<10>,
|
||||
|
||||
/// Number of cache hits/misses for VPC endpoint IDs.
|
||||
pub vpc_endpoint_id_cache_stats: CounterVec<StaticLabelSet<CacheOutcome>>,
|
||||
|
||||
/// Number of cache hits/misses for access blocker flags.
|
||||
pub access_blocker_flags_cache_stats: CounterVec<StaticLabelSet<CacheOutcome>>,
|
||||
|
||||
/// Number of allowed VPC endpoints IDs
|
||||
#[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))]
|
||||
pub allowed_vpc_endpoint_ids: Histogram<10>,
|
||||
|
||||
/// Number of connections (per sni).
|
||||
pub accepted_connections_by_sni: CounterVec<StaticLabelSet<SniKind>>,
|
||||
|
||||
@@ -570,6 +580,9 @@ pub enum RedisEventsCount {
|
||||
CancelSession,
|
||||
PasswordUpdate,
|
||||
AllowedIpsUpdate,
|
||||
AllowedVpcEndpointIdsUpdateForProjects,
|
||||
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
|
||||
BlockPublicOrVpcAccessUpdate,
|
||||
}
|
||||
|
||||
pub struct ThreadPoolWorkers(usize);
|
||||
|
||||
@@ -283,7 +283,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
auth_backend,
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
|
||||
@@ -26,7 +26,7 @@ use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{
|
||||
self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache,
|
||||
self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache,
|
||||
};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
@@ -526,9 +526,19 @@ impl TestControlPlaneClient for TestConnectMechanism {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_allowed_ips_and_secret(
|
||||
fn get_allowed_ips(&self) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
|
||||
fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), control_plane::errors::GetAuthInfoError>
|
||||
) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
|
||||
fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
) -> Result<control_plane::CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError>
|
||||
{
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use uuid::Uuid;
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
use crate::intern::{ProjectIdInt, RoleNameInt};
|
||||
use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
|
||||
|
||||
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
@@ -86,9 +86,7 @@ pub(crate) struct BlockPublicOrVpcAccessUpdated {
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
|
||||
// TODO: change type once the implementation is more fully fledged.
|
||||
// See e.g. https://github.com/neondatabase/neon/pull/10073.
|
||||
account_id: ProjectIdInt,
|
||||
account_id: AccountIdInt,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
@@ -205,6 +203,24 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::PasswordUpdate);
|
||||
} else if matches!(
|
||||
msg,
|
||||
Notification::AllowedVpcEndpointsUpdatedForProjects { .. }
|
||||
) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
|
||||
} else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
|
||||
} else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
|
||||
}
|
||||
// TODO: add additional metrics for the other event types.
|
||||
|
||||
@@ -230,20 +246,26 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
Notification::AllowedIpsUpdate { allowed_ips_update } => {
|
||||
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
|
||||
}
|
||||
Notification::BlockPublicOrVpcAccessUpdated {
|
||||
block_public_or_vpc_access_updated,
|
||||
} => cache.invalidate_block_public_or_vpc_access_for_project(
|
||||
block_public_or_vpc_access_updated.project_id,
|
||||
),
|
||||
Notification::AllowedVpcEndpointsUpdatedForOrg {
|
||||
allowed_vpc_endpoints_updated_for_org,
|
||||
} => cache.invalidate_allowed_vpc_endpoint_ids_for_org(
|
||||
allowed_vpc_endpoints_updated_for_org.account_id,
|
||||
),
|
||||
Notification::AllowedVpcEndpointsUpdatedForProjects {
|
||||
allowed_vpc_endpoints_updated_for_projects,
|
||||
} => cache.invalidate_allowed_vpc_endpoint_ids_for_projects(
|
||||
allowed_vpc_endpoints_updated_for_projects.project_ids,
|
||||
),
|
||||
Notification::PasswordUpdate { password_update } => cache
|
||||
.invalidate_role_secret_for_project(
|
||||
password_update.project_id,
|
||||
password_update.role_name,
|
||||
),
|
||||
Notification::BlockPublicOrVpcAccessUpdated { .. } => {
|
||||
// https://github.com/neondatabase/neon/pull/10073
|
||||
}
|
||||
Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
|
||||
// https://github.com/neondatabase/neon/pull/10073
|
||||
}
|
||||
Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
|
||||
// https://github.com/neondatabase/neon/pull/10073
|
||||
}
|
||||
Notification::UnknownTopic => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::CachedNodeInfo;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
@@ -57,23 +58,52 @@ impl PoolingBackend {
|
||||
|
||||
let user_info = user_info.clone();
|
||||
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
let allowed_ips = backend.get_allowed_ips(ctx).await?;
|
||||
|
||||
if self.config.authentication_config.ip_allowlist_check_enabled
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
|
||||
{
|
||||
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
|
||||
let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?;
|
||||
if self.config.authentication_config.is_vpc_acccess_proxy {
|
||||
if access_blocker_flags.vpc_access_blocked {
|
||||
return Err(AuthError::NetworkNotAllowed);
|
||||
}
|
||||
|
||||
let extra = ctx.extra();
|
||||
let incoming_endpoint_id = match extra {
|
||||
None => String::new(),
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
|
||||
// Convert the vcpe_id to a string
|
||||
String::from_utf8(vpce_id.to_vec()).unwrap_or_default()
|
||||
}
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
|
||||
};
|
||||
|
||||
if incoming_endpoint_id.is_empty() {
|
||||
return Err(AuthError::MissingVPCEndpointId);
|
||||
}
|
||||
|
||||
let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?;
|
||||
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
|
||||
if !allowed_vpc_endpoint_ids.is_empty()
|
||||
&& !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id)
|
||||
{
|
||||
return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id));
|
||||
}
|
||||
} else if access_blocker_flags.public_access_blocked {
|
||||
return Err(AuthError::NetworkNotAllowed);
|
||||
}
|
||||
|
||||
if !self
|
||||
.endpoint_rate_limiter
|
||||
.check(user_info.endpoint.clone().into(), 1)
|
||||
{
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => backend.get_role_secret(ctx).await?,
|
||||
};
|
||||
|
||||
let cached_secret = backend.get_role_secret(ctx).await?;
|
||||
let secret = match cached_secret.value.clone() {
|
||||
Some(secret) => self.config.authentication_config.check_rate_limit(
|
||||
ctx,
|
||||
|
||||
@@ -97,6 +97,8 @@ smol_str_wrapper!(EndpointId);
|
||||
smol_str_wrapper!(BranchId);
|
||||
// 90% of project strings are 23 characters or less.
|
||||
smol_str_wrapper!(ProjectId);
|
||||
// 90% of account strings are 23 characters or less.
|
||||
smol_str_wrapper!(AccountId);
|
||||
|
||||
// will usually equal endpoint ID
|
||||
smol_str_wrapper!(EndpointCacheKey);
|
||||
|
||||
Reference in New Issue
Block a user