a little more type-safety, a little more verbose...

This commit is contained in:
Conrad Ludgate
2024-10-24 12:33:10 +01:00
parent c8108a4b84
commit b66e545e26
4 changed files with 110 additions and 88 deletions

View File

@@ -34,6 +34,7 @@ use crate::{scram, stream};
/// The [crate::serverless] module can authenticate either using control-plane
/// to get authentication state, or by using JWKs stored in the filesystem.
#[derive(Clone, Copy)]
pub enum ServerlessBackend<'a> {
/// Cloud API (V2).
ControlPlane(&'a ControlPlaneBackend),

View File

@@ -15,9 +15,9 @@ use super::conn_pool::poll_client;
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
use super::http_conn_pool::{self, poll_http2_client, Send};
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::local::{LocalBackend, StaticAuthRules};
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, check_peer_addr_is_in_list, AuthError, ServerlessBackend};
use crate::auth::{check_peer_addr_is_in_list, AuthError, ServerlessBackend};
use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
@@ -26,7 +26,7 @@ use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::provider::ApiLockError;
use crate::control_plane::provider::{ApiLockError, ControlPlaneBackend};
use crate::control_plane::{Api, CachedNodeInfo};
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
@@ -41,7 +41,6 @@ pub(crate) struct PoolingBackend {
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: ServerlessBackend<'static>,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
@@ -49,19 +48,13 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_password(
&self,
ctx: &RequestMonitoring,
auth_backend: &ControlPlaneBackend,
user_info: &ComputeUserInfo,
password: &[u8],
) -> Result<ComputeCredentials, AuthError> {
let cplane = match self.auth_backend {
ServerlessBackend::ControlPlane(cplane) => cplane,
ServerlessBackend::Local(_local) => {
return Err(AuthError::bad_auth_method(
"password authentication not supported by local_proxy",
))
}
};
let (allowed_ips, maybe_secret) = cplane.get_allowed_ips_and_secret(ctx, user_info).await?;
let (allowed_ips, maybe_secret) = auth_backend
.get_allowed_ips_and_secret(ctx, user_info)
.await?;
if self.config.authentication_config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
@@ -75,7 +68,7 @@ impl PoolingBackend {
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => cplane.get_role_secret(ctx, user_info).await?,
None => auth_backend.get_role_secret(ctx, user_info).await?,
};
let secret = match cached_secret.value.clone() {
@@ -118,10 +111,11 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
auth_backend: ServerlessBackend<'static>,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<ComputeCredentials, AuthError> {
match &self.auth_backend {
match auth_backend {
ServerlessBackend::ControlPlane(console) => {
self.config
.authentication_config
@@ -130,7 +124,7 @@ impl PoolingBackend {
ctx,
user_info.endpoint.clone(),
&user_info.user,
&**console,
console,
&jwt,
)
.await
@@ -171,6 +165,7 @@ impl PoolingBackend {
pub(crate) async fn connect_to_compute(
&self,
ctx: &RequestMonitoring,
auth_backend: ServerlessBackend<'static>,
conn_info: ConnInfo,
keys: ComputeCredentials,
force_new: bool,
@@ -190,11 +185,11 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let api = match &self.auth_backend {
let api = match auth_backend {
ServerlessBackend::ControlPlane(cplane) => {
&cplane.attach_to_credentials(keys) as &dyn ComputeConnectBackend
}
ServerlessBackend::Local(local_proxy) => &**local_proxy as &dyn ComputeConnectBackend,
ServerlessBackend::Local(local_proxy) => local_proxy as &dyn ComputeConnectBackend,
};
crate::proxy::connect_compute::connect_to_compute(
@@ -218,15 +213,9 @@ impl PoolingBackend {
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestMonitoring,
auth_backend: &'static ControlPlaneBackend,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
let cplane = match &self.auth_backend {
ServerlessBackend::Local(_) => {
panic!("connect to local_proxy should not be called if we are already local_proxy")
}
ServerlessBackend::ControlPlane(cplane) => cplane,
};
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
@@ -236,7 +225,7 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = cplane.attach_to_credentials(ComputeCredentials {
let backend = auth_backend.attach_to_credentials(ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)),
@@ -271,26 +260,20 @@ impl PoolingBackend {
pub(crate) async fn connect_to_local_postgres(
&self,
ctx: &RequestMonitoring,
auth_backend: &LocalBackend,
conn_info: ConnInfo,
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
return Ok(client);
}
let local_backend = match &self.auth_backend {
auth::ServerlessBackend::ControlPlane(_) => {
unreachable!("only local_proxy can connect to local postgres")
}
auth::ServerlessBackend::Local(local) => local,
};
if !self.local_pool.initialized(&conn_info) {
// only install and grant usage one at a time.
let _permit = local_backend.initialize.acquire().await.unwrap();
let _permit = auth_backend.initialize.acquire().await.unwrap();
// check again for race
if !self.local_pool.initialized(&conn_info) {
local_backend
auth_backend
.compute_ctl
.install_extension(&ExtensionInstallRequest {
extension: EXT_NAME,
@@ -299,7 +282,7 @@ impl PoolingBackend {
})
.await?;
local_backend
auth_backend
.compute_ctl
.grant_role(&SetRoleGrantsRequest {
schema: EXT_SCHEMA,
@@ -317,7 +300,7 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = local_backend.node_info.clone();
let mut node_info = auth_backend.node_info.clone();
let (key, jwk) = create_random_jwk();

View File

@@ -112,7 +112,6 @@ pub async fn task_main(
local_pool,
pool: Arc::clone(&conn_pool),
config,
auth_backend,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
@@ -185,6 +184,7 @@ pub async fn task_main(
Box::pin(connection_handler(
config,
auth_backend,
backend,
connections2,
cancellation_handler,
@@ -290,6 +290,7 @@ async fn connection_startup(
#[allow(clippy::too_many_arguments)]
async fn connection_handler(
config: &'static ProxyConfig,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -324,6 +325,7 @@ async fn connection_handler(
request_handler(
req,
config,
auth_backend,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
@@ -363,6 +365,7 @@ async fn connection_handler(
async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -383,7 +386,7 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ServerlessBackend::ControlPlane(auth_backend) = backend.auth_backend else {
let ServerlessBackend::ControlPlane(auth_backend) = auth_backend else {
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
};
@@ -430,9 +433,16 @@ async fn request_handler(
);
let span = ctx.span();
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
sql_over_http::handle(
config,
ctx,
request,
auth_backend,
backend,
http_cancellation_token,
)
.instrument(span)
.await
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")

View File

@@ -30,10 +30,11 @@ use super::conn_pool_lib::{self, ConnInfo};
use super::http_util::json_response;
use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
use super::local_conn_pool;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::{endpoint_sni, ComputeUserInfoParseError, ServerlessBackend};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
use crate::context::RequestMonitoring;
use crate::control_plane::provider::ControlPlaneBackend;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::{HttpDirection, Metrics};
use crate::proxy::{run_until_cancelled, NeonOptions};
@@ -240,10 +241,11 @@ pub(crate) async fn handle(
config: &'static ProxyConfig,
ctx: RequestMonitoring,
request: Request<Incoming>,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let result = handle_inner(cancel, config, &ctx, request, auth_backend, backend).await;
let mut response = match result {
Ok(r) => {
@@ -498,6 +500,7 @@ async fn handle_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get()
@@ -522,7 +525,11 @@ async fn handle_inner(
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
let ServerlessBackend::ControlPlane(cplane) = auth_backend else {
panic!("auth_broker must be configured with a control-plane auth backend.")
};
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, cplane, backend).await
}
auth => {
handle_db_inner(
@@ -532,6 +539,7 @@ async fn handle_inner(
request,
conn_info.conn_info,
auth,
auth_backend,
backend,
)
.await
@@ -539,6 +547,7 @@ async fn handle_inner(
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_db_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
@@ -546,6 +555,7 @@ async fn handle_db_inner(
request: Request<Incoming>,
conn_info: ConnInfo,
auth: AuthData,
auth_backend: ServerlessBackend<'static>,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
//
@@ -588,48 +598,58 @@ async fn handle_db_inner(
.map_err(SqlOverHttpError::from),
);
let authenticate_and_connect = Box::pin(
async {
let is_local_proxy = matches!(
backend.auth_backend,
crate::auth::ServerlessBackend::Local(_)
);
let authenticate_and_connect = Box::pin(async {
let creds = match auth {
AuthData::Password(pw) => {
let ServerlessBackend::ControlPlane(cplane) = auth_backend else {
return Err(SqlOverHttpError::ConnInfo(
ConnInfoError::MissingCredentials(Credentials::BearerJwt),
));
};
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await?
}
};
backend
.authenticate_with_password(ctx, cplane, &conn_info.user_info, &pw)
.await
.map_err(HttpConnError::from)?
}
AuthData::Jwt(jwt) => backend
.authenticate_with_jwt(ctx, auth_backend, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::from)?,
};
let client = match keys.keys {
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
Client::Local(client)
}
_ => {
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
Client::Remote(client)
}
};
let client = match (creds.keys, auth_backend) {
(ComputeCredentialKeys::JwtPayload(payload), ServerlessBackend::Local(local)) => {
let mut client = backend
.connect_to_local_postgres(ctx, local, conn_info)
.await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
Client::Local(client)
}
(keys, auth_backend) => {
let client = backend
.connect_to_compute(
ctx,
auth_backend,
conn_info,
ComputeCredentials {
keys,
info: creds.info,
},
!allow_pool,
)
.await
.map_err(HttpConnError::from)?;
Client::Remote(client)
}
};
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.success();
Ok::<_, HttpConnError>(client)
}
.map_err(SqlOverHttpError::from),
);
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.success();
Ok::<_, SqlOverHttpError>(client)
});
let (payload, mut client) = match run_until_cancelled(
// Run both operations in parallel
@@ -714,14 +734,22 @@ async fn handle_auth_broker_inner(
request: Request<Incoming>,
conn_info: ConnInfo,
jwt: String,
auth_backend: &'static ControlPlaneBackend,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.authenticate_with_jwt(
ctx,
ServerlessBackend::ControlPlane(auth_backend),
&conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let mut client = backend
.connect_to_local_proxy(ctx, auth_backend, conn_info)
.await?;
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");