mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-17 10:22:56 +00:00
Compare commits
20 Commits
arthur/rep
...
auth-broke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3f7d0d3f1 | ||
|
|
0724df1d3f | ||
|
|
4d47049b00 | ||
|
|
5687384a8e | ||
|
|
249f5ea17d | ||
|
|
6abcc1f298 | ||
|
|
3e97cf0d6e | ||
|
|
054ef4988b | ||
|
|
5202cd75b5 | ||
|
|
f475dac0e6 | ||
|
|
a4100373e5 | ||
|
|
040d8cf4f6 | ||
|
|
75bfd57e01 | ||
|
|
4bc2686dee | ||
|
|
8e7d2aab76 | ||
|
|
2703abccc7 | ||
|
|
76515cdae3 | ||
|
|
08c7f933a3 | ||
|
|
4ad3aa7c96 | ||
|
|
9c59e3b4b9 |
@@ -80,6 +80,14 @@ pub(crate) trait TestBackend: Send + Sync + 'static {
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
|
||||
fn dyn_clone(&self) -> Box<dyn TestBackend>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Clone for Box<dyn TestBackend> {
|
||||
fn clone(&self) -> Self {
|
||||
TestBackend::dyn_clone(&**self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Backend<'_, (), ()> {
|
||||
@@ -557,7 +565,7 @@ mod tests {
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
|
||||
use super::{auth_quirks, AuthRateLimiter};
|
||||
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
|
||||
|
||||
struct Auth {
|
||||
ips: Vec<IpPattern>,
|
||||
@@ -585,6 +593,14 @@ mod tests {
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_endpoint: crate::EndpointId,
|
||||
) -> anyhow::Result<Vec<super::jwt::AuthRule>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
@@ -595,12 +611,15 @@ mod tests {
|
||||
}
|
||||
|
||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(1),
|
||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||
rate_limiter_enabled: true,
|
||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: false,
|
||||
});
|
||||
|
||||
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
future::Future,
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
@@ -8,7 +9,10 @@ use anyhow::{bail, ensure, Context};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use dashmap::DashMap;
|
||||
use jose_jwk::crypto::KeyInfo;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use serde::{
|
||||
de::{DeserializeSeed, IgnoredAny, Visitor},
|
||||
Deserializer,
|
||||
};
|
||||
use signature::Verifier;
|
||||
use tokio::time::Instant;
|
||||
|
||||
@@ -33,6 +37,7 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
|
||||
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct AuthRule {
|
||||
pub(crate) id: String,
|
||||
pub(crate) jwks_url: url::Url,
|
||||
@@ -303,35 +308,21 @@ impl JwkCacheEntryLock {
|
||||
}
|
||||
key => bail!("unsupported key type {key:?}"),
|
||||
};
|
||||
tracing::debug!("JWT signature valid");
|
||||
|
||||
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
|
||||
tracing::debug!(?payload, "JWT signature valid with claims");
|
||||
let validator = JwtValidator {
|
||||
expected_audience,
|
||||
current_time: SystemTime::now(),
|
||||
clock_skew_leeway: CLOCK_SKEW_LEEWAY,
|
||||
};
|
||||
|
||||
match (expected_audience, payload.audience) {
|
||||
// check the audience matches
|
||||
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
|
||||
// the audience is expected but is missing
|
||||
(Some(_), None) => bail!("invalid JWT token audience"),
|
||||
// we don't care for the audience field
|
||||
(None, _) => {}
|
||||
}
|
||||
let payload = validator
|
||||
.deserialize(&mut serde_json::Deserializer::from_slice(&payload))?;
|
||||
|
||||
let now = SystemTime::now();
|
||||
|
||||
if let Some(exp) = payload.expiration {
|
||||
ensure!(now < exp + CLOCK_SKEW_LEEWAY, "JWT token has expired");
|
||||
}
|
||||
|
||||
if let Some(nbf) = payload.not_before {
|
||||
ensure!(
|
||||
nbf < now + CLOCK_SKEW_LEEWAY,
|
||||
"JWT token is not yet ready to use"
|
||||
);
|
||||
}
|
||||
tracing::debug!(?payload, "JWT claims valid");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -419,37 +410,184 @@ struct JwtHeader<'a> {
|
||||
key_id: Option<&'a str>,
|
||||
}
|
||||
|
||||
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
|
||||
#[derive(serde::Deserialize, serde::Serialize, Debug)]
|
||||
struct JwtPayload<'a> {
|
||||
/// Audience - Recipient for which the JWT is intended
|
||||
#[serde(rename = "aud")]
|
||||
audience: Option<&'a str>,
|
||||
/// Expiration - Time after which the JWT expires
|
||||
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
|
||||
expiration: Option<SystemTime>,
|
||||
/// Not before - Time after which the JWT expires
|
||||
#[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
|
||||
not_before: Option<SystemTime>,
|
||||
|
||||
// the following entries are only extracted for the sake of debug logging.
|
||||
/// Issuer of the JWT
|
||||
#[serde(rename = "iss")]
|
||||
issuer: Option<&'a str>,
|
||||
/// Subject of the JWT (the user)
|
||||
#[serde(rename = "sub")]
|
||||
subject: Option<&'a str>,
|
||||
/// Unique token identifier
|
||||
#[serde(rename = "jti")]
|
||||
jwt_id: Option<&'a str>,
|
||||
/// Unique session identifier
|
||||
#[serde(rename = "sid")]
|
||||
session_id: Option<&'a str>,
|
||||
struct JwtValidator<'a> {
|
||||
expected_audience: Option<&'a str>,
|
||||
current_time: SystemTime,
|
||||
clock_skew_leeway: Duration,
|
||||
}
|
||||
|
||||
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
|
||||
let d = <Option<u64>>::deserialize(d)?;
|
||||
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
|
||||
impl<'de> DeserializeSeed<'de> for JwtValidator<'_> {
|
||||
type Value = JwtPayload<'de>;
|
||||
|
||||
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
impl<'de> Visitor<'de> for JwtValidator<'_> {
|
||||
type Value = JwtPayload<'de>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a JWT payload")
|
||||
}
|
||||
|
||||
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::MapAccess<'de>,
|
||||
{
|
||||
let mut payload = JwtPayload {
|
||||
issuer: None,
|
||||
subject: None,
|
||||
jwt_id: None,
|
||||
session_id: None,
|
||||
};
|
||||
|
||||
let mut aud = false;
|
||||
|
||||
while let Some(key) = map.next_key()? {
|
||||
match key {
|
||||
"iss" if payload.issuer.is_none() => {
|
||||
payload.issuer = Some(map.next_value()?);
|
||||
}
|
||||
"sub" if payload.subject.is_none() => {
|
||||
payload.subject = Some(map.next_value()?);
|
||||
}
|
||||
"jit" if payload.jwt_id.is_none() => {
|
||||
payload.jwt_id = Some(map.next_value()?);
|
||||
}
|
||||
"sid" if payload.session_id.is_none() => {
|
||||
payload.session_id = Some(map.next_value()?);
|
||||
}
|
||||
"exp" => {
|
||||
let exp = map.next_value::<u64>()?;
|
||||
let exp = SystemTime::UNIX_EPOCH + Duration::from_secs(exp);
|
||||
|
||||
if self.current_time > exp + self.clock_skew_leeway {
|
||||
return Err(serde::de::Error::custom("JWT token has expired"));
|
||||
}
|
||||
}
|
||||
"nbf" => {
|
||||
let nbf = map.next_value::<u64>()?;
|
||||
let nbf = SystemTime::UNIX_EPOCH + Duration::from_secs(nbf);
|
||||
|
||||
if self.current_time + self.clock_skew_leeway < nbf {
|
||||
return Err(serde::de::Error::custom(
|
||||
"JWT token is not yet ready to use",
|
||||
));
|
||||
}
|
||||
}
|
||||
"aud" => {
|
||||
if let Some(expected_audience) = self.expected_audience {
|
||||
map.next_value_seed(AudienceValidator { expected_audience })?;
|
||||
aud = true;
|
||||
} else {
|
||||
map.next_value::<IgnoredAny>()?;
|
||||
}
|
||||
}
|
||||
_ => map.next_value::<IgnoredAny>().map(|IgnoredAny| ())?,
|
||||
}
|
||||
}
|
||||
|
||||
if self.expected_audience.is_some() && !aud {
|
||||
return Err(serde::de::Error::custom("invalid JWT token audience"));
|
||||
}
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_map(self)
|
||||
}
|
||||
}
|
||||
|
||||
struct AudienceValidator<'a> {
|
||||
expected_audience: &'a str,
|
||||
}
|
||||
|
||||
impl<'de> DeserializeSeed<'de> for AudienceValidator<'_> {
|
||||
type Value = ();
|
||||
|
||||
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
impl<'de> Visitor<'de> for AudienceValidator<'_> {
|
||||
type Value = ();
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a single string or an array of strings")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
if self.expected_audience == v {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(E::custom("invalid JWT token audience"))
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
while let Some(v) = seq.next_element_seed(SingleAudienceValidator {
|
||||
expected_audience: self.expected_audience,
|
||||
})? {
|
||||
if v {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err(serde::de::Error::custom("invalid JWT token audience"))
|
||||
}
|
||||
}
|
||||
deserializer.deserialize_any(self)
|
||||
}
|
||||
}
|
||||
|
||||
struct SingleAudienceValidator<'a> {
|
||||
expected_audience: &'a str,
|
||||
}
|
||||
|
||||
impl<'de> DeserializeSeed<'de> for SingleAudienceValidator<'_> {
|
||||
type Value = bool;
|
||||
|
||||
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
impl<'de> Visitor<'de> for SingleAudienceValidator<'_> {
|
||||
type Value = bool;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a single audience string")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(self.expected_audience == v)
|
||||
}
|
||||
}
|
||||
deserializer.deserialize_any(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
|
||||
// the following entries are only extracted for the sake of debug logging.
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
struct JwtPayload<'a> {
|
||||
/// Issuer of the JWT
|
||||
issuer: Option<Cow<'a, str>>,
|
||||
/// Subject of the JWT (the user)
|
||||
subject: Option<Cow<'a, str>>,
|
||||
/// Unique token identifier
|
||||
jwt_id: Option<Cow<'a, str>>,
|
||||
/// Unique session identifier
|
||||
session_id: Option<Cow<'a, str>>,
|
||||
}
|
||||
|
||||
struct JwkRenewalPermit<'a> {
|
||||
@@ -530,6 +668,8 @@ mod tests {
|
||||
use hyper_util::rt::TokioIo;
|
||||
use rand::rngs::OsRng;
|
||||
use rsa::pkcs8::DecodePrivateKey;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use signature::Signer;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
@@ -562,18 +702,36 @@ mod tests {
|
||||
}
|
||||
|
||||
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let body = typed_json::json! {{
|
||||
"exp": now + 3600,
|
||||
"nbf": now,
|
||||
"aud": ["audience1", "neon", "audience2"],
|
||||
"sub": "user1",
|
||||
"sid": "session1",
|
||||
"jti": "token1",
|
||||
"iss": "neon-testing",
|
||||
}};
|
||||
build_custom_jwt_payload(kid, body, sig)
|
||||
}
|
||||
|
||||
fn build_custom_jwt_payload(
|
||||
kid: String,
|
||||
body: impl Serialize,
|
||||
sig: jose_jwa::Signing,
|
||||
) -> String {
|
||||
let header = JwtHeader {
|
||||
typ: "JWT",
|
||||
algorithm: jose_jwa::Algorithm::Signing(sig),
|
||||
key_id: Some(&kid),
|
||||
};
|
||||
let body = typed_json::json! {{
|
||||
"exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
|
||||
}};
|
||||
|
||||
let header =
|
||||
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
|
||||
let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
|
||||
let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
|
||||
|
||||
format!("{header}.{body}")
|
||||
}
|
||||
@@ -588,6 +746,16 @@ mod tests {
|
||||
format!("{payload}.{sig}")
|
||||
}
|
||||
|
||||
fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
|
||||
use p256::ecdsa::{Signature, SigningKey};
|
||||
|
||||
let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
|
||||
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
|
||||
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
|
||||
|
||||
format!("{payload}.{sig}")
|
||||
}
|
||||
|
||||
fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
|
||||
use rsa::pkcs1v15::SigningKey;
|
||||
use rsa::signature::SignatureEncoding;
|
||||
@@ -659,37 +827,34 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
|
||||
-----END PRIVATE KEY-----
|
||||
";
|
||||
|
||||
#[tokio::test]
|
||||
async fn renew() {
|
||||
let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into());
|
||||
let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into());
|
||||
let (ec1, jwk3) = new_ec_jwk("3".into());
|
||||
let (ec2, jwk4) = new_ec_jwk("4".into());
|
||||
#[derive(Clone)]
|
||||
struct Fetch(Vec<AuthRule>);
|
||||
|
||||
let foo_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk1, jwk3],
|
||||
};
|
||||
let bar_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk2, jwk4],
|
||||
};
|
||||
impl FetchAuthRules for Fetch {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
Ok(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
async fn jwks_server(
|
||||
router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
|
||||
) -> SocketAddr {
|
||||
let router = Arc::new(router);
|
||||
let service = service_fn(move |req| {
|
||||
let foo_jwks = foo_jwks.clone();
|
||||
let bar_jwks = bar_jwks.clone();
|
||||
let router = Arc::clone(&router);
|
||||
async move {
|
||||
let jwks = match req.uri().path() {
|
||||
"/foo" => &foo_jwks,
|
||||
"/bar" => &bar_jwks,
|
||||
_ => {
|
||||
return Response::builder()
|
||||
.status(404)
|
||||
.body(Full::new(Bytes::new()));
|
||||
}
|
||||
};
|
||||
let body = serde_json::to_vec(jwks).unwrap();
|
||||
Response::builder()
|
||||
.status(200)
|
||||
.body(Full::new(Bytes::from(body)))
|
||||
match router(req.uri().path()) {
|
||||
Some(body) => Response::builder()
|
||||
.status(200)
|
||||
.body(Full::new(Bytes::from(body))),
|
||||
None => Response::builder()
|
||||
.status(404)
|
||||
.body(Full::new(Bytes::new())),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -704,84 +869,61 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
|
||||
}
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
addr
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Fetch(SocketAddr, Vec<RoleNameInt>);
|
||||
#[tokio::test]
|
||||
async fn check_jwt_happy_path() {
|
||||
let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
|
||||
let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
|
||||
let (ec1, jwk3) = new_ec_jwk("ec1".into());
|
||||
let (ec2, jwk4) = new_ec_jwk("ec2".into());
|
||||
|
||||
impl FetchAuthRules for Fetch {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
Ok(vec![
|
||||
AuthRule {
|
||||
id: "foo".to_owned(),
|
||||
jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: self.1.clone(),
|
||||
},
|
||||
AuthRule {
|
||||
id: "bar".to_owned(),
|
||||
jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: self.1.clone(),
|
||||
},
|
||||
])
|
||||
}
|
||||
}
|
||||
let foo_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk1, jwk3],
|
||||
};
|
||||
let bar_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk2, jwk4],
|
||||
};
|
||||
|
||||
let jwks_addr = jwks_server(move |path| match path {
|
||||
"/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
|
||||
"/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let role_name1 = RoleName::from("anonymous");
|
||||
let role_name2 = RoleName::from("authenticated");
|
||||
|
||||
let fetch = Fetch(
|
||||
addr,
|
||||
vec![
|
||||
RoleNameInt::from(&role_name1),
|
||||
RoleNameInt::from(&role_name2),
|
||||
],
|
||||
);
|
||||
let roles = vec![
|
||||
RoleNameInt::from(&role_name1),
|
||||
RoleNameInt::from(&role_name2),
|
||||
];
|
||||
let rules = vec![
|
||||
AuthRule {
|
||||
id: "foo".to_owned(),
|
||||
jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: roles.clone(),
|
||||
},
|
||||
AuthRule {
|
||||
id: "bar".to_owned(),
|
||||
jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: roles.clone(),
|
||||
},
|
||||
];
|
||||
|
||||
let fetch = Fetch(rules);
|
||||
let jwk_cache = JwkCache::default();
|
||||
|
||||
let endpoint = EndpointId::from("ep");
|
||||
|
||||
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
|
||||
|
||||
let jwt1 = new_rsa_jwt("1".into(), rs1);
|
||||
let jwt2 = new_rsa_jwt("2".into(), rs2);
|
||||
let jwt3 = new_ec_jwt("3".into(), &ec1);
|
||||
let jwt4 = new_ec_jwt("4".into(), &ec2);
|
||||
|
||||
// had the wrong kid, therefore will have the wrong ecdsa signature
|
||||
let bad_jwt = new_ec_jwt("3".into(), &ec2);
|
||||
// this role_name is not accepted
|
||||
let bad_role_name = RoleName::from("cloud_admin");
|
||||
|
||||
let err = jwk_cache
|
||||
.check_jwt(
|
||||
&RequestMonitoring::test(),
|
||||
&bad_jwt,
|
||||
&client,
|
||||
endpoint.clone(),
|
||||
&role_name1,
|
||||
&fetch,
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("signature error"));
|
||||
|
||||
let err = jwk_cache
|
||||
.check_jwt(
|
||||
&RequestMonitoring::test(),
|
||||
&jwt1,
|
||||
&client,
|
||||
endpoint.clone(),
|
||||
&bad_role_name,
|
||||
&fetch,
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("jwk not found"));
|
||||
let jwt1 = new_rsa_jwt("rs1".into(), rs1);
|
||||
let jwt2 = new_rsa_jwt("rs2".into(), rs2);
|
||||
let jwt3 = new_ec_jwt("ec1".into(), &ec1);
|
||||
let jwt4 = new_ec_jwt("ec2".into(), &ec2);
|
||||
|
||||
let tokens = [jwt1, jwt2, jwt3, jwt4];
|
||||
let role_names = [role_name1, role_name2];
|
||||
@@ -790,15 +932,194 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
|
||||
jwk_cache
|
||||
.check_jwt(
|
||||
&RequestMonitoring::test(),
|
||||
token,
|
||||
&client,
|
||||
endpoint.clone(),
|
||||
role,
|
||||
&fetch,
|
||||
token,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_jwt_invalid_signature() {
|
||||
let (_, jwk) = new_ec_jwk("1".into());
|
||||
let (key, _) = new_ec_jwk("1".into());
|
||||
|
||||
// has a matching kid, but signed by the wrong key
|
||||
let bad_jwt = new_ec_jwt("1".into(), &key);
|
||||
|
||||
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
|
||||
let jwks_addr = jwks_server(move |path| match path {
|
||||
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let role = RoleName::from("authenticated");
|
||||
|
||||
let rules = vec![AuthRule {
|
||||
id: String::new(),
|
||||
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: vec![RoleNameInt::from(&role)],
|
||||
}];
|
||||
|
||||
let fetch = Fetch(rules);
|
||||
let jwk_cache = JwkCache::default();
|
||||
|
||||
let ep = EndpointId::from("ep");
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let err = jwk_cache
|
||||
.check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("signature error"),
|
||||
"expected \"signature error\", got {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_jwt_unknown_role() {
|
||||
let (key, jwk) = new_rsa_jwk(RS1, "1".into());
|
||||
let jwt = new_rsa_jwt("1".into(), key);
|
||||
|
||||
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
|
||||
let jwks_addr = jwks_server(move |path| match path {
|
||||
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let role = RoleName::from("authenticated");
|
||||
let rules = vec![AuthRule {
|
||||
id: String::new(),
|
||||
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: vec![RoleNameInt::from(&role)],
|
||||
}];
|
||||
|
||||
let fetch = Fetch(rules);
|
||||
let jwk_cache = JwkCache::default();
|
||||
|
||||
let ep = EndpointId::from("ep");
|
||||
|
||||
// this role_name is not accepted
|
||||
let bad_role_name = RoleName::from("cloud_admin");
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let err = jwk_cache
|
||||
.check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("jwk not found"),
|
||||
"expected \"jwk not found\", got {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_jwt_invalid_claims() {
|
||||
let (key, jwk) = new_ec_jwk("1".into());
|
||||
|
||||
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
|
||||
let jwks_addr = jwks_server(move |path| match path {
|
||||
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
struct Test {
|
||||
body: serde_json::Value,
|
||||
error: &'static str,
|
||||
}
|
||||
|
||||
let table = vec![
|
||||
Test {
|
||||
body: json! {{
|
||||
"nbf": now + 60,
|
||||
"aud": "neon",
|
||||
}},
|
||||
error: "JWT token is not yet ready to use",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"exp": now - 60,
|
||||
"aud": ["neon"],
|
||||
}},
|
||||
error: "JWT token has expired",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"aud": [],
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"aud": "foo",
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"aud": ["foo"],
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"aud": ["foo", "bar"],
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
];
|
||||
|
||||
let role = RoleName::from("authenticated");
|
||||
|
||||
let rules = vec![AuthRule {
|
||||
id: String::new(),
|
||||
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
|
||||
audience: Some("neon".to_string()),
|
||||
role_names: vec![RoleNameInt::from(&role)],
|
||||
}];
|
||||
|
||||
let fetch = Fetch(rules);
|
||||
let jwk_cache = JwkCache::default();
|
||||
|
||||
let ep = EndpointId::from("ep");
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
for test in table {
|
||||
let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
|
||||
|
||||
match jwk_cache
|
||||
.check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
|
||||
.await
|
||||
{
|
||||
Err(err) if err.to_string().contains(test.error) => {}
|
||||
Err(err) => {
|
||||
panic!("expected {:?}, got {err:?}", test.error)
|
||||
}
|
||||
Ok(()) => {
|
||||
panic!("expected {:?}, got ok", test.error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,17 +14,15 @@ use crate::{
|
||||
EndpointId,
|
||||
};
|
||||
|
||||
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
|
||||
use super::jwt::{AuthRule, FetchAuthRules};
|
||||
|
||||
pub struct LocalBackend {
|
||||
pub(crate) jwks_cache: JwkCache,
|
||||
pub(crate) node_info: NodeInfo,
|
||||
}
|
||||
|
||||
impl LocalBackend {
|
||||
pub fn new(postgres_addr: SocketAddr) -> Self {
|
||||
LocalBackend {
|
||||
jwks_cache: JwkCache::default(),
|
||||
node_info: NodeInfo {
|
||||
config: {
|
||||
let mut cfg = ConnCfg::new();
|
||||
|
||||
@@ -6,7 +6,10 @@ use compute_api::spec::LocalProxySpec;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::Either;
|
||||
use proxy::{
|
||||
auth::backend::local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
auth::backend::{
|
||||
jwt::JwkCache,
|
||||
local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
},
|
||||
cancellation::CancellationHandlerMain,
|
||||
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
|
||||
console::{
|
||||
@@ -267,12 +270,15 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
allow_self_signed_compute: false,
|
||||
http_config,
|
||||
authentication_config: AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(0),
|
||||
scram_protocol_timeout: Duration::from_secs(10),
|
||||
rate_limiter_enabled: false,
|
||||
rate_limiter: BucketRateLimiter::new(vec![]),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: true,
|
||||
},
|
||||
require_client_ip: false,
|
||||
handshake_timeout: Duration::from_secs(10),
|
||||
|
||||
@@ -8,6 +8,7 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
|
||||
use aws_config::Region;
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::AuthRateLimiter;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::cancellation::CancelMap;
|
||||
@@ -102,6 +103,9 @@ struct ProxyCliArgs {
|
||||
default_value = "http://localhost:3000/authenticate_proxy_request/"
|
||||
)]
|
||||
auth_endpoint: String,
|
||||
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
is_auth_broker: bool,
|
||||
/// path to TLS key for client postgres connections
|
||||
///
|
||||
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||
@@ -382,9 +386,27 @@ async fn main() -> anyhow::Result<()> {
|
||||
info!("Starting mgmt on {mgmt_address}");
|
||||
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
|
||||
|
||||
let proxy_address: SocketAddr = args.proxy.parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
let proxy_listener = if !args.is_auth_broker {
|
||||
let proxy_address: SocketAddr = args.proxy.parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
|
||||
Some(TcpListener::bind(proxy_address).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
let serverless_listener = if let Some(serverless_address) = args.wss {
|
||||
let serverless_address: SocketAddr = serverless_address.parse()?;
|
||||
info!("Starting wss on {serverless_address}");
|
||||
Some(TcpListener::bind(serverless_address).await?)
|
||||
} else if args.is_auth_broker {
|
||||
bail!("wss arg must be present for auth-broker")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let cancel_map = CancelMap::default();
|
||||
@@ -430,21 +452,17 @@ async fn main() -> anyhow::Result<()> {
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
if let Some(serverless_address) = args.wss {
|
||||
let serverless_address: SocketAddr = serverless_address.parse()?;
|
||||
info!("Starting wss on {serverless_address}");
|
||||
let serverless_listener = TcpListener::bind(serverless_address).await?;
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(serverless_listener) = serverless_listener {
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
serverless_listener,
|
||||
@@ -674,7 +692,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
)?;
|
||||
|
||||
let http_config = HttpConfig {
|
||||
accept_websockets: true,
|
||||
accept_websockets: !args.is_auth_broker,
|
||||
pool_options: GlobalConnPoolOptions {
|
||||
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
|
||||
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
|
||||
@@ -689,12 +707,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_auth_broker: args.is_auth_broker,
|
||||
accept_jwts: args.is_auth_broker,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::{
|
||||
auth::{self, backend::AuthRateLimiter},
|
||||
auth::{
|
||||
self,
|
||||
backend::{jwt::JwkCache, AuthRateLimiter},
|
||||
},
|
||||
console::locks::ApiLocks,
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
||||
scram::threadpool::ThreadPool,
|
||||
@@ -67,6 +70,9 @@ pub struct AuthenticationConfig {
|
||||
pub rate_limiter: AuthRateLimiter,
|
||||
pub rate_limit_ip_subnet: u8,
|
||||
pub ip_allowlist_check_enabled: bool,
|
||||
pub jwks_cache: JwkCache,
|
||||
pub is_auth_broker: bool,
|
||||
pub accept_jwts: bool,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
@@ -250,18 +256,26 @@ impl CertResolver {
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
let common_name = if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
.context("Failed to parse common name from certificate")?;
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ pub mod neon;
|
||||
use super::messages::{ConsoleError, MetricsAuxInfo};
|
||||
use crate::{
|
||||
auth::{
|
||||
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
||||
backend::{
|
||||
jwt::{AuthRule, FetchAuthRules},
|
||||
ComputeCredentialKeys, ComputeUserInfo,
|
||||
},
|
||||
IpPattern,
|
||||
},
|
||||
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
@@ -16,7 +19,7 @@ use crate::{
|
||||
intern::ProjectIdInt,
|
||||
metrics::ApiLockMetrics,
|
||||
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
|
||||
scram, EndpointCacheKey,
|
||||
scram, EndpointCacheKey, EndpointId,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::{hash::Hash, sync::Arc, time::Duration};
|
||||
@@ -334,6 +337,12 @@ pub(crate) trait Api {
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
@@ -343,6 +352,7 @@ pub(crate) trait Api {
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone)]
|
||||
pub enum ConsoleBackend {
|
||||
/// Current Cloud API (V2).
|
||||
Console(neon::Api),
|
||||
@@ -386,6 +396,20 @@ impl Api for ConsoleBackend {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
match self {
|
||||
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_api) => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
@@ -552,3 +576,13 @@ impl WakeComputePermit {
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl FetchAuthRules for ConsoleBackend {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
self.get_endpoint_jwks(ctx, endpoint).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||
};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::{
|
||||
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
|
||||
};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::IpPattern, cache::Cached};
|
||||
use crate::{
|
||||
@@ -118,6 +120,39 @@ impl Api {
|
||||
})
|
||||
}
|
||||
|
||||
async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
|
||||
let connection = tokio::spawn(connection);
|
||||
|
||||
let res = client.query(
|
||||
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
|
||||
&[&endpoint.as_str()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut rows = vec![];
|
||||
for row in res {
|
||||
rows.push(AuthRule {
|
||||
id: row.get("id"),
|
||||
jwks_url: url::Url::parse(row.get("jwks_url"))?,
|
||||
audience: row.get("audience"),
|
||||
role_names: row
|
||||
.get::<_, Vec<String>>("role_names")
|
||||
.into_iter()
|
||||
.map(RoleName::from)
|
||||
.map(|s| RoleNameInt::from(&s))
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
|
||||
drop(client);
|
||||
connection.await??;
|
||||
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
@@ -185,6 +220,14 @@ impl super::Api for Api {
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
self.do_get_endpoint_jwks(endpoint).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
|
||||
@@ -7,27 +7,33 @@ use super::{
|
||||
NodeInfo,
|
||||
};
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
auth::backend::{jwt::AuthRule, ComputeUserInfo},
|
||||
compute,
|
||||
console::messages::{ColdStartInfo, Reason},
|
||||
console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
|
||||
http,
|
||||
metrics::{CacheOutcome, Metrics},
|
||||
rate_limiter::WakeComputeRateLimiter,
|
||||
scram, EndpointCacheKey,
|
||||
scram, EndpointCacheKey, EndpointId,
|
||||
};
|
||||
use crate::{cache::Cached, context::RequestMonitoring};
|
||||
use ::http::{header::AUTHORIZATION, HeaderName};
|
||||
use anyhow::bail;
|
||||
use futures::TryFutureExt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
|
||||
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
jwt: String,
|
||||
// put in a shared ref so we don't copy secrets all over in memory
|
||||
jwt: Arc<str>,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
@@ -38,7 +44,9 @@ impl Api {
|
||||
locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
) -> Self {
|
||||
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default();
|
||||
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
|
||||
.unwrap_or_default()
|
||||
.into();
|
||||
Self {
|
||||
endpoint,
|
||||
caches,
|
||||
@@ -71,9 +79,9 @@ impl Api {
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.get_path("proxy_get_role_secret")
|
||||
.header(X_REQUEST_ID, &request_id)
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
@@ -125,6 +133,61 @@ impl Api {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
if !self
|
||||
.caches
|
||||
.endpoints_cache
|
||||
.is_valid(ctx, &endpoint.normalize())
|
||||
.await
|
||||
{
|
||||
bail!("endpoint not found");
|
||||
}
|
||||
let request_id = ctx.session_id().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get_with_url(|url| {
|
||||
url.path_segments_mut()
|
||||
.push("endpoints")
|
||||
.push(endpoint.as_str())
|
||||
.push("jwks");
|
||||
})
|
||||
.header(X_REQUEST_ID, &request_id)
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
|
||||
let body = parse_body::<EndpointJwksResponse>(response).await?;
|
||||
|
||||
let rules = body
|
||||
.jwks
|
||||
.into_iter()
|
||||
.map(|jwks| AuthRule {
|
||||
id: jwks.id,
|
||||
jwks_url: jwks.jwks_url,
|
||||
audience: jwks.jwt_audience,
|
||||
role_names: jwks.role_names,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(rules)
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
@@ -135,7 +198,7 @@ impl Api {
|
||||
async {
|
||||
let mut request_builder = self
|
||||
.endpoint
|
||||
.get("proxy_wake_compute")
|
||||
.get_path("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
@@ -262,6 +325,15 @@ impl super::Api for Api {
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
self.do_get_endpoint_jwks(ctx, endpoint).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
|
||||
@@ -86,9 +86,17 @@ impl Endpoint {
|
||||
|
||||
/// Return a [builder](RequestBuilder) for a `GET` request,
|
||||
/// appending a single `path` segment to the base endpoint URL.
|
||||
pub(crate) fn get(&self, path: &str) -> RequestBuilder {
|
||||
pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
|
||||
self.get_with_url(|u| {
|
||||
u.path_segments_mut().push(path);
|
||||
})
|
||||
}
|
||||
|
||||
/// Return a [builder](RequestBuilder) for a `GET` request,
|
||||
/// accepting a closure to modify the url path segments for more complex paths queries.
|
||||
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push(path);
|
||||
f(&mut url);
|
||||
self.client.get(url.into_inner())
|
||||
}
|
||||
|
||||
@@ -144,7 +152,7 @@ mod tests {
|
||||
|
||||
// Validate that this pattern makes sense.
|
||||
let req = endpoint
|
||||
.get("frobnicate")
|
||||
.get_path("frobnicate")
|
||||
.query(&[
|
||||
("foo", Some("10")), // should be just `foo=10`
|
||||
("bar", None), // shouldn't be passed at all
|
||||
@@ -162,7 +170,7 @@ mod tests {
|
||||
let endpoint = Endpoint::new(url, Client::new());
|
||||
|
||||
let req = endpoint
|
||||
.get("frobnicate")
|
||||
.get_path("frobnicate")
|
||||
.query(&[("session_id", uuid::Uuid::nil())])
|
||||
.build()?;
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::{
|
||||
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
|
||||
any::type_name, hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index,
|
||||
sync::OnceLock,
|
||||
};
|
||||
|
||||
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
|
||||
@@ -16,12 +17,21 @@ pub struct StringInterner<Id> {
|
||||
_id: PhantomData<Id>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
|
||||
#[derive(PartialEq, Clone, Copy, Eq, Hash)]
|
||||
pub struct InternedString<Id> {
|
||||
inner: Spur,
|
||||
_id: PhantomData<Id>,
|
||||
}
|
||||
|
||||
impl<Id: InternId> std::fmt::Debug for InternedString<Id> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_tuple("InternedString")
|
||||
.field(&type_name::<Id>())
|
||||
.field(&self.as_str())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: InternId> std::fmt::Display for InternedString<Id> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.as_str().fmt(f)
|
||||
|
||||
@@ -525,6 +525,10 @@ impl TestBackend for TestConnectMechanism {
|
||||
{
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
|
||||
fn dyn_clone(&self) -> Box<dyn TestBackend> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
mod backend;
|
||||
pub mod cancel_set;
|
||||
mod conn_pool;
|
||||
mod http_conn_pool;
|
||||
mod http_util;
|
||||
mod json;
|
||||
mod sql_over_http;
|
||||
@@ -19,7 +20,8 @@ use anyhow::Context;
|
||||
use futures::future::{select, Either};
|
||||
use futures::TryFutureExt;
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use hyper1::body::Incoming;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
@@ -81,7 +83,28 @@ pub async fn task_main(
|
||||
}
|
||||
});
|
||||
|
||||
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
|
||||
{
|
||||
let http_conn_pool = Arc::clone(&http_conn_pool);
|
||||
tokio::spawn(async move {
|
||||
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
|
||||
});
|
||||
}
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::spawn({
|
||||
let cancellation_token = cancellation_token.clone();
|
||||
let http_conn_pool = http_conn_pool.clone();
|
||||
async move {
|
||||
cancellation_token.cancelled().await;
|
||||
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let backend = Arc::new(PoolingBackend {
|
||||
http_conn_pool: Arc::clone(&http_conn_pool),
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
@@ -342,7 +365,7 @@ async fn request_handler(
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
@@ -386,7 +409,7 @@ async fn request_handler(
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
|
||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
@@ -409,7 +432,7 @@ async fn request_handler(
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Full::new(Bytes::new()))
|
||||
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{io, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use tokio::net::{lookup_host, TcpStream};
|
||||
use tracing::{field::display, info};
|
||||
|
||||
use crate::{
|
||||
@@ -27,9 +29,13 @@ use crate::{
|
||||
Host,
|
||||
};
|
||||
|
||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||
use super::{
|
||||
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
|
||||
http_conn_pool::{self, poll_http2_client},
|
||||
};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -103,32 +109,44 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_jwt(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
user_info: &ComputeUserInfo,
|
||||
jwt: &str,
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
jwt: String,
|
||||
) -> Result<(), AuthError> {
|
||||
match &self.config.auth_backend {
|
||||
crate::auth::Backend::Console(_, ()) => {
|
||||
Err(AuthError::auth_failed("JWT login is not yet supported"))
|
||||
crate::auth::Backend::Console(console, ()) => {
|
||||
config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
user_info.endpoint.clone(),
|
||||
&user_info.user,
|
||||
&**console,
|
||||
&jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
|
||||
"JWT login over web auth proxy is not supported",
|
||||
)),
|
||||
crate::auth::Backend::Local(cache) => {
|
||||
cache
|
||||
crate::auth::Backend::Local(_) => {
|
||||
config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
user_info.endpoint.clone(),
|
||||
&user_info.user,
|
||||
&StaticAuthRules,
|
||||
jwt,
|
||||
&jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
Ok(ComputeCredentials {
|
||||
info: user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
})
|
||||
|
||||
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -174,14 +192,55 @@ impl PoolingBackend {
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Wake up the destination if needed
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
pub(crate) async fn connect_to_local_proxy(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client, HttpConnError> {
|
||||
info!("pool: looking for an existing connection");
|
||||
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self
|
||||
.config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|()| ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.http_conn_pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&backend,
|
||||
false, // do not allow self signed compute for http flow
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
#[error("could not connection to compute")]
|
||||
ConnectionError(#[from] tokio_postgres::Error),
|
||||
#[error("could not connection to postgres in compute")]
|
||||
PostgresConnectionError(#[from] tokio_postgres::Error),
|
||||
#[error("could not connection to local-proxy in compute")]
|
||||
LocalProxyConnectionError(#[from] LocalProxyConnError),
|
||||
|
||||
#[error("could not get auth info")]
|
||||
GetAuthInfo(#[from] GetAuthInfoError),
|
||||
@@ -193,11 +252,20 @@ pub(crate) enum HttpConnError {
|
||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum LocalProxyConnError {
|
||||
#[error("error with connection to local-proxy")]
|
||||
Io(#[source] std::io::Error),
|
||||
#[error("could not establish h2 connection")]
|
||||
H2(#[from] hyper1::Error),
|
||||
}
|
||||
|
||||
impl ReportableError for HttpConnError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
||||
HttpConnError::ConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
|
||||
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
||||
HttpConnError::AuthError(a) => a.get_error_kind(),
|
||||
HttpConnError::WakeCompute(w) => w.get_error_kind(),
|
||||
@@ -210,7 +278,8 @@ impl UserFacingError for HttpConnError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
||||
HttpConnError::ConnectionError(p) => p.to_string(),
|
||||
HttpConnError::PostgresConnectionError(p) => p.to_string(),
|
||||
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
|
||||
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
||||
HttpConnError::AuthError(c) => c.to_string_client(),
|
||||
HttpConnError::WakeCompute(c) => c.to_string_client(),
|
||||
@@ -224,7 +293,8 @@ impl UserFacingError for HttpConnError {
|
||||
impl CouldRetry for HttpConnError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
HttpConnError::ConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => false,
|
||||
HttpConnError::GetAuthInfo(_) => false,
|
||||
HttpConnError::AuthError(_) => false,
|
||||
@@ -236,7 +306,7 @@ impl CouldRetry for HttpConnError {
|
||||
impl ShouldRetryWakeCompute for HttpConnError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
|
||||
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
|
||||
// we never checked cache validity
|
||||
HttpConnError::TooManyConnectionAttempts(_) => false,
|
||||
_ => true,
|
||||
@@ -244,6 +314,38 @@ impl ShouldRetryWakeCompute for HttpConnError {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for LocalProxyConnError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => ErrorKind::Compute,
|
||||
LocalProxyConnError::H2(_) => ErrorKind::Compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for LocalProxyConnError {
|
||||
fn to_string_client(&self) -> String {
|
||||
"Could not establish HTTP connection to the database".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for LocalProxyConnError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => false,
|
||||
LocalProxyConnError::H2(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for LocalProxyConnError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => false,
|
||||
LocalProxyConnError::H2(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
conn_info: ConnInfo,
|
||||
@@ -293,3 +395,99 @@ impl ConnectMechanism for TokioMechanism {
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
struct HyperMechanism {
|
||||
pool: Arc<http_conn_pool::GlobalConnPool>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
locks: &'static ApiLocks<Host>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for HyperMechanism {
|
||||
type Connection = http_conn_pool::Client;
|
||||
type ConnectError = HttpConnError;
|
||||
type Error = HttpConnError;
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host()?;
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
|
||||
let res = connect_http2(&host, 10432, timeout).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
Ok(poll_http2_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
&self.conn_info,
|
||||
client,
|
||||
connection,
|
||||
self.conn_id,
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
async fn connect_http2(
|
||||
host: &str,
|
||||
port: u16,
|
||||
timeout: Duration,
|
||||
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
|
||||
// assumption: host is an ip address so this should not actually perform any requests.
|
||||
// todo: add that assumption as a guarantee in the control-plane API.
|
||||
let mut addrs = lookup_host((host, port))
|
||||
.await
|
||||
.map_err(LocalProxyConnError::Io)?;
|
||||
|
||||
let mut last_err = None;
|
||||
|
||||
let stream = loop {
|
||||
let Some(addr) = addrs.next() else {
|
||||
return Err(last_err.unwrap_or_else(|| {
|
||||
LocalProxyConnError::Io(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"could not resolve any addresses",
|
||||
))
|
||||
}));
|
||||
};
|
||||
|
||||
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
|
||||
Ok(Ok(stream)) => {
|
||||
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
|
||||
break stream;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
last_err = Some(LocalProxyConnError::Io(e));
|
||||
}
|
||||
Err(e) => {
|
||||
last_err = Some(LocalProxyConnError::Io(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
e,
|
||||
)));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive_interval(Duration::from_secs(20))
|
||||
.keep_alive_while_idle(true)
|
||||
.keep_alive_timeout(Duration::from_secs(5))
|
||||
.handshake(TokioIo::new(stream))
|
||||
.await?;
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
335
proxy/src/serverless/http_conn_pool.rs
Normal file
335
proxy/src/serverless/http_conn_pool.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
use dashmap::DashMap;
|
||||
use hyper1::client::conn::http2;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use parking_lot::RwLock;
|
||||
use rand::Rng;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::{self, AtomicUsize};
|
||||
use std::{sync::Arc, sync::Weak};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{context::RequestMonitoring, EndpointCacheKey};
|
||||
|
||||
use tracing::{debug, error};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
|
||||
pub(crate) type Connect =
|
||||
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConnPoolEntry {
|
||||
conn: Send,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
// Per-endpoint connection pool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool {
|
||||
conns: VecDeque<ConnPoolEntry>,
|
||||
_guard: HttpEndpointPoolsGuard<'static>,
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl EndpointConnPool {
|
||||
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
|
||||
let Self { conns, .. } = self;
|
||||
|
||||
let conn = conns.pop_front()?;
|
||||
conns.push_back(conn.clone());
|
||||
Some(conn)
|
||||
}
|
||||
|
||||
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
|
||||
let Self {
|
||||
conns,
|
||||
global_connections_count,
|
||||
..
|
||||
} = self;
|
||||
|
||||
let old_len = conns.len();
|
||||
conns.retain(|conn| conn.conn_id != conn_id);
|
||||
let new_len = conns.len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
}
|
||||
removed > 0
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EndpointConnPool {
|
||||
fn drop(&mut self) {
|
||||
if !self.conns.is_empty() {
|
||||
self.global_connections_count
|
||||
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(self.conns.len() as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct GlobalConnPool {
|
||||
// endpoint -> per-endpoint connection pool
|
||||
//
|
||||
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||
// pool as early as possible and release the lock.
|
||||
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
|
||||
|
||||
/// Number of endpoint-connection pools
|
||||
///
|
||||
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
|
||||
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
|
||||
/// It's only used for diagnostics.
|
||||
global_pool_size: AtomicUsize,
|
||||
|
||||
/// Total number of connections in the pool
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
|
||||
config: &'static crate::config::HttpConfig,
|
||||
}
|
||||
|
||||
impl GlobalConnPool {
|
||||
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
let shards = config.pool_options.pool_shards;
|
||||
Arc::new(Self {
|
||||
global_pool: DashMap::with_shard_amount(shards),
|
||||
global_pool_size: AtomicUsize::new(0),
|
||||
config,
|
||||
global_connections_count: Arc::new(AtomicUsize::new(0)),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&self) {
|
||||
// drops all strong references to endpoint-pools
|
||||
self.global_pool.clear();
|
||||
}
|
||||
|
||||
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
|
||||
let epoch = self.config.pool_options.gc_epoch;
|
||||
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let shard = rng.gen_range(0..self.global_pool.shards().len());
|
||||
self.gc(shard);
|
||||
}
|
||||
}
|
||||
|
||||
fn gc(&self, shard: usize) {
|
||||
debug!(shard, "pool: performing epoch reclamation");
|
||||
|
||||
// acquire a random shard lock
|
||||
let mut shard = self.global_pool.shards()[shard].write();
|
||||
|
||||
let timer = Metrics::get()
|
||||
.proxy
|
||||
.http_pool_reclaimation_lag_seconds
|
||||
.start_timer();
|
||||
let current_len = shard.len();
|
||||
let mut clients_removed = 0;
|
||||
shard.retain(|endpoint, x| {
|
||||
// if the current endpoint pool is unique (no other strong or weak references)
|
||||
// then it is currently not in use by any connections.
|
||||
if let Some(pool) = Arc::get_mut(x.get_mut()) {
|
||||
let EndpointConnPool { conns, .. } = pool.get_mut();
|
||||
|
||||
let old_len = conns.len();
|
||||
|
||||
conns.retain(|conn| !conn.conn.is_closed());
|
||||
|
||||
let new_len = conns.len();
|
||||
let removed = old_len - new_len;
|
||||
clients_removed += removed;
|
||||
|
||||
// we only remove this pool if it has no active connections
|
||||
if conns.is_empty() {
|
||||
info!("pool: discarding pool for endpoint {endpoint}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
});
|
||||
|
||||
let new_len = shard.len();
|
||||
drop(shard);
|
||||
timer.observe();
|
||||
|
||||
// Do logging outside of the lock.
|
||||
if clients_removed > 0 {
|
||||
let size = self
|
||||
.global_connections_count
|
||||
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
|
||||
- clients_removed;
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(clients_removed as i64);
|
||||
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
|
||||
}
|
||||
let removed = current_len - new_len;
|
||||
|
||||
if removed > 0 {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_sub(removed, atomic::Ordering::Relaxed)
|
||||
- removed;
|
||||
info!("pool: performed global pool gc. size now {global_pool_size}");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Option<Client> {
|
||||
let endpoint = conn_info.endpoint_cache_key()?;
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
let client = endpoint_pool.write().get_conn_entry()?;
|
||||
|
||||
if client.conn.is_closed() {
|
||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return None;
|
||||
}
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
Some(Client::new(client.conn, client.aux))
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(
|
||||
self: &Arc<Self>,
|
||||
endpoint: &EndpointCacheKey,
|
||||
) -> Arc<RwLock<EndpointConnPool>> {
|
||||
// fast path
|
||||
if let Some(pool) = self.global_pool.get(endpoint) {
|
||||
return pool.clone();
|
||||
}
|
||||
|
||||
// slow path
|
||||
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
|
||||
conns: VecDeque::new(),
|
||||
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
|
||||
global_connections_count: self.global_connections_count.clone(),
|
||||
}));
|
||||
|
||||
// find or create a pool for this endpoint
|
||||
let mut created = false;
|
||||
let pool = self
|
||||
.global_pool
|
||||
.entry(endpoint.clone())
|
||||
.or_insert_with(|| {
|
||||
created = true;
|
||||
new_pool
|
||||
})
|
||||
.clone();
|
||||
|
||||
// log new global pool size
|
||||
if created {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_add(1, atomic::Ordering::Relaxed)
|
||||
+ 1;
|
||||
info!(
|
||||
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
|
||||
);
|
||||
}
|
||||
|
||||
pool
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn poll_http2_client(
|
||||
global_pool: Arc<GlobalConnPool>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
client: Send,
|
||||
connection: Connect,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let session_id = ctx.session_id();
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
});
|
||||
|
||||
let pool = match conn_info.endpoint_cache_key() {
|
||||
Some(endpoint) => {
|
||||
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
|
||||
|
||||
pool.write().conns.push_back(ConnPoolEntry {
|
||||
conn: client.clone(),
|
||||
conn_id,
|
||||
aux: aux.clone(),
|
||||
});
|
||||
|
||||
Arc::downgrade(&pool)
|
||||
}
|
||||
None => Weak::new(),
|
||||
};
|
||||
|
||||
// let idle = global_pool.get_idle_timeout();
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let _conn_gauge = conn_gauge;
|
||||
let res = connection.await;
|
||||
match res {
|
||||
Ok(()) => info!("connection closed"),
|
||||
Err(e) => error!(%session_id, "connection error: {}", e),
|
||||
}
|
||||
|
||||
// remove from connection pool
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
if pool.write().remove_conn(conn_id) {
|
||||
info!("closed connection removed");
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
|
||||
Client::new(client, aux)
|
||||
}
|
||||
|
||||
pub(crate) struct Client {
|
||||
pub(crate) inner: Send,
|
||||
aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
|
||||
Self { inner, aux }
|
||||
}
|
||||
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: self.aux.endpoint_id,
|
||||
branch_id: self.aux.branch_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,13 +5,13 @@ use bytes::Bytes;
|
||||
|
||||
use anyhow::Context;
|
||||
use http::{Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
|
||||
use serde::Serialize;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
/// Like [`ApiError::into_response`]
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
match this {
|
||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||
@@ -64,17 +64,24 @@ struct HttpErrorBody {
|
||||
|
||||
impl HttpErrorBody {
|
||||
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
|
||||
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
fn response_from_msg_and_status(
|
||||
msg: String,
|
||||
status: StatusCode,
|
||||
) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
HttpErrorBody { msg }.to_response(status)
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
|
||||
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
|
||||
.body(
|
||||
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
@@ -83,14 +90,14 @@ impl HttpErrorBody {
|
||||
pub(crate) fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
.body(Full::new(Bytes::from(json)))
|
||||
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ use futures::future::Either;
|
||||
use futures::StreamExt;
|
||||
use futures::TryFutureExt;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http::Method;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Body;
|
||||
@@ -38,9 +40,11 @@ use url::Url;
|
||||
use urlencoding;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::auth::backend::ComputeCredentials;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
@@ -56,6 +60,7 @@ use crate::usage_metrics::MetricCounterRecorder;
|
||||
use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::backend::LocalProxyConnError;
|
||||
use super::backend::PoolingBackend;
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool::Client;
|
||||
@@ -123,8 +128,8 @@ pub(crate) enum ConnInfoError {
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing password")]
|
||||
MissingPassword,
|
||||
#[error("missing authentication credentials: {0}")]
|
||||
MissingCredentials(Credentials),
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
@@ -133,6 +138,14 @@ pub(crate) enum ConnInfoError {
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum Credentials {
|
||||
#[error("required password")]
|
||||
Password,
|
||||
#[error("required authorization bearer token in JWT format")]
|
||||
BearerJwt,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
@@ -146,6 +159,7 @@ impl UserFacingError for ConnInfoError {
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
tls: Option<&TlsConfig>,
|
||||
@@ -181,21 +195,32 @@ fn get_conn_info(
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||
if !config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
}
|
||||
|
||||
let auth = auth
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||
AuthData::Jwt(
|
||||
auth.strip_prefix("Bearer ")
|
||||
.ok_or(ConnInfoError::MissingPassword)?
|
||||
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
|
||||
.into(),
|
||||
)
|
||||
} else if let Some(pass) = connection_url.password() {
|
||||
// wrong credentials provided
|
||||
if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
}
|
||||
|
||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
})
|
||||
} else if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
} else {
|
||||
return Err(ConnInfoError::MissingPassword);
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
};
|
||||
|
||||
let endpoint = match connection_url.host() {
|
||||
@@ -247,7 +272,7 @@ pub(crate) async fn handle(
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
@@ -279,7 +304,7 @@ pub(crate) async fn handle(
|
||||
|
||||
let mut message = e.to_string_client();
|
||||
let db_error = match &e {
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
@@ -504,7 +529,7 @@ async fn handle_inner(
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
let _requeset_gauge = Metrics::get()
|
||||
.proxy
|
||||
.connection_requests
|
||||
@@ -514,18 +539,50 @@ async fn handle_inner(
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
|
||||
// TLS config should be there.
|
||||
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
|
||||
let conn_info = get_conn_info(
|
||||
&config.authentication_config,
|
||||
ctx,
|
||||
request.headers(),
|
||||
config.tls_config.as_ref(),
|
||||
)?;
|
||||
info!(
|
||||
user = conn_info.conn_info.user_info.user.as_str(),
|
||||
"credentials"
|
||||
);
|
||||
|
||||
match conn_info.auth {
|
||||
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
}
|
||||
auth => {
|
||||
handle_db_inner(
|
||||
cancel,
|
||||
config,
|
||||
ctx,
|
||||
request,
|
||||
conn_info.conn_info,
|
||||
auth,
|
||||
backend,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_db_inner(
|
||||
cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
auth: AuthData,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
|
||||
// Allow connection pooling only if explicitly requested
|
||||
// or if we have decided that http pool is no longer opt-in
|
||||
let allow_pool = !config.http_config.pool_options.opt_in
|
||||
@@ -563,26 +620,36 @@ async fn handle_inner(
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let keys = match &conn_info.auth {
|
||||
let keys = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
.authenticate_with_password(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.conn_info.user_info,
|
||||
pw,
|
||||
&conn_info.user_info,
|
||||
&pw,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
|
||||
.await?
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await?;
|
||||
|
||||
ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
@@ -640,7 +707,11 @@ async fn handle_inner(
|
||||
|
||||
let len = json_output.len();
|
||||
let response = response
|
||||
.body(Full::new(Bytes::from(json_output)))
|
||||
.body(
|
||||
Full::new(Bytes::from(json_output))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
)
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
@@ -656,6 +727,65 @@ async fn handle_inner(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
&AUTHORIZATION,
|
||||
&CONN_STRING,
|
||||
&RAW_TEXT_OUTPUT,
|
||||
&ARRAY_MODE,
|
||||
&TXN_ISOLATION_LEVEL,
|
||||
&TXN_READ_ONLY,
|
||||
&TXN_DEFERRABLE,
|
||||
];
|
||||
|
||||
async fn handle_auth_broker_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
jwt: String,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||
|
||||
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||
|
||||
let (mut parts, body) = request.into_parts();
|
||||
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
|
||||
|
||||
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
|
||||
// these instead just to ensure they remain normalised.
|
||||
for &h in HEADERS_TO_FORWARD {
|
||||
if let Some(hv) = parts.headers.remove(h) {
|
||||
req = req.header(h, hv);
|
||||
}
|
||||
}
|
||||
|
||||
let req = req
|
||||
.body(body)
|
||||
.expect("all headers and params received via hyper should be valid for request");
|
||||
|
||||
// todo: map body to count egress
|
||||
let _metrics = client.metrics();
|
||||
|
||||
Ok(client
|
||||
.inner
|
||||
.send_request(req)
|
||||
.await
|
||||
.map_err(LocalProxyConnError::from)
|
||||
.map_err(HttpConnError::from)?
|
||||
.map(|b| b.boxed()))
|
||||
}
|
||||
|
||||
impl QueryData {
|
||||
async fn process(
|
||||
self,
|
||||
@@ -705,7 +835,9 @@ impl QueryData {
|
||||
// query failed or was cancelled.
|
||||
Ok(Err(error)) => {
|
||||
let db_error = match &error {
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
||||
SqlOverHttpError::ConnectCompute(
|
||||
HttpConnError::PostgresConnectionError(e),
|
||||
)
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
@@ -56,6 +56,7 @@ from _pytest.fixtures import FixtureRequest
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
from psycopg2.extensions import cursor as PgCursor
|
||||
from psycopg2.extensions import make_dsn, parse_dsn
|
||||
from pytest_httpserver import HTTPServer
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from fixtures import overlayfs
|
||||
@@ -440,9 +441,9 @@ class NeonEnvBuilder:
|
||||
|
||||
self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine
|
||||
|
||||
self.pageserver_default_tenant_config_compaction_algorithm: Optional[
|
||||
Dict[str, Any]
|
||||
] = pageserver_default_tenant_config_compaction_algorithm
|
||||
self.pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]] = (
|
||||
pageserver_default_tenant_config_compaction_algorithm
|
||||
)
|
||||
if self.pageserver_default_tenant_config_compaction_algorithm is not None:
|
||||
log.debug(
|
||||
f"Overriding pageserver default compaction algorithm to {self.pageserver_default_tenant_config_compaction_algorithm}"
|
||||
@@ -1072,9 +1073,9 @@ class NeonEnv:
|
||||
ps_cfg["virtual_file_io_engine"] = self.pageserver_virtual_file_io_engine
|
||||
if config.pageserver_default_tenant_config_compaction_algorithm is not None:
|
||||
tenant_config = ps_cfg.setdefault("tenant_config", {})
|
||||
tenant_config[
|
||||
"compaction_algorithm"
|
||||
] = config.pageserver_default_tenant_config_compaction_algorithm
|
||||
tenant_config["compaction_algorithm"] = (
|
||||
config.pageserver_default_tenant_config_compaction_algorithm
|
||||
)
|
||||
|
||||
if self.pageserver_remote_storage is not None:
|
||||
ps_cfg["remote_storage"] = remote_storage_to_toml_dict(
|
||||
@@ -1117,9 +1118,9 @@ class NeonEnv:
|
||||
if config.auth_enabled:
|
||||
sk_cfg["auth_enabled"] = True
|
||||
if self.safekeepers_remote_storage is not None:
|
||||
sk_cfg[
|
||||
"remote_storage"
|
||||
] = self.safekeepers_remote_storage.to_toml_inline_table().strip()
|
||||
sk_cfg["remote_storage"] = (
|
||||
self.safekeepers_remote_storage.to_toml_inline_table().strip()
|
||||
)
|
||||
self.safekeepers.append(
|
||||
Safekeeper(env=self, id=id, port=port, extra_opts=config.safekeeper_extra_opts)
|
||||
)
|
||||
@@ -3572,6 +3573,20 @@ class NeonProxy(PgProtocol):
|
||||
]
|
||||
return args
|
||||
|
||||
class AuthBroker(AuthBackend):
|
||||
def __init__(self, endpoint: str):
|
||||
self.endpoint = endpoint
|
||||
|
||||
def extra_args(self) -> list[str]:
|
||||
args = [
|
||||
# Console auth backend params
|
||||
*["--auth-backend", "console"],
|
||||
*["--auth-endpoint", self.endpoint],
|
||||
*["--sql-over-http-pool-opt-in", "false"],
|
||||
*["--is-auth-broker"],
|
||||
]
|
||||
return args
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Postgres(AuthBackend):
|
||||
pg_conn_url: str
|
||||
@@ -3600,7 +3615,7 @@ class NeonProxy(PgProtocol):
|
||||
metric_collection_interval: Optional[str] = None,
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
domain = "proxy.localtest.me" # resolves to 127.0.0.1
|
||||
domain = "ep-foo-bar-1234.localtest.me" # resolves to 127.0.0.1
|
||||
super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port)
|
||||
|
||||
self.domain = domain
|
||||
@@ -3886,6 +3901,50 @@ def static_proxy(
|
||||
yield proxy
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def static_auth_broker(
|
||||
vanilla_pg: VanillaPostgres,
|
||||
port_distributor: PortDistributor,
|
||||
neon_binpath: Path,
|
||||
test_output_dir: Path,
|
||||
httpserver: HTTPServer,
|
||||
) -> Iterator[NeonProxy]:
|
||||
"""Neon proxy that routes directly to vanilla postgres."""
|
||||
|
||||
auth_endpoint = httpserver.url_for("/cplane")
|
||||
|
||||
port = vanilla_pg.default_options["port"]
|
||||
host = vanilla_pg.default_options["host"]
|
||||
|
||||
httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json(
|
||||
{
|
||||
"address": f"{host}:{port}",
|
||||
"aux": {
|
||||
"endpoint_id": "ep-foo-bar-1234",
|
||||
"branch_id": "br-foo-bar",
|
||||
"project_id": "foo-bar",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
proxy_port = port_distributor.get_port()
|
||||
mgmt_port = port_distributor.get_port()
|
||||
http_port = port_distributor.get_port()
|
||||
external_http_port = port_distributor.get_port()
|
||||
|
||||
with NeonProxy(
|
||||
neon_binpath=neon_binpath,
|
||||
test_output_dir=test_output_dir,
|
||||
proxy_port=proxy_port,
|
||||
http_port=http_port,
|
||||
mgmt_port=mgmt_port,
|
||||
external_http_port=external_http_port,
|
||||
auth_backend=NeonProxy.AuthBroker(auth_endpoint),
|
||||
) as proxy:
|
||||
proxy.start()
|
||||
yield proxy
|
||||
|
||||
|
||||
class Endpoint(PgProtocol, LogUtils):
|
||||
"""An object representing a Postgres compute endpoint managed by the control plane."""
|
||||
|
||||
|
||||
71
test_runner/regress/test_auth_broker.py
Normal file
71
test_runner/regress/test_auth_broker.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import psycopg2
|
||||
import pytest
|
||||
import requests
|
||||
from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres
|
||||
|
||||
GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'"
|
||||
|
||||
|
||||
def test_sql_over_http(static_auth_broker: NeonProxy):
|
||||
static_auth_broker.safe_psql("create role http with login password 'http' superuser")
|
||||
|
||||
def q(sql: str, params: Optional[List[Any]] = None) -> Any:
|
||||
params = params or []
|
||||
connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
|
||||
response = requests.post(
|
||||
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
||||
data=json.dumps({"query": sql, "params": params}),
|
||||
headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
|
||||
verify=str(static_proxy.test_output_dir / "proxy.crt"),
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
return response.json()
|
||||
|
||||
rows = q("select 42 as answer")["rows"]
|
||||
assert rows == [{"answer": 42}]
|
||||
|
||||
rows = q("select $1 as answer", [42])["rows"]
|
||||
assert rows == [{"answer": "42"}]
|
||||
|
||||
rows = q("select $1 * 1 as answer", [42])["rows"]
|
||||
assert rows == [{"answer": 42}]
|
||||
|
||||
rows = q("select $1::int[] as answer", [[1, 2, 3]])["rows"]
|
||||
assert rows == [{"answer": [1, 2, 3]}]
|
||||
|
||||
rows = q("select $1::json->'a' as answer", [{"a": {"b": 42}}])["rows"]
|
||||
assert rows == [{"answer": {"b": 42}}]
|
||||
|
||||
rows = q("select $1::jsonb[] as answer", [[{}]])["rows"]
|
||||
assert rows == [{"answer": [{}]}]
|
||||
|
||||
rows = q("select $1::jsonb[] as answer", [[{"foo": 1}, {"bar": 2}]])["rows"]
|
||||
assert rows == [{"answer": [{"foo": 1}, {"bar": 2}]}]
|
||||
|
||||
rows = q("select * from pg_class limit 1")["rows"]
|
||||
assert len(rows) == 1
|
||||
|
||||
res = q("create table t(id serial primary key, val int)")
|
||||
assert res["command"] == "CREATE"
|
||||
assert res["rowCount"] is None
|
||||
|
||||
res = q("insert into t(val) values (10), (20), (30) returning id")
|
||||
assert res["command"] == "INSERT"
|
||||
assert res["rowCount"] == 3
|
||||
assert res["rows"] == [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
|
||||
res = q("select * from t")
|
||||
assert res["command"] == "SELECT"
|
||||
assert res["rowCount"] == 3
|
||||
|
||||
res = q("drop table t")
|
||||
assert res["command"] == "DROP"
|
||||
assert res["rowCount"] is None
|
||||
|
||||
Reference in New Issue
Block a user