proxy: local auth backend (#8806)

Adds a Local authentication backend. Updates http to extract JWT bearer
tokens and passes them to the local backend to validate.
This commit is contained in:
Conrad Ludgate
2024-08-23 19:48:06 +01:00
committed by GitHub
parent 0aa1450936
commit 701cb61b57
9 changed files with 240 additions and 46 deletions

View File

@@ -2,6 +2,7 @@ mod classic;
mod hacks;
pub mod jwt;
mod link;
pub mod local;
use std::net::IpAddr;
use std::sync::Arc;
@@ -9,6 +10,7 @@ use std::time::Duration;
use ipnet::{Ipv4Net, Ipv6Net};
pub use link::LinkAuthError;
use local::LocalBackend;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::AuthKeys;
use tracing::{info, warn};
@@ -68,6 +70,8 @@ pub enum BackendType<'a, T, D> {
Console(MaybeOwned<'a, ConsoleBackend>, T),
/// Authentication via a web browser.
Link(MaybeOwned<'a, url::ApiUrl>, D),
/// Local proxy uses configured auth credentials and does not wake compute
Local(MaybeOwned<'a, LocalBackend>),
}
pub trait TestBackend: Send + Sync + 'static {
@@ -93,6 +97,7 @@ impl std::fmt::Display for BackendType<'_, (), ()> {
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
Self::Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
Self::Local(_) => fmt.debug_tuple("Local").finish(),
}
}
}
@@ -104,6 +109,7 @@ impl<T, D> BackendType<'_, T, D> {
match self {
Self::Console(c, x) => BackendType::Console(MaybeOwned::Borrowed(c), x),
Self::Link(c, x) => BackendType::Link(MaybeOwned::Borrowed(c), x),
Self::Local(l) => BackendType::Local(MaybeOwned::Borrowed(l)),
}
}
}
@@ -116,6 +122,7 @@ impl<'a, T, D> BackendType<'a, T, D> {
match self {
Self::Console(c, x) => BackendType::Console(c, f(x)),
Self::Link(c, x) => BackendType::Link(c, x),
Self::Local(l) => BackendType::Local(l),
}
}
}
@@ -126,6 +133,7 @@ impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
match self {
Self::Console(c, x) => x.map(|x| BackendType::Console(c, x)),
Self::Link(c, x) => Ok(BackendType::Link(c, x)),
Self::Local(l) => Ok(BackendType::Local(l)),
}
}
}
@@ -157,6 +165,7 @@ impl ComputeUserInfo {
pub enum ComputeCredentialKeys {
Password(Vec<u8>),
AuthKeys(AuthKeys),
None,
}
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
@@ -289,7 +298,7 @@ async fn auth_quirks(
ctx.set_endpoint_id(res.info.endpoint.clone());
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
ComputeCredentialKeys::AuthKeys(_) => {
ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
unreachable!("password hack should return a password")
}
};
@@ -401,6 +410,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
match self {
Self::Console(_, user_info) => user_info.endpoint_id.clone(),
Self::Link(_, _) => Some("link".into()),
Self::Local(_) => Some("local".into()),
}
}
@@ -409,6 +419,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
match self {
Self::Console(_, user_info) => &user_info.user,
Self::Link(_, _) => "link",
Self::Local(_) => "local",
}
}
@@ -450,6 +461,9 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
BackendType::Link(url, info)
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
}
};
info!("user successfully authenticated");
@@ -465,6 +479,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
match self {
Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Self::Link(_, _) => Ok(Cached::new_uncached(None)),
Self::Local(_) => Ok(Cached::new_uncached(None)),
}
}
@@ -475,6 +490,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
match self {
Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Self::Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
}
@@ -488,13 +504,15 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
match self {
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Link(_, info) => Ok(Cached::new_uncached(info.clone())),
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::Console(_, creds) => Some(&creds.keys),
Self::Link(_, _) => None,
Self::Console(_, creds) => &creds.keys,
Self::Link(_, _) => &ComputeCredentialKeys::None,
Self::Local(_) => &ComputeCredentialKeys::None,
}
}
}
@@ -508,13 +526,15 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
match self {
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::Console(_, creds) => Some(&creds.keys),
Self::Link(_, _) => None,
Self::Console(_, creds) => &creds.keys,
Self::Link(_, _) => &ComputeCredentialKeys::None,
Self::Local(_) => &ComputeCredentialKeys::None,
}
}
}

View File

@@ -0,0 +1,79 @@
use std::{collections::HashMap, net::SocketAddr};
use anyhow::Context;
use arc_swap::ArcSwapOption;
use crate::{
compute::ConnCfg,
console::{
messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo},
NodeInfo,
},
intern::{BranchIdInt, BranchIdTag, EndpointIdTag, InternId, ProjectIdInt, ProjectIdTag},
RoleName,
};
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
pub struct LocalBackend {
pub jwks_cache: JwkCache,
pub postgres_addr: SocketAddr,
pub node_info: NodeInfo,
}
impl LocalBackend {
pub fn new(postgres_addr: SocketAddr) -> Self {
LocalBackend {
jwks_cache: JwkCache::default(),
postgres_addr,
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();
cfg.host(&postgres_addr.ip().to_string());
cfg.port(postgres_addr.port());
cfg
},
// TODO(conrad): make this better reflect compute info rather than endpoint info.
aux: MetricsAuxInfo {
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
project_id: ProjectIdTag::get_interner().get_or_intern("local"),
branch_id: BranchIdTag::get_interner().get_or_intern("local"),
cold_start_info: ColdStartInfo::WarmCached,
},
allow_self_signed_compute: false,
},
}
}
}
#[derive(Clone, Copy)]
pub struct StaticAuthRules;
pub static JWKS_ROLE_MAP: ArcSwapOption<JwksRoleSettings> = ArcSwapOption::const_empty();
#[derive(Debug, Clone)]
pub struct JwksRoleSettings {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
}
impl FetchAuthRules for StaticAuthRules {
async fn fetch_auth_rules(&self, role_name: RoleName) -> anyhow::Result<Vec<AuthRule>> {
let mappings = JWKS_ROLE_MAP.load();
let role_mappings = mappings
.as_deref()
.and_then(|m| m.roles.get(&role_name))
.context("JWKs settings for this role were not configured")?;
let mut rules = vec![];
for setting in &role_mappings.jwks {
rules.push(AuthRule {
id: setting.id.clone(),
jwks_url: setting.jwks_url.clone(),
audience: setting.jwt_audience.clone(),
});
}
Ok(rules)
}
}

View File

@@ -1,11 +1,13 @@
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::{self, Display};
use crate::auth::IpPattern;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
use crate::proxy::retry::CouldRetry;
use crate::RoleName;
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
@@ -341,6 +343,26 @@ impl ColdStartInfo {
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct JwksRoleMapping {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct EndpointJwksResponse {
pub jwks: Vec<JwksSettings>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct JwksSettings {
pub id: String,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
pub jwks_url: url::Url,
pub provider_name: String,
pub jwt_audience: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -305,6 +305,7 @@ impl NodeInfo {
match keys {
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
ComputeCredentialKeys::None => &mut self.config,
};
}
}

View File

@@ -61,7 +61,7 @@ pub trait ComputeConnectBackend {
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
fn get_keys(&self) -> &ComputeCredentialKeys;
}
pub struct TcpMechanism<'a> {
@@ -112,9 +112,8 @@ where
let mut num_retries = 0;
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
if let Some(keys) = user_info.get_keys() {
node_info.set_keys(keys);
}
node_info.set_keys(user_info.get_keys());
node_info.allow_self_signed_compute = allow_self_signed_compute;
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
mechanism.update_connect_config(&mut node_info.config);

View File

@@ -407,7 +407,7 @@ async fn request_handler(
.header("Access-Control-Allow-Origin", "*")
.header(
"Access-Control-Allow-Headers",
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
"Authorization, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@@ -4,7 +4,10 @@ use async_trait::async_trait;
use tracing::{field::display, info};
use crate::{
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
auth::{
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
check_peer_addr_is_in_list, AuthError,
},
compute,
config::{AuthenticationConfig, ProxyConfig},
console::{
@@ -24,7 +27,7 @@ use crate::{
Host,
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
use super::conn_pool::{poll_client, AuthData, Client, ConnInfo, GlobalConnPool};
pub struct PoolingBackend {
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
@@ -33,13 +36,14 @@ pub struct PoolingBackend {
}
impl PoolingBackend {
pub async fn authenticate(
pub async fn authenticate_with_password(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
conn_info: &ConnInfo,
user_info: &ComputeUserInfo,
password: &[u8],
) -> Result<ComputeCredentials, AuthError> {
let user_info = conn_info.user_info.clone();
let user_info = user_info.clone();
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
@@ -47,7 +51,7 @@ impl PoolingBackend {
}
if !self
.endpoint_rate_limiter
.check(conn_info.user_info.endpoint.clone().into(), 1)
.check(user_info.endpoint.clone().into(), 1)
{
return Err(AuthError::too_many_connections());
}
@@ -70,14 +74,10 @@ impl PoolingBackend {
return Err(AuthError::auth_failed(&*user_info.user));
}
};
let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&config.thread_pool,
ep,
&conn_info.password,
secret,
)
.await?;
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome =
crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");
@@ -85,7 +85,7 @@ impl PoolingBackend {
}
crate::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
Err(AuthError::auth_failed(&*conn_info.user_info.user))
Err(AuthError::auth_failed(&*user_info.user))
}
};
res.map(|key| ComputeCredentials {
@@ -94,6 +94,39 @@ impl PoolingBackend {
})
}
pub async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
jwt: &str,
) -> Result<ComputeCredentials, AuthError> {
match &self.config.auth_backend {
crate::auth::BackendType::Console(_, _) => {
Err(AuthError::auth_failed("JWT login is not yet supported"))
}
crate::auth::BackendType::Link(_, _) => Err(AuthError::auth_failed(
"JWT login over link proxy is not supported",
)),
crate::auth::BackendType::Local(cache) => {
cache
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
user_info.user.clone(),
&StaticAuthRules,
jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
}
}
}
// Wake up the destination if needed. Code here is a bit involved because
// we reuse the code from the usual proxy and we need to prepare few structures
// that this code expects.
@@ -232,10 +265,16 @@ impl ConnectMechanism for TokioMechanism {
let mut config = (*node_info.config).clone();
let config = config
.user(&self.conn_info.user_info.user)
.password(&*self.conn_info.password)
.dbname(&self.conn_info.dbname)
.connect_timeout(timeout);
match &self.conn_info.auth {
AuthData::Jwt(_) => {}
AuthData::Password(pw) => {
config.password(pw);
}
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(tokio_postgres::NoTls).await;
drop(pause);

View File

@@ -33,7 +33,13 @@ use super::backend::HttpConnError;
pub struct ConnInfo {
pub user_info: ComputeUserInfo,
pub dbname: DbName,
pub password: SmallVec<[u8; 16]>,
pub auth: AuthData,
}
#[derive(Debug, Clone)]
pub enum AuthData {
Password(SmallVec<[u8; 16]>),
Jwt(String),
}
impl ConnInfo {
@@ -778,7 +784,7 @@ mod tests {
options: Default::default(),
},
dbname: "dbname".into(),
password: "password".as_bytes().into(),
auth: AuthData::Password("password".as_bytes().into()),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
@@ -836,7 +842,7 @@ mod tests {
options: Default::default(),
},
dbname: "dbname".into(),
password: "password".as_bytes().into(),
auth: AuthData::Password("password".as_bytes().into()),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),

View File

@@ -7,6 +7,7 @@ use futures::future::try_join;
use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use http::header::AUTHORIZATION;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
@@ -56,6 +57,7 @@ use crate::DbName;
use crate::RoleName;
use super::backend::PoolingBackend;
use super::conn_pool::AuthData;
use super::conn_pool::Client;
use super::conn_pool::ConnInfo;
use super::http_util::json_response;
@@ -88,6 +90,7 @@ enum Payload {
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
@@ -109,7 +112,7 @@ where
#[derive(Debug, thiserror::Error)]
pub enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static str),
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
@@ -153,10 +156,10 @@ fn get_conn_info(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
let connection_string = headers
.get("Neon-Connection-String")
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
let connection_url = Url::parse(connection_string)?;
@@ -179,10 +182,23 @@ fn get_conn_info(
}
ctx.set_user(username.clone());
let password = connection_url
.password()
.ok_or(ConnInfoError::MissingPassword)?;
let password = urlencoding::decode_binary(password.as_bytes());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingPassword)?
.into(),
)
} else if let Some(pass) = connection_url.password() {
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else {
return Err(ConnInfoError::MissingPassword);
};
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
@@ -225,10 +241,7 @@ fn get_conn_info(
Ok(ConnInfo {
user_info,
dbname,
password: match password {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
},
auth,
})
}
@@ -550,9 +563,24 @@ async fn handle_inner(
let authenticate_and_connect = Box::pin(
async {
let keys = backend
.authenticate(ctx, &config.authentication_config, &conn_info)
.await?;
let keys = match &conn_info.auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await?
}
};
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;