[local_proxy] update api for pg_session_jwt (#9359)

pg_session_jwt now:
1. Sets the JWK in a PGU_BACKEND session guc, no longer in the init()
function.
2. JWK no longer needs the kid.
This commit is contained in:
Conrad Ludgate
2024-10-15 13:13:57 +01:00
committed by GitHub
parent ec4cc30de9
commit d92d36a315
7 changed files with 139 additions and 74 deletions

View File

@@ -2,8 +2,9 @@ use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use p256::{ecdsa::SigningKey, elliptic_curve::JwkEcKey};
use rand::rngs::OsRng;
use tokio::net::{lookup_host, TcpStream};
use tokio_postgres::types::ToSql;
use tracing::{debug, field::display, info};
use crate::{
@@ -267,50 +268,58 @@ impl PoolingBackend {
auth::Backend::Local(local) => local.node_info.clone(),
};
let (key, jwk) = create_random_jwk();
let config = node_info
.config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname);
.dbname(&conn_info.dbname)
.options(&format!(
"-c pg_session_jwt.jwk={}",
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
));
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
drop(pause);
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
let pid = client.get_process_id();
tracing::Span::current().record("pid", pid);
let handle = local_conn_pool::poll_client(
let mut handle = local_conn_pool::poll_client(
self.local_pool.clone(),
ctx,
conn_info,
client,
connection,
key,
conn_id,
node_info.aux.clone(),
);
let kid = handle.get_client().get_process_id() as i64;
let jwk = p256::PublicKey::from(handle.key().verifying_key()).to_jwk();
{
let (client, mut discard) = handle.inner();
debug!("setting up backend session state");
debug!(kid, ?jwk, "setting up backend session state");
// initiates the auth session
if let Err(e) = client.query("select auth.init()", &[]).await {
discard.discard();
return Err(e.into());
}
// initiates the auth session
handle
.get_client()
.query(
"select auth.init($1, $2);",
&[
&kid as &(dyn ToSql + Sync),
&tokio_postgres::types::Json(jwk),
],
)
.await?;
info!(?kid, "backend session state init");
info!("backend session state initialized");
}
Ok(handle)
}
}
fn create_random_jwk() -> (SigningKey, JwkEcKey) {
let key = SigningKey::random(&mut OsRng);
let jwk = p256::PublicKey::from(key.verifying_key()).to_jwk();
(key, jwk)
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]

View File

@@ -1,9 +1,9 @@
use futures::{future::poll_fn, Future};
use indexmap::IndexMap;
use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
use p256::ecdsa::{Signature, SigningKey};
use parking_lot::RwLock;
use rand::rngs::OsRng;
use serde_json::Value;
use serde_json::value::RawValue;
use signature::Signer;
use std::task::{ready, Poll};
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
@@ -12,14 +12,13 @@ use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::types::ToSql;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use typed_json::json;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, DbName, RoleName};
use tracing::{debug, error, warn, Span};
use tracing::{error, warn, Span};
use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
@@ -245,12 +244,14 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn poll_client(
global_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
client: tokio_postgres::Client,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
key: SigningKey,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> LocalClient<tokio_postgres::Client> {
@@ -346,8 +347,6 @@ pub(crate) fn poll_client(
}
.instrument(span));
let key = SigningKey::random(&mut OsRng);
let inner = ClientInner {
inner: client,
session: tx,
@@ -430,13 +429,6 @@ impl<C: ClientInnerExt> LocalClient<C> {
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { conn_info, pool })
}
pub(crate) fn key(&self) -> &SigningKey {
let inner = &self
.inner
.as_ref()
.expect("client inner should not be removed");
&inner.key
}
}
impl LocalClient<tokio_postgres::Client> {
@@ -445,25 +437,9 @@ impl LocalClient<tokio_postgres::Client> {
.inner
.as_mut()
.expect("client inner should not be removed");
inner.jti += 1;
let kid = inner.inner.get_process_id();
let header = json!({"kid":kid}).to_string();
let mut payload = serde_json::from_slice::<serde_json::Map<String, Value>>(payload)
.map_err(HttpConnError::JwtPayloadError)?;
payload.insert("jti".to_string(), Value::Number(inner.jti.into()));
let payload = Value::Object(payload).to_string();
debug!(
kid,
jti = inner.jti,
?header,
?payload,
"signing new ephemeral JWT"
);
let token = sign_jwt(&inner.key, header, payload);
let token = resign_jwt(&inner.key, payload, inner.jti)?;
// initiates the auth session
inner.inner.simple_query("discard all").await?;
@@ -475,20 +451,74 @@ impl LocalClient<tokio_postgres::Client> {
)
.await?;
info!(kid, jti = inner.jti, "user session state init");
let pid = inner.inner.get_process_id();
info!(pid, jti = inner.jti, "user session state init");
Ok(())
}
}
fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String {
let header = Base64UrlUnpadded::encode_string(header.as_bytes());
let payload = Base64UrlUnpadded::encode_string(payload.as_bytes());
/// implements relatively efficient in-place json object key upserting
///
/// only supports top-level keys
fn upsert_json_object(
payload: &[u8],
key: &str,
value: &RawValue,
) -> Result<String, serde_json::Error> {
let mut payload = serde_json::from_slice::<IndexMap<&str, &RawValue>>(payload)?;
payload.insert(key, value);
serde_json::to_string(&payload)
}
let message = format!("{header}.{payload}");
let sig: Signature = sk.sign(message.as_bytes());
let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes());
format!("{message}.{base64_sig}")
fn resign_jwt(sk: &SigningKey, payload: &[u8], jti: u64) -> Result<String, HttpConnError> {
let mut buffer = itoa::Buffer::new();
// encode the jti integer to a json rawvalue
let jti = serde_json::from_str::<&RawValue>(buffer.format(jti)).unwrap();
// update the jti in-place
let payload =
upsert_json_object(payload, "jti", jti).map_err(HttpConnError::JwtPayloadError)?;
// sign the jwt
let token = sign_jwt(sk, payload.as_bytes());
Ok(token)
}
fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
let header_len = 20;
let payload_len = Base64UrlUnpadded::encoded_len(payload);
let signature_len = Base64UrlUnpadded::encoded_len(&[0; 64]);
let total_len = header_len + payload_len + signature_len + 2;
let mut jwt = String::with_capacity(total_len);
let cap = jwt.capacity();
// we only need an empty header with the alg specified.
// base64url(r#"{"alg":"ES256"}"#) == "eyJhbGciOiJFUzI1NiJ9"
jwt.push_str("eyJhbGciOiJFUzI1NiJ9.");
// encode the jwt payload in-place
base64::encode_config_buf(payload, base64::URL_SAFE_NO_PAD, &mut jwt);
// create the signature from the encoded header || payload
let sig: Signature = sk.sign(jwt.as_bytes());
jwt.push('.');
// encode the jwt signature in-place
base64::encode_config_buf(sig.to_bytes(), base64::URL_SAFE_NO_PAD, &mut jwt);
debug_assert_eq!(
jwt.len(),
total_len,
"the jwt len should match our expected len"
);
debug_assert_eq!(jwt.capacity(), cap, "the jwt capacity should not change");
jwt
}
impl<C: ClientInnerExt> Discard<'_, C> {
@@ -509,14 +539,6 @@ impl<C: ClientInnerExt> Discard<'_, C> {
}
impl<C: ClientInnerExt> LocalClient<C> {
pub fn get_client(&self) -> &C {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
@@ -542,3 +564,30 @@ impl<C: ClientInnerExt> Drop for LocalClient<C> {
}
}
}
#[cfg(test)]
mod tests {
use p256::ecdsa::SigningKey;
use typed_json::json;
use super::resign_jwt;
#[test]
fn jwt_token_snapshot() {
let key = SigningKey::from_bytes(&[1; 32].into()).unwrap();
let data =
json!({"foo":"bar","jti":"foo\nbar","nested":{"jti":"tricky nesting"}}).to_string();
let jwt = resign_jwt(&key, data.as_bytes(), 2).unwrap();
// To validate the JWT, copy the JWT string and paste it into https://jwt.io/.
// In the public-key box, paste the following jwk public key
// `{"kty":"EC","crv":"P-256","x":"b_A7lJJBzh2t1DUZ5pYOCoW0GmmgXDKBA6orzhWUyhY","y":"PE91OlW_AdxT9sCwx-7ni0DG_30lqW4igrmJzvccFEo"}`
// let pub_key = p256::ecdsa::VerifyingKey::from(&key);
// let pub_key = p256::PublicKey::from(pub_key);
// println!("{}", pub_key.to_jwk_string());
assert_eq!(jwt, "eyJhbGciOiJFUzI1NiJ9.eyJmb28iOiJiYXIiLCJqdGkiOjIsIm5lc3RlZCI6eyJqdGkiOiJ0cmlja3kgbmVzdGluZyJ9fQ.pYf0LxoJ8sDgpmsYOgrbNecOSipnPBEGwnZzB-JhW2cONrKlqRsgXwK8_cOsyolGy-hTTe8GXbWTl_UdpF5RyA");
}
}