mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-04 12:02:55 +00:00
proxy: auth broker (#8855)
Opens http2 connection to local-proxy and forwards requests over with all headers and body closes https://github.com/neondatabase/cloud/issues/16039
This commit is contained in:
@@ -565,7 +565,7 @@ mod tests {
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
|
||||
use super::{auth_quirks, AuthRateLimiter};
|
||||
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
|
||||
|
||||
struct Auth {
|
||||
ips: Vec<IpPattern>,
|
||||
@@ -611,12 +611,15 @@ mod tests {
|
||||
}
|
||||
|
||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(1),
|
||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||
rate_limiter_enabled: true,
|
||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: false,
|
||||
});
|
||||
|
||||
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
||||
|
||||
@@ -8,7 +8,7 @@ use anyhow::{bail, ensure, Context};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use dashmap::DashMap;
|
||||
use jose_jwk::crypto::KeyInfo;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use serde::{de::Visitor, Deserialize, Deserializer};
|
||||
use signature::Verifier;
|
||||
use tokio::time::Instant;
|
||||
|
||||
@@ -311,13 +311,11 @@ impl JwkCacheEntryLock {
|
||||
|
||||
tracing::debug!(?payload, "JWT signature valid with claims");
|
||||
|
||||
match (expected_audience, payload.audience) {
|
||||
// check the audience matches
|
||||
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
|
||||
// the audience is expected but is missing
|
||||
(Some(_), None) => bail!("invalid JWT token audience"),
|
||||
// we don't care for the audience field
|
||||
(None, _) => {}
|
||||
if let Some(aud) = expected_audience {
|
||||
ensure!(
|
||||
payload.audience.0.iter().any(|s| s == aud),
|
||||
"invalid JWT token audience"
|
||||
);
|
||||
}
|
||||
|
||||
let now = SystemTime::now();
|
||||
@@ -420,11 +418,12 @@ struct JwtHeader<'a> {
|
||||
}
|
||||
|
||||
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
|
||||
#[derive(serde::Deserialize, serde::Serialize, Debug)]
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
#[allow(dead_code)]
|
||||
struct JwtPayload<'a> {
|
||||
/// Audience - Recipient for which the JWT is intended
|
||||
#[serde(rename = "aud")]
|
||||
audience: Option<&'a str>,
|
||||
#[serde(rename = "aud", default)]
|
||||
audience: OneOrMany,
|
||||
/// Expiration - Time after which the JWT expires
|
||||
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
|
||||
expiration: Option<SystemTime>,
|
||||
@@ -447,6 +446,59 @@ struct JwtPayload<'a> {
|
||||
session_id: Option<&'a str>,
|
||||
}
|
||||
|
||||
/// `OneOrMany` supports parsing either a single item or an array of items.
|
||||
///
|
||||
/// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
|
||||
///
|
||||
/// > The "aud" (audience) claim identifies the recipients that the JWT is
|
||||
/// > intended for. Each principal intended to process the JWT MUST
|
||||
/// > identify itself with a value in the audience claim. If the principal
|
||||
/// > processing the claim does not identify itself with a value in the
|
||||
/// > "aud" claim when this claim is present, then the JWT MUST be
|
||||
/// > rejected. In the general case, the "aud" value is **an array of case-
|
||||
/// > sensitive strings**, each containing a StringOrURI value. In the
|
||||
/// > special case when the JWT has one audience, the "aud" value MAY be a
|
||||
/// > **single case-sensitive string** containing a StringOrURI value. The
|
||||
/// > interpretation of audience values is generally application specific.
|
||||
/// > Use of this claim is OPTIONAL.
|
||||
#[derive(Default, Debug)]
|
||||
struct OneOrMany(Vec<String>);
|
||||
|
||||
impl<'de> Deserialize<'de> for OneOrMany {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct OneOrManyVisitor;
|
||||
impl<'de> Visitor<'de> for OneOrManyVisitor {
|
||||
type Value = OneOrMany;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a single string or an array of strings")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(OneOrMany(vec![v.to_owned()]))
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let mut v = vec![];
|
||||
while let Some(s) = seq.next_element()? {
|
||||
v.push(s);
|
||||
}
|
||||
Ok(OneOrMany(v))
|
||||
}
|
||||
}
|
||||
deserializer.deserialize_any(OneOrManyVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
|
||||
let d = <Option<u64>>::deserialize(d)?;
|
||||
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
|
||||
|
||||
@@ -14,17 +14,15 @@ use crate::{
|
||||
EndpointId,
|
||||
};
|
||||
|
||||
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
|
||||
use super::jwt::{AuthRule, FetchAuthRules};
|
||||
|
||||
pub struct LocalBackend {
|
||||
pub(crate) jwks_cache: JwkCache,
|
||||
pub(crate) node_info: NodeInfo,
|
||||
}
|
||||
|
||||
impl LocalBackend {
|
||||
pub fn new(postgres_addr: SocketAddr) -> Self {
|
||||
LocalBackend {
|
||||
jwks_cache: JwkCache::default(),
|
||||
node_info: NodeInfo {
|
||||
config: {
|
||||
let mut cfg = ConnCfg::new();
|
||||
|
||||
@@ -6,7 +6,10 @@ use compute_api::spec::LocalProxySpec;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::Either;
|
||||
use proxy::{
|
||||
auth::backend::local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
auth::backend::{
|
||||
jwt::JwkCache,
|
||||
local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
},
|
||||
cancellation::CancellationHandlerMain,
|
||||
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
|
||||
console::{
|
||||
@@ -267,12 +270,15 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
allow_self_signed_compute: false,
|
||||
http_config,
|
||||
authentication_config: AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(0),
|
||||
scram_protocol_timeout: Duration::from_secs(10),
|
||||
rate_limiter_enabled: false,
|
||||
rate_limiter: BucketRateLimiter::new(vec![]),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: true,
|
||||
},
|
||||
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
|
||||
handshake_timeout: Duration::from_secs(10),
|
||||
|
||||
@@ -8,6 +8,7 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
|
||||
use aws_config::Region;
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::AuthRateLimiter;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::cancellation::CancelMap;
|
||||
@@ -103,6 +104,9 @@ struct ProxyCliArgs {
|
||||
default_value = "http://localhost:3000/authenticate_proxy_request/"
|
||||
)]
|
||||
auth_endpoint: String,
|
||||
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
is_auth_broker: bool,
|
||||
/// path to TLS key for client postgres connections
|
||||
///
|
||||
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||
@@ -385,9 +389,27 @@ async fn main() -> anyhow::Result<()> {
|
||||
info!("Starting mgmt on {mgmt_address}");
|
||||
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
|
||||
|
||||
let proxy_address: SocketAddr = args.proxy.parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
let proxy_listener = if !args.is_auth_broker {
|
||||
let proxy_address: SocketAddr = args.proxy.parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
|
||||
Some(TcpListener::bind(proxy_address).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
let serverless_listener = if let Some(serverless_address) = args.wss {
|
||||
let serverless_address: SocketAddr = serverless_address.parse()?;
|
||||
info!("Starting wss on {serverless_address}");
|
||||
Some(TcpListener::bind(serverless_address).await?)
|
||||
} else if args.is_auth_broker {
|
||||
bail!("wss arg must be present for auth-broker")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let cancel_map = CancelMap::default();
|
||||
@@ -433,21 +455,17 @@ async fn main() -> anyhow::Result<()> {
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
if let Some(serverless_address) = args.wss {
|
||||
let serverless_address: SocketAddr = serverless_address.parse()?;
|
||||
info!("Starting wss on {serverless_address}");
|
||||
let serverless_listener = TcpListener::bind(serverless_address).await?;
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(serverless_listener) = serverless_listener {
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
serverless_listener,
|
||||
@@ -677,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
)?;
|
||||
|
||||
let http_config = HttpConfig {
|
||||
accept_websockets: true,
|
||||
accept_websockets: !args.is_auth_broker,
|
||||
pool_options: GlobalConnPoolOptions {
|
||||
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
|
||||
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
|
||||
@@ -692,12 +710,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_auth_broker: args.is_auth_broker,
|
||||
accept_jwts: args.is_auth_broker,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::{
|
||||
auth::{self, backend::AuthRateLimiter},
|
||||
auth::{
|
||||
self,
|
||||
backend::{jwt::JwkCache, AuthRateLimiter},
|
||||
},
|
||||
console::locks::ApiLocks,
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
||||
scram::threadpool::ThreadPool,
|
||||
@@ -78,6 +81,9 @@ pub struct AuthenticationConfig {
|
||||
pub rate_limiter: AuthRateLimiter,
|
||||
pub rate_limit_ip_subnet: u8,
|
||||
pub ip_allowlist_check_enabled: bool,
|
||||
pub jwks_cache: JwkCache,
|
||||
pub is_auth_broker: bool,
|
||||
pub accept_jwts: bool,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
@@ -261,18 +267,26 @@ impl CertResolver {
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
let common_name = if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
.context("Failed to parse common name from certificate")?;
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
mod backend;
|
||||
pub mod cancel_set;
|
||||
mod conn_pool;
|
||||
mod http_conn_pool;
|
||||
mod http_util;
|
||||
mod json;
|
||||
mod sql_over_http;
|
||||
@@ -19,7 +20,8 @@ use anyhow::Context;
|
||||
use futures::future::{select, Either};
|
||||
use futures::TryFutureExt;
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use hyper1::body::Incoming;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
@@ -81,7 +83,28 @@ pub async fn task_main(
|
||||
}
|
||||
});
|
||||
|
||||
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
|
||||
{
|
||||
let http_conn_pool = Arc::clone(&http_conn_pool);
|
||||
tokio::spawn(async move {
|
||||
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
|
||||
});
|
||||
}
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::spawn({
|
||||
let cancellation_token = cancellation_token.clone();
|
||||
let http_conn_pool = http_conn_pool.clone();
|
||||
async move {
|
||||
cancellation_token.cancelled().await;
|
||||
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let backend = Arc::new(PoolingBackend {
|
||||
http_conn_pool: Arc::clone(&http_conn_pool),
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
@@ -342,7 +365,7 @@ async fn request_handler(
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
@@ -386,7 +409,7 @@ async fn request_handler(
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
|
||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
@@ -409,7 +432,7 @@ async fn request_handler(
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Full::new(Bytes::new()))
|
||||
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{io, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use tokio::net::{lookup_host, TcpStream};
|
||||
use tracing::{field::display, info};
|
||||
|
||||
use crate::{
|
||||
@@ -27,9 +29,13 @@ use crate::{
|
||||
Host,
|
||||
};
|
||||
|
||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||
use super::{
|
||||
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
|
||||
http_conn_pool::{self, poll_http2_client},
|
||||
};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -103,32 +109,44 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_jwt(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
user_info: &ComputeUserInfo,
|
||||
jwt: &str,
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
jwt: String,
|
||||
) -> Result<(), AuthError> {
|
||||
match &self.config.auth_backend {
|
||||
crate::auth::Backend::Console(_, ()) => {
|
||||
Err(AuthError::auth_failed("JWT login is not yet supported"))
|
||||
crate::auth::Backend::Console(console, ()) => {
|
||||
config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
user_info.endpoint.clone(),
|
||||
&user_info.user,
|
||||
&**console,
|
||||
&jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
|
||||
"JWT login over web auth proxy is not supported",
|
||||
)),
|
||||
crate::auth::Backend::Local(cache) => {
|
||||
cache
|
||||
crate::auth::Backend::Local(_) => {
|
||||
config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
user_info.endpoint.clone(),
|
||||
&user_info.user,
|
||||
&StaticAuthRules,
|
||||
jwt,
|
||||
&jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
Ok(ComputeCredentials {
|
||||
info: user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
})
|
||||
|
||||
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -174,14 +192,55 @@ impl PoolingBackend {
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Wake up the destination if needed
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
pub(crate) async fn connect_to_local_proxy(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client, HttpConnError> {
|
||||
info!("pool: looking for an existing connection");
|
||||
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self
|
||||
.config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|()| ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.http_conn_pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&backend,
|
||||
false, // do not allow self signed compute for http flow
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
#[error("could not connection to compute")]
|
||||
ConnectionError(#[from] tokio_postgres::Error),
|
||||
#[error("could not connection to postgres in compute")]
|
||||
PostgresConnectionError(#[from] tokio_postgres::Error),
|
||||
#[error("could not connection to local-proxy in compute")]
|
||||
LocalProxyConnectionError(#[from] LocalProxyConnError),
|
||||
|
||||
#[error("could not get auth info")]
|
||||
GetAuthInfo(#[from] GetAuthInfoError),
|
||||
@@ -193,11 +252,20 @@ pub(crate) enum HttpConnError {
|
||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum LocalProxyConnError {
|
||||
#[error("error with connection to local-proxy")]
|
||||
Io(#[source] std::io::Error),
|
||||
#[error("could not establish h2 connection")]
|
||||
H2(#[from] hyper1::Error),
|
||||
}
|
||||
|
||||
impl ReportableError for HttpConnError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
||||
HttpConnError::ConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
|
||||
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
||||
HttpConnError::AuthError(a) => a.get_error_kind(),
|
||||
HttpConnError::WakeCompute(w) => w.get_error_kind(),
|
||||
@@ -210,7 +278,8 @@ impl UserFacingError for HttpConnError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
||||
HttpConnError::ConnectionError(p) => p.to_string(),
|
||||
HttpConnError::PostgresConnectionError(p) => p.to_string(),
|
||||
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
|
||||
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
||||
HttpConnError::AuthError(c) => c.to_string_client(),
|
||||
HttpConnError::WakeCompute(c) => c.to_string_client(),
|
||||
@@ -224,7 +293,8 @@ impl UserFacingError for HttpConnError {
|
||||
impl CouldRetry for HttpConnError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
HttpConnError::ConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => false,
|
||||
HttpConnError::GetAuthInfo(_) => false,
|
||||
HttpConnError::AuthError(_) => false,
|
||||
@@ -236,7 +306,7 @@ impl CouldRetry for HttpConnError {
|
||||
impl ShouldRetryWakeCompute for HttpConnError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
|
||||
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
|
||||
// we never checked cache validity
|
||||
HttpConnError::TooManyConnectionAttempts(_) => false,
|
||||
_ => true,
|
||||
@@ -244,6 +314,38 @@ impl ShouldRetryWakeCompute for HttpConnError {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for LocalProxyConnError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => ErrorKind::Compute,
|
||||
LocalProxyConnError::H2(_) => ErrorKind::Compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for LocalProxyConnError {
|
||||
fn to_string_client(&self) -> String {
|
||||
"Could not establish HTTP connection to the database".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for LocalProxyConnError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => false,
|
||||
LocalProxyConnError::H2(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for LocalProxyConnError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => false,
|
||||
LocalProxyConnError::H2(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
conn_info: ConnInfo,
|
||||
@@ -293,3 +395,99 @@ impl ConnectMechanism for TokioMechanism {
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
struct HyperMechanism {
|
||||
pool: Arc<http_conn_pool::GlobalConnPool>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
locks: &'static ApiLocks<Host>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for HyperMechanism {
|
||||
type Connection = http_conn_pool::Client;
|
||||
type ConnectError = HttpConnError;
|
||||
type Error = HttpConnError;
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host()?;
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
|
||||
let res = connect_http2(&host, 10432, timeout).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
Ok(poll_http2_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
&self.conn_info,
|
||||
client,
|
||||
connection,
|
||||
self.conn_id,
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
async fn connect_http2(
|
||||
host: &str,
|
||||
port: u16,
|
||||
timeout: Duration,
|
||||
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
|
||||
// assumption: host is an ip address so this should not actually perform any requests.
|
||||
// todo: add that assumption as a guarantee in the control-plane API.
|
||||
let mut addrs = lookup_host((host, port))
|
||||
.await
|
||||
.map_err(LocalProxyConnError::Io)?;
|
||||
|
||||
let mut last_err = None;
|
||||
|
||||
let stream = loop {
|
||||
let Some(addr) = addrs.next() else {
|
||||
return Err(last_err.unwrap_or_else(|| {
|
||||
LocalProxyConnError::Io(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"could not resolve any addresses",
|
||||
))
|
||||
}));
|
||||
};
|
||||
|
||||
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
|
||||
Ok(Ok(stream)) => {
|
||||
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
|
||||
break stream;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
last_err = Some(LocalProxyConnError::Io(e));
|
||||
}
|
||||
Err(e) => {
|
||||
last_err = Some(LocalProxyConnError::Io(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
e,
|
||||
)));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive_interval(Duration::from_secs(20))
|
||||
.keep_alive_while_idle(true)
|
||||
.keep_alive_timeout(Duration::from_secs(5))
|
||||
.handshake(TokioIo::new(stream))
|
||||
.await?;
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
342
proxy/src/serverless/http_conn_pool.rs
Normal file
342
proxy/src/serverless/http_conn_pool.rs
Normal file
@@ -0,0 +1,342 @@
|
||||
use dashmap::DashMap;
|
||||
use hyper1::client::conn::http2;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use parking_lot::RwLock;
|
||||
use rand::Rng;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::{self, AtomicUsize};
|
||||
use std::{sync::Arc, sync::Weak};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{context::RequestMonitoring, EndpointCacheKey};
|
||||
|
||||
use tracing::{debug, error};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
|
||||
pub(crate) type Connect =
|
||||
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConnPoolEntry {
|
||||
conn: Send,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
// Per-endpoint connection pool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool {
|
||||
// TODO(conrad):
|
||||
// either we should open more connections depending on stream count
|
||||
// (not exposed by hyper, need our own counter)
|
||||
// or we can change this to an Option rather than a VecDeque.
|
||||
//
|
||||
// Opening more connections to the same db because we run out of streams
|
||||
// seems somewhat redundant though.
|
||||
//
|
||||
// Probably we should run a semaphore and just the single conn. TBD.
|
||||
conns: VecDeque<ConnPoolEntry>,
|
||||
_guard: HttpEndpointPoolsGuard<'static>,
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl EndpointConnPool {
|
||||
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
|
||||
let Self { conns, .. } = self;
|
||||
|
||||
loop {
|
||||
let conn = conns.pop_front()?;
|
||||
if !conn.conn.is_closed() {
|
||||
conns.push_back(conn.clone());
|
||||
return Some(conn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
|
||||
let Self {
|
||||
conns,
|
||||
global_connections_count,
|
||||
..
|
||||
} = self;
|
||||
|
||||
let old_len = conns.len();
|
||||
conns.retain(|conn| conn.conn_id != conn_id);
|
||||
let new_len = conns.len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
}
|
||||
removed > 0
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EndpointConnPool {
|
||||
fn drop(&mut self) {
|
||||
if !self.conns.is_empty() {
|
||||
self.global_connections_count
|
||||
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(self.conns.len() as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct GlobalConnPool {
|
||||
// endpoint -> per-endpoint connection pool
|
||||
//
|
||||
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||
// pool as early as possible and release the lock.
|
||||
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
|
||||
|
||||
/// Number of endpoint-connection pools
|
||||
///
|
||||
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
|
||||
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
|
||||
/// It's only used for diagnostics.
|
||||
global_pool_size: AtomicUsize,
|
||||
|
||||
/// Total number of connections in the pool
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
|
||||
config: &'static crate::config::HttpConfig,
|
||||
}
|
||||
|
||||
impl GlobalConnPool {
|
||||
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
let shards = config.pool_options.pool_shards;
|
||||
Arc::new(Self {
|
||||
global_pool: DashMap::with_shard_amount(shards),
|
||||
global_pool_size: AtomicUsize::new(0),
|
||||
config,
|
||||
global_connections_count: Arc::new(AtomicUsize::new(0)),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&self) {
|
||||
// drops all strong references to endpoint-pools
|
||||
self.global_pool.clear();
|
||||
}
|
||||
|
||||
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
|
||||
let epoch = self.config.pool_options.gc_epoch;
|
||||
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let shard = rng.gen_range(0..self.global_pool.shards().len());
|
||||
self.gc(shard);
|
||||
}
|
||||
}
|
||||
|
||||
fn gc(&self, shard: usize) {
|
||||
debug!(shard, "pool: performing epoch reclamation");
|
||||
|
||||
// acquire a random shard lock
|
||||
let mut shard = self.global_pool.shards()[shard].write();
|
||||
|
||||
let timer = Metrics::get()
|
||||
.proxy
|
||||
.http_pool_reclaimation_lag_seconds
|
||||
.start_timer();
|
||||
let current_len = shard.len();
|
||||
let mut clients_removed = 0;
|
||||
shard.retain(|endpoint, x| {
|
||||
// if the current endpoint pool is unique (no other strong or weak references)
|
||||
// then it is currently not in use by any connections.
|
||||
if let Some(pool) = Arc::get_mut(x.get_mut()) {
|
||||
let EndpointConnPool { conns, .. } = pool.get_mut();
|
||||
|
||||
let old_len = conns.len();
|
||||
|
||||
conns.retain(|conn| !conn.conn.is_closed());
|
||||
|
||||
let new_len = conns.len();
|
||||
let removed = old_len - new_len;
|
||||
clients_removed += removed;
|
||||
|
||||
// we only remove this pool if it has no active connections
|
||||
if conns.is_empty() {
|
||||
info!("pool: discarding pool for endpoint {endpoint}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
});
|
||||
|
||||
let new_len = shard.len();
|
||||
drop(shard);
|
||||
timer.observe();
|
||||
|
||||
// Do logging outside of the lock.
|
||||
if clients_removed > 0 {
|
||||
let size = self
|
||||
.global_connections_count
|
||||
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
|
||||
- clients_removed;
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(clients_removed as i64);
|
||||
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
|
||||
}
|
||||
let removed = current_len - new_len;
|
||||
|
||||
if removed > 0 {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_sub(removed, atomic::Ordering::Relaxed)
|
||||
- removed;
|
||||
info!("pool: performed global pool gc. size now {global_pool_size}");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Option<Client> {
|
||||
let endpoint = conn_info.endpoint_cache_key()?;
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
let client = endpoint_pool.write().get_conn_entry()?;
|
||||
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
Some(Client::new(client.conn, client.aux))
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(
|
||||
self: &Arc<Self>,
|
||||
endpoint: &EndpointCacheKey,
|
||||
) -> Arc<RwLock<EndpointConnPool>> {
|
||||
// fast path
|
||||
if let Some(pool) = self.global_pool.get(endpoint) {
|
||||
return pool.clone();
|
||||
}
|
||||
|
||||
// slow path
|
||||
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
|
||||
conns: VecDeque::new(),
|
||||
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
|
||||
global_connections_count: self.global_connections_count.clone(),
|
||||
}));
|
||||
|
||||
// find or create a pool for this endpoint
|
||||
let mut created = false;
|
||||
let pool = self
|
||||
.global_pool
|
||||
.entry(endpoint.clone())
|
||||
.or_insert_with(|| {
|
||||
created = true;
|
||||
new_pool
|
||||
})
|
||||
.clone();
|
||||
|
||||
// log new global pool size
|
||||
if created {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_add(1, atomic::Ordering::Relaxed)
|
||||
+ 1;
|
||||
info!(
|
||||
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
|
||||
);
|
||||
}
|
||||
|
||||
pool
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn poll_http2_client(
|
||||
global_pool: Arc<GlobalConnPool>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
client: Send,
|
||||
connection: Connect,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let session_id = ctx.session_id();
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
});
|
||||
|
||||
let pool = match conn_info.endpoint_cache_key() {
|
||||
Some(endpoint) => {
|
||||
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
|
||||
|
||||
pool.write().conns.push_back(ConnPoolEntry {
|
||||
conn: client.clone(),
|
||||
conn_id,
|
||||
aux: aux.clone(),
|
||||
});
|
||||
|
||||
Arc::downgrade(&pool)
|
||||
}
|
||||
None => Weak::new(),
|
||||
};
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let _conn_gauge = conn_gauge;
|
||||
let res = connection.await;
|
||||
match res {
|
||||
Ok(()) => info!("connection closed"),
|
||||
Err(e) => error!(%session_id, "connection error: {}", e),
|
||||
}
|
||||
|
||||
// remove from connection pool
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
if pool.write().remove_conn(conn_id) {
|
||||
info!("closed connection removed");
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
|
||||
Client::new(client, aux)
|
||||
}
|
||||
|
||||
pub(crate) struct Client {
|
||||
pub(crate) inner: Send,
|
||||
aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
|
||||
Self { inner, aux }
|
||||
}
|
||||
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: self.aux.endpoint_id,
|
||||
branch_id: self.aux.branch_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,13 +5,13 @@ use bytes::Bytes;
|
||||
|
||||
use anyhow::Context;
|
||||
use http::{Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
|
||||
use serde::Serialize;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
/// Like [`ApiError::into_response`]
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
match this {
|
||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||
@@ -64,17 +64,24 @@ struct HttpErrorBody {
|
||||
|
||||
impl HttpErrorBody {
|
||||
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
|
||||
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
fn response_from_msg_and_status(
|
||||
msg: String,
|
||||
status: StatusCode,
|
||||
) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
HttpErrorBody { msg }.to_response(status)
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
|
||||
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
|
||||
.body(
|
||||
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
@@ -83,14 +90,14 @@ impl HttpErrorBody {
|
||||
pub(crate) fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
.body(Full::new(Bytes::from(json)))
|
||||
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ use futures::future::Either;
|
||||
use futures::StreamExt;
|
||||
use futures::TryFutureExt;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http::Method;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Body;
|
||||
@@ -38,9 +40,11 @@ use url::Url;
|
||||
use urlencoding;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::auth::backend::ComputeCredentials;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
@@ -56,6 +60,7 @@ use crate::usage_metrics::MetricCounterRecorder;
|
||||
use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::backend::LocalProxyConnError;
|
||||
use super::backend::PoolingBackend;
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool::Client;
|
||||
@@ -123,8 +128,8 @@ pub(crate) enum ConnInfoError {
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing password")]
|
||||
MissingPassword,
|
||||
#[error("missing authentication credentials: {0}")]
|
||||
MissingCredentials(Credentials),
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
@@ -133,6 +138,14 @@ pub(crate) enum ConnInfoError {
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum Credentials {
|
||||
#[error("required password")]
|
||||
Password,
|
||||
#[error("required authorization bearer token in JWT format")]
|
||||
BearerJwt,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
@@ -146,6 +159,7 @@ impl UserFacingError for ConnInfoError {
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
tls: Option<&TlsConfig>,
|
||||
@@ -181,21 +195,32 @@ fn get_conn_info(
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||
if !config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
}
|
||||
|
||||
let auth = auth
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||
AuthData::Jwt(
|
||||
auth.strip_prefix("Bearer ")
|
||||
.ok_or(ConnInfoError::MissingPassword)?
|
||||
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
|
||||
.into(),
|
||||
)
|
||||
} else if let Some(pass) = connection_url.password() {
|
||||
// wrong credentials provided
|
||||
if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
}
|
||||
|
||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
})
|
||||
} else if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
} else {
|
||||
return Err(ConnInfoError::MissingPassword);
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
};
|
||||
|
||||
let endpoint = match connection_url.host() {
|
||||
@@ -247,7 +272,7 @@ pub(crate) async fn handle(
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
@@ -279,7 +304,7 @@ pub(crate) async fn handle(
|
||||
|
||||
let mut message = e.to_string_client();
|
||||
let db_error = match &e {
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
@@ -504,7 +529,7 @@ async fn handle_inner(
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
let _requeset_gauge = Metrics::get()
|
||||
.proxy
|
||||
.connection_requests
|
||||
@@ -514,18 +539,50 @@ async fn handle_inner(
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
|
||||
// TLS config should be there.
|
||||
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
|
||||
let conn_info = get_conn_info(
|
||||
&config.authentication_config,
|
||||
ctx,
|
||||
request.headers(),
|
||||
config.tls_config.as_ref(),
|
||||
)?;
|
||||
info!(
|
||||
user = conn_info.conn_info.user_info.user.as_str(),
|
||||
"credentials"
|
||||
);
|
||||
|
||||
match conn_info.auth {
|
||||
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
}
|
||||
auth => {
|
||||
handle_db_inner(
|
||||
cancel,
|
||||
config,
|
||||
ctx,
|
||||
request,
|
||||
conn_info.conn_info,
|
||||
auth,
|
||||
backend,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_db_inner(
|
||||
cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
auth: AuthData,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
|
||||
// Allow connection pooling only if explicitly requested
|
||||
// or if we have decided that http pool is no longer opt-in
|
||||
let allow_pool = !config.http_config.pool_options.opt_in
|
||||
@@ -563,26 +620,36 @@ async fn handle_inner(
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let keys = match &conn_info.auth {
|
||||
let keys = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
.authenticate_with_password(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.conn_info.user_info,
|
||||
pw,
|
||||
&conn_info.user_info,
|
||||
&pw,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
|
||||
.await?
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await?;
|
||||
|
||||
ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
@@ -640,7 +707,11 @@ async fn handle_inner(
|
||||
|
||||
let len = json_output.len();
|
||||
let response = response
|
||||
.body(Full::new(Bytes::from(json_output)))
|
||||
.body(
|
||||
Full::new(Bytes::from(json_output))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
)
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
@@ -656,6 +727,65 @@ async fn handle_inner(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
&AUTHORIZATION,
|
||||
&CONN_STRING,
|
||||
&RAW_TEXT_OUTPUT,
|
||||
&ARRAY_MODE,
|
||||
&TXN_ISOLATION_LEVEL,
|
||||
&TXN_READ_ONLY,
|
||||
&TXN_DEFERRABLE,
|
||||
];
|
||||
|
||||
async fn handle_auth_broker_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
jwt: String,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||
|
||||
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||
|
||||
let (mut parts, body) = request.into_parts();
|
||||
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
|
||||
|
||||
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
|
||||
// these instead just to ensure they remain normalised.
|
||||
for &h in HEADERS_TO_FORWARD {
|
||||
if let Some(hv) = parts.headers.remove(h) {
|
||||
req = req.header(h, hv);
|
||||
}
|
||||
}
|
||||
|
||||
let req = req
|
||||
.body(body)
|
||||
.expect("all headers and params received via hyper should be valid for request");
|
||||
|
||||
// todo: map body to count egress
|
||||
let _metrics = client.metrics();
|
||||
|
||||
Ok(client
|
||||
.inner
|
||||
.send_request(req)
|
||||
.await
|
||||
.map_err(LocalProxyConnError::from)
|
||||
.map_err(HttpConnError::from)?
|
||||
.map(|b| b.boxed()))
|
||||
}
|
||||
|
||||
impl QueryData {
|
||||
async fn process(
|
||||
self,
|
||||
@@ -705,7 +835,9 @@ impl QueryData {
|
||||
// query failed or was cancelled.
|
||||
Ok(Err(error)) => {
|
||||
let db_error = match &error {
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
||||
SqlOverHttpError::ConnectCompute(
|
||||
HttpConnError::PostgresConnectionError(e),
|
||||
)
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user