mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-01 20:40:37 +00:00
[local_proxy] update api for pg_session_jwt (#9359)
pg_session_jwt now: 1. Sets the JWK in a PGU_BACKEND session guc, no longer in the init() function. 2. JWK no longer needs the kid.
This commit is contained in:
@@ -2,8 +2,9 @@ use std::{io, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use p256::{ecdsa::SigningKey, elliptic_curve::JwkEcKey};
|
||||
use rand::rngs::OsRng;
|
||||
use tokio::net::{lookup_host, TcpStream};
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tracing::{debug, field::display, info};
|
||||
|
||||
use crate::{
|
||||
@@ -267,50 +268,58 @@ impl PoolingBackend {
|
||||
auth::Backend::Local(local) => local.node_info.clone(),
|
||||
};
|
||||
|
||||
let (key, jwk) = create_random_jwk();
|
||||
|
||||
let config = node_info
|
||||
.config
|
||||
.user(&conn_info.user_info.user)
|
||||
.dbname(&conn_info.dbname);
|
||||
.dbname(&conn_info.dbname)
|
||||
.options(&format!(
|
||||
"-c pg_session_jwt.jwk={}",
|
||||
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
|
||||
));
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
|
||||
let pid = client.get_process_id();
|
||||
tracing::Span::current().record("pid", pid);
|
||||
|
||||
let handle = local_conn_pool::poll_client(
|
||||
let mut handle = local_conn_pool::poll_client(
|
||||
self.local_pool.clone(),
|
||||
ctx,
|
||||
conn_info,
|
||||
client,
|
||||
connection,
|
||||
key,
|
||||
conn_id,
|
||||
node_info.aux.clone(),
|
||||
);
|
||||
|
||||
let kid = handle.get_client().get_process_id() as i64;
|
||||
let jwk = p256::PublicKey::from(handle.key().verifying_key()).to_jwk();
|
||||
{
|
||||
let (client, mut discard) = handle.inner();
|
||||
debug!("setting up backend session state");
|
||||
|
||||
debug!(kid, ?jwk, "setting up backend session state");
|
||||
// initiates the auth session
|
||||
if let Err(e) = client.query("select auth.init()", &[]).await {
|
||||
discard.discard();
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
// initiates the auth session
|
||||
handle
|
||||
.get_client()
|
||||
.query(
|
||||
"select auth.init($1, $2);",
|
||||
&[
|
||||
&kid as &(dyn ToSql + Sync),
|
||||
&tokio_postgres::types::Json(jwk),
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(?kid, "backend session state init");
|
||||
info!("backend session state initialized");
|
||||
}
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
}
|
||||
|
||||
fn create_random_jwk() -> (SigningKey, JwkEcKey) {
|
||||
let key = SigningKey::random(&mut OsRng);
|
||||
let jwk = p256::PublicKey::from(key.verifying_key()).to_jwk();
|
||||
(key, jwk)
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use futures::{future::poll_fn, Future};
|
||||
use indexmap::IndexMap;
|
||||
use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
|
||||
use p256::ecdsa::{Signature, SigningKey};
|
||||
use parking_lot::RwLock;
|
||||
use rand::rngs::OsRng;
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
use signature::Signer;
|
||||
use std::task::{ready, Poll};
|
||||
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||
@@ -12,14 +12,13 @@ use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use typed_json::json;
|
||||
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{context::RequestMonitoring, DbName, RoleName};
|
||||
|
||||
use tracing::{debug, error, warn, Span};
|
||||
use tracing::{error, warn, Span};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
@@ -245,12 +244,14 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn poll_client(
|
||||
global_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
client: tokio_postgres::Client,
|
||||
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
|
||||
key: SigningKey,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> LocalClient<tokio_postgres::Client> {
|
||||
@@ -346,8 +347,6 @@ pub(crate) fn poll_client(
|
||||
}
|
||||
.instrument(span));
|
||||
|
||||
let key = SigningKey::random(&mut OsRng);
|
||||
|
||||
let inner = ClientInner {
|
||||
inner: client,
|
||||
session: tx,
|
||||
@@ -430,13 +429,6 @@ impl<C: ClientInnerExt> LocalClient<C> {
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
(&mut inner.inner, Discard { conn_info, pool })
|
||||
}
|
||||
pub(crate) fn key(&self) -> &SigningKey {
|
||||
let inner = &self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed");
|
||||
&inner.key
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalClient<tokio_postgres::Client> {
|
||||
@@ -445,25 +437,9 @@ impl LocalClient<tokio_postgres::Client> {
|
||||
.inner
|
||||
.as_mut()
|
||||
.expect("client inner should not be removed");
|
||||
|
||||
inner.jti += 1;
|
||||
|
||||
let kid = inner.inner.get_process_id();
|
||||
let header = json!({"kid":kid}).to_string();
|
||||
|
||||
let mut payload = serde_json::from_slice::<serde_json::Map<String, Value>>(payload)
|
||||
.map_err(HttpConnError::JwtPayloadError)?;
|
||||
payload.insert("jti".to_string(), Value::Number(inner.jti.into()));
|
||||
let payload = Value::Object(payload).to_string();
|
||||
|
||||
debug!(
|
||||
kid,
|
||||
jti = inner.jti,
|
||||
?header,
|
||||
?payload,
|
||||
"signing new ephemeral JWT"
|
||||
);
|
||||
|
||||
let token = sign_jwt(&inner.key, header, payload);
|
||||
let token = resign_jwt(&inner.key, payload, inner.jti)?;
|
||||
|
||||
// initiates the auth session
|
||||
inner.inner.simple_query("discard all").await?;
|
||||
@@ -475,20 +451,74 @@ impl LocalClient<tokio_postgres::Client> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(kid, jti = inner.jti, "user session state init");
|
||||
let pid = inner.inner.get_process_id();
|
||||
info!(pid, jti = inner.jti, "user session state init");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String {
|
||||
let header = Base64UrlUnpadded::encode_string(header.as_bytes());
|
||||
let payload = Base64UrlUnpadded::encode_string(payload.as_bytes());
|
||||
/// implements relatively efficient in-place json object key upserting
|
||||
///
|
||||
/// only supports top-level keys
|
||||
fn upsert_json_object(
|
||||
payload: &[u8],
|
||||
key: &str,
|
||||
value: &RawValue,
|
||||
) -> Result<String, serde_json::Error> {
|
||||
let mut payload = serde_json::from_slice::<IndexMap<&str, &RawValue>>(payload)?;
|
||||
payload.insert(key, value);
|
||||
serde_json::to_string(&payload)
|
||||
}
|
||||
|
||||
let message = format!("{header}.{payload}");
|
||||
let sig: Signature = sk.sign(message.as_bytes());
|
||||
let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes());
|
||||
format!("{message}.{base64_sig}")
|
||||
fn resign_jwt(sk: &SigningKey, payload: &[u8], jti: u64) -> Result<String, HttpConnError> {
|
||||
let mut buffer = itoa::Buffer::new();
|
||||
|
||||
// encode the jti integer to a json rawvalue
|
||||
let jti = serde_json::from_str::<&RawValue>(buffer.format(jti)).unwrap();
|
||||
|
||||
// update the jti in-place
|
||||
let payload =
|
||||
upsert_json_object(payload, "jti", jti).map_err(HttpConnError::JwtPayloadError)?;
|
||||
|
||||
// sign the jwt
|
||||
let token = sign_jwt(sk, payload.as_bytes());
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
|
||||
let header_len = 20;
|
||||
let payload_len = Base64UrlUnpadded::encoded_len(payload);
|
||||
let signature_len = Base64UrlUnpadded::encoded_len(&[0; 64]);
|
||||
let total_len = header_len + payload_len + signature_len + 2;
|
||||
|
||||
let mut jwt = String::with_capacity(total_len);
|
||||
let cap = jwt.capacity();
|
||||
|
||||
// we only need an empty header with the alg specified.
|
||||
// base64url(r#"{"alg":"ES256"}"#) == "eyJhbGciOiJFUzI1NiJ9"
|
||||
jwt.push_str("eyJhbGciOiJFUzI1NiJ9.");
|
||||
|
||||
// encode the jwt payload in-place
|
||||
base64::encode_config_buf(payload, base64::URL_SAFE_NO_PAD, &mut jwt);
|
||||
|
||||
// create the signature from the encoded header || payload
|
||||
let sig: Signature = sk.sign(jwt.as_bytes());
|
||||
|
||||
jwt.push('.');
|
||||
|
||||
// encode the jwt signature in-place
|
||||
base64::encode_config_buf(sig.to_bytes(), base64::URL_SAFE_NO_PAD, &mut jwt);
|
||||
|
||||
debug_assert_eq!(
|
||||
jwt.len(),
|
||||
total_len,
|
||||
"the jwt len should match our expected len"
|
||||
);
|
||||
debug_assert_eq!(jwt.capacity(), cap, "the jwt capacity should not change");
|
||||
|
||||
jwt
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
@@ -509,14 +539,6 @@ impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub fn get_client(&self) -> &C {
|
||||
&self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed")
|
||||
.inner
|
||||
}
|
||||
|
||||
fn do_drop(&mut self) -> Option<impl FnOnce()> {
|
||||
let conn_info = self.conn_info.clone();
|
||||
let client = self
|
||||
@@ -542,3 +564,30 @@ impl<C: ClientInnerExt> Drop for LocalClient<C> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use p256::ecdsa::SigningKey;
|
||||
use typed_json::json;
|
||||
|
||||
use super::resign_jwt;
|
||||
|
||||
#[test]
|
||||
fn jwt_token_snapshot() {
|
||||
let key = SigningKey::from_bytes(&[1; 32].into()).unwrap();
|
||||
let data =
|
||||
json!({"foo":"bar","jti":"foo\nbar","nested":{"jti":"tricky nesting"}}).to_string();
|
||||
|
||||
let jwt = resign_jwt(&key, data.as_bytes(), 2).unwrap();
|
||||
|
||||
// To validate the JWT, copy the JWT string and paste it into https://jwt.io/.
|
||||
// In the public-key box, paste the following jwk public key
|
||||
// `{"kty":"EC","crv":"P-256","x":"b_A7lJJBzh2t1DUZ5pYOCoW0GmmgXDKBA6orzhWUyhY","y":"PE91OlW_AdxT9sCwx-7ni0DG_30lqW4igrmJzvccFEo"}`
|
||||
|
||||
// let pub_key = p256::ecdsa::VerifyingKey::from(&key);
|
||||
// let pub_key = p256::PublicKey::from(pub_key);
|
||||
// println!("{}", pub_key.to_jwk_string());
|
||||
|
||||
assert_eq!(jwt, "eyJhbGciOiJFUzI1NiJ9.eyJmb28iOiJiYXIiLCJqdGkiOjIsIm5lc3RlZCI6eyJqdGkiOiJ0cmlja3kgbmVzdGluZyJ9fQ.pYf0LxoJ8sDgpmsYOgrbNecOSipnPBEGwnZzB-JhW2cONrKlqRsgXwK8_cOsyolGy-hTTe8GXbWTl_UdpF5RyA");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user