Proxy refactor auth+connect (#6708)

## Problem

Not really a problem, just refactoring.

## Summary of changes

Separate authenticate from wake compute.

Do not call wake compute second time if we managed to connect to
postgres or if we got it not from cache.
This commit is contained in:
Anna Khanova
2024-02-12 19:41:02 +01:00
committed by GitHub
parent a1f37cba1c
commit fac50a6264
15 changed files with 307 additions and 191 deletions

View File

@@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange;
use crate::cache::Cached;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
use crate::console::AuthSecret;
use crate::console::{AuthSecret, NodeInfo};
use crate::context::RequestMonitoring;
use crate::proxy::wake_compute::wake_compute;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::stream::Stream;
use crate::{
@@ -26,7 +26,6 @@ use crate::{
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use futures::TryFutureExt;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -56,11 +55,11 @@ impl<T> std::ops::Deref for MaybeOwned<'_, T> {
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum BackendType<'a, T> {
pub enum BackendType<'a, T, D> {
/// Cloud API (V2).
Console(MaybeOwned<'a, ConsoleBackend>, T),
/// Authentication via a web browser.
Link(MaybeOwned<'a, url::ApiUrl>),
Link(MaybeOwned<'a, url::ApiUrl>, D),
}
pub trait TestBackend: Send + Sync + 'static {
@@ -71,7 +70,7 @@ pub trait TestBackend: Send + Sync + 'static {
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for BackendType<'_, ()> {
impl std::fmt::Display for BackendType<'_, (), ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
@@ -86,51 +85,50 @@ impl std::fmt::Display for BackendType<'_, ()> {
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
}
impl<T> BackendType<'_, T> {
impl<T, D> BackendType<'_, T, D> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub fn as_ref(&self) -> BackendType<'_, &T> {
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
use BackendType::*;
match self {
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
Link(c) => Link(MaybeOwned::Borrowed(c)),
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
}
}
}
impl<'a, T> BackendType<'a, T> {
impl<'a, T, D> BackendType<'a, T, D> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
use BackendType::*;
match self {
Console(c, x) => Console(c, f(x)),
Link(c) => Link(c),
Link(c, x) => Link(c, x),
}
}
}
impl<'a, T, E> BackendType<'a, Result<T, E>> {
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
use BackendType::*;
match self {
Console(c, x) => x.map(|x| Console(c, x)),
Link(c) => Ok(Link(c)),
Link(c, x) => Ok(Link(c, x)),
}
}
}
pub struct ComputeCredentials<T> {
pub struct ComputeCredentials {
pub info: ComputeUserInfo,
pub keys: T,
pub keys: ComputeCredentialKeys,
}
#[derive(Debug, Clone)]
@@ -153,7 +151,6 @@ impl ComputeUserInfo {
}
pub enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
}
@@ -188,7 +185,7 @@ async fn auth_quirks(
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -198,8 +195,11 @@ async fn auth_quirks(
ctx.set_endpoint_id(res.info.endpoint.clone());
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
(res.info, Some(res.keys))
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
_ => unreachable!("password hack should return a password"),
};
(res.info, Some(password))
}
Ok(info) => (info, None),
};
@@ -253,7 +253,7 @@ async fn authenticate_with_secret(
unauthenticated_password: Option<Vec<u8>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let auth_outcome = validate_password_and_exchange(&password, secret)?;
let keys = match auth_outcome {
@@ -283,14 +283,14 @@ async fn authenticate_with_secret(
classic::authenticate(ctx, info, client, config, secret).await
}
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<EndpointId> {
use BackendType::*;
match self {
Console(_, user_info) => user_info.endpoint_id.clone(),
Link(_) => Some("link".into()),
Link(_, _) => Some("link".into()),
}
}
@@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
match self {
Console(_, user_info) => &user_info.user,
Link(_) => "link",
Link(_, _) => "link",
}
}
@@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
use BackendType::*;
let res = match self {
@@ -323,33 +323,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
"performing authentication using the console"
);
let compute_credentials =
let credentials =
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
let mut num_retries = 0;
let mut node =
wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?;
ctx.set_project(node.aux.clone());
match compute_credentials.keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => node.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
};
(node, BackendType::Console(api, compute_credentials.info))
BackendType::Console(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
Link(url) => {
Link(url, _) => {
info!("performing link authentication");
let node_info = link::authenticate(ctx, &url, client).await?;
let info = link::authenticate(ctx, &url, client).await?;
(
CachedNodeInfo::new_uncached(node_info),
BackendType::Link(url),
)
BackendType::Link(url, info)
}
};
@@ -358,7 +342,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
}
}
impl BackendType<'_, ComputeUserInfo> {
impl BackendType<'_, ComputeUserInfo, &()> {
pub async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
@@ -366,7 +350,7 @@ impl BackendType<'_, ComputeUserInfo> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Link(_) => Ok(Cached::new_uncached(None)),
Link(_, _) => Ok(Cached::new_uncached(None)),
}
}
@@ -377,21 +361,51 @@ impl BackendType<'_, ComputeUserInfo> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
/// When applicable, wake the compute node, gaining its connection info in the process.
/// The link auth flow doesn't support this, so we return [`None`] in that case.
pub async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
Link(_) => Ok(None),
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
}
}
}

View File

@@ -17,7 +17,7 @@ pub(super) async fn authenticate(
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
) -> auth::Result<ComputeCredentials> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]

View File

@@ -20,7 +20,7 @@ pub async fn authenticate_cleartext(
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
) -> auth::Result<ComputeCredentials> {
warn!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
@@ -51,7 +51,7 @@ pub async fn password_hack_no_authentication(
ctx: &mut RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
) -> auth::Result<ComputeCredentials<Vec<u8>>> {
) -> auth::Result<ComputeCredentials> {
warn!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
@@ -73,6 +73,6 @@ pub async fn password_hack_no_authentication(
options: info.options,
endpoint: payload.endpoint,
},
keys: payload.password,
keys: ComputeCredentialKeys::Password(payload.password),
})
}