From 5004e62f5bf5e3ee28d09da29c5ea976ea6a4441 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 16 Nov 2023 15:15:08 +0100 Subject: [PATCH] proxy: start work on jwt auth --- Cargo.lock | 91 ++++++++++++++++++------- proxy/Cargo.toml | 2 +- proxy/src/auth/backend.rs | 38 +++++++++++ proxy/src/console/provider.rs | 12 +++- proxy/src/console/provider/mock.rs | 20 +++++- proxy/src/console/provider/neon.rs | 68 ++++++++++++++++++- proxy/src/http.rs | 8 +++ proxy/src/serverless.rs | 21 +++++- proxy/src/serverless/conn_pool.rs | 25 ++++++- proxy/src/serverless/jwt_auth.rs | 98 +++++++++++++++++++++++++++ proxy/src/serverless/sql_over_http.rs | 72 ++++++++++++++++---- 11 files changed, 406 insertions(+), 49 deletions(-) create mode 100644 proxy/src/serverless/jwt_auth.rs diff --git a/Cargo.lock b/Cargo.lock index 841c60c7e4..53bfd6956f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -274,7 +274,7 @@ dependencies = [ "hex", "http", "hyper", - "ring", + "ring 0.16.20", "time", "tokio", "tower", @@ -703,7 +703,7 @@ dependencies = [ "bytes", "dyn-clone", "futures", - "getrandom 0.2.9", + "getrandom 0.2.11", "http-types", "log", "paste", @@ -863,6 +863,22 @@ dependencies = [ "which", ] +[[package]] +name = "biscuit" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e28fc7c56c61743a01d0d1b73e4fed68b8a4f032ea3a2d4bb8c6520a33fc05a" +dependencies = [ + "chrono", + "data-encoding", + "num-bigint", + "num-traits", + "once_cell", + "ring 0.17.5", + "serde", + "serde_json", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -945,11 +961,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "jobserver", + "libc", ] [[package]] @@ -1846,9 +1863,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "js-sys", @@ -2342,7 +2359,7 @@ checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" dependencies = [ "base64 0.21.1", "pem 1.1.1", - "ring", + "ring 0.16.20", "serde", "serde_json", "simple_asn1", @@ -2382,9 +2399,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.144" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libloading" @@ -2691,7 +2708,7 @@ checksum = "c38841cdd844847e3e7c8d29cef9dcfed8877f8f56f9071f77843ecf3baf937f" dependencies = [ "base64 0.13.1", "chrono", - "getrandom 0.2.9", + "getrandom 0.2.11", "http", "rand 0.8.5", "serde", @@ -3474,6 +3491,7 @@ dependencies = [ "anyhow", "async-trait", "base64 0.13.1", + "biscuit", "bstr", "bytes", "chrono", @@ -3619,7 +3637,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.9", + "getrandom 0.2.11", ] [[package]] @@ -3660,7 +3678,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4954fbc00dcd4d8282c987710e50ba513d351400dbdd00e803a05172a90d8976" dependencies = [ "pem 2.0.1", - "ring", + "ring 0.16.20", "time", "yasna", ] @@ -3830,7 +3848,7 @@ dependencies = [ "async-trait", "chrono", "futures", - "getrandom 0.2.9", + "getrandom 0.2.11", "http", "hyper", "parking_lot 0.11.2", @@ -3851,7 +3869,7 @@ checksum = "1b97ad83c2fc18113346b7158d79732242002427c30f620fa817c1f32901e0a8" dependencies = [ "anyhow", "async-trait", - "getrandom 0.2.9", + "getrandom 0.2.11", "matchit", "opentelemetry", "reqwest", @@ -3882,11 +3900,25 @@ dependencies = [ "libc", "once_cell", "spin 0.5.2", - "untrusted", + "untrusted 0.7.1", "web-sys", "winapi", ] +[[package]] +name = "ring" +version = "0.17.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" +dependencies = [ + "cc", + "getrandom 0.2.11", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys 0.48.0", +] + [[package]] name = "routerify" version = "3.0.0" @@ -4003,7 +4035,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb" dependencies = [ "log", - "ring", + "ring 0.16.20", "rustls-webpki 0.101.4", "sct", ] @@ -4035,8 +4067,8 @@ version = "0.100.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab" dependencies = [ - "ring", - "untrusted", + "ring 0.16.20", + "untrusted 0.7.1", ] [[package]] @@ -4045,8 +4077,8 @@ version = "0.101.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d" dependencies = [ - "ring", - "untrusted", + "ring 0.16.20", + "untrusted 0.7.1", ] [[package]] @@ -4191,8 +4223,8 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" dependencies = [ - "ring", - "untrusted", + "ring 0.16.20", + "untrusted 0.7.1", ] [[package]] @@ -4311,7 +4343,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99dc599bd6646884fc403d593cdcb9816dd67c50cff3271c01ff123617908dcd" dependencies = [ "debugid", - "getrandom 0.2.9", + "getrandom 0.2.11", "hex", "serde", "serde_json", @@ -4357,6 +4389,7 @@ version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" dependencies = [ + "indexmap", "itoa", "ryu", "serde", @@ -4960,7 +4993,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd5831152cb0d3f79ef5523b357319ba154795d64c7078b2daa95a803b54057f" dependencies = [ "futures", - "ring", + "ring 0.16.20", "rustls", "tokio", "tokio-postgres", @@ -5416,6 +5449,12 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "ureq" version = "2.7.1" @@ -5517,7 +5556,7 @@ version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "345444e32442451b267fc254ae85a209c64be56d2890e601a0c37ff0c3c5ecd2" dependencies = [ - "getrandom 0.2.9", + "getrandom 0.2.11", "serde", ] @@ -6010,7 +6049,7 @@ dependencies = [ "regex", "regex-syntax 0.7.2", "reqwest", - "ring", + "ring 0.16.20", "rustls", "scopeguard", "serde", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 0ec7efd316..4d834c92b6 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -68,7 +68,7 @@ webpki-roots.workspace = true x509-parser.workspace = true native-tls.workspace = true postgres-native-tls.workspace = true - +biscuit = { version = "0.7",features = [] } workspace_hack.workspace = true tokio-util.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 9cf45c0eec..f904cee985 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -3,6 +3,7 @@ mod hacks; mod link; pub use link::LinkAuthError; +use serde::{Deserialize, Serialize}; use tokio_postgres::config::AuthKeys; use crate::proxy::{handle_try_wake, retry_after, LatencyTimer}; @@ -319,4 +320,41 @@ impl BackendType<'_, ClientCredentials<'_>> { Test(x) => x.wake_compute().map(Some), } } + + /// Get the password for the RLS user + pub async fn ensure_row_level( + &self, + extra: &ConsoleReqExtra<'_>, + dbname: String, + username: String, + policies: Vec, + ) -> anyhow::Result { + use BackendType::*; + + match self { + Console(api, creds) => { + api.ensure_row_level(extra, creds, dbname, username, policies) + .await + } + Postgres(api, creds) => { + api.ensure_row_level(extra, creds, dbname, username, policies) + .await + } + Link(_) => Err(anyhow::anyhow!("not on link")), + Test(_) => Err(anyhow::anyhow!("not on test")), + } + } } + +// TODO(conrad): policies can be quite complex. Figure out how to configure this + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct Policy { + table: String, + column: String, +} + +// enum PolicyType { +// ForSelect(), +// ForUpdate() +// } diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 54bcd1f081..dbbf71a4b9 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -3,7 +3,7 @@ pub mod neon; use super::messages::MetricsAuxInfo; use crate::{ - auth::ClientCredentials, + auth::{ClientCredentials, backend::Policy}, cache::{timed_lru, TimedLru}, compute, scram, }; @@ -248,6 +248,16 @@ pub trait Api { extra: &ConsoleReqExtra<'_>, creds: &ClientCredentials, ) -> Result; + + /// Get the password for the RLS user + async fn ensure_row_level( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials, + dbname: String, + username: String, + policies: Vec + ) -> anyhow::Result; } /// Various caches for [`console`](super). diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 750a2d141e..68bc909b00 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -4,7 +4,13 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; -use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl}; +use crate::{ + auth::{backend::Policy, ClientCredentials}, + compute, + error::io_error, + scram, + url::ApiUrl, +}; use async_trait::async_trait; use futures::TryFutureExt; use thiserror::Error; @@ -121,6 +127,18 @@ impl super::Api for Api { .map_ok(CachedNodeInfo::new_uncached) .await } + + /// Get the password for the RLS user + async fn ensure_row_level( + &self, + _extra: &ConsoleReqExtra<'_>, + _creds: &ClientCredentials, + _dbname: String, + _username: String, + _policies: Vec, + ) -> anyhow::Result { + Err(anyhow::anyhow!("unimplemented")) + } } fn parse_md5(input: &str) -> Option<[u8; 16]> { diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 0dc7c71534..a373dda000 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -5,9 +5,13 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, ApiCaches, ApiLocks, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; -use crate::{auth::ClientCredentials, compute, http, scram}; +use crate::{ + auth::{backend::Policy, ClientCredentials}, + compute, http, scram, +}; use async_trait::async_trait; use futures::TryFutureExt; +use serde::{Deserialize, Serialize}; use std::{net::SocketAddr, sync::Arc}; use tokio::time::Instant; use tokio_postgres::config::SslMode; @@ -139,6 +143,55 @@ impl Api { .instrument(info_span!("http", id = request_id)) .await } + + async fn do_ensure_row_level( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials<'_>, + dbname: String, + username: String, + policies: Vec, + ) -> anyhow::Result { + let project = creds.project().expect("impossible"); + let request_id = uuid::Uuid::new_v4().to_string(); + async { + let request = self + .endpoint + .post("proxy_ensure_row_level") + .header("X-Request-ID", &request_id) + .header("Authorization", format!("Bearer {}", &self.jwt)) + .query(&[("session_id", extra.session_id)]) + .query(&[ + ("application_name", extra.application_name), + ("project", Some(project)), + ("dbname", Some(&dbname)), + ("username", Some(&username)), + ("options", extra.options), + ]) + .json(&EnsureRowLevelReq { policies }) + .build()?; + + info!(url = request.url().as_str(), "sending http request"); + let start = Instant::now(); + let response = self.endpoint.execute(request).await?; + info!(duration = ?start.elapsed(), "received http response"); + let body = parse_body::(response).await?; + + Ok(body.password) + } + .map_err(crate::error::log_error) + .instrument(info_span!("http", id = request_id)) + .await + } +} + +#[derive(Serialize)] +struct EnsureRowLevelReq { + policies: Vec, +} +#[derive(Deserialize)] +struct UserRowLevel { + password: String, } #[async_trait] @@ -188,6 +241,19 @@ impl super::Api for Api { Ok(cached) } + + /// Get the password for the RLS user + async fn ensure_row_level( + &self, + extra: &ConsoleReqExtra<'_>, + creds: &ClientCredentials, + dbname: String, + username: String, + policies: Vec, + ) -> anyhow::Result { + self.do_ensure_row_level(extra, creds, dbname, username, policies) + .await + } } /// Parse http response body, taking status code into account. diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 159b949da3..0fb40bb6be 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -88,6 +88,14 @@ impl Endpoint { self.client.get(url.into_inner()) } + /// Return a [builder](RequestBuilder) for a `POST` request, + /// appending a single `path` segment to the base endpoint URL. + pub fn post(&self, path: &str) -> RequestBuilder { + let mut url = self.endpoint.clone(); + url.path_segments_mut().push(path); + self.client.post(url.into_inner()) + } + /// Execute a [request](reqwest::Request). pub async fn execute(&self, request: Request) -> Result { self.client.execute(request).await diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 23deda3ae6..93fac65e4e 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -3,10 +3,12 @@ //! Handles both SQL over HTTP and SQL over Websockets. mod conn_pool; +pub mod jwt_auth; mod sql_over_http; mod websocket; use anyhow::bail; +use dashmap::DashMap; use hyper::StatusCode; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; @@ -31,6 +33,8 @@ use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; use utils::http::{error::ApiError, json::json_response}; +use self::jwt_auth::JWKSetCaches; + pub async fn task_main( config: &'static ProxyConfig, ws_listener: TcpListener, @@ -41,6 +45,9 @@ pub async fn task_main( } let conn_pool = conn_pool::GlobalConnPool::new(config); + let jwk_cache_pool = Arc::new(JWKSetCaches { + map: DashMap::new(), + }); // shutdown the connection pool tokio::spawn({ @@ -85,6 +92,7 @@ pub async fn task_main( let remote_addr = io.inner.remote_addr(); let sni_name = tls.server_name().map(|s| s.to_string()); let conn_pool = conn_pool.clone(); + let jwk_cache_pool = jwk_cache_pool.clone(); async move { let peer_addr = match client_addr { @@ -96,13 +104,20 @@ pub async fn task_main( move |req: Request| { let sni_name = sni_name.clone(); let conn_pool = conn_pool.clone(); + let jwk_cache_pool = jwk_cache_pool.clone(); async move { let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); request_handler( - req, config, conn_pool, cancel_map, session_id, sni_name, + req, + config, + conn_pool, + jwk_cache_pool, + cancel_map, + session_id, + sni_name, ) .instrument(info_span!( "serverless", @@ -167,6 +182,7 @@ async fn request_handler( mut request: Request, config: &'static ProxyConfig, conn_pool: Arc, + jwk_cache_pool: Arc, cancel_map: Arc, session_id: uuid::Uuid, sni_hostname: Option, @@ -204,6 +220,7 @@ async fn request_handler( request, sni_hostname, conn_pool, + jwk_cache_pool, session_id, &config.http_config, ) @@ -214,7 +231,7 @@ async fn request_handler( .header("Access-Control-Allow-Origin", "*") .header( "Access-Control-Allow-Headers", - "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In", + "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Authorization", ) .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index d09554a922..6b48dc1c0e 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -21,7 +21,8 @@ use tokio::time; use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use crate::{ - auth, console, + auth::{self, backend::Policy}, + console, proxy::{ neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER, @@ -45,6 +46,8 @@ pub struct ConnInfo { pub hostname: String, pub password: String, pub options: Option, + /// row level security mode enabled + pub policies: Option>, } impl ConnInfo { @@ -365,6 +368,7 @@ struct TokioMechanism<'a> { conn_info: &'a ConnInfo, session_id: uuid::Uuid, conn_id: uuid::Uuid, + password: Option, } #[async_trait] @@ -384,6 +388,7 @@ impl ConnectMechanism for TokioMechanism<'_> { timeout, self.conn_id, self.session_id, + self.password.as_deref(), ) .await } @@ -431,11 +436,26 @@ async fn connect_to_compute( .await? .context("missing cache entry from wake_compute")?; + let mut password = None; + if let Some(policies) = &conn_info.policies { + password = Some( + creds + .ensure_row_level( + &extra, + conn_info.dbname.to_owned(), + conn_info.username.to_owned(), + policies.clone(), + ) + .await?, + ); + } + crate::proxy::connect_to_compute( &TokioMechanism { conn_id, conn_info, session_id, + password, }, node_info, &extra, @@ -451,12 +471,13 @@ async fn connect_to_compute_once( timeout: time::Duration, conn_id: uuid::Uuid, mut session: uuid::Uuid, + password: Option<&str>, ) -> Result { let mut config = (*node_info.config).clone(); let (client, mut connection) = config .user(&conn_info.username) - .password(&conn_info.password) + .password(password.unwrap_or(&conn_info.password)) .dbname(&conn_info.dbname) .connect_timeout(timeout) .connect(tokio_postgres::NoTls) diff --git a/proxy/src/serverless/jwt_auth.rs b/proxy/src/serverless/jwt_auth.rs new file mode 100644 index 0000000000..f204f3714a --- /dev/null +++ b/proxy/src/serverless/jwt_auth.rs @@ -0,0 +1,98 @@ +// https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json + +use std::sync::Arc; + +use anyhow::{bail, Context}; +use biscuit::{ + jwk::{JWKSet, JWK}, + jws, CompactPart, +}; +use dashmap::DashMap; +use reqwest::{IntoUrl, Url}; +use serde::{de::DeserializeOwned, Serialize}; +use tokio::sync::RwLock; + +pub struct JWKSetCaches { + pub map: DashMap>, +} + +impl JWKSetCaches { + pub async fn get_cache(&self, url: impl IntoUrl) -> anyhow::Result> { + let url = url.into_url()?; + if let Some(x) = self.map.get(&url) { + return Ok(x.clone()); + } + let cache = JWKSetCache::new(url.clone()).await?; + let cache = Arc::new(cache); + self.map.insert(url, cache.clone()); + Ok(cache) + } +} + +pub struct JWKSetCache { + url: Url, + current: RwLock>, +} + +impl JWKSetCache { + pub async fn new(url: impl IntoUrl) -> anyhow::Result { + let url = url.into_url()?; + let current = reqwest::get(url.clone()).await?.json().await?; + Ok(Self { + url, + current: RwLock::new(current), + }) + } + + pub async fn get(&self, kid: &str) -> anyhow::Result> { + let current = self.current.read().await.clone(); + if let Some(key) = current.find(kid) { + return Ok(key.clone()); + } + let new = reqwest::get(self.url.clone()).await?.json().await?; + if new == current { + bail!("not found") + } + *self.current.write().await = new; + current.find(kid).cloned().context("not found") + } + + pub async fn decode( + &self, + token: &jws::Compact, + ) -> anyhow::Result> + where + T: CompactPart, + H: Serialize + DeserializeOwned, + { + let current = self.current.read().await.clone(); + match token.decode_with_jwks(¤t, None) { + Ok(t) => Ok(t), + Err(biscuit::errors::Error::ValidationError( + biscuit::errors::ValidationError::KeyNotFound, + )) => { + let new: JWKSet<()> = reqwest::get(self.url.clone()).await?.json().await?; + if new == current { + bail!("not found") + } + *self.current.write().await = new.clone(); + token.decode_with_jwks(&new, None).context("error") + // current.find(kid).cloned().context("not found") + } + Err(e) => Err(e.into()), + } + } +} + +#[cfg(test)] +mod tests { + use super::JWKSetCache; + #[tokio::test] + async fn jwkset() { + let cache = + JWKSetCache::new("https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json") + .await + .unwrap(); + dbg!(cache.get("ins_2YFechxysnwZcZN6TDHEz6u6w6v").await.unwrap()); + } +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 16736ac00d..47ff0bb2fd 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,15 +1,20 @@ use std::sync::Arc; use anyhow::bail; +use anyhow::Context; +use biscuit::JWT; use futures::pin_mut; use futures::StreamExt; use hyper::body::HttpBody; use hyper::header; +use hyper::header::AUTHORIZATION; use hyper::http::HeaderName; use hyper::http::HeaderValue; use hyper::Response; use hyper::StatusCode; use hyper::{Body, HeaderMap, Request}; +use serde::Deserialize; +use serde::Serialize; use serde_json::json; use serde_json::Map; use serde_json::Value; @@ -26,11 +31,13 @@ use url::Url; use utils::http::error::ApiError; use utils::http::json::json_response; +use crate::auth::backend::Policy; use crate::config::HttpConfig; use crate::proxy::{NUM_CONNECTIONS_ACCEPTED_COUNTER, NUM_CONNECTIONS_CLOSED_COUNTER}; use super::conn_pool::ConnInfo; use super::conn_pool::GlobalConnPool; +use super::jwt_auth::JWKSetCaches; #[derive(serde::Deserialize)] struct QueryData { @@ -118,9 +125,10 @@ fn json_array_to_pg_array(value: &Value) -> Option { } } -fn get_conn_info( +async fn get_conn_info( + jwk_cache_pool: &JWKSetCaches, headers: &HeaderMap, - sni_hostname: Option, + sni_hostname: &str, ) -> Result { let connection_string = headers .get("Neon-Connection-String") @@ -144,18 +152,40 @@ fn get_conn_info( .next() .ok_or(anyhow::anyhow!("invalid database name"))?; - let username = connection_url.username(); - if username.is_empty() { - return Err(anyhow::anyhow!("missing username")); - } + let mut password = ""; + let mut policies = None; + let authorization = headers.get(AUTHORIZATION); + let username = if let Some(auth) = authorization { + // TODO: introduce control plane API to fetch this + let jwks_url = match sni_hostname { + "foo" => "https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json", + _ => anyhow::bail!("invalid sni name"), + }; + let jwk_cache = jwk_cache_pool.get_cache(jwks_url).await?; - let password = connection_url - .password() - .ok_or(anyhow::anyhow!("no password"))?; + let auth = auth.to_str()?; + let token = auth.strip_prefix("Bearer ").context("bad token")?; + let jwt: JWT = JWT::new_encoded(token); + let token = jwk_cache.decode(&jwt).await?; + let payload = token.payload().unwrap(); + policies = Some(payload.private.policies.clone()); + payload + .registered + .subject + .as_deref() + .context("missing user id")? + .to_owned() + } else { + password = connection_url + .password() + .ok_or(anyhow::anyhow!("no password"))?; - // TLS certificate selector now based on SNI hostname, so if we are running here - // we are sure that SNI hostname is set to one of the configured domain names. - let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?; + let u = connection_url.username(); + if u.is_empty() { + return Err(anyhow::anyhow!("missing username")); + } + u.to_owned() + }; let hostname = connection_url .host_str() @@ -186,7 +216,8 @@ fn get_conn_info( } Ok(ConnInfo { - username: username.to_owned(), + username, + policies, dbname: dbname.to_owned(), hostname: hostname.to_owned(), password: password.to_owned(), @@ -199,12 +230,13 @@ pub async fn handle( request: Request, sni_hostname: Option, conn_pool: Arc, + jwk_cache_pool: Arc, session_id: uuid::Uuid, config: &'static HttpConfig, ) -> Result, ApiError> { let result = tokio::time::timeout( config.sql_over_http_timeout, - handle_inner(request, sni_hostname, conn_pool, session_id), + handle_inner(request, sni_hostname, conn_pool, jwk_cache_pool, session_id), ) .await; let mut response = match result { @@ -255,6 +287,7 @@ async fn handle_inner( request: Request, sni_hostname: Option, conn_pool: Arc, + jwk_cache_pool: Arc, session_id: uuid::Uuid, ) -> anyhow::Result> { NUM_CONNECTIONS_ACCEPTED_COUNTER @@ -264,11 +297,15 @@ async fn handle_inner( NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc(); } + // TLS certificate selector now based on SNI hostname, so if we are running here + // we are sure that SNI hostname is set to one of the configured domain names. + let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?; + // // Determine the destination and connection params // let headers = request.headers(); - let conn_info = get_conn_info(headers, sni_hostname)?; + let conn_info = get_conn_info(&jwk_cache_pool, headers, &sni_hostname).await?; // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. @@ -697,6 +734,11 @@ fn _pg_array_parse( Ok((Value::Array(entries), 0)) } +#[derive(Serialize, Deserialize)] +pub struct NeonFields { + policies: Vec, +} + #[cfg(test)] mod tests { use super::*;