mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
proxy: start work on jwt auth
This commit is contained in:
91
Cargo.lock
generated
91
Cargo.lock
generated
@@ -274,7 +274,7 @@ dependencies = [
|
||||
"hex",
|
||||
"http",
|
||||
"hyper",
|
||||
"ring",
|
||||
"ring 0.16.20",
|
||||
"time",
|
||||
"tokio",
|
||||
"tower",
|
||||
@@ -703,7 +703,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"dyn-clone",
|
||||
"futures",
|
||||
"getrandom 0.2.9",
|
||||
"getrandom 0.2.11",
|
||||
"http-types",
|
||||
"log",
|
||||
"paste",
|
||||
@@ -863,6 +863,22 @@ dependencies = [
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "biscuit"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e28fc7c56c61743a01d0d1b73e4fed68b8a4f032ea3a2d4bb8c6520a33fc05a"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"data-encoding",
|
||||
"num-bigint",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"ring 0.17.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
@@ -945,11 +961,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.79"
|
||||
version = "1.0.83"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
|
||||
checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1846,9 +1863,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.9"
|
||||
version = "0.2.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
|
||||
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
@@ -2342,7 +2359,7 @@ checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378"
|
||||
dependencies = [
|
||||
"base64 0.21.1",
|
||||
"pem 1.1.1",
|
||||
"ring",
|
||||
"ring 0.16.20",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"simple_asn1",
|
||||
@@ -2382,9 +2399,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.144"
|
||||
version = "0.2.150"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1"
|
||||
checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
@@ -2691,7 +2708,7 @@ checksum = "c38841cdd844847e3e7c8d29cef9dcfed8877f8f56f9071f77843ecf3baf937f"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"chrono",
|
||||
"getrandom 0.2.9",
|
||||
"getrandom 0.2.11",
|
||||
"http",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
@@ -3474,6 +3491,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64 0.13.1",
|
||||
"biscuit",
|
||||
"bstr",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -3619,7 +3637,7 @@ version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom 0.2.9",
|
||||
"getrandom 0.2.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3660,7 +3678,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4954fbc00dcd4d8282c987710e50ba513d351400dbdd00e803a05172a90d8976"
|
||||
dependencies = [
|
||||
"pem 2.0.1",
|
||||
"ring",
|
||||
"ring 0.16.20",
|
||||
"time",
|
||||
"yasna",
|
||||
]
|
||||
@@ -3830,7 +3848,7 @@ dependencies = [
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"futures",
|
||||
"getrandom 0.2.9",
|
||||
"getrandom 0.2.11",
|
||||
"http",
|
||||
"hyper",
|
||||
"parking_lot 0.11.2",
|
||||
@@ -3851,7 +3869,7 @@ checksum = "1b97ad83c2fc18113346b7158d79732242002427c30f620fa817c1f32901e0a8"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"getrandom 0.2.9",
|
||||
"getrandom 0.2.11",
|
||||
"matchit",
|
||||
"opentelemetry",
|
||||
"reqwest",
|
||||
@@ -3882,11 +3900,25 @@ dependencies = [
|
||||
"libc",
|
||||
"once_cell",
|
||||
"spin 0.5.2",
|
||||
"untrusted",
|
||||
"untrusted 0.7.1",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"getrandom 0.2.11",
|
||||
"libc",
|
||||
"spin 0.9.8",
|
||||
"untrusted 0.9.0",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "routerify"
|
||||
version = "3.0.0"
|
||||
@@ -4003,7 +4035,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"ring 0.16.20",
|
||||
"rustls-webpki 0.101.4",
|
||||
"sct",
|
||||
]
|
||||
@@ -4035,8 +4067,8 @@ version = "0.100.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
"ring 0.16.20",
|
||||
"untrusted 0.7.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4045,8 +4077,8 @@ version = "0.101.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
"ring 0.16.20",
|
||||
"untrusted 0.7.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4191,8 +4223,8 @@ version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
"ring 0.16.20",
|
||||
"untrusted 0.7.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4311,7 +4343,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "99dc599bd6646884fc403d593cdcb9816dd67c50cff3271c01ff123617908dcd"
|
||||
dependencies = [
|
||||
"debugid",
|
||||
"getrandom 0.2.9",
|
||||
"getrandom 0.2.11",
|
||||
"hex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -4357,6 +4389,7 @@ version = "1.0.96"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
@@ -4960,7 +4993,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd5831152cb0d3f79ef5523b357319ba154795d64c7078b2daa95a803b54057f"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"ring",
|
||||
"ring 0.16.20",
|
||||
"rustls",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
@@ -5416,6 +5449,12 @@ version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.7.1"
|
||||
@@ -5517,7 +5556,7 @@ version = "1.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "345444e32442451b267fc254ae85a209c64be56d2890e601a0c37ff0c3c5ecd2"
|
||||
dependencies = [
|
||||
"getrandom 0.2.9",
|
||||
"getrandom 0.2.11",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@@ -6010,7 +6049,7 @@ dependencies = [
|
||||
"regex",
|
||||
"regex-syntax 0.7.2",
|
||||
"reqwest",
|
||||
"ring",
|
||||
"ring 0.16.20",
|
||||
"rustls",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
|
||||
@@ -68,7 +68,7 @@ webpki-roots.workspace = true
|
||||
x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
|
||||
biscuit = { version = "0.7",features = [] }
|
||||
workspace_hack.workspace = true
|
||||
tokio-util.workspace = true
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ mod hacks;
|
||||
mod link;
|
||||
|
||||
pub use link::LinkAuthError;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
|
||||
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
|
||||
@@ -319,4 +320,41 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
Test(x) => x.wake_compute().map(Some),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the password for the RLS user
|
||||
pub async fn ensure_row_level(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
dbname: String,
|
||||
username: String,
|
||||
policies: Vec<Policy>,
|
||||
) -> anyhow::Result<String> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => {
|
||||
api.ensure_row_level(extra, creds, dbname, username, policies)
|
||||
.await
|
||||
}
|
||||
Postgres(api, creds) => {
|
||||
api.ensure_row_level(extra, creds, dbname, username, policies)
|
||||
.await
|
||||
}
|
||||
Link(_) => Err(anyhow::anyhow!("not on link")),
|
||||
Test(_) => Err(anyhow::anyhow!("not on test")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(conrad): policies can be quite complex. Figure out how to configure this
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct Policy {
|
||||
table: String,
|
||||
column: String,
|
||||
}
|
||||
|
||||
// enum PolicyType {
|
||||
// ForSelect(),
|
||||
// ForUpdate()
|
||||
// }
|
||||
|
||||
@@ -3,7 +3,7 @@ pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::ClientCredentials,
|
||||
auth::{ClientCredentials, backend::Policy},
|
||||
cache::{timed_lru, TimedLru},
|
||||
compute, scram,
|
||||
};
|
||||
@@ -248,6 +248,16 @@ pub trait Api {
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
|
||||
/// Get the password for the RLS user
|
||||
async fn ensure_row_level(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
dbname: String,
|
||||
username: String,
|
||||
policies: Vec<Policy>
|
||||
) -> anyhow::Result<String>;
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
|
||||
@@ -4,7 +4,13 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{
|
||||
auth::{backend::Policy, ClientCredentials},
|
||||
compute,
|
||||
error::io_error,
|
||||
scram,
|
||||
url::ApiUrl,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
@@ -121,6 +127,18 @@ impl super::Api for Api {
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Get the password for the RLS user
|
||||
async fn ensure_row_level(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
_creds: &ClientCredentials,
|
||||
_dbname: String,
|
||||
_username: String,
|
||||
_policies: Vec<Policy>,
|
||||
) -> anyhow::Result<String> {
|
||||
Err(anyhow::anyhow!("unimplemented"))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_md5(input: &str) -> Option<[u8; 16]> {
|
||||
|
||||
@@ -5,9 +5,13 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, ApiLocks, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, http, scram};
|
||||
use crate::{
|
||||
auth::{backend::Policy, ClientCredentials},
|
||||
compute, http, scram,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
@@ -139,6 +143,55 @@ impl Api {
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_ensure_row_level(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
dbname: String,
|
||||
username: String,
|
||||
policies: Vec<Policy>,
|
||||
) -> anyhow::Result<String> {
|
||||
let project = creds.project().expect("impossible");
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.post("proxy_ensure_row_level")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name),
|
||||
("project", Some(project)),
|
||||
("dbname", Some(&dbname)),
|
||||
("username", Some(&username)),
|
||||
("options", extra.options),
|
||||
])
|
||||
.json(&EnsureRowLevelReq { policies })
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
let body = parse_body::<UserRowLevel>(response).await?;
|
||||
|
||||
Ok(body.password)
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EnsureRowLevelReq {
|
||||
policies: Vec<Policy>,
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
struct UserRowLevel {
|
||||
password: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -188,6 +241,19 @@ impl super::Api for Api {
|
||||
|
||||
Ok(cached)
|
||||
}
|
||||
|
||||
/// Get the password for the RLS user
|
||||
async fn ensure_row_level(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
dbname: String,
|
||||
username: String,
|
||||
policies: Vec<Policy>,
|
||||
) -> anyhow::Result<String> {
|
||||
self.do_ensure_row_level(extra, creds, dbname, username, policies)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
|
||||
@@ -88,6 +88,14 @@ impl Endpoint {
|
||||
self.client.get(url.into_inner())
|
||||
}
|
||||
|
||||
/// Return a [builder](RequestBuilder) for a `POST` request,
|
||||
/// appending a single `path` segment to the base endpoint URL.
|
||||
pub fn post(&self, path: &str) -> RequestBuilder {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push(path);
|
||||
self.client.post(url.into_inner())
|
||||
}
|
||||
|
||||
/// Execute a [request](reqwest::Request).
|
||||
pub async fn execute(&self, request: Request) -> Result<Response, Error> {
|
||||
self.client.execute(request).await
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
//! Handles both SQL over HTTP and SQL over Websockets.
|
||||
|
||||
mod conn_pool;
|
||||
pub mod jwt_auth;
|
||||
mod sql_over_http;
|
||||
mod websocket;
|
||||
|
||||
use anyhow::bail;
|
||||
use dashmap::DashMap;
|
||||
use hyper::StatusCode;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
@@ -31,6 +33,8 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
use self::jwt_auth::JWKSetCaches;
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
ws_listener: TcpListener,
|
||||
@@ -41,6 +45,9 @@ pub async fn task_main(
|
||||
}
|
||||
|
||||
let conn_pool = conn_pool::GlobalConnPool::new(config);
|
||||
let jwk_cache_pool = Arc::new(JWKSetCaches {
|
||||
map: DashMap::new(),
|
||||
});
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::spawn({
|
||||
@@ -85,6 +92,7 @@ pub async fn task_main(
|
||||
let remote_addr = io.inner.remote_addr();
|
||||
let sni_name = tls.server_name().map(|s| s.to_string());
|
||||
let conn_pool = conn_pool.clone();
|
||||
let jwk_cache_pool = jwk_cache_pool.clone();
|
||||
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
@@ -96,13 +104,20 @@ pub async fn task_main(
|
||||
move |req: Request<Body>| {
|
||||
let sni_name = sni_name.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
let jwk_cache_pool = jwk_cache_pool.clone();
|
||||
|
||||
async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
request_handler(
|
||||
req, config, conn_pool, cancel_map, session_id, sni_name,
|
||||
req,
|
||||
config,
|
||||
conn_pool,
|
||||
jwk_cache_pool,
|
||||
cancel_map,
|
||||
session_id,
|
||||
sni_name,
|
||||
)
|
||||
.instrument(info_span!(
|
||||
"serverless",
|
||||
@@ -167,6 +182,7 @@ async fn request_handler(
|
||||
mut request: Request<Body>,
|
||||
config: &'static ProxyConfig,
|
||||
conn_pool: Arc<conn_pool::GlobalConnPool>,
|
||||
jwk_cache_pool: Arc<JWKSetCaches>,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
session_id: uuid::Uuid,
|
||||
sni_hostname: Option<String>,
|
||||
@@ -204,6 +220,7 @@ async fn request_handler(
|
||||
request,
|
||||
sni_hostname,
|
||||
conn_pool,
|
||||
jwk_cache_pool,
|
||||
session_id,
|
||||
&config.http_config,
|
||||
)
|
||||
@@ -214,7 +231,7 @@ async fn request_handler(
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
.header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In",
|
||||
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Authorization",
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
|
||||
@@ -21,7 +21,8 @@ use tokio::time;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
|
||||
|
||||
use crate::{
|
||||
auth, console,
|
||||
auth::{self, backend::Policy},
|
||||
console,
|
||||
proxy::{
|
||||
neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER,
|
||||
NUM_DB_CONNECTIONS_OPENED_COUNTER,
|
||||
@@ -45,6 +46,8 @@ pub struct ConnInfo {
|
||||
pub hostname: String,
|
||||
pub password: String,
|
||||
pub options: Option<String>,
|
||||
/// row level security mode enabled
|
||||
pub policies: Option<Vec<Policy>>,
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
@@ -365,6 +368,7 @@ struct TokioMechanism<'a> {
|
||||
conn_info: &'a ConnInfo,
|
||||
session_id: uuid::Uuid,
|
||||
conn_id: uuid::Uuid,
|
||||
password: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -384,6 +388,7 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
||||
timeout,
|
||||
self.conn_id,
|
||||
self.session_id,
|
||||
self.password.as_deref(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -431,11 +436,26 @@ async fn connect_to_compute(
|
||||
.await?
|
||||
.context("missing cache entry from wake_compute")?;
|
||||
|
||||
let mut password = None;
|
||||
if let Some(policies) = &conn_info.policies {
|
||||
password = Some(
|
||||
creds
|
||||
.ensure_row_level(
|
||||
&extra,
|
||||
conn_info.dbname.to_owned(),
|
||||
conn_info.username.to_owned(),
|
||||
policies.clone(),
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
crate::proxy::connect_to_compute(
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
session_id,
|
||||
password,
|
||||
},
|
||||
node_info,
|
||||
&extra,
|
||||
@@ -451,12 +471,13 @@ async fn connect_to_compute_once(
|
||||
timeout: time::Duration,
|
||||
conn_id: uuid::Uuid,
|
||||
mut session: uuid::Uuid,
|
||||
password: Option<&str>,
|
||||
) -> Result<ClientInner, tokio_postgres::Error> {
|
||||
let mut config = (*node_info.config).clone();
|
||||
|
||||
let (client, mut connection) = config
|
||||
.user(&conn_info.username)
|
||||
.password(&conn_info.password)
|
||||
.password(password.unwrap_or(&conn_info.password))
|
||||
.dbname(&conn_info.dbname)
|
||||
.connect_timeout(timeout)
|
||||
.connect(tokio_postgres::NoTls)
|
||||
|
||||
98
proxy/src/serverless/jwt_auth.rs
Normal file
98
proxy/src/serverless/jwt_auth.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
// https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use biscuit::{
|
||||
jwk::{JWKSet, JWK},
|
||||
jws, CompactPart,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use reqwest::{IntoUrl, Url};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct JWKSetCaches {
|
||||
pub map: DashMap<Url, Arc<JWKSetCache>>,
|
||||
}
|
||||
|
||||
impl JWKSetCaches {
|
||||
pub async fn get_cache(&self, url: impl IntoUrl) -> anyhow::Result<Arc<JWKSetCache>> {
|
||||
let url = url.into_url()?;
|
||||
if let Some(x) = self.map.get(&url) {
|
||||
return Ok(x.clone());
|
||||
}
|
||||
let cache = JWKSetCache::new(url.clone()).await?;
|
||||
let cache = Arc::new(cache);
|
||||
self.map.insert(url, cache.clone());
|
||||
Ok(cache)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JWKSetCache {
|
||||
url: Url,
|
||||
current: RwLock<biscuit::jwk::JWKSet<()>>,
|
||||
}
|
||||
|
||||
impl JWKSetCache {
|
||||
pub async fn new(url: impl IntoUrl) -> anyhow::Result<Self> {
|
||||
let url = url.into_url()?;
|
||||
let current = reqwest::get(url.clone()).await?.json().await?;
|
||||
Ok(Self {
|
||||
url,
|
||||
current: RwLock::new(current),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get(&self, kid: &str) -> anyhow::Result<JWK<()>> {
|
||||
let current = self.current.read().await.clone();
|
||||
if let Some(key) = current.find(kid) {
|
||||
return Ok(key.clone());
|
||||
}
|
||||
let new = reqwest::get(self.url.clone()).await?.json().await?;
|
||||
if new == current {
|
||||
bail!("not found")
|
||||
}
|
||||
*self.current.write().await = new;
|
||||
current.find(kid).cloned().context("not found")
|
||||
}
|
||||
|
||||
pub async fn decode<T, H>(
|
||||
&self,
|
||||
token: &jws::Compact<T, H>,
|
||||
) -> anyhow::Result<jws::Compact<T, H>>
|
||||
where
|
||||
T: CompactPart,
|
||||
H: Serialize + DeserializeOwned,
|
||||
{
|
||||
let current = self.current.read().await.clone();
|
||||
match token.decode_with_jwks(¤t, None) {
|
||||
Ok(t) => Ok(t),
|
||||
Err(biscuit::errors::Error::ValidationError(
|
||||
biscuit::errors::ValidationError::KeyNotFound,
|
||||
)) => {
|
||||
let new: JWKSet<()> = reqwest::get(self.url.clone()).await?.json().await?;
|
||||
if new == current {
|
||||
bail!("not found")
|
||||
}
|
||||
*self.current.write().await = new.clone();
|
||||
token.decode_with_jwks(&new, None).context("error")
|
||||
// current.find(kid).cloned().context("not found")
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::JWKSetCache;
|
||||
#[tokio::test]
|
||||
async fn jwkset() {
|
||||
let cache =
|
||||
JWKSetCache::new("https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json")
|
||||
.await
|
||||
.unwrap();
|
||||
dbg!(cache.get("ins_2YFechxysnwZcZN6TDHEz6u6w6v").await.unwrap());
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,20 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::Context;
|
||||
use biscuit::JWT;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
use hyper::header;
|
||||
use hyper::header::AUTHORIZATION;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
@@ -26,11 +31,13 @@ use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
use utils::http::json::json_response;
|
||||
|
||||
use crate::auth::backend::Policy;
|
||||
use crate::config::HttpConfig;
|
||||
use crate::proxy::{NUM_CONNECTIONS_ACCEPTED_COUNTER, NUM_CONNECTIONS_CLOSED_COUNTER};
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
use super::jwt_auth::JWKSetCaches;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct QueryData {
|
||||
@@ -118,9 +125,10 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
async fn get_conn_info(
|
||||
jwk_cache_pool: &JWKSetCaches,
|
||||
headers: &HeaderMap,
|
||||
sni_hostname: Option<String>,
|
||||
sni_hostname: &str,
|
||||
) -> Result<ConnInfo, anyhow::Error> {
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
@@ -144,18 +152,40 @@ fn get_conn_info(
|
||||
.next()
|
||||
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||
|
||||
let username = connection_url.username();
|
||||
if username.is_empty() {
|
||||
return Err(anyhow::anyhow!("missing username"));
|
||||
}
|
||||
let mut password = "";
|
||||
let mut policies = None;
|
||||
let authorization = headers.get(AUTHORIZATION);
|
||||
let username = if let Some(auth) = authorization {
|
||||
// TODO: introduce control plane API to fetch this
|
||||
let jwks_url = match sni_hostname {
|
||||
"foo" => "https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json",
|
||||
_ => anyhow::bail!("invalid sni name"),
|
||||
};
|
||||
let jwk_cache = jwk_cache_pool.get_cache(jwks_url).await?;
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
.ok_or(anyhow::anyhow!("no password"))?;
|
||||
let auth = auth.to_str()?;
|
||||
let token = auth.strip_prefix("Bearer ").context("bad token")?;
|
||||
let jwt: JWT<NeonFields, ()> = JWT::new_encoded(token);
|
||||
let token = jwk_cache.decode(&jwt).await?;
|
||||
let payload = token.payload().unwrap();
|
||||
policies = Some(payload.private.policies.clone());
|
||||
payload
|
||||
.registered
|
||||
.subject
|
||||
.as_deref()
|
||||
.context("missing user id")?
|
||||
.to_owned()
|
||||
} else {
|
||||
password = connection_url
|
||||
.password()
|
||||
.ok_or(anyhow::anyhow!("no password"))?;
|
||||
|
||||
// TLS certificate selector now based on SNI hostname, so if we are running here
|
||||
// we are sure that SNI hostname is set to one of the configured domain names.
|
||||
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
|
||||
let u = connection_url.username();
|
||||
if u.is_empty() {
|
||||
return Err(anyhow::anyhow!("missing username"));
|
||||
}
|
||||
u.to_owned()
|
||||
};
|
||||
|
||||
let hostname = connection_url
|
||||
.host_str()
|
||||
@@ -186,7 +216,8 @@ fn get_conn_info(
|
||||
}
|
||||
|
||||
Ok(ConnInfo {
|
||||
username: username.to_owned(),
|
||||
username,
|
||||
policies,
|
||||
dbname: dbname.to_owned(),
|
||||
hostname: hostname.to_owned(),
|
||||
password: password.to_owned(),
|
||||
@@ -199,12 +230,13 @@ pub async fn handle(
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
jwk_cache_pool: Arc<JWKSetCaches>,
|
||||
session_id: uuid::Uuid,
|
||||
config: &'static HttpConfig,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.sql_over_http_timeout,
|
||||
handle_inner(request, sni_hostname, conn_pool, session_id),
|
||||
handle_inner(request, sni_hostname, conn_pool, jwk_cache_pool, session_id),
|
||||
)
|
||||
.await;
|
||||
let mut response = match result {
|
||||
@@ -255,6 +287,7 @@ async fn handle_inner(
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
jwk_cache_pool: Arc<JWKSetCaches>,
|
||||
session_id: uuid::Uuid,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER
|
||||
@@ -264,11 +297,15 @@ async fn handle_inner(
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
|
||||
}
|
||||
|
||||
// TLS certificate selector now based on SNI hostname, so if we are running here
|
||||
// we are sure that SNI hostname is set to one of the configured domain names.
|
||||
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
|
||||
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
let conn_info = get_conn_info(headers, sni_hostname)?;
|
||||
let conn_info = get_conn_info(&jwk_cache_pool, headers, &sni_hostname).await?;
|
||||
|
||||
// Determine the output options. Default behaviour is 'false'. Anything that is not
|
||||
// strictly 'true' assumed to be false.
|
||||
@@ -697,6 +734,11 @@ fn _pg_array_parse(
|
||||
Ok((Value::Array(entries), 0))
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct NeonFields {
|
||||
policies: Vec<Policy>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user