mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
Proxy refactor auth+connect (#6708)
## Problem Not really a problem, just refactoring. ## Summary of changes Separate authenticate from wake compute. Do not call wake compute second time if we managed to connect to postgres or if we got it not from cache.
This commit is contained in:
@@ -36,9 +36,6 @@ pub enum AuthErrorImpl {
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
|
||||
|
||||
#[error(transparent)]
|
||||
WakeCompute(#[from] console::errors::WakeComputeError),
|
||||
|
||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
@@ -119,7 +116,6 @@ impl UserFacingError for AuthError {
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.to_string_client(),
|
||||
GetAuthInfo(e) => e.to_string_client(),
|
||||
WakeCompute(e) => e.to_string_client(),
|
||||
Sasl(e) => e.to_string_client(),
|
||||
AuthFailed(_) => self.to_string(),
|
||||
BadAuthMethod(_) => self.to_string(),
|
||||
@@ -139,7 +135,6 @@ impl ReportableError for AuthError {
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.get_error_kind(),
|
||||
GetAuthInfo(e) => e.get_error_kind(),
|
||||
WakeCompute(e) => e.get_error_kind(),
|
||||
Sasl(e) => e.get_error_kind(),
|
||||
AuthFailed(_) => crate::error::ErrorKind::User,
|
||||
BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
|
||||
@@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange;
|
||||
use crate::cache::Cached;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::AuthSecret;
|
||||
use crate::console::{AuthSecret, NodeInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
@@ -26,7 +26,6 @@ use crate::{
|
||||
stream, url,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
use futures::TryFutureExt;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
@@ -56,11 +55,11 @@ impl<T> std::ops::Deref for MaybeOwned<'_, T> {
|
||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum BackendType<'a, T> {
|
||||
pub enum BackendType<'a, T, D> {
|
||||
/// Cloud API (V2).
|
||||
Console(MaybeOwned<'a, ConsoleBackend>, T),
|
||||
/// Authentication via a web browser.
|
||||
Link(MaybeOwned<'a, url::ApiUrl>),
|
||||
Link(MaybeOwned<'a, url::ApiUrl>, D),
|
||||
}
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
@@ -71,7 +70,7 @@ pub trait TestBackend: Send + Sync + 'static {
|
||||
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, ()> {
|
||||
impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
@@ -86,51 +85,50 @@ impl std::fmt::Display for BackendType<'_, ()> {
|
||||
#[cfg(test)]
|
||||
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
},
|
||||
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> BackendType<'_, T> {
|
||||
impl<T, D> BackendType<'_, T, D> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T> {
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
|
||||
Link(c) => Link(MaybeOwned::Borrowed(c)),
|
||||
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> BackendType<'a, T> {
|
||||
impl<'a, T, D> BackendType<'a, T, D> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(c, f(x)),
|
||||
Link(c) => Link(c),
|
||||
Link(c, x) => Link(c, x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, E> BackendType<'a, Result<T, E>> {
|
||||
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => x.map(|x| Console(c, x)),
|
||||
Link(c) => Ok(Link(c)),
|
||||
Link(c, x) => Ok(Link(c, x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ComputeCredentials<T> {
|
||||
pub struct ComputeCredentials {
|
||||
pub info: ComputeUserInfo,
|
||||
pub keys: T,
|
||||
pub keys: ComputeCredentialKeys,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -153,7 +151,6 @@ impl ComputeUserInfo {
|
||||
}
|
||||
|
||||
pub enum ComputeCredentialKeys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
}
|
||||
@@ -188,7 +185,7 @@ async fn auth_quirks(
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
@@ -198,8 +195,11 @@ async fn auth_quirks(
|
||||
|
||||
ctx.set_endpoint_id(res.info.endpoint.clone());
|
||||
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
|
||||
|
||||
(res.info, Some(res.keys))
|
||||
let password = match res.keys {
|
||||
ComputeCredentialKeys::Password(p) => p,
|
||||
_ => unreachable!("password hack should return a password"),
|
||||
};
|
||||
(res.info, Some(password))
|
||||
}
|
||||
Ok(info) => (info, None),
|
||||
};
|
||||
@@ -253,7 +253,7 @@ async fn authenticate_with_secret(
|
||||
unauthenticated_password: Option<Vec<u8>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let auth_outcome = validate_password_and_exchange(&password, secret)?;
|
||||
let keys = match auth_outcome {
|
||||
@@ -283,14 +283,14 @@ async fn authenticate_with_secret(
|
||||
classic::authenticate(ctx, info, client, config, secret).await
|
||||
}
|
||||
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<EndpointId> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||
Link(_) => Some("link".into()),
|
||||
Link(_, _) => Some("link".into()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => &user_info.user,
|
||||
Link(_) => "link",
|
||||
Link(_, _) => "link",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
|
||||
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
@@ -323,33 +323,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let compute_credentials =
|
||||
let credentials =
|
||||
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
|
||||
|
||||
let mut num_retries = 0;
|
||||
let mut node =
|
||||
wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?;
|
||||
|
||||
ctx.set_project(node.aux.clone());
|
||||
|
||||
match compute_credentials.keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => node.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
|
||||
};
|
||||
|
||||
(node, BackendType::Console(api, compute_credentials.info))
|
||||
BackendType::Console(api, credentials)
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url) => {
|
||||
Link(url, _) => {
|
||||
info!("performing link authentication");
|
||||
|
||||
let node_info = link::authenticate(ctx, &url, client).await?;
|
||||
let info = link::authenticate(ctx, &url, client).await?;
|
||||
|
||||
(
|
||||
CachedNodeInfo::new_uncached(node_info),
|
||||
BackendType::Link(url),
|
||||
)
|
||||
BackendType::Link(url, info)
|
||||
}
|
||||
};
|
||||
|
||||
@@ -358,7 +342,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<'_, ComputeUserInfo> {
|
||||
impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
pub async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
@@ -366,7 +350,7 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||
Link(_) => Ok(Cached::new_uncached(None)),
|
||||
Link(_, _) => Ok(Cached::new_uncached(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,21 +361,51 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// When applicable, wake the compute node, gaining its connection info in the process.
|
||||
/// The link auth flow doesn't support this, so we return [`None`] in that case.
|
||||
pub async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
||||
Link(_) => Ok(None),
|
||||
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ pub(super) async fn authenticate(
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
|
||||
@@ -20,7 +20,7 @@ pub async fn authenticate_cleartext(
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
@@ -51,7 +51,7 @@ pub async fn password_hack_no_authentication(
|
||||
ctx: &mut RequestMonitoring,
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
) -> auth::Result<ComputeCredentials<Vec<u8>>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
@@ -73,6 +73,6 @@ pub async fn password_hack_no_authentication(
|
||||
options: info.options,
|
||||
endpoint: payload.endpoint,
|
||||
},
|
||||
keys: payload.password,
|
||||
keys: ComputeCredentialKeys::Password(payload.password),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -383,7 +383,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
}
|
||||
AuthBackend::Link => {
|
||||
let url = args.uri.parse()?;
|
||||
auth::BackendType::Link(MaybeOwned::Owned(url))
|
||||
auth::BackendType::Link(MaybeOwned::Owned(url), ())
|
||||
}
|
||||
};
|
||||
let http_config = HttpConfig {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
auth::parse_endpoint_param,
|
||||
cancellation::CancelClosure,
|
||||
console::errors::WakeComputeError,
|
||||
console::{errors::WakeComputeError, messages::MetricsAuxInfo},
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
@@ -93,7 +93,7 @@ impl ConnCfg {
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub fn reuse_password(&mut self, other: &Self) {
|
||||
pub fn reuse_password(&mut self, other: Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
@@ -253,6 +253,8 @@ pub struct PostgresConnection {
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
pub cancel_closure: CancelClosure,
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
_guage: IntCounterPairGuard,
|
||||
}
|
||||
@@ -263,6 +265,7 @@ impl ConnCfg {
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
allow_self_signed_compute: bool,
|
||||
aux: MetricsAuxInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
@@ -297,6 +300,7 @@ impl ConnCfg {
|
||||
stream,
|
||||
params,
|
||||
cancel_closure,
|
||||
aux,
|
||||
_guage: NUM_DB_CONNECTIONS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard(),
|
||||
|
||||
@@ -13,7 +13,7 @@ use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::BackendType<'static, ()>,
|
||||
pub auth_backend: auth::BackendType<'static, (), ()>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub http_config: HttpConfig,
|
||||
|
||||
@@ -4,7 +4,10 @@ pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::{backend::ComputeUserInfo, IpPattern},
|
||||
auth::{
|
||||
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
||||
IpPattern,
|
||||
},
|
||||
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||
@@ -261,6 +264,34 @@ pub struct NodeInfo {
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(
|
||||
ctx,
|
||||
self.allow_self_signed_compute,
|
||||
self.aux.clone(),
|
||||
timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
pub fn reuse_settings(&mut self, other: Self) {
|
||||
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
|
||||
@@ -176,9 +176,7 @@ impl super::Api for Api {
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute()
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
.await
|
||||
self.do_wake_compute().map_ok(Cached::new_uncached).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ pub trait UserFacingError: ReportableError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum ErrorKind {
|
||||
/// Wrong password, unknown endpoint, protocol violation, etc...
|
||||
User,
|
||||
@@ -90,3 +90,13 @@ impl ReportableError for tokio::time::error::Elapsed {
|
||||
ErrorKind::RateLimit
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for tokio_postgres::error::Error {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
if self.as_db_error().is_some() {
|
||||
ErrorKind::Postgres
|
||||
} else {
|
||||
ErrorKind::Compute
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,14 +163,14 @@ pub enum ClientMode {
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
fn allow_cleartext(&self) -> bool {
|
||||
pub fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => config.allow_self_signed_compute,
|
||||
ClientMode::Websockets { .. } => false,
|
||||
@@ -287,7 +287,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let (mut node_info, user_info) = match user_info
|
||||
let user_info = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
@@ -306,14 +306,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
};
|
||||
|
||||
node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
|
||||
|
||||
let aux = node_info.aux.clone();
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism { params: ¶ms },
|
||||
node_info,
|
||||
&user_info,
|
||||
mode.allow_self_signed_compute(config),
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
@@ -330,8 +327,8 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node,
|
||||
aux,
|
||||
req: _request_gauge,
|
||||
conn: _client_gauge,
|
||||
}))
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::{
|
||||
auth,
|
||||
auth::backend::ComputeCredentialKeys,
|
||||
compute::{self, PostgresConnection},
|
||||
console::{self, errors::WakeComputeError},
|
||||
console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::NUM_CONNECTION_FAILURES,
|
||||
proxy::{
|
||||
retry::{retry_after, ShouldRetry},
|
||||
@@ -20,7 +21,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
||||
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
|
||||
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
warn!("invalidating stalled compute node info cache entry");
|
||||
@@ -31,13 +32,13 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg
|
||||
};
|
||||
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
||||
|
||||
node_info.invalidate().config
|
||||
node_info.invalidate()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ConnectMechanism {
|
||||
type Connection;
|
||||
type ConnectError;
|
||||
type ConnectError: ReportableError;
|
||||
type Error: From<Self::ConnectError>;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
@@ -49,6 +50,16 @@ pub trait ConnectMechanism {
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
|
||||
}
|
||||
|
||||
pub struct TcpMechanism<'a> {
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub params: &'a StartupMessageParams,
|
||||
@@ -67,11 +78,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let allow_self_signed_compute = node_info.allow_self_signed_compute;
|
||||
node_info
|
||||
.config
|
||||
.connect(ctx, allow_self_signed_compute, timeout)
|
||||
.await
|
||||
node_info.connect(ctx, timeout).await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
@@ -82,16 +89,23 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
/// This function might update `node_info`, so we take it by `&mut`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn connect_to_compute<M: ConnectMechanism>(
|
||||
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &mut RequestMonitoring,
|
||||
mechanism: &M,
|
||||
mut node_info: console::CachedNodeInfo,
|
||||
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
||||
user_info: &B,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
||||
M::Error: From<WakeComputeError>,
|
||||
{
|
||||
let mut num_retries = 0;
|
||||
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
|
||||
if let Some(keys) = user_info.get_keys() {
|
||||
node_info.set_keys(keys);
|
||||
}
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
@@ -108,28 +122,31 @@ where
|
||||
|
||||
error!(error = ?err, "could not connect to compute node");
|
||||
|
||||
let mut num_retries = 1;
|
||||
|
||||
match user_info {
|
||||
auth::BackendType::Console(api, info) => {
|
||||
let node_info =
|
||||
if err.get_error_kind() == crate::error::ErrorKind::Postgres || !node_info.cached() {
|
||||
// If the error is Postgres, that means that we managed to connect to the compute node, but there was an error.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if !err.should_retry(num_retries) {
|
||||
return Err(err.into());
|
||||
}
|
||||
node_info
|
||||
} else {
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
|
||||
ctx.latency_timer.cache_miss();
|
||||
let config = invalidate_cache(node_info);
|
||||
node_info = wake_compute(&mut num_retries, ctx, api, info).await?;
|
||||
let old_node_info = invalidate_cache(node_info);
|
||||
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
|
||||
node_info.reuse_settings(old_node_info);
|
||||
|
||||
node_info.config.reuse_password(&config);
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
}
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => {}
|
||||
};
|
||||
node_info
|
||||
};
|
||||
|
||||
// now that we have a new node, try connect to it repeatedly.
|
||||
// this can error for a few reasons, for instance:
|
||||
// * DNS connection settings haven't quite propagated yet
|
||||
info!("wake_compute success. attempting to connect");
|
||||
num_retries = 1;
|
||||
loop {
|
||||
match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
|
||||
@@ -2,13 +2,19 @@
|
||||
|
||||
mod mitm;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::ShouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend};
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
|
||||
};
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::caches::NodeInfoCache;
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
@@ -369,12 +375,15 @@ enum ConnectAction {
|
||||
Connect,
|
||||
Retry,
|
||||
Fail,
|
||||
RetryPg,
|
||||
FailPg,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestConnectMechanism {
|
||||
counter: Arc<std::sync::Mutex<usize>>,
|
||||
sequence: Vec<ConnectAction>,
|
||||
cache: &'static NodeInfoCache,
|
||||
}
|
||||
|
||||
impl TestConnectMechanism {
|
||||
@@ -393,6 +402,12 @@ impl TestConnectMechanism {
|
||||
Self {
|
||||
counter: Arc::new(std::sync::Mutex::new(0)),
|
||||
sequence,
|
||||
cache: Box::leak(Box::new(NodeInfoCache::new(
|
||||
"test",
|
||||
1,
|
||||
Duration::from_secs(100),
|
||||
false,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -403,6 +418,13 @@ struct TestConnection;
|
||||
#[derive(Debug)]
|
||||
struct TestConnectError {
|
||||
retryable: bool,
|
||||
kind: crate::error::ErrorKind,
|
||||
}
|
||||
|
||||
impl ReportableError for TestConnectError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
self.kind
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestConnectError {
|
||||
@@ -436,8 +458,22 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Connect => Ok(TestConnection),
|
||||
ConnectAction::Retry => Err(TestConnectError { retryable: true }),
|
||||
ConnectAction::Fail => Err(TestConnectError { retryable: false }),
|
||||
ConnectAction::Retry => Err(TestConnectError {
|
||||
retryable: true,
|
||||
kind: ErrorKind::Compute,
|
||||
}),
|
||||
ConnectAction::Fail => Err(TestConnectError {
|
||||
retryable: false,
|
||||
kind: ErrorKind::Compute,
|
||||
}),
|
||||
ConnectAction::FailPg => Err(TestConnectError {
|
||||
retryable: false,
|
||||
kind: ErrorKind::Postgres,
|
||||
}),
|
||||
ConnectAction::RetryPg => Err(TestConnectError {
|
||||
retryable: true,
|
||||
kind: ErrorKind::Postgres,
|
||||
}),
|
||||
x => panic!("expecting action {:?}, connect is called instead", x),
|
||||
}
|
||||
}
|
||||
@@ -451,7 +487,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
let action = self.sequence[*counter];
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info()),
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
|
||||
ConnectAction::WakeFail => {
|
||||
let err = console::errors::ApiError::Console {
|
||||
status: http::StatusCode::FORBIDDEN,
|
||||
@@ -483,37 +519,41 @@ impl TestBackend for TestConnectMechanism {
|
||||
}
|
||||
}
|
||||
|
||||
fn helper_create_cached_node_info() -> CachedNodeInfo {
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
let node = NodeInfo {
|
||||
config: compute::ConnCfg::new(),
|
||||
aux: Default::default(),
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
CachedNodeInfo::new_uncached(node)
|
||||
let (_, node) = cache.insert("key".into(), node);
|
||||
node
|
||||
}
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) {
|
||||
let cache = helper_create_cached_node_info();
|
||||
) -> auth::BackendType<'static, ComputeCredentials, &()> {
|
||||
let user_info = auth::BackendType::Console(
|
||||
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
|
||||
ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
},
|
||||
);
|
||||
(cache, user_info)
|
||||
user_info
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_success() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -521,24 +561,52 @@ async fn connect_to_compute_success() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_retry_pg() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, RetryPg, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_fail_pg() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, FailPg]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
/// Test that we don't retry if the error is not retryable.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_1() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -547,11 +615,12 @@ async fn connect_to_compute_non_retry_1() {
|
||||
/// Even for non-retryable errors, we should retry at least once.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_2() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -560,15 +629,16 @@ async fn connect_to_compute_non_retry_2() {
|
||||
/// Retry for at most `NUM_RETRIES_CONNECT` times.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_3() {
|
||||
let _ = env_logger::try_init();
|
||||
assert_eq!(NUM_RETRIES_CONNECT, 16);
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![
|
||||
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||
Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -577,11 +647,12 @@ async fn connect_to_compute_non_retry_3() {
|
||||
/// Should retry wake compute.
|
||||
#[tokio::test]
|
||||
async fn wake_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -590,11 +661,12 @@ async fn wake_retry() {
|
||||
/// Wake failed with a non-retryable error.
|
||||
#[tokio::test]
|
||||
async fn wake_non_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::console::{
|
||||
errors::WakeComputeError,
|
||||
provider::{CachedNodeInfo, ConsoleBackend},
|
||||
Api,
|
||||
};
|
||||
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES};
|
||||
use crate::proxy::retry::retry_after;
|
||||
@@ -11,17 +6,16 @@ use hyper::StatusCode;
|
||||
use std::ops::ControlFlow;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use super::connect_compute::ComputeConnectBackend;
|
||||
use super::retry::ShouldRetry;
|
||||
|
||||
/// wake a compute (or retrieve an existing compute session from cache)
|
||||
pub async fn wake_compute(
|
||||
pub async fn wake_compute<B: ComputeConnectBackend>(
|
||||
num_retries: &mut u32,
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &ConsoleBackend,
|
||||
info: &ComputeUserInfo,
|
||||
api: &B,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
loop {
|
||||
let wake_res = api.wake_compute(ctx, info).await;
|
||||
let wake_res = api.wake_compute(ctx).await;
|
||||
match handle_try_wake(wake_res, *num_retries) {
|
||||
Err(e) => {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||
|
||||
@@ -4,7 +4,7 @@ use async_trait::async_trait;
|
||||
use tracing::{field::display, info};
|
||||
|
||||
use crate::{
|
||||
auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
|
||||
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
|
||||
compute,
|
||||
config::ProxyConfig,
|
||||
console::{
|
||||
@@ -27,7 +27,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<ComputeCredentialKeys, AuthError> {
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = conn_info.user_info.clone();
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
@@ -49,13 +49,17 @@ impl PoolingBackend {
|
||||
};
|
||||
let auth_outcome =
|
||||
crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
|
||||
match auth_outcome {
|
||||
let res = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => Ok(key),
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
Err(AuthError::auth_failed(&*conn_info.user_info.user))
|
||||
}
|
||||
}
|
||||
};
|
||||
res.map(|key| ComputeCredentials {
|
||||
info: user_info,
|
||||
keys: key,
|
||||
})
|
||||
}
|
||||
|
||||
// Wake up the destination if needed. Code here is a bit involved because
|
||||
@@ -66,7 +70,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentialKeys,
|
||||
keys: ComputeCredentials,
|
||||
force_new: bool,
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
let maybe_client = if !force_new {
|
||||
@@ -82,26 +86,8 @@ impl PoolingBackend {
|
||||
}
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!("pool: opening a new connection '{conn_info}'");
|
||||
let backend = self
|
||||
.config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| conn_info.user_info.clone());
|
||||
|
||||
let mut node_info = backend
|
||||
.wake_compute(ctx)
|
||||
.await?
|
||||
.ok_or(HttpConnError::NoComputeInfo)?;
|
||||
|
||||
match keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => node_info.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
|
||||
};
|
||||
|
||||
ctx.set_project(node_info.aux.clone());
|
||||
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| keys);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
@@ -109,8 +95,8 @@ impl PoolingBackend {
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
},
|
||||
node_info,
|
||||
&backend,
|
||||
false, // do not allow self signed compute for http flow
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -129,8 +115,6 @@ pub enum HttpConnError {
|
||||
AuthError(#[from] AuthError),
|
||||
#[error("wake_compute returned error")]
|
||||
WakeCompute(#[from] WakeComputeError),
|
||||
#[error("wake_compute returned nothing")]
|
||||
NoComputeInfo,
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
|
||||
Reference in New Issue
Block a user