mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
build out auth proxy core logic
This commit is contained in:
@@ -4,9 +4,9 @@ pub mod backend;
|
||||
pub use backend::Backend;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::ComputeUserInfoMaybeEndpoint;
|
||||
pub(crate) use credentials::{
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
|
||||
ComputeUserInfoParseError, IpPattern,
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoParseError, IpPattern,
|
||||
};
|
||||
|
||||
mod password_hack;
|
||||
@@ -77,7 +77,7 @@ pub(crate) enum AuthErrorImpl {
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub(crate) struct AuthError(Box<AuthErrorImpl>);
|
||||
pub struct AuthError(Box<AuthErrorImpl>);
|
||||
|
||||
impl AuthError {
|
||||
pub(crate) fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
|
||||
|
||||
@@ -138,7 +138,7 @@ impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ComputeCredentials {
|
||||
pub struct ComputeCredentials {
|
||||
pub(crate) info: ComputeUserInfo,
|
||||
pub(crate) keys: ComputeCredentialKeys,
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
||||
pub(crate) enum ComputeUserInfoParseError {
|
||||
pub enum ComputeUserInfoParseError {
|
||||
#[error("Parameter '{0}' is missing in startup packet.")]
|
||||
MissingKey(&'static str),
|
||||
|
||||
@@ -51,10 +51,10 @@ impl ReportableError for ComputeUserInfoParseError {
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct ComputeUserInfoMaybeEndpoint {
|
||||
pub(crate) user: RoleName,
|
||||
pub(crate) endpoint_id: Option<EndpointId>,
|
||||
pub(crate) options: NeonOptions,
|
||||
pub struct ComputeUserInfoMaybeEndpoint {
|
||||
pub user: RoleName,
|
||||
pub endpoint_id: Option<EndpointId>,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
@@ -83,7 +83,7 @@ pub(crate) fn endpoint_sni(
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
pub(crate) fn parse(
|
||||
pub fn parse(
|
||||
ctx: &RequestMonitoring,
|
||||
params: &StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
|
||||
@@ -86,7 +86,7 @@ impl std::fmt::Display for Backend<'_, ()> {
|
||||
impl<T> Backend<'_, T> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
|
||||
pub fn as_ref(&self) -> Backend<'_, &T> {
|
||||
match self {
|
||||
Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x),
|
||||
}
|
||||
@@ -97,7 +97,7 @@ impl<'a, T> Backend<'a, T> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
|
||||
match self {
|
||||
Self::Console(c, x) => Backend::Console(c, f(x)),
|
||||
}
|
||||
@@ -202,13 +202,13 @@ async fn authenticate_with_secret(
|
||||
|
||||
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get username from the credentials.
|
||||
pub(crate) fn get_user(&self) -> &str {
|
||||
pub fn get_user(&self) -> &str {
|
||||
match self {
|
||||
Self::Console(_, user_info) => &user_info.user,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn authenticate(
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
client: &mut AuthProxyStream,
|
||||
config: &'static AuthenticationConfig,
|
||||
|
||||
@@ -1,7 +1,25 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use clap::Parser;
|
||||
use futures::TryStreamExt;
|
||||
use proxy::{PglbCodec, PglbControlMessage, PglbMessage};
|
||||
use pq_proto::FeStartupPacket;
|
||||
use proxy::{
|
||||
auth::{self, backend::AuthRateLimiter},
|
||||
auth_proxy::{self, backend::MaybeOwned, AuthProxyStream, Backend},
|
||||
config::{self, AuthenticationConfig, CacheOptions, ProjectInfoCacheOptions},
|
||||
console::{
|
||||
caches::ApiCaches,
|
||||
locks::ApiLocks,
|
||||
provider::{neon::Api, ConsoleBackend},
|
||||
},
|
||||
http,
|
||||
metrics::Metrics,
|
||||
proxy::NeonOptions,
|
||||
rate_limiter::{RateBucketInfo, WakeComputeRateLimiter},
|
||||
scram::threadpool::ThreadPool,
|
||||
stream::AuthProxyStreamExt,
|
||||
PglbCodec, PglbControlMessage, PglbMessage,
|
||||
};
|
||||
use quinn::{
|
||||
crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream,
|
||||
VarInt,
|
||||
@@ -14,8 +32,75 @@ use tokio::{
|
||||
};
|
||||
use tokio_util::{codec::Framed, task::TaskTracker};
|
||||
|
||||
/// Neon proxy/router
|
||||
#[derive(Parser)]
|
||||
#[command(about)]
|
||||
struct ProxyCliArgs {
|
||||
/// cloud API endpoint for authenticating users
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
default_value = "http://localhost:3000/authenticate_proxy_request/"
|
||||
)]
|
||||
auth_endpoint: String,
|
||||
/// timeout for the TLS handshake
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
handshake_timeout: tokio::time::Duration,
|
||||
/// cache for `wake_compute` api method (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
wake_compute_cache: String,
|
||||
/// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
|
||||
#[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
|
||||
wake_compute_lock: String,
|
||||
/// timeout for scram authentication protocol
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
scram_protocol_timeout: tokio::time::Duration,
|
||||
/// size of the threadpool for password hashing
|
||||
#[clap(long, default_value_t = 4)]
|
||||
scram_thread_pool_size: u8,
|
||||
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
|
||||
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
disable_dynamic_rate_limiter: bool,
|
||||
/// Endpoint rate limiter max number of requests per second.
|
||||
///
|
||||
/// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
|
||||
/// Can be given multiple times for different bucket sizes.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
|
||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Wake compute rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
wake_compute_limit: Vec<RateBucketInfo>,
|
||||
/// Whether the auth rate limiter actually takes effect (for testing)
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
auth_rate_limit_enabled: bool,
|
||||
/// Authentication rate limiter max number of hashes per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
|
||||
auth_rate_limit: Vec<RateBucketInfo>,
|
||||
/// The IP subnet to use when considering whether two IP addresses are considered the same.
|
||||
#[clap(long, default_value_t = 64)]
|
||||
auth_rate_limit_ip_subnet: u8,
|
||||
/// cache for `allowed_ips` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
allowed_ips_cache: String,
|
||||
/// cache for `role_secret` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
role_secret_cache: String,
|
||||
/// cache for `project_info` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
project_info_cache: String,
|
||||
/// cache for all valid endpoints
|
||||
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
|
||||
endpoint_cache_config: String,
|
||||
|
||||
/// Whether to retry the wake_compute request
|
||||
#[clap(long, default_value = config::RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)]
|
||||
wake_compute_retry: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let args = ProxyCliArgs::parse();
|
||||
|
||||
let server = "127.0.0.1:5634".parse().unwrap();
|
||||
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap();
|
||||
|
||||
@@ -37,6 +122,63 @@ async fn main() {
|
||||
|
||||
let tasks = TaskTracker::new();
|
||||
|
||||
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
||||
Metrics::install(thread_pool.metrics.clone());
|
||||
|
||||
let backend = {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse().unwrap();
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse().unwrap();
|
||||
let endpoint_cache_config: config::EndpointCacheConfig =
|
||||
args.endpoint_cache_config.parse().unwrap();
|
||||
|
||||
let caches = Box::leak(Box::new(ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
endpoint_cache_config,
|
||||
)));
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.wake_compute_lock.parse().unwrap();
|
||||
let locks = Box::leak(Box::new(
|
||||
ApiLocks::new(
|
||||
"wake_compute_lock",
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().wake_compute_lock,
|
||||
)
|
||||
.unwrap(),
|
||||
));
|
||||
tokio::spawn(locks.garbage_collect_worker());
|
||||
|
||||
let url = args.auth_endpoint.parse().unwrap();
|
||||
let endpoint = http::Endpoint::new(url, http::new_client());
|
||||
|
||||
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
|
||||
RateBucketInfo::validate(&mut wake_compute_rps_limit).unwrap();
|
||||
let wake_compute_endpoint_rate_limiter =
|
||||
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
|
||||
let api = Api::new(endpoint, caches, locks, wake_compute_endpoint_rate_limiter);
|
||||
let api = ConsoleBackend::Console(api);
|
||||
Backend::Console(MaybeOwned::Owned(api), ())
|
||||
};
|
||||
|
||||
let auth = AuthenticationConfig {
|
||||
thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(Config { backend, auth }));
|
||||
|
||||
loop {
|
||||
select! {
|
||||
_ = int.recv() => break,
|
||||
@@ -48,7 +190,7 @@ async fn main() {
|
||||
}
|
||||
stream = conn.accept_bi() => {
|
||||
let (send, recv) = stream.unwrap();
|
||||
tasks.spawn(handle_stream(send, recv));
|
||||
tasks.spawn(handle_stream(config, send, recv));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -66,6 +208,11 @@ async fn main() {
|
||||
conn.close(VarInt::from_u32(1), b"graceful shutdown");
|
||||
}
|
||||
|
||||
struct Config {
|
||||
backend: Backend<'static, ()>,
|
||||
auth: AuthenticationConfig,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
struct NoVerify;
|
||||
|
||||
@@ -104,12 +251,35 @@ impl danger::ServerCertVerifier for NoVerify {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stream(send: SendStream, recv: RecvStream) {
|
||||
let mut stream = Framed::new(join(recv, send), PglbCodec);
|
||||
async fn handle_stream(config: &'static Config, send: SendStream, recv: RecvStream) {
|
||||
let mut stream: AuthProxyStream = Framed::new(join(recv, send), PglbCodec);
|
||||
|
||||
let first_msg = stream.try_next().await.unwrap();
|
||||
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_first_msg))) = first_msg
|
||||
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(first_msg))) = first_msg
|
||||
else {
|
||||
panic!("invalid first msg")
|
||||
};
|
||||
|
||||
let startup = stream.read_startup_packet().await.unwrap();
|
||||
let FeStartupPacket::StartupMessage { version, params } = startup else {
|
||||
panic!("invalid startup message")
|
||||
};
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let user_info = auth::ComputeUserInfoMaybeEndpoint {
|
||||
user: params.get("user").unwrap().into(),
|
||||
endpoint_id: first_msg
|
||||
.server_name
|
||||
.as_deref()
|
||||
.map(|h| h.split_once('.').map_or(h, |(ep, _)| ep).into()),
|
||||
options: NeonOptions::parse_params(¶ms),
|
||||
};
|
||||
|
||||
let user_info = config.backend.as_ref().map(|()| user_info);
|
||||
let user_info = match user_info.authenticate(&mut stream, &config.auth).await {
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
return stream.throw_error(e).await.unwrap();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -377,10 +377,10 @@ async fn prepare_client_connection<P>(
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
|
||||
impl NeonOptions {
|
||||
pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
pub fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
params
|
||||
.options_raw()
|
||||
.map(Self::parse_from_iter)
|
||||
|
||||
@@ -298,7 +298,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait AuthProxyStreamExt {
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait AuthProxyStreamExt {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user