From d92d36a315f955cd39bc6f6b0948bae25ed195ad Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 15 Oct 2024 13:13:57 +0100 Subject: [PATCH] [local_proxy] update api for pg_session_jwt (#9359) pg_session_jwt now: 1. Sets the JWK in a PGU_BACKEND session guc, no longer in the init() function. 2. JWK no longer needs the kid. --- Cargo.lock | 7 +- Cargo.toml | 1 + compute/Dockerfile.compute-node | 4 +- proxy/Cargo.toml | 3 +- proxy/src/serverless/backend.rs | 49 ++++---- proxy/src/serverless/local_conn_pool.rs | 143 ++++++++++++++++-------- workspace_hack/Cargo.toml | 6 +- 7 files changed, 139 insertions(+), 74 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5edf5cf7b4..7e772814ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2695,6 +2695,7 @@ checksum = "ad227c3af19d4914570ad36d30409928b75967c298feb9ea1969db3a610bb14e" dependencies = [ "equivalent", "hashbrown 0.14.5", + "serde", ] [[package]] @@ -2794,9 +2795,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.6" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" @@ -4296,6 +4297,7 @@ dependencies = [ "indexmap 2.0.1", "ipnet", "itertools 0.10.5", + "itoa", "jose-jwa", "jose-jwk", "lasso", @@ -7307,6 +7309,7 @@ dependencies = [ "hyper 1.4.1", "hyper-util", "indexmap 1.9.3", + "indexmap 2.0.1", "itertools 0.12.1", "lazy_static", "libc", diff --git a/Cargo.toml b/Cargo.toml index dde80f5020..a1a974b33b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,6 +107,7 @@ indexmap = "2" indoc = "2" ipnet = "2.9.0" itertools = "0.10" +itoa = "1.0.11" jsonwebtoken = "9" lasso = "0.7" libc = "0.2" diff --git a/compute/Dockerfile.compute-node b/compute/Dockerfile.compute-node index 91528618da..412c64eda4 100644 --- a/compute/Dockerfile.compute-node +++ b/compute/Dockerfile.compute-node @@ -929,8 +929,8 @@ ARG PG_VERSION RUN case "${PG_VERSION}" in "v17") \ echo "pg_session_jwt does not yet have a release that supports pg17" && exit 0;; \ esac && \ - wget https://github.com/neondatabase/pg_session_jwt/archive/ff0a72440e8ff584dab24b3f9b7c00c56c660b8e.tar.gz -O pg_session_jwt.tar.gz && \ - echo "1fbb2b5a339263bcf6daa847fad8bccbc0b451cea6a62e6d3bf232b0087f05cb pg_session_jwt.tar.gz" | sha256sum --check && \ + wget https://github.com/neondatabase/pg_session_jwt/archive/5aee2625af38213650e1a07ae038fdc427250ee4.tar.gz -O pg_session_jwt.tar.gz && \ + echo "5d91b10bc1347d36cffc456cb87bec25047935d6503dc652ca046f04760828e7 pg_session_jwt.tar.gz" | sha256sum --check && \ mkdir pg_session_jwt-src && cd pg_session_jwt-src && tar xzf ../pg_session_jwt.tar.gz --strip-components=1 -C . && \ sed -i 's/pgrx = "=0.11.3"/pgrx = { version = "=0.11.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ cargo pgrx install --release diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 963fb94a7d..e25d2fcbab 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -42,9 +42,10 @@ hyper0.workspace = true hyper = { workspace = true, features = ["server", "http1", "http2"] } hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] } http-body-util = { version = "0.1" } -indexmap.workspace = true +indexmap = { workspace = true, features = ["serde"] } ipnet.workspace = true itertools.workspace = true +itoa.workspace = true lasso = { workspace = true, features = ["multi-threaded"] } measured = { workspace = true, features = ["lasso"] } metrics.workspace = true diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 2b060af9e1..927854897f 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -2,8 +2,9 @@ use std::{io, sync::Arc, time::Duration}; use async_trait::async_trait; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; +use p256::{ecdsa::SigningKey, elliptic_curve::JwkEcKey}; +use rand::rngs::OsRng; use tokio::net::{lookup_host, TcpStream}; -use tokio_postgres::types::ToSql; use tracing::{debug, field::display, info}; use crate::{ @@ -267,50 +268,58 @@ impl PoolingBackend { auth::Backend::Local(local) => local.node_info.clone(), }; + let (key, jwk) = create_random_jwk(); + let config = node_info .config .user(&conn_info.user_info.user) - .dbname(&conn_info.dbname); + .dbname(&conn_info.dbname) + .options(&format!( + "-c pg_session_jwt.jwk={}", + serde_json::to_string(&jwk).expect("serializing jwk to json should not fail") + )); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let (client, connection) = config.connect(tokio_postgres::NoTls).await?; drop(pause); - tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); + let pid = client.get_process_id(); + tracing::Span::current().record("pid", pid); - let handle = local_conn_pool::poll_client( + let mut handle = local_conn_pool::poll_client( self.local_pool.clone(), ctx, conn_info, client, connection, + key, conn_id, node_info.aux.clone(), ); - let kid = handle.get_client().get_process_id() as i64; - let jwk = p256::PublicKey::from(handle.key().verifying_key()).to_jwk(); + { + let (client, mut discard) = handle.inner(); + debug!("setting up backend session state"); - debug!(kid, ?jwk, "setting up backend session state"); + // initiates the auth session + if let Err(e) = client.query("select auth.init()", &[]).await { + discard.discard(); + return Err(e.into()); + } - // initiates the auth session - handle - .get_client() - .query( - "select auth.init($1, $2);", - &[ - &kid as &(dyn ToSql + Sync), - &tokio_postgres::types::Json(jwk), - ], - ) - .await?; - - info!(?kid, "backend session state init"); + info!("backend session state initialized"); + } Ok(handle) } } +fn create_random_jwk() -> (SigningKey, JwkEcKey) { + let key = SigningKey::random(&mut OsRng); + let jwk = p256::PublicKey::from(key.verifying_key()).to_jwk(); + (key, jwk) +} + #[derive(Debug, thiserror::Error)] pub(crate) enum HttpConnError { #[error("pooled connection closed at inconsistent state")] diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 1dde5952e1..4ab14ad35f 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -1,9 +1,9 @@ use futures::{future::poll_fn, Future}; +use indexmap::IndexMap; use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding}; use p256::ecdsa::{Signature, SigningKey}; use parking_lot::RwLock; -use rand::rngs::OsRng; -use serde_json::Value; +use serde_json::value::RawValue; use signature::Signer; use std::task::{ready, Poll}; use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; @@ -12,14 +12,13 @@ use tokio_postgres::tls::NoTlsStream; use tokio_postgres::types::ToSql; use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; use tokio_util::sync::CancellationToken; -use typed_json::json; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::Metrics; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{context::RequestMonitoring, DbName, RoleName}; -use tracing::{debug, error, warn, Span}; +use tracing::{error, warn, Span}; use tracing::{info, info_span, Instrument}; use super::backend::HttpConnError; @@ -245,12 +244,14 @@ impl LocalConnPool { } } +#[allow(clippy::too_many_arguments)] pub(crate) fn poll_client( global_pool: Arc>, ctx: &RequestMonitoring, conn_info: ConnInfo, client: tokio_postgres::Client, mut connection: tokio_postgres::Connection, + key: SigningKey, conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> LocalClient { @@ -346,8 +347,6 @@ pub(crate) fn poll_client( } .instrument(span)); - let key = SigningKey::random(&mut OsRng); - let inner = ClientInner { inner: client, session: tx, @@ -430,13 +429,6 @@ impl LocalClient { let inner = inner.as_mut().expect("client inner should not be removed"); (&mut inner.inner, Discard { conn_info, pool }) } - pub(crate) fn key(&self) -> &SigningKey { - let inner = &self - .inner - .as_ref() - .expect("client inner should not be removed"); - &inner.key - } } impl LocalClient { @@ -445,25 +437,9 @@ impl LocalClient { .inner .as_mut() .expect("client inner should not be removed"); + inner.jti += 1; - - let kid = inner.inner.get_process_id(); - let header = json!({"kid":kid}).to_string(); - - let mut payload = serde_json::from_slice::>(payload) - .map_err(HttpConnError::JwtPayloadError)?; - payload.insert("jti".to_string(), Value::Number(inner.jti.into())); - let payload = Value::Object(payload).to_string(); - - debug!( - kid, - jti = inner.jti, - ?header, - ?payload, - "signing new ephemeral JWT" - ); - - let token = sign_jwt(&inner.key, header, payload); + let token = resign_jwt(&inner.key, payload, inner.jti)?; // initiates the auth session inner.inner.simple_query("discard all").await?; @@ -475,20 +451,74 @@ impl LocalClient { ) .await?; - info!(kid, jti = inner.jti, "user session state init"); + let pid = inner.inner.get_process_id(); + info!(pid, jti = inner.jti, "user session state init"); Ok(()) } } -fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String { - let header = Base64UrlUnpadded::encode_string(header.as_bytes()); - let payload = Base64UrlUnpadded::encode_string(payload.as_bytes()); +/// implements relatively efficient in-place json object key upserting +/// +/// only supports top-level keys +fn upsert_json_object( + payload: &[u8], + key: &str, + value: &RawValue, +) -> Result { + let mut payload = serde_json::from_slice::>(payload)?; + payload.insert(key, value); + serde_json::to_string(&payload) +} - let message = format!("{header}.{payload}"); - let sig: Signature = sk.sign(message.as_bytes()); - let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes()); - format!("{message}.{base64_sig}") +fn resign_jwt(sk: &SigningKey, payload: &[u8], jti: u64) -> Result { + let mut buffer = itoa::Buffer::new(); + + // encode the jti integer to a json rawvalue + let jti = serde_json::from_str::<&RawValue>(buffer.format(jti)).unwrap(); + + // update the jti in-place + let payload = + upsert_json_object(payload, "jti", jti).map_err(HttpConnError::JwtPayloadError)?; + + // sign the jwt + let token = sign_jwt(sk, payload.as_bytes()); + + Ok(token) +} + +fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String { + let header_len = 20; + let payload_len = Base64UrlUnpadded::encoded_len(payload); + let signature_len = Base64UrlUnpadded::encoded_len(&[0; 64]); + let total_len = header_len + payload_len + signature_len + 2; + + let mut jwt = String::with_capacity(total_len); + let cap = jwt.capacity(); + + // we only need an empty header with the alg specified. + // base64url(r#"{"alg":"ES256"}"#) == "eyJhbGciOiJFUzI1NiJ9" + jwt.push_str("eyJhbGciOiJFUzI1NiJ9."); + + // encode the jwt payload in-place + base64::encode_config_buf(payload, base64::URL_SAFE_NO_PAD, &mut jwt); + + // create the signature from the encoded header || payload + let sig: Signature = sk.sign(jwt.as_bytes()); + + jwt.push('.'); + + // encode the jwt signature in-place + base64::encode_config_buf(sig.to_bytes(), base64::URL_SAFE_NO_PAD, &mut jwt); + + debug_assert_eq!( + jwt.len(), + total_len, + "the jwt len should match our expected len" + ); + debug_assert_eq!(jwt.capacity(), cap, "the jwt capacity should not change"); + + jwt } impl Discard<'_, C> { @@ -509,14 +539,6 @@ impl Discard<'_, C> { } impl LocalClient { - pub fn get_client(&self) -> &C { - &self - .inner - .as_ref() - .expect("client inner should not be removed") - .inner - } - fn do_drop(&mut self) -> Option { let conn_info = self.conn_info.clone(); let client = self @@ -542,3 +564,30 @@ impl Drop for LocalClient { } } } + +#[cfg(test)] +mod tests { + use p256::ecdsa::SigningKey; + use typed_json::json; + + use super::resign_jwt; + + #[test] + fn jwt_token_snapshot() { + let key = SigningKey::from_bytes(&[1; 32].into()).unwrap(); + let data = + json!({"foo":"bar","jti":"foo\nbar","nested":{"jti":"tricky nesting"}}).to_string(); + + let jwt = resign_jwt(&key, data.as_bytes(), 2).unwrap(); + + // To validate the JWT, copy the JWT string and paste it into https://jwt.io/. + // In the public-key box, paste the following jwk public key + // `{"kty":"EC","crv":"P-256","x":"b_A7lJJBzh2t1DUZ5pYOCoW0GmmgXDKBA6orzhWUyhY","y":"PE91OlW_AdxT9sCwx-7ni0DG_30lqW4igrmJzvccFEo"}` + + // let pub_key = p256::ecdsa::VerifyingKey::from(&key); + // let pub_key = p256::PublicKey::from(pub_key); + // println!("{}", pub_key.to_jwk_string()); + + assert_eq!(jwt, "eyJhbGciOiJFUzI1NiJ9.eyJmb28iOiJiYXIiLCJqdGkiOjIsIm5lc3RlZCI6eyJqdGkiOiJ0cmlja3kgbmVzdGluZyJ9fQ.pYf0LxoJ8sDgpmsYOgrbNecOSipnPBEGwnZzB-JhW2cONrKlqRsgXwK8_cOsyolGy-hTTe8GXbWTl_UdpF5RyA"); + } +} diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 0a90b6b6f7..1347d6ddff 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -46,7 +46,8 @@ hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper-582f2526e08bb6a0 = { package = "hyper", version = "0.14", features = ["full"] } hyper-dff4ba8e3ae991db = { package = "hyper", version = "1", features = ["full"] } hyper-util = { version = "0.1", features = ["client-legacy", "server-auto", "service"] } -indexmap = { version = "1", default-features = false, features = ["std"] } +indexmap-dff4ba8e3ae991db = { package = "indexmap", version = "1", default-features = false, features = ["std"] } +indexmap-f595c2ba2a3f28df = { package = "indexmap", version = "2", features = ["serde"] } itertools = { version = "0.12" } lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } libc = { version = "0.2", features = ["extra_traits", "use_std"] } @@ -101,7 +102,8 @@ either = { version = "1" } getrandom = { version = "0.2", default-features = false, features = ["std"] } half = { version = "2", default-features = false, features = ["num-traits"] } hashbrown = { version = "0.14", features = ["raw"] } -indexmap = { version = "1", default-features = false, features = ["std"] } +indexmap-dff4ba8e3ae991db = { package = "indexmap", version = "1", default-features = false, features = ["std"] } +indexmap-f595c2ba2a3f28df = { package = "indexmap", version = "2", features = ["serde"] } itertools = { version = "0.12" } libc = { version = "0.2", features = ["extra_traits", "use_std"] } log = { version = "0.4", default-features = false, features = ["std"] }