proxy: start work on jwt auth

This commit is contained in:
Conrad Ludgate
2023-11-16 15:15:08 +01:00
parent ab631e6792
commit 5004e62f5b
11 changed files with 406 additions and 49 deletions

View File

@@ -68,7 +68,7 @@ webpki-roots.workspace = true
x509-parser.workspace = true
native-tls.workspace = true
postgres-native-tls.workspace = true
biscuit = { version = "0.7",features = [] }
workspace_hack.workspace = true
tokio-util.workspace = true

View File

@@ -3,6 +3,7 @@ mod hacks;
mod link;
pub use link::LinkAuthError;
use serde::{Deserialize, Serialize};
use tokio_postgres::config::AuthKeys;
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
@@ -319,4 +320,41 @@ impl BackendType<'_, ClientCredentials<'_>> {
Test(x) => x.wake_compute().map(Some),
}
}
/// Get the password for the RLS user
pub async fn ensure_row_level(
&self,
extra: &ConsoleReqExtra<'_>,
dbname: String,
username: String,
policies: Vec<Policy>,
) -> anyhow::Result<String> {
use BackendType::*;
match self {
Console(api, creds) => {
api.ensure_row_level(extra, creds, dbname, username, policies)
.await
}
Postgres(api, creds) => {
api.ensure_row_level(extra, creds, dbname, username, policies)
.await
}
Link(_) => Err(anyhow::anyhow!("not on link")),
Test(_) => Err(anyhow::anyhow!("not on test")),
}
}
}
// TODO(conrad): policies can be quite complex. Figure out how to configure this
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Policy {
table: String,
column: String,
}
// enum PolicyType {
// ForSelect(),
// ForUpdate()
// }

View File

@@ -3,7 +3,7 @@ pub mod neon;
use super::messages::MetricsAuxInfo;
use crate::{
auth::ClientCredentials,
auth::{ClientCredentials, backend::Policy},
cache::{timed_lru, TimedLru},
compute, scram,
};
@@ -248,6 +248,16 @@ pub trait Api {
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
/// Get the password for the RLS user
async fn ensure_row_level(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
dbname: String,
username: String,
policies: Vec<Policy>
) -> anyhow::Result<String>;
}
/// Various caches for [`console`](super).

View File

@@ -4,7 +4,13 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
use crate::{
auth::{backend::Policy, ClientCredentials},
compute,
error::io_error,
scram,
url::ApiUrl,
};
use async_trait::async_trait;
use futures::TryFutureExt;
use thiserror::Error;
@@ -121,6 +127,18 @@ impl super::Api for Api {
.map_ok(CachedNodeInfo::new_uncached)
.await
}
/// Get the password for the RLS user
async fn ensure_row_level(
&self,
_extra: &ConsoleReqExtra<'_>,
_creds: &ClientCredentials,
_dbname: String,
_username: String,
_policies: Vec<Policy>,
) -> anyhow::Result<String> {
Err(anyhow::anyhow!("unimplemented"))
}
}
fn parse_md5(input: &str) -> Option<[u8; 16]> {

View File

@@ -5,9 +5,13 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, ApiLocks, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{auth::ClientCredentials, compute, http, scram};
use crate::{
auth::{backend::Policy, ClientCredentials},
compute, http, scram,
};
use async_trait::async_trait;
use futures::TryFutureExt;
use serde::{Deserialize, Serialize};
use std::{net::SocketAddr, sync::Arc};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
@@ -139,6 +143,55 @@ impl Api {
.instrument(info_span!("http", id = request_id))
.await
}
async fn do_ensure_row_level(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
dbname: String,
username: String,
policies: Vec<Policy>,
) -> anyhow::Result<String> {
let project = creds.project().expect("impossible");
let request_id = uuid::Uuid::new_v4().to_string();
async {
let request = self
.endpoint
.post("proxy_ensure_row_level")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", extra.session_id)])
.query(&[
("application_name", extra.application_name),
("project", Some(project)),
("dbname", Some(&dbname)),
("username", Some(&username)),
("options", extra.options),
])
.json(&EnsureRowLevelReq { policies })
.build()?;
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let response = self.endpoint.execute(request).await?;
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<UserRowLevel>(response).await?;
Ok(body.password)
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}
}
#[derive(Serialize)]
struct EnsureRowLevelReq {
policies: Vec<Policy>,
}
#[derive(Deserialize)]
struct UserRowLevel {
password: String,
}
#[async_trait]
@@ -188,6 +241,19 @@ impl super::Api for Api {
Ok(cached)
}
/// Get the password for the RLS user
async fn ensure_row_level(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
dbname: String,
username: String,
policies: Vec<Policy>,
) -> anyhow::Result<String> {
self.do_ensure_row_level(extra, creds, dbname, username, policies)
.await
}
}
/// Parse http response body, taking status code into account.

View File

@@ -88,6 +88,14 @@ impl Endpoint {
self.client.get(url.into_inner())
}
/// Return a [builder](RequestBuilder) for a `POST` request,
/// appending a single `path` segment to the base endpoint URL.
pub fn post(&self, path: &str) -> RequestBuilder {
let mut url = self.endpoint.clone();
url.path_segments_mut().push(path);
self.client.post(url.into_inner())
}
/// Execute a [request](reqwest::Request).
pub async fn execute(&self, request: Request) -> Result<Response, Error> {
self.client.execute(request).await

View File

@@ -3,10 +3,12 @@
//! Handles both SQL over HTTP and SQL over Websockets.
mod conn_pool;
pub mod jwt_auth;
mod sql_over_http;
mod websocket;
use anyhow::bail;
use dashmap::DashMap;
use hyper::StatusCode;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
@@ -31,6 +33,8 @@ use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};
use self::jwt_auth::JWKSetCaches;
pub async fn task_main(
config: &'static ProxyConfig,
ws_listener: TcpListener,
@@ -41,6 +45,9 @@ pub async fn task_main(
}
let conn_pool = conn_pool::GlobalConnPool::new(config);
let jwk_cache_pool = Arc::new(JWKSetCaches {
map: DashMap::new(),
});
// shutdown the connection pool
tokio::spawn({
@@ -85,6 +92,7 @@ pub async fn task_main(
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
let jwk_cache_pool = jwk_cache_pool.clone();
async move {
let peer_addr = match client_addr {
@@ -96,13 +104,20 @@ pub async fn task_main(
move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
let jwk_cache_pool = jwk_cache_pool.clone();
async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
request_handler(
req, config, conn_pool, cancel_map, session_id, sni_name,
req,
config,
conn_pool,
jwk_cache_pool,
cancel_map,
session_id,
sni_name,
)
.instrument(info_span!(
"serverless",
@@ -167,6 +182,7 @@ async fn request_handler(
mut request: Request<Body>,
config: &'static ProxyConfig,
conn_pool: Arc<conn_pool::GlobalConnPool>,
jwk_cache_pool: Arc<JWKSetCaches>,
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
sni_hostname: Option<String>,
@@ -204,6 +220,7 @@ async fn request_handler(
request,
sni_hostname,
conn_pool,
jwk_cache_pool,
session_id,
&config.http_config,
)
@@ -214,7 +231,7 @@ async fn request_handler(
.header("Access-Control-Allow-Origin", "*")
.header(
"Access-Control-Allow-Headers",
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In",
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Authorization",
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@@ -21,7 +21,8 @@ use tokio::time;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
use crate::{
auth, console,
auth::{self, backend::Policy},
console,
proxy::{
neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER,
NUM_DB_CONNECTIONS_OPENED_COUNTER,
@@ -45,6 +46,8 @@ pub struct ConnInfo {
pub hostname: String,
pub password: String,
pub options: Option<String>,
/// row level security mode enabled
pub policies: Option<Vec<Policy>>,
}
impl ConnInfo {
@@ -365,6 +368,7 @@ struct TokioMechanism<'a> {
conn_info: &'a ConnInfo,
session_id: uuid::Uuid,
conn_id: uuid::Uuid,
password: Option<String>,
}
#[async_trait]
@@ -384,6 +388,7 @@ impl ConnectMechanism for TokioMechanism<'_> {
timeout,
self.conn_id,
self.session_id,
self.password.as_deref(),
)
.await
}
@@ -431,11 +436,26 @@ async fn connect_to_compute(
.await?
.context("missing cache entry from wake_compute")?;
let mut password = None;
if let Some(policies) = &conn_info.policies {
password = Some(
creds
.ensure_row_level(
&extra,
conn_info.dbname.to_owned(),
conn_info.username.to_owned(),
policies.clone(),
)
.await?,
);
}
crate::proxy::connect_to_compute(
&TokioMechanism {
conn_id,
conn_info,
session_id,
password,
},
node_info,
&extra,
@@ -451,12 +471,13 @@ async fn connect_to_compute_once(
timeout: time::Duration,
conn_id: uuid::Uuid,
mut session: uuid::Uuid,
password: Option<&str>,
) -> Result<ClientInner, tokio_postgres::Error> {
let mut config = (*node_info.config).clone();
let (client, mut connection) = config
.user(&conn_info.username)
.password(&conn_info.password)
.password(password.unwrap_or(&conn_info.password))
.dbname(&conn_info.dbname)
.connect_timeout(timeout)
.connect(tokio_postgres::NoTls)

View File

@@ -0,0 +1,98 @@
// https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json
use std::sync::Arc;
use anyhow::{bail, Context};
use biscuit::{
jwk::{JWKSet, JWK},
jws, CompactPart,
};
use dashmap::DashMap;
use reqwest::{IntoUrl, Url};
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::RwLock;
pub struct JWKSetCaches {
pub map: DashMap<Url, Arc<JWKSetCache>>,
}
impl JWKSetCaches {
pub async fn get_cache(&self, url: impl IntoUrl) -> anyhow::Result<Arc<JWKSetCache>> {
let url = url.into_url()?;
if let Some(x) = self.map.get(&url) {
return Ok(x.clone());
}
let cache = JWKSetCache::new(url.clone()).await?;
let cache = Arc::new(cache);
self.map.insert(url, cache.clone());
Ok(cache)
}
}
pub struct JWKSetCache {
url: Url,
current: RwLock<biscuit::jwk::JWKSet<()>>,
}
impl JWKSetCache {
pub async fn new(url: impl IntoUrl) -> anyhow::Result<Self> {
let url = url.into_url()?;
let current = reqwest::get(url.clone()).await?.json().await?;
Ok(Self {
url,
current: RwLock::new(current),
})
}
pub async fn get(&self, kid: &str) -> anyhow::Result<JWK<()>> {
let current = self.current.read().await.clone();
if let Some(key) = current.find(kid) {
return Ok(key.clone());
}
let new = reqwest::get(self.url.clone()).await?.json().await?;
if new == current {
bail!("not found")
}
*self.current.write().await = new;
current.find(kid).cloned().context("not found")
}
pub async fn decode<T, H>(
&self,
token: &jws::Compact<T, H>,
) -> anyhow::Result<jws::Compact<T, H>>
where
T: CompactPart,
H: Serialize + DeserializeOwned,
{
let current = self.current.read().await.clone();
match token.decode_with_jwks(&current, None) {
Ok(t) => Ok(t),
Err(biscuit::errors::Error::ValidationError(
biscuit::errors::ValidationError::KeyNotFound,
)) => {
let new: JWKSet<()> = reqwest::get(self.url.clone()).await?.json().await?;
if new == current {
bail!("not found")
}
*self.current.write().await = new.clone();
token.decode_with_jwks(&new, None).context("error")
// current.find(kid).cloned().context("not found")
}
Err(e) => Err(e.into()),
}
}
}
#[cfg(test)]
mod tests {
use super::JWKSetCache;
#[tokio::test]
async fn jwkset() {
let cache =
JWKSetCache::new("https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json")
.await
.unwrap();
dbg!(cache.get("ins_2YFechxysnwZcZN6TDHEz6u6w6v").await.unwrap());
}
}

View File

@@ -1,15 +1,20 @@
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
use biscuit::JWT;
use futures::pin_mut;
use futures::StreamExt;
use hyper::body::HttpBody;
use hyper::header;
use hyper::header::AUTHORIZATION;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use serde_json::Map;
use serde_json::Value;
@@ -26,11 +31,13 @@ use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::auth::backend::Policy;
use crate::config::HttpConfig;
use crate::proxy::{NUM_CONNECTIONS_ACCEPTED_COUNTER, NUM_CONNECTIONS_CLOSED_COUNTER};
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
use super::jwt_auth::JWKSetCaches;
#[derive(serde::Deserialize)]
struct QueryData {
@@ -118,9 +125,10 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
}
}
fn get_conn_info(
async fn get_conn_info(
jwk_cache_pool: &JWKSetCaches,
headers: &HeaderMap,
sni_hostname: Option<String>,
sni_hostname: &str,
) -> Result<ConnInfo, anyhow::Error> {
let connection_string = headers
.get("Neon-Connection-String")
@@ -144,18 +152,40 @@ fn get_conn_info(
.next()
.ok_or(anyhow::anyhow!("invalid database name"))?;
let username = connection_url.username();
if username.is_empty() {
return Err(anyhow::anyhow!("missing username"));
}
let mut password = "";
let mut policies = None;
let authorization = headers.get(AUTHORIZATION);
let username = if let Some(auth) = authorization {
// TODO: introduce control plane API to fetch this
let jwks_url = match sni_hostname {
"foo" => "https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json",
_ => anyhow::bail!("invalid sni name"),
};
let jwk_cache = jwk_cache_pool.get_cache(jwks_url).await?;
let password = connection_url
.password()
.ok_or(anyhow::anyhow!("no password"))?;
let auth = auth.to_str()?;
let token = auth.strip_prefix("Bearer ").context("bad token")?;
let jwt: JWT<NeonFields, ()> = JWT::new_encoded(token);
let token = jwk_cache.decode(&jwt).await?;
let payload = token.payload().unwrap();
policies = Some(payload.private.policies.clone());
payload
.registered
.subject
.as_deref()
.context("missing user id")?
.to_owned()
} else {
password = connection_url
.password()
.ok_or(anyhow::anyhow!("no password"))?;
// TLS certificate selector now based on SNI hostname, so if we are running here
// we are sure that SNI hostname is set to one of the configured domain names.
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
let u = connection_url.username();
if u.is_empty() {
return Err(anyhow::anyhow!("missing username"));
}
u.to_owned()
};
let hostname = connection_url
.host_str()
@@ -186,7 +216,8 @@ fn get_conn_info(
}
Ok(ConnInfo {
username: username.to_owned(),
username,
policies,
dbname: dbname.to_owned(),
hostname: hostname.to_owned(),
password: password.to_owned(),
@@ -199,12 +230,13 @@ pub async fn handle(
request: Request<Body>,
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
jwk_cache_pool: Arc<JWKSetCaches>,
session_id: uuid::Uuid,
config: &'static HttpConfig,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.sql_over_http_timeout,
handle_inner(request, sni_hostname, conn_pool, session_id),
handle_inner(request, sni_hostname, conn_pool, jwk_cache_pool, session_id),
)
.await;
let mut response = match result {
@@ -255,6 +287,7 @@ async fn handle_inner(
request: Request<Body>,
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
jwk_cache_pool: Arc<JWKSetCaches>,
session_id: uuid::Uuid,
) -> anyhow::Result<Response<Body>> {
NUM_CONNECTIONS_ACCEPTED_COUNTER
@@ -264,11 +297,15 @@ async fn handle_inner(
NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
}
// TLS certificate selector now based on SNI hostname, so if we are running here
// we are sure that SNI hostname is set to one of the configured domain names.
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
//
// Determine the destination and connection params
//
let headers = request.headers();
let conn_info = get_conn_info(headers, sni_hostname)?;
let conn_info = get_conn_info(&jwk_cache_pool, headers, &sni_hostname).await?;
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
@@ -697,6 +734,11 @@ fn _pg_array_parse(
Ok((Value::Array(entries), 0))
}
#[derive(Serialize, Deserialize)]
pub struct NeonFields {
policies: Vec<Policy>,
}
#[cfg(test)]
mod tests {
use super::*;