build out auth proxy core logic

This commit is contained in:
Conrad Ludgate
2024-09-12 17:23:56 +01:00
parent 91e8b7d22b
commit ce200a53e8
7 changed files with 193 additions and 22 deletions

View File

@@ -4,9 +4,9 @@ pub mod backend;
pub use backend::Backend;
mod credentials;
pub use credentials::ComputeUserInfoMaybeEndpoint;
pub(crate) use credentials::{
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
ComputeUserInfoParseError, IpPattern,
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoParseError, IpPattern,
};
mod password_hack;
@@ -77,7 +77,7 @@ pub(crate) enum AuthErrorImpl {
#[derive(Debug, Error)]
#[error(transparent)]
pub(crate) struct AuthError(Box<AuthErrorImpl>);
pub struct AuthError(Box<AuthErrorImpl>);
impl AuthError {
pub(crate) fn bad_auth_method(name: impl Into<Box<str>>) -> Self {

View File

@@ -138,7 +138,7 @@ impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
}
}
pub(crate) struct ComputeCredentials {
pub struct ComputeCredentials {
pub(crate) info: ComputeUserInfo,
pub(crate) keys: ComputeCredentialKeys,
}

View File

@@ -16,7 +16,7 @@ use thiserror::Error;
use tracing::{info, warn};
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub(crate) enum ComputeUserInfoParseError {
pub enum ComputeUserInfoParseError {
#[error("Parameter '{0}' is missing in startup packet.")]
MissingKey(&'static str),
@@ -51,10 +51,10 @@ impl ReportableError for ComputeUserInfoParseError {
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ComputeUserInfoMaybeEndpoint {
pub(crate) user: RoleName,
pub(crate) endpoint_id: Option<EndpointId>,
pub(crate) options: NeonOptions,
pub struct ComputeUserInfoMaybeEndpoint {
pub user: RoleName,
pub endpoint_id: Option<EndpointId>,
pub options: NeonOptions,
}
impl ComputeUserInfoMaybeEndpoint {
@@ -83,7 +83,7 @@ pub(crate) fn endpoint_sni(
}
impl ComputeUserInfoMaybeEndpoint {
pub(crate) fn parse(
pub fn parse(
ctx: &RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,

View File

@@ -86,7 +86,7 @@ impl std::fmt::Display for Backend<'_, ()> {
impl<T> Backend<'_, T> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
pub fn as_ref(&self) -> Backend<'_, &T> {
match self {
Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x),
}
@@ -97,7 +97,7 @@ impl<'a, T> Backend<'a, T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
/// a function to a contained value.
pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
match self {
Self::Console(c, x) => Backend::Console(c, f(x)),
}
@@ -202,13 +202,13 @@ async fn authenticate_with_secret(
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
pub fn get_user(&self) -> &str {
match self {
Self::Console(_, user_info) => &user_info.user,
}
}
pub(crate) async fn authenticate(
pub async fn authenticate(
self,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,

View File

@@ -1,7 +1,25 @@
use std::{sync::Arc, time::Duration};
use clap::Parser;
use futures::TryStreamExt;
use proxy::{PglbCodec, PglbControlMessage, PglbMessage};
use pq_proto::FeStartupPacket;
use proxy::{
auth::{self, backend::AuthRateLimiter},
auth_proxy::{self, backend::MaybeOwned, AuthProxyStream, Backend},
config::{self, AuthenticationConfig, CacheOptions, ProjectInfoCacheOptions},
console::{
caches::ApiCaches,
locks::ApiLocks,
provider::{neon::Api, ConsoleBackend},
},
http,
metrics::Metrics,
proxy::NeonOptions,
rate_limiter::{RateBucketInfo, WakeComputeRateLimiter},
scram::threadpool::ThreadPool,
stream::AuthProxyStreamExt,
PglbCodec, PglbControlMessage, PglbMessage,
};
use quinn::{
crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream,
VarInt,
@@ -14,8 +32,75 @@ use tokio::{
};
use tokio_util::{codec::Framed, task::TaskTracker};
/// Neon proxy/router
#[derive(Parser)]
#[command(about)]
struct ProxyCliArgs {
/// cloud API endpoint for authenticating users
#[clap(
short,
long,
default_value = "http://localhost:3000/authenticate_proxy_request/"
)]
auth_endpoint: String,
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String,
/// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
#[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
wake_compute_lock: String,
/// timeout for scram authentication protocol
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
scram_protocol_timeout: tokio::time::Duration,
/// size of the threadpool for password hashing
#[clap(long, default_value_t = 4)]
scram_thread_pool_size: u8,
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_dynamic_rate_limiter: bool,
/// Endpoint rate limiter max number of requests per second.
///
/// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// cache for `project_info` (use `size=0` to disable)
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
project_info_cache: String,
/// cache for all valid endpoints
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
endpoint_cache_config: String,
/// Whether to retry the wake_compute request
#[clap(long, default_value = config::RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)]
wake_compute_retry: String,
}
#[tokio::main]
async fn main() {
let args = ProxyCliArgs::parse();
let server = "127.0.0.1:5634".parse().unwrap();
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
@@ -37,6 +122,63 @@ async fn main() {
let tasks = TaskTracker::new();
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
let backend = {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse().unwrap();
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse().unwrap();
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse().unwrap();
let caches = Box::leak(Box::new(ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse().unwrap();
let locks = Box::leak(Box::new(
ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)
.unwrap(),
));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse().unwrap();
let endpoint = http::Endpoint::new(url, http::new_client());
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit).unwrap();
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = Api::new(endpoint, caches, locks, wake_compute_endpoint_rate_limiter);
let api = ConsoleBackend::Console(api);
Backend::Console(MaybeOwned::Owned(api), ())
};
let auth = AuthenticationConfig {
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
};
let config = Box::leak(Box::new(Config { backend, auth }));
loop {
select! {
_ = int.recv() => break,
@@ -48,7 +190,7 @@ async fn main() {
}
stream = conn.accept_bi() => {
let (send, recv) = stream.unwrap();
tasks.spawn(handle_stream(send, recv));
tasks.spawn(handle_stream(config, send, recv));
}
}
}
@@ -66,6 +208,11 @@ async fn main() {
conn.close(VarInt::from_u32(1), b"graceful shutdown");
}
struct Config {
backend: Backend<'static, ()>,
auth: AuthenticationConfig,
}
#[derive(Copy, Clone, Debug)]
struct NoVerify;
@@ -104,12 +251,35 @@ impl danger::ServerCertVerifier for NoVerify {
}
}
async fn handle_stream(send: SendStream, recv: RecvStream) {
let mut stream = Framed::new(join(recv, send), PglbCodec);
async fn handle_stream(config: &'static Config, send: SendStream, recv: RecvStream) {
let mut stream: AuthProxyStream = Framed::new(join(recv, send), PglbCodec);
let first_msg = stream.try_next().await.unwrap();
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_first_msg))) = first_msg
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg
else {
panic!("invalid first msg")
};
let startup = stream.read_startup_packet().await.unwrap();
let FeStartupPacket::StartupMessage { version, params } = startup else {
panic!("invalid startup message")
};
// Extract credentials which we're going to use for auth.
let user_info = auth::ComputeUserInfoMaybeEndpoint {
user: params.get("user").unwrap().into(),
endpoint_id: first_msg
.server_name
.as_deref()
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
options: NeonOptions::parse_params(&params),
};
let user_info = config.backend.as_ref().map(|()| user_info);
let user_info = match user_info.authenticate(&mut stream, &config.auth).await {
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await.unwrap();
}
};
}

View File

@@ -377,10 +377,10 @@ async fn prepare_client_connection<P>(
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
impl NeonOptions {
pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
pub fn parse_params(params: &StartupMessageParams) -> Self {
params
.options_raw()
.map(Self::parse_from_iter)

View File

@@ -298,7 +298,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
}
}
pub(crate) trait AuthProxyStreamExt {
#[allow(async_fn_in_trait)]
pub trait AuthProxyStreamExt {
/// Write the message into an internal buffer, but don't flush the underlying stream.
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;