diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 7c408f817c..d51f45b5c2 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -4,9 +4,9 @@ pub mod backend; pub use backend::Backend; mod credentials; +pub use credentials::ComputeUserInfoMaybeEndpoint; pub(crate) use credentials::{ - check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, - ComputeUserInfoParseError, IpPattern, + check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoParseError, IpPattern, }; mod password_hack; @@ -77,7 +77,7 @@ pub(crate) enum AuthErrorImpl { #[derive(Debug, Error)] #[error(transparent)] -pub(crate) struct AuthError(Box); +pub struct AuthError(Box); impl AuthError { pub(crate) fn bad_auth_method(name: impl Into>) -> Self { diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 1d28c6df31..fc6f45af84 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -138,7 +138,7 @@ impl<'a, T, D, E> Backend<'a, Result, D> { } } -pub(crate) struct ComputeCredentials { +pub struct ComputeCredentials { pub(crate) info: ComputeUserInfo, pub(crate) keys: ComputeCredentialKeys, } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 0e91ae570a..c993ae4eb6 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -16,7 +16,7 @@ use thiserror::Error; use tracing::{info, warn}; #[derive(Debug, Error, PartialEq, Eq, Clone)] -pub(crate) enum ComputeUserInfoParseError { +pub enum ComputeUserInfoParseError { #[error("Parameter '{0}' is missing in startup packet.")] MissingKey(&'static str), @@ -51,10 +51,10 @@ impl ReportableError for ComputeUserInfoParseError { /// Various client credentials which we use for authentication. /// Note that we don't store any kind of client key or password here. #[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct ComputeUserInfoMaybeEndpoint { - pub(crate) user: RoleName, - pub(crate) endpoint_id: Option, - pub(crate) options: NeonOptions, +pub struct ComputeUserInfoMaybeEndpoint { + pub user: RoleName, + pub endpoint_id: Option, + pub options: NeonOptions, } impl ComputeUserInfoMaybeEndpoint { @@ -83,7 +83,7 @@ pub(crate) fn endpoint_sni( } impl ComputeUserInfoMaybeEndpoint { - pub(crate) fn parse( + pub fn parse( ctx: &RequestMonitoring, params: &StartupMessageParams, sni: Option<&str>, diff --git a/proxy/src/auth_proxy/backend.rs b/proxy/src/auth_proxy/backend.rs index e0b8eb4e8b..a1d6874570 100644 --- a/proxy/src/auth_proxy/backend.rs +++ b/proxy/src/auth_proxy/backend.rs @@ -86,7 +86,7 @@ impl std::fmt::Display for Backend<'_, ()> { impl Backend<'_, T> { /// Very similar to [`std::option::Option::as_ref`]. /// This helps us pass structured config to async tasks. - pub(crate) fn as_ref(&self) -> Backend<'_, &T> { + pub fn as_ref(&self) -> Backend<'_, &T> { match self { Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x), } @@ -97,7 +97,7 @@ impl<'a, T> Backend<'a, T> { /// Very similar to [`std::option::Option::map`]. /// Maps [`Backend`] to [`Backend`] by applying /// a function to a contained value. - pub(crate) fn map(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> { + pub fn map(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> { match self { Self::Console(c, x) => Backend::Console(c, f(x)), } @@ -202,13 +202,13 @@ async fn authenticate_with_secret( impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { /// Get username from the credentials. - pub(crate) fn get_user(&self) -> &str { + pub fn get_user(&self) -> &str { match self { Self::Console(_, user_info) => &user_info.user, } } - pub(crate) async fn authenticate( + pub async fn authenticate( self, client: &mut AuthProxyStream, config: &'static AuthenticationConfig, diff --git a/proxy/src/bin/auth_proxy.rs b/proxy/src/bin/auth_proxy.rs index 2e9b5b2997..848d7fe9fe 100644 --- a/proxy/src/bin/auth_proxy.rs +++ b/proxy/src/bin/auth_proxy.rs @@ -1,7 +1,25 @@ use std::{sync::Arc, time::Duration}; +use clap::Parser; use futures::TryStreamExt; -use proxy::{PglbCodec, PglbControlMessage, PglbMessage}; +use pq_proto::FeStartupPacket; +use proxy::{ + auth::{self, backend::AuthRateLimiter}, + auth_proxy::{self, backend::MaybeOwned, AuthProxyStream, Backend}, + config::{self, AuthenticationConfig, CacheOptions, ProjectInfoCacheOptions}, + console::{ + caches::ApiCaches, + locks::ApiLocks, + provider::{neon::Api, ConsoleBackend}, + }, + http, + metrics::Metrics, + proxy::NeonOptions, + rate_limiter::{RateBucketInfo, WakeComputeRateLimiter}, + scram::threadpool::ThreadPool, + stream::AuthProxyStreamExt, + PglbCodec, PglbControlMessage, PglbMessage, +}; use quinn::{ crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream, VarInt, @@ -14,8 +32,75 @@ use tokio::{ }; use tokio_util::{codec::Framed, task::TaskTracker}; +/// Neon proxy/router +#[derive(Parser)] +#[command(about)] +struct ProxyCliArgs { + /// cloud API endpoint for authenticating users + #[clap( + short, + long, + default_value = "http://localhost:3000/authenticate_proxy_request/" + )] + auth_endpoint: String, + /// timeout for the TLS handshake + #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] + handshake_timeout: tokio::time::Duration, + /// cache for `wake_compute` api method (use `size=0` to disable) + #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] + wake_compute_cache: String, + /// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable). + #[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)] + wake_compute_lock: String, + /// timeout for scram authentication protocol + #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] + scram_protocol_timeout: tokio::time::Duration, + /// size of the threadpool for password hashing + #[clap(long, default_value_t = 4)] + scram_thread_pool_size: u8, + /// Disable dynamic rate limiter and store the metrics to ensure its production behaviour. + #[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + disable_dynamic_rate_limiter: bool, + /// Endpoint rate limiter max number of requests per second. + /// + /// Provided in the form `@`. + /// Can be given multiple times for different bucket sizes. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] + endpoint_rps_limit: Vec, + /// Wake compute rate limiter max number of requests per second. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] + wake_compute_limit: Vec, + /// Whether the auth rate limiter actually takes effect (for testing) + #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + auth_rate_limit_enabled: bool, + /// Authentication rate limiter max number of hashes per second. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] + auth_rate_limit: Vec, + /// The IP subnet to use when considering whether two IP addresses are considered the same. + #[clap(long, default_value_t = 64)] + auth_rate_limit_ip_subnet: u8, + /// cache for `allowed_ips` (use `size=0` to disable) + #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] + allowed_ips_cache: String, + /// cache for `role_secret` (use `size=0` to disable) + #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] + role_secret_cache: String, + /// cache for `project_info` (use `size=0` to disable) + #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)] + project_info_cache: String, + /// cache for all valid endpoints + #[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)] + endpoint_cache_config: String, + + /// Whether to retry the wake_compute request + #[clap(long, default_value = config::RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)] + wake_compute_retry: String, +} + #[tokio::main] async fn main() { + let args = ProxyCliArgs::parse(); + let server = "127.0.0.1:5634".parse().unwrap(); let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap(); @@ -37,6 +122,63 @@ async fn main() { let tasks = TaskTracker::new(); + let thread_pool = ThreadPool::new(args.scram_thread_pool_size); + Metrics::install(thread_pool.metrics.clone()); + + let backend = { + let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse().unwrap(); + let project_info_cache_config: ProjectInfoCacheOptions = + args.project_info_cache.parse().unwrap(); + let endpoint_cache_config: config::EndpointCacheConfig = + args.endpoint_cache_config.parse().unwrap(); + + let caches = Box::leak(Box::new(ApiCaches::new( + wake_compute_cache_config, + project_info_cache_config, + endpoint_cache_config, + ))); + + let config::ConcurrencyLockOptions { + shards, + limiter, + epoch, + timeout, + } = args.wake_compute_lock.parse().unwrap(); + let locks = Box::leak(Box::new( + ApiLocks::new( + "wake_compute_lock", + limiter, + shards, + timeout, + epoch, + &Metrics::get().wake_compute_lock, + ) + .unwrap(), + )); + tokio::spawn(locks.garbage_collect_worker()); + + let url = args.auth_endpoint.parse().unwrap(); + let endpoint = http::Endpoint::new(url, http::new_client()); + + let mut wake_compute_rps_limit = args.wake_compute_limit.clone(); + RateBucketInfo::validate(&mut wake_compute_rps_limit).unwrap(); + let wake_compute_endpoint_rate_limiter = + Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit)); + let api = Api::new(endpoint, caches, locks, wake_compute_endpoint_rate_limiter); + let api = ConsoleBackend::Console(api); + Backend::Console(MaybeOwned::Owned(api), ()) + }; + + let auth = AuthenticationConfig { + thread_pool, + scram_protocol_timeout: args.scram_protocol_timeout, + rate_limiter_enabled: args.auth_rate_limit_enabled, + rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), + rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, + }; + + let config = Box::leak(Box::new(Config { backend, auth })); + loop { select! { _ = int.recv() => break, @@ -48,7 +190,7 @@ async fn main() { } stream = conn.accept_bi() => { let (send, recv) = stream.unwrap(); - tasks.spawn(handle_stream(send, recv)); + tasks.spawn(handle_stream(config, send, recv)); } } } @@ -66,6 +208,11 @@ async fn main() { conn.close(VarInt::from_u32(1), b"graceful shutdown"); } +struct Config { + backend: Backend<'static, ()>, + auth: AuthenticationConfig, +} + #[derive(Copy, Clone, Debug)] struct NoVerify; @@ -104,12 +251,35 @@ impl danger::ServerCertVerifier for NoVerify { } } -async fn handle_stream(send: SendStream, recv: RecvStream) { - let mut stream = Framed::new(join(recv, send), PglbCodec); +async fn handle_stream(config: &'static Config, send: SendStream, recv: RecvStream) { + let mut stream: AuthProxyStream = Framed::new(join(recv, send), PglbCodec); let first_msg = stream.try_next().await.unwrap(); - let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_first_msg))) = first_msg + let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg else { panic!("invalid first msg") }; + + let startup = stream.read_startup_packet().await.unwrap(); + let FeStartupPacket::StartupMessage { version, params } = startup else { + panic!("invalid startup message") + }; + + // Extract credentials which we're going to use for auth. + let user_info = auth::ComputeUserInfoMaybeEndpoint { + user: params.get("user").unwrap().into(), + endpoint_id: first_msg + .server_name + .as_deref() + .map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()), + options: NeonOptions::parse_params(¶ms), + }; + + let user_info = config.backend.as_ref().map(|()| user_info); + let user_info = match user_info.authenticate(&mut stream, &config.auth).await { + Ok(auth_result) => auth_result, + Err(e) => { + return stream.throw_error(e).await.unwrap(); + } + }; } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index ff199ac701..8806f31f14 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -377,10 +377,10 @@ async fn prepare_client_connection

( } #[derive(Debug, Clone, PartialEq, Eq, Default)] -pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>); +pub struct NeonOptions(Vec<(SmolStr, SmolStr)>); impl NeonOptions { - pub(crate) fn parse_params(params: &StartupMessageParams) -> Self { + pub fn parse_params(params: &StartupMessageParams) -> Self { params .options_raw() .map(Self::parse_from_iter) diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index d0f6920271..a1331e5b2c 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -298,7 +298,8 @@ impl AsyncWrite for Stream { } } -pub(crate) trait AuthProxyStreamExt { +#[allow(async_fn_in_trait)] +pub trait AuthProxyStreamExt { /// Write the message into an internal buffer, but don't flush the underlying stream. fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;