//! Production console backend. use std::net::IpAddr; use std::str::FromStr; use std::sync::Arc; use ::http::HeaderName; use ::http::header::AUTHORIZATION; use bytes::Bytes; use futures::TryFutureExt; use hyper::StatusCode; use postgres_client::config::SslMode; use tokio::time::Instant; use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; use crate::cache::node_info::CachedNodeInfo; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse}; use crate::control_plane::{ AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl, }; use crate::metrics::Metrics; use crate::proxy::retry::CouldRetry; use crate::rate_limiter::WakeComputeRateLimiter; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, http, scram}; pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); #[derive(Clone)] pub struct NeonControlPlaneClient { endpoint: http::Endpoint, pub caches: &'static ApiCaches, pub(crate) locks: &'static ApiLocks, pub(crate) wake_compute_endpoint_rate_limiter: Arc, // put in a shared ref so we don't copy secrets all over in memory jwt: Arc, } impl NeonControlPlaneClient { /// Construct an API object containing the auth parameters. pub fn new( endpoint: http::Endpoint, jwt: Arc, caches: &'static ApiCaches, locks: &'static ApiLocks, wake_compute_endpoint_rate_limiter: Arc, ) -> Self { Self { endpoint, caches, locks, wake_compute_endpoint_rate_limiter, jwt, } } pub(crate) fn url(&self) -> &str { self.endpoint.url().as_str() } async fn get_and_cache_auth_info( &self, ctx: &RequestContext, endpoint: &EndpointId, role: &RoleName, cache_key: &EndpointId, extract: impl FnOnce(&EndpointAccessControl, &RoleAccessControl) -> T, ) -> Result { match self.do_get_auth_req(ctx, endpoint, role).await { Ok(auth_info) => { let control = EndpointAccessControl { allowed_ips: Arc::new(auth_info.allowed_ips), allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), flags: auth_info.access_blocker_flags, rate_limits: auth_info.rate_limits, }; let role_control = RoleAccessControl { secret: auth_info.secret, }; let res = extract(&control, &role_control); self.caches.project_info.insert_endpoint_access( auth_info.account_id, auth_info.project_id, cache_key.into(), role.into(), control, role_control, ); if let Some(project_id) = auth_info.project_id { ctx.set_project_id(project_id); } Ok(res) } Err(err) => match err { GetAuthInfoError::ApiError(ControlPlaneError::Message(ref msg)) => { let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info); // If we can retry this error, do not cache it, // unless we were given a retry delay. if msg.could_retry() && retry_info.is_none() { return Err(err); } self.caches.project_info.insert_endpoint_access_err( cache_key.into(), role.into(), msg.clone(), ); Err(err) } err => Err(err), }, } } async fn do_get_auth_req( &self, ctx: &RequestContext, endpoint: &EndpointId, role: &RoleName, ) -> Result { async { let response = { let request = self .endpoint .get_path("get_endpoint_access_control") .header(X_REQUEST_ID, ctx.session_id().to_string()) .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) .query(&[("session_id", ctx.session_id())]) .query(&[ ("application_name", ctx.console_application_name().as_str()), ("endpointish", endpoint.as_str()), ("role", role.as_str()), ]) .build()?; debug!(url = request.url().as_str(), "sending http request"); let start = Instant::now(); let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); let response = self.endpoint.execute(request).await?; info!(duration = ?start.elapsed(), "received http response"); response }; let body = match parse_body::( response.status(), response.bytes().await?, ) { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. // TODO(anna): retry Err(e) => { return if e.get_reason().is_not_found() { // TODO: refactor this because it's weird // this is a failure to authenticate but we return Ok. Ok(AuthInfo::default()) } else { Err(e.into()) }; } }; let secret = if body.role_secret.is_empty() { None } else { let secret = scram::ServerSecret::parse(&body.role_secret) .map(AuthSecret::Scram) .ok_or(GetAuthInfoError::BadSecret)?; Some(secret) }; let allowed_ips = body.allowed_ips.unwrap_or_default(); Metrics::get() .proxy .allowed_ips_number .observe(allowed_ips.len() as f64); let allowed_vpc_endpoint_ids = body.allowed_vpc_endpoint_ids.unwrap_or_default(); Metrics::get() .proxy .allowed_vpc_endpoint_ids .observe(allowed_vpc_endpoint_ids.len() as f64); let block_public_connections = body.block_public_connections.unwrap_or_default(); let block_vpc_connections = body.block_vpc_connections.unwrap_or_default(); Ok(AuthInfo { secret, allowed_ips, allowed_vpc_endpoint_ids, project_id: body.project_id, account_id: body.account_id, access_blocker_flags: AccessBlockerFlags { public_access_blocked: block_public_connections, vpc_access_blocked: block_vpc_connections, }, rate_limits: body.rate_limits, }) } .inspect_err(|e| tracing::debug!(error = ?e)) .instrument(info_span!("do_get_auth_info")) .await } async fn do_get_endpoint_jwks( &self, ctx: &RequestContext, endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { let request_id = ctx.session_id().to_string(); async { let request = self .endpoint .get_with_url(|url| { url.path_segments_mut() .push("endpoints") .push(endpoint.as_str()) .push("jwks"); }) .header(X_REQUEST_ID, &request_id) .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) .query(&[("session_id", ctx.session_id())]) .build() .map_err(GetEndpointJwksError::RequestBuild)?; debug!(url = request.url().as_str(), "sending http request"); let start = Instant::now(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); let response = self .endpoint .execute(request) .await .map_err(GetEndpointJwksError::RequestExecute)?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); let body = parse_body::( response.status(), response.bytes().await.map_err(ControlPlaneError::from)?, )?; let rules = body .jwks .into_iter() .map(|jwks| AuthRule { id: jwks.id, jwks_url: jwks.jwks_url, audience: jwks.jwt_audience, role_names: jwks.role_names, }) .collect(); Ok(rules) } .inspect_err(|e| tracing::debug!(error = ?e)) .instrument(info_span!("do_get_endpoint_jwks")) .await } async fn do_wake_compute( &self, ctx: &RequestContext, user_info: &ComputeUserInfo, ) -> Result { let request_id = ctx.session_id().to_string(); let application_name = ctx.console_application_name(); async { let mut request_builder = self .endpoint .get_path("wake_compute") .header("X-Request-ID", &request_id) .header("Authorization", format!("Bearer {}", &self.jwt)) .query(&[("session_id", ctx.session_id())]) .query(&[ ("application_name", application_name.as_str()), ("endpointish", user_info.endpoint.as_str()), ]); let options = user_info.options.to_deep_object(); if !options.is_empty() { request_builder = request_builder.query(&options); } let request = request_builder.build()?; debug!(url = request.url().as_str(), "sending http request"); let start = Instant::now(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); let response = self.endpoint.execute(request).await?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); let body = parse_body::(response.status(), response.bytes().await?)?; let Some((host, port)) = parse_host_port(&body.address) else { return Err(WakeComputeError::BadComputeAddress(body.address)); }; let host_addr = IpAddr::from_str(host).ok(); let ssl_mode = match &body.server_name { Some(_) => SslMode::Require, None => SslMode::Disable, }; let host = match body.server_name { Some(host) => host.into(), None => host.into(), }; let node = NodeInfo { conn_info: compute::ConnectInfo { host_addr, host, port, ssl_mode, }, aux: body.aux, }; Ok(node) } .inspect_err(|e| tracing::debug!(error = ?e)) .instrument(info_span!("do_wake_compute")) .await } } impl super::ControlPlaneApi for NeonControlPlaneClient { #[tracing::instrument(skip_all)] async fn get_role_access_control( &self, ctx: &RequestContext, endpoint: &EndpointId, role: &RoleName, ) -> Result { let key = endpoint.normalize(); if let Some(role_control) = self.caches.project_info.get_role_secret(&key, role) { return match role_control { Err(msg) => { info!(key = &*key, "found cached get_role_access_control error"); Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) } Ok(role_control) => { debug!(key = &*key, "found cached role access control"); Ok(role_control) } }; } self.get_and_cache_auth_info(ctx, endpoint, role, &key, |_, role_control| { role_control.clone() }) .await } #[tracing::instrument(skip_all)] async fn get_endpoint_access_control( &self, ctx: &RequestContext, endpoint: &EndpointId, role: &RoleName, ) -> Result { let key = endpoint.normalize(); if let Some(control) = self.caches.project_info.get_endpoint_access(&key) { return match control { Err(msg) => { info!( key = &*key, "found cached get_endpoint_access_control error" ); Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) } Ok(control) => { debug!(key = &*key, "found cached endpoint access control"); Ok(control) } }; } self.get_and_cache_auth_info(ctx, endpoint, role, &key, |control, _| control.clone()) .await } #[tracing::instrument(skip_all)] async fn get_endpoint_jwks( &self, ctx: &RequestContext, endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(ctx, endpoint).await } #[tracing::instrument(skip_all)] async fn wake_compute( &self, ctx: &RequestContext, user_info: &ComputeUserInfo, ) -> Result { let key = user_info.endpoint_cache_key(); macro_rules! check_cache { () => { if let Some(info) = self.caches.node_info.get_entry(&key) { return match info { Err(msg) => { info!(key = &*key, "found cached wake_compute error"); Err(WakeComputeError::ControlPlane(ControlPlaneError::Message( msg, ))) } Ok(info) => { debug!(key = &*key, "found cached compute node info"); ctx.set_project(info.aux.clone()); Ok(info) } }; } }; } // Every time we do a wakeup http request, the compute node will stay up // for some time (highly depends on the console's scale-to-zero policy); // The connection info remains the same during that period of time, // which means that we might cache it to reduce the load and latency. check_cache!(); let permit = self.locks.get_permit(&key).await?; // after getting back a permit - it's possible the cache was filled // double check if permit.should_check_cache() { // TODO: if there is something in the cache, mark the permit as success. check_cache!(); } // check rate limit if !self .wake_compute_endpoint_rate_limiter .check(user_info.endpoint.normalize_intern(), 1) { return Err(WakeComputeError::TooManyConnections); } let node = permit.release_result(self.do_wake_compute(ctx, user_info).await); match node { Ok(node) => { ctx.set_project(node.aux.clone()); debug!(key = &*key, "created a cache entry for woken compute node"); let mut stored_node = node.clone(); // store the cached node as 'warm_cached' stored_node.aux.cold_start_info = ColdStartInfo::WarmCached; self.caches.node_info.insert(key.clone(), Ok(stored_node)); Ok(Cached { token: Some((&self.caches.node_info, key)), value: node, }) } Err(err) => match err { WakeComputeError::ControlPlane(ControlPlaneError::Message(ref msg)) => { let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info); // If we can retry this error, do not cache it, // unless we were given a retry delay. if msg.could_retry() && retry_info.is_none() { return Err(err); } debug!( key = &*key, "created a cache entry for the wake compute error" ); self.caches.node_info.insert(key, Err(msg.clone())); Err(err) } err => Err(err), }, } } } /// Parse http response body, taking status code into account. fn parse_body serde::Deserialize<'a>>( status: StatusCode, body: Bytes, ) -> Result { if status.is_success() { // We shouldn't log raw body because it may contain secrets. info!("request succeeded, processing the body"); return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?); } // Log plaintext to be able to detect, whether there are some cases not covered by the error struct. info!("response_error plaintext: {:?}", body); // Don't throw an error here because it's not as important // as the fact that the request itself has failed. let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| { warn!("failed to parse error body: {e}"); Box::new(ControlPlaneErrorMessage { error: "reason unclear (malformed error message)".into(), http_status_code: status, status: None, }) }); body.http_status_code = status; warn!("console responded with an error ({status}): {body:?}"); Err(ControlPlaneError::Message(body)) } fn parse_host_port(input: &str) -> Option<(&str, u16)> { let (host, port) = input.rsplit_once(':')?; let ipv6_brackets: &[_] = &['[', ']']; Some((host.trim_matches(ipv6_brackets), port.parse().ok()?)) } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_host_port_v4() { let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse"); assert_eq!(host, "127.0.0.1"); assert_eq!(port, 5432); } #[test] fn test_parse_host_port_v6() { let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse"); assert_eq!(host, "2001:db8::1"); assert_eq!(port, 5432); } #[test] fn test_parse_host_port_url() { let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432") .expect("failed to parse"); assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local"); assert_eq!(port, 5432); } }