mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-28 00:23:00 +00:00
proxy small tweaks (#6398)
## Problem In https://github.com/neondatabase/neon/pull/6283 I did a couple changes that weren't directly related to the goal of extracting the state machine, so I'm putting them here ## Summary of changes - move postgres vs console provider into another enum - reduce error cases for link auth - slightly refactor link flow
This commit is contained in:
@@ -5,7 +5,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["testing"]
|
||||
default = []
|
||||
testing = []
|
||||
|
||||
[dependencies]
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::validate_password_and_exchange;
|
||||
use crate::cache::Cached;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::provider::ConsoleBackend;
|
||||
use crate::console::AuthSecret;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::proxy::connect_compute::handle_try_wake;
|
||||
@@ -43,11 +44,8 @@ use tracing::{error, info, warn};
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum BackendType<'a, T> {
|
||||
/// Current Cloud API (V2).
|
||||
Console(Cow<'a, console::provider::neon::Api>, T),
|
||||
/// Local mock of Cloud API (V2).
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(Cow<'a, console::provider::mock::Api>, T),
|
||||
/// Cloud API (V2).
|
||||
Console(Cow<'a, ConsoleBackend>, T),
|
||||
/// Authentication via a web browser.
|
||||
Link(Cow<'a, url::ApiUrl>),
|
||||
#[cfg(test)]
|
||||
@@ -64,9 +62,15 @@ impl std::fmt::Display for BackendType<'_, ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(),
|
||||
Console(api, _) => match &**api {
|
||||
ConsoleBackend::Console(endpoint) => {
|
||||
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
ConsoleBackend::Postgres(endpoint) => {
|
||||
fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
|
||||
}
|
||||
},
|
||||
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
#[cfg(test)]
|
||||
Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
@@ -81,8 +85,6 @@ impl<T> BackendType<'_, T> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(Cow::Borrowed(c), x),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(c, x) => Postgres(Cow::Borrowed(c), x),
|
||||
Link(c) => Link(Cow::Borrowed(c)),
|
||||
#[cfg(test)]
|
||||
Test(x) => Test(*x),
|
||||
@@ -98,8 +100,6 @@ impl<'a, T> BackendType<'a, T> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(c, f(x)),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(c, x) => Postgres(c, f(x)),
|
||||
Link(c) => Link(c),
|
||||
#[cfg(test)]
|
||||
Test(x) => Test(x),
|
||||
@@ -114,8 +114,6 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => x.map(|x| Console(c, x)),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(c, x) => x.map(|x| Postgres(c, x)),
|
||||
Link(c) => Ok(Link(c)),
|
||||
#[cfg(test)]
|
||||
Test(x) => Ok(Test(x)),
|
||||
@@ -325,8 +323,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => user_info.project.clone(),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(_, user_info) => user_info.project.clone(),
|
||||
Link(_) => Some("link".into()),
|
||||
#[cfg(test)]
|
||||
Test(_) => Some("test".into()),
|
||||
@@ -339,8 +335,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => &user_info.user,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(_, user_info) => &user_info.user,
|
||||
Link(_) => "link",
|
||||
#[cfg(test)]
|
||||
Test(_) => "test",
|
||||
@@ -371,19 +365,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
.await?;
|
||||
(cache_info, BackendType::Console(api, user_info))
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, user_info) => {
|
||||
info!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.project(),
|
||||
"performing authentication using a local postgres instance"
|
||||
);
|
||||
|
||||
let (cache_info, user_info) =
|
||||
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
|
||||
.await?;
|
||||
(cache_info, BackendType::Postgres(api, user_info))
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url) => {
|
||||
info!("performing link authentication");
|
||||
@@ -414,8 +395,6 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
|
||||
Link(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
|
||||
#[cfg(test)]
|
||||
Test(x) => Ok(Cached::new_uncached(Arc::new(x.get_allowed_ips()?))),
|
||||
@@ -432,8 +411,6 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
|
||||
match self {
|
||||
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
||||
Link(_) => Ok(None),
|
||||
#[cfg(test)]
|
||||
Test(x) => x.wake_compute().map(Some),
|
||||
|
||||
@@ -57,24 +57,31 @@ pub(super) async fn authenticate(
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<NodeInfo> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
// registering waiter can fail if we get unlucky with rng.
|
||||
// just try again.
|
||||
let (psql_session_id, waiter) = loop {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
|
||||
match console::mgmt::get_waiter(&psql_session_id) {
|
||||
Ok(waiter) => break (psql_session_id, waiter),
|
||||
Err(_e) => continue,
|
||||
}
|
||||
};
|
||||
|
||||
let span = info_span!("link", psql_session_id = &psql_session_id);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
|
||||
let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database.
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::NoticeResponse(&greeting))
|
||||
.await?;
|
||||
// Give user a URL to spawn a new database.
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::NoticeResponse(&greeting))
|
||||
.await?;
|
||||
|
||||
// Wait for web console response (see `mgmt`).
|
||||
info!(parent: &span, "waiting for console's reply...");
|
||||
waiter.await?.map_err(LinkAuthError::AuthFailed)
|
||||
})
|
||||
.await?;
|
||||
// Wait for web console response (see `mgmt`).
|
||||
info!(parent: &span, "waiting for console's reply...");
|
||||
let db_info = waiter.await.map_err(LinkAuthError::from)?;
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
|
||||
|
||||
@@ -249,12 +249,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
if let auth::BackendType::Console(api, _) = &config.auth_backend {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(url) = args.redis_notifications {
|
||||
info!("Starting redis notifications listener ({url})");
|
||||
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
|
||||
match &**api {
|
||||
proxy::console::provider::ConsoleBackend::Console(api) => {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(url) = args.redis_notifications {
|
||||
info!("Starting redis notifications listener ({url})");
|
||||
maintenance_tasks
|
||||
.spawn(notifications::task_main(url.to_owned(), cache.clone()));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
proxy::console::provider::ConsoleBackend::Postgres(_) => {}
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
|
||||
let maintenance = loop {
|
||||
@@ -351,13 +358,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));
|
||||
|
||||
let api = console::provider::neon::Api::new(endpoint, caches, locks);
|
||||
let api = console::provider::ConsoleBackend::Console(api);
|
||||
auth::BackendType::Console(Cow::Owned(api), ())
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
AuthBackend::Postgres => {
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let api = console::provider::mock::Api::new(url);
|
||||
auth::BackendType::Postgres(Cow::Owned(api), ())
|
||||
let api = console::provider::ConsoleBackend::Postgres(api);
|
||||
auth::BackendType::Console(Cow::Owned(api), ())
|
||||
}
|
||||
AuthBackend::Link => {
|
||||
let url = args.uri.parse()?;
|
||||
|
||||
@@ -13,16 +13,10 @@ use tracing::{error, info, info_span, Instrument};
|
||||
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
|
||||
|
||||
/// Give caller an opportunity to wait for the cloud's reply.
|
||||
pub async fn with_waiter<R, T, E>(
|
||||
pub fn get_waiter(
|
||||
psql_session_id: impl Into<String>,
|
||||
action: impl FnOnce(Waiter<'static, ComputeReady>) -> R,
|
||||
) -> Result<T, E>
|
||||
where
|
||||
R: std::future::Future<Output = Result<T, E>>,
|
||||
E: From<waiters::RegisterError>,
|
||||
{
|
||||
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
||||
action(waiter).await
|
||||
) -> Result<Waiter<'static, ComputeReady>, waiters::RegisterError> {
|
||||
CPLANE_WAITERS.register(psql_session_id.into())
|
||||
}
|
||||
|
||||
pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
@@ -77,7 +71,7 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
pub type ComputeReady = Result<DatabaseInfo, String>;
|
||||
pub type ComputeReady = DatabaseInfo;
|
||||
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
@@ -102,7 +96,7 @@ fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), Qu
|
||||
let _enter = span.enter();
|
||||
info!("got response: {:?}", resp.result);
|
||||
|
||||
match notify(resp.session_id, Ok(resp.result)) {
|
||||
match notify(resp.session_id, resp.result) {
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
|
||||
@@ -248,23 +248,75 @@ pub trait Api {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<Option<CachedRoleSecret>, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum ConsoleBackend {
|
||||
/// Current Cloud API (V2).
|
||||
Console(neon::Api),
|
||||
/// Local mock of Cloud API (V2).
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(mock::Api),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Api for ConsoleBackend {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<Option<CachedRoleSecret>, errors::GetAuthInfoError> {
|
||||
use ConsoleBackend::*;
|
||||
match self {
|
||||
Console(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api) => api.get_role_secret(ctx, user_info).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedIps, errors::GetAuthInfoError> {
|
||||
use ConsoleBackend::*;
|
||||
match self {
|
||||
Console(api) => api.get_allowed_ips(ctx, user_info).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api) => api.get_allowed_ips(ctx, user_info).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
||||
use ConsoleBackend::*;
|
||||
|
||||
match self {
|
||||
Console(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api) => api.wake_compute(ctx, user_info).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
|
||||
@@ -160,8 +160,6 @@ where
|
||||
let node_info = loop {
|
||||
let wake_res = match user_info {
|
||||
auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(feature = "testing")]
|
||||
auth::BackendType::Postgres(api, user_info) => api.wake_compute(ctx, user_info).await,
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => return Err(err.into()),
|
||||
// test backend
|
||||
|
||||
Reference in New Issue
Block a user