diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 48de4e2353..c8028d1bf0 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -36,9 +36,6 @@ pub enum AuthErrorImpl { #[error(transparent)] GetAuthInfo(#[from] console::errors::GetAuthInfoError), - #[error(transparent)] - WakeCompute(#[from] console::errors::WakeComputeError), - /// SASL protocol errors (includes [SCRAM](crate::scram)). #[error(transparent)] Sasl(#[from] crate::sasl::Error), @@ -119,7 +116,6 @@ impl UserFacingError for AuthError { match self.0.as_ref() { Link(e) => e.to_string_client(), GetAuthInfo(e) => e.to_string_client(), - WakeCompute(e) => e.to_string_client(), Sasl(e) => e.to_string_client(), AuthFailed(_) => self.to_string(), BadAuthMethod(_) => self.to_string(), @@ -139,7 +135,6 @@ impl ReportableError for AuthError { match self.0.as_ref() { Link(e) => e.get_error_kind(), GetAuthInfo(e) => e.get_error_kind(), - WakeCompute(e) => e.get_error_kind(), Sasl(e) => e.get_error_kind(), AuthFailed(_) => crate::error::ErrorKind::User, BadAuthMethod(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index c9f21f1cf5..47c1dc4e92 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange; use crate::cache::Cached; use crate::console::errors::GetAuthInfoError; use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; -use crate::console::AuthSecret; +use crate::console::{AuthSecret, NodeInfo}; use crate::context::RequestMonitoring; -use crate::proxy::wake_compute::wake_compute; +use crate::proxy::connect_compute::ComputeConnectBackend; use crate::proxy::NeonOptions; use crate::stream::Stream; use crate::{ @@ -26,7 +26,6 @@ use crate::{ stream, url, }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; -use futures::TryFutureExt; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -56,11 +55,11 @@ impl std::ops::Deref for MaybeOwned<'_, T> { /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`], /// this helps us provide the credentials only to those auth /// backends which require them for the authentication process. -pub enum BackendType<'a, T> { +pub enum BackendType<'a, T, D> { /// Cloud API (V2). Console(MaybeOwned<'a, ConsoleBackend>, T), /// Authentication via a web browser. - Link(MaybeOwned<'a, url::ApiUrl>), + Link(MaybeOwned<'a, url::ApiUrl>, D), } pub trait TestBackend: Send + Sync + 'static { @@ -71,7 +70,7 @@ pub trait TestBackend: Send + Sync + 'static { fn get_role_secret(&self) -> Result; } -impl std::fmt::Display for BackendType<'_, ()> { +impl std::fmt::Display for BackendType<'_, (), ()> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use BackendType::*; match self { @@ -86,51 +85,50 @@ impl std::fmt::Display for BackendType<'_, ()> { #[cfg(test)] ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), }, - Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), + Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), } } } -impl BackendType<'_, T> { +impl BackendType<'_, T, D> { /// Very similar to [`std::option::Option::as_ref`]. /// This helps us pass structured config to async tasks. - pub fn as_ref(&self) -> BackendType<'_, &T> { + pub fn as_ref(&self) -> BackendType<'_, &T, &D> { use BackendType::*; match self { Console(c, x) => Console(MaybeOwned::Borrowed(c), x), - Link(c) => Link(MaybeOwned::Borrowed(c)), + Link(c, x) => Link(MaybeOwned::Borrowed(c), x), } } } -impl<'a, T> BackendType<'a, T> { +impl<'a, T, D> BackendType<'a, T, D> { /// Very similar to [`std::option::Option::map`]. /// Maps [`BackendType`] to [`BackendType`] by applying /// a function to a contained value. - pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> { + pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> { use BackendType::*; match self { Console(c, x) => Console(c, f(x)), - Link(c) => Link(c), + Link(c, x) => Link(c, x), } } } - -impl<'a, T, E> BackendType<'a, Result> { +impl<'a, T, D, E> BackendType<'a, Result, D> { /// Very similar to [`std::option::Option::transpose`]. /// This is most useful for error handling. - pub fn transpose(self) -> Result, E> { + pub fn transpose(self) -> Result, E> { use BackendType::*; match self { Console(c, x) => x.map(|x| Console(c, x)), - Link(c) => Ok(Link(c)), + Link(c, x) => Ok(Link(c, x)), } } } -pub struct ComputeCredentials { +pub struct ComputeCredentials { pub info: ComputeUserInfo, - pub keys: T, + pub keys: ComputeCredentialKeys, } #[derive(Debug, Clone)] @@ -153,7 +151,6 @@ impl ComputeUserInfo { } pub enum ComputeCredentialKeys { - #[cfg(any(test, feature = "testing"))] Password(Vec), AuthKeys(AuthKeys), } @@ -188,7 +185,7 @@ async fn auth_quirks( client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, -) -> auth::Result> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -198,8 +195,11 @@ async fn auth_quirks( ctx.set_endpoint_id(res.info.endpoint.clone()); tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint)); - - (res.info, Some(res.keys)) + let password = match res.keys { + ComputeCredentialKeys::Password(p) => p, + _ => unreachable!("password hack should return a password"), + }; + (res.info, Some(password)) } Ok(info) => (info, None), }; @@ -253,7 +253,7 @@ async fn authenticate_with_secret( unauthenticated_password: Option>, allow_cleartext: bool, config: &'static AuthenticationConfig, -) -> auth::Result> { +) -> auth::Result { if let Some(password) = unauthenticated_password { let auth_outcome = validate_password_and_exchange(&password, secret)?; let keys = match auth_outcome { @@ -283,14 +283,14 @@ async fn authenticate_with_secret( classic::authenticate(ctx, info, client, config, secret).await } -impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { +impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { /// Get compute endpoint name from the credentials. pub fn get_endpoint(&self) -> Option { use BackendType::*; match self { Console(_, user_info) => user_info.endpoint_id.clone(), - Link(_) => Some("link".into()), + Link(_, _) => Some("link".into()), } } @@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => &user_info.user, - Link(_) => "link", + Link(_, _) => "link", } } @@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, - ) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> { + ) -> auth::Result> { use BackendType::*; let res = match self { @@ -323,33 +323,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let compute_credentials = + let credentials = auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?; - - let mut num_retries = 0; - let mut node = - wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?; - - ctx.set_project(node.aux.clone()); - - match compute_credentials.keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => node.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), - }; - - (node, BackendType::Console(api, compute_credentials.info)) + BackendType::Console(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Link(url) => { + Link(url, _) => { info!("performing link authentication"); - let node_info = link::authenticate(ctx, &url, client).await?; + let info = link::authenticate(ctx, &url, client).await?; - ( - CachedNodeInfo::new_uncached(node_info), - BackendType::Link(url), - ) + BackendType::Link(url, info) } }; @@ -358,7 +342,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { } } -impl BackendType<'_, ComputeUserInfo> { +impl BackendType<'_, ComputeUserInfo, &()> { pub async fn get_role_secret( &self, ctx: &mut RequestMonitoring, @@ -366,7 +350,7 @@ impl BackendType<'_, ComputeUserInfo> { use BackendType::*; match self { Console(api, user_info) => api.get_role_secret(ctx, user_info).await, - Link(_) => Ok(Cached::new_uncached(None)), + Link(_, _) => Ok(Cached::new_uncached(None)), } } @@ -377,21 +361,51 @@ impl BackendType<'_, ComputeUserInfo> { use BackendType::*; match self { Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, - Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), - } - } - - /// When applicable, wake the compute node, gaining its connection info in the process. - /// The link auth flow doesn't support this, so we return [`None`] in that case. - pub async fn wake_compute( - &self, - ctx: &mut RequestMonitoring, - ) -> Result, console::errors::WakeComputeError> { - use BackendType::*; - - match self { - Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, - Link(_) => Ok(None), + Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + } + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + + match self { + Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Link(_, info) => Ok(Cached::new_uncached(info.clone())), + } + } + + fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + match self { + BackendType::Console(_, creds) => Some(&creds.keys), + BackendType::Link(_, _) => None, + } + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + + match self { + Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"), + } + } + + fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + match self { + BackendType::Console(_, creds) => Some(&creds.keys), + BackendType::Link(_, _) => None, } } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index e855843bc3..d075331846 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -17,7 +17,7 @@ pub(super) async fn authenticate( client: &mut PqStream>, config: &'static AuthenticationConfig, secret: AuthSecret, -) -> auth::Result> { +) -> auth::Result { let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 9f60b709d4..26cf7a01f2 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -20,7 +20,7 @@ pub async fn authenticate_cleartext( info: ComputeUserInfo, client: &mut stream::PqStream>, secret: AuthSecret, -) -> auth::Result> { +) -> auth::Result { warn!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); @@ -51,7 +51,7 @@ pub async fn password_hack_no_authentication( ctx: &mut RequestMonitoring, info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, -) -> auth::Result>> { +) -> auth::Result { warn!("project not specified, resorting to the password hack auth flow"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); @@ -73,6 +73,6 @@ pub async fn password_hack_no_authentication( options: info.options, endpoint: payload.endpoint, }, - keys: payload.password, + keys: ComputeCredentialKeys::Password(payload.password), }) } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 8fbcb56758..00a229c135 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -383,7 +383,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } AuthBackend::Link => { let url = args.uri.parse()?; - auth::BackendType::Link(MaybeOwned::Owned(url)) + auth::BackendType::Link(MaybeOwned::Owned(url), ()) } }; let http_config = HttpConfig { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 83940d80ec..b61c1fb9ef 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,7 +1,7 @@ use crate::{ auth::parse_endpoint_param, cancellation::CancelClosure, - console::errors::WakeComputeError, + console::{errors::WakeComputeError, messages::MetricsAuxInfo}, context::RequestMonitoring, error::{ReportableError, UserFacingError}, metrics::NUM_DB_CONNECTIONS_GAUGE, @@ -93,7 +93,7 @@ impl ConnCfg { } /// Reuse password or auth keys from the other config. - pub fn reuse_password(&mut self, other: &Self) { + pub fn reuse_password(&mut self, other: Self) { if let Some(password) = other.get_password() { self.password(password); } @@ -253,6 +253,8 @@ pub struct PostgresConnection { pub params: std::collections::HashMap, /// Query cancellation token. pub cancel_closure: CancelClosure, + /// Labels for proxy's metrics. + pub aux: MetricsAuxInfo, _guage: IntCounterPairGuard, } @@ -263,6 +265,7 @@ impl ConnCfg { &self, ctx: &mut RequestMonitoring, allow_self_signed_compute: bool, + aux: MetricsAuxInfo, timeout: Duration, ) -> Result { let (socket_addr, stream, host) = self.connect_raw(timeout).await?; @@ -297,6 +300,7 @@ impl ConnCfg { stream, params, cancel_closure, + aux, _guage: NUM_DB_CONNECTIONS_GAUGE .with_label_values(&[ctx.protocol]) .guard(), diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 31c9228b35..5fcb537834 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -13,7 +13,7 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::BackendType<'static, ()>, + pub auth_backend: auth::BackendType<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index e5cad42753..640444d14e 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -4,7 +4,10 @@ pub mod neon; use super::messages::MetricsAuxInfo; use crate::{ - auth::{backend::ComputeUserInfo, IpPattern}, + auth::{ + backend::{ComputeCredentialKeys, ComputeUserInfo}, + IpPattern, + }, cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru}, compute, config::{CacheOptions, ProjectInfoCacheOptions}, @@ -261,6 +264,34 @@ pub struct NodeInfo { pub allow_self_signed_compute: bool, } +impl NodeInfo { + pub async fn connect( + &self, + ctx: &mut RequestMonitoring, + timeout: Duration, + ) -> Result { + self.config + .connect( + ctx, + self.allow_self_signed_compute, + self.aux.clone(), + timeout, + ) + .await + } + pub fn reuse_settings(&mut self, other: Self) { + self.allow_self_signed_compute = other.allow_self_signed_compute; + self.config.reuse_password(other.config); + } + + pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) { + match keys { + ComputeCredentialKeys::Password(password) => self.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), + }; + } +} + pub type NodeInfoCache = TimedLru; pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 79a04f255d..0579ef6fc4 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -176,9 +176,7 @@ impl super::Api for Api { _ctx: &mut RequestMonitoring, _user_info: &ComputeUserInfo, ) -> Result { - self.do_wake_compute() - .map_ok(CachedNodeInfo::new_uncached) - .await + self.do_wake_compute().map_ok(Cached::new_uncached).await } } diff --git a/proxy/src/error.rs b/proxy/src/error.rs index eafe92bf48..69fe1ebc12 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -29,7 +29,7 @@ pub trait UserFacingError: ReportableError { } } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ErrorKind { /// Wrong password, unknown endpoint, protocol violation, etc... User, @@ -90,3 +90,13 @@ impl ReportableError for tokio::time::error::Elapsed { ErrorKind::RateLimit } } + +impl ReportableError for tokio_postgres::error::Error { + fn get_error_kind(&self) -> ErrorKind { + if self.as_db_error().is_some() { + ErrorKind::Postgres + } else { + ErrorKind::Compute + } + } +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 77aadb6f28..5f65de4c98 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -163,14 +163,14 @@ pub enum ClientMode { /// Abstracts the logic of handling TCP vs WS clients impl ClientMode { - fn allow_cleartext(&self) -> bool { + pub fn allow_cleartext(&self) -> bool { match self { ClientMode::Tcp => false, ClientMode::Websockets { .. } => true, } } - fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { + pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { match self { ClientMode::Tcp => config.allow_self_signed_compute, ClientMode::Websockets { .. } => false, @@ -287,7 +287,7 @@ pub async fn handle_client( } let user = user_info.get_user().to_owned(); - let (mut node_info, user_info) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, @@ -306,14 +306,11 @@ pub async fn handle_client( } }; - node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config); - - let aux = node_info.aux.clone(); let mut node = connect_to_compute( ctx, &TcpMechanism { params: ¶ms }, - node_info, &user_info, + mode.allow_self_signed_compute(config), ) .or_else(|e| stream.throw_error(e)) .await?; @@ -330,8 +327,8 @@ pub async fn handle_client( Ok(Some(ProxyPassthrough { client: stream, + aux: node.aux.clone(), compute: node, - aux, req: _request_gauge, conn: _client_gauge, })) diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index b9346aa743..6e57caf998 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,8 +1,9 @@ use crate::{ - auth, + auth::backend::ComputeCredentialKeys, compute::{self, PostgresConnection}, - console::{self, errors::WakeComputeError}, + console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo}, context::RequestMonitoring, + error::ReportableError, metrics::NUM_CONNECTION_FAILURES, proxy::{ retry::{retry_after, ShouldRetry}, @@ -20,7 +21,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(name = "invalidate_cache", skip_all)] -pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg { +pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); @@ -31,13 +32,13 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg }; NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc(); - node_info.invalidate().config + node_info.invalidate() } #[async_trait] pub trait ConnectMechanism { type Connection; - type ConnectError; + type ConnectError: ReportableError; type Error: From; async fn connect_once( &self, @@ -49,6 +50,16 @@ pub trait ConnectMechanism { fn update_connect_config(&self, conf: &mut compute::ConnCfg); } +#[async_trait] +pub trait ComputeConnectBackend { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result; + + fn get_keys(&self) -> Option<&ComputeCredentialKeys>; +} + pub struct TcpMechanism<'a> { /// KV-dictionary with PostgreSQL connection params. pub params: &'a StartupMessageParams, @@ -67,11 +78,7 @@ impl ConnectMechanism for TcpMechanism<'_> { node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result { - let allow_self_signed_compute = node_info.allow_self_signed_compute; - node_info - .config - .connect(ctx, allow_self_signed_compute, timeout) - .await + node_info.connect(ctx, timeout).await } fn update_connect_config(&self, config: &mut compute::ConnCfg) { @@ -82,16 +89,23 @@ impl ConnectMechanism for TcpMechanism<'_> { /// Try to connect to the compute node, retrying if necessary. /// This function might update `node_info`, so we take it by `&mut`. #[tracing::instrument(skip_all)] -pub async fn connect_to_compute( +pub async fn connect_to_compute( ctx: &mut RequestMonitoring, mechanism: &M, - mut node_info: console::CachedNodeInfo, - user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>, + user_info: &B, + allow_self_signed_compute: bool, ) -> Result where M::ConnectError: ShouldRetry + std::fmt::Debug, M::Error: From, { + let mut num_retries = 0; + let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; + if let Some(keys) = user_info.get_keys() { + node_info.set_keys(keys); + } + node_info.allow_self_signed_compute = allow_self_signed_compute; + // let mut node_info = credentials.get_node_info(ctx, user_info).await?; mechanism.update_connect_config(&mut node_info.config); // try once @@ -108,28 +122,31 @@ where error!(error = ?err, "could not connect to compute node"); - let mut num_retries = 1; - - match user_info { - auth::BackendType::Console(api, info) => { + let node_info = + if err.get_error_kind() == crate::error::ErrorKind::Postgres || !node_info.cached() { + // If the error is Postgres, that means that we managed to connect to the compute node, but there was an error. + // Do not need to retrieve a new node_info, just return the old one. + if !err.should_retry(num_retries) { + return Err(err.into()); + } + node_info + } else { // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node info!("compute node's state has likely changed; requesting a wake-up"); - ctx.latency_timer.cache_miss(); - let config = invalidate_cache(node_info); - node_info = wake_compute(&mut num_retries, ctx, api, info).await?; + let old_node_info = invalidate_cache(node_info); + let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; + node_info.reuse_settings(old_node_info); - node_info.config.reuse_password(&config); mechanism.update_connect_config(&mut node_info.config); - } - // nothing to do? - auth::BackendType::Link(_) => {} - }; + node_info + }; // now that we have a new node, try connect to it repeatedly. // this can error for a few reasons, for instance: // * DNS connection settings haven't quite propagated yet info!("wake_compute success. attempting to connect"); + num_retries = 1; loop { match mechanism .connect_once(ctx, &node_info, CONNECT_TIMEOUT) diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 5bb43c0375..efbd661bbf 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -2,13 +2,19 @@ mod mitm; +use std::time::Duration; + use super::connect_compute::ConnectMechanism; use super::retry::ShouldRetry; use super::*; -use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend}; +use crate::auth::backend::{ + ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend, +}; use crate::config::CertResolver; +use crate::console::caches::NodeInfoCache; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; +use crate::error::ErrorKind; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; use async_trait::async_trait; @@ -369,12 +375,15 @@ enum ConnectAction { Connect, Retry, Fail, + RetryPg, + FailPg, } #[derive(Clone)] struct TestConnectMechanism { counter: Arc>, sequence: Vec, + cache: &'static NodeInfoCache, } impl TestConnectMechanism { @@ -393,6 +402,12 @@ impl TestConnectMechanism { Self { counter: Arc::new(std::sync::Mutex::new(0)), sequence, + cache: Box::leak(Box::new(NodeInfoCache::new( + "test", + 1, + Duration::from_secs(100), + false, + ))), } } } @@ -403,6 +418,13 @@ struct TestConnection; #[derive(Debug)] struct TestConnectError { retryable: bool, + kind: crate::error::ErrorKind, +} + +impl ReportableError for TestConnectError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + self.kind + } } impl std::fmt::Display for TestConnectError { @@ -436,8 +458,22 @@ impl ConnectMechanism for TestConnectMechanism { *counter += 1; match action { ConnectAction::Connect => Ok(TestConnection), - ConnectAction::Retry => Err(TestConnectError { retryable: true }), - ConnectAction::Fail => Err(TestConnectError { retryable: false }), + ConnectAction::Retry => Err(TestConnectError { + retryable: true, + kind: ErrorKind::Compute, + }), + ConnectAction::Fail => Err(TestConnectError { + retryable: false, + kind: ErrorKind::Compute, + }), + ConnectAction::FailPg => Err(TestConnectError { + retryable: false, + kind: ErrorKind::Postgres, + }), + ConnectAction::RetryPg => Err(TestConnectError { + retryable: true, + kind: ErrorKind::Postgres, + }), x => panic!("expecting action {:?}, connect is called instead", x), } } @@ -451,7 +487,7 @@ impl TestBackend for TestConnectMechanism { let action = self.sequence[*counter]; *counter += 1; match action { - ConnectAction::Wake => Ok(helper_create_cached_node_info()), + ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)), ConnectAction::WakeFail => { let err = console::errors::ApiError::Console { status: http::StatusCode::FORBIDDEN, @@ -483,37 +519,41 @@ impl TestBackend for TestConnectMechanism { } } -fn helper_create_cached_node_info() -> CachedNodeInfo { +fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = NodeInfo { config: compute::ConnCfg::new(), aux: Default::default(), allow_self_signed_compute: false, }; - CachedNodeInfo::new_uncached(node) + let (_, node) = cache.insert("key".into(), node); + node } fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) { - let cache = helper_create_cached_node_info(); +) -> auth::BackendType<'static, ComputeCredentials, &()> { let user_info = auth::BackendType::Console( MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))), - ComputeUserInfo { - endpoint: "endpoint".into(), - user: "user".into(), - options: NeonOptions::parse_options_raw(""), + ComputeCredentials { + info: ComputeUserInfo { + endpoint: "endpoint".into(), + user: "user".into(), + options: NeonOptions::parse_options_raw(""), + }, + keys: ComputeCredentialKeys::Password("password".into()), }, ); - (cache, user_info) + user_info } #[tokio::test] async fn connect_to_compute_success() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -521,24 +561,52 @@ async fn connect_to_compute_success() { #[tokio::test] async fn connect_to_compute_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); } +#[tokio::test] +async fn connect_to_compute_retry_pg() { + let _ = env_logger::try_init(); + use ConnectAction::*; + let mut ctx = RequestMonitoring::test(); + let mechanism = TestConnectMechanism::new(vec![Wake, RetryPg, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) + .await + .unwrap(); + mechanism.verify(); +} + +#[tokio::test] +async fn connect_to_compute_fail_pg() { + let _ = env_logger::try_init(); + use ConnectAction::*; + let mut ctx = RequestMonitoring::test(); + let mechanism = TestConnectMechanism::new(vec![Wake, FailPg]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) + .await + .unwrap_err(); + mechanism.verify(); +} + /// Test that we don't retry if the error is not retryable. #[tokio::test] async fn connect_to_compute_non_retry_1() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); @@ -547,11 +615,12 @@ async fn connect_to_compute_non_retry_1() { /// Even for non-retryable errors, we should retry at least once. #[tokio::test] async fn connect_to_compute_non_retry_2() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -560,15 +629,16 @@ async fn connect_to_compute_non_retry_2() { /// Retry for at most `NUM_RETRIES_CONNECT` times. #[tokio::test] async fn connect_to_compute_non_retry_3() { + let _ = env_logger::try_init(); assert_eq!(NUM_RETRIES_CONNECT, 16); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![ - Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, - Retry, Retry, Retry, Retry, /* the 17th time */ Retry, + Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, + Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry, ]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); @@ -577,11 +647,12 @@ async fn connect_to_compute_non_retry_3() { /// Should retry wake compute. #[tokio::test] async fn wake_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -590,11 +661,12 @@ async fn wake_retry() { /// Wake failed with a non-retryable error. #[tokio::test] async fn wake_non_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 925727bdab..2c593451b4 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,9 +1,4 @@ -use crate::auth::backend::ComputeUserInfo; -use crate::console::{ - errors::WakeComputeError, - provider::{CachedNodeInfo, ConsoleBackend}, - Api, -}; +use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo}; use crate::context::RequestMonitoring; use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES}; use crate::proxy::retry::retry_after; @@ -11,17 +6,16 @@ use hyper::StatusCode; use std::ops::ControlFlow; use tracing::{error, warn}; +use super::connect_compute::ComputeConnectBackend; use super::retry::ShouldRetry; -/// wake a compute (or retrieve an existing compute session from cache) -pub async fn wake_compute( +pub async fn wake_compute( num_retries: &mut u32, ctx: &mut RequestMonitoring, - api: &ConsoleBackend, - info: &ComputeUserInfo, + api: &B, ) -> Result { loop { - let wake_res = api.wake_compute(ctx, info).await; + let wake_res = api.wake_compute(ctx).await; match handle_try_wake(wake_res, *num_retries) { Err(e) => { error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 156002006d..6f93f86d5f 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use tracing::{field::display, info}; use crate::{ - auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, + auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, compute, config::ProxyConfig, console::{ @@ -27,7 +27,7 @@ impl PoolingBackend { &self, ctx: &mut RequestMonitoring, conn_info: &ConnInfo, - ) -> Result { + ) -> Result { let user_info = conn_info.user_info.clone(); let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; @@ -49,13 +49,17 @@ impl PoolingBackend { }; let auth_outcome = crate::auth::validate_password_and_exchange(&conn_info.password, secret)?; - match auth_outcome { + let res = match auth_outcome { crate::sasl::Outcome::Success(key) => Ok(key), crate::sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); Err(AuthError::auth_failed(&*conn_info.user_info.user)) } - } + }; + res.map(|key| ComputeCredentials { + info: user_info, + keys: key, + }) } // Wake up the destination if needed. Code here is a bit involved because @@ -66,7 +70,7 @@ impl PoolingBackend { &self, ctx: &mut RequestMonitoring, conn_info: ConnInfo, - keys: ComputeCredentialKeys, + keys: ComputeCredentials, force_new: bool, ) -> Result, HttpConnError> { let maybe_client = if !force_new { @@ -82,26 +86,8 @@ impl PoolingBackend { } let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); - info!("pool: opening a new connection '{conn_info}'"); - let backend = self - .config - .auth_backend - .as_ref() - .map(|_| conn_info.user_info.clone()); - - let mut node_info = backend - .wake_compute(ctx) - .await? - .ok_or(HttpConnError::NoComputeInfo)?; - - match keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => node_info.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys), - }; - - ctx.set_project(node_info.aux.clone()); - + info!(%conn_id, "pool: opening a new connection '{conn_info}'"); + let backend = self.config.auth_backend.as_ref().map(|_| keys); crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { @@ -109,8 +95,8 @@ impl PoolingBackend { conn_info, pool: self.pool.clone(), }, - node_info, &backend, + false, // do not allow self signed compute for http flow ) .await } @@ -129,8 +115,6 @@ pub enum HttpConnError { AuthError(#[from] AuthError), #[error("wake_compute returned error")] WakeCompute(#[from] WakeComputeError), - #[error("wake_compute returned nothing")] - NoComputeInfo, } struct TokioMechanism {