From bea1580af2e8f4f97df2d8a85feba655aae33562 Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Fri, 23 May 2025 22:25:29 +0300 Subject: [PATCH] DO NOT MERGE [proxy]: Add geo-based routing for replicated projects --- proxy/src/binary/proxy.rs | 19 +++ proxy/src/context/mod.rs | 4 + .../control_plane/client/cplane_proxy_v1.rs | 134 +++++++++++++++++- proxy/src/control_plane/errors.rs | 5 + proxy/src/control_plane/messages.rs | 19 +++ 5 files changed, 180 insertions(+), 1 deletion(-) diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 5f24940985..ff20ff2932 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -27,6 +27,7 @@ use crate::config::{ ProxyConfig, ProxyProtocolV2, remote_storage_from_toml, }; use crate::context::parquet::ParquetUploadArgs; +use crate::control_plane::client::cplane_proxy_v1::{GeoProximity, RegionProximityMap}; use crate::http::health_server::AppMetrics; use crate::metrics::Metrics; use crate::rate_limiter::{ @@ -766,12 +767,21 @@ fn build_auth_backend( let wake_compute_endpoint_rate_limiter = Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit)); + let geo_map = Box::leak(Box::new(RegionProximityMap::from([( + args.region.clone(), + GeoProximity { + _weight: 1, + _distance: 0, + }, + )]))); + let api = control_plane::client::cplane_proxy_v1::NeonControlPlaneClient::new( endpoint, args.control_plane_token.clone(), caches, locks, wake_compute_endpoint_rate_limiter, + geo_map, ); let api = control_plane::client::ControlPlaneClient::ProxyV1(api); @@ -845,6 +855,14 @@ fn build_auth_backend( let wake_compute_endpoint_rate_limiter = Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit)); + let geo_map = Box::leak(Box::new(RegionProximityMap::from([( + args.region.clone(), + GeoProximity { + _weight: 1, + _distance: 0, + }, + )]))); + // Since we use only get_allowed_ips_and_secret() wake_compute_endpoint_rate_limiter // and locks are not used in ConsoleRedirectBackend, // but they are required by the NeonControlPlaneClient @@ -854,6 +872,7 @@ fn build_auth_backend( caches, locks, wake_compute_endpoint_rate_limiter, + geo_map, ); let backend = ConsoleRedirectBackend::new(url, api); diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 79aaf22990..6d2bee112b 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -296,6 +296,10 @@ impl RequestContext { .has_private_peer_addr() } + pub fn is_global(&self) -> bool { + self.0.try_lock().expect("should not deadlock").region == "global" + } + pub(crate) fn set_error_kind(&self, kind: ErrorKind) { let mut this = self.0.try_lock().expect("should not deadlock"); // Do not record errors from the private address to metrics. diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 2765aaa462..9522a62367 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -1,5 +1,6 @@ //! Production console backend. +use std::collections::HashMap; use std::net::IpAddr; use std::str::FromStr; use std::sync::Arc; @@ -12,7 +13,10 @@ use postgres_client::config::SslMode; use tokio::time::Instant; use tracing::{Instrument, debug, info, info_span, warn}; -use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; +use super::super::messages::{ + ControlPlaneErrorMessage, GetEndpointAccessControl, GetEndpointAccessControlReplicated, + WakeCompute, +}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; @@ -34,6 +38,15 @@ use crate::{compute, http, scram}; pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); +pub type RegionId = String; +pub type RegionProximityMap = HashMap; + +#[derive(Clone)] +pub struct GeoProximity { + pub _weight: u64, // load or preference-based parameter + pub _distance: u64, // approximate distance from the region to the current proxy +} + #[derive(Clone)] pub struct NeonControlPlaneClient { endpoint: http::Endpoint, @@ -42,6 +55,7 @@ pub struct NeonControlPlaneClient { pub(crate) wake_compute_endpoint_rate_limiter: Arc, // put in a shared ref so we don't copy secrets all over in memory jwt: Arc, + geo_map: &'static RegionProximityMap, } impl NeonControlPlaneClient { @@ -52,6 +66,7 @@ impl NeonControlPlaneClient { caches: &'static ApiCaches, locks: &'static ApiLocks, wake_compute_endpoint_rate_limiter: Arc, + geo_map: &'static RegionProximityMap, ) -> Self { Self { endpoint, @@ -59,6 +74,7 @@ impl NeonControlPlaneClient { locks, wake_compute_endpoint_rate_limiter, jwt, + geo_map, } } @@ -81,10 +97,126 @@ impl NeonControlPlaneClient { info!("endpoint is not valid, skipping the request"); return Ok(AuthInfo::default()); } + + if ctx.is_global() { + return self + .do_get_auth_req_replicated(user_info, &ctx.session_id(), Some(ctx)) + .await; + } + self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx)) .await } + async fn do_get_auth_req_replicated( + &self, + user_info: &ComputeUserInfo, + session_id: &uuid::Uuid, + ctx: Option<&RequestContext>, + ) -> Result { + let request_id: String = session_id.to_string(); + let application_name = if let Some(ctx) = ctx { + ctx.console_application_name() + } else { + "auth_cancellation".to_string() + }; + + async { + let request = self + .endpoint + .get_path("get_endpoint_access_control_replicated") + .header(X_REQUEST_ID, &request_id) + .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) + .query(&[("session_id", session_id)]) + .query(&[ + ("application_name", application_name.as_str()), + ("endpointish", user_info.endpoint.as_str()), + ("role", user_info.user.as_str()), + ]) + .build()?; + + debug!(url = request.url().as_str(), "sending http request"); + let start = Instant::now(); + let response = match ctx { + Some(ctx) => { + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); + let rsp = self.endpoint.execute(request).await; + drop(pause); + rsp? + } + None => self.endpoint.execute(request).await?, + }; + + info!(duration = ?start.elapsed(), "received http response"); + let body = match parse_body::(response).await { + Ok(body) => body, + // Error 404 is special: it's ok not to have a secret. + // TODO(anna): retry + Err(e) => { + return if e.get_reason().is_not_found() { + // TODO: refactor this because it's weird + // this is a failure to authenticate but we return Ok. + Ok(AuthInfo::default()) + } else { + Err(e.into()) + }; + } + }; + + for endpoint in body.endpoints { + if let Some(region_id) = &endpoint.region_id { + if let Some(_proximity) = self.geo_map.get(region_id) { + // TODO:: calculate proximity and reroute + + let secret = if endpoint.role_secret.is_empty() { + None + } else { + let secret = scram::ServerSecret::parse(&endpoint.role_secret) + .map(AuthSecret::Scram) + .ok_or(GetAuthInfoError::BadSecret)?; + Some(secret) + }; + let allowed_ips = endpoint.allowed_ips.unwrap_or_default(); + Metrics::get() + .proxy + .allowed_ips_number + .observe(allowed_ips.len() as f64); + let allowed_vpc_endpoint_ids = + endpoint.allowed_vpc_endpoint_ids.unwrap_or_default(); + Metrics::get() + .proxy + .allowed_vpc_endpoint_ids + .observe(allowed_vpc_endpoint_ids.len() as f64); + let block_public_connections = + endpoint.block_public_connections.unwrap_or_default(); + let block_vpc_connections = + endpoint.block_vpc_connections.unwrap_or_default(); + + // return the closest replica + return Ok(AuthInfo { + secret, + allowed_ips, + allowed_vpc_endpoint_ids, + project_id: endpoint.project_id, + account_id: endpoint.account_id, + access_blocker_flags: AccessBlockerFlags { + public_access_blocked: block_public_connections, + vpc_access_blocked: block_vpc_connections, + }, + }); + } + + return Err(GetAuthInfoError::RegionNotFound); + } + } + + Err(GetAuthInfoError::RegionNotFound) + } + .inspect_err(|e| tracing::debug!(error = ?e)) + .instrument(info_span!("do_get_auth_info")) + .await + } + async fn do_get_auth_req( &self, user_info: &ComputeUserInfo, diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index 850d061333..3f47b80393 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -99,6 +99,9 @@ pub(crate) enum GetAuthInfoError { #[error(transparent)] ApiError(ControlPlaneError), + + #[error("No endpoint found in this region")] + RegionNotFound, } // This allows more useful interactions than `#[from]`. @@ -115,6 +118,7 @@ impl UserFacingError for GetAuthInfoError { Self::BadSecret => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. Self::ApiError(e) => e.to_string_client(), + Self::RegionNotFound => "No endpoint found in this region".to_owned(), } } } @@ -124,6 +128,7 @@ impl ReportableError for GetAuthInfoError { match self { Self::BadSecret => crate::error::ErrorKind::ControlPlane, Self::ApiError(_) => crate::error::ErrorKind::ControlPlane, + Self::RegionNotFound => crate::error::ErrorKind::User, } } } diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index ec4554eab5..b1e5551aca 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -222,6 +222,25 @@ pub(crate) struct UserFacingMessage { pub(crate) message: Box, } +/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`]. +/// Returned by the `/get_endpoint_access_control_replicated` API method. +#[derive(Deserialize)] +pub(crate) struct GetEndpointAccessControlReplicated { + pub(crate) endpoints: Vec, +} + +#[derive(Deserialize)] +pub(crate) struct EndpointAccessControlReplicated { + pub(crate) role_secret: Box, + pub(crate) allowed_ips: Option>, + pub(crate) allowed_vpc_endpoint_ids: Option>, + pub(crate) project_id: Option, + pub(crate) account_id: Option, + pub(crate) block_public_connections: Option, + pub(crate) block_vpc_connections: Option, + pub(crate) region_id: Option, +} + /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`]. /// Returned by the `/get_endpoint_access_control` API method. #[derive(Deserialize)]