add authbroker cli flag and fix http2 ka

This commit is contained in:
Conrad Ludgate
2024-09-23 11:38:06 +01:00
parent 4bc2686dee
commit 75bfd57e01
8 changed files with 63 additions and 23 deletions

View File

@@ -619,6 +619,7 @@ mod tests {
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: false,
});
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {

View File

@@ -6,7 +6,7 @@ use crate::{
error::{ReportableError, UserFacingError},
metrics::{Metrics, SniKind},
proxy::NeonOptions,
serverless::SERVERLESS_DRIVER_SNI,
serverless::{SERVERLESS_DRIVER_AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI},
EndpointId, RoleName,
};
use itertools::Itertools;
@@ -76,7 +76,7 @@ pub(crate) fn endpoint_sni(
cn: common_name.into(),
});
}
if subdomain == SERVERLESS_DRIVER_SNI {
if subdomain == SERVERLESS_DRIVER_SNI || subdomain == SERVERLESS_DRIVER_AUTH_BROKER_SNI {
return Ok(None);
}
Ok(Some(EndpointId::from(subdomain)))

View File

@@ -278,6 +278,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: true,
},
require_client_ip: false,
handshake_timeout: Duration::from_secs(10),

View File

@@ -103,6 +103,9 @@ struct ProxyCliArgs {
default_value = "http://localhost:3000/authenticate_proxy_request/"
)]
auth_endpoint: String,
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_auth_broker: bool,
/// path to TLS key for client postgres connections
///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
@@ -383,9 +386,27 @@ async fn main() -> anyhow::Result<()> {
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let proxy_listener = if !args.is_auth_broker {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
Some(TcpListener::bind(proxy_address).await?)
} else {
None
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
None
};
let cancellation_token = CancellationToken::new();
let cancel_map = CancelMap::default();
@@ -431,21 +452,17 @@ async fn main() -> anyhow::Result<()> {
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
let serverless_listener = TcpListener::bind(serverless_address).await?;
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
serverless_listener,
@@ -675,7 +692,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
)?;
let http_config = HttpConfig {
accept_websockets: true,
accept_websockets: !args.is_auth_broker,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
@@ -697,7 +714,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: true,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
};
let config = Box::leak(Box::new(ProxyConfig {

View File

@@ -72,6 +72,7 @@ pub struct AuthenticationConfig {
pub ip_allowlist_check_enabled: bool,
pub jwks_cache: JwkCache,
pub is_auth_broker: bool,
pub accept_jwts: bool,
}
impl TlsConfig {

View File

@@ -51,6 +51,7 @@ use tracing::{error, info, warn, Instrument};
use utils::http::error::ApiError;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub(crate) const SERVERLESS_DRIVER_AUTH_BROKER_SNI: &str = "apiauth";
pub async fn task_main(
config: &'static ProxyConfig,

View File

@@ -1,7 +1,7 @@
use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use tokio::net::{lookup_host, TcpStream};
use tracing::{field::display, info};
@@ -446,6 +446,9 @@ async fn connect_http2(
};
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;

View File

@@ -43,6 +43,7 @@ use crate::auth::backend::ComputeCredentials;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
@@ -148,6 +149,7 @@ impl UserFacingError for ConnInfoError {
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestMonitoring,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
@@ -183,6 +185,10 @@ fn get_conn_info(
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingPassword);
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
@@ -192,6 +198,10 @@ fn get_conn_info(
.into(),
)
} else if let Some(pass) = connection_url.password() {
if config.accept_jwts {
return Err(ConnInfoError::MissingPassword);
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
@@ -516,7 +526,12 @@ async fn handle_inner(
"handling interactive connection from client"
);
let conn_info = get_conn_info(ctx, request.headers(), config.tls_config.as_ref())?;
let conn_info = get_conn_info(
&config.authentication_config,
ctx,
request.headers(),
config.tls_config.as_ref(),
)?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"