mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-21 12:22:56 +00:00
Compare commits
5 Commits
cloneable/
...
proxy-simp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b66e545e26 | ||
|
|
c8108a4b84 | ||
|
|
2d34fec39b | ||
|
|
3da4705775 | ||
|
|
80c5576816 |
@@ -21,10 +21,7 @@ use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserIn
|
||||
use crate::cache::Cached;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::errors::GetAuthInfoError;
|
||||
use crate::control_plane::provider::{
|
||||
CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneBackend,
|
||||
};
|
||||
use crate::control_plane::provider::{CachedNodeInfo, ControlPlaneBackend};
|
||||
use crate::control_plane::{self, Api, AuthSecret};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::Metrics;
|
||||
@@ -35,38 +32,19 @@ use crate::stream::Stream;
|
||||
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
|
||||
use crate::{scram, stream};
|
||||
|
||||
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
|
||||
pub enum MaybeOwned<'a, T> {
|
||||
Owned(T),
|
||||
Borrowed(&'a T),
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
MaybeOwned::Owned(t) => t,
|
||||
MaybeOwned::Borrowed(t) => t,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
/// which we use in [`crate::config::ProxyConfig`].
|
||||
///
|
||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum Backend<'a, T> {
|
||||
/// The [crate::serverless] module can authenticate either using control-plane
|
||||
/// to get authentication state, or by using JWKs stored in the filesystem.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ServerlessBackend<'a> {
|
||||
/// Cloud API (V2).
|
||||
ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
|
||||
ControlPlane(&'a ControlPlaneBackend),
|
||||
/// Local proxy uses configured auth credentials and does not wake compute
|
||||
Local(MaybeOwned<'a, LocalBackend>),
|
||||
Local(&'a LocalBackend),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
|
||||
@@ -83,56 +61,20 @@ impl Clone for Box<dyn TestBackend> {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Backend<'_, ()> {
|
||||
impl std::fmt::Display for ControlPlaneBackend {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ControlPlane(api, ()) => match &**api {
|
||||
ControlPlaneBackend::Management(endpoint) => fmt
|
||||
.debug_tuple("ControlPlane::Management")
|
||||
.field(&endpoint.url())
|
||||
.finish(),
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ControlPlaneBackend::PostgresMock(endpoint) => fmt
|
||||
.debug_tuple("ControlPlane::PostgresMock")
|
||||
.field(&endpoint.url())
|
||||
.finish(),
|
||||
#[cfg(test)]
|
||||
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
|
||||
},
|
||||
Self::Local(_) => fmt.debug_tuple("Local").finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Backend<'_, T> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
|
||||
match self {
|
||||
Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
|
||||
Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Backend<'a, T> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
|
||||
match self {
|
||||
Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
|
||||
Self::Local(l) => Backend::Local(l),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a, T, E> Backend<'a, Result<T, E>> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
|
||||
match self {
|
||||
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
|
||||
Self::Local(l) => Ok(Backend::Local(l)),
|
||||
ControlPlaneBackend::Management(endpoint) => fmt
|
||||
.debug_tuple("ControlPlane::Management")
|
||||
.field(&endpoint.url())
|
||||
.finish(),
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ControlPlaneBackend::PostgresMock(endpoint) => fmt
|
||||
.debug_tuple("ControlPlane::PostgresMock")
|
||||
.field(&endpoint.url())
|
||||
.finish(),
|
||||
#[cfg(test)]
|
||||
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -399,96 +341,79 @@ async fn authenticate_with_secret(
|
||||
classic::authenticate(ctx, info, client, config, secret).await
|
||||
}
|
||||
|
||||
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get username from the credentials.
|
||||
pub(crate) fn get_user(&self) -> &str {
|
||||
match self {
|
||||
Self::ControlPlane(_, user_info) => &user_info.user,
|
||||
Self::Local(_) => "local",
|
||||
}
|
||||
}
|
||||
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
impl ControlPlaneBackend {
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub(crate) async fn authenticate(
|
||||
self,
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<Backend<'a, ComputeCredentials>> {
|
||||
let res = match self {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
info!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
) -> auth::Result<ControlPlaneComputeBackend> {
|
||||
info!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let credentials = auth_quirks(
|
||||
ctx,
|
||||
&*api,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
Backend::ControlPlane(api, credentials)
|
||||
}
|
||||
Self::Local(_) => {
|
||||
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
|
||||
}
|
||||
};
|
||||
let credentials = auth_quirks(
|
||||
ctx,
|
||||
self,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!("user successfully authenticated");
|
||||
Ok(res)
|
||||
Ok(ControlPlaneComputeBackend {
|
||||
api: self,
|
||||
creds: credentials,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn attach_to_credentials(
|
||||
&self,
|
||||
creds: ComputeCredentials,
|
||||
) -> ControlPlaneComputeBackend {
|
||||
ControlPlaneComputeBackend { api: self, creds }
|
||||
}
|
||||
}
|
||||
|
||||
impl Backend<'_, ComputeUserInfo> {
|
||||
pub(crate) async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::Local(_) => Ok(Cached::new_uncached(None)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
api.get_allowed_ips_and_secret(ctx, user_info).await
|
||||
}
|
||||
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
}
|
||||
pub struct ControlPlaneComputeBackend<'a> {
|
||||
api: &'a ControlPlaneBackend,
|
||||
creds: ComputeCredentials,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
|
||||
impl ComputeConnectBackend for ControlPlaneComputeBackend<'static> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
|
||||
}
|
||||
self.api.wake_compute(ctx, &self.creds.info).await
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
match self {
|
||||
Self::ControlPlane(_, creds) => &creds.keys,
|
||||
Self::Local(_) => &ComputeCredentialKeys::None,
|
||||
}
|
||||
&self.creds.keys
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for LocalBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
Ok(Cached::new_uncached(self.node_info.clone()))
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
&ComputeCredentialKeys::None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::Backend;
|
||||
pub use backend::ServerlessBackend;
|
||||
|
||||
mod credentials;
|
||||
pub(crate) use credentials::{
|
||||
|
||||
@@ -203,7 +203,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let task = serverless::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
auth::ServerlessBackend::Local(auth_backend),
|
||||
http_listener,
|
||||
shutdown.clone(),
|
||||
Arc::new(CancellationHandlerMain::new(
|
||||
@@ -295,12 +295,8 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
}
|
||||
|
||||
/// auth::Backend is created at proxy startup, and lives forever.
|
||||
fn build_auth_backend(
|
||||
args: &LocalProxyCliArgs,
|
||||
) -> anyhow::Result<&'static auth::Backend<'static, ()>> {
|
||||
let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
|
||||
LocalBackend::new(args.postgres, args.compute_ctl.clone()),
|
||||
));
|
||||
fn build_auth_backend(args: &LocalProxyCliArgs) -> anyhow::Result<&'static LocalBackend> {
|
||||
let auth_backend = LocalBackend::new(args.postgres, args.compute_ctl.clone());
|
||||
|
||||
Ok(Box::leak(Box::new(auth_backend)))
|
||||
}
|
||||
|
||||
@@ -13,13 +13,14 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
|
||||
use aws_config::Region;
|
||||
use futures::future::Either;
|
||||
use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
|
||||
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend};
|
||||
use proxy::cancellation::{CancelMap, CancellationHandler};
|
||||
use proxy::config::{
|
||||
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
|
||||
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
|
||||
};
|
||||
use proxy::context::parquet::ParquetUploadArgs;
|
||||
use proxy::control_plane::provider::ControlPlaneBackend;
|
||||
use proxy::http::health_server::AppMetrics;
|
||||
use proxy::metrics::Metrics;
|
||||
use proxy::rate_limiter::{
|
||||
@@ -467,7 +468,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
if let Some(serverless_listener) = serverless_listener {
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
auth::ServerlessBackend::ControlPlane(auth_backend),
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
@@ -515,40 +516,38 @@ async fn main() -> anyhow::Result<()> {
|
||||
));
|
||||
}
|
||||
|
||||
if let Either::Left(auth::Backend::ControlPlane(api, _)) = &auth_backend {
|
||||
if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api {
|
||||
match (redis_notifications_client, regional_redis_client.clone()) {
|
||||
(None, None) => {}
|
||||
(client1, client2) => {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
if let Either::Left(ControlPlaneBackend::Management(api)) = &auth_backend {
|
||||
match (redis_notifications_client, regional_redis_client.clone()) {
|
||||
(None, None) => {}
|
||||
(client1, client2) => {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = regional_redis_client;
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks.spawn(
|
||||
async move { cache.do_read(con, cancellation_token.clone()).await }
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = regional_redis_client;
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks.spawn(
|
||||
async move { cache.do_read(con, cancellation_token.clone()).await }
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -694,7 +693,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
/// auth::Backend is created at proxy startup, and lives forever.
|
||||
fn build_auth_backend(
|
||||
args: &ProxyCliArgs,
|
||||
) -> anyhow::Result<Either<&'static auth::Backend<'static, ()>, &'static ConsoleRedirectBackend>> {
|
||||
) -> anyhow::Result<Either<&'static ControlPlaneBackend, &'static ConsoleRedirectBackend>> {
|
||||
match &args.auth_backend {
|
||||
AuthBackendType::Console => {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
@@ -744,8 +743,7 @@ fn build_auth_backend(
|
||||
locks,
|
||||
wake_compute_endpoint_rate_limiter,
|
||||
);
|
||||
let api = control_plane::provider::ControlPlaneBackend::Management(api);
|
||||
let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ());
|
||||
let auth_backend = control_plane::provider::ControlPlaneBackend::Management(api);
|
||||
|
||||
let config = Box::leak(Box::new(auth_backend));
|
||||
|
||||
@@ -756,9 +754,7 @@ fn build_auth_backend(
|
||||
AuthBackendType::Postgres => {
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
|
||||
let api = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
|
||||
|
||||
let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ());
|
||||
let auth_backend = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
|
||||
|
||||
let config = Box::leak(Box::new(auth_backend));
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ pub(crate) trait ConnectMechanism {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ComputeConnectBackend {
|
||||
pub(crate) trait ComputeConnectBackend: Send + Sync + 'static {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
@@ -98,10 +98,10 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism>(
|
||||
ctx: &RequestMonitoring,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
user_info: &dyn ComputeConnectBackend,
|
||||
allow_self_signed_compute: bool,
|
||||
wake_compute_retry_config: RetryConfig,
|
||||
connect_to_compute_retry_config: RetryConfig,
|
||||
|
||||
@@ -26,6 +26,7 @@ use self::passthrough::ProxyPassthrough;
|
||||
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::provider::ControlPlaneBackend;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::read_proxy_protocol;
|
||||
@@ -54,7 +55,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -241,7 +242,7 @@ impl ReportableError for ClientRequestError {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
ctx: &RequestMonitoring,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
stream: S,
|
||||
@@ -282,20 +283,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth_backend
|
||||
.as_ref()
|
||||
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names);
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e).await?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let user_info = match user_info
|
||||
let user = user_info.user.clone();
|
||||
let user_info = match auth_backend
|
||||
.authenticate(
|
||||
ctx,
|
||||
user_info,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::time::Duration;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use async_trait::async_trait;
|
||||
use auth::backend::ControlPlaneComputeBackend;
|
||||
use http::StatusCode;
|
||||
use retry::{retry_after, ShouldRetryWakeCompute};
|
||||
use rstest::rstest;
|
||||
@@ -19,7 +20,7 @@ use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::CouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, TestBackend,
|
||||
};
|
||||
use crate::config::{CertResolver, RetryConfig};
|
||||
use crate::control_plane::messages::{ControlPlaneError, Details, MetricsAuxInfo, Status};
|
||||
@@ -566,19 +567,21 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> auth::Backend<'static, ComputeCredentials> {
|
||||
let user_info = auth::Backend::ControlPlane(
|
||||
MaybeOwned::Owned(ControlPlaneBackend::Test(Box::new(mechanism.clone()))),
|
||||
ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
) -> ControlPlaneComputeBackend<'static> {
|
||||
let api = Box::leak(Box::new(ControlPlaneBackend::Test(Box::new(
|
||||
mechanism.clone(),
|
||||
))));
|
||||
|
||||
let creds = ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
);
|
||||
user_info
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
};
|
||||
|
||||
api.attach_to_credentials(creds)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -11,10 +11,10 @@ use crate::metrics::{
|
||||
};
|
||||
use crate::proxy::retry::{retry_after, should_retry};
|
||||
|
||||
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
|
||||
pub(crate) async fn wake_compute(
|
||||
num_retries: &mut u32,
|
||||
ctx: &RequestMonitoring,
|
||||
api: &B,
|
||||
api: &dyn ComputeConnectBackend,
|
||||
config: RetryConfig,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let retry_type = RetryType::WakeCompute;
|
||||
|
||||
@@ -15,9 +15,9 @@ use super::conn_pool::poll_client;
|
||||
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, poll_http2_client, Send};
|
||||
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::local::{LocalBackend, StaticAuthRules};
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{self, check_peer_addr_is_in_list, AuthError};
|
||||
use crate::auth::{check_peer_addr_is_in_list, AuthError, ServerlessBackend};
|
||||
use crate::compute;
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
@@ -26,11 +26,11 @@ use crate::config::ProxyConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::provider::ApiLockError;
|
||||
use crate::control_plane::CachedNodeInfo;
|
||||
use crate::control_plane::provider::{ApiLockError, ControlPlaneBackend};
|
||||
use crate::control_plane::{Api, CachedNodeInfo};
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::proxy::connect_compute::{ComputeConnectBackend, ConnectMechanism};
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::types::{EndpointId, Host};
|
||||
@@ -41,7 +41,6 @@ pub(crate) struct PoolingBackend {
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
@@ -49,12 +48,13 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_password(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: &ControlPlaneBackend,
|
||||
user_info: &ComputeUserInfo,
|
||||
password: &[u8],
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = user_info.clone();
|
||||
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
let (allowed_ips, maybe_secret) = auth_backend
|
||||
.get_allowed_ips_and_secret(ctx, user_info)
|
||||
.await?;
|
||||
if self.config.authentication_config.ip_allowlist_check_enabled
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
|
||||
{
|
||||
@@ -68,7 +68,7 @@ impl PoolingBackend {
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => backend.get_role_secret(ctx).await?,
|
||||
None => auth_backend.get_role_secret(ctx, user_info).await?,
|
||||
};
|
||||
|
||||
let secret = match cached_secret.value.clone() {
|
||||
@@ -103,7 +103,7 @@ impl PoolingBackend {
|
||||
}
|
||||
};
|
||||
res.map(|key| ComputeCredentials {
|
||||
info: user_info,
|
||||
info: user_info.clone(),
|
||||
keys: key,
|
||||
})
|
||||
}
|
||||
@@ -111,11 +111,12 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_jwt(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
user_info: &ComputeUserInfo,
|
||||
jwt: String,
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
match &self.auth_backend {
|
||||
crate::auth::Backend::ControlPlane(console, ()) => {
|
||||
match auth_backend {
|
||||
ServerlessBackend::ControlPlane(console) => {
|
||||
self.config
|
||||
.authentication_config
|
||||
.jwks_cache
|
||||
@@ -123,7 +124,7 @@ impl PoolingBackend {
|
||||
ctx,
|
||||
user_info.endpoint.clone(),
|
||||
&user_info.user,
|
||||
&**console,
|
||||
console,
|
||||
&jwt,
|
||||
)
|
||||
.await
|
||||
@@ -134,7 +135,7 @@ impl PoolingBackend {
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
})
|
||||
}
|
||||
crate::auth::Backend::Local(_) => {
|
||||
ServerlessBackend::Local(_) => {
|
||||
let keys = self
|
||||
.config
|
||||
.authentication_config
|
||||
@@ -164,6 +165,7 @@ impl PoolingBackend {
|
||||
pub(crate) async fn connect_to_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentials,
|
||||
force_new: bool,
|
||||
@@ -182,7 +184,14 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys);
|
||||
|
||||
let api = match auth_backend {
|
||||
ServerlessBackend::ControlPlane(cplane) => {
|
||||
&cplane.attach_to_credentials(keys) as &dyn ComputeConnectBackend
|
||||
}
|
||||
ServerlessBackend::Local(local_proxy) => local_proxy as &dyn ComputeConnectBackend,
|
||||
};
|
||||
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
@@ -191,7 +200,7 @@ impl PoolingBackend {
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&backend,
|
||||
api,
|
||||
false, // do not allow self signed compute for http flow
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
@@ -204,6 +213,7 @@ impl PoolingBackend {
|
||||
pub(crate) async fn connect_to_local_proxy(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
|
||||
info!("pool: looking for an existing connection");
|
||||
@@ -214,7 +224,8 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
|
||||
|
||||
let backend = auth_backend.attach_to_credentials(ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
user: conn_info.user_info.user.clone(),
|
||||
endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)),
|
||||
@@ -249,26 +260,20 @@ impl PoolingBackend {
|
||||
pub(crate) async fn connect_to_local_postgres(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: &LocalBackend,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
|
||||
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
let local_backend = match &self.auth_backend {
|
||||
auth::Backend::ControlPlane(_, ()) => {
|
||||
unreachable!("only local_proxy can connect to local postgres")
|
||||
}
|
||||
auth::Backend::Local(local) => local,
|
||||
};
|
||||
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
// only install and grant usage one at a time.
|
||||
let _permit = local_backend.initialize.acquire().await.unwrap();
|
||||
let _permit = auth_backend.initialize.acquire().await.unwrap();
|
||||
|
||||
// check again for race
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
local_backend
|
||||
auth_backend
|
||||
.compute_ctl
|
||||
.install_extension(&ExtensionInstallRequest {
|
||||
extension: EXT_NAME,
|
||||
@@ -277,7 +282,7 @@ impl PoolingBackend {
|
||||
})
|
||||
.await?;
|
||||
|
||||
local_backend
|
||||
auth_backend
|
||||
.compute_ctl
|
||||
.grant_role(&SetRoleGrantsRequest {
|
||||
schema: EXT_SCHEMA,
|
||||
@@ -295,7 +300,7 @@ impl PoolingBackend {
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let mut node_info = local_backend.node_info.clone();
|
||||
let mut node_info = auth_backend.node_info.clone();
|
||||
|
||||
let (key, jwk) = create_random_jwk();
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ use tokio_util::task::TaskTracker;
|
||||
use tracing::{info, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::auth::ServerlessBackend;
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
@@ -55,7 +56,7 @@ pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -111,7 +112,6 @@ pub async fn task_main(
|
||||
local_pool,
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
auth_backend,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
});
|
||||
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
|
||||
@@ -184,6 +184,7 @@ pub async fn task_main(
|
||||
|
||||
Box::pin(connection_handler(
|
||||
config,
|
||||
auth_backend,
|
||||
backend,
|
||||
connections2,
|
||||
cancellation_handler,
|
||||
@@ -289,6 +290,7 @@ async fn connection_startup(
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn connection_handler(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -323,6 +325,7 @@ async fn connection_handler(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
auth_backend,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
@@ -362,6 +365,7 @@ async fn connection_handler(
|
||||
async fn request_handler(
|
||||
mut request: hyper::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -382,6 +386,10 @@ async fn request_handler(
|
||||
if config.http_config.accept_websockets
|
||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
||||
{
|
||||
let ServerlessBackend::ControlPlane(auth_backend) = auth_backend else {
|
||||
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
|
||||
};
|
||||
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
@@ -399,7 +407,7 @@ async fn request_handler(
|
||||
async move {
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
backend.auth_backend,
|
||||
auth_backend,
|
||||
ctx,
|
||||
websocket,
|
||||
cancellation_handler,
|
||||
@@ -425,9 +433,16 @@ async fn request_handler(
|
||||
);
|
||||
let span = ctx.span();
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
|
||||
.instrument(span)
|
||||
.await
|
||||
sql_over_http::handle(
|
||||
config,
|
||||
ctx,
|
||||
request,
|
||||
auth_backend,
|
||||
backend,
|
||||
http_cancellation_token,
|
||||
)
|
||||
.instrument(span)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
|
||||
@@ -30,10 +30,11 @@ use super::conn_pool_lib::{self, ConnInfo};
|
||||
use super::http_util::json_response;
|
||||
use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
|
||||
use super::local_conn_pool;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{endpoint_sni, ComputeUserInfoParseError, ServerlessBackend};
|
||||
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::provider::ControlPlaneBackend;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::{HttpDirection, Metrics};
|
||||
use crate::proxy::{run_until_cancelled, NeonOptions};
|
||||
@@ -240,10 +241,11 @@ pub(crate) async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
let result = handle_inner(cancel, config, &ctx, request, auth_backend, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
Ok(r) => {
|
||||
@@ -498,6 +500,7 @@ async fn handle_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
|
||||
let _requeset_gauge = Metrics::get()
|
||||
@@ -522,7 +525,11 @@ async fn handle_inner(
|
||||
|
||||
match conn_info.auth {
|
||||
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
let ServerlessBackend::ControlPlane(cplane) = auth_backend else {
|
||||
panic!("auth_broker must be configured with a control-plane auth backend.")
|
||||
};
|
||||
|
||||
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, cplane, backend).await
|
||||
}
|
||||
auth => {
|
||||
handle_db_inner(
|
||||
@@ -532,6 +539,7 @@ async fn handle_inner(
|
||||
request,
|
||||
conn_info.conn_info,
|
||||
auth,
|
||||
auth_backend,
|
||||
backend,
|
||||
)
|
||||
.await
|
||||
@@ -539,6 +547,7 @@ async fn handle_inner(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handle_db_inner(
|
||||
cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
@@ -546,6 +555,7 @@ async fn handle_db_inner(
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
auth: AuthData,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
|
||||
//
|
||||
@@ -588,45 +598,58 @@ async fn handle_db_inner(
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_));
|
||||
let authenticate_and_connect = Box::pin(async {
|
||||
let creds = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
let ServerlessBackend::ControlPlane(cplane) = auth_backend else {
|
||||
return Err(SqlOverHttpError::ConnInfo(
|
||||
ConnInfoError::MissingCredentials(Credentials::BearerJwt),
|
||||
));
|
||||
};
|
||||
|
||||
let keys = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
|
||||
.await?
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
backend
|
||||
.authenticate_with_password(ctx, cplane, &conn_info.user_info, &pw)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?
|
||||
}
|
||||
AuthData::Jwt(jwt) => backend
|
||||
.authenticate_with_jwt(ctx, auth_backend, &conn_info.user_info, jwt)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?,
|
||||
};
|
||||
|
||||
let client = match keys.keys {
|
||||
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
|
||||
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
|
||||
let (cli_inner, _dsc) = client.client_inner();
|
||||
cli_inner.set_jwt_session(&payload).await?;
|
||||
Client::Local(client)
|
||||
}
|
||||
_ => {
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
Client::Remote(client)
|
||||
}
|
||||
};
|
||||
let client = match (creds.keys, auth_backend) {
|
||||
(ComputeCredentialKeys::JwtPayload(payload), ServerlessBackend::Local(local)) => {
|
||||
let mut client = backend
|
||||
.connect_to_local_postgres(ctx, local, conn_info)
|
||||
.await?;
|
||||
let (cli_inner, _dsc) = client.client_inner();
|
||||
cli_inner.set_jwt_session(&payload).await?;
|
||||
Client::Local(client)
|
||||
}
|
||||
(keys, auth_backend) => {
|
||||
let client = backend
|
||||
.connect_to_compute(
|
||||
ctx,
|
||||
auth_backend,
|
||||
conn_info,
|
||||
ComputeCredentials {
|
||||
keys,
|
||||
info: creds.info,
|
||||
},
|
||||
!allow_pool,
|
||||
)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
Client::Remote(client)
|
||||
}
|
||||
};
|
||||
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.success();
|
||||
Ok::<_, HttpConnError>(client)
|
||||
}
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.success();
|
||||
Ok::<_, SqlOverHttpError>(client)
|
||||
});
|
||||
|
||||
let (payload, mut client) = match run_until_cancelled(
|
||||
// Run both operations in parallel
|
||||
@@ -711,14 +734,22 @@ async fn handle_auth_broker_inner(
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
jwt: String,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
|
||||
backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
ServerlessBackend::ControlPlane(auth_backend),
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||
let mut client = backend
|
||||
.connect_to_local_proxy(ctx, auth_backend, conn_info)
|
||||
.await?;
|
||||
|
||||
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ use tracing::warn;
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::provider::ControlPlaneBackend;
|
||||
use crate::error::{io_error, ReportableError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::{handle_client, ClientMode, ErrorSource};
|
||||
@@ -125,7 +126,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
|
||||
pub(crate) async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
ctx: RequestMonitoring,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
|
||||
Reference in New Issue
Block a user