impr(proxy): Decouple ip_allowlist from the CancelClosure (#10199)

This PR removes the direct dependency of the IP allowlist from
CancelClosure, allowing for more scalable and flexible IP restrictions
and enabling the future use of Redis-based CancelMap storage.

Changes:
- Introduce a new BackendAuth async trait that retrieves the IP
allowlist through existing authentication methods;
- Improve cancellation error handling by instrument() async
cancel_sesion() rather than dropping it.
- Set and store IP allowlist for SCRAM Proxy to consistently perform IP
allowance check
 
 Relates to #9660
This commit is contained in:
Ivan Efremov
2025-01-08 21:34:53 +02:00
committed by GitHub
parent 0ad0db6ff8
commit fcfff72454
8 changed files with 307 additions and 60 deletions

View File

@@ -1,16 +1,18 @@
use async_trait::async_trait;
use postgres_client::config::SslMode;
use pq_proto::BeMessage as Be;
use std::fmt;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};
use super::ComputeCredentialKeys;
use super::{ComputeCredentialKeys, ControlPlaneApi};
use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
use crate::auth::IpPattern;
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::control_plane::{self, client::cplane_proxy_v1, CachedNodeInfo, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::stream::PqStream;
@@ -31,6 +33,13 @@ pub(crate) enum ConsoleRedirectError {
#[derive(Debug)]
pub struct ConsoleRedirectBackend {
console_uri: reqwest::Url,
api: cplane_proxy_v1::NeonControlPlaneClient,
}
impl fmt::Debug for cplane_proxy_v1::NeonControlPlaneClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "NeonControlPlaneClient")
}
}
impl UserFacingError for ConsoleRedirectError {
@@ -71,9 +80,24 @@ pub(crate) fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
#[async_trait]
impl BackendIpAllowlist for ConsoleRedirectBackend {
async fn get_allowed_ips(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> auth::Result<Vec<auth::IpPattern>> {
self.api
.get_allowed_ips_and_secret(ctx, user_info)
.await
.map(|(ips, _)| ips.as_ref().clone())
.map_err(|e| e.into())
}
}
impl ConsoleRedirectBackend {
pub fn new(console_uri: reqwest::Url) -> Self {
Self { console_uri }
pub fn new(console_uri: reqwest::Url, api: cplane_proxy_v1::NeonControlPlaneClient) -> Self {
Self { console_uri, api }
}
pub(crate) async fn authenticate(

View File

@@ -16,7 +16,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint};
use crate::auth::{
self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern,
};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
@@ -131,7 +133,7 @@ pub(crate) struct ComputeUserInfoNoEndpoint {
pub(crate) options: NeonOptions,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub(crate) struct ComputeUserInfo {
pub(crate) endpoint: EndpointId,
pub(crate) user: RoleName,
@@ -244,6 +246,15 @@ impl AuthenticationConfig {
}
}
#[async_trait::async_trait]
pub(crate) trait BackendIpAllowlist {
async fn get_allowed_ips(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> auth::Result<Vec<auth::IpPattern>>;
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
@@ -256,7 +267,7 @@ async fn auth_quirks(
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -315,7 +326,7 @@ async fn auth_quirks(
)
.await
{
Ok(keys) => Ok(keys),
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
@@ -385,7 +396,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<Backend<'a, ComputeCredentials>> {
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
let res = match self {
Self::ControlPlane(api, user_info) => {
debug!(
@@ -394,7 +405,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
"performing authentication using the console"
);
let credentials = auth_quirks(
let (credentials, ip_allowlist) = auth_quirks(
ctx,
&*api,
user_info,
@@ -404,7 +415,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
endpoint_rate_limiter,
)
.await?;
Backend::ControlPlane(api, credentials)
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
@@ -413,7 +424,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
// TODO: replace with some metric
info!("user successfully authenticated");
Ok(res)
res
}
}
@@ -441,6 +452,24 @@ impl Backend<'_, ComputeUserInfo> {
}
}
#[async_trait::async_trait]
impl BackendIpAllowlist for Backend<'_, ()> {
async fn get_allowed_ips(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> auth::Result<Vec<auth::IpPattern>> {
let auth_data = match self {
Self::ControlPlane(api, ()) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
};
auth_data
.map(|(ips, _)| ips.as_ref().clone())
.map_err(|e| e.into())
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
@@ -786,7 +815,7 @@ mod tests {
.await
.unwrap();
assert_eq!(creds.info.endpoint, "my-endpoint");
assert_eq!(creds.0.info.endpoint, "my-endpoint");
handle.await.unwrap();
}

View File

@@ -744,9 +744,59 @@ fn build_auth_backend(
}
AuthBackendType::ConsoleRedirect => {
let url = args.uri.parse()?;
let backend = ConsoleRedirectBackend::new(url);
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
let url = args.uri.clone().parse()?;
let ep_url: proxy::url::ApiUrl = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(ep_url, http::new_client());
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
// Since we use only get_allowed_ips_and_secret() wake_compute_endpoint_rate_limiter
// and locks are not used in ConsoleRedirectBackend,
// but they are required by the NeonControlPlaneClient
let api = control_plane::client::cplane_proxy_v1::NeonControlPlaneClient::new(
endpoint,
args.control_plane_token.clone(),
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let backend = ConsoleRedirectBackend::new(url, api);
let config = Box::leak(Box::new(backend));
Ok(Either::Right(config))

View File

@@ -12,8 +12,10 @@ use tokio::sync::Mutex;
use tracing::{debug, info};
use uuid::Uuid;
use crate::auth::{check_peer_addr_is_in_list, IpPattern};
use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
use crate::auth::{check_peer_addr_is_in_list, AuthError, IpPattern};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
@@ -56,6 +58,9 @@ pub(crate) enum CancelError {
#[error("IP is not allowed")]
IpNotAllowed,
#[error("Authentication backend error")]
AuthError(#[from] AuthError),
}
impl ReportableError for CancelError {
@@ -68,6 +73,7 @@ impl ReportableError for CancelError {
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
CancelError::IpNotAllowed => crate::error::ErrorKind::User,
CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
@@ -102,10 +108,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
}
/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
/// check_allowed - if true, check if the IP is allowed to cancel the query
/// return Result primarily for tests
/// Cancelling only in notification, will be removed
pub(crate) async fn cancel_session(
&self,
key: CancelKeyData,
@@ -134,7 +137,8 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
let cancel_state = self.map.get(&key).and_then(|x| x.clone());
let Some(cancel_closure) = cancel_state else {
tracing::warn!("query cancellation key not found: {key}");
Metrics::get()
.proxy
@@ -185,6 +189,96 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
cancel_closure.try_cancel_query(self.compute_config).await
}
/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
/// check_allowed - if true, check if the IP is allowed to cancel the query.
/// Will fetch IP allowlist internally.
///
/// return Result primarily for tests
pub(crate) async fn cancel_session_auth<T: BackendIpAllowlist>(
&self,
key: CancelKeyData,
ctx: RequestContext,
check_allowed: bool,
auth_backend: &T,
) -> Result<(), CancelError> {
// TODO: check for unspecified address is only for backward compatibility, should be removed
if !ctx.peer_addr().is_unspecified() {
let subnet_key = match ctx.peer_addr() {
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
};
if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
// log only the subnet part of the IP address to know which subnet is rate limited
tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
Metrics::get()
.proxy
.cancellation_requests_total
.inc(CancellationRequest {
source: self.from,
kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
});
return Err(CancelError::RateLimit);
}
}
// NB: we should immediately release the lock after cloning the token.
let cancel_state = self.map.get(&key).and_then(|x| x.clone());
let Some(cancel_closure) = cancel_state else {
tracing::warn!("query cancellation key not found: {key}");
Metrics::get()
.proxy
.cancellation_requests_total
.inc(CancellationRequest {
source: self.from,
kind: crate::metrics::CancellationOutcome::NotFound,
});
if ctx.session_id() == Uuid::nil() {
// was already published, do not publish it again
return Ok(());
}
match self
.client
.try_publish(key, ctx.session_id(), ctx.peer_addr())
.await
{
Ok(()) => {} // do nothing
Err(e) => {
// log it here since cancel_session could be spawned in a task
tracing::error!("failed to publish cancellation key: {key}, error: {e}");
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)));
}
}
return Ok(());
};
let ip_allowlist = auth_backend
.get_allowed_ips(&ctx, &cancel_closure.user_info)
.await
.map_err(CancelError::AuthError)?;
if check_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
// log it here since cancel_session could be spawned in a task
tracing::warn!("IP is not allowed to cancel the query: {key}");
return Err(CancelError::IpNotAllowed);
}
Metrics::get()
.proxy
.cancellation_requests_total
.inc(CancellationRequest {
source: self.from,
kind: crate::metrics::CancellationOutcome::Found,
});
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query(self.compute_config).await
}
#[cfg(test)]
fn contains(&self, session: &Session<P>) -> bool {
self.map.contains_key(&session.key)
@@ -248,6 +342,7 @@ pub struct CancelClosure {
cancel_token: CancelToken,
ip_allowlist: Vec<IpPattern>,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
}
impl CancelClosure {
@@ -256,12 +351,14 @@ impl CancelClosure {
cancel_token: CancelToken,
ip_allowlist: Vec<IpPattern>,
hostname: String,
user_info: ComputeUserInfo,
) -> Self {
Self {
socket_addr,
cancel_token,
ip_allowlist,
hostname,
user_info,
}
}
/// Cancels the query running on user's compute node.
@@ -288,6 +385,8 @@ impl CancelClosure {
debug!("query was cancelled");
Ok(())
}
/// Obsolete (will be removed after moving CancelMap to Redis), only for notifications
pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
self.ip_allowlist = ip_allowlist;
}

View File

@@ -13,6 +13,7 @@ use thiserror::Error;
use tokio::net::TcpStream;
use tracing::{debug, error, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::config::ComputeConfig;
@@ -23,8 +24,10 @@ use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::proxy::neon_option;
use crate::proxy::NeonOptions;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;
use crate::types::{EndpointId, RoleName};
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -284,6 +287,28 @@ impl ConnCfg {
self.0.get_ssl_mode()
);
let compute_info = match parameters.get("user") {
Some(user) => {
match parameters.get("database") {
Some(database) => {
ComputeUserInfo {
user: RoleName::from(user),
options: NeonOptions::default(), // just a shim, we don't need options
endpoint: EndpointId::from(database),
}
}
None => {
warn!("compute node didn't return database name");
ComputeUserInfo::default()
}
}
}
None => {
warn!("compute node didn't return user name");
ComputeUserInfo::default()
}
};
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
@@ -294,8 +319,9 @@ impl ConnCfg {
process_id,
secret_key,
},
vec![],
vec![], // TODO: deprecated, will be removed
host.to_string(),
compute_info,
);
let connection = PostgresConnection {

View File

@@ -159,6 +159,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.as_ref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
@@ -171,23 +172,20 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session_id = ctx.session_id();
let peer_ip = ctx.peer_addr();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?session_id);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
drop(
cancellation_handler_clone
.cancel_session(
cancel_key_data,
session_id,
peer_ip,
config.authentication_config.ip_allowlist_check_enabled,
)
.instrument(cancel_span)
.await,
);
}
cancellation_handler_clone
.cancel_session_auth(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
backend,
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);

View File

@@ -29,7 +29,7 @@ use crate::rate_limiter::WakeComputeRateLimiter;
use crate::types::{EndpointCacheKey, EndpointId};
use crate::{compute, http, scram};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]
pub struct NeonControlPlaneClient {
@@ -78,15 +78,30 @@ impl NeonControlPlaneClient {
info!("endpoint is not valid, skipping the request");
return Ok(AuthInfo::default());
}
let request_id = ctx.session_id().to_string();
let application_name = ctx.console_application_name();
self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx))
.await
}
async fn do_get_auth_req(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
ctx: Option<&RequestContext>,
) -> Result<AuthInfo, GetAuthInfoError> {
let request_id: String = session_id.to_string();
let application_name = if let Some(ctx) = ctx {
ctx.console_application_name()
} else {
"auth_cancellation".to_string()
};
async {
let request = self
.endpoint
.get_path("get_endpoint_access_control")
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[("session_id", session_id)])
.query(&[
("application_name", application_name.as_str()),
("endpointish", user_info.endpoint.as_str()),
@@ -96,9 +111,16 @@ impl NeonControlPlaneClient {
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
let response = match ctx {
Some(ctx) => {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let rsp = self.endpoint.execute(request).await;
drop(pause);
rsp?
}
None => self.endpoint.execute(request).await?,
};
info!(duration = ?start.elapsed(), "received http response");
let body = match parse_body::<GetEndpointAccessControl>(response).await {
Ok(body) => body,

View File

@@ -273,23 +273,20 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session_id = ctx.session_id();
let peer_ip = ctx.peer_addr();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?session_id);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
drop(
cancellation_handler_clone
.cancel_session(
cancel_key_data,
session_id,
peer_ip,
config.authentication_config.ip_allowlist_check_enabled,
)
.instrument(cancel_span)
.await,
);
}
cancellation_handler_clone
.cancel_session_auth(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
auth_backend,
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);
@@ -315,7 +312,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
};
let user = user_info.get_user().to_owned();
let user_info = match user_info
let (user_info, ip_allowlist) = match user_info
.authenticate(
ctx,
&mut stream,
@@ -356,6 +353,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.or_else(|e| stream.throw_error(e))
.await?;
node.cancel_closure
.set_ip_allowlist(ip_allowlist.unwrap_or_default());
let session = cancellation_handler.get_session();
prepare_client_connection(&node, &session, &mut stream).await?;