From 34ddec67d9570bd275aedadadff89d5afef762cd Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 21 Jan 2024 08:58:42 +0000 Subject: [PATCH] proxy small tweaks (#6398) ## Problem In https://github.com/neondatabase/neon/pull/6283 I did a couple changes that weren't directly related to the goal of extracting the state machine, so I'm putting them here ## Summary of changes - move postgres vs console provider into another enum - reduce error cases for link auth - slightly refactor link flow --- proxy/Cargo.toml | 2 +- proxy/src/auth/backend.rs | 47 +++++++----------------- proxy/src/auth/backend/link.rs | 35 ++++++++++-------- proxy/src/bin/proxy.rs | 21 +++++++---- proxy/src/console/mgmt.rs | 16 +++------ proxy/src/console/provider.rs | 58 ++++++++++++++++++++++++++++-- proxy/src/proxy/connect_compute.rs | 2 -- 7 files changed, 109 insertions(+), 72 deletions(-) diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 9610071aa6..f075c718a7 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -5,7 +5,7 @@ edition.workspace = true license.workspace = true [features] -default = ["testing"] +default = [] testing = [] [dependencies] diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 120ed46992..34171d4d3f 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -10,6 +10,7 @@ use crate::auth::credentials::check_peer_addr_is_in_list; use crate::auth::validate_password_and_exchange; use crate::cache::Cached; use crate::console::errors::GetAuthInfoError; +use crate::console::provider::ConsoleBackend; use crate::console::AuthSecret; use crate::context::RequestMonitoring; use crate::proxy::connect_compute::handle_try_wake; @@ -43,11 +44,8 @@ use tracing::{error, info, warn}; /// this helps us provide the credentials only to those auth /// backends which require them for the authentication process. pub enum BackendType<'a, T> { - /// Current Cloud API (V2). - Console(Cow<'a, console::provider::neon::Api>, T), - /// Local mock of Cloud API (V2). - #[cfg(feature = "testing")] - Postgres(Cow<'a, console::provider::mock::Api>, T), + /// Cloud API (V2). + Console(Cow<'a, ConsoleBackend>, T), /// Authentication via a web browser. Link(Cow<'a, url::ApiUrl>), #[cfg(test)] @@ -64,9 +62,15 @@ impl std::fmt::Display for BackendType<'_, ()> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use BackendType::*; match self { - Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(), - #[cfg(feature = "testing")] - Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(), + Console(api, _) => match &**api { + ConsoleBackend::Console(endpoint) => { + fmt.debug_tuple("Console").field(&endpoint.url()).finish() + } + #[cfg(feature = "testing")] + ConsoleBackend::Postgres(endpoint) => { + fmt.debug_tuple("Postgres").field(&endpoint.url()).finish() + } + }, Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), #[cfg(test)] Test(_) => fmt.debug_tuple("Test").finish(), @@ -81,8 +85,6 @@ impl BackendType<'_, T> { use BackendType::*; match self { Console(c, x) => Console(Cow::Borrowed(c), x), - #[cfg(feature = "testing")] - Postgres(c, x) => Postgres(Cow::Borrowed(c), x), Link(c) => Link(Cow::Borrowed(c)), #[cfg(test)] Test(x) => Test(*x), @@ -98,8 +100,6 @@ impl<'a, T> BackendType<'a, T> { use BackendType::*; match self { Console(c, x) => Console(c, f(x)), - #[cfg(feature = "testing")] - Postgres(c, x) => Postgres(c, f(x)), Link(c) => Link(c), #[cfg(test)] Test(x) => Test(x), @@ -114,8 +114,6 @@ impl<'a, T, E> BackendType<'a, Result> { use BackendType::*; match self { Console(c, x) => x.map(|x| Console(c, x)), - #[cfg(feature = "testing")] - Postgres(c, x) => x.map(|x| Postgres(c, x)), Link(c) => Ok(Link(c)), #[cfg(test)] Test(x) => Ok(Test(x)), @@ -325,8 +323,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => user_info.project.clone(), - #[cfg(feature = "testing")] - Postgres(_, user_info) => user_info.project.clone(), Link(_) => Some("link".into()), #[cfg(test)] Test(_) => Some("test".into()), @@ -339,8 +335,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => &user_info.user, - #[cfg(feature = "testing")] - Postgres(_, user_info) => &user_info.user, Link(_) => "link", #[cfg(test)] Test(_) => "test", @@ -371,19 +365,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { .await?; (cache_info, BackendType::Console(api, user_info)) } - #[cfg(feature = "testing")] - Postgres(api, user_info) => { - info!( - user = &*user_info.user, - project = user_info.project(), - "performing authentication using a local postgres instance" - ); - - let (cache_info, user_info) = - auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config) - .await?; - (cache_info, BackendType::Postgres(api, user_info)) - } // NOTE: this auth backend doesn't use client credentials. Link(url) => { info!("performing link authentication"); @@ -414,8 +395,6 @@ impl BackendType<'_, ComputeUserInfo> { use BackendType::*; match self { Console(api, user_info) => api.get_allowed_ips(ctx, user_info).await, - #[cfg(feature = "testing")] - Postgres(api, user_info) => api.get_allowed_ips(ctx, user_info).await, Link(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), #[cfg(test)] Test(x) => Ok(Cached::new_uncached(Arc::new(x.get_allowed_ips()?))), @@ -432,8 +411,6 @@ impl BackendType<'_, ComputeUserInfo> { match self { Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, - #[cfg(feature = "testing")] - Postgres(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, Link(_) => Ok(None), #[cfg(test)] Test(x) => x.wake_compute().map(Some), diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index 2cf7e3acc7..a7ddd257b3 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -57,24 +57,31 @@ pub(super) async fn authenticate( link_uri: &reqwest::Url, client: &mut PqStream, ) -> auth::Result { - let psql_session_id = new_psql_session_id(); + // registering waiter can fail if we get unlucky with rng. + // just try again. + let (psql_session_id, waiter) = loop { + let psql_session_id = new_psql_session_id(); + + match console::mgmt::get_waiter(&psql_session_id) { + Ok(waiter) => break (psql_session_id, waiter), + Err(_e) => continue, + } + }; + let span = info_span!("link", psql_session_id = &psql_session_id); let greeting = hello_message(link_uri, &psql_session_id); - let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async { - // Give user a URL to spawn a new database. - info!(parent: &span, "sending the auth URL to the user"); - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::NoticeResponse(&greeting)) - .await?; + // Give user a URL to spawn a new database. + info!(parent: &span, "sending the auth URL to the user"); + client + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&Be::CLIENT_ENCODING)? + .write_message(&Be::NoticeResponse(&greeting)) + .await?; - // Wait for web console response (see `mgmt`). - info!(parent: &span, "waiting for console's reply..."); - waiter.await?.map_err(LinkAuthError::AuthFailed) - }) - .await?; + // Wait for web console response (see `mgmt`). + info!(parent: &span, "waiting for console's reply..."); + let db_info = waiter.await.map_err(LinkAuthError::from)?; client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index e1dac34a59..ba113a89eb 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -249,12 +249,19 @@ async fn main() -> anyhow::Result<()> { } if let auth::BackendType::Console(api, _) = &config.auth_backend { - let cache = api.caches.project_info.clone(); - if let Some(url) = args.redis_notifications { - info!("Starting redis notifications listener ({url})"); - maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone())); + match &**api { + proxy::console::provider::ConsoleBackend::Console(api) => { + let cache = api.caches.project_info.clone(); + if let Some(url) = args.redis_notifications { + info!("Starting redis notifications listener ({url})"); + maintenance_tasks + .spawn(notifications::task_main(url.to_owned(), cache.clone())); + } + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + } + #[cfg(feature = "testing")] + proxy::console::provider::ConsoleBackend::Postgres(_) => {} } - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } let maintenance = loop { @@ -351,13 +358,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config)); let api = console::provider::neon::Api::new(endpoint, caches, locks); + let api = console::provider::ConsoleBackend::Console(api); auth::BackendType::Console(Cow::Owned(api), ()) } #[cfg(feature = "testing")] AuthBackend::Postgres => { let url = args.auth_endpoint.parse()?; let api = console::provider::mock::Api::new(url); - auth::BackendType::Postgres(Cow::Owned(api), ()) + let api = console::provider::ConsoleBackend::Postgres(api); + auth::BackendType::Console(Cow::Owned(api), ()) } AuthBackend::Link => { let url = args.uri.parse()?; diff --git a/proxy/src/console/mgmt.rs b/proxy/src/console/mgmt.rs index f0e084b679..373138b09e 100644 --- a/proxy/src/console/mgmt.rs +++ b/proxy/src/console/mgmt.rs @@ -13,16 +13,10 @@ use tracing::{error, info, info_span, Instrument}; static CPLANE_WAITERS: Lazy> = Lazy::new(Default::default); /// Give caller an opportunity to wait for the cloud's reply. -pub async fn with_waiter( +pub fn get_waiter( psql_session_id: impl Into, - action: impl FnOnce(Waiter<'static, ComputeReady>) -> R, -) -> Result -where - R: std::future::Future>, - E: From, -{ - let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; - action(waiter).await +) -> Result, waiters::RegisterError> { + CPLANE_WAITERS.register(psql_session_id.into()) } pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> { @@ -77,7 +71,7 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { } /// A message received by `mgmt` when a compute node is ready. -pub type ComputeReady = Result; +pub type ComputeReady = DatabaseInfo; // TODO: replace with an http-based protocol. struct MgmtHandler; @@ -102,7 +96,7 @@ fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), Qu let _enter = span.enter(); info!("got response: {:?}", resp.result); - match notify(resp.session_id, Ok(resp.result)) { + match notify(resp.session_id, resp.result) { Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 84c43183cc..178a7a2f4c 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -248,23 +248,75 @@ pub trait Api { async fn get_role_secret( &self, ctx: &mut RequestMonitoring, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result, errors::GetAuthInfoError>; async fn get_allowed_ips( &self, ctx: &mut RequestMonitoring, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result; /// Wake up the compute node and return the corresponding connection info. async fn wake_compute( &self, ctx: &mut RequestMonitoring, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result; } +#[derive(Clone)] +pub enum ConsoleBackend { + /// Current Cloud API (V2). + Console(neon::Api), + /// Local mock of Cloud API (V2). + #[cfg(feature = "testing")] + Postgres(mock::Api), +} + +#[async_trait] +impl Api for ConsoleBackend { + async fn get_role_secret( + &self, + ctx: &mut RequestMonitoring, + user_info: &ComputeUserInfo, + ) -> Result, errors::GetAuthInfoError> { + use ConsoleBackend::*; + match self { + Console(api) => api.get_role_secret(ctx, user_info).await, + #[cfg(feature = "testing")] + Postgres(api) => api.get_role_secret(ctx, user_info).await, + } + } + + async fn get_allowed_ips( + &self, + ctx: &mut RequestMonitoring, + user_info: &ComputeUserInfo, + ) -> Result { + use ConsoleBackend::*; + match self { + Console(api) => api.get_allowed_ips(ctx, user_info).await, + #[cfg(feature = "testing")] + Postgres(api) => api.get_allowed_ips(ctx, user_info).await, + } + } + + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + user_info: &ComputeUserInfo, + ) -> Result { + use ConsoleBackend::*; + + match self { + Console(api) => api.wake_compute(ctx, user_info).await, + #[cfg(feature = "testing")] + Postgres(api) => api.wake_compute(ctx, user_info).await, + } + } +} + /// Various caches for [`console`](super). pub struct ApiCaches { /// Cache for the `wake_compute` API method. diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 72cab1fe5d..8bbe88aa51 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -160,8 +160,6 @@ where let node_info = loop { let wake_res = match user_info { auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await, - #[cfg(feature = "testing")] - auth::BackendType::Postgres(api, user_info) => api.wake_compute(ctx, user_info).await, // nothing to do? auth::BackendType::Link(_) => return Err(err.into()), // test backend