mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-11 14:40:36 +00:00
Compare commits
20 Commits
conrad/pro
...
auth-broke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3f7d0d3f1 | ||
|
|
0724df1d3f | ||
|
|
4d47049b00 | ||
|
|
5687384a8e | ||
|
|
249f5ea17d | ||
|
|
6abcc1f298 | ||
|
|
3e97cf0d6e | ||
|
|
054ef4988b | ||
|
|
5202cd75b5 | ||
|
|
f475dac0e6 | ||
|
|
a4100373e5 | ||
|
|
040d8cf4f6 | ||
|
|
75bfd57e01 | ||
|
|
4bc2686dee | ||
|
|
8e7d2aab76 | ||
|
|
2703abccc7 | ||
|
|
76515cdae3 | ||
|
|
08c7f933a3 | ||
|
|
4ad3aa7c96 | ||
|
|
9c59e3b4b9 |
@@ -80,6 +80,14 @@ pub(crate) trait TestBackend: Send + Sync + 'static {
|
|||||||
fn get_allowed_ips_and_secret(
|
fn get_allowed_ips_and_secret(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
|
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
|
||||||
|
fn dyn_clone(&self) -> Box<dyn TestBackend>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
impl Clone for Box<dyn TestBackend> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
TestBackend::dyn_clone(&**self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Backend<'_, (), ()> {
|
impl std::fmt::Display for Backend<'_, (), ()> {
|
||||||
@@ -557,7 +565,7 @@ mod tests {
|
|||||||
stream::{PqStream, Stream},
|
stream::{PqStream, Stream},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{auth_quirks, AuthRateLimiter};
|
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
|
||||||
|
|
||||||
struct Auth {
|
struct Auth {
|
||||||
ips: Vec<IpPattern>,
|
ips: Vec<IpPattern>,
|
||||||
@@ -585,6 +593,14 @@ mod tests {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_endpoint_jwks(
|
||||||
|
&self,
|
||||||
|
_ctx: &RequestMonitoring,
|
||||||
|
_endpoint: crate::EndpointId,
|
||||||
|
) -> anyhow::Result<Vec<super::jwt::AuthRule>> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
async fn wake_compute(
|
async fn wake_compute(
|
||||||
&self,
|
&self,
|
||||||
_ctx: &RequestMonitoring,
|
_ctx: &RequestMonitoring,
|
||||||
@@ -595,12 +611,15 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||||
|
jwks_cache: JwkCache::default(),
|
||||||
thread_pool: ThreadPool::new(1),
|
thread_pool: ThreadPool::new(1),
|
||||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||||
rate_limiter_enabled: true,
|
rate_limiter_enabled: true,
|
||||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||||
rate_limit_ip_subnet: 64,
|
rate_limit_ip_subnet: 64,
|
||||||
ip_allowlist_check_enabled: true,
|
ip_allowlist_check_enabled: true,
|
||||||
|
is_auth_broker: false,
|
||||||
|
accept_jwts: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use std::{
|
use std::{
|
||||||
|
borrow::Cow,
|
||||||
future::Future,
|
future::Future,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
time::{Duration, SystemTime},
|
time::{Duration, SystemTime},
|
||||||
@@ -8,7 +9,10 @@ use anyhow::{bail, ensure, Context};
|
|||||||
use arc_swap::ArcSwapOption;
|
use arc_swap::ArcSwapOption;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use jose_jwk::crypto::KeyInfo;
|
use jose_jwk::crypto::KeyInfo;
|
||||||
use serde::{Deserialize, Deserializer};
|
use serde::{
|
||||||
|
de::{DeserializeSeed, IgnoredAny, Visitor},
|
||||||
|
Deserializer,
|
||||||
|
};
|
||||||
use signature::Verifier;
|
use signature::Verifier;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
||||||
@@ -33,6 +37,7 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
|
|||||||
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
|
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct AuthRule {
|
pub(crate) struct AuthRule {
|
||||||
pub(crate) id: String,
|
pub(crate) id: String,
|
||||||
pub(crate) jwks_url: url::Url,
|
pub(crate) jwks_url: url::Url,
|
||||||
@@ -303,35 +308,21 @@ impl JwkCacheEntryLock {
|
|||||||
}
|
}
|
||||||
key => bail!("unsupported key type {key:?}"),
|
key => bail!("unsupported key type {key:?}"),
|
||||||
};
|
};
|
||||||
|
tracing::debug!("JWT signature valid");
|
||||||
|
|
||||||
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
|
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
|
||||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||||
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
|
|
||||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
|
||||||
|
|
||||||
tracing::debug!(?payload, "JWT signature valid with claims");
|
let validator = JwtValidator {
|
||||||
|
expected_audience,
|
||||||
|
current_time: SystemTime::now(),
|
||||||
|
clock_skew_leeway: CLOCK_SKEW_LEEWAY,
|
||||||
|
};
|
||||||
|
|
||||||
match (expected_audience, payload.audience) {
|
let payload = validator
|
||||||
// check the audience matches
|
.deserialize(&mut serde_json::Deserializer::from_slice(&payload))?;
|
||||||
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
|
|
||||||
// the audience is expected but is missing
|
|
||||||
(Some(_), None) => bail!("invalid JWT token audience"),
|
|
||||||
// we don't care for the audience field
|
|
||||||
(None, _) => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
let now = SystemTime::now();
|
tracing::debug!(?payload, "JWT claims valid");
|
||||||
|
|
||||||
if let Some(exp) = payload.expiration {
|
|
||||||
ensure!(now < exp + CLOCK_SKEW_LEEWAY, "JWT token has expired");
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(nbf) = payload.not_before {
|
|
||||||
ensure!(
|
|
||||||
nbf < now + CLOCK_SKEW_LEEWAY,
|
|
||||||
"JWT token is not yet ready to use"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -419,37 +410,184 @@ struct JwtHeader<'a> {
|
|||||||
key_id: Option<&'a str>,
|
key_id: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
|
struct JwtValidator<'a> {
|
||||||
#[derive(serde::Deserialize, serde::Serialize, Debug)]
|
expected_audience: Option<&'a str>,
|
||||||
struct JwtPayload<'a> {
|
current_time: SystemTime,
|
||||||
/// Audience - Recipient for which the JWT is intended
|
clock_skew_leeway: Duration,
|
||||||
#[serde(rename = "aud")]
|
|
||||||
audience: Option<&'a str>,
|
|
||||||
/// Expiration - Time after which the JWT expires
|
|
||||||
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
|
|
||||||
expiration: Option<SystemTime>,
|
|
||||||
/// Not before - Time after which the JWT expires
|
|
||||||
#[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
|
|
||||||
not_before: Option<SystemTime>,
|
|
||||||
|
|
||||||
// the following entries are only extracted for the sake of debug logging.
|
|
||||||
/// Issuer of the JWT
|
|
||||||
#[serde(rename = "iss")]
|
|
||||||
issuer: Option<&'a str>,
|
|
||||||
/// Subject of the JWT (the user)
|
|
||||||
#[serde(rename = "sub")]
|
|
||||||
subject: Option<&'a str>,
|
|
||||||
/// Unique token identifier
|
|
||||||
#[serde(rename = "jti")]
|
|
||||||
jwt_id: Option<&'a str>,
|
|
||||||
/// Unique session identifier
|
|
||||||
#[serde(rename = "sid")]
|
|
||||||
session_id: Option<&'a str>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
|
impl<'de> DeserializeSeed<'de> for JwtValidator<'_> {
|
||||||
let d = <Option<u64>>::deserialize(d)?;
|
type Value = JwtPayload<'de>;
|
||||||
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
|
|
||||||
|
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
impl<'de> Visitor<'de> for JwtValidator<'_> {
|
||||||
|
type Value = JwtPayload<'de>;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
formatter.write_str("a JWT payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: serde::de::MapAccess<'de>,
|
||||||
|
{
|
||||||
|
let mut payload = JwtPayload {
|
||||||
|
issuer: None,
|
||||||
|
subject: None,
|
||||||
|
jwt_id: None,
|
||||||
|
session_id: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut aud = false;
|
||||||
|
|
||||||
|
while let Some(key) = map.next_key()? {
|
||||||
|
match key {
|
||||||
|
"iss" if payload.issuer.is_none() => {
|
||||||
|
payload.issuer = Some(map.next_value()?);
|
||||||
|
}
|
||||||
|
"sub" if payload.subject.is_none() => {
|
||||||
|
payload.subject = Some(map.next_value()?);
|
||||||
|
}
|
||||||
|
"jit" if payload.jwt_id.is_none() => {
|
||||||
|
payload.jwt_id = Some(map.next_value()?);
|
||||||
|
}
|
||||||
|
"sid" if payload.session_id.is_none() => {
|
||||||
|
payload.session_id = Some(map.next_value()?);
|
||||||
|
}
|
||||||
|
"exp" => {
|
||||||
|
let exp = map.next_value::<u64>()?;
|
||||||
|
let exp = SystemTime::UNIX_EPOCH + Duration::from_secs(exp);
|
||||||
|
|
||||||
|
if self.current_time > exp + self.clock_skew_leeway {
|
||||||
|
return Err(serde::de::Error::custom("JWT token has expired"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"nbf" => {
|
||||||
|
let nbf = map.next_value::<u64>()?;
|
||||||
|
let nbf = SystemTime::UNIX_EPOCH + Duration::from_secs(nbf);
|
||||||
|
|
||||||
|
if self.current_time + self.clock_skew_leeway < nbf {
|
||||||
|
return Err(serde::de::Error::custom(
|
||||||
|
"JWT token is not yet ready to use",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"aud" => {
|
||||||
|
if let Some(expected_audience) = self.expected_audience {
|
||||||
|
map.next_value_seed(AudienceValidator { expected_audience })?;
|
||||||
|
aud = true;
|
||||||
|
} else {
|
||||||
|
map.next_value::<IgnoredAny>()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => map.next_value::<IgnoredAny>().map(|IgnoredAny| ())?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.expected_audience.is_some() && !aud {
|
||||||
|
return Err(serde::de::Error::custom("invalid JWT token audience"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deserializer.deserialize_map(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AudienceValidator<'a> {
|
||||||
|
expected_audience: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> DeserializeSeed<'de> for AudienceValidator<'_> {
|
||||||
|
type Value = ();
|
||||||
|
|
||||||
|
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
impl<'de> Visitor<'de> for AudienceValidator<'_> {
|
||||||
|
type Value = ();
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
formatter.write_str("a single string or an array of strings")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
if self.expected_audience == v {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(E::custom("invalid JWT token audience"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: serde::de::SeqAccess<'de>,
|
||||||
|
{
|
||||||
|
while let Some(v) = seq.next_element_seed(SingleAudienceValidator {
|
||||||
|
expected_audience: self.expected_audience,
|
||||||
|
})? {
|
||||||
|
if v {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(serde::de::Error::custom("invalid JWT token audience"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
deserializer.deserialize_any(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SingleAudienceValidator<'a> {
|
||||||
|
expected_audience: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> DeserializeSeed<'de> for SingleAudienceValidator<'_> {
|
||||||
|
type Value = bool;
|
||||||
|
|
||||||
|
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
impl<'de> Visitor<'de> for SingleAudienceValidator<'_> {
|
||||||
|
type Value = bool;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
formatter.write_str("a single audience string")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
Ok(self.expected_audience == v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
deserializer.deserialize_any(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
|
||||||
|
// the following entries are only extracted for the sake of debug logging.
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
struct JwtPayload<'a> {
|
||||||
|
/// Issuer of the JWT
|
||||||
|
issuer: Option<Cow<'a, str>>,
|
||||||
|
/// Subject of the JWT (the user)
|
||||||
|
subject: Option<Cow<'a, str>>,
|
||||||
|
/// Unique token identifier
|
||||||
|
jwt_id: Option<Cow<'a, str>>,
|
||||||
|
/// Unique session identifier
|
||||||
|
session_id: Option<Cow<'a, str>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct JwkRenewalPermit<'a> {
|
struct JwkRenewalPermit<'a> {
|
||||||
@@ -530,6 +668,8 @@ mod tests {
|
|||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use rand::rngs::OsRng;
|
use rand::rngs::OsRng;
|
||||||
use rsa::pkcs8::DecodePrivateKey;
|
use rsa::pkcs8::DecodePrivateKey;
|
||||||
|
use serde::Serialize;
|
||||||
|
use serde_json::json;
|
||||||
use signature::Signer;
|
use signature::Signer;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
@@ -562,18 +702,36 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
|
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs();
|
||||||
|
let body = typed_json::json! {{
|
||||||
|
"exp": now + 3600,
|
||||||
|
"nbf": now,
|
||||||
|
"aud": ["audience1", "neon", "audience2"],
|
||||||
|
"sub": "user1",
|
||||||
|
"sid": "session1",
|
||||||
|
"jti": "token1",
|
||||||
|
"iss": "neon-testing",
|
||||||
|
}};
|
||||||
|
build_custom_jwt_payload(kid, body, sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_custom_jwt_payload(
|
||||||
|
kid: String,
|
||||||
|
body: impl Serialize,
|
||||||
|
sig: jose_jwa::Signing,
|
||||||
|
) -> String {
|
||||||
let header = JwtHeader {
|
let header = JwtHeader {
|
||||||
typ: "JWT",
|
typ: "JWT",
|
||||||
algorithm: jose_jwa::Algorithm::Signing(sig),
|
algorithm: jose_jwa::Algorithm::Signing(sig),
|
||||||
key_id: Some(&kid),
|
key_id: Some(&kid),
|
||||||
};
|
};
|
||||||
let body = typed_json::json! {{
|
|
||||||
"exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
|
|
||||||
}};
|
|
||||||
|
|
||||||
let header =
|
let header =
|
||||||
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
|
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
|
||||||
let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
|
let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
|
||||||
|
|
||||||
format!("{header}.{body}")
|
format!("{header}.{body}")
|
||||||
}
|
}
|
||||||
@@ -588,6 +746,16 @@ mod tests {
|
|||||||
format!("{payload}.{sig}")
|
format!("{payload}.{sig}")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
|
||||||
|
use p256::ecdsa::{Signature, SigningKey};
|
||||||
|
|
||||||
|
let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
|
||||||
|
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
|
||||||
|
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
|
||||||
|
|
||||||
|
format!("{payload}.{sig}")
|
||||||
|
}
|
||||||
|
|
||||||
fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
|
fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
|
||||||
use rsa::pkcs1v15::SigningKey;
|
use rsa::pkcs1v15::SigningKey;
|
||||||
use rsa::signature::SignatureEncoding;
|
use rsa::signature::SignatureEncoding;
|
||||||
@@ -659,37 +827,34 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
|
|||||||
-----END PRIVATE KEY-----
|
-----END PRIVATE KEY-----
|
||||||
";
|
";
|
||||||
|
|
||||||
#[tokio::test]
|
#[derive(Clone)]
|
||||||
async fn renew() {
|
struct Fetch(Vec<AuthRule>);
|
||||||
let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into());
|
|
||||||
let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into());
|
|
||||||
let (ec1, jwk3) = new_ec_jwk("3".into());
|
|
||||||
let (ec2, jwk4) = new_ec_jwk("4".into());
|
|
||||||
|
|
||||||
let foo_jwks = jose_jwk::JwkSet {
|
impl FetchAuthRules for Fetch {
|
||||||
keys: vec![jwk1, jwk3],
|
async fn fetch_auth_rules(
|
||||||
};
|
&self,
|
||||||
let bar_jwks = jose_jwk::JwkSet {
|
_ctx: &RequestMonitoring,
|
||||||
keys: vec![jwk2, jwk4],
|
_endpoint: EndpointId,
|
||||||
};
|
) -> anyhow::Result<Vec<AuthRule>> {
|
||||||
|
Ok(self.0.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn jwks_server(
|
||||||
|
router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
|
||||||
|
) -> SocketAddr {
|
||||||
|
let router = Arc::new(router);
|
||||||
let service = service_fn(move |req| {
|
let service = service_fn(move |req| {
|
||||||
let foo_jwks = foo_jwks.clone();
|
let router = Arc::clone(&router);
|
||||||
let bar_jwks = bar_jwks.clone();
|
|
||||||
async move {
|
async move {
|
||||||
let jwks = match req.uri().path() {
|
match router(req.uri().path()) {
|
||||||
"/foo" => &foo_jwks,
|
Some(body) => Response::builder()
|
||||||
"/bar" => &bar_jwks,
|
.status(200)
|
||||||
_ => {
|
.body(Full::new(Bytes::from(body))),
|
||||||
return Response::builder()
|
None => Response::builder()
|
||||||
.status(404)
|
.status(404)
|
||||||
.body(Full::new(Bytes::new()));
|
.body(Full::new(Bytes::new())),
|
||||||
}
|
}
|
||||||
};
|
|
||||||
let body = serde_json::to_vec(jwks).unwrap();
|
|
||||||
Response::builder()
|
|
||||||
.status(200)
|
|
||||||
.body(Full::new(Bytes::from(body)))
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -704,84 +869,61 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
addr
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[tokio::test]
|
||||||
struct Fetch(SocketAddr, Vec<RoleNameInt>);
|
async fn check_jwt_happy_path() {
|
||||||
|
let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
|
||||||
|
let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
|
||||||
|
let (ec1, jwk3) = new_ec_jwk("ec1".into());
|
||||||
|
let (ec2, jwk4) = new_ec_jwk("ec2".into());
|
||||||
|
|
||||||
impl FetchAuthRules for Fetch {
|
let foo_jwks = jose_jwk::JwkSet {
|
||||||
async fn fetch_auth_rules(
|
keys: vec![jwk1, jwk3],
|
||||||
&self,
|
};
|
||||||
_ctx: &RequestMonitoring,
|
let bar_jwks = jose_jwk::JwkSet {
|
||||||
_endpoint: EndpointId,
|
keys: vec![jwk2, jwk4],
|
||||||
) -> anyhow::Result<Vec<AuthRule>> {
|
};
|
||||||
Ok(vec![
|
|
||||||
AuthRule {
|
let jwks_addr = jwks_server(move |path| match path {
|
||||||
id: "foo".to_owned(),
|
"/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
|
||||||
jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
|
"/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
|
||||||
audience: None,
|
_ => None,
|
||||||
role_names: self.1.clone(),
|
})
|
||||||
},
|
.await;
|
||||||
AuthRule {
|
|
||||||
id: "bar".to_owned(),
|
|
||||||
jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
|
|
||||||
audience: None,
|
|
||||||
role_names: self.1.clone(),
|
|
||||||
},
|
|
||||||
])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let role_name1 = RoleName::from("anonymous");
|
let role_name1 = RoleName::from("anonymous");
|
||||||
let role_name2 = RoleName::from("authenticated");
|
let role_name2 = RoleName::from("authenticated");
|
||||||
|
|
||||||
let fetch = Fetch(
|
let roles = vec![
|
||||||
addr,
|
RoleNameInt::from(&role_name1),
|
||||||
vec![
|
RoleNameInt::from(&role_name2),
|
||||||
RoleNameInt::from(&role_name1),
|
];
|
||||||
RoleNameInt::from(&role_name2),
|
let rules = vec![
|
||||||
],
|
AuthRule {
|
||||||
);
|
id: "foo".to_owned(),
|
||||||
|
jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
|
||||||
|
audience: None,
|
||||||
|
role_names: roles.clone(),
|
||||||
|
},
|
||||||
|
AuthRule {
|
||||||
|
id: "bar".to_owned(),
|
||||||
|
jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
|
||||||
|
audience: None,
|
||||||
|
role_names: roles.clone(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let fetch = Fetch(rules);
|
||||||
|
let jwk_cache = JwkCache::default();
|
||||||
|
|
||||||
let endpoint = EndpointId::from("ep");
|
let endpoint = EndpointId::from("ep");
|
||||||
|
|
||||||
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
|
let jwt1 = new_rsa_jwt("rs1".into(), rs1);
|
||||||
|
let jwt2 = new_rsa_jwt("rs2".into(), rs2);
|
||||||
let jwt1 = new_rsa_jwt("1".into(), rs1);
|
let jwt3 = new_ec_jwt("ec1".into(), &ec1);
|
||||||
let jwt2 = new_rsa_jwt("2".into(), rs2);
|
let jwt4 = new_ec_jwt("ec2".into(), &ec2);
|
||||||
let jwt3 = new_ec_jwt("3".into(), &ec1);
|
|
||||||
let jwt4 = new_ec_jwt("4".into(), &ec2);
|
|
||||||
|
|
||||||
// had the wrong kid, therefore will have the wrong ecdsa signature
|
|
||||||
let bad_jwt = new_ec_jwt("3".into(), &ec2);
|
|
||||||
// this role_name is not accepted
|
|
||||||
let bad_role_name = RoleName::from("cloud_admin");
|
|
||||||
|
|
||||||
let err = jwk_cache
|
|
||||||
.check_jwt(
|
|
||||||
&RequestMonitoring::test(),
|
|
||||||
&bad_jwt,
|
|
||||||
&client,
|
|
||||||
endpoint.clone(),
|
|
||||||
&role_name1,
|
|
||||||
&fetch,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap_err();
|
|
||||||
assert!(err.to_string().contains("signature error"));
|
|
||||||
|
|
||||||
let err = jwk_cache
|
|
||||||
.check_jwt(
|
|
||||||
&RequestMonitoring::test(),
|
|
||||||
&jwt1,
|
|
||||||
&client,
|
|
||||||
endpoint.clone(),
|
|
||||||
&bad_role_name,
|
|
||||||
&fetch,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap_err();
|
|
||||||
assert!(err.to_string().contains("jwk not found"));
|
|
||||||
|
|
||||||
let tokens = [jwt1, jwt2, jwt3, jwt4];
|
let tokens = [jwt1, jwt2, jwt3, jwt4];
|
||||||
let role_names = [role_name1, role_name2];
|
let role_names = [role_name1, role_name2];
|
||||||
@@ -790,15 +932,194 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
|
|||||||
jwk_cache
|
jwk_cache
|
||||||
.check_jwt(
|
.check_jwt(
|
||||||
&RequestMonitoring::test(),
|
&RequestMonitoring::test(),
|
||||||
token,
|
|
||||||
&client,
|
|
||||||
endpoint.clone(),
|
endpoint.clone(),
|
||||||
role,
|
role,
|
||||||
&fetch,
|
&fetch,
|
||||||
|
token,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn check_jwt_invalid_signature() {
|
||||||
|
let (_, jwk) = new_ec_jwk("1".into());
|
||||||
|
let (key, _) = new_ec_jwk("1".into());
|
||||||
|
|
||||||
|
// has a matching kid, but signed by the wrong key
|
||||||
|
let bad_jwt = new_ec_jwt("1".into(), &key);
|
||||||
|
|
||||||
|
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
|
||||||
|
let jwks_addr = jwks_server(move |path| match path {
|
||||||
|
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let role = RoleName::from("authenticated");
|
||||||
|
|
||||||
|
let rules = vec![AuthRule {
|
||||||
|
id: String::new(),
|
||||||
|
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
|
||||||
|
audience: None,
|
||||||
|
role_names: vec![RoleNameInt::from(&role)],
|
||||||
|
}];
|
||||||
|
|
||||||
|
let fetch = Fetch(rules);
|
||||||
|
let jwk_cache = JwkCache::default();
|
||||||
|
|
||||||
|
let ep = EndpointId::from("ep");
|
||||||
|
|
||||||
|
let ctx = RequestMonitoring::test();
|
||||||
|
let err = jwk_cache
|
||||||
|
.check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
|
||||||
|
.await
|
||||||
|
.unwrap_err();
|
||||||
|
assert!(
|
||||||
|
err.to_string().contains("signature error"),
|
||||||
|
"expected \"signature error\", got {err:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn check_jwt_unknown_role() {
|
||||||
|
let (key, jwk) = new_rsa_jwk(RS1, "1".into());
|
||||||
|
let jwt = new_rsa_jwt("1".into(), key);
|
||||||
|
|
||||||
|
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
|
||||||
|
let jwks_addr = jwks_server(move |path| match path {
|
||||||
|
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let role = RoleName::from("authenticated");
|
||||||
|
let rules = vec![AuthRule {
|
||||||
|
id: String::new(),
|
||||||
|
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
|
||||||
|
audience: None,
|
||||||
|
role_names: vec![RoleNameInt::from(&role)],
|
||||||
|
}];
|
||||||
|
|
||||||
|
let fetch = Fetch(rules);
|
||||||
|
let jwk_cache = JwkCache::default();
|
||||||
|
|
||||||
|
let ep = EndpointId::from("ep");
|
||||||
|
|
||||||
|
// this role_name is not accepted
|
||||||
|
let bad_role_name = RoleName::from("cloud_admin");
|
||||||
|
|
||||||
|
let ctx = RequestMonitoring::test();
|
||||||
|
let err = jwk_cache
|
||||||
|
.check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
|
||||||
|
.await
|
||||||
|
.unwrap_err();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
err.to_string().contains("jwk not found"),
|
||||||
|
"expected \"jwk not found\", got {err:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn check_jwt_invalid_claims() {
|
||||||
|
let (key, jwk) = new_ec_jwk("1".into());
|
||||||
|
|
||||||
|
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
|
||||||
|
let jwks_addr = jwks_server(move |path| match path {
|
||||||
|
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
struct Test {
|
||||||
|
body: serde_json::Value,
|
||||||
|
error: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
let table = vec![
|
||||||
|
Test {
|
||||||
|
body: json! {{
|
||||||
|
"nbf": now + 60,
|
||||||
|
"aud": "neon",
|
||||||
|
}},
|
||||||
|
error: "JWT token is not yet ready to use",
|
||||||
|
},
|
||||||
|
Test {
|
||||||
|
body: json! {{
|
||||||
|
"exp": now - 60,
|
||||||
|
"aud": ["neon"],
|
||||||
|
}},
|
||||||
|
error: "JWT token has expired",
|
||||||
|
},
|
||||||
|
Test {
|
||||||
|
body: json! {{
|
||||||
|
}},
|
||||||
|
error: "invalid JWT token audience",
|
||||||
|
},
|
||||||
|
Test {
|
||||||
|
body: json! {{
|
||||||
|
"aud": [],
|
||||||
|
}},
|
||||||
|
error: "invalid JWT token audience",
|
||||||
|
},
|
||||||
|
Test {
|
||||||
|
body: json! {{
|
||||||
|
"aud": "foo",
|
||||||
|
}},
|
||||||
|
error: "invalid JWT token audience",
|
||||||
|
},
|
||||||
|
Test {
|
||||||
|
body: json! {{
|
||||||
|
"aud": ["foo"],
|
||||||
|
}},
|
||||||
|
error: "invalid JWT token audience",
|
||||||
|
},
|
||||||
|
Test {
|
||||||
|
body: json! {{
|
||||||
|
"aud": ["foo", "bar"],
|
||||||
|
}},
|
||||||
|
error: "invalid JWT token audience",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let role = RoleName::from("authenticated");
|
||||||
|
|
||||||
|
let rules = vec![AuthRule {
|
||||||
|
id: String::new(),
|
||||||
|
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
|
||||||
|
audience: Some("neon".to_string()),
|
||||||
|
role_names: vec![RoleNameInt::from(&role)],
|
||||||
|
}];
|
||||||
|
|
||||||
|
let fetch = Fetch(rules);
|
||||||
|
let jwk_cache = JwkCache::default();
|
||||||
|
|
||||||
|
let ep = EndpointId::from("ep");
|
||||||
|
|
||||||
|
let ctx = RequestMonitoring::test();
|
||||||
|
for test in table {
|
||||||
|
let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
|
||||||
|
|
||||||
|
match jwk_cache
|
||||||
|
.check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Err(err) if err.to_string().contains(test.error) => {}
|
||||||
|
Err(err) => {
|
||||||
|
panic!("expected {:?}, got {err:?}", test.error)
|
||||||
|
}
|
||||||
|
Ok(()) => {
|
||||||
|
panic!("expected {:?}, got ok", test.error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,17 +14,15 @@ use crate::{
|
|||||||
EndpointId,
|
EndpointId,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
|
use super::jwt::{AuthRule, FetchAuthRules};
|
||||||
|
|
||||||
pub struct LocalBackend {
|
pub struct LocalBackend {
|
||||||
pub(crate) jwks_cache: JwkCache,
|
|
||||||
pub(crate) node_info: NodeInfo,
|
pub(crate) node_info: NodeInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LocalBackend {
|
impl LocalBackend {
|
||||||
pub fn new(postgres_addr: SocketAddr) -> Self {
|
pub fn new(postgres_addr: SocketAddr) -> Self {
|
||||||
LocalBackend {
|
LocalBackend {
|
||||||
jwks_cache: JwkCache::default(),
|
|
||||||
node_info: NodeInfo {
|
node_info: NodeInfo {
|
||||||
config: {
|
config: {
|
||||||
let mut cfg = ConnCfg::new();
|
let mut cfg = ConnCfg::new();
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ use compute_api::spec::LocalProxySpec;
|
|||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use futures::future::Either;
|
use futures::future::Either;
|
||||||
use proxy::{
|
use proxy::{
|
||||||
auth::backend::local::{LocalBackend, JWKS_ROLE_MAP},
|
auth::backend::{
|
||||||
|
jwt::JwkCache,
|
||||||
|
local::{LocalBackend, JWKS_ROLE_MAP},
|
||||||
|
},
|
||||||
cancellation::CancellationHandlerMain,
|
cancellation::CancellationHandlerMain,
|
||||||
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
|
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
|
||||||
console::{
|
console::{
|
||||||
@@ -267,12 +270,15 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
|||||||
allow_self_signed_compute: false,
|
allow_self_signed_compute: false,
|
||||||
http_config,
|
http_config,
|
||||||
authentication_config: AuthenticationConfig {
|
authentication_config: AuthenticationConfig {
|
||||||
|
jwks_cache: JwkCache::default(),
|
||||||
thread_pool: ThreadPool::new(0),
|
thread_pool: ThreadPool::new(0),
|
||||||
scram_protocol_timeout: Duration::from_secs(10),
|
scram_protocol_timeout: Duration::from_secs(10),
|
||||||
rate_limiter_enabled: false,
|
rate_limiter_enabled: false,
|
||||||
rate_limiter: BucketRateLimiter::new(vec![]),
|
rate_limiter: BucketRateLimiter::new(vec![]),
|
||||||
rate_limit_ip_subnet: 64,
|
rate_limit_ip_subnet: 64,
|
||||||
ip_allowlist_check_enabled: true,
|
ip_allowlist_check_enabled: true,
|
||||||
|
is_auth_broker: false,
|
||||||
|
accept_jwts: true,
|
||||||
},
|
},
|
||||||
require_client_ip: false,
|
require_client_ip: false,
|
||||||
handshake_timeout: Duration::from_secs(10),
|
handshake_timeout: Duration::from_secs(10),
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
|
|||||||
use aws_config::Region;
|
use aws_config::Region;
|
||||||
use futures::future::Either;
|
use futures::future::Either;
|
||||||
use proxy::auth;
|
use proxy::auth;
|
||||||
|
use proxy::auth::backend::jwt::JwkCache;
|
||||||
use proxy::auth::backend::AuthRateLimiter;
|
use proxy::auth::backend::AuthRateLimiter;
|
||||||
use proxy::auth::backend::MaybeOwned;
|
use proxy::auth::backend::MaybeOwned;
|
||||||
use proxy::cancellation::CancelMap;
|
use proxy::cancellation::CancelMap;
|
||||||
@@ -102,6 +103,9 @@ struct ProxyCliArgs {
|
|||||||
default_value = "http://localhost:3000/authenticate_proxy_request/"
|
default_value = "http://localhost:3000/authenticate_proxy_request/"
|
||||||
)]
|
)]
|
||||||
auth_endpoint: String,
|
auth_endpoint: String,
|
||||||
|
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
|
||||||
|
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||||
|
is_auth_broker: bool,
|
||||||
/// path to TLS key for client postgres connections
|
/// path to TLS key for client postgres connections
|
||||||
///
|
///
|
||||||
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||||
@@ -382,9 +386,27 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
info!("Starting mgmt on {mgmt_address}");
|
info!("Starting mgmt on {mgmt_address}");
|
||||||
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
|
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
|
||||||
|
|
||||||
let proxy_address: SocketAddr = args.proxy.parse()?;
|
let proxy_listener = if !args.is_auth_broker {
|
||||||
info!("Starting proxy on {proxy_address}");
|
let proxy_address: SocketAddr = args.proxy.parse()?;
|
||||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
info!("Starting proxy on {proxy_address}");
|
||||||
|
|
||||||
|
Some(TcpListener::bind(proxy_address).await?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: rename the argument to something like serverless.
|
||||||
|
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||||
|
let serverless_listener = if let Some(serverless_address) = args.wss {
|
||||||
|
let serverless_address: SocketAddr = serverless_address.parse()?;
|
||||||
|
info!("Starting wss on {serverless_address}");
|
||||||
|
Some(TcpListener::bind(serverless_address).await?)
|
||||||
|
} else if args.is_auth_broker {
|
||||||
|
bail!("wss arg must be present for auth-broker")
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let cancellation_token = CancellationToken::new();
|
let cancellation_token = CancellationToken::new();
|
||||||
|
|
||||||
let cancel_map = CancelMap::default();
|
let cancel_map = CancelMap::default();
|
||||||
@@ -430,21 +452,17 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
// client facing tasks. these will exit on error or on cancellation
|
// client facing tasks. these will exit on error or on cancellation
|
||||||
// cancellation returns Ok(())
|
// cancellation returns Ok(())
|
||||||
let mut client_tasks = JoinSet::new();
|
let mut client_tasks = JoinSet::new();
|
||||||
client_tasks.spawn(proxy::proxy::task_main(
|
if let Some(proxy_listener) = proxy_listener {
|
||||||
config,
|
client_tasks.spawn(proxy::proxy::task_main(
|
||||||
proxy_listener,
|
config,
|
||||||
cancellation_token.clone(),
|
proxy_listener,
|
||||||
cancellation_handler.clone(),
|
cancellation_token.clone(),
|
||||||
endpoint_rate_limiter.clone(),
|
cancellation_handler.clone(),
|
||||||
));
|
endpoint_rate_limiter.clone(),
|
||||||
|
));
|
||||||
// TODO: rename the argument to something like serverless.
|
}
|
||||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
|
||||||
if let Some(serverless_address) = args.wss {
|
|
||||||
let serverless_address: SocketAddr = serverless_address.parse()?;
|
|
||||||
info!("Starting wss on {serverless_address}");
|
|
||||||
let serverless_listener = TcpListener::bind(serverless_address).await?;
|
|
||||||
|
|
||||||
|
if let Some(serverless_listener) = serverless_listener {
|
||||||
client_tasks.spawn(serverless::task_main(
|
client_tasks.spawn(serverless::task_main(
|
||||||
config,
|
config,
|
||||||
serverless_listener,
|
serverless_listener,
|
||||||
@@ -674,7 +692,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
|||||||
)?;
|
)?;
|
||||||
|
|
||||||
let http_config = HttpConfig {
|
let http_config = HttpConfig {
|
||||||
accept_websockets: true,
|
accept_websockets: !args.is_auth_broker,
|
||||||
pool_options: GlobalConnPoolOptions {
|
pool_options: GlobalConnPoolOptions {
|
||||||
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
|
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
|
||||||
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
|
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
|
||||||
@@ -689,12 +707,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
|||||||
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
|
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
|
||||||
};
|
};
|
||||||
let authentication_config = AuthenticationConfig {
|
let authentication_config = AuthenticationConfig {
|
||||||
|
jwks_cache: JwkCache::default(),
|
||||||
thread_pool,
|
thread_pool,
|
||||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||||
|
is_auth_broker: args.is_auth_broker,
|
||||||
|
accept_jwts: args.is_auth_broker,
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = Box::leak(Box::new(ProxyConfig {
|
let config = Box::leak(Box::new(ProxyConfig {
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
auth::{self, backend::AuthRateLimiter},
|
auth::{
|
||||||
|
self,
|
||||||
|
backend::{jwt::JwkCache, AuthRateLimiter},
|
||||||
|
},
|
||||||
console::locks::ApiLocks,
|
console::locks::ApiLocks,
|
||||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
||||||
scram::threadpool::ThreadPool,
|
scram::threadpool::ThreadPool,
|
||||||
@@ -67,6 +70,9 @@ pub struct AuthenticationConfig {
|
|||||||
pub rate_limiter: AuthRateLimiter,
|
pub rate_limiter: AuthRateLimiter,
|
||||||
pub rate_limit_ip_subnet: u8,
|
pub rate_limit_ip_subnet: u8,
|
||||||
pub ip_allowlist_check_enabled: bool,
|
pub ip_allowlist_check_enabled: bool,
|
||||||
|
pub jwks_cache: JwkCache,
|
||||||
|
pub is_auth_broker: bool,
|
||||||
|
pub accept_jwts: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TlsConfig {
|
impl TlsConfig {
|
||||||
@@ -250,18 +256,26 @@ impl CertResolver {
|
|||||||
|
|
||||||
let common_name = pem.subject().to_string();
|
let common_name = pem.subject().to_string();
|
||||||
|
|
||||||
// We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
|
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
// seen within the proxy codebase.
|
||||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
//
|
||||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||||
// of cutting off '*.' parts.
|
//
|
||||||
let common_name = if common_name.starts_with("CN=*.") {
|
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||||
|
//
|
||||||
|
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||||
|
// validation, so let's we can continue with any common-name
|
||||||
|
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||||
|
s.to_string()
|
||||||
|
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||||
|
s.to_string()
|
||||||
|
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||||
|
s.to_string()
|
||||||
} else {
|
} else {
|
||||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
bail!("Failed to parse common name from certificate")
|
||||||
}
|
};
|
||||||
.context("Failed to parse common name from certificate")?;
|
|
||||||
|
|
||||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ pub mod neon;
|
|||||||
use super::messages::{ConsoleError, MetricsAuxInfo};
|
use super::messages::{ConsoleError, MetricsAuxInfo};
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{
|
auth::{
|
||||||
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
backend::{
|
||||||
|
jwt::{AuthRule, FetchAuthRules},
|
||||||
|
ComputeCredentialKeys, ComputeUserInfo,
|
||||||
|
},
|
||||||
IpPattern,
|
IpPattern,
|
||||||
},
|
},
|
||||||
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||||
@@ -16,7 +19,7 @@ use crate::{
|
|||||||
intern::ProjectIdInt,
|
intern::ProjectIdInt,
|
||||||
metrics::ApiLockMetrics,
|
metrics::ApiLockMetrics,
|
||||||
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
|
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
|
||||||
scram, EndpointCacheKey,
|
scram, EndpointCacheKey, EndpointId,
|
||||||
};
|
};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use std::{hash::Hash, sync::Arc, time::Duration};
|
use std::{hash::Hash, sync::Arc, time::Duration};
|
||||||
@@ -334,6 +337,12 @@ pub(crate) trait Api {
|
|||||||
user_info: &ComputeUserInfo,
|
user_info: &ComputeUserInfo,
|
||||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||||
|
|
||||||
|
async fn get_endpoint_jwks(
|
||||||
|
&self,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
endpoint: EndpointId,
|
||||||
|
) -> anyhow::Result<Vec<AuthRule>>;
|
||||||
|
|
||||||
/// Wake up the compute node and return the corresponding connection info.
|
/// Wake up the compute node and return the corresponding connection info.
|
||||||
async fn wake_compute(
|
async fn wake_compute(
|
||||||
&self,
|
&self,
|
||||||
@@ -343,6 +352,7 @@ pub(crate) trait Api {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
|
#[derive(Clone)]
|
||||||
pub enum ConsoleBackend {
|
pub enum ConsoleBackend {
|
||||||
/// Current Cloud API (V2).
|
/// Current Cloud API (V2).
|
||||||
Console(neon::Api),
|
Console(neon::Api),
|
||||||
@@ -386,6 +396,20 @@ impl Api for ConsoleBackend {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_endpoint_jwks(
|
||||||
|
&self,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
endpoint: EndpointId,
|
||||||
|
) -> anyhow::Result<Vec<AuthRule>> {
|
||||||
|
match self {
|
||||||
|
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||||
|
#[cfg(any(test, feature = "testing"))]
|
||||||
|
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||||
|
#[cfg(test)]
|
||||||
|
Self::Test(_api) => Ok(vec![]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn wake_compute(
|
async fn wake_compute(
|
||||||
&self,
|
&self,
|
||||||
ctx: &RequestMonitoring,
|
ctx: &RequestMonitoring,
|
||||||
@@ -552,3 +576,13 @@ impl WakeComputePermit {
|
|||||||
res
|
res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FetchAuthRules for ConsoleBackend {
|
||||||
|
async fn fetch_auth_rules(
|
||||||
|
&self,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
endpoint: EndpointId,
|
||||||
|
) -> anyhow::Result<Vec<AuthRule>> {
|
||||||
|
self.get_endpoint_jwks(ctx, endpoint).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ use super::{
|
|||||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||||
};
|
};
|
||||||
use crate::context::RequestMonitoring;
|
use crate::{
|
||||||
|
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
|
||||||
|
};
|
||||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||||
use crate::{auth::IpPattern, cache::Cached};
|
use crate::{auth::IpPattern, cache::Cached};
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -118,6 +120,39 @@ impl Api {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
|
||||||
|
let (client, connection) =
|
||||||
|
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||||
|
|
||||||
|
let connection = tokio::spawn(connection);
|
||||||
|
|
||||||
|
let res = client.query(
|
||||||
|
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
|
||||||
|
&[&endpoint.as_str()],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut rows = vec![];
|
||||||
|
for row in res {
|
||||||
|
rows.push(AuthRule {
|
||||||
|
id: row.get("id"),
|
||||||
|
jwks_url: url::Url::parse(row.get("jwks_url"))?,
|
||||||
|
audience: row.get("audience"),
|
||||||
|
role_names: row
|
||||||
|
.get::<_, Vec<String>>("role_names")
|
||||||
|
.into_iter()
|
||||||
|
.map(RoleName::from)
|
||||||
|
.map(|s| RoleNameInt::from(&s))
|
||||||
|
.collect(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
connection.await??;
|
||||||
|
|
||||||
|
Ok(rows)
|
||||||
|
}
|
||||||
|
|
||||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||||
let mut config = compute::ConnCfg::new();
|
let mut config = compute::ConnCfg::new();
|
||||||
config
|
config
|
||||||
@@ -185,6 +220,14 @@ impl super::Api for Api {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_endpoint_jwks(
|
||||||
|
&self,
|
||||||
|
_ctx: &RequestMonitoring,
|
||||||
|
endpoint: EndpointId,
|
||||||
|
) -> anyhow::Result<Vec<AuthRule>> {
|
||||||
|
self.do_get_endpoint_jwks(endpoint).await
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
async fn wake_compute(
|
async fn wake_compute(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
@@ -7,27 +7,33 @@ use super::{
|
|||||||
NodeInfo,
|
NodeInfo,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::backend::ComputeUserInfo,
|
auth::backend::{jwt::AuthRule, ComputeUserInfo},
|
||||||
compute,
|
compute,
|
||||||
console::messages::{ColdStartInfo, Reason},
|
console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
|
||||||
http,
|
http,
|
||||||
metrics::{CacheOutcome, Metrics},
|
metrics::{CacheOutcome, Metrics},
|
||||||
rate_limiter::WakeComputeRateLimiter,
|
rate_limiter::WakeComputeRateLimiter,
|
||||||
scram, EndpointCacheKey,
|
scram, EndpointCacheKey, EndpointId,
|
||||||
};
|
};
|
||||||
use crate::{cache::Cached, context::RequestMonitoring};
|
use crate::{cache::Cached, context::RequestMonitoring};
|
||||||
|
use ::http::{header::AUTHORIZATION, HeaderName};
|
||||||
|
use anyhow::bail;
|
||||||
use futures::TryFutureExt;
|
use futures::TryFutureExt;
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::{sync::Arc, time::Duration};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_postgres::config::SslMode;
|
use tokio_postgres::config::SslMode;
|
||||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||||
|
|
||||||
|
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct Api {
|
pub struct Api {
|
||||||
endpoint: http::Endpoint,
|
endpoint: http::Endpoint,
|
||||||
pub caches: &'static ApiCaches,
|
pub caches: &'static ApiCaches,
|
||||||
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
|
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
|
||||||
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||||
jwt: String,
|
// put in a shared ref so we don't copy secrets all over in memory
|
||||||
|
jwt: Arc<str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Api {
|
impl Api {
|
||||||
@@ -38,7 +44,9 @@ impl Api {
|
|||||||
locks: &'static ApiLocks<EndpointCacheKey>,
|
locks: &'static ApiLocks<EndpointCacheKey>,
|
||||||
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default();
|
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
|
||||||
|
.unwrap_or_default()
|
||||||
|
.into();
|
||||||
Self {
|
Self {
|
||||||
endpoint,
|
endpoint,
|
||||||
caches,
|
caches,
|
||||||
@@ -71,9 +79,9 @@ impl Api {
|
|||||||
async {
|
async {
|
||||||
let request = self
|
let request = self
|
||||||
.endpoint
|
.endpoint
|
||||||
.get("proxy_get_role_secret")
|
.get_path("proxy_get_role_secret")
|
||||||
.header("X-Request-ID", &request_id)
|
.header(X_REQUEST_ID, &request_id)
|
||||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||||
.query(&[("session_id", ctx.session_id())])
|
.query(&[("session_id", ctx.session_id())])
|
||||||
.query(&[
|
.query(&[
|
||||||
("application_name", application_name.as_str()),
|
("application_name", application_name.as_str()),
|
||||||
@@ -125,6 +133,61 @@ impl Api {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn do_get_endpoint_jwks(
|
||||||
|
&self,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
endpoint: EndpointId,
|
||||||
|
) -> anyhow::Result<Vec<AuthRule>> {
|
||||||
|
if !self
|
||||||
|
.caches
|
||||||
|
.endpoints_cache
|
||||||
|
.is_valid(ctx, &endpoint.normalize())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
bail!("endpoint not found");
|
||||||
|
}
|
||||||
|
let request_id = ctx.session_id().to_string();
|
||||||
|
async {
|
||||||
|
let request = self
|
||||||
|
.endpoint
|
||||||
|
.get_with_url(|url| {
|
||||||
|
url.path_segments_mut()
|
||||||
|
.push("endpoints")
|
||||||
|
.push(endpoint.as_str())
|
||||||
|
.push("jwks");
|
||||||
|
})
|
||||||
|
.header(X_REQUEST_ID, &request_id)
|
||||||
|
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||||
|
.query(&[("session_id", ctx.session_id())])
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
info!(url = request.url().as_str(), "sending http request");
|
||||||
|
let start = Instant::now();
|
||||||
|
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||||
|
let response = self.endpoint.execute(request).await?;
|
||||||
|
drop(pause);
|
||||||
|
info!(duration = ?start.elapsed(), "received http response");
|
||||||
|
|
||||||
|
let body = parse_body::<EndpointJwksResponse>(response).await?;
|
||||||
|
|
||||||
|
let rules = body
|
||||||
|
.jwks
|
||||||
|
.into_iter()
|
||||||
|
.map(|jwks| AuthRule {
|
||||||
|
id: jwks.id,
|
||||||
|
jwks_url: jwks.jwks_url,
|
||||||
|
audience: jwks.jwt_audience,
|
||||||
|
role_names: jwks.role_names,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(rules)
|
||||||
|
}
|
||||||
|
.map_err(crate::error::log_error)
|
||||||
|
.instrument(info_span!("http", id = request_id))
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
async fn do_wake_compute(
|
async fn do_wake_compute(
|
||||||
&self,
|
&self,
|
||||||
ctx: &RequestMonitoring,
|
ctx: &RequestMonitoring,
|
||||||
@@ -135,7 +198,7 @@ impl Api {
|
|||||||
async {
|
async {
|
||||||
let mut request_builder = self
|
let mut request_builder = self
|
||||||
.endpoint
|
.endpoint
|
||||||
.get("proxy_wake_compute")
|
.get_path("proxy_wake_compute")
|
||||||
.header("X-Request-ID", &request_id)
|
.header("X-Request-ID", &request_id)
|
||||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||||
.query(&[("session_id", ctx.session_id())])
|
.query(&[("session_id", ctx.session_id())])
|
||||||
@@ -262,6 +325,15 @@ impl super::Api for Api {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip_all)]
|
||||||
|
async fn get_endpoint_jwks(
|
||||||
|
&self,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
endpoint: EndpointId,
|
||||||
|
) -> anyhow::Result<Vec<AuthRule>> {
|
||||||
|
self.do_get_endpoint_jwks(ctx, endpoint).await
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
async fn wake_compute(
|
async fn wake_compute(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
@@ -86,9 +86,17 @@ impl Endpoint {
|
|||||||
|
|
||||||
/// Return a [builder](RequestBuilder) for a `GET` request,
|
/// Return a [builder](RequestBuilder) for a `GET` request,
|
||||||
/// appending a single `path` segment to the base endpoint URL.
|
/// appending a single `path` segment to the base endpoint URL.
|
||||||
pub(crate) fn get(&self, path: &str) -> RequestBuilder {
|
pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
|
||||||
|
self.get_with_url(|u| {
|
||||||
|
u.path_segments_mut().push(path);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a [builder](RequestBuilder) for a `GET` request,
|
||||||
|
/// accepting a closure to modify the url path segments for more complex paths queries.
|
||||||
|
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
|
||||||
let mut url = self.endpoint.clone();
|
let mut url = self.endpoint.clone();
|
||||||
url.path_segments_mut().push(path);
|
f(&mut url);
|
||||||
self.client.get(url.into_inner())
|
self.client.get(url.into_inner())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,7 +152,7 @@ mod tests {
|
|||||||
|
|
||||||
// Validate that this pattern makes sense.
|
// Validate that this pattern makes sense.
|
||||||
let req = endpoint
|
let req = endpoint
|
||||||
.get("frobnicate")
|
.get_path("frobnicate")
|
||||||
.query(&[
|
.query(&[
|
||||||
("foo", Some("10")), // should be just `foo=10`
|
("foo", Some("10")), // should be just `foo=10`
|
||||||
("bar", None), // shouldn't be passed at all
|
("bar", None), // shouldn't be passed at all
|
||||||
@@ -162,7 +170,7 @@ mod tests {
|
|||||||
let endpoint = Endpoint::new(url, Client::new());
|
let endpoint = Endpoint::new(url, Client::new());
|
||||||
|
|
||||||
let req = endpoint
|
let req = endpoint
|
||||||
.get("frobnicate")
|
.get_path("frobnicate")
|
||||||
.query(&[("session_id", uuid::Uuid::nil())])
|
.query(&[("session_id", uuid::Uuid::nil())])
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use std::{
|
use std::{
|
||||||
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
|
any::type_name, hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index,
|
||||||
|
sync::OnceLock,
|
||||||
};
|
};
|
||||||
|
|
||||||
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
|
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
|
||||||
@@ -16,12 +17,21 @@ pub struct StringInterner<Id> {
|
|||||||
_id: PhantomData<Id>,
|
_id: PhantomData<Id>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
|
#[derive(PartialEq, Clone, Copy, Eq, Hash)]
|
||||||
pub struct InternedString<Id> {
|
pub struct InternedString<Id> {
|
||||||
inner: Spur,
|
inner: Spur,
|
||||||
_id: PhantomData<Id>,
|
_id: PhantomData<Id>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<Id: InternId> std::fmt::Debug for InternedString<Id> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_tuple("InternedString")
|
||||||
|
.field(&type_name::<Id>())
|
||||||
|
.field(&self.as_str())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<Id: InternId> std::fmt::Display for InternedString<Id> {
|
impl<Id: InternId> std::fmt::Display for InternedString<Id> {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
self.as_str().fmt(f)
|
self.as_str().fmt(f)
|
||||||
|
|||||||
@@ -525,6 +525,10 @@ impl TestBackend for TestConnectMechanism {
|
|||||||
{
|
{
|
||||||
unimplemented!("not used in tests")
|
unimplemented!("not used in tests")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn dyn_clone(&self) -> Box<dyn TestBackend> {
|
||||||
|
Box::new(self.clone())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
mod backend;
|
mod backend;
|
||||||
pub mod cancel_set;
|
pub mod cancel_set;
|
||||||
mod conn_pool;
|
mod conn_pool;
|
||||||
|
mod http_conn_pool;
|
||||||
mod http_util;
|
mod http_util;
|
||||||
mod json;
|
mod json;
|
||||||
mod sql_over_http;
|
mod sql_over_http;
|
||||||
@@ -19,7 +20,8 @@ use anyhow::Context;
|
|||||||
use futures::future::{select, Either};
|
use futures::future::{select, Either};
|
||||||
use futures::TryFutureExt;
|
use futures::TryFutureExt;
|
||||||
use http::{Method, Response, StatusCode};
|
use http::{Method, Response, StatusCode};
|
||||||
use http_body_util::Full;
|
use http_body_util::combinators::BoxBody;
|
||||||
|
use http_body_util::{BodyExt, Empty};
|
||||||
use hyper1::body::Incoming;
|
use hyper1::body::Incoming;
|
||||||
use hyper_util::rt::TokioExecutor;
|
use hyper_util::rt::TokioExecutor;
|
||||||
use hyper_util::server::conn::auto::Builder;
|
use hyper_util::server::conn::auto::Builder;
|
||||||
@@ -81,7 +83,28 @@ pub async fn task_main(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
|
||||||
|
{
|
||||||
|
let http_conn_pool = Arc::clone(&http_conn_pool);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdown the connection pool
|
||||||
|
tokio::spawn({
|
||||||
|
let cancellation_token = cancellation_token.clone();
|
||||||
|
let http_conn_pool = http_conn_pool.clone();
|
||||||
|
async move {
|
||||||
|
cancellation_token.cancelled().await;
|
||||||
|
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let backend = Arc::new(PoolingBackend {
|
let backend = Arc::new(PoolingBackend {
|
||||||
|
http_conn_pool: Arc::clone(&http_conn_pool),
|
||||||
pool: Arc::clone(&conn_pool),
|
pool: Arc::clone(&conn_pool),
|
||||||
config,
|
config,
|
||||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||||
@@ -342,7 +365,7 @@ async fn request_handler(
|
|||||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||||
http_cancellation_token: CancellationToken,
|
http_cancellation_token: CancellationToken,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||||
let host = request
|
let host = request
|
||||||
.headers()
|
.headers()
|
||||||
.get("host")
|
.get("host")
|
||||||
@@ -386,7 +409,7 @@ async fn request_handler(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Return the response so the spawned future can continue.
|
// Return the response so the spawned future can continue.
|
||||||
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
|
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||||
let ctx = RequestMonitoring::new(
|
let ctx = RequestMonitoring::new(
|
||||||
session_id,
|
session_id,
|
||||||
@@ -409,7 +432,7 @@ async fn request_handler(
|
|||||||
)
|
)
|
||||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||||
.body(Full::new(Bytes::new()))
|
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||||
} else {
|
} else {
|
||||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
use std::{sync::Arc, time::Duration};
|
use std::{io, sync::Arc, time::Duration};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||||
|
use tokio::net::{lookup_host, TcpStream};
|
||||||
use tracing::{field::display, info};
|
use tracing::{field::display, info};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -27,9 +29,13 @@ use crate::{
|
|||||||
Host,
|
Host,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
use super::{
|
||||||
|
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
|
||||||
|
http_conn_pool::{self, poll_http2_client},
|
||||||
|
};
|
||||||
|
|
||||||
pub(crate) struct PoolingBackend {
|
pub(crate) struct PoolingBackend {
|
||||||
|
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
|
||||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||||
pub(crate) config: &'static ProxyConfig,
|
pub(crate) config: &'static ProxyConfig,
|
||||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||||
@@ -103,32 +109,44 @@ impl PoolingBackend {
|
|||||||
pub(crate) async fn authenticate_with_jwt(
|
pub(crate) async fn authenticate_with_jwt(
|
||||||
&self,
|
&self,
|
||||||
ctx: &RequestMonitoring,
|
ctx: &RequestMonitoring,
|
||||||
|
config: &AuthenticationConfig,
|
||||||
user_info: &ComputeUserInfo,
|
user_info: &ComputeUserInfo,
|
||||||
jwt: &str,
|
jwt: String,
|
||||||
) -> Result<ComputeCredentials, AuthError> {
|
) -> Result<(), AuthError> {
|
||||||
match &self.config.auth_backend {
|
match &self.config.auth_backend {
|
||||||
crate::auth::Backend::Console(_, ()) => {
|
crate::auth::Backend::Console(console, ()) => {
|
||||||
Err(AuthError::auth_failed("JWT login is not yet supported"))
|
config
|
||||||
|
.jwks_cache
|
||||||
|
.check_jwt(
|
||||||
|
ctx,
|
||||||
|
user_info.endpoint.clone(),
|
||||||
|
&user_info.user,
|
||||||
|
&**console,
|
||||||
|
&jwt,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
|
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
|
||||||
"JWT login over web auth proxy is not supported",
|
"JWT login over web auth proxy is not supported",
|
||||||
)),
|
)),
|
||||||
crate::auth::Backend::Local(cache) => {
|
crate::auth::Backend::Local(_) => {
|
||||||
cache
|
config
|
||||||
.jwks_cache
|
.jwks_cache
|
||||||
.check_jwt(
|
.check_jwt(
|
||||||
ctx,
|
ctx,
|
||||||
user_info.endpoint.clone(),
|
user_info.endpoint.clone(),
|
||||||
&user_info.user,
|
&user_info.user,
|
||||||
&StaticAuthRules,
|
&StaticAuthRules,
|
||||||
jwt,
|
&jwt,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||||
Ok(ComputeCredentials {
|
|
||||||
info: user_info.clone(),
|
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
|
||||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
Ok(())
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -174,14 +192,55 @@ impl PoolingBackend {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wake up the destination if needed
|
||||||
|
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||||
|
pub(crate) async fn connect_to_local_proxy(
|
||||||
|
&self,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
conn_info: ConnInfo,
|
||||||
|
) -> Result<http_conn_pool::Client, HttpConnError> {
|
||||||
|
info!("pool: looking for an existing connection");
|
||||||
|
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||||
|
return Ok(client);
|
||||||
|
}
|
||||||
|
|
||||||
|
let conn_id = uuid::Uuid::new_v4();
|
||||||
|
tracing::Span::current().record("conn_id", display(conn_id));
|
||||||
|
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||||
|
let backend = self
|
||||||
|
.config
|
||||||
|
.auth_backend
|
||||||
|
.as_ref()
|
||||||
|
.map(|()| ComputeCredentials {
|
||||||
|
info: conn_info.user_info.clone(),
|
||||||
|
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||||
|
});
|
||||||
|
crate::proxy::connect_compute::connect_to_compute(
|
||||||
|
ctx,
|
||||||
|
&HyperMechanism {
|
||||||
|
conn_id,
|
||||||
|
conn_info,
|
||||||
|
pool: self.http_conn_pool.clone(),
|
||||||
|
locks: &self.config.connect_compute_locks,
|
||||||
|
},
|
||||||
|
&backend,
|
||||||
|
false, // do not allow self signed compute for http flow
|
||||||
|
self.config.wake_compute_retry_config,
|
||||||
|
self.config.connect_to_compute_retry_config,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub(crate) enum HttpConnError {
|
pub(crate) enum HttpConnError {
|
||||||
#[error("pooled connection closed at inconsistent state")]
|
#[error("pooled connection closed at inconsistent state")]
|
||||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||||
#[error("could not connection to compute")]
|
#[error("could not connection to postgres in compute")]
|
||||||
ConnectionError(#[from] tokio_postgres::Error),
|
PostgresConnectionError(#[from] tokio_postgres::Error),
|
||||||
|
#[error("could not connection to local-proxy in compute")]
|
||||||
|
LocalProxyConnectionError(#[from] LocalProxyConnError),
|
||||||
|
|
||||||
#[error("could not get auth info")]
|
#[error("could not get auth info")]
|
||||||
GetAuthInfo(#[from] GetAuthInfoError),
|
GetAuthInfo(#[from] GetAuthInfoError),
|
||||||
@@ -193,11 +252,20 @@ pub(crate) enum HttpConnError {
|
|||||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub(crate) enum LocalProxyConnError {
|
||||||
|
#[error("error with connection to local-proxy")]
|
||||||
|
Io(#[source] std::io::Error),
|
||||||
|
#[error("could not establish h2 connection")]
|
||||||
|
H2(#[from] hyper1::Error),
|
||||||
|
}
|
||||||
|
|
||||||
impl ReportableError for HttpConnError {
|
impl ReportableError for HttpConnError {
|
||||||
fn get_error_kind(&self) -> ErrorKind {
|
fn get_error_kind(&self) -> ErrorKind {
|
||||||
match self {
|
match self {
|
||||||
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
||||||
HttpConnError::ConnectionError(p) => p.get_error_kind(),
|
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
|
||||||
|
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
|
||||||
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
||||||
HttpConnError::AuthError(a) => a.get_error_kind(),
|
HttpConnError::AuthError(a) => a.get_error_kind(),
|
||||||
HttpConnError::WakeCompute(w) => w.get_error_kind(),
|
HttpConnError::WakeCompute(w) => w.get_error_kind(),
|
||||||
@@ -210,7 +278,8 @@ impl UserFacingError for HttpConnError {
|
|||||||
fn to_string_client(&self) -> String {
|
fn to_string_client(&self) -> String {
|
||||||
match self {
|
match self {
|
||||||
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
||||||
HttpConnError::ConnectionError(p) => p.to_string(),
|
HttpConnError::PostgresConnectionError(p) => p.to_string(),
|
||||||
|
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
|
||||||
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
||||||
HttpConnError::AuthError(c) => c.to_string_client(),
|
HttpConnError::AuthError(c) => c.to_string_client(),
|
||||||
HttpConnError::WakeCompute(c) => c.to_string_client(),
|
HttpConnError::WakeCompute(c) => c.to_string_client(),
|
||||||
@@ -224,7 +293,8 @@ impl UserFacingError for HttpConnError {
|
|||||||
impl CouldRetry for HttpConnError {
|
impl CouldRetry for HttpConnError {
|
||||||
fn could_retry(&self) -> bool {
|
fn could_retry(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
HttpConnError::ConnectionError(e) => e.could_retry(),
|
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
|
||||||
|
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
|
||||||
HttpConnError::ConnectionClosedAbruptly(_) => false,
|
HttpConnError::ConnectionClosedAbruptly(_) => false,
|
||||||
HttpConnError::GetAuthInfo(_) => false,
|
HttpConnError::GetAuthInfo(_) => false,
|
||||||
HttpConnError::AuthError(_) => false,
|
HttpConnError::AuthError(_) => false,
|
||||||
@@ -236,7 +306,7 @@ impl CouldRetry for HttpConnError {
|
|||||||
impl ShouldRetryWakeCompute for HttpConnError {
|
impl ShouldRetryWakeCompute for HttpConnError {
|
||||||
fn should_retry_wake_compute(&self) -> bool {
|
fn should_retry_wake_compute(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
|
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
|
||||||
// we never checked cache validity
|
// we never checked cache validity
|
||||||
HttpConnError::TooManyConnectionAttempts(_) => false,
|
HttpConnError::TooManyConnectionAttempts(_) => false,
|
||||||
_ => true,
|
_ => true,
|
||||||
@@ -244,6 +314,38 @@ impl ShouldRetryWakeCompute for HttpConnError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ReportableError for LocalProxyConnError {
|
||||||
|
fn get_error_kind(&self) -> ErrorKind {
|
||||||
|
match self {
|
||||||
|
LocalProxyConnError::Io(_) => ErrorKind::Compute,
|
||||||
|
LocalProxyConnError::H2(_) => ErrorKind::Compute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserFacingError for LocalProxyConnError {
|
||||||
|
fn to_string_client(&self) -> String {
|
||||||
|
"Could not establish HTTP connection to the database".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CouldRetry for LocalProxyConnError {
|
||||||
|
fn could_retry(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
LocalProxyConnError::Io(_) => false,
|
||||||
|
LocalProxyConnError::H2(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl ShouldRetryWakeCompute for LocalProxyConnError {
|
||||||
|
fn should_retry_wake_compute(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
LocalProxyConnError::Io(_) => false,
|
||||||
|
LocalProxyConnError::H2(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct TokioMechanism {
|
struct TokioMechanism {
|
||||||
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||||
conn_info: ConnInfo,
|
conn_info: ConnInfo,
|
||||||
@@ -293,3 +395,99 @@ impl ConnectMechanism for TokioMechanism {
|
|||||||
|
|
||||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct HyperMechanism {
|
||||||
|
pool: Arc<http_conn_pool::GlobalConnPool>,
|
||||||
|
conn_info: ConnInfo,
|
||||||
|
conn_id: uuid::Uuid,
|
||||||
|
|
||||||
|
/// connect_to_compute concurrency lock
|
||||||
|
locks: &'static ApiLocks<Host>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ConnectMechanism for HyperMechanism {
|
||||||
|
type Connection = http_conn_pool::Client;
|
||||||
|
type ConnectError = HttpConnError;
|
||||||
|
type Error = HttpConnError;
|
||||||
|
|
||||||
|
async fn connect_once(
|
||||||
|
&self,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
node_info: &CachedNodeInfo,
|
||||||
|
timeout: Duration,
|
||||||
|
) -> Result<Self::Connection, Self::ConnectError> {
|
||||||
|
let host = node_info.config.get_host()?;
|
||||||
|
let permit = self.locks.get_permit(&host).await?;
|
||||||
|
|
||||||
|
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||||
|
|
||||||
|
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
|
||||||
|
let res = connect_http2(&host, 10432, timeout).await;
|
||||||
|
drop(pause);
|
||||||
|
let (client, connection) = permit.release_result(res)?;
|
||||||
|
|
||||||
|
Ok(poll_http2_client(
|
||||||
|
self.pool.clone(),
|
||||||
|
ctx,
|
||||||
|
&self.conn_info,
|
||||||
|
client,
|
||||||
|
connection,
|
||||||
|
self.conn_id,
|
||||||
|
node_info.aux.clone(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_http2(
|
||||||
|
host: &str,
|
||||||
|
port: u16,
|
||||||
|
timeout: Duration,
|
||||||
|
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
|
||||||
|
// assumption: host is an ip address so this should not actually perform any requests.
|
||||||
|
// todo: add that assumption as a guarantee in the control-plane API.
|
||||||
|
let mut addrs = lookup_host((host, port))
|
||||||
|
.await
|
||||||
|
.map_err(LocalProxyConnError::Io)?;
|
||||||
|
|
||||||
|
let mut last_err = None;
|
||||||
|
|
||||||
|
let stream = loop {
|
||||||
|
let Some(addr) = addrs.next() else {
|
||||||
|
return Err(last_err.unwrap_or_else(|| {
|
||||||
|
LocalProxyConnError::Io(io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"could not resolve any addresses",
|
||||||
|
))
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
|
||||||
|
Ok(Ok(stream)) => {
|
||||||
|
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
|
||||||
|
break stream;
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
last_err = Some(LocalProxyConnError::Io(e));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
last_err = Some(LocalProxyConnError::Io(io::Error::new(
|
||||||
|
io::ErrorKind::TimedOut,
|
||||||
|
e,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||||
|
.timer(TokioTimer::new())
|
||||||
|
.keep_alive_interval(Duration::from_secs(20))
|
||||||
|
.keep_alive_while_idle(true)
|
||||||
|
.keep_alive_timeout(Duration::from_secs(5))
|
||||||
|
.handshake(TokioIo::new(stream))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok((client, connection))
|
||||||
|
}
|
||||||
|
|||||||
335
proxy/src/serverless/http_conn_pool.rs
Normal file
335
proxy/src/serverless/http_conn_pool.rs
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
use dashmap::DashMap;
|
||||||
|
use hyper1::client::conn::http2;
|
||||||
|
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use rand::Rng;
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::sync::atomic::{self, AtomicUsize};
|
||||||
|
use std::{sync::Arc, sync::Weak};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
|
||||||
|
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||||
|
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||||
|
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||||
|
use crate::{context::RequestMonitoring, EndpointCacheKey};
|
||||||
|
|
||||||
|
use tracing::{debug, error};
|
||||||
|
use tracing::{info, info_span, Instrument};
|
||||||
|
|
||||||
|
use super::conn_pool::ConnInfo;
|
||||||
|
|
||||||
|
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
|
||||||
|
pub(crate) type Connect =
|
||||||
|
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct ConnPoolEntry {
|
||||||
|
conn: Send,
|
||||||
|
conn_id: uuid::Uuid,
|
||||||
|
aux: MetricsAuxInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-endpoint connection pool
|
||||||
|
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||||
|
pub(crate) struct EndpointConnPool {
|
||||||
|
conns: VecDeque<ConnPoolEntry>,
|
||||||
|
_guard: HttpEndpointPoolsGuard<'static>,
|
||||||
|
global_connections_count: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EndpointConnPool {
|
||||||
|
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
|
||||||
|
let Self { conns, .. } = self;
|
||||||
|
|
||||||
|
let conn = conns.pop_front()?;
|
||||||
|
conns.push_back(conn.clone());
|
||||||
|
Some(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
|
||||||
|
let Self {
|
||||||
|
conns,
|
||||||
|
global_connections_count,
|
||||||
|
..
|
||||||
|
} = self;
|
||||||
|
|
||||||
|
let old_len = conns.len();
|
||||||
|
conns.retain(|conn| conn.conn_id != conn_id);
|
||||||
|
let new_len = conns.len();
|
||||||
|
let removed = old_len - new_len;
|
||||||
|
if removed > 0 {
|
||||||
|
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
|
||||||
|
Metrics::get()
|
||||||
|
.proxy
|
||||||
|
.http_pool_opened_connections
|
||||||
|
.get_metric()
|
||||||
|
.dec_by(removed as i64);
|
||||||
|
}
|
||||||
|
removed > 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for EndpointConnPool {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if !self.conns.is_empty() {
|
||||||
|
self.global_connections_count
|
||||||
|
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
|
||||||
|
Metrics::get()
|
||||||
|
.proxy
|
||||||
|
.http_pool_opened_connections
|
||||||
|
.get_metric()
|
||||||
|
.dec_by(self.conns.len() as i64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct GlobalConnPool {
|
||||||
|
// endpoint -> per-endpoint connection pool
|
||||||
|
//
|
||||||
|
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||||
|
// pool as early as possible and release the lock.
|
||||||
|
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
|
||||||
|
|
||||||
|
/// Number of endpoint-connection pools
|
||||||
|
///
|
||||||
|
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
|
||||||
|
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
|
||||||
|
/// It's only used for diagnostics.
|
||||||
|
global_pool_size: AtomicUsize,
|
||||||
|
|
||||||
|
/// Total number of connections in the pool
|
||||||
|
global_connections_count: Arc<AtomicUsize>,
|
||||||
|
|
||||||
|
config: &'static crate::config::HttpConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GlobalConnPool {
|
||||||
|
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||||
|
let shards = config.pool_options.pool_shards;
|
||||||
|
Arc::new(Self {
|
||||||
|
global_pool: DashMap::with_shard_amount(shards),
|
||||||
|
global_pool_size: AtomicUsize::new(0),
|
||||||
|
config,
|
||||||
|
global_connections_count: Arc::new(AtomicUsize::new(0)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn shutdown(&self) {
|
||||||
|
// drops all strong references to endpoint-pools
|
||||||
|
self.global_pool.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
|
||||||
|
let epoch = self.config.pool_options.gc_epoch;
|
||||||
|
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
let shard = rng.gen_range(0..self.global_pool.shards().len());
|
||||||
|
self.gc(shard);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gc(&self, shard: usize) {
|
||||||
|
debug!(shard, "pool: performing epoch reclamation");
|
||||||
|
|
||||||
|
// acquire a random shard lock
|
||||||
|
let mut shard = self.global_pool.shards()[shard].write();
|
||||||
|
|
||||||
|
let timer = Metrics::get()
|
||||||
|
.proxy
|
||||||
|
.http_pool_reclaimation_lag_seconds
|
||||||
|
.start_timer();
|
||||||
|
let current_len = shard.len();
|
||||||
|
let mut clients_removed = 0;
|
||||||
|
shard.retain(|endpoint, x| {
|
||||||
|
// if the current endpoint pool is unique (no other strong or weak references)
|
||||||
|
// then it is currently not in use by any connections.
|
||||||
|
if let Some(pool) = Arc::get_mut(x.get_mut()) {
|
||||||
|
let EndpointConnPool { conns, .. } = pool.get_mut();
|
||||||
|
|
||||||
|
let old_len = conns.len();
|
||||||
|
|
||||||
|
conns.retain(|conn| !conn.conn.is_closed());
|
||||||
|
|
||||||
|
let new_len = conns.len();
|
||||||
|
let removed = old_len - new_len;
|
||||||
|
clients_removed += removed;
|
||||||
|
|
||||||
|
// we only remove this pool if it has no active connections
|
||||||
|
if conns.is_empty() {
|
||||||
|
info!("pool: discarding pool for endpoint {endpoint}");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
});
|
||||||
|
|
||||||
|
let new_len = shard.len();
|
||||||
|
drop(shard);
|
||||||
|
timer.observe();
|
||||||
|
|
||||||
|
// Do logging outside of the lock.
|
||||||
|
if clients_removed > 0 {
|
||||||
|
let size = self
|
||||||
|
.global_connections_count
|
||||||
|
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
|
||||||
|
- clients_removed;
|
||||||
|
Metrics::get()
|
||||||
|
.proxy
|
||||||
|
.http_pool_opened_connections
|
||||||
|
.get_metric()
|
||||||
|
.dec_by(clients_removed as i64);
|
||||||
|
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
|
||||||
|
}
|
||||||
|
let removed = current_len - new_len;
|
||||||
|
|
||||||
|
if removed > 0 {
|
||||||
|
let global_pool_size = self
|
||||||
|
.global_pool_size
|
||||||
|
.fetch_sub(removed, atomic::Ordering::Relaxed)
|
||||||
|
- removed;
|
||||||
|
info!("pool: performed global pool gc. size now {global_pool_size}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
conn_info: &ConnInfo,
|
||||||
|
) -> Option<Client> {
|
||||||
|
let endpoint = conn_info.endpoint_cache_key()?;
|
||||||
|
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||||
|
let client = endpoint_pool.write().get_conn_entry()?;
|
||||||
|
|
||||||
|
if client.conn.is_closed() {
|
||||||
|
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||||
|
info!(
|
||||||
|
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||||
|
"pool: reusing connection '{conn_info}'"
|
||||||
|
);
|
||||||
|
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||||
|
ctx.success();
|
||||||
|
Some(Client::new(client.conn, client.aux))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_or_create_endpoint_pool(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
endpoint: &EndpointCacheKey,
|
||||||
|
) -> Arc<RwLock<EndpointConnPool>> {
|
||||||
|
// fast path
|
||||||
|
if let Some(pool) = self.global_pool.get(endpoint) {
|
||||||
|
return pool.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
// slow path
|
||||||
|
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
|
||||||
|
conns: VecDeque::new(),
|
||||||
|
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
|
||||||
|
global_connections_count: self.global_connections_count.clone(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// find or create a pool for this endpoint
|
||||||
|
let mut created = false;
|
||||||
|
let pool = self
|
||||||
|
.global_pool
|
||||||
|
.entry(endpoint.clone())
|
||||||
|
.or_insert_with(|| {
|
||||||
|
created = true;
|
||||||
|
new_pool
|
||||||
|
})
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
// log new global pool size
|
||||||
|
if created {
|
||||||
|
let global_pool_size = self
|
||||||
|
.global_pool_size
|
||||||
|
.fetch_add(1, atomic::Ordering::Relaxed)
|
||||||
|
+ 1;
|
||||||
|
info!(
|
||||||
|
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn poll_http2_client(
|
||||||
|
global_pool: Arc<GlobalConnPool>,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
conn_info: &ConnInfo,
|
||||||
|
client: Send,
|
||||||
|
connection: Connect,
|
||||||
|
conn_id: uuid::Uuid,
|
||||||
|
aux: MetricsAuxInfo,
|
||||||
|
) -> Client {
|
||||||
|
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||||
|
let session_id = ctx.session_id();
|
||||||
|
|
||||||
|
let span = info_span!(parent: None, "connection", %conn_id);
|
||||||
|
let cold_start_info = ctx.cold_start_info();
|
||||||
|
span.in_scope(|| {
|
||||||
|
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||||
|
});
|
||||||
|
|
||||||
|
let pool = match conn_info.endpoint_cache_key() {
|
||||||
|
Some(endpoint) => {
|
||||||
|
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
|
||||||
|
|
||||||
|
pool.write().conns.push_back(ConnPoolEntry {
|
||||||
|
conn: client.clone(),
|
||||||
|
conn_id,
|
||||||
|
aux: aux.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
Arc::downgrade(&pool)
|
||||||
|
}
|
||||||
|
None => Weak::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// let idle = global_pool.get_idle_timeout();
|
||||||
|
|
||||||
|
tokio::spawn(
|
||||||
|
async move {
|
||||||
|
let _conn_gauge = conn_gauge;
|
||||||
|
let res = connection.await;
|
||||||
|
match res {
|
||||||
|
Ok(()) => info!("connection closed"),
|
||||||
|
Err(e) => error!(%session_id, "connection error: {}", e),
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove from connection pool
|
||||||
|
if let Some(pool) = pool.clone().upgrade() {
|
||||||
|
if pool.write().remove_conn(conn_id) {
|
||||||
|
info!("closed connection removed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.instrument(span),
|
||||||
|
);
|
||||||
|
|
||||||
|
Client::new(client, aux)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct Client {
|
||||||
|
pub(crate) inner: Send,
|
||||||
|
aux: MetricsAuxInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
|
||||||
|
Self { inner, aux }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||||
|
USAGE_METRICS.register(Ids {
|
||||||
|
endpoint_id: self.aux.endpoint_id,
|
||||||
|
branch_id: self.aux.branch_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,13 +5,13 @@ use bytes::Bytes;
|
|||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use http::{Response, StatusCode};
|
use http::{Response, StatusCode};
|
||||||
use http_body_util::Full;
|
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||||
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use utils::http::error::ApiError;
|
use utils::http::error::ApiError;
|
||||||
|
|
||||||
/// Like [`ApiError::into_response`]
|
/// Like [`ApiError::into_response`]
|
||||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||||
match this {
|
match this {
|
||||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||||
@@ -64,17 +64,24 @@ struct HttpErrorBody {
|
|||||||
|
|
||||||
impl HttpErrorBody {
|
impl HttpErrorBody {
|
||||||
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
|
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
|
||||||
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
|
fn response_from_msg_and_status(
|
||||||
|
msg: String,
|
||||||
|
status: StatusCode,
|
||||||
|
) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||||
HttpErrorBody { msg }.to_response(status)
|
HttpErrorBody { msg }.to_response(status)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
|
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
|
||||||
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
|
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||||
Response::builder()
|
Response::builder()
|
||||||
.status(status)
|
.status(status)
|
||||||
.header(http::header::CONTENT_TYPE, "application/json")
|
.header(http::header::CONTENT_TYPE, "application/json")
|
||||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||||
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
|
.body(
|
||||||
|
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
|
||||||
|
.map_err(|x| match x {})
|
||||||
|
.boxed(),
|
||||||
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -83,14 +90,14 @@ impl HttpErrorBody {
|
|||||||
pub(crate) fn json_response<T: Serialize>(
|
pub(crate) fn json_response<T: Serialize>(
|
||||||
status: StatusCode,
|
status: StatusCode,
|
||||||
data: T,
|
data: T,
|
||||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||||
let json = serde_json::to_string(&data)
|
let json = serde_json::to_string(&data)
|
||||||
.context("Failed to serialize JSON response")
|
.context("Failed to serialize JSON response")
|
||||||
.map_err(ApiError::InternalServerError)?;
|
.map_err(ApiError::InternalServerError)?;
|
||||||
let response = Response::builder()
|
let response = Response::builder()
|
||||||
.status(status)
|
.status(status)
|
||||||
.header(http::header::CONTENT_TYPE, "application/json")
|
.header(http::header::CONTENT_TYPE, "application/json")
|
||||||
.body(Full::new(Bytes::from(json)))
|
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
|
||||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ use futures::future::Either;
|
|||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use futures::TryFutureExt;
|
use futures::TryFutureExt;
|
||||||
use http::header::AUTHORIZATION;
|
use http::header::AUTHORIZATION;
|
||||||
|
use http::Method;
|
||||||
|
use http_body_util::combinators::BoxBody;
|
||||||
use http_body_util::BodyExt;
|
use http_body_util::BodyExt;
|
||||||
use http_body_util::Full;
|
use http_body_util::Full;
|
||||||
use hyper1::body::Body;
|
use hyper1::body::Body;
|
||||||
@@ -38,9 +40,11 @@ use url::Url;
|
|||||||
use urlencoding;
|
use urlencoding;
|
||||||
use utils::http::error::ApiError;
|
use utils::http::error::ApiError;
|
||||||
|
|
||||||
|
use crate::auth::backend::ComputeCredentials;
|
||||||
use crate::auth::backend::ComputeUserInfo;
|
use crate::auth::backend::ComputeUserInfo;
|
||||||
use crate::auth::endpoint_sni;
|
use crate::auth::endpoint_sni;
|
||||||
use crate::auth::ComputeUserInfoParseError;
|
use crate::auth::ComputeUserInfoParseError;
|
||||||
|
use crate::config::AuthenticationConfig;
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::config::TlsConfig;
|
use crate::config::TlsConfig;
|
||||||
use crate::context::RequestMonitoring;
|
use crate::context::RequestMonitoring;
|
||||||
@@ -56,6 +60,7 @@ use crate::usage_metrics::MetricCounterRecorder;
|
|||||||
use crate::DbName;
|
use crate::DbName;
|
||||||
use crate::RoleName;
|
use crate::RoleName;
|
||||||
|
|
||||||
|
use super::backend::LocalProxyConnError;
|
||||||
use super::backend::PoolingBackend;
|
use super::backend::PoolingBackend;
|
||||||
use super::conn_pool::AuthData;
|
use super::conn_pool::AuthData;
|
||||||
use super::conn_pool::Client;
|
use super::conn_pool::Client;
|
||||||
@@ -123,8 +128,8 @@ pub(crate) enum ConnInfoError {
|
|||||||
MissingUsername,
|
MissingUsername,
|
||||||
#[error("invalid username: {0}")]
|
#[error("invalid username: {0}")]
|
||||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||||
#[error("missing password")]
|
#[error("missing authentication credentials: {0}")]
|
||||||
MissingPassword,
|
MissingCredentials(Credentials),
|
||||||
#[error("missing hostname")]
|
#[error("missing hostname")]
|
||||||
MissingHostname,
|
MissingHostname,
|
||||||
#[error("invalid hostname: {0}")]
|
#[error("invalid hostname: {0}")]
|
||||||
@@ -133,6 +138,14 @@ pub(crate) enum ConnInfoError {
|
|||||||
MalformedEndpoint,
|
MalformedEndpoint,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub(crate) enum Credentials {
|
||||||
|
#[error("required password")]
|
||||||
|
Password,
|
||||||
|
#[error("required authorization bearer token in JWT format")]
|
||||||
|
BearerJwt,
|
||||||
|
}
|
||||||
|
|
||||||
impl ReportableError for ConnInfoError {
|
impl ReportableError for ConnInfoError {
|
||||||
fn get_error_kind(&self) -> ErrorKind {
|
fn get_error_kind(&self) -> ErrorKind {
|
||||||
ErrorKind::User
|
ErrorKind::User
|
||||||
@@ -146,6 +159,7 @@ impl UserFacingError for ConnInfoError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_conn_info(
|
fn get_conn_info(
|
||||||
|
config: &'static AuthenticationConfig,
|
||||||
ctx: &RequestMonitoring,
|
ctx: &RequestMonitoring,
|
||||||
headers: &HeaderMap,
|
headers: &HeaderMap,
|
||||||
tls: Option<&TlsConfig>,
|
tls: Option<&TlsConfig>,
|
||||||
@@ -181,21 +195,32 @@ fn get_conn_info(
|
|||||||
ctx.set_user(username.clone());
|
ctx.set_user(username.clone());
|
||||||
|
|
||||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||||
|
if !config.accept_jwts {
|
||||||
|
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||||
|
}
|
||||||
|
|
||||||
let auth = auth
|
let auth = auth
|
||||||
.to_str()
|
.to_str()
|
||||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||||
AuthData::Jwt(
|
AuthData::Jwt(
|
||||||
auth.strip_prefix("Bearer ")
|
auth.strip_prefix("Bearer ")
|
||||||
.ok_or(ConnInfoError::MissingPassword)?
|
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
|
||||||
.into(),
|
.into(),
|
||||||
)
|
)
|
||||||
} else if let Some(pass) = connection_url.password() {
|
} else if let Some(pass) = connection_url.password() {
|
||||||
|
// wrong credentials provided
|
||||||
|
if config.accept_jwts {
|
||||||
|
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||||
|
}
|
||||||
|
|
||||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||||
std::borrow::Cow::Owned(b) => b.into(),
|
std::borrow::Cow::Owned(b) => b.into(),
|
||||||
})
|
})
|
||||||
|
} else if config.accept_jwts {
|
||||||
|
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||||
} else {
|
} else {
|
||||||
return Err(ConnInfoError::MissingPassword);
|
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||||
};
|
};
|
||||||
|
|
||||||
let endpoint = match connection_url.host() {
|
let endpoint = match connection_url.host() {
|
||||||
@@ -247,7 +272,7 @@ pub(crate) async fn handle(
|
|||||||
request: Request<Incoming>,
|
request: Request<Incoming>,
|
||||||
backend: Arc<PoolingBackend>,
|
backend: Arc<PoolingBackend>,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||||
|
|
||||||
let mut response = match result {
|
let mut response = match result {
|
||||||
@@ -279,7 +304,7 @@ pub(crate) async fn handle(
|
|||||||
|
|
||||||
let mut message = e.to_string_client();
|
let mut message = e.to_string_client();
|
||||||
let db_error = match &e {
|
let db_error = match &e {
|
||||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
|
||||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
@@ -504,7 +529,7 @@ async fn handle_inner(
|
|||||||
ctx: &RequestMonitoring,
|
ctx: &RequestMonitoring,
|
||||||
request: Request<Incoming>,
|
request: Request<Incoming>,
|
||||||
backend: Arc<PoolingBackend>,
|
backend: Arc<PoolingBackend>,
|
||||||
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
|
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||||
let _requeset_gauge = Metrics::get()
|
let _requeset_gauge = Metrics::get()
|
||||||
.proxy
|
.proxy
|
||||||
.connection_requests
|
.connection_requests
|
||||||
@@ -514,18 +539,50 @@ async fn handle_inner(
|
|||||||
"handling interactive connection from client"
|
"handling interactive connection from client"
|
||||||
);
|
);
|
||||||
|
|
||||||
//
|
let conn_info = get_conn_info(
|
||||||
// Determine the destination and connection params
|
&config.authentication_config,
|
||||||
//
|
ctx,
|
||||||
let headers = request.headers();
|
request.headers(),
|
||||||
|
config.tls_config.as_ref(),
|
||||||
// TLS config should be there.
|
)?;
|
||||||
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
|
|
||||||
info!(
|
info!(
|
||||||
user = conn_info.conn_info.user_info.user.as_str(),
|
user = conn_info.conn_info.user_info.user.as_str(),
|
||||||
"credentials"
|
"credentials"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
match conn_info.auth {
|
||||||
|
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||||
|
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
|
||||||
|
}
|
||||||
|
auth => {
|
||||||
|
handle_db_inner(
|
||||||
|
cancel,
|
||||||
|
config,
|
||||||
|
ctx,
|
||||||
|
request,
|
||||||
|
conn_info.conn_info,
|
||||||
|
auth,
|
||||||
|
backend,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_db_inner(
|
||||||
|
cancel: CancellationToken,
|
||||||
|
config: &'static ProxyConfig,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
request: Request<Incoming>,
|
||||||
|
conn_info: ConnInfo,
|
||||||
|
auth: AuthData,
|
||||||
|
backend: Arc<PoolingBackend>,
|
||||||
|
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||||
|
//
|
||||||
|
// Determine the destination and connection params
|
||||||
|
//
|
||||||
|
let headers = request.headers();
|
||||||
|
|
||||||
// Allow connection pooling only if explicitly requested
|
// Allow connection pooling only if explicitly requested
|
||||||
// or if we have decided that http pool is no longer opt-in
|
// or if we have decided that http pool is no longer opt-in
|
||||||
let allow_pool = !config.http_config.pool_options.opt_in
|
let allow_pool = !config.http_config.pool_options.opt_in
|
||||||
@@ -563,26 +620,36 @@ async fn handle_inner(
|
|||||||
|
|
||||||
let authenticate_and_connect = Box::pin(
|
let authenticate_and_connect = Box::pin(
|
||||||
async {
|
async {
|
||||||
let keys = match &conn_info.auth {
|
let keys = match auth {
|
||||||
AuthData::Password(pw) => {
|
AuthData::Password(pw) => {
|
||||||
backend
|
backend
|
||||||
.authenticate_with_password(
|
.authenticate_with_password(
|
||||||
ctx,
|
ctx,
|
||||||
&config.authentication_config,
|
&config.authentication_config,
|
||||||
&conn_info.conn_info.user_info,
|
&conn_info.user_info,
|
||||||
pw,
|
&pw,
|
||||||
)
|
)
|
||||||
.await?
|
.await?
|
||||||
}
|
}
|
||||||
AuthData::Jwt(jwt) => {
|
AuthData::Jwt(jwt) => {
|
||||||
backend
|
backend
|
||||||
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
|
.authenticate_with_jwt(
|
||||||
.await?
|
ctx,
|
||||||
|
&config.authentication_config,
|
||||||
|
&conn_info.user_info,
|
||||||
|
jwt,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
ComputeCredentials {
|
||||||
|
info: conn_info.user_info.clone(),
|
||||||
|
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = backend
|
let client = backend
|
||||||
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
|
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||||
.await?;
|
.await?;
|
||||||
// not strictly necessary to mark success here,
|
// not strictly necessary to mark success here,
|
||||||
// but it's just insurance for if we forget it somewhere else
|
// but it's just insurance for if we forget it somewhere else
|
||||||
@@ -640,7 +707,11 @@ async fn handle_inner(
|
|||||||
|
|
||||||
let len = json_output.len();
|
let len = json_output.len();
|
||||||
let response = response
|
let response = response
|
||||||
.body(Full::new(Bytes::from(json_output)))
|
.body(
|
||||||
|
Full::new(Bytes::from(json_output))
|
||||||
|
.map_err(|x| match x {})
|
||||||
|
.boxed(),
|
||||||
|
)
|
||||||
// only fails if invalid status code or invalid header/values are given.
|
// only fails if invalid status code or invalid header/values are given.
|
||||||
// these are not user configurable so it cannot fail dynamically
|
// these are not user configurable so it cannot fail dynamically
|
||||||
.expect("building response payload should not fail");
|
.expect("building response payload should not fail");
|
||||||
@@ -656,6 +727,65 @@ async fn handle_inner(
|
|||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||||
|
&AUTHORIZATION,
|
||||||
|
&CONN_STRING,
|
||||||
|
&RAW_TEXT_OUTPUT,
|
||||||
|
&ARRAY_MODE,
|
||||||
|
&TXN_ISOLATION_LEVEL,
|
||||||
|
&TXN_READ_ONLY,
|
||||||
|
&TXN_DEFERRABLE,
|
||||||
|
];
|
||||||
|
|
||||||
|
async fn handle_auth_broker_inner(
|
||||||
|
config: &'static ProxyConfig,
|
||||||
|
ctx: &RequestMonitoring,
|
||||||
|
request: Request<Incoming>,
|
||||||
|
conn_info: ConnInfo,
|
||||||
|
jwt: String,
|
||||||
|
backend: Arc<PoolingBackend>,
|
||||||
|
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||||
|
backend
|
||||||
|
.authenticate_with_jwt(
|
||||||
|
ctx,
|
||||||
|
&config.authentication_config,
|
||||||
|
&conn_info.user_info,
|
||||||
|
jwt,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(HttpConnError::from)?;
|
||||||
|
|
||||||
|
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||||
|
|
||||||
|
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||||
|
|
||||||
|
let (mut parts, body) = request.into_parts();
|
||||||
|
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
|
||||||
|
|
||||||
|
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
|
||||||
|
// these instead just to ensure they remain normalised.
|
||||||
|
for &h in HEADERS_TO_FORWARD {
|
||||||
|
if let Some(hv) = parts.headers.remove(h) {
|
||||||
|
req = req.header(h, hv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let req = req
|
||||||
|
.body(body)
|
||||||
|
.expect("all headers and params received via hyper should be valid for request");
|
||||||
|
|
||||||
|
// todo: map body to count egress
|
||||||
|
let _metrics = client.metrics();
|
||||||
|
|
||||||
|
Ok(client
|
||||||
|
.inner
|
||||||
|
.send_request(req)
|
||||||
|
.await
|
||||||
|
.map_err(LocalProxyConnError::from)
|
||||||
|
.map_err(HttpConnError::from)?
|
||||||
|
.map(|b| b.boxed()))
|
||||||
|
}
|
||||||
|
|
||||||
impl QueryData {
|
impl QueryData {
|
||||||
async fn process(
|
async fn process(
|
||||||
self,
|
self,
|
||||||
@@ -705,7 +835,9 @@ impl QueryData {
|
|||||||
// query failed or was cancelled.
|
// query failed or was cancelled.
|
||||||
Ok(Err(error)) => {
|
Ok(Err(error)) => {
|
||||||
let db_error = match &error {
|
let db_error = match &error {
|
||||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
SqlOverHttpError::ConnectCompute(
|
||||||
|
HttpConnError::PostgresConnectionError(e),
|
||||||
|
)
|
||||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ from _pytest.fixtures import FixtureRequest
|
|||||||
from psycopg2.extensions import connection as PgConnection
|
from psycopg2.extensions import connection as PgConnection
|
||||||
from psycopg2.extensions import cursor as PgCursor
|
from psycopg2.extensions import cursor as PgCursor
|
||||||
from psycopg2.extensions import make_dsn, parse_dsn
|
from psycopg2.extensions import make_dsn, parse_dsn
|
||||||
|
from pytest_httpserver import HTTPServer
|
||||||
from urllib3.util.retry import Retry
|
from urllib3.util.retry import Retry
|
||||||
|
|
||||||
from fixtures import overlayfs
|
from fixtures import overlayfs
|
||||||
@@ -440,9 +441,9 @@ class NeonEnvBuilder:
|
|||||||
|
|
||||||
self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine
|
self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine
|
||||||
|
|
||||||
self.pageserver_default_tenant_config_compaction_algorithm: Optional[
|
self.pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]] = (
|
||||||
Dict[str, Any]
|
pageserver_default_tenant_config_compaction_algorithm
|
||||||
] = pageserver_default_tenant_config_compaction_algorithm
|
)
|
||||||
if self.pageserver_default_tenant_config_compaction_algorithm is not None:
|
if self.pageserver_default_tenant_config_compaction_algorithm is not None:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Overriding pageserver default compaction algorithm to {self.pageserver_default_tenant_config_compaction_algorithm}"
|
f"Overriding pageserver default compaction algorithm to {self.pageserver_default_tenant_config_compaction_algorithm}"
|
||||||
@@ -1072,9 +1073,9 @@ class NeonEnv:
|
|||||||
ps_cfg["virtual_file_io_engine"] = self.pageserver_virtual_file_io_engine
|
ps_cfg["virtual_file_io_engine"] = self.pageserver_virtual_file_io_engine
|
||||||
if config.pageserver_default_tenant_config_compaction_algorithm is not None:
|
if config.pageserver_default_tenant_config_compaction_algorithm is not None:
|
||||||
tenant_config = ps_cfg.setdefault("tenant_config", {})
|
tenant_config = ps_cfg.setdefault("tenant_config", {})
|
||||||
tenant_config[
|
tenant_config["compaction_algorithm"] = (
|
||||||
"compaction_algorithm"
|
config.pageserver_default_tenant_config_compaction_algorithm
|
||||||
] = config.pageserver_default_tenant_config_compaction_algorithm
|
)
|
||||||
|
|
||||||
if self.pageserver_remote_storage is not None:
|
if self.pageserver_remote_storage is not None:
|
||||||
ps_cfg["remote_storage"] = remote_storage_to_toml_dict(
|
ps_cfg["remote_storage"] = remote_storage_to_toml_dict(
|
||||||
@@ -1117,9 +1118,9 @@ class NeonEnv:
|
|||||||
if config.auth_enabled:
|
if config.auth_enabled:
|
||||||
sk_cfg["auth_enabled"] = True
|
sk_cfg["auth_enabled"] = True
|
||||||
if self.safekeepers_remote_storage is not None:
|
if self.safekeepers_remote_storage is not None:
|
||||||
sk_cfg[
|
sk_cfg["remote_storage"] = (
|
||||||
"remote_storage"
|
self.safekeepers_remote_storage.to_toml_inline_table().strip()
|
||||||
] = self.safekeepers_remote_storage.to_toml_inline_table().strip()
|
)
|
||||||
self.safekeepers.append(
|
self.safekeepers.append(
|
||||||
Safekeeper(env=self, id=id, port=port, extra_opts=config.safekeeper_extra_opts)
|
Safekeeper(env=self, id=id, port=port, extra_opts=config.safekeeper_extra_opts)
|
||||||
)
|
)
|
||||||
@@ -3572,6 +3573,20 @@ class NeonProxy(PgProtocol):
|
|||||||
]
|
]
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
class AuthBroker(AuthBackend):
|
||||||
|
def __init__(self, endpoint: str):
|
||||||
|
self.endpoint = endpoint
|
||||||
|
|
||||||
|
def extra_args(self) -> list[str]:
|
||||||
|
args = [
|
||||||
|
# Console auth backend params
|
||||||
|
*["--auth-backend", "console"],
|
||||||
|
*["--auth-endpoint", self.endpoint],
|
||||||
|
*["--sql-over-http-pool-opt-in", "false"],
|
||||||
|
*["--is-auth-broker"],
|
||||||
|
]
|
||||||
|
return args
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Postgres(AuthBackend):
|
class Postgres(AuthBackend):
|
||||||
pg_conn_url: str
|
pg_conn_url: str
|
||||||
@@ -3600,7 +3615,7 @@ class NeonProxy(PgProtocol):
|
|||||||
metric_collection_interval: Optional[str] = None,
|
metric_collection_interval: Optional[str] = None,
|
||||||
):
|
):
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
domain = "proxy.localtest.me" # resolves to 127.0.0.1
|
domain = "ep-foo-bar-1234.localtest.me" # resolves to 127.0.0.1
|
||||||
super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port)
|
super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port)
|
||||||
|
|
||||||
self.domain = domain
|
self.domain = domain
|
||||||
@@ -3886,6 +3901,50 @@ def static_proxy(
|
|||||||
yield proxy
|
yield proxy
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def static_auth_broker(
|
||||||
|
vanilla_pg: VanillaPostgres,
|
||||||
|
port_distributor: PortDistributor,
|
||||||
|
neon_binpath: Path,
|
||||||
|
test_output_dir: Path,
|
||||||
|
httpserver: HTTPServer,
|
||||||
|
) -> Iterator[NeonProxy]:
|
||||||
|
"""Neon proxy that routes directly to vanilla postgres."""
|
||||||
|
|
||||||
|
auth_endpoint = httpserver.url_for("/cplane")
|
||||||
|
|
||||||
|
port = vanilla_pg.default_options["port"]
|
||||||
|
host = vanilla_pg.default_options["host"]
|
||||||
|
|
||||||
|
httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json(
|
||||||
|
{
|
||||||
|
"address": f"{host}:{port}",
|
||||||
|
"aux": {
|
||||||
|
"endpoint_id": "ep-foo-bar-1234",
|
||||||
|
"branch_id": "br-foo-bar",
|
||||||
|
"project_id": "foo-bar",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_port = port_distributor.get_port()
|
||||||
|
mgmt_port = port_distributor.get_port()
|
||||||
|
http_port = port_distributor.get_port()
|
||||||
|
external_http_port = port_distributor.get_port()
|
||||||
|
|
||||||
|
with NeonProxy(
|
||||||
|
neon_binpath=neon_binpath,
|
||||||
|
test_output_dir=test_output_dir,
|
||||||
|
proxy_port=proxy_port,
|
||||||
|
http_port=http_port,
|
||||||
|
mgmt_port=mgmt_port,
|
||||||
|
external_http_port=external_http_port,
|
||||||
|
auth_backend=NeonProxy.AuthBroker(auth_endpoint),
|
||||||
|
) as proxy:
|
||||||
|
proxy.start()
|
||||||
|
yield proxy
|
||||||
|
|
||||||
|
|
||||||
class Endpoint(PgProtocol, LogUtils):
|
class Endpoint(PgProtocol, LogUtils):
|
||||||
"""An object representing a Postgres compute endpoint managed by the control plane."""
|
"""An object representing a Postgres compute endpoint managed by the control plane."""
|
||||||
|
|
||||||
|
|||||||
71
test_runner/regress/test_auth_broker.py
Normal file
71
test_runner/regress/test_auth_broker.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres
|
||||||
|
|
||||||
|
GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sql_over_http(static_auth_broker: NeonProxy):
|
||||||
|
static_auth_broker.safe_psql("create role http with login password 'http' superuser")
|
||||||
|
|
||||||
|
def q(sql: str, params: Optional[List[Any]] = None) -> Any:
|
||||||
|
params = params or []
|
||||||
|
connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
|
||||||
|
response = requests.post(
|
||||||
|
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
||||||
|
data=json.dumps({"query": sql, "params": params}),
|
||||||
|
headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
|
||||||
|
verify=str(static_proxy.test_output_dir / "proxy.crt"),
|
||||||
|
)
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
rows = q("select 42 as answer")["rows"]
|
||||||
|
assert rows == [{"answer": 42}]
|
||||||
|
|
||||||
|
rows = q("select $1 as answer", [42])["rows"]
|
||||||
|
assert rows == [{"answer": "42"}]
|
||||||
|
|
||||||
|
rows = q("select $1 * 1 as answer", [42])["rows"]
|
||||||
|
assert rows == [{"answer": 42}]
|
||||||
|
|
||||||
|
rows = q("select $1::int[] as answer", [[1, 2, 3]])["rows"]
|
||||||
|
assert rows == [{"answer": [1, 2, 3]}]
|
||||||
|
|
||||||
|
rows = q("select $1::json->'a' as answer", [{"a": {"b": 42}}])["rows"]
|
||||||
|
assert rows == [{"answer": {"b": 42}}]
|
||||||
|
|
||||||
|
rows = q("select $1::jsonb[] as answer", [[{}]])["rows"]
|
||||||
|
assert rows == [{"answer": [{}]}]
|
||||||
|
|
||||||
|
rows = q("select $1::jsonb[] as answer", [[{"foo": 1}, {"bar": 2}]])["rows"]
|
||||||
|
assert rows == [{"answer": [{"foo": 1}, {"bar": 2}]}]
|
||||||
|
|
||||||
|
rows = q("select * from pg_class limit 1")["rows"]
|
||||||
|
assert len(rows) == 1
|
||||||
|
|
||||||
|
res = q("create table t(id serial primary key, val int)")
|
||||||
|
assert res["command"] == "CREATE"
|
||||||
|
assert res["rowCount"] is None
|
||||||
|
|
||||||
|
res = q("insert into t(val) values (10), (20), (30) returning id")
|
||||||
|
assert res["command"] == "INSERT"
|
||||||
|
assert res["rowCount"] == 3
|
||||||
|
assert res["rows"] == [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||||
|
|
||||||
|
res = q("select * from t")
|
||||||
|
assert res["command"] == "SELECT"
|
||||||
|
assert res["rowCount"] == 3
|
||||||
|
|
||||||
|
res = q("drop table t")
|
||||||
|
assert res["command"] == "DROP"
|
||||||
|
assert res["rowCount"] is None
|
||||||
|
|
||||||
Reference in New Issue
Block a user