mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 17:02:56 +00:00
a little more type-safety, a little more verbose...
This commit is contained in:
@@ -34,6 +34,7 @@ use crate::{scram, stream};
|
||||
|
||||
/// The [crate::serverless] module can authenticate either using control-plane
|
||||
/// to get authentication state, or by using JWKs stored in the filesystem.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ServerlessBackend<'a> {
|
||||
/// Cloud API (V2).
|
||||
ControlPlane(&'a ControlPlaneBackend),
|
||||
|
||||
@@ -15,9 +15,9 @@ use super::conn_pool::poll_client;
|
||||
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, poll_http2_client, Send};
|
||||
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::local::{LocalBackend, StaticAuthRules};
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{self, check_peer_addr_is_in_list, AuthError, ServerlessBackend};
|
||||
use crate::auth::{check_peer_addr_is_in_list, AuthError, ServerlessBackend};
|
||||
use crate::compute;
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
@@ -26,7 +26,7 @@ use crate::config::ProxyConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::provider::ApiLockError;
|
||||
use crate::control_plane::provider::{ApiLockError, ControlPlaneBackend};
|
||||
use crate::control_plane::{Api, CachedNodeInfo};
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
@@ -41,7 +41,6 @@ pub(crate) struct PoolingBackend {
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) auth_backend: ServerlessBackend<'static>,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
@@ -49,19 +48,13 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_password(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: &ControlPlaneBackend,
|
||||
user_info: &ComputeUserInfo,
|
||||
password: &[u8],
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let cplane = match self.auth_backend {
|
||||
ServerlessBackend::ControlPlane(cplane) => cplane,
|
||||
ServerlessBackend::Local(_local) => {
|
||||
return Err(AuthError::bad_auth_method(
|
||||
"password authentication not supported by local_proxy",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let (allowed_ips, maybe_secret) = cplane.get_allowed_ips_and_secret(ctx, user_info).await?;
|
||||
let (allowed_ips, maybe_secret) = auth_backend
|
||||
.get_allowed_ips_and_secret(ctx, user_info)
|
||||
.await?;
|
||||
if self.config.authentication_config.ip_allowlist_check_enabled
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
|
||||
{
|
||||
@@ -75,7 +68,7 @@ impl PoolingBackend {
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => cplane.get_role_secret(ctx, user_info).await?,
|
||||
None => auth_backend.get_role_secret(ctx, user_info).await?,
|
||||
};
|
||||
|
||||
let secret = match cached_secret.value.clone() {
|
||||
@@ -118,10 +111,11 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_jwt(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
user_info: &ComputeUserInfo,
|
||||
jwt: String,
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
match &self.auth_backend {
|
||||
match auth_backend {
|
||||
ServerlessBackend::ControlPlane(console) => {
|
||||
self.config
|
||||
.authentication_config
|
||||
@@ -130,7 +124,7 @@ impl PoolingBackend {
|
||||
ctx,
|
||||
user_info.endpoint.clone(),
|
||||
&user_info.user,
|
||||
&**console,
|
||||
console,
|
||||
&jwt,
|
||||
)
|
||||
.await
|
||||
@@ -171,6 +165,7 @@ impl PoolingBackend {
|
||||
pub(crate) async fn connect_to_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentials,
|
||||
force_new: bool,
|
||||
@@ -190,11 +185,11 @@ impl PoolingBackend {
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let api = match &self.auth_backend {
|
||||
let api = match auth_backend {
|
||||
ServerlessBackend::ControlPlane(cplane) => {
|
||||
&cplane.attach_to_credentials(keys) as &dyn ComputeConnectBackend
|
||||
}
|
||||
ServerlessBackend::Local(local_proxy) => &**local_proxy as &dyn ComputeConnectBackend,
|
||||
ServerlessBackend::Local(local_proxy) => local_proxy as &dyn ComputeConnectBackend,
|
||||
};
|
||||
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
@@ -218,15 +213,9 @@ impl PoolingBackend {
|
||||
pub(crate) async fn connect_to_local_proxy(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
|
||||
let cplane = match &self.auth_backend {
|
||||
ServerlessBackend::Local(_) => {
|
||||
panic!("connect to local_proxy should not be called if we are already local_proxy")
|
||||
}
|
||||
ServerlessBackend::ControlPlane(cplane) => cplane,
|
||||
};
|
||||
|
||||
info!("pool: looking for an existing connection");
|
||||
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
return Ok(client);
|
||||
@@ -236,7 +225,7 @@ impl PoolingBackend {
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let backend = cplane.attach_to_credentials(ComputeCredentials {
|
||||
let backend = auth_backend.attach_to_credentials(ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
user: conn_info.user_info.user.clone(),
|
||||
endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)),
|
||||
@@ -271,26 +260,20 @@ impl PoolingBackend {
|
||||
pub(crate) async fn connect_to_local_postgres(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_backend: &LocalBackend,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
|
||||
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
let local_backend = match &self.auth_backend {
|
||||
auth::ServerlessBackend::ControlPlane(_) => {
|
||||
unreachable!("only local_proxy can connect to local postgres")
|
||||
}
|
||||
auth::ServerlessBackend::Local(local) => local,
|
||||
};
|
||||
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
// only install and grant usage one at a time.
|
||||
let _permit = local_backend.initialize.acquire().await.unwrap();
|
||||
let _permit = auth_backend.initialize.acquire().await.unwrap();
|
||||
|
||||
// check again for race
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
local_backend
|
||||
auth_backend
|
||||
.compute_ctl
|
||||
.install_extension(&ExtensionInstallRequest {
|
||||
extension: EXT_NAME,
|
||||
@@ -299,7 +282,7 @@ impl PoolingBackend {
|
||||
})
|
||||
.await?;
|
||||
|
||||
local_backend
|
||||
auth_backend
|
||||
.compute_ctl
|
||||
.grant_role(&SetRoleGrantsRequest {
|
||||
schema: EXT_SCHEMA,
|
||||
@@ -317,7 +300,7 @@ impl PoolingBackend {
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let mut node_info = local_backend.node_info.clone();
|
||||
let mut node_info = auth_backend.node_info.clone();
|
||||
|
||||
let (key, jwk) = create_random_jwk();
|
||||
|
||||
|
||||
@@ -112,7 +112,6 @@ pub async fn task_main(
|
||||
local_pool,
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
auth_backend,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
});
|
||||
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
|
||||
@@ -185,6 +184,7 @@ pub async fn task_main(
|
||||
|
||||
Box::pin(connection_handler(
|
||||
config,
|
||||
auth_backend,
|
||||
backend,
|
||||
connections2,
|
||||
cancellation_handler,
|
||||
@@ -290,6 +290,7 @@ async fn connection_startup(
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn connection_handler(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -324,6 +325,7 @@ async fn connection_handler(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
auth_backend,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
@@ -363,6 +365,7 @@ async fn connection_handler(
|
||||
async fn request_handler(
|
||||
mut request: hyper::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -383,7 +386,7 @@ async fn request_handler(
|
||||
if config.http_config.accept_websockets
|
||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
||||
{
|
||||
let ServerlessBackend::ControlPlane(auth_backend) = backend.auth_backend else {
|
||||
let ServerlessBackend::ControlPlane(auth_backend) = auth_backend else {
|
||||
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
|
||||
};
|
||||
|
||||
@@ -430,9 +433,16 @@ async fn request_handler(
|
||||
);
|
||||
let span = ctx.span();
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
|
||||
.instrument(span)
|
||||
.await
|
||||
sql_over_http::handle(
|
||||
config,
|
||||
ctx,
|
||||
request,
|
||||
auth_backend,
|
||||
backend,
|
||||
http_cancellation_token,
|
||||
)
|
||||
.instrument(span)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
|
||||
@@ -30,10 +30,11 @@ use super::conn_pool_lib::{self, ConnInfo};
|
||||
use super::http_util::json_response;
|
||||
use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
|
||||
use super::local_conn_pool;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{endpoint_sni, ComputeUserInfoParseError, ServerlessBackend};
|
||||
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::provider::ControlPlaneBackend;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::{HttpDirection, Metrics};
|
||||
use crate::proxy::{run_until_cancelled, NeonOptions};
|
||||
@@ -240,10 +241,11 @@ pub(crate) async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
let result = handle_inner(cancel, config, &ctx, request, auth_backend, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
Ok(r) => {
|
||||
@@ -498,6 +500,7 @@ async fn handle_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
|
||||
let _requeset_gauge = Metrics::get()
|
||||
@@ -522,7 +525,11 @@ async fn handle_inner(
|
||||
|
||||
match conn_info.auth {
|
||||
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
let ServerlessBackend::ControlPlane(cplane) = auth_backend else {
|
||||
panic!("auth_broker must be configured with a control-plane auth backend.")
|
||||
};
|
||||
|
||||
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, cplane, backend).await
|
||||
}
|
||||
auth => {
|
||||
handle_db_inner(
|
||||
@@ -532,6 +539,7 @@ async fn handle_inner(
|
||||
request,
|
||||
conn_info.conn_info,
|
||||
auth,
|
||||
auth_backend,
|
||||
backend,
|
||||
)
|
||||
.await
|
||||
@@ -539,6 +547,7 @@ async fn handle_inner(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handle_db_inner(
|
||||
cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
@@ -546,6 +555,7 @@ async fn handle_db_inner(
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
auth: AuthData,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
|
||||
//
|
||||
@@ -588,48 +598,58 @@ async fn handle_db_inner(
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let is_local_proxy = matches!(
|
||||
backend.auth_backend,
|
||||
crate::auth::ServerlessBackend::Local(_)
|
||||
);
|
||||
let authenticate_and_connect = Box::pin(async {
|
||||
let creds = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
let ServerlessBackend::ControlPlane(cplane) = auth_backend else {
|
||||
return Err(SqlOverHttpError::ConnInfo(
|
||||
ConnInfoError::MissingCredentials(Credentials::BearerJwt),
|
||||
));
|
||||
};
|
||||
|
||||
let keys = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
|
||||
.await?
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
backend
|
||||
.authenticate_with_password(ctx, cplane, &conn_info.user_info, &pw)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?
|
||||
}
|
||||
AuthData::Jwt(jwt) => backend
|
||||
.authenticate_with_jwt(ctx, auth_backend, &conn_info.user_info, jwt)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?,
|
||||
};
|
||||
|
||||
let client = match keys.keys {
|
||||
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
|
||||
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
|
||||
let (cli_inner, _dsc) = client.client_inner();
|
||||
cli_inner.set_jwt_session(&payload).await?;
|
||||
Client::Local(client)
|
||||
}
|
||||
_ => {
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
Client::Remote(client)
|
||||
}
|
||||
};
|
||||
let client = match (creds.keys, auth_backend) {
|
||||
(ComputeCredentialKeys::JwtPayload(payload), ServerlessBackend::Local(local)) => {
|
||||
let mut client = backend
|
||||
.connect_to_local_postgres(ctx, local, conn_info)
|
||||
.await?;
|
||||
let (cli_inner, _dsc) = client.client_inner();
|
||||
cli_inner.set_jwt_session(&payload).await?;
|
||||
Client::Local(client)
|
||||
}
|
||||
(keys, auth_backend) => {
|
||||
let client = backend
|
||||
.connect_to_compute(
|
||||
ctx,
|
||||
auth_backend,
|
||||
conn_info,
|
||||
ComputeCredentials {
|
||||
keys,
|
||||
info: creds.info,
|
||||
},
|
||||
!allow_pool,
|
||||
)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
Client::Remote(client)
|
||||
}
|
||||
};
|
||||
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.success();
|
||||
Ok::<_, HttpConnError>(client)
|
||||
}
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.success();
|
||||
Ok::<_, SqlOverHttpError>(client)
|
||||
});
|
||||
|
||||
let (payload, mut client) = match run_until_cancelled(
|
||||
// Run both operations in parallel
|
||||
@@ -714,14 +734,22 @@ async fn handle_auth_broker_inner(
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
jwt: String,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
|
||||
backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
ServerlessBackend::ControlPlane(auth_backend),
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||
let mut client = backend
|
||||
.connect_to_local_proxy(ctx, auth_backend, conn_info)
|
||||
.await?;
|
||||
|
||||
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user