proxy: remove auth info from http conn info & fixup jwt api trait (#9047)

misc changes split out from #8855 

- **allow cloning the request context in a read-only fashion for
background tasks**
- **propagate endpoint and request context through the jwk cache**
- **only allow password based auth for md5 during testing**
- **remove auth info from conn info**
This commit is contained in:
Conrad Ludgate
2024-09-19 16:09:30 +01:00
committed by GitHub
parent ff9f065c43
commit 0a1ca7670c
11 changed files with 127 additions and 52 deletions

View File

@@ -163,6 +163,7 @@ impl ComputeUserInfo {
}
pub(crate) enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
None,
@@ -293,16 +294,10 @@ async fn auth_quirks(
// We now expect to see a very specific payload in the place of password.
let (info, unauthenticated_password) = match user_info.try_into() {
Err(info) => {
let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
ctx.set_endpoint_id(res.info.endpoint.clone());
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
unreachable!("password hack should return a password")
}
};
(res.info, Some(password))
let (info, password) =
hacks::password_hack_no_authentication(ctx, info, client).await?;
ctx.set_endpoint_id(info.endpoint.clone());
(info, Some(password))
}
Ok(info) => (info, None),
};

View File

@@ -1,6 +1,4 @@
use super::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
use super::{ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint};
use crate::{
auth::{self, AuthFlow},
config::AuthenticationConfig,
@@ -63,7 +61,7 @@ pub(crate) async fn password_hack_no_authentication(
ctx: &RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<(ComputeUserInfo, Vec<u8>)> {
warn!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
@@ -79,12 +77,12 @@ pub(crate) async fn password_hack_no_authentication(
info!(project = &*payload.endpoint, "received missing parameter");
// Report tentative success; compute node will check the password anyway.
Ok(ComputeCredentials {
info: ComputeUserInfo {
Ok((
ComputeUserInfo {
user: info.user,
options: info.options,
endpoint: payload.endpoint,
},
keys: ComputeCredentialKeys::Password(payload.password),
})
payload.password,
))
}

View File

@@ -25,6 +25,8 @@ const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
fn fetch_auth_rules(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
role_name: RoleName,
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
}
@@ -101,7 +103,9 @@ impl JwkCacheEntryLock {
async fn renew_jwks<F: FetchAuthRules>(
&self,
_permit: JwkRenewalPermit<'_>,
ctx: &RequestMonitoring,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
auth_rules: &F,
) -> anyhow::Result<Arc<JwkCacheEntry>> {
@@ -115,7 +119,9 @@ impl JwkCacheEntryLock {
}
}
let rules = auth_rules.fetch_auth_rules(role_name).await?;
let rules = auth_rules
.fetch_auth_rules(ctx, endpoint, role_name)
.await?;
let mut key_sets =
ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
// TODO(conrad): run concurrently
@@ -166,6 +172,7 @@ impl JwkCacheEntryLock {
self: &Arc<Self>,
ctx: &RequestMonitoring,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
fetch: &F,
) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
@@ -176,7 +183,9 @@ impl JwkCacheEntryLock {
let Some(cached) = guard else {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
return self.renew_jwks(permit, client, role_name, fetch).await;
return self
.renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
.await;
};
let last_update = now.duration_since(cached.last_retrieved);
@@ -187,7 +196,9 @@ impl JwkCacheEntryLock {
let permit = self.acquire_permit().await;
// it's been too long since we checked the keys. wait for them to update.
return self.renew_jwks(permit, client, role_name, fetch).await;
return self
.renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
.await;
}
// every 5 minutes we should spawn a job to eagerly update the token.
@@ -198,8 +209,12 @@ impl JwkCacheEntryLock {
let entry = self.clone();
let client = client.clone();
let fetch = fetch.clone();
let ctx = ctx.clone();
tokio::spawn(async move {
if let Err(e) = entry.renew_jwks(permit, &client, role_name, &fetch).await {
if let Err(e) = entry
.renew_jwks(permit, &ctx, &client, endpoint, role_name, &fetch)
.await
{
tracing::warn!(error=?e, "could not fetch JWKs in background job");
}
});
@@ -216,6 +231,7 @@ impl JwkCacheEntryLock {
ctx: &RequestMonitoring,
jwt: &str,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
fetch: &F,
) -> Result<(), anyhow::Error> {
@@ -242,7 +258,7 @@ impl JwkCacheEntryLock {
let kid = header.key_id.context("missing key id")?;
let mut guard = self
.get_or_update_jwk_cache(ctx, client, role_name.clone(), fetch)
.get_or_update_jwk_cache(ctx, client, endpoint.clone(), role_name.clone(), fetch)
.await?;
// get the key from the JWKs if possible. If not, wait for the keys to update.
@@ -254,7 +270,14 @@ impl JwkCacheEntryLock {
let permit = self.acquire_permit().await;
guard = self
.renew_jwks(permit, client, role_name.clone(), fetch)
.renew_jwks(
permit,
ctx,
client,
endpoint.clone(),
role_name.clone(),
fetch,
)
.await?;
}
_ => {
@@ -318,7 +341,7 @@ impl JwkCache {
jwt: &str,
) -> Result<(), anyhow::Error> {
// try with just a read lock first
let key = (endpoint, role_name.clone());
let key = (endpoint.clone(), role_name.clone());
let entry = self.map.get(&key).as_deref().map(Arc::clone);
let entry = entry.unwrap_or_else(|| {
// acquire a write lock after to insert.
@@ -327,7 +350,7 @@ impl JwkCache {
});
entry
.check_jwt(ctx, jwt, &self.client, role_name, fetch)
.check_jwt(ctx, jwt, &self.client, endpoint, role_name, fetch)
.await
}
}
@@ -688,6 +711,8 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
impl FetchAuthRules for Fetch {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
_role_name: RoleName,
) -> anyhow::Result<Vec<AuthRule>> {
Ok(vec![
@@ -706,6 +731,7 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
}
let role_name = RoleName::from("user");
let endpoint = EndpointId::from("ep");
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
@@ -715,6 +741,7 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
&RequestMonitoring::test(),
&token,
&client,
endpoint.clone(),
role_name.clone(),
&Fetch(addr),
)

View File

@@ -9,8 +9,9 @@ use crate::{
messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo},
NodeInfo,
},
context::RequestMonitoring,
intern::{BranchIdInt, BranchIdTag, EndpointIdTag, InternId, ProjectIdInt, ProjectIdTag},
RoleName,
EndpointId, RoleName,
};
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
@@ -57,7 +58,12 @@ pub struct JwksRoleSettings {
}
impl FetchAuthRules for StaticAuthRules {
async fn fetch_auth_rules(&self, role_name: RoleName) -> anyhow::Result<Vec<AuthRule>> {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
role_name: RoleName,
) -> anyhow::Result<Vec<AuthRule>> {
let mappings = JWKS_ROLE_MAP.load();
let role_mappings = mappings
.as_deref()

View File

@@ -303,6 +303,7 @@ impl NodeInfo {
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
match keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
ComputeCredentialKeys::None => &mut self.config,

View File

@@ -79,6 +79,40 @@ pub(crate) enum AuthMethod {
Cleartext,
}
impl Clone for RequestMonitoring {
fn clone(&self) -> Self {
let inner = self.0.try_lock().expect("should not deadlock");
let new = RequestMonitoringInner {
peer_addr: inner.peer_addr,
session_id: inner.session_id,
protocol: inner.protocol,
first_packet: inner.first_packet,
region: inner.region,
span: info_span!("background_task"),
project: inner.project,
branch: inner.branch,
endpoint_id: inner.endpoint_id.clone(),
dbname: inner.dbname.clone(),
user: inner.user.clone(),
application: inner.application.clone(),
error_kind: inner.error_kind,
auth_method: inner.auth_method.clone(),
success: inner.success,
rejected: inner.rejected,
cold_start_info: inner.cold_start_info,
pg_options: inner.pg_options.clone(),
sender: None,
disconnect_sender: None,
latency_timer: LatencyTimer::noop(inner.protocol),
disconnect_timestamp: inner.disconnect_timestamp,
};
Self(TryLock::new(new))
}
}
impl RequestMonitoring {
pub fn new(
session_id: Uuid,

View File

@@ -397,6 +397,8 @@ pub struct LatencyTimer {
protocol: Protocol,
cold_start_info: ColdStartInfo,
outcome: ConnectOutcome,
skip_reporting: bool,
}
impl LatencyTimer {
@@ -409,6 +411,20 @@ impl LatencyTimer {
cold_start_info: ColdStartInfo::Unknown,
// assume failed unless otherwise specified
outcome: ConnectOutcome::Failed,
skip_reporting: false,
}
}
pub(crate) fn noop(protocol: Protocol) -> Self {
Self {
start: time::Instant::now(),
stop: None,
accumulated: Accumulated::default(),
protocol,
cold_start_info: ColdStartInfo::Unknown,
// assume failed unless otherwise specified
outcome: ConnectOutcome::Failed,
skip_reporting: true,
}
}
@@ -443,6 +459,10 @@ pub enum ConnectOutcome {
impl Drop for LatencyTimer {
fn drop(&mut self) {
if self.skip_reporting {
return;
}
let duration = self
.stop
.unwrap_or_else(time::Instant::now)

View File

@@ -27,7 +27,7 @@ use crate::{
Host,
};
use super::conn_pool::{poll_client, AuthData, Client, ConnInfo, GlobalConnPool};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
pub(crate) struct PoolingBackend {
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
@@ -274,13 +274,6 @@ impl ConnectMechanism for TokioMechanism {
.dbname(&self.conn_info.dbname)
.connect_timeout(timeout);
match &self.conn_info.auth {
AuthData::Jwt(_) => {}
AuthData::Password(pw) => {
config.password(pw);
}
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(tokio_postgres::NoTls).await;
drop(pause);

View File

@@ -29,11 +29,16 @@ use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {
pub(crate) conn_info: ConnInfo,
pub(crate) auth: AuthData,
}
#[derive(Debug, Clone)]
pub(crate) struct ConnInfo {
pub(crate) user_info: ComputeUserInfo,
pub(crate) dbname: DbName,
pub(crate) auth: AuthData,
}
#[derive(Debug, Clone)]
@@ -787,7 +792,6 @@ mod tests {
options: NeonOptions::default(),
},
dbname: "dbname".into(),
auth: AuthData::Password("password".as_bytes().into()),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
@@ -845,7 +849,6 @@ mod tests {
options: NeonOptions::default(),
},
dbname: "dbname".into(),
auth: AuthData::Password("password".as_bytes().into()),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),

View File

@@ -60,6 +60,7 @@ use super::backend::PoolingBackend;
use super::conn_pool::AuthData;
use super::conn_pool::Client;
use super::conn_pool::ConnInfo;
use super::conn_pool::ConnInfoWithAuth;
use super::http_util::json_response;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
@@ -148,7 +149,7 @@ fn get_conn_info(
ctx: &RequestMonitoring,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
) -> Result<ConnInfo, ConnInfoError> {
) -> Result<ConnInfoWithAuth, ConnInfoError> {
// HTTP only uses cleartext (for now and likely always)
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
@@ -235,11 +236,8 @@ fn get_conn_info(
options: options.unwrap_or_default(),
};
Ok(ConnInfo {
user_info,
dbname,
auth,
})
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}
// TODO: return different http error codes
@@ -523,7 +521,10 @@ async fn handle_inner(
// TLS config should be there.
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
info!(user = conn_info.user_info.user.as_str(), "credentials");
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
);
// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
@@ -568,20 +569,20 @@ async fn handle_inner(
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
&conn_info.conn_info.user_info,
pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
.await?
}
};
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else