Compare commits

...

20 Commits

Author SHA1 Message Date
Conrad Ludgate
f3f7d0d3f1 zero-copy jwt claim validation 2024-09-30 12:47:07 +01:00
Conrad Ludgate
0724df1d3f stash 2024-09-29 20:29:26 +01:00
Conrad Ludgate
4d47049b00 split up jwt tests 2024-09-27 16:31:49 +01:00
Conrad Ludgate
5687384a8e remove deref impl 2024-09-27 11:43:34 +01:00
Conrad Ludgate
249f5ea17d cleaner local-proxy conn error code 2024-09-27 11:43:34 +01:00
Conrad Ludgate
6abcc1f298 add explicit panic reason 2024-09-27 11:43:34 +01:00
Conrad Ludgate
3e97cf0d6e refine missing credentials error 2024-09-27 11:43:34 +01:00
Conrad Ludgate
054ef4988b update certification comment 2024-09-27 11:43:34 +01:00
Conrad Ludgate
5202cd75b5 only forward expected headers 2024-09-27 11:43:34 +01:00
Conrad Ludgate
f475dac0e6 keepalive while idle 2024-09-27 11:43:34 +01:00
Conrad Ludgate
a4100373e5 fix common name parsing 2024-09-27 11:43:34 +01:00
Conrad Ludgate
040d8cf4f6 fix common name parsing 2024-09-27 11:43:34 +01:00
Conrad Ludgate
75bfd57e01 add authbroker cli flag and fix http2 ka 2024-09-27 11:43:34 +01:00
Conrad Ludgate
4bc2686dee small tweaks 2024-09-27 11:43:34 +01:00
Conrad Ludgate
8e7d2aab76 put it all together 2024-09-27 11:43:34 +01:00
Conrad Ludgate
2703abccc7 start on http2 local proxy connection pool 2024-09-27 11:43:34 +01:00
Conrad Ludgate
76515cdae3 split out auth info from conn info, return the jwt as the auth keys 2024-09-27 11:43:34 +01:00
Conrad Ludgate
08c7f933a3 add support for console backend jwt 2024-09-27 11:43:34 +01:00
Conrad Ludgate
4ad3aa7c96 update doc comment for get_with_url 2024-09-27 10:24:50 +01:00
Conrad Ludgate
9c59e3b4b9 proxy: add jwks endpoint to control plane and mock providers 2024-09-27 10:24:43 +01:00
19 changed files with 1645 additions and 270 deletions

View File

@@ -80,6 +80,14 @@ pub(crate) trait TestBackend: Send + Sync + 'static {
fn get_allowed_ips_and_secret( fn get_allowed_ips_and_secret(
&self, &self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>; ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
fn dyn_clone(&self) -> Box<dyn TestBackend>;
}
#[cfg(test)]
impl Clone for Box<dyn TestBackend> {
fn clone(&self) -> Self {
TestBackend::dyn_clone(&**self)
}
} }
impl std::fmt::Display for Backend<'_, (), ()> { impl std::fmt::Display for Backend<'_, (), ()> {
@@ -557,7 +565,7 @@ mod tests {
stream::{PqStream, Stream}, stream::{PqStream, Stream},
}; };
use super::{auth_quirks, AuthRateLimiter}; use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
struct Auth { struct Auth {
ips: Vec<IpPattern>, ips: Vec<IpPattern>,
@@ -585,6 +593,14 @@ mod tests {
)) ))
} }
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
_endpoint: crate::EndpointId,
) -> anyhow::Result<Vec<super::jwt::AuthRule>> {
unimplemented!()
}
async fn wake_compute( async fn wake_compute(
&self, &self,
_ctx: &RequestMonitoring, _ctx: &RequestMonitoring,
@@ -595,12 +611,15 @@ mod tests {
} }
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig { static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1), thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5), scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true, rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64, rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true, ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: false,
}); });
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage { async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {

View File

@@ -1,4 +1,5 @@
use std::{ use std::{
borrow::Cow,
future::Future, future::Future,
sync::Arc, sync::Arc,
time::{Duration, SystemTime}, time::{Duration, SystemTime},
@@ -8,7 +9,10 @@ use anyhow::{bail, ensure, Context};
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use dashmap::DashMap; use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo; use jose_jwk::crypto::KeyInfo;
use serde::{Deserialize, Deserializer}; use serde::{
de::{DeserializeSeed, IgnoredAny, Visitor},
Deserializer,
};
use signature::Verifier; use signature::Verifier;
use tokio::time::Instant; use tokio::time::Instant;
@@ -33,6 +37,7 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send; ) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
} }
#[derive(Debug, Clone)]
pub(crate) struct AuthRule { pub(crate) struct AuthRule {
pub(crate) id: String, pub(crate) id: String,
pub(crate) jwks_url: url::Url, pub(crate) jwks_url: url::Url,
@@ -303,35 +308,21 @@ impl JwkCacheEntryLock {
} }
key => bail!("unsupported key type {key:?}"), key => bail!("unsupported key type {key:?}"),
}; };
tracing::debug!("JWT signature valid");
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD) let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?; .context("Provided authentication token is not a valid JWT encoding")?;
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
.context("Provided authentication token is not a valid JWT encoding")?;
tracing::debug!(?payload, "JWT signature valid with claims"); let validator = JwtValidator {
expected_audience,
current_time: SystemTime::now(),
clock_skew_leeway: CLOCK_SKEW_LEEWAY,
};
match (expected_audience, payload.audience) { let payload = validator
// check the audience matches .deserialize(&mut serde_json::Deserializer::from_slice(&payload))?;
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
// the audience is expected but is missing
(Some(_), None) => bail!("invalid JWT token audience"),
// we don't care for the audience field
(None, _) => {}
}
let now = SystemTime::now(); tracing::debug!(?payload, "JWT claims valid");
if let Some(exp) = payload.expiration {
ensure!(now < exp + CLOCK_SKEW_LEEWAY, "JWT token has expired");
}
if let Some(nbf) = payload.not_before {
ensure!(
nbf < now + CLOCK_SKEW_LEEWAY,
"JWT token is not yet ready to use"
);
}
Ok(()) Ok(())
} }
@@ -419,37 +410,184 @@ struct JwtHeader<'a> {
key_id: Option<&'a str>, key_id: Option<&'a str>,
} }
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1> struct JwtValidator<'a> {
#[derive(serde::Deserialize, serde::Serialize, Debug)] expected_audience: Option<&'a str>,
struct JwtPayload<'a> { current_time: SystemTime,
/// Audience - Recipient for which the JWT is intended clock_skew_leeway: Duration,
#[serde(rename = "aud")]
audience: Option<&'a str>,
/// Expiration - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
expiration: Option<SystemTime>,
/// Not before - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
not_before: Option<SystemTime>,
// the following entries are only extracted for the sake of debug logging.
/// Issuer of the JWT
#[serde(rename = "iss")]
issuer: Option<&'a str>,
/// Subject of the JWT (the user)
#[serde(rename = "sub")]
subject: Option<&'a str>,
/// Unique token identifier
#[serde(rename = "jti")]
jwt_id: Option<&'a str>,
/// Unique session identifier
#[serde(rename = "sid")]
session_id: Option<&'a str>,
} }
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> { impl<'de> DeserializeSeed<'de> for JwtValidator<'_> {
let d = <Option<u64>>::deserialize(d)?; type Value = JwtPayload<'de>;
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
impl<'de> Visitor<'de> for JwtValidator<'_> {
type Value = JwtPayload<'de>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a JWT payload")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut payload = JwtPayload {
issuer: None,
subject: None,
jwt_id: None,
session_id: None,
};
let mut aud = false;
while let Some(key) = map.next_key()? {
match key {
"iss" if payload.issuer.is_none() => {
payload.issuer = Some(map.next_value()?);
}
"sub" if payload.subject.is_none() => {
payload.subject = Some(map.next_value()?);
}
"jit" if payload.jwt_id.is_none() => {
payload.jwt_id = Some(map.next_value()?);
}
"sid" if payload.session_id.is_none() => {
payload.session_id = Some(map.next_value()?);
}
"exp" => {
let exp = map.next_value::<u64>()?;
let exp = SystemTime::UNIX_EPOCH + Duration::from_secs(exp);
if self.current_time > exp + self.clock_skew_leeway {
return Err(serde::de::Error::custom("JWT token has expired"));
}
}
"nbf" => {
let nbf = map.next_value::<u64>()?;
let nbf = SystemTime::UNIX_EPOCH + Duration::from_secs(nbf);
if self.current_time + self.clock_skew_leeway < nbf {
return Err(serde::de::Error::custom(
"JWT token is not yet ready to use",
));
}
}
"aud" => {
if let Some(expected_audience) = self.expected_audience {
map.next_value_seed(AudienceValidator { expected_audience })?;
aud = true;
} else {
map.next_value::<IgnoredAny>()?;
}
}
_ => map.next_value::<IgnoredAny>().map(|IgnoredAny| ())?,
}
}
if self.expected_audience.is_some() && !aud {
return Err(serde::de::Error::custom("invalid JWT token audience"));
}
Ok(payload)
}
}
deserializer.deserialize_map(self)
}
}
struct AudienceValidator<'a> {
expected_audience: &'a str,
}
impl<'de> DeserializeSeed<'de> for AudienceValidator<'_> {
type Value = ();
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
impl<'de> Visitor<'de> for AudienceValidator<'_> {
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a single string or an array of strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if self.expected_audience == v {
Ok(())
} else {
Err(E::custom("invalid JWT token audience"))
}
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
while let Some(v) = seq.next_element_seed(SingleAudienceValidator {
expected_audience: self.expected_audience,
})? {
if v {
return Ok(());
}
}
Err(serde::de::Error::custom("invalid JWT token audience"))
}
}
deserializer.deserialize_any(self)
}
}
struct SingleAudienceValidator<'a> {
expected_audience: &'a str,
}
impl<'de> DeserializeSeed<'de> for SingleAudienceValidator<'_> {
type Value = bool;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
impl<'de> Visitor<'de> for SingleAudienceValidator<'_> {
type Value = bool;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a single audience string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(self.expected_audience == v)
}
}
deserializer.deserialize_any(self)
}
}
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
// the following entries are only extracted for the sake of debug logging.
#[derive(Debug)]
#[allow(dead_code)]
struct JwtPayload<'a> {
/// Issuer of the JWT
issuer: Option<Cow<'a, str>>,
/// Subject of the JWT (the user)
subject: Option<Cow<'a, str>>,
/// Unique token identifier
jwt_id: Option<Cow<'a, str>>,
/// Unique session identifier
session_id: Option<Cow<'a, str>>,
} }
struct JwkRenewalPermit<'a> { struct JwkRenewalPermit<'a> {
@@ -530,6 +668,8 @@ mod tests {
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rsa::pkcs8::DecodePrivateKey; use rsa::pkcs8::DecodePrivateKey;
use serde::Serialize;
use serde_json::json;
use signature::Signer; use signature::Signer;
use tokio::net::TcpListener; use tokio::net::TcpListener;
@@ -562,18 +702,36 @@ mod tests {
} }
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String { fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let body = typed_json::json! {{
"exp": now + 3600,
"nbf": now,
"aud": ["audience1", "neon", "audience2"],
"sub": "user1",
"sid": "session1",
"jti": "token1",
"iss": "neon-testing",
}};
build_custom_jwt_payload(kid, body, sig)
}
fn build_custom_jwt_payload(
kid: String,
body: impl Serialize,
sig: jose_jwa::Signing,
) -> String {
let header = JwtHeader { let header = JwtHeader {
typ: "JWT", typ: "JWT",
algorithm: jose_jwa::Algorithm::Signing(sig), algorithm: jose_jwa::Algorithm::Signing(sig),
key_id: Some(&kid), key_id: Some(&kid),
}; };
let body = typed_json::json! {{
"exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
}};
let header = let header =
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD); base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD); let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
format!("{header}.{body}") format!("{header}.{body}")
} }
@@ -588,6 +746,16 @@ mod tests {
format!("{payload}.{sig}") format!("{payload}.{sig}")
} }
fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
use p256::ecdsa::{Signature, SigningKey};
let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
format!("{payload}.{sig}")
}
fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String { fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
use rsa::pkcs1v15::SigningKey; use rsa::pkcs1v15::SigningKey;
use rsa::signature::SignatureEncoding; use rsa::signature::SignatureEncoding;
@@ -659,37 +827,34 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
-----END PRIVATE KEY----- -----END PRIVATE KEY-----
"; ";
#[tokio::test] #[derive(Clone)]
async fn renew() { struct Fetch(Vec<AuthRule>);
let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into());
let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into());
let (ec1, jwk3) = new_ec_jwk("3".into());
let (ec2, jwk4) = new_ec_jwk("4".into());
let foo_jwks = jose_jwk::JwkSet { impl FetchAuthRules for Fetch {
keys: vec![jwk1, jwk3], async fn fetch_auth_rules(
}; &self,
let bar_jwks = jose_jwk::JwkSet { _ctx: &RequestMonitoring,
keys: vec![jwk2, jwk4], _endpoint: EndpointId,
}; ) -> anyhow::Result<Vec<AuthRule>> {
Ok(self.0.clone())
}
}
async fn jwks_server(
router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
) -> SocketAddr {
let router = Arc::new(router);
let service = service_fn(move |req| { let service = service_fn(move |req| {
let foo_jwks = foo_jwks.clone(); let router = Arc::clone(&router);
let bar_jwks = bar_jwks.clone();
async move { async move {
let jwks = match req.uri().path() { match router(req.uri().path()) {
"/foo" => &foo_jwks, Some(body) => Response::builder()
"/bar" => &bar_jwks, .status(200)
_ => { .body(Full::new(Bytes::from(body))),
return Response::builder() None => Response::builder()
.status(404) .status(404)
.body(Full::new(Bytes::new())); .body(Full::new(Bytes::new())),
} }
};
let body = serde_json::to_vec(jwks).unwrap();
Response::builder()
.status(200)
.body(Full::new(Bytes::from(body)))
} }
}); });
@@ -704,84 +869,61 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
} }
}); });
let client = reqwest::Client::new(); addr
}
#[derive(Clone)] #[tokio::test]
struct Fetch(SocketAddr, Vec<RoleNameInt>); async fn check_jwt_happy_path() {
let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
let (ec1, jwk3) = new_ec_jwk("ec1".into());
let (ec2, jwk4) = new_ec_jwk("ec2".into());
impl FetchAuthRules for Fetch { let foo_jwks = jose_jwk::JwkSet {
async fn fetch_auth_rules( keys: vec![jwk1, jwk3],
&self, };
_ctx: &RequestMonitoring, let bar_jwks = jose_jwk::JwkSet {
_endpoint: EndpointId, keys: vec![jwk2, jwk4],
) -> anyhow::Result<Vec<AuthRule>> { };
Ok(vec![
AuthRule { let jwks_addr = jwks_server(move |path| match path {
id: "foo".to_owned(), "/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
jwks_url: format!("http://{}/foo", self.0).parse().unwrap(), "/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
audience: None, _ => None,
role_names: self.1.clone(), })
}, .await;
AuthRule {
id: "bar".to_owned(),
jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
audience: None,
role_names: self.1.clone(),
},
])
}
}
let role_name1 = RoleName::from("anonymous"); let role_name1 = RoleName::from("anonymous");
let role_name2 = RoleName::from("authenticated"); let role_name2 = RoleName::from("authenticated");
let fetch = Fetch( let roles = vec![
addr, RoleNameInt::from(&role_name1),
vec![ RoleNameInt::from(&role_name2),
RoleNameInt::from(&role_name1), ];
RoleNameInt::from(&role_name2), let rules = vec![
], AuthRule {
); id: "foo".to_owned(),
jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
audience: None,
role_names: roles.clone(),
},
AuthRule {
id: "bar".to_owned(),
jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
audience: None,
role_names: roles.clone(),
},
];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let endpoint = EndpointId::from("ep"); let endpoint = EndpointId::from("ep");
let jwk_cache = Arc::new(JwkCacheEntryLock::default()); let jwt1 = new_rsa_jwt("rs1".into(), rs1);
let jwt2 = new_rsa_jwt("rs2".into(), rs2);
let jwt1 = new_rsa_jwt("1".into(), rs1); let jwt3 = new_ec_jwt("ec1".into(), &ec1);
let jwt2 = new_rsa_jwt("2".into(), rs2); let jwt4 = new_ec_jwt("ec2".into(), &ec2);
let jwt3 = new_ec_jwt("3".into(), &ec1);
let jwt4 = new_ec_jwt("4".into(), &ec2);
// had the wrong kid, therefore will have the wrong ecdsa signature
let bad_jwt = new_ec_jwt("3".into(), &ec2);
// this role_name is not accepted
let bad_role_name = RoleName::from("cloud_admin");
let err = jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&bad_jwt,
&client,
endpoint.clone(),
&role_name1,
&fetch,
)
.await
.unwrap_err();
assert!(err.to_string().contains("signature error"));
let err = jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&jwt1,
&client,
endpoint.clone(),
&bad_role_name,
&fetch,
)
.await
.unwrap_err();
assert!(err.to_string().contains("jwk not found"));
let tokens = [jwt1, jwt2, jwt3, jwt4]; let tokens = [jwt1, jwt2, jwt3, jwt4];
let role_names = [role_name1, role_name2]; let role_names = [role_name1, role_name2];
@@ -790,15 +932,194 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
jwk_cache jwk_cache
.check_jwt( .check_jwt(
&RequestMonitoring::test(), &RequestMonitoring::test(),
token,
&client,
endpoint.clone(), endpoint.clone(),
role, role,
&fetch, &fetch,
token,
) )
.await .await
.unwrap(); .unwrap();
} }
} }
} }
#[tokio::test]
async fn check_jwt_invalid_signature() {
let (_, jwk) = new_ec_jwk("1".into());
let (key, _) = new_ec_jwk("1".into());
// has a matching kid, but signed by the wrong key
let bad_jwt = new_ec_jwt("1".into(), &key);
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;
let role = RoleName::from("authenticated");
let rules = vec![AuthRule {
id: String::new(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: None,
role_names: vec![RoleNameInt::from(&role)],
}];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let ep = EndpointId::from("ep");
let ctx = RequestMonitoring::test();
let err = jwk_cache
.check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
.await
.unwrap_err();
assert!(
err.to_string().contains("signature error"),
"expected \"signature error\", got {err:?}"
);
}
#[tokio::test]
async fn check_jwt_unknown_role() {
let (key, jwk) = new_rsa_jwk(RS1, "1".into());
let jwt = new_rsa_jwt("1".into(), key);
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;
let role = RoleName::from("authenticated");
let rules = vec![AuthRule {
id: String::new(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: None,
role_names: vec![RoleNameInt::from(&role)],
}];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let ep = EndpointId::from("ep");
// this role_name is not accepted
let bad_role_name = RoleName::from("cloud_admin");
let ctx = RequestMonitoring::test();
let err = jwk_cache
.check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
.await
.unwrap_err();
assert!(
err.to_string().contains("jwk not found"),
"expected \"jwk not found\", got {err:?}"
);
}
#[tokio::test]
async fn check_jwt_invalid_claims() {
let (key, jwk) = new_ec_jwk("1".into());
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
struct Test {
body: serde_json::Value,
error: &'static str,
}
let table = vec![
Test {
body: json! {{
"nbf": now + 60,
"aud": "neon",
}},
error: "JWT token is not yet ready to use",
},
Test {
body: json! {{
"exp": now - 60,
"aud": ["neon"],
}},
error: "JWT token has expired",
},
Test {
body: json! {{
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": [],
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": "foo",
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": ["foo"],
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": ["foo", "bar"],
}},
error: "invalid JWT token audience",
},
];
let role = RoleName::from("authenticated");
let rules = vec![AuthRule {
id: String::new(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: Some("neon".to_string()),
role_names: vec![RoleNameInt::from(&role)],
}];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let ep = EndpointId::from("ep");
let ctx = RequestMonitoring::test();
for test in table {
let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
match jwk_cache
.check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
.await
{
Err(err) if err.to_string().contains(test.error) => {}
Err(err) => {
panic!("expected {:?}, got {err:?}", test.error)
}
Ok(()) => {
panic!("expected {:?}, got ok", test.error)
}
}
}
}
} }

View File

@@ -14,17 +14,15 @@ use crate::{
EndpointId, EndpointId,
}; };
use super::jwt::{AuthRule, FetchAuthRules, JwkCache}; use super::jwt::{AuthRule, FetchAuthRules};
pub struct LocalBackend { pub struct LocalBackend {
pub(crate) jwks_cache: JwkCache,
pub(crate) node_info: NodeInfo, pub(crate) node_info: NodeInfo,
} }
impl LocalBackend { impl LocalBackend {
pub fn new(postgres_addr: SocketAddr) -> Self { pub fn new(postgres_addr: SocketAddr) -> Self {
LocalBackend { LocalBackend {
jwks_cache: JwkCache::default(),
node_info: NodeInfo { node_info: NodeInfo {
config: { config: {
let mut cfg = ConnCfg::new(); let mut cfg = ConnCfg::new();

View File

@@ -6,7 +6,10 @@ use compute_api::spec::LocalProxySpec;
use dashmap::DashMap; use dashmap::DashMap;
use futures::future::Either; use futures::future::Either;
use proxy::{ use proxy::{
auth::backend::local::{LocalBackend, JWKS_ROLE_MAP}, auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
cancellation::CancellationHandlerMain, cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig}, config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
console::{ console::{
@@ -267,12 +270,15 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
allow_self_signed_compute: false, allow_self_signed_compute: false,
http_config, http_config,
authentication_config: AuthenticationConfig { authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0), thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10), scram_protocol_timeout: Duration::from_secs(10),
rate_limiter_enabled: false, rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]), rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64, rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true, ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: true,
}, },
require_client_ip: false, require_client_ip: false,
handshake_timeout: Duration::from_secs(10), handshake_timeout: Duration::from_secs(10),

View File

@@ -8,6 +8,7 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_config::Region; use aws_config::Region;
use futures::future::Either; use futures::future::Either;
use proxy::auth; use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter; use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::MaybeOwned; use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap; use proxy::cancellation::CancelMap;
@@ -102,6 +103,9 @@ struct ProxyCliArgs {
default_value = "http://localhost:3000/authenticate_proxy_request/" default_value = "http://localhost:3000/authenticate_proxy_request/"
)] )]
auth_endpoint: String, auth_endpoint: String,
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_auth_broker: bool,
/// path to TLS key for client postgres connections /// path to TLS key for client postgres connections
/// ///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir /// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
@@ -382,9 +386,27 @@ async fn main() -> anyhow::Result<()> {
info!("Starting mgmt on {mgmt_address}"); info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?; let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_address: SocketAddr = args.proxy.parse()?; let proxy_listener = if !args.is_auth_broker {
info!("Starting proxy on {proxy_address}"); let proxy_address: SocketAddr = args.proxy.parse()?;
let proxy_listener = TcpListener::bind(proxy_address).await?; info!("Starting proxy on {proxy_address}");
Some(TcpListener::bind(proxy_address).await?)
} else {
None
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
None
};
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
let cancel_map = CancelMap::default(); let cancel_map = CancelMap::default();
@@ -430,21 +452,17 @@ async fn main() -> anyhow::Result<()> {
// client facing tasks. these will exit on error or on cancellation // client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(()) // cancellation returns Ok(())
let mut client_tasks = JoinSet::new(); let mut client_tasks = JoinSet::new();
client_tasks.spawn(proxy::proxy::task_main( if let Some(proxy_listener) = proxy_listener {
config, client_tasks.spawn(proxy::proxy::task_main(
proxy_listener, config,
cancellation_token.clone(), proxy_listener,
cancellation_handler.clone(), cancellation_token.clone(),
endpoint_rate_limiter.clone(), cancellation_handler.clone(),
)); endpoint_rate_limiter.clone(),
));
// TODO: rename the argument to something like serverless. }
// It now covers more than just websockets, it also covers SQL over HTTP.
if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
let serverless_listener = TcpListener::bind(serverless_address).await?;
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main( client_tasks.spawn(serverless::task_main(
config, config,
serverless_listener, serverless_listener,
@@ -674,7 +692,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
)?; )?;
let http_config = HttpConfig { let http_config = HttpConfig {
accept_websockets: true, accept_websockets: !args.is_auth_broker,
pool_options: GlobalConnPoolOptions { pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
@@ -689,12 +707,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
}; };
let authentication_config = AuthenticationConfig { let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool, thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout, scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled, rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy, ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
}; };
let config = Box::leak(Box::new(ProxyConfig { let config = Box::leak(Box::new(ProxyConfig {

View File

@@ -1,5 +1,8 @@
use crate::{ use crate::{
auth::{self, backend::AuthRateLimiter}, auth::{
self,
backend::{jwt::JwkCache, AuthRateLimiter},
},
console::locks::ApiLocks, console::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}, rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool, scram::threadpool::ThreadPool,
@@ -67,6 +70,9 @@ pub struct AuthenticationConfig {
pub rate_limiter: AuthRateLimiter, pub rate_limiter: AuthRateLimiter,
pub rate_limit_ip_subnet: u8, pub rate_limit_ip_subnet: u8,
pub ip_allowlist_check_enabled: bool, pub ip_allowlist_check_enabled: bool,
pub jwks_cache: JwkCache,
pub is_auth_broker: bool,
pub accept_jwts: bool,
} }
impl TlsConfig { impl TlsConfig {
@@ -250,18 +256,26 @@ impl CertResolver {
let common_name = pem.subject().to_string(); let common_name = pem.subject().to_string();
// We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as // We need to get the canonical name for this certificate so we can match them against any domain names
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so // seen within the proxy codebase.
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names //
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding // In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead // We need to remove the wildcard prefix for the purposes of certificate selection.
// of cutting off '*.' parts. //
let common_name = if common_name.starts_with("CN=*.") { // auth-broker does not use SNI and instead uses the Neon-Connection-String header.
common_name.strip_prefix("CN=*.").map(|s| s.to_string()) // Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else { } else {
common_name.strip_prefix("CN=").map(|s| s.to_string()) bail!("Failed to parse common name from certificate")
} };
.context("Failed to parse common name from certificate")?;
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key)); let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));

View File

@@ -5,7 +5,10 @@ pub mod neon;
use super::messages::{ConsoleError, MetricsAuxInfo}; use super::messages::{ConsoleError, MetricsAuxInfo};
use crate::{ use crate::{
auth::{ auth::{
backend::{ComputeCredentialKeys, ComputeUserInfo}, backend::{
jwt::{AuthRule, FetchAuthRules},
ComputeCredentialKeys, ComputeUserInfo,
},
IpPattern, IpPattern,
}, },
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru}, cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
@@ -16,7 +19,7 @@ use crate::{
intern::ProjectIdInt, intern::ProjectIdInt,
metrics::ApiLockMetrics, metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}, rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey, scram, EndpointCacheKey, EndpointId,
}; };
use dashmap::DashMap; use dashmap::DashMap;
use std::{hash::Hash, sync::Arc, time::Duration}; use std::{hash::Hash, sync::Arc, time::Duration};
@@ -334,6 +337,12 @@ pub(crate) trait Api {
user_info: &ComputeUserInfo, user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>; ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>>;
/// Wake up the compute node and return the corresponding connection info. /// Wake up the compute node and return the corresponding connection info.
async fn wake_compute( async fn wake_compute(
&self, &self,
@@ -343,6 +352,7 @@ pub(crate) trait Api {
} }
#[non_exhaustive] #[non_exhaustive]
#[derive(Clone)]
pub enum ConsoleBackend { pub enum ConsoleBackend {
/// Current Cloud API (V2). /// Current Cloud API (V2).
Console(neon::Api), Console(neon::Api),
@@ -386,6 +396,20 @@ impl Api for ConsoleBackend {
} }
} }
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
match self {
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(any(test, feature = "testing"))]
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(test)]
Self::Test(_api) => Ok(vec![]),
}
}
async fn wake_compute( async fn wake_compute(
&self, &self,
ctx: &RequestMonitoring, ctx: &RequestMonitoring,
@@ -552,3 +576,13 @@ impl WakeComputePermit {
res res
} }
} }
impl FetchAuthRules for ConsoleBackend {
async fn fetch_auth_rules(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.get_endpoint_jwks(ctx, endpoint).await
}
}

View File

@@ -4,7 +4,9 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError}, errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
}; };
use crate::context::RequestMonitoring; use crate::{
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached}; use crate::{auth::IpPattern, cache::Cached};
use crate::{ use crate::{
@@ -118,6 +120,39 @@ impl Api {
}) })
} }
async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
let connection = tokio::spawn(connection);
let res = client.query(
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
&[&endpoint.as_str()],
)
.await?;
let mut rows = vec![];
for row in res {
rows.push(AuthRule {
id: row.get("id"),
jwks_url: url::Url::parse(row.get("jwks_url"))?,
audience: row.get("audience"),
role_names: row
.get::<_, Vec<String>>("role_names")
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
});
}
drop(client);
connection.await??;
Ok(rows)
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> { async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new(); let mut config = compute::ConnCfg::new();
config config
@@ -185,6 +220,14 @@ impl super::Api for Api {
)) ))
} }
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(endpoint).await
}
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn wake_compute( async fn wake_compute(
&self, &self,

View File

@@ -7,27 +7,33 @@ use super::{
NodeInfo, NodeInfo,
}; };
use crate::{ use crate::{
auth::backend::ComputeUserInfo, auth::backend::{jwt::AuthRule, ComputeUserInfo},
compute, compute,
console::messages::{ColdStartInfo, Reason}, console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
http, http,
metrics::{CacheOutcome, Metrics}, metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter, rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey, scram, EndpointCacheKey, EndpointId,
}; };
use crate::{cache::Cached, context::RequestMonitoring}; use crate::{cache::Cached, context::RequestMonitoring};
use ::http::{header::AUTHORIZATION, HeaderName};
use anyhow::bail;
use futures::TryFutureExt; use futures::TryFutureExt;
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_postgres::config::SslMode; use tokio_postgres::config::SslMode;
use tracing::{debug, error, info, info_span, warn, Instrument}; use tracing::{debug, error, info, info_span, warn, Instrument};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]
pub struct Api { pub struct Api {
endpoint: http::Endpoint, endpoint: http::Endpoint,
pub caches: &'static ApiCaches, pub caches: &'static ApiCaches,
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>, pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>, pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
jwt: String, // put in a shared ref so we don't copy secrets all over in memory
jwt: Arc<str>,
} }
impl Api { impl Api {
@@ -38,7 +44,9 @@ impl Api {
locks: &'static ApiLocks<EndpointCacheKey>, locks: &'static ApiLocks<EndpointCacheKey>,
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>, wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
) -> Self { ) -> Self {
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default(); let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
.unwrap_or_default()
.into();
Self { Self {
endpoint, endpoint,
caches, caches,
@@ -71,9 +79,9 @@ impl Api {
async { async {
let request = self let request = self
.endpoint .endpoint
.get("proxy_get_role_secret") .get_path("proxy_get_role_secret")
.header("X-Request-ID", &request_id) .header(X_REQUEST_ID, &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt)) .header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())]) .query(&[("session_id", ctx.session_id())])
.query(&[ .query(&[
("application_name", application_name.as_str()), ("application_name", application_name.as_str()),
@@ -125,6 +133,61 @@ impl Api {
.await .await
} }
async fn do_get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &endpoint.normalize())
.await
{
bail!("endpoint not found");
}
let request_id = ctx.session_id().to_string();
async {
let request = self
.endpoint
.get_with_url(|url| {
url.path_segments_mut()
.push("endpoints")
.push(endpoint.as_str())
.push("jwks");
})
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<EndpointJwksResponse>(response).await?;
let rules = body
.jwks
.into_iter()
.map(|jwks| AuthRule {
id: jwks.id,
jwks_url: jwks.jwks_url,
audience: jwks.jwt_audience,
role_names: jwks.role_names,
})
.collect();
Ok(rules)
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}
async fn do_wake_compute( async fn do_wake_compute(
&self, &self,
ctx: &RequestMonitoring, ctx: &RequestMonitoring,
@@ -135,7 +198,7 @@ impl Api {
async { async {
let mut request_builder = self let mut request_builder = self
.endpoint .endpoint
.get("proxy_wake_compute") .get_path("proxy_wake_compute")
.header("X-Request-ID", &request_id) .header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt)) .header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())]) .query(&[("session_id", ctx.session_id())])
@@ -262,6 +325,15 @@ impl super::Api for Api {
)) ))
} }
#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(ctx, endpoint).await
}
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn wake_compute( async fn wake_compute(
&self, &self,

View File

@@ -86,9 +86,17 @@ impl Endpoint {
/// Return a [builder](RequestBuilder) for a `GET` request, /// Return a [builder](RequestBuilder) for a `GET` request,
/// appending a single `path` segment to the base endpoint URL. /// appending a single `path` segment to the base endpoint URL.
pub(crate) fn get(&self, path: &str) -> RequestBuilder { pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
self.get_with_url(|u| {
u.path_segments_mut().push(path);
})
}
/// Return a [builder](RequestBuilder) for a `GET` request,
/// accepting a closure to modify the url path segments for more complex paths queries.
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
let mut url = self.endpoint.clone(); let mut url = self.endpoint.clone();
url.path_segments_mut().push(path); f(&mut url);
self.client.get(url.into_inner()) self.client.get(url.into_inner())
} }
@@ -144,7 +152,7 @@ mod tests {
// Validate that this pattern makes sense. // Validate that this pattern makes sense.
let req = endpoint let req = endpoint
.get("frobnicate") .get_path("frobnicate")
.query(&[ .query(&[
("foo", Some("10")), // should be just `foo=10` ("foo", Some("10")), // should be just `foo=10`
("bar", None), // shouldn't be passed at all ("bar", None), // shouldn't be passed at all
@@ -162,7 +170,7 @@ mod tests {
let endpoint = Endpoint::new(url, Client::new()); let endpoint = Endpoint::new(url, Client::new());
let req = endpoint let req = endpoint
.get("frobnicate") .get_path("frobnicate")
.query(&[("session_id", uuid::Uuid::nil())]) .query(&[("session_id", uuid::Uuid::nil())])
.build()?; .build()?;

View File

@@ -1,5 +1,6 @@
use std::{ use std::{
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock, any::type_name, hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index,
sync::OnceLock,
}; };
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo}; use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
@@ -16,12 +17,21 @@ pub struct StringInterner<Id> {
_id: PhantomData<Id>, _id: PhantomData<Id>,
} }
#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)] #[derive(PartialEq, Clone, Copy, Eq, Hash)]
pub struct InternedString<Id> { pub struct InternedString<Id> {
inner: Spur, inner: Spur,
_id: PhantomData<Id>, _id: PhantomData<Id>,
} }
impl<Id: InternId> std::fmt::Debug for InternedString<Id> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("InternedString")
.field(&type_name::<Id>())
.field(&self.as_str())
.finish()
}
}
impl<Id: InternId> std::fmt::Display for InternedString<Id> { impl<Id: InternId> std::fmt::Display for InternedString<Id> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.as_str().fmt(f) self.as_str().fmt(f)

View File

@@ -525,6 +525,10 @@ impl TestBackend for TestConnectMechanism {
{ {
unimplemented!("not used in tests") unimplemented!("not used in tests")
} }
fn dyn_clone(&self) -> Box<dyn TestBackend> {
Box::new(self.clone())
}
} }
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {

View File

@@ -5,6 +5,7 @@
mod backend; mod backend;
pub mod cancel_set; pub mod cancel_set;
mod conn_pool; mod conn_pool;
mod http_conn_pool;
mod http_util; mod http_util;
mod json; mod json;
mod sql_over_http; mod sql_over_http;
@@ -19,7 +20,8 @@ use anyhow::Context;
use futures::future::{select, Either}; use futures::future::{select, Either};
use futures::TryFutureExt; use futures::TryFutureExt;
use http::{Method, Response, StatusCode}; use http::{Method, Response, StatusCode};
use http_body_util::Full; use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use hyper1::body::Incoming; use hyper1::body::Incoming;
use hyper_util::rt::TokioExecutor; use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder; use hyper_util::server::conn::auto::Builder;
@@ -81,7 +83,28 @@ pub async fn task_main(
} }
}); });
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
{
let http_conn_pool = Arc::clone(&http_conn_pool);
tokio::spawn(async move {
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
});
}
// shutdown the connection pool
tokio::spawn({
let cancellation_token = cancellation_token.clone();
let http_conn_pool = http_conn_pool.clone();
async move {
cancellation_token.cancelled().await;
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
.await
.unwrap();
}
});
let backend = Arc::new(PoolingBackend { let backend = Arc::new(PoolingBackend {
http_conn_pool: Arc::clone(&http_conn_pool),
pool: Arc::clone(&conn_pool), pool: Arc::clone(&conn_pool),
config, config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
@@ -342,7 +365,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets // used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken, http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>, endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Full<Bytes>>, ApiError> { ) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let host = request let host = request
.headers() .headers()
.get("host") .get("host")
@@ -386,7 +409,7 @@ async fn request_handler(
); );
// Return the response so the spawned future can continue. // Return the response so the spawned future can continue.
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new()))) Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST { } else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new( let ctx = RequestMonitoring::new(
session_id, session_id,
@@ -409,7 +432,7 @@ async fn request_handler(
) )
.header("Access-Control-Max-Age", "86400" /* 24 hours */) .header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Full::new(Bytes::new())) .body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into())) .map_err(|e| ApiError::InternalServerError(e.into()))
} else { } else {
json_response(StatusCode::BAD_REQUEST, "query is not supported") json_response(StatusCode::BAD_REQUEST, "query is not supported")

View File

@@ -1,6 +1,8 @@
use std::{sync::Arc, time::Duration}; use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait; use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use tokio::net::{lookup_host, TcpStream};
use tracing::{field::display, info}; use tracing::{field::display, info};
use crate::{ use crate::{
@@ -27,9 +29,13 @@ use crate::{
Host, Host,
}; };
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}; use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
};
pub(crate) struct PoolingBackend { pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>, pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig, pub(crate) config: &'static ProxyConfig,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>, pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -103,32 +109,44 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt( pub(crate) async fn authenticate_with_jwt(
&self, &self,
ctx: &RequestMonitoring, ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo, user_info: &ComputeUserInfo,
jwt: &str, jwt: String,
) -> Result<ComputeCredentials, AuthError> { ) -> Result<(), AuthError> {
match &self.config.auth_backend { match &self.config.auth_backend {
crate::auth::Backend::Console(_, ()) => { crate::auth::Backend::Console(console, ()) => {
Err(AuthError::auth_failed("JWT login is not yet supported")) config
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
&user_info.user,
&**console,
&jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(())
} }
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed( crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported", "JWT login over web auth proxy is not supported",
)), )),
crate::auth::Backend::Local(cache) => { crate::auth::Backend::Local(_) => {
cache config
.jwks_cache .jwks_cache
.check_jwt( .check_jwt(
ctx, ctx,
user_info.endpoint.clone(), user_info.endpoint.clone(),
&user_info.user, &user_info.user,
&StaticAuthRules, &StaticAuthRules,
jwt, &jwt,
) )
.await .await
.map_err(|e| AuthError::auth_failed(e.to_string()))?; .map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(ComputeCredentials {
info: user_info.clone(), // todo: rewrite JWT signature with key shared somehow between local proxy and postgres
keys: crate::auth::backend::ComputeCredentialKeys::None, Ok(())
})
} }
} }
} }
@@ -174,14 +192,55 @@ impl PoolingBackend {
) )
.await .await
} }
// Wake up the destination if needed
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client, HttpConnError> {
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub(crate) enum HttpConnError { pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")] #[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>), ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connection to compute")] #[error("could not connection to postgres in compute")]
ConnectionError(#[from] tokio_postgres::Error), PostgresConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to local-proxy in compute")]
LocalProxyConnectionError(#[from] LocalProxyConnError),
#[error("could not get auth info")] #[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError), GetAuthInfo(#[from] GetAuthInfoError),
@@ -193,11 +252,20 @@ pub(crate) enum HttpConnError {
TooManyConnectionAttempts(#[from] ApiLockError), TooManyConnectionAttempts(#[from] ApiLockError),
} }
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper1::Error),
}
impl ReportableError for HttpConnError { impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind { fn get_error_kind(&self) -> ErrorKind {
match self { match self {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::ConnectionError(p) => p.get_error_kind(), HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::GetAuthInfo(a) => a.get_error_kind(), HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(), HttpConnError::AuthError(a) => a.get_error_kind(),
HttpConnError::WakeCompute(w) => w.get_error_kind(), HttpConnError::WakeCompute(w) => w.get_error_kind(),
@@ -210,7 +278,8 @@ impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String { fn to_string_client(&self) -> String {
match self { match self {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::ConnectionError(p) => p.to_string(), HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(), HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(), HttpConnError::AuthError(c) => c.to_string_client(),
HttpConnError::WakeCompute(c) => c.to_string_client(), HttpConnError::WakeCompute(c) => c.to_string_client(),
@@ -224,7 +293,8 @@ impl UserFacingError for HttpConnError {
impl CouldRetry for HttpConnError { impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool { fn could_retry(&self) -> bool {
match self { match self {
HttpConnError::ConnectionError(e) => e.could_retry(), HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ConnectionClosedAbruptly(_) => false, HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::GetAuthInfo(_) => false, HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false, HttpConnError::AuthError(_) => false,
@@ -236,7 +306,7 @@ impl CouldRetry for HttpConnError {
impl ShouldRetryWakeCompute for HttpConnError { impl ShouldRetryWakeCompute for HttpConnError {
fn should_retry_wake_compute(&self) -> bool { fn should_retry_wake_compute(&self) -> bool {
match self { match self {
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(), HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
// we never checked cache validity // we never checked cache validity
HttpConnError::TooManyConnectionAttempts(_) => false, HttpConnError::TooManyConnectionAttempts(_) => false,
_ => true, _ => true,
@@ -244,6 +314,38 @@ impl ShouldRetryWakeCompute for HttpConnError {
} }
} }
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
}
impl UserFacingError for LocalProxyConnError {
fn to_string_client(&self) -> String {
"Could not establish HTTP connection to the database".to_string()
}
}
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
struct TokioMechanism { struct TokioMechanism {
pool: Arc<GlobalConnPool<tokio_postgres::Client>>, pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
conn_info: ConnInfo, conn_info: ConnInfo,
@@ -293,3 +395,99 @@ impl ConnectMechanism for TokioMechanism {
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
} }
struct HyperMechanism {
pool: Arc<http_conn_pool::GlobalConnPool>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestMonitoring,
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let permit = self.locks.get_permit(&host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
let res = connect_http2(&host, 10432, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(
host: &str,
port: u16,
timeout: Duration,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
// assumption: host is an ip address so this should not actually perform any requests.
// todo: add that assumption as a guarantee in the control-plane API.
let mut addrs = lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?;
let mut last_err = None;
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
};
};
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;
Ok((client, connection))
}

View File

@@ -0,0 +1,335 @@
use dashmap::DashMap;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::{sync::Arc, sync::Weak};
use tokio::net::TcpStream;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, EndpointCacheKey};
use tracing::{debug, error};
use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
#[derive(Clone)]
struct ConnPoolEntry {
conn: Send,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool {
conns: VecDeque<ConnPoolEntry>,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
}
impl EndpointConnPool {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
let Self { conns, .. } = self;
let conn = conns.pop_front()?;
conns.push_back(conn.clone());
Some(conn)
}
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
let Self {
conns,
global_connections_count,
..
} = self;
let old_len = conns.len();
conns.retain(|conn| conn.conn_id != conn_id);
let new_len = conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
removed > 0
}
}
impl Drop for EndpointConnPool {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.conns.len() as i64);
}
}
}
pub(crate) struct GlobalConnPool {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
impl GlobalConnPool {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool { conns, .. } = pool.get_mut();
let old_len = conns.len();
conns.retain(|conn| !conn.conn.is_closed());
let new_len = conns.len();
let removed = old_len - new_len;
clients_removed += removed;
// we only remove this pool if it has no active connections
if conns.is_empty() {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Option<Client> {
let endpoint = conn_info.endpoint_cache_key()?;
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let client = endpoint_pool.write().get_conn_entry()?;
if client.conn.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return None;
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Some(Client::new(client.conn, client.aux))
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
conns: VecDeque::new(),
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => {
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
pool.write().conns.push_back(ConnPoolEntry {
conn: client.clone(),
conn_id,
aux: aux.clone(),
});
Arc::downgrade(&pool)
}
None => Weak::new(),
};
// let idle = global_pool.get_idle_timeout();
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!(%session_id, "connection error: {}", e),
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
}
}
.instrument(span),
);
Client::new(client, aux)
}
pub(crate) struct Client {
pub(crate) inner: Send,
aux: MetricsAuxInfo,
}
impl Client {
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
Self { inner, aux }
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
USAGE_METRICS.register(Ids {
endpoint_id: self.aux.endpoint_id,
branch_id: self.aux.branch_id,
})
}
}

View File

@@ -5,13 +5,13 @@ use bytes::Bytes;
use anyhow::Context; use anyhow::Context;
use http::{Response, StatusCode}; use http::{Response, StatusCode};
use http_body_util::Full; use http_body_util::{combinators::BoxBody, BodyExt, Full};
use serde::Serialize; use serde::Serialize;
use utils::http::error::ApiError; use utils::http::error::ApiError;
/// Like [`ApiError::into_response`] /// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> { pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
match this { match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status( ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause format!("{err:#?}"), // use debug printing so that we give the cause
@@ -64,17 +64,24 @@ struct HttpErrorBody {
impl HttpErrorBody { impl HttpErrorBody {
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`] /// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> { fn response_from_msg_and_status(
msg: String,
status: StatusCode,
) -> Response<BoxBody<Bytes, hyper1::Error>> {
HttpErrorBody { msg }.to_response(status) HttpErrorBody { msg }.to_response(status)
} }
/// Same as [`utils::http::error::HttpErrorBody::to_response`] /// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> { fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
Response::builder() Response::builder()
.status(status) .status(status)
.header(http::header::CONTENT_TYPE, "application/json") .header(http::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail // we do not have nested maps with non string keys so serialization shouldn't fail
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap()))) .body(
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
.map_err(|x| match x {})
.boxed(),
)
.unwrap() .unwrap()
} }
} }
@@ -83,14 +90,14 @@ impl HttpErrorBody {
pub(crate) fn json_response<T: Serialize>( pub(crate) fn json_response<T: Serialize>(
status: StatusCode, status: StatusCode,
data: T, data: T,
) -> Result<Response<Full<Bytes>>, ApiError> { ) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let json = serde_json::to_string(&data) let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response") .context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?; .map_err(ApiError::InternalServerError)?;
let response = Response::builder() let response = Response::builder()
.status(status) .status(status)
.header(http::header::CONTENT_TYPE, "application/json") .header(http::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json))) .body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))?; .map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response) Ok(response)
} }

View File

@@ -8,6 +8,8 @@ use futures::future::Either;
use futures::StreamExt; use futures::StreamExt;
use futures::TryFutureExt; use futures::TryFutureExt;
use http::header::AUTHORIZATION; use http::header::AUTHORIZATION;
use http::Method;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt; use http_body_util::BodyExt;
use http_body_util::Full; use http_body_util::Full;
use hyper1::body::Body; use hyper1::body::Body;
@@ -38,9 +40,11 @@ use url::Url;
use urlencoding; use urlencoding;
use utils::http::error::ApiError; use utils::http::error::ApiError;
use crate::auth::backend::ComputeCredentials;
use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni; use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError; use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::config::TlsConfig; use crate::config::TlsConfig;
use crate::context::RequestMonitoring; use crate::context::RequestMonitoring;
@@ -56,6 +60,7 @@ use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName; use crate::DbName;
use crate::RoleName; use crate::RoleName;
use super::backend::LocalProxyConnError;
use super::backend::PoolingBackend; use super::backend::PoolingBackend;
use super::conn_pool::AuthData; use super::conn_pool::AuthData;
use super::conn_pool::Client; use super::conn_pool::Client;
@@ -123,8 +128,8 @@ pub(crate) enum ConnInfoError {
MissingUsername, MissingUsername,
#[error("invalid username: {0}")] #[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error), InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing password")] #[error("missing authentication credentials: {0}")]
MissingPassword, MissingCredentials(Credentials),
#[error("missing hostname")] #[error("missing hostname")]
MissingHostname, MissingHostname,
#[error("invalid hostname: {0}")] #[error("invalid hostname: {0}")]
@@ -133,6 +138,14 @@ pub(crate) enum ConnInfoError {
MalformedEndpoint, MalformedEndpoint,
} }
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError { impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind { fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User ErrorKind::User
@@ -146,6 +159,7 @@ impl UserFacingError for ConnInfoError {
} }
fn get_conn_info( fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestMonitoring, ctx: &RequestMonitoring,
headers: &HeaderMap, headers: &HeaderMap,
tls: Option<&TlsConfig>, tls: Option<&TlsConfig>,
@@ -181,21 +195,32 @@ fn get_conn_info(
ctx.set_user(username.clone()); ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth let auth = auth
.to_str() .to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt( AuthData::Jwt(
auth.strip_prefix("Bearer ") auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingPassword)? .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(), .into(),
) )
} else if let Some(pass) = connection_url.password() { } else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(), std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(), std::borrow::Cow::Owned(b) => b.into(),
}) })
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else { } else {
return Err(ConnInfoError::MissingPassword); return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}; };
let endpoint = match connection_url.host() { let endpoint = match connection_url.host() {
@@ -247,7 +272,7 @@ pub(crate) async fn handle(
request: Request<Incoming>, request: Request<Incoming>,
backend: Arc<PoolingBackend>, backend: Arc<PoolingBackend>,
cancel: CancellationToken, cancel: CancellationToken,
) -> Result<Response<Full<Bytes>>, ApiError> { ) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await; let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result { let mut response = match result {
@@ -279,7 +304,7 @@ pub(crate) async fn handle(
let mut message = e.to_string_client(); let mut message = e.to_string_client();
let db_error = match &e { let db_error = match &e {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e)) SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
| SqlOverHttpError::Postgres(e) => e.as_db_error(), | SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None, _ => None,
}; };
@@ -504,7 +529,7 @@ async fn handle_inner(
ctx: &RequestMonitoring, ctx: &RequestMonitoring,
request: Request<Incoming>, request: Request<Incoming>,
backend: Arc<PoolingBackend>, backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> { ) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get() let _requeset_gauge = Metrics::get()
.proxy .proxy
.connection_requests .connection_requests
@@ -514,18 +539,50 @@ async fn handle_inner(
"handling interactive connection from client" "handling interactive connection from client"
); );
// let conn_info = get_conn_info(
// Determine the destination and connection params &config.authentication_config,
// ctx,
let headers = request.headers(); request.headers(),
config.tls_config.as_ref(),
// TLS config should be there. )?;
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
info!( info!(
user = conn_info.conn_info.user_info.user.as_str(), user = conn_info.conn_info.user_info.user.as_str(),
"credentials" "credentials"
); );
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
cancel,
config,
ctx,
request,
conn_info.conn_info,
auth,
backend,
)
.await
}
}
}
async fn handle_db_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
auth: AuthData,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
//
// Determine the destination and connection params
//
let headers = request.headers();
// Allow connection pooling only if explicitly requested // Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in // or if we have decided that http pool is no longer opt-in
let allow_pool = !config.http_config.pool_options.opt_in let allow_pool = !config.http_config.pool_options.opt_in
@@ -563,26 +620,36 @@ async fn handle_inner(
let authenticate_and_connect = Box::pin( let authenticate_and_connect = Box::pin(
async { async {
let keys = match &conn_info.auth { let keys = match auth {
AuthData::Password(pw) => { AuthData::Password(pw) => {
backend backend
.authenticate_with_password( .authenticate_with_password(
ctx, ctx,
&config.authentication_config, &config.authentication_config,
&conn_info.conn_info.user_info, &conn_info.user_info,
pw, &pw,
) )
.await? .await?
} }
AuthData::Jwt(jwt) => { AuthData::Jwt(jwt) => {
backend backend
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt) .authenticate_with_jwt(
.await? ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await?;
ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
}
} }
}; };
let client = backend let client = backend
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool) .connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?; .await?;
// not strictly necessary to mark success here, // not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else // but it's just insurance for if we forget it somewhere else
@@ -640,7 +707,11 @@ async fn handle_inner(
let len = json_output.len(); let len = json_output.len();
let response = response let response = response
.body(Full::new(Bytes::from(json_output))) .body(
Full::new(Bytes::from(json_output))
.map_err(|x| match x {})
.boxed(),
)
// only fails if invalid status code or invalid header/values are given. // only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically // these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail"); .expect("building response payload should not fail");
@@ -656,6 +727,65 @@ async fn handle_inner(
Ok(response) Ok(response)
} }
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&AUTHORIZATION,
&CONN_STRING,
&RAW_TEXT_OUTPUT,
&ARRAY_MODE,
&TXN_ISOLATION_LEVEL,
&TXN_READ_ONLY,
&TXN_DEFERRABLE,
];
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
let (mut parts, body) = request.into_parts();
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
// these instead just to ensure they remain normalised.
for &h in HEADERS_TO_FORWARD {
if let Some(hv) = parts.headers.remove(h) {
req = req.header(h, hv);
}
}
let req = req
.body(body)
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress
let _metrics = client.metrics();
Ok(client
.inner
.send_request(req)
.await
.map_err(LocalProxyConnError::from)
.map_err(HttpConnError::from)?
.map(|b| b.boxed()))
}
impl QueryData { impl QueryData {
async fn process( async fn process(
self, self,
@@ -705,7 +835,9 @@ impl QueryData {
// query failed or was cancelled. // query failed or was cancelled.
Ok(Err(error)) => { Ok(Err(error)) => {
let db_error = match &error { let db_error = match &error {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e)) SqlOverHttpError::ConnectCompute(
HttpConnError::PostgresConnectionError(e),
)
| SqlOverHttpError::Postgres(e) => e.as_db_error(), | SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None, _ => None,
}; };

View File

@@ -56,6 +56,7 @@ from _pytest.fixtures import FixtureRequest
from psycopg2.extensions import connection as PgConnection from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import cursor as PgCursor from psycopg2.extensions import cursor as PgCursor
from psycopg2.extensions import make_dsn, parse_dsn from psycopg2.extensions import make_dsn, parse_dsn
from pytest_httpserver import HTTPServer
from urllib3.util.retry import Retry from urllib3.util.retry import Retry
from fixtures import overlayfs from fixtures import overlayfs
@@ -440,9 +441,9 @@ class NeonEnvBuilder:
self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine
self.pageserver_default_tenant_config_compaction_algorithm: Optional[ self.pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]] = (
Dict[str, Any] pageserver_default_tenant_config_compaction_algorithm
] = pageserver_default_tenant_config_compaction_algorithm )
if self.pageserver_default_tenant_config_compaction_algorithm is not None: if self.pageserver_default_tenant_config_compaction_algorithm is not None:
log.debug( log.debug(
f"Overriding pageserver default compaction algorithm to {self.pageserver_default_tenant_config_compaction_algorithm}" f"Overriding pageserver default compaction algorithm to {self.pageserver_default_tenant_config_compaction_algorithm}"
@@ -1072,9 +1073,9 @@ class NeonEnv:
ps_cfg["virtual_file_io_engine"] = self.pageserver_virtual_file_io_engine ps_cfg["virtual_file_io_engine"] = self.pageserver_virtual_file_io_engine
if config.pageserver_default_tenant_config_compaction_algorithm is not None: if config.pageserver_default_tenant_config_compaction_algorithm is not None:
tenant_config = ps_cfg.setdefault("tenant_config", {}) tenant_config = ps_cfg.setdefault("tenant_config", {})
tenant_config[ tenant_config["compaction_algorithm"] = (
"compaction_algorithm" config.pageserver_default_tenant_config_compaction_algorithm
] = config.pageserver_default_tenant_config_compaction_algorithm )
if self.pageserver_remote_storage is not None: if self.pageserver_remote_storage is not None:
ps_cfg["remote_storage"] = remote_storage_to_toml_dict( ps_cfg["remote_storage"] = remote_storage_to_toml_dict(
@@ -1117,9 +1118,9 @@ class NeonEnv:
if config.auth_enabled: if config.auth_enabled:
sk_cfg["auth_enabled"] = True sk_cfg["auth_enabled"] = True
if self.safekeepers_remote_storage is not None: if self.safekeepers_remote_storage is not None:
sk_cfg[ sk_cfg["remote_storage"] = (
"remote_storage" self.safekeepers_remote_storage.to_toml_inline_table().strip()
] = self.safekeepers_remote_storage.to_toml_inline_table().strip() )
self.safekeepers.append( self.safekeepers.append(
Safekeeper(env=self, id=id, port=port, extra_opts=config.safekeeper_extra_opts) Safekeeper(env=self, id=id, port=port, extra_opts=config.safekeeper_extra_opts)
) )
@@ -3572,6 +3573,20 @@ class NeonProxy(PgProtocol):
] ]
return args return args
class AuthBroker(AuthBackend):
def __init__(self, endpoint: str):
self.endpoint = endpoint
def extra_args(self) -> list[str]:
args = [
# Console auth backend params
*["--auth-backend", "console"],
*["--auth-endpoint", self.endpoint],
*["--sql-over-http-pool-opt-in", "false"],
*["--is-auth-broker"],
]
return args
@dataclass(frozen=True) @dataclass(frozen=True)
class Postgres(AuthBackend): class Postgres(AuthBackend):
pg_conn_url: str pg_conn_url: str
@@ -3600,7 +3615,7 @@ class NeonProxy(PgProtocol):
metric_collection_interval: Optional[str] = None, metric_collection_interval: Optional[str] = None,
): ):
host = "127.0.0.1" host = "127.0.0.1"
domain = "proxy.localtest.me" # resolves to 127.0.0.1 domain = "ep-foo-bar-1234.localtest.me" # resolves to 127.0.0.1
super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port) super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port)
self.domain = domain self.domain = domain
@@ -3886,6 +3901,50 @@ def static_proxy(
yield proxy yield proxy
@pytest.fixture(scope="function")
def static_auth_broker(
vanilla_pg: VanillaPostgres,
port_distributor: PortDistributor,
neon_binpath: Path,
test_output_dir: Path,
httpserver: HTTPServer,
) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
auth_endpoint = httpserver.url_for("/cplane")
port = vanilla_pg.default_options["port"]
host = vanilla_pg.default_options["host"]
httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json(
{
"address": f"{host}:{port}",
"aux": {
"endpoint_id": "ep-foo-bar-1234",
"branch_id": "br-foo-bar",
"project_id": "foo-bar",
},
}
)
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()
external_http_port = port_distributor.get_port()
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.AuthBroker(auth_endpoint),
) as proxy:
proxy.start()
yield proxy
class Endpoint(PgProtocol, LogUtils): class Endpoint(PgProtocol, LogUtils):
"""An object representing a Postgres compute endpoint managed by the control plane.""" """An object representing a Postgres compute endpoint managed by the control plane."""

View File

@@ -0,0 +1,71 @@
import asyncio
import json
import subprocess
import time
import urllib.parse
from typing import Any, List, Optional, Tuple
import psycopg2
import pytest
import requests
from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres
GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'"
def test_sql_over_http(static_auth_broker: NeonProxy):
static_auth_broker.safe_psql("create role http with login password 'http' superuser")
def q(sql: str, params: Optional[List[Any]] = None) -> Any:
params = params or []
connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
response = requests.post(
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
data=json.dumps({"query": sql, "params": params}),
headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
verify=str(static_proxy.test_output_dir / "proxy.crt"),
)
assert response.status_code == 200, response.text
return response.json()
rows = q("select 42 as answer")["rows"]
assert rows == [{"answer": 42}]
rows = q("select $1 as answer", [42])["rows"]
assert rows == [{"answer": "42"}]
rows = q("select $1 * 1 as answer", [42])["rows"]
assert rows == [{"answer": 42}]
rows = q("select $1::int[] as answer", [[1, 2, 3]])["rows"]
assert rows == [{"answer": [1, 2, 3]}]
rows = q("select $1::json->'a' as answer", [{"a": {"b": 42}}])["rows"]
assert rows == [{"answer": {"b": 42}}]
rows = q("select $1::jsonb[] as answer", [[{}]])["rows"]
assert rows == [{"answer": [{}]}]
rows = q("select $1::jsonb[] as answer", [[{"foo": 1}, {"bar": 2}]])["rows"]
assert rows == [{"answer": [{"foo": 1}, {"bar": 2}]}]
rows = q("select * from pg_class limit 1")["rows"]
assert len(rows) == 1
res = q("create table t(id serial primary key, val int)")
assert res["command"] == "CREATE"
assert res["rowCount"] is None
res = q("insert into t(val) values (10), (20), (30) returning id")
assert res["command"] == "INSERT"
assert res["rowCount"] == 3
assert res["rows"] == [{"id": 1}, {"id": 2}, {"id": 3}]
res = q("select * from t")
assert res["command"] == "SELECT"
assert res["rowCount"] == 3
res = q("drop table t")
assert res["command"] == "DROP"
assert res["rowCount"] is None