DO NOT MERGE [proxy]: Add geo-based routing for replicated projects

This commit is contained in:
Ivan Efremov
2025-05-23 22:25:29 +03:00
parent 136eaeb74a
commit bea1580af2
5 changed files with 180 additions and 1 deletions

View File

@@ -27,6 +27,7 @@ use crate::config::{
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
};
use crate::context::parquet::ParquetUploadArgs;
use crate::control_plane::client::cplane_proxy_v1::{GeoProximity, RegionProximityMap};
use crate::http::health_server::AppMetrics;
use crate::metrics::Metrics;
use crate::rate_limiter::{
@@ -766,12 +767,21 @@ fn build_auth_backend(
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let geo_map = Box::leak(Box::new(RegionProximityMap::from([(
args.region.clone(),
GeoProximity {
_weight: 1,
_distance: 0,
},
)])));
let api = control_plane::client::cplane_proxy_v1::NeonControlPlaneClient::new(
endpoint,
args.control_plane_token.clone(),
caches,
locks,
wake_compute_endpoint_rate_limiter,
geo_map,
);
let api = control_plane::client::ControlPlaneClient::ProxyV1(api);
@@ -845,6 +855,14 @@ fn build_auth_backend(
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let geo_map = Box::leak(Box::new(RegionProximityMap::from([(
args.region.clone(),
GeoProximity {
_weight: 1,
_distance: 0,
},
)])));
// Since we use only get_allowed_ips_and_secret() wake_compute_endpoint_rate_limiter
// and locks are not used in ConsoleRedirectBackend,
// but they are required by the NeonControlPlaneClient
@@ -854,6 +872,7 @@ fn build_auth_backend(
caches,
locks,
wake_compute_endpoint_rate_limiter,
geo_map,
);
let backend = ConsoleRedirectBackend::new(url, api);

View File

@@ -296,6 +296,10 @@ impl RequestContext {
.has_private_peer_addr()
}
pub fn is_global(&self) -> bool {
self.0.try_lock().expect("should not deadlock").region == "global"
}
pub(crate) fn set_error_kind(&self, kind: ErrorKind) {
let mut this = self.0.try_lock().expect("should not deadlock");
// Do not record errors from the private address to metrics.

View File

@@ -1,5 +1,6 @@
//! Production console backend.
use std::collections::HashMap;
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
@@ -12,7 +13,10 @@ use postgres_client::config::SslMode;
use tokio::time::Instant;
use tracing::{Instrument, debug, info, info_span, warn};
use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
use super::super::messages::{
ControlPlaneErrorMessage, GetEndpointAccessControl, GetEndpointAccessControlReplicated,
WakeCompute,
};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
@@ -34,6 +38,15 @@ use crate::{compute, http, scram};
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
pub type RegionId = String;
pub type RegionProximityMap = HashMap<RegionId, GeoProximity>;
#[derive(Clone)]
pub struct GeoProximity {
pub _weight: u64, // load or preference-based parameter
pub _distance: u64, // approximate distance from the region to the current proxy
}
#[derive(Clone)]
pub struct NeonControlPlaneClient {
endpoint: http::Endpoint,
@@ -42,6 +55,7 @@ pub struct NeonControlPlaneClient {
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
// put in a shared ref so we don't copy secrets all over in memory
jwt: Arc<str>,
geo_map: &'static RegionProximityMap,
}
impl NeonControlPlaneClient {
@@ -52,6 +66,7 @@ impl NeonControlPlaneClient {
caches: &'static ApiCaches,
locks: &'static ApiLocks<EndpointCacheKey>,
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
geo_map: &'static RegionProximityMap,
) -> Self {
Self {
endpoint,
@@ -59,6 +74,7 @@ impl NeonControlPlaneClient {
locks,
wake_compute_endpoint_rate_limiter,
jwt,
geo_map,
}
}
@@ -81,10 +97,126 @@ impl NeonControlPlaneClient {
info!("endpoint is not valid, skipping the request");
return Ok(AuthInfo::default());
}
if ctx.is_global() {
return self
.do_get_auth_req_replicated(user_info, &ctx.session_id(), Some(ctx))
.await;
}
self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx))
.await
}
async fn do_get_auth_req_replicated(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
ctx: Option<&RequestContext>,
) -> Result<AuthInfo, GetAuthInfoError> {
let request_id: String = session_id.to_string();
let application_name = if let Some(ctx) = ctx {
ctx.console_application_name()
} else {
"auth_cancellation".to_string()
};
async {
let request = self
.endpoint
.get_path("get_endpoint_access_control_replicated")
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", session_id)])
.query(&[
("application_name", application_name.as_str()),
("endpointish", user_info.endpoint.as_str()),
("role", user_info.user.as_str()),
])
.build()?;
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let response = match ctx {
Some(ctx) => {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let rsp = self.endpoint.execute(request).await;
drop(pause);
rsp?
}
None => self.endpoint.execute(request).await?,
};
info!(duration = ?start.elapsed(), "received http response");
let body = match parse_body::<GetEndpointAccessControlReplicated>(response).await {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
// TODO(anna): retry
Err(e) => {
return if e.get_reason().is_not_found() {
// TODO: refactor this because it's weird
// this is a failure to authenticate but we return Ok.
Ok(AuthInfo::default())
} else {
Err(e.into())
};
}
};
for endpoint in body.endpoints {
if let Some(region_id) = &endpoint.region_id {
if let Some(_proximity) = self.geo_map.get(region_id) {
// TODO:: calculate proximity and reroute
let secret = if endpoint.role_secret.is_empty() {
None
} else {
let secret = scram::ServerSecret::parse(&endpoint.role_secret)
.map(AuthSecret::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
Some(secret)
};
let allowed_ips = endpoint.allowed_ips.unwrap_or_default();
Metrics::get()
.proxy
.allowed_ips_number
.observe(allowed_ips.len() as f64);
let allowed_vpc_endpoint_ids =
endpoint.allowed_vpc_endpoint_ids.unwrap_or_default();
Metrics::get()
.proxy
.allowed_vpc_endpoint_ids
.observe(allowed_vpc_endpoint_ids.len() as f64);
let block_public_connections =
endpoint.block_public_connections.unwrap_or_default();
let block_vpc_connections =
endpoint.block_vpc_connections.unwrap_or_default();
// return the closest replica
return Ok(AuthInfo {
secret,
allowed_ips,
allowed_vpc_endpoint_ids,
project_id: endpoint.project_id,
account_id: endpoint.account_id,
access_blocker_flags: AccessBlockerFlags {
public_access_blocked: block_public_connections,
vpc_access_blocked: block_vpc_connections,
},
});
}
return Err(GetAuthInfoError::RegionNotFound);
}
}
Err(GetAuthInfoError::RegionNotFound)
}
.inspect_err(|e| tracing::debug!(error = ?e))
.instrument(info_span!("do_get_auth_info"))
.await
}
async fn do_get_auth_req(
&self,
user_info: &ComputeUserInfo,

View File

@@ -99,6 +99,9 @@ pub(crate) enum GetAuthInfoError {
#[error(transparent)]
ApiError(ControlPlaneError),
#[error("No endpoint found in this region")]
RegionNotFound,
}
// This allows more useful interactions than `#[from]`.
@@ -115,6 +118,7 @@ impl UserFacingError for GetAuthInfoError {
Self::BadSecret => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
Self::ApiError(e) => e.to_string_client(),
Self::RegionNotFound => "No endpoint found in this region".to_owned(),
}
}
}
@@ -124,6 +128,7 @@ impl ReportableError for GetAuthInfoError {
match self {
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
Self::RegionNotFound => crate::error::ErrorKind::User,
}
}
}

View File

@@ -222,6 +222,25 @@ pub(crate) struct UserFacingMessage {
pub(crate) message: Box<str>,
}
/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
/// Returned by the `/get_endpoint_access_control_replicated` API method.
#[derive(Deserialize)]
pub(crate) struct GetEndpointAccessControlReplicated {
pub(crate) endpoints: Vec<EndpointAccessControlReplicated>,
}
#[derive(Deserialize)]
pub(crate) struct EndpointAccessControlReplicated {
pub(crate) role_secret: Box<str>,
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) account_id: Option<AccountIdInt>,
pub(crate) block_public_connections: Option<bool>,
pub(crate) block_vpc_connections: Option<bool>,
pub(crate) region_id: Option<String>,
}
/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
/// Returned by the `/get_endpoint_access_control` API method.
#[derive(Deserialize)]