diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index cad5ff2afb..52ddfd90fb 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -619,6 +619,7 @@ mod tests { rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_auth_broker: false, + accept_jwts: false, }); async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index cba8601d14..c5ba8707fc 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -6,7 +6,7 @@ use crate::{ error::{ReportableError, UserFacingError}, metrics::{Metrics, SniKind}, proxy::NeonOptions, - serverless::SERVERLESS_DRIVER_SNI, + serverless::{SERVERLESS_DRIVER_AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI}, EndpointId, RoleName, }; use itertools::Itertools; @@ -76,7 +76,7 @@ pub(crate) fn endpoint_sni( cn: common_name.into(), }); } - if subdomain == SERVERLESS_DRIVER_SNI { + if subdomain == SERVERLESS_DRIVER_SNI || subdomain == SERVERLESS_DRIVER_AUTH_BROKER_SNI { return Ok(None); } Ok(Some(EndpointId::from(subdomain))) diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 44d264caaf..d14d7d80cd 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -278,6 +278,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_auth_broker: false, + accept_jwts: true, }, require_client_ip: false, handshake_timeout: Duration::from_secs(10), diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 748266edc9..425c7a3143 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -103,6 +103,9 @@ struct ProxyCliArgs { default_value = "http://localhost:3000/authenticate_proxy_request/" )] auth_endpoint: String, + /// if this is not local proxy, this toggles whether we accept jwt or passwords for http + #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + is_auth_broker: bool, /// path to TLS key for client postgres connections /// /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir @@ -383,9 +386,27 @@ async fn main() -> anyhow::Result<()> { info!("Starting mgmt on {mgmt_address}"); let mgmt_listener = TcpListener::bind(mgmt_address).await?; - let proxy_address: SocketAddr = args.proxy.parse()?; - info!("Starting proxy on {proxy_address}"); - let proxy_listener = TcpListener::bind(proxy_address).await?; + let proxy_listener = if !args.is_auth_broker { + let proxy_address: SocketAddr = args.proxy.parse()?; + info!("Starting proxy on {proxy_address}"); + + Some(TcpListener::bind(proxy_address).await?) + } else { + None + }; + + // TODO: rename the argument to something like serverless. + // It now covers more than just websockets, it also covers SQL over HTTP. + let serverless_listener = if let Some(serverless_address) = args.wss { + let serverless_address: SocketAddr = serverless_address.parse()?; + info!("Starting wss on {serverless_address}"); + Some(TcpListener::bind(serverless_address).await?) + } else if args.is_auth_broker { + bail!("wss arg must be present for auth-broker") + } else { + None + }; + let cancellation_token = CancellationToken::new(); let cancel_map = CancelMap::default(); @@ -431,21 +452,17 @@ async fn main() -> anyhow::Result<()> { // client facing tasks. these will exit on error or on cancellation // cancellation returns Ok(()) let mut client_tasks = JoinSet::new(); - client_tasks.spawn(proxy::proxy::task_main( - config, - proxy_listener, - cancellation_token.clone(), - cancellation_handler.clone(), - endpoint_rate_limiter.clone(), - )); - - // TODO: rename the argument to something like serverless. - // It now covers more than just websockets, it also covers SQL over HTTP. - if let Some(serverless_address) = args.wss { - let serverless_address: SocketAddr = serverless_address.parse()?; - info!("Starting wss on {serverless_address}"); - let serverless_listener = TcpListener::bind(serverless_address).await?; + if let Some(proxy_listener) = proxy_listener { + client_tasks.spawn(proxy::proxy::task_main( + config, + proxy_listener, + cancellation_token.clone(), + cancellation_handler.clone(), + endpoint_rate_limiter.clone(), + )); + } + if let Some(serverless_listener) = serverless_listener { client_tasks.spawn(serverless::task_main( config, serverless_listener, @@ -675,7 +692,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { )?; let http_config = HttpConfig { - accept_websockets: true, + accept_websockets: !args.is_auth_broker, pool_options: GlobalConnPoolOptions { max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, @@ -697,7 +714,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, ip_allowlist_check_enabled: !args.is_private_access_proxy, - is_auth_broker: true, + is_auth_broker: args.is_auth_broker, + accept_jwts: args.is_auth_broker, }; let config = Box::leak(Box::new(ProxyConfig { diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 207e4fdc18..1fe121d59c 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -72,6 +72,7 @@ pub struct AuthenticationConfig { pub ip_allowlist_check_enabled: bool, pub jwks_cache: JwkCache, pub is_auth_broker: bool, + pub accept_jwts: bool, } impl TlsConfig { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index a7e3fa709b..8e0c62af34 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -51,6 +51,7 @@ use tracing::{error, info, warn, Instrument}; use utils::http::error::ApiError; pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; +pub(crate) const SERVERLESS_DRIVER_AUTH_BROKER_SNI: &str = "apiauth"; pub async fn task_main( config: &'static ProxyConfig, diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 1a80632929..34c66bbc97 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,7 +1,7 @@ use std::{io, sync::Arc, time::Duration}; use async_trait::async_trait; -use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use tokio::net::{lookup_host, TcpStream}; use tracing::{field::display, info}; @@ -446,6 +446,9 @@ async fn connect_http2( }; let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new()) + .timer(TokioTimer::new()) + .keep_alive_interval(Duration::from_secs(20)) + .keep_alive_timeout(Duration::from_secs(5)) .handshake(TokioIo::new(stream)) .await?; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index dffd94e13f..29df25228a 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -43,6 +43,7 @@ use crate::auth::backend::ComputeCredentials; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; use crate::auth::ComputeUserInfoParseError; +use crate::config::AuthenticationConfig; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; @@ -148,6 +149,7 @@ impl UserFacingError for ConnInfoError { } fn get_conn_info( + config: &'static AuthenticationConfig, ctx: &RequestMonitoring, headers: &HeaderMap, tls: Option<&TlsConfig>, @@ -183,6 +185,10 @@ fn get_conn_info( ctx.set_user(username.clone()); let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { + if !config.accept_jwts { + return Err(ConnInfoError::MissingPassword); + } + let auth = auth .to_str() .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; @@ -192,6 +198,10 @@ fn get_conn_info( .into(), ) } else if let Some(pass) = connection_url.password() { + if config.accept_jwts { + return Err(ConnInfoError::MissingPassword); + } + AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { std::borrow::Cow::Borrowed(b) => b.into(), std::borrow::Cow::Owned(b) => b.into(), @@ -516,7 +526,12 @@ async fn handle_inner( "handling interactive connection from client" ); - let conn_info = get_conn_info(ctx, request.headers(), config.tls_config.as_ref())?; + let conn_info = get_conn_info( + &config.authentication_config, + ctx, + request.headers(), + config.tls_config.as_ref(), + )?; info!( user = conn_info.conn_info.user_info.user.as_str(), "credentials"