mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-30 19:40:39 +00:00
move proxy to proxy/code
This commit is contained in:
150
proxy/core/src/auth.rs
Normal file
150
proxy/core/src/auth.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::{
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
|
||||
ComputeUserInfoParseError, IpPattern,
|
||||
};
|
||||
|
||||
mod password_hack;
|
||||
pub use password_hack::parse_endpoint_param;
|
||||
use password_hack::PasswordHackPayload;
|
||||
|
||||
mod flow;
|
||||
pub use flow::*;
|
||||
use tokio::time::error::Elapsed;
|
||||
|
||||
use crate::{
|
||||
console,
|
||||
error::{ReportableError, UserFacingError},
|
||||
};
|
||||
use std::{io, net::IpAddr};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Convenience wrapper for the authentication error.
|
||||
pub type Result<T> = std::result::Result<T, AuthError>;
|
||||
|
||||
/// Common authentication error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
#[error(transparent)]
|
||||
Link(#[from] backend::LinkAuthError),
|
||||
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
|
||||
|
||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
|
||||
#[error("Unsupported authentication method: {0}")]
|
||||
BadAuthMethod(Box<str>),
|
||||
|
||||
#[error("Malformed password message: {0}")]
|
||||
MalformedPassword(&'static str),
|
||||
|
||||
#[error(
|
||||
"Endpoint ID is not specified. \
|
||||
Either please upgrade the postgres client library (libpq) for SNI support \
|
||||
or pass the endpoint ID (first part of the domain name) as a parameter: '?options=endpoint%3D<endpoint-id>'. \
|
||||
See more at https://neon.tech/sni"
|
||||
)]
|
||||
MissingEndpointName,
|
||||
|
||||
#[error("password authentication failed for user '{0}'")]
|
||||
AuthFailed(Box<str>),
|
||||
|
||||
/// Errors produced by e.g. [`crate::stream::PqStream`].
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
|
||||
#[error(
|
||||
"This IP address {0} is not allowed to connect to this endpoint. \
|
||||
Please add it to the allowed list in the Neon console. \
|
||||
Make sure to check for IPv4 or IPv6 addresses."
|
||||
)]
|
||||
IpAddressNotAllowed(IpAddr),
|
||||
|
||||
#[error("Too many connections to this endpoint. Please try again later.")]
|
||||
TooManyConnections,
|
||||
|
||||
#[error("Authentication timed out")]
|
||||
UserTimeout(Elapsed),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub struct AuthError(Box<AuthErrorImpl>);
|
||||
|
||||
impl AuthError {
|
||||
pub fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
|
||||
AuthErrorImpl::BadAuthMethod(name.into()).into()
|
||||
}
|
||||
|
||||
pub fn auth_failed(user: impl Into<Box<str>>) -> Self {
|
||||
AuthErrorImpl::AuthFailed(user.into()).into()
|
||||
}
|
||||
|
||||
pub fn ip_address_not_allowed(ip: IpAddr) -> Self {
|
||||
AuthErrorImpl::IpAddressNotAllowed(ip).into()
|
||||
}
|
||||
|
||||
pub fn too_many_connections() -> Self {
|
||||
AuthErrorImpl::TooManyConnections.into()
|
||||
}
|
||||
|
||||
pub fn is_auth_failed(&self) -> bool {
|
||||
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
|
||||
}
|
||||
|
||||
pub fn user_timeout(elapsed: Elapsed) -> Self {
|
||||
AuthErrorImpl::UserTimeout(elapsed).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
|
||||
fn from(e: E) -> Self {
|
||||
Self(Box::new(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for AuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.to_string_client(),
|
||||
GetAuthInfo(e) => e.to_string_client(),
|
||||
Sasl(e) => e.to_string_client(),
|
||||
AuthFailed(_) => self.to_string(),
|
||||
BadAuthMethod(_) => self.to_string(),
|
||||
MalformedPassword(_) => self.to_string(),
|
||||
MissingEndpointName => self.to_string(),
|
||||
Io(_) => "Internal error".to_string(),
|
||||
IpAddressNotAllowed(_) => self.to_string(),
|
||||
TooManyConnections => self.to_string(),
|
||||
UserTimeout(_) => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for AuthError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.get_error_kind(),
|
||||
GetAuthInfo(e) => e.get_error_kind(),
|
||||
Sasl(e) => e.get_error_kind(),
|
||||
AuthFailed(_) => crate::error::ErrorKind::User,
|
||||
BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
MalformedPassword(_) => crate::error::ErrorKind::User,
|
||||
MissingEndpointName => crate::error::ErrorKind::User,
|
||||
Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
|
||||
TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
UserTimeout(_) => crate::error::ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
847
proxy/core/src/auth/backend.rs
Normal file
847
proxy/core/src/auth/backend.rs
Normal file
@@ -0,0 +1,847 @@
|
||||
mod classic;
|
||||
mod hacks;
|
||||
pub mod jwt;
|
||||
mod link;
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
pub use link::LinkAuthError;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::{validate_password_and_exchange, AuthError};
|
||||
use crate::cache::Cached;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::{AuthSecret, NodeInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
auth::{self, ComputeUserInfoMaybeEndpoint},
|
||||
config::AuthenticationConfig,
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedAllowedIps, CachedNodeInfo},
|
||||
Api,
|
||||
},
|
||||
stream, url,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
|
||||
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
|
||||
pub enum MaybeOwned<'a, T> {
|
||||
Owned(T),
|
||||
Borrowed(&'a T),
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
MaybeOwned::Owned(t) => t,
|
||||
MaybeOwned::Borrowed(t) => t,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
/// which we use in [`crate::config::ProxyConfig`].
|
||||
///
|
||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum BackendType<'a, T, D> {
|
||||
/// Cloud API (V2).
|
||||
Console(MaybeOwned<'a, ConsoleBackend>, T),
|
||||
/// Authentication via a web browser.
|
||||
Link(MaybeOwned<'a, url::ApiUrl>, D),
|
||||
}
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
|
||||
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, _) => match &**api {
|
||||
ConsoleBackend::Console(endpoint) => {
|
||||
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
|
||||
}
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ConsoleBackend::Postgres(endpoint) => {
|
||||
fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
|
||||
}
|
||||
#[cfg(test)]
|
||||
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
},
|
||||
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, D> BackendType<'_, T, D> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
|
||||
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, D> BackendType<'a, T, D> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(c, f(x)),
|
||||
Link(c, x) => Link(c, x),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => x.map(|x| Console(c, x)),
|
||||
Link(c, x) => Ok(Link(c, x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ComputeCredentials {
|
||||
pub info: ComputeUserInfo,
|
||||
pub keys: ComputeCredentialKeys,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfoNoEndpoint {
|
||||
pub user: RoleName,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfo {
|
||||
pub endpoint: EndpointId,
|
||||
pub user: RoleName,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfo {
|
||||
pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
|
||||
self.options.get_cache_key(&self.endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ComputeCredentialKeys {
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
}
|
||||
|
||||
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
|
||||
// user name
|
||||
type Error = ComputeUserInfoNoEndpoint;
|
||||
|
||||
fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
|
||||
match user_info.endpoint_id {
|
||||
None => Err(ComputeUserInfoNoEndpoint {
|
||||
user: user_info.user,
|
||||
options: user_info.options,
|
||||
}),
|
||||
Some(endpoint) => Ok(ComputeUserInfo {
|
||||
endpoint,
|
||||
user: user_info.user,
|
||||
options: user_info.options,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
|
||||
pub struct MaskedIp(IpAddr);
|
||||
|
||||
impl MaskedIp {
|
||||
fn new(value: IpAddr, prefix: u8) -> Self {
|
||||
match value {
|
||||
IpAddr::V4(v4) => Self(IpAddr::V4(
|
||||
Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
|
||||
)),
|
||||
IpAddr::V6(v6) => Self(IpAddr::V6(
|
||||
Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This can't be just per IP because that would limit some PaaS that share IP addresses
|
||||
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
|
||||
|
||||
impl RateBucketInfo {
|
||||
/// All of these are per endpoint-maskedip pair.
|
||||
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
|
||||
///
|
||||
/// First bucket: 1000mcpus total per endpoint-ip pair
|
||||
/// * 4096000 requests per second with 1 hash rounds.
|
||||
/// * 1000 requests per second with 4096 hash rounds.
|
||||
/// * 6.8 requests per second with 600000 hash rounds.
|
||||
pub const DEFAULT_AUTH_SET: [Self; 3] = [
|
||||
Self::new(1000 * 4096, Duration::from_secs(1)),
|
||||
Self::new(600 * 4096, Duration::from_secs(60)),
|
||||
Self::new(300 * 4096, Duration::from_secs(600)),
|
||||
];
|
||||
}
|
||||
|
||||
impl AuthenticationConfig {
|
||||
pub fn check_rate_limit(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
endpoint: &EndpointId,
|
||||
is_cleartext: bool,
|
||||
) -> auth::Result<AuthSecret> {
|
||||
// we have validated the endpoint exists, so let's intern it.
|
||||
let endpoint_int = EndpointIdInt::from(endpoint.normalize());
|
||||
|
||||
// only count the full hash count if password hack or websocket flow.
|
||||
// in other words, if proxy needs to run the hashing
|
||||
let password_weight = if is_cleartext {
|
||||
match &secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => 1,
|
||||
AuthSecret::Scram(s) => s.iterations + 1,
|
||||
}
|
||||
} else {
|
||||
// validating scram takes just 1 hmac_sha_256 operation.
|
||||
1
|
||||
};
|
||||
|
||||
let limit_not_exceeded = self.rate_limiter.check(
|
||||
(
|
||||
endpoint_int,
|
||||
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
|
||||
),
|
||||
password_weight,
|
||||
);
|
||||
|
||||
if !limit_not_exceeded {
|
||||
warn!(
|
||||
enabled = self.rate_limiter_enabled,
|
||||
"rate limiting authentication"
|
||||
);
|
||||
Metrics::get().proxy.requests_auth_rate_limits_total.inc();
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.endpoints_auth_rate_limits
|
||||
.get_metric()
|
||||
.measure(endpoint);
|
||||
|
||||
if self.rate_limiter_enabled {
|
||||
return Err(auth::AuthError::too_many_connections());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(secret)
|
||||
}
|
||||
}
|
||||
|
||||
/// True to its name, this function encapsulates our current auth trade-offs.
|
||||
/// Here, we choose the appropriate auth flow based on circumstances.
|
||||
///
|
||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
||||
async fn auth_quirks(
|
||||
ctx: &RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let (info, unauthenticated_password) = match user_info.try_into() {
|
||||
Err(info) => {
|
||||
let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
|
||||
|
||||
ctx.set_endpoint_id(res.info.endpoint.clone());
|
||||
let password = match res.keys {
|
||||
ComputeCredentialKeys::Password(p) => p,
|
||||
_ => unreachable!("password hack should return a password"),
|
||||
};
|
||||
(res.info, Some(password))
|
||||
}
|
||||
Ok(info) => (info, None),
|
||||
};
|
||||
|
||||
info!("fetching user's authentication info");
|
||||
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
|
||||
|
||||
// check allowed list
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
|
||||
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => api.get_role_secret(ctx, &info).await?,
|
||||
};
|
||||
let (cached_entry, secret) = cached_secret.take_value();
|
||||
|
||||
let secret = match secret {
|
||||
Some(secret) => config.check_rate_limit(
|
||||
ctx,
|
||||
config,
|
||||
secret,
|
||||
&info.endpoint,
|
||||
unauthenticated_password.is_some() || allow_cleartext,
|
||||
)?,
|
||||
None => {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
|
||||
}
|
||||
};
|
||||
|
||||
match authenticate_with_secret(
|
||||
ctx,
|
||||
secret,
|
||||
info,
|
||||
client,
|
||||
unauthenticated_password,
|
||||
allow_cleartext,
|
||||
config,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(keys) => Ok(keys),
|
||||
Err(e) => {
|
||||
if e.is_auth_failed() {
|
||||
// The password could have been changed, so we invalidate the cache.
|
||||
cached_entry.invalidate();
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate_with_secret(
|
||||
ctx: &RequestMonitoring,
|
||||
secret: AuthSecret,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
unauthenticated_password: Option<Vec<u8>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
|
||||
let auth_outcome =
|
||||
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
|
||||
let keys = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => key,
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*info.user));
|
||||
}
|
||||
};
|
||||
|
||||
// we have authenticated the password
|
||||
client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
|
||||
|
||||
return Ok(ComputeCredentials { info, keys });
|
||||
}
|
||||
|
||||
// -- the remaining flows are self-authenticating --
|
||||
|
||||
// Perform cleartext auth if we're allowed to do that.
|
||||
// Currently, we use it for websocket connections (latency).
|
||||
if allow_cleartext {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
|
||||
}
|
||||
|
||||
// Finally, proceed with the main auth flow (SCRAM-based).
|
||||
classic::authenticate(ctx, info, client, config, secret).await
|
||||
}
|
||||
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<EndpointId> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||
Link(_, _) => Some("link".into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get username from the credentials.
|
||||
pub fn get_user(&self) -> &str {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => &user_info.user,
|
||||
Link(_, _) => "link",
|
||||
}
|
||||
}
|
||||
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
ctx: &RequestMonitoring,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
Console(api, user_info) => {
|
||||
info!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let credentials = auth_quirks(
|
||||
ctx,
|
||||
&*api,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
BackendType::Console(api, credentials)
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url, _) => {
|
||||
info!("performing link authentication");
|
||||
|
||||
let info = link::authenticate(ctx, &url, client).await?;
|
||||
|
||||
BackendType::Link(url, info)
|
||||
}
|
||||
};
|
||||
|
||||
info!("user successfully authenticated");
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
pub async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||
Link(_, _) => Ok(Cached::new_uncached(None)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{net::IpAddr, sync::Arc, time::Duration};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_protocol::{
|
||||
authentication::sasl::{ChannelBinding, ScramSha256},
|
||||
message::{backend::Message as PgMessage, frontend},
|
||||
};
|
||||
use provider::AuthSecret;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use crate::{
|
||||
auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern},
|
||||
config::AuthenticationConfig,
|
||||
console::{
|
||||
self,
|
||||
provider::{self, CachedAllowedIps, CachedRoleSecret},
|
||||
CachedNodeInfo,
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
proxy::NeonOptions,
|
||||
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
|
||||
scram::{threadpool::ThreadPool, ServerSecret},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
|
||||
use super::{auth_quirks, AuthRateLimiter};
|
||||
|
||||
struct Auth {
|
||||
ips: Vec<IpPattern>,
|
||||
secret: AuthSecret,
|
||||
}
|
||||
|
||||
impl console::Api for Auth {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
|
||||
{
|
||||
Ok((
|
||||
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
|
||||
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
|
||||
))
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||
thread_pool: ThreadPool::new(1),
|
||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||
rate_limiter_enabled: true,
|
||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||
rate_limit_ip_subnet: 64,
|
||||
});
|
||||
|
||||
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
||||
loop {
|
||||
r.read_buf(&mut *b).await.unwrap();
|
||||
if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
|
||||
break m;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn masked_ip() {
|
||||
let ip_a = IpAddr::V4([127, 0, 0, 1].into());
|
||||
let ip_b = IpAddr::V4([127, 0, 0, 2].into());
|
||||
let ip_c = IpAddr::V4([192, 168, 1, 101].into());
|
||||
let ip_d = IpAddr::V4([192, 168, 1, 102].into());
|
||||
let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
|
||||
let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
|
||||
|
||||
assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
|
||||
assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
|
||||
assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
|
||||
assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
|
||||
|
||||
assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
|
||||
assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_auth_rate_limit_set() {
|
||||
// these values used to exceed u32::MAX
|
||||
assert_eq!(
|
||||
RateBucketInfo::DEFAULT_AUTH_SET,
|
||||
[
|
||||
RateBucketInfo {
|
||||
interval: Duration::from_secs(1),
|
||||
max_rpi: 1000 * 4096,
|
||||
},
|
||||
RateBucketInfo {
|
||||
interval: Duration::from_secs(60),
|
||||
max_rpi: 600 * 4096 * 60,
|
||||
},
|
||||
RateBucketInfo {
|
||||
interval: Duration::from_secs(600),
|
||||
max_rpi: 300 * 4096 * 600,
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
for x in RateBucketInfo::DEFAULT_AUTH_SET {
|
||||
let y = x.to_string().parse().unwrap();
|
||||
assert_eq!(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
};
|
||||
|
||||
let user_info = ComputeUserInfoMaybeEndpoint {
|
||||
user: "conrad".into(),
|
||||
endpoint_id: Some("endpoint".into()),
|
||||
options: NeonOptions::default(),
|
||||
};
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
|
||||
|
||||
let mut read = BytesMut::new();
|
||||
|
||||
// server should offer scram
|
||||
match read_message(&mut client, &mut read).await {
|
||||
PgMessage::AuthenticationSasl(a) => {
|
||||
let options: Vec<&str> = a.mechanisms().collect().unwrap();
|
||||
assert_eq!(options, ["SCRAM-SHA-256"]);
|
||||
}
|
||||
_ => panic!("wrong message"),
|
||||
}
|
||||
|
||||
// client sends client-first-message
|
||||
let mut write = BytesMut::new();
|
||||
frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
|
||||
client.write_all(&write).await.unwrap();
|
||||
|
||||
// server response with server-first-message
|
||||
match read_message(&mut client, &mut read).await {
|
||||
PgMessage::AuthenticationSaslContinue(a) => {
|
||||
scram.update(a.data()).await.unwrap();
|
||||
}
|
||||
_ => panic!("wrong message"),
|
||||
}
|
||||
|
||||
// client response with client-final-message
|
||||
write.clear();
|
||||
frontend::sasl_response(scram.message(), &mut write).unwrap();
|
||||
client.write_all(&write).await.unwrap();
|
||||
|
||||
// server response with server-final-message
|
||||
match read_message(&mut client, &mut read).await {
|
||||
PgMessage::AuthenticationSaslFinal(a) => {
|
||||
scram.finish(a.data()).unwrap();
|
||||
}
|
||||
_ => panic!("wrong message"),
|
||||
}
|
||||
});
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
EndpointRateLimiter::DEFAULT,
|
||||
64,
|
||||
));
|
||||
|
||||
let _creds = auth_quirks(
|
||||
&ctx,
|
||||
&api,
|
||||
user_info,
|
||||
&mut stream,
|
||||
false,
|
||||
&CONFIG,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
};
|
||||
|
||||
let user_info = ComputeUserInfoMaybeEndpoint {
|
||||
user: "conrad".into(),
|
||||
endpoint_id: Some("endpoint".into()),
|
||||
options: NeonOptions::default(),
|
||||
};
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut read = BytesMut::new();
|
||||
let mut write = BytesMut::new();
|
||||
|
||||
// server should offer cleartext
|
||||
match read_message(&mut client, &mut read).await {
|
||||
PgMessage::AuthenticationCleartextPassword => {}
|
||||
_ => panic!("wrong message"),
|
||||
}
|
||||
|
||||
// client responds with password
|
||||
write.clear();
|
||||
frontend::password_message(b"my-secret-password", &mut write).unwrap();
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
EndpointRateLimiter::DEFAULT,
|
||||
64,
|
||||
));
|
||||
|
||||
let _creds = auth_quirks(
|
||||
&ctx,
|
||||
&api,
|
||||
user_info,
|
||||
&mut stream,
|
||||
true,
|
||||
&CONFIG,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
};
|
||||
|
||||
let user_info = ComputeUserInfoMaybeEndpoint {
|
||||
user: "conrad".into(),
|
||||
endpoint_id: None,
|
||||
options: NeonOptions::default(),
|
||||
};
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut read = BytesMut::new();
|
||||
|
||||
// server should offer cleartext
|
||||
match read_message(&mut client, &mut read).await {
|
||||
PgMessage::AuthenticationCleartextPassword => {}
|
||||
_ => panic!("wrong message"),
|
||||
}
|
||||
|
||||
// client responds with password
|
||||
let mut write = BytesMut::new();
|
||||
frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
|
||||
.unwrap();
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
EndpointRateLimiter::DEFAULT,
|
||||
64,
|
||||
));
|
||||
|
||||
let creds = auth_quirks(
|
||||
&ctx,
|
||||
&api,
|
||||
user_info,
|
||||
&mut stream,
|
||||
true,
|
||||
&CONFIG,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(creds.info.endpoint, "my-endpoint");
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
72
proxy/core/src/auth/backend/classic.rs
Normal file
72
proxy/core/src/auth/backend/classic.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use super::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::{
|
||||
auth::{self, backend::ComputeCredentialKeys, AuthFlow},
|
||||
compute,
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
sasl,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &RequestMonitoring,
|
||||
creds: ComputeUserInfo,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
info!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
info!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret, ctx);
|
||||
|
||||
let auth_outcome = tokio::time::timeout(
|
||||
config.scram_protocol_timeout,
|
||||
async {
|
||||
|
||||
flow.begin(scram).await.map_err(|error| {
|
||||
warn!(?error, "error sending scram acknowledgement");
|
||||
error
|
||||
})?.authenticate().await.map_err(|error| {
|
||||
warn!(?error, "error processing scram messages");
|
||||
error
|
||||
})
|
||||
}
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
|
||||
auth::AuthError::user_timeout(e)
|
||||
})??;
|
||||
|
||||
let client_key = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*creds.user));
|
||||
}
|
||||
};
|
||||
|
||||
compute::ScramKeys {
|
||||
client_key: client_key.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ComputeCredentials {
|
||||
info: creds,
|
||||
keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
|
||||
scram_keys,
|
||||
)),
|
||||
})
|
||||
}
|
||||
90
proxy/core/src/auth/backend/hacks.rs
Normal file
90
proxy/core/src/auth/backend/hacks.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use super::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
|
||||
};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow},
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
intern::EndpointIdInt,
|
||||
sasl,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Compared to [SCRAM](crate::scram), cleartext password auth saves
|
||||
/// one round trip and *expensive* computations (>= 4096 HMAC iterations).
|
||||
/// These properties are benefical for serverless JS workers, so we
|
||||
/// use this mechanism for websocket connections.
|
||||
pub async fn authenticate_cleartext(
|
||||
ctx: &RequestMonitoring,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
secret: AuthSecret,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
|
||||
let auth_flow = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword {
|
||||
secret,
|
||||
endpoint: ep,
|
||||
pool: config.thread_pool.clone(),
|
||||
})
|
||||
.await?;
|
||||
drop(paused);
|
||||
// cleartext auth is only allowed to the ws/http protocol.
|
||||
// If we're here, we already received the password in the first message.
|
||||
// Scram protocol will be executed on the proxy side.
|
||||
let auth_outcome = auth_flow.authenticate().await?;
|
||||
|
||||
let keys = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*info.user));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ComputeCredentials { info, keys })
|
||||
}
|
||||
|
||||
/// Workaround for clients which don't provide an endpoint (project) name.
|
||||
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
||||
/// and passwords are not yet validated (we don't know how to validate them!)
|
||||
pub async fn password_hack_no_authentication(
|
||||
ctx: &RequestMonitoring,
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
.await?
|
||||
.get_password()
|
||||
.await?;
|
||||
|
||||
info!(project = &*payload.endpoint, "received missing parameter");
|
||||
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
user: info.user,
|
||||
options: info.options,
|
||||
endpoint: payload.endpoint,
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password(payload.password),
|
||||
})
|
||||
}
|
||||
554
proxy/core/src/auth/backend/jwt.rs
Normal file
554
proxy/core/src/auth/backend/jwt.rs
Normal file
@@ -0,0 +1,554 @@
|
||||
use std::{future::Future, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use dashmap::DashMap;
|
||||
use jose_jwk::crypto::KeyInfo;
|
||||
use signature::Verifier;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::{http::parse_json_body_with_limit, intern::EndpointIdInt};
|
||||
|
||||
// TODO(conrad): make these configurable.
|
||||
const MIN_RENEW: Duration = Duration::from_secs(30);
|
||||
const AUTO_RENEW: Duration = Duration::from_secs(300);
|
||||
const MAX_RENEW: Duration = Duration::from_secs(3600);
|
||||
const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
|
||||
|
||||
/// How to get the JWT auth rules
|
||||
pub trait FetchAuthRules: Clone + Send + Sync + 'static {
|
||||
fn fetch_auth_rules(&self) -> impl Future<Output = anyhow::Result<AuthRules>> + Send;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FetchAuthRulesFromCplane {
|
||||
#[allow(dead_code)]
|
||||
endpoint: EndpointIdInt,
|
||||
}
|
||||
|
||||
impl FetchAuthRules for FetchAuthRulesFromCplane {
|
||||
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
|
||||
Err(anyhow::anyhow!("not yet implemented"))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthRules {
|
||||
jwks_urls: Vec<url::Url>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct JwkCache {
|
||||
client: reqwest::Client,
|
||||
|
||||
map: DashMap<EndpointIdInt, Arc<JwkCacheEntryLock>>,
|
||||
}
|
||||
|
||||
pub struct JwkCacheEntryLock {
|
||||
cached: ArcSwapOption<JwkCacheEntry>,
|
||||
lookup: tokio::sync::Semaphore,
|
||||
}
|
||||
|
||||
impl Default for JwkCacheEntryLock {
|
||||
fn default() -> Self {
|
||||
JwkCacheEntryLock {
|
||||
cached: ArcSwapOption::empty(),
|
||||
lookup: tokio::sync::Semaphore::new(1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JwkCacheEntry {
|
||||
/// Should refetch at least every hour to verify when old keys have been removed.
|
||||
/// Should refetch when new key IDs are seen only every 5 minutes or so
|
||||
last_retrieved: Instant,
|
||||
|
||||
/// cplane will return multiple JWKs urls that we need to scrape.
|
||||
key_sets: ahash::HashMap<url::Url, jose_jwk::JwkSet>,
|
||||
}
|
||||
|
||||
impl JwkCacheEntryLock {
|
||||
async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
|
||||
JwkRenewalPermit::acquire_permit(self).await
|
||||
}
|
||||
|
||||
fn try_acquire_permit<'a>(self: &'a Arc<Self>) -> Option<JwkRenewalPermit<'a>> {
|
||||
JwkRenewalPermit::try_acquire_permit(self)
|
||||
}
|
||||
|
||||
async fn renew_jwks<F: FetchAuthRules>(
|
||||
&self,
|
||||
_permit: JwkRenewalPermit<'_>,
|
||||
client: &reqwest::Client,
|
||||
auth_rules: &F,
|
||||
) -> anyhow::Result<Arc<JwkCacheEntry>> {
|
||||
// double check that no one beat us to updating the cache.
|
||||
let now = Instant::now();
|
||||
let guard = self.cached.load_full();
|
||||
if let Some(cached) = guard {
|
||||
let last_update = now.duration_since(cached.last_retrieved);
|
||||
if last_update < Duration::from_secs(300) {
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
|
||||
let rules = auth_rules.fetch_auth_rules().await?;
|
||||
let mut key_sets = ahash::HashMap::with_capacity_and_hasher(
|
||||
rules.jwks_urls.len(),
|
||||
ahash::RandomState::new(),
|
||||
);
|
||||
// TODO(conrad): run concurrently
|
||||
for url in rules.jwks_urls {
|
||||
let req = client.get(url.clone());
|
||||
// TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
|
||||
match req.send().await.and_then(|r| r.error_for_status()) {
|
||||
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
|
||||
// I expect these failures would be quite sparse.
|
||||
Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"),
|
||||
Ok(r) => {
|
||||
let resp: http::Response<reqwest::Body> = r.into();
|
||||
match parse_json_body_with_limit::<jose_jwk::JwkSet>(
|
||||
resp.into_body(),
|
||||
MAX_JWK_BODY_SIZE,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"),
|
||||
Ok(jwks) => {
|
||||
key_sets.insert(url, jwks);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let entry = Arc::new(JwkCacheEntry {
|
||||
last_retrieved: now,
|
||||
key_sets,
|
||||
});
|
||||
self.cached.swap(Some(Arc::clone(&entry)));
|
||||
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
async fn get_or_update_jwk_cache<F: FetchAuthRules>(
|
||||
self: &Arc<Self>,
|
||||
client: &reqwest::Client,
|
||||
fetch: &F,
|
||||
) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
|
||||
let now = Instant::now();
|
||||
let guard = self.cached.load_full();
|
||||
|
||||
// if we have no cached JWKs, try and get some
|
||||
let Some(cached) = guard else {
|
||||
let permit = self.acquire_permit().await;
|
||||
return self.renew_jwks(permit, client, fetch).await;
|
||||
};
|
||||
|
||||
let last_update = now.duration_since(cached.last_retrieved);
|
||||
|
||||
// check if the cached JWKs need updating.
|
||||
if last_update > MAX_RENEW {
|
||||
let permit = self.acquire_permit().await;
|
||||
|
||||
// it's been too long since we checked the keys. wait for them to update.
|
||||
return self.renew_jwks(permit, client, fetch).await;
|
||||
}
|
||||
|
||||
// every 5 minutes we should spawn a job to eagerly update the token.
|
||||
if last_update > AUTO_RENEW {
|
||||
if let Some(permit) = self.try_acquire_permit() {
|
||||
tracing::debug!("JWKs should be renewed. Renewal permit acquired");
|
||||
let permit = permit.into_owned();
|
||||
let entry = self.clone();
|
||||
let client = client.clone();
|
||||
let fetch = fetch.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await {
|
||||
tracing::warn!(error=?e, "could not fetch JWKs in background job");
|
||||
}
|
||||
});
|
||||
} else {
|
||||
tracing::debug!("JWKs should be renewed. Renewal permit already taken, skipping");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(cached)
|
||||
}
|
||||
|
||||
async fn check_jwt<F: FetchAuthRules>(
|
||||
self: &Arc<Self>,
|
||||
jwt: String,
|
||||
client: &reqwest::Client,
|
||||
fetch: &F,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// JWT compact form is defined to be
|
||||
// <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
|
||||
// where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
|
||||
|
||||
let (header_payload, signature) = jwt
|
||||
.rsplit_once(".")
|
||||
.context("not a valid compact JWT encoding")?;
|
||||
let (header, _payload) = header_payload
|
||||
.split_once(".")
|
||||
.context("not a valid compact JWT encoding")?;
|
||||
|
||||
let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)
|
||||
.context("not a valid compact JWT encoding")?;
|
||||
let header = serde_json::from_slice::<JWTHeader>(&header)
|
||||
.context("not a valid compact JWT encoding")?;
|
||||
|
||||
ensure!(header.typ == "JWT");
|
||||
let kid = header.kid.context("missing key id")?;
|
||||
|
||||
let mut guard = self.get_or_update_jwk_cache(client, fetch).await?;
|
||||
|
||||
// get the key from the JWKs if possible. If not, wait for the keys to update.
|
||||
let jwk = loop {
|
||||
let jwk = guard
|
||||
.key_sets
|
||||
.values()
|
||||
.flat_map(|jwks| &jwks.keys)
|
||||
.find(|jwk| jwk.prm.kid.as_deref() == Some(kid));
|
||||
|
||||
match jwk {
|
||||
Some(jwk) => break jwk,
|
||||
None if guard.last_retrieved.elapsed() > MIN_RENEW => {
|
||||
let permit = self.acquire_permit().await;
|
||||
guard = self.renew_jwks(permit, client, fetch).await?;
|
||||
}
|
||||
_ => {
|
||||
bail!("jwk not found");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ensure!(
|
||||
jwk.is_supported(&header.alg),
|
||||
"signature algorithm not supported"
|
||||
);
|
||||
|
||||
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
|
||||
.context("not a valid compact JWT encoding")?;
|
||||
match &jwk.key {
|
||||
jose_jwk::Key::Ec(key) => {
|
||||
verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
|
||||
}
|
||||
jose_jwk::Key::Rsa(key) => {
|
||||
verify_rsa_signature(header_payload.as_bytes(), &sig, key, &jwk.prm.alg)?;
|
||||
}
|
||||
key => bail!("unsupported key type {key:?}"),
|
||||
};
|
||||
|
||||
// TODO(conrad): verify iss, exp, nbf, etc...
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl JwkCache {
|
||||
pub async fn check_jwt(
|
||||
&self,
|
||||
endpoint: EndpointIdInt,
|
||||
jwt: String,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// try with just a read lock first
|
||||
let entry = self.map.get(&endpoint).as_deref().map(Arc::clone);
|
||||
let entry = match entry {
|
||||
Some(entry) => entry,
|
||||
None => {
|
||||
// acquire a write lock after to insert.
|
||||
let entry = self.map.entry(endpoint).or_default();
|
||||
Arc::clone(&*entry)
|
||||
}
|
||||
};
|
||||
|
||||
let fetch = FetchAuthRulesFromCplane { endpoint };
|
||||
entry.check_jwt(jwt, &self.client, &fetch).await
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_ec_signature(data: &[u8], sig: &[u8], key: &jose_jwk::Ec) -> anyhow::Result<()> {
|
||||
use ecdsa::Signature;
|
||||
use signature::Verifier;
|
||||
|
||||
match key.crv {
|
||||
jose_jwk::EcCurves::P256 => {
|
||||
let pk =
|
||||
p256::PublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid P256 key"))?;
|
||||
let key = p256::ecdsa::VerifyingKey::from(&pk);
|
||||
let sig = Signature::from_slice(sig)?;
|
||||
key.verify(data, &sig)?;
|
||||
}
|
||||
key => bail!("unsupported ec key type {key:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn verify_rsa_signature(
|
||||
data: &[u8],
|
||||
sig: &[u8],
|
||||
key: &jose_jwk::Rsa,
|
||||
alg: &Option<jose_jwa::Algorithm>,
|
||||
) -> anyhow::Result<()> {
|
||||
use jose_jwa::{Algorithm, Signing};
|
||||
use rsa::{
|
||||
pkcs1v15::{Signature, VerifyingKey},
|
||||
RsaPublicKey,
|
||||
};
|
||||
|
||||
let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?;
|
||||
|
||||
match alg {
|
||||
Some(Algorithm::Signing(Signing::Rs256)) => {
|
||||
let key = VerifyingKey::<sha2::Sha256>::new(key);
|
||||
let sig = Signature::try_from(sig)?;
|
||||
key.verify(data, &sig)?;
|
||||
}
|
||||
_ => bail!("invalid RSA signing algorithm"),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
|
||||
#[derive(serde::Deserialize, serde::Serialize)]
|
||||
struct JWTHeader<'a> {
|
||||
/// must be "JWT"
|
||||
typ: &'a str,
|
||||
/// must be a supported alg
|
||||
alg: jose_jwa::Algorithm,
|
||||
/// key id, must be provided for our usecase
|
||||
kid: Option<&'a str>,
|
||||
}
|
||||
|
||||
struct JwkRenewalPermit<'a> {
|
||||
inner: Option<JwkRenewalPermitInner<'a>>,
|
||||
}
|
||||
|
||||
enum JwkRenewalPermitInner<'a> {
|
||||
Owned(Arc<JwkCacheEntryLock>),
|
||||
Borrowed(&'a Arc<JwkCacheEntryLock>),
|
||||
}
|
||||
|
||||
impl JwkRenewalPermit<'_> {
|
||||
fn into_owned(mut self) -> JwkRenewalPermit<'static> {
|
||||
JwkRenewalPermit {
|
||||
inner: self.inner.take().map(JwkRenewalPermitInner::into_owned),
|
||||
}
|
||||
}
|
||||
|
||||
async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit {
|
||||
match from.lookup.acquire().await {
|
||||
Ok(permit) => {
|
||||
permit.forget();
|
||||
JwkRenewalPermit {
|
||||
inner: Some(JwkRenewalPermitInner::Borrowed(from)),
|
||||
}
|
||||
}
|
||||
Err(_) => panic!("semaphore should not be closed"),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit> {
|
||||
match from.lookup.try_acquire() {
|
||||
Ok(permit) => {
|
||||
permit.forget();
|
||||
Some(JwkRenewalPermit {
|
||||
inner: Some(JwkRenewalPermitInner::Borrowed(from)),
|
||||
})
|
||||
}
|
||||
Err(tokio::sync::TryAcquireError::NoPermits) => None,
|
||||
Err(tokio::sync::TryAcquireError::Closed) => panic!("semaphore should not be closed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl JwkRenewalPermitInner<'_> {
|
||||
fn into_owned(self) -> JwkRenewalPermitInner<'static> {
|
||||
match self {
|
||||
JwkRenewalPermitInner::Owned(p) => JwkRenewalPermitInner::Owned(p),
|
||||
JwkRenewalPermitInner::Borrowed(p) => JwkRenewalPermitInner::Owned(Arc::clone(p)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for JwkRenewalPermit<'_> {
|
||||
fn drop(&mut self) {
|
||||
let entry = match &self.inner {
|
||||
None => return,
|
||||
Some(JwkRenewalPermitInner::Owned(p)) => p,
|
||||
Some(JwkRenewalPermitInner::Borrowed(p)) => *p,
|
||||
};
|
||||
entry.lookup.add_permits(1);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use std::{future::IntoFuture, net::SocketAddr, time::SystemTime};
|
||||
|
||||
use base64::URL_SAFE_NO_PAD;
|
||||
use bytes::Bytes;
|
||||
use http::Response;
|
||||
use http_body_util::Full;
|
||||
use hyper1::service::service_fn;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use rand::rngs::OsRng;
|
||||
use signature::Signer;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
|
||||
let sk = p256::SecretKey::random(&mut OsRng);
|
||||
let pk = sk.public_key().into();
|
||||
let jwk = jose_jwk::Jwk {
|
||||
key: jose_jwk::Key::Ec(pk),
|
||||
prm: jose_jwk::Parameters {
|
||||
kid: Some(kid),
|
||||
alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
(sk, jwk)
|
||||
}
|
||||
|
||||
fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) {
|
||||
let sk = rsa::RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
|
||||
let pk = sk.to_public_key().into();
|
||||
let jwk = jose_jwk::Jwk {
|
||||
key: jose_jwk::Key::Rsa(pk),
|
||||
prm: jose_jwk::Parameters {
|
||||
kid: Some(kid),
|
||||
alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
(sk, jwk)
|
||||
}
|
||||
|
||||
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
|
||||
let header = JWTHeader {
|
||||
typ: "JWT",
|
||||
alg: jose_jwa::Algorithm::Signing(sig),
|
||||
kid: Some(&kid),
|
||||
};
|
||||
let body = typed_json::json! {{
|
||||
"exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
|
||||
}};
|
||||
|
||||
let header =
|
||||
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
|
||||
let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
|
||||
|
||||
format!("{header}.{body}")
|
||||
}
|
||||
|
||||
fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String {
|
||||
use p256::ecdsa::{Signature, SigningKey};
|
||||
|
||||
let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
|
||||
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
|
||||
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
|
||||
|
||||
format!("{payload}.{sig}")
|
||||
}
|
||||
|
||||
fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
|
||||
use rsa::pkcs1v15::SigningKey;
|
||||
use rsa::signature::SignatureEncoding;
|
||||
|
||||
let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
|
||||
let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
|
||||
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
|
||||
|
||||
format!("{payload}.{sig}")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn renew() {
|
||||
let (rs1, jwk1) = new_rsa_jwk("1".into());
|
||||
let (rs2, jwk2) = new_rsa_jwk("2".into());
|
||||
let (ec1, jwk3) = new_ec_jwk("3".into());
|
||||
let (ec2, jwk4) = new_ec_jwk("4".into());
|
||||
|
||||
let jwt1 = new_rsa_jwt("1".into(), rs1);
|
||||
let jwt2 = new_rsa_jwt("2".into(), rs2);
|
||||
let jwt3 = new_ec_jwt("3".into(), ec1);
|
||||
let jwt4 = new_ec_jwt("4".into(), ec2);
|
||||
|
||||
let foo_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk1, jwk3],
|
||||
};
|
||||
let bar_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk2, jwk4],
|
||||
};
|
||||
|
||||
let service = service_fn(move |req| {
|
||||
let foo_jwks = foo_jwks.clone();
|
||||
let bar_jwks = bar_jwks.clone();
|
||||
async move {
|
||||
let jwks = match req.uri().path() {
|
||||
"/foo" => &foo_jwks,
|
||||
"/bar" => &bar_jwks,
|
||||
_ => {
|
||||
return Response::builder()
|
||||
.status(404)
|
||||
.body(Full::new(Bytes::new()));
|
||||
}
|
||||
};
|
||||
let body = serde_json::to_vec(jwks).unwrap();
|
||||
Response::builder()
|
||||
.status(200)
|
||||
.body(Full::new(Bytes::from(body)))
|
||||
}
|
||||
});
|
||||
|
||||
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
|
||||
let server = hyper1::server::conn::http1::Builder::new();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (s, _) = listener.accept().await.unwrap();
|
||||
let serve = server.serve_connection(TokioIo::new(s), service.clone());
|
||||
tokio::spawn(serve.into_future());
|
||||
}
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Fetch(SocketAddr);
|
||||
|
||||
impl FetchAuthRules for Fetch {
|
||||
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
|
||||
Ok(AuthRules {
|
||||
jwks_urls: vec![
|
||||
format!("http://{}/foo", self.0).parse().unwrap(),
|
||||
format!("http://{}/bar", self.0).parse().unwrap(),
|
||||
],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
|
||||
|
||||
jwk_cache
|
||||
.check_jwt(jwt1, &client, &Fetch(addr))
|
||||
.await
|
||||
.unwrap();
|
||||
jwk_cache
|
||||
.check_jwt(jwt2, &client, &Fetch(addr))
|
||||
.await
|
||||
.unwrap();
|
||||
jwk_cache
|
||||
.check_jwt(jwt3, &client, &Fetch(addr))
|
||||
.await
|
||||
.unwrap();
|
||||
jwk_cache
|
||||
.check_jwt(jwt4, &client, &Fetch(addr))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
127
proxy/core/src/auth/backend/link.rs
Normal file
127
proxy/core/src/auth/backend/link.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
use crate::{
|
||||
auth, compute,
|
||||
console::{self, provider::NodeInfo},
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
stream::PqStream,
|
||||
waiters,
|
||||
};
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{info, info_span};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LinkAuthError {
|
||||
#[error(transparent)]
|
||||
WaiterRegister(#[from] waiters::RegisterError),
|
||||
|
||||
#[error(transparent)]
|
||||
WaiterWait(#[from] waiters::WaitError),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl UserFacingError for LinkAuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
"Internal error".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for LinkAuthError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
|
||||
LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
|
||||
LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
"Welcome to Neon!\n",
|
||||
"Authenticate by visiting:\n",
|
||||
" {redirect_uri}{session_id}\n\n",
|
||||
],
|
||||
redirect_uri = redirect_uri,
|
||||
session_id = session_id,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &RequestMonitoring,
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<NodeInfo> {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Web);
|
||||
|
||||
// registering waiter can fail if we get unlucky with rng.
|
||||
// just try again.
|
||||
let (psql_session_id, waiter) = loop {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
|
||||
match console::mgmt::get_waiter(&psql_session_id) {
|
||||
Ok(waiter) => break (psql_session_id, waiter),
|
||||
Err(_e) => continue,
|
||||
}
|
||||
};
|
||||
|
||||
let span = info_span!("link", psql_session_id = &psql_session_id);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
|
||||
// Give user a URL to spawn a new database.
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::NoticeResponse(&greeting))
|
||||
.await?;
|
||||
|
||||
// Wait for web console response (see `mgmt`).
|
||||
info!(parent: &span, "waiting for console's reply...");
|
||||
let db_info = waiter.await.map_err(LinkAuthError::from)?;
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
|
||||
// This config should be self-contained, because we won't
|
||||
// take username or dbname from client's startup message.
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
ctx.set_dbname(db_info.dbname.into());
|
||||
ctx.set_user(db_info.user.into());
|
||||
ctx.set_project(db_info.aux.clone());
|
||||
info!("woken up a compute node");
|
||||
|
||||
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
|
||||
// while direct connections do not. Once we migrate to pg_sni_proxy
|
||||
// everywhere, we can remove this.
|
||||
if db_info.host.contains("--") {
|
||||
// we need TLS connection with SNI info to properly route it
|
||||
config.ssl_mode(SslMode::Require);
|
||||
} else {
|
||||
config.ssl_mode(SslMode::Disable);
|
||||
}
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password.as_ref());
|
||||
}
|
||||
|
||||
Ok(NodeInfo {
|
||||
config,
|
||||
aux: db_info.aux,
|
||||
allow_self_signed_compute: false, // caller may override
|
||||
})
|
||||
}
|
||||
533
proxy/core/src/auth/credentials.rs
Normal file
533
proxy/core/src/auth/credentials.rs
Normal file
@@ -0,0 +1,533 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use crate::{
|
||||
auth::password_hack::parse_endpoint_param,
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
metrics::{Metrics, SniKind},
|
||||
proxy::NeonOptions,
|
||||
serverless::SERVERLESS_DRIVER_SNI,
|
||||
EndpointId, RoleName,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use std::{collections::HashSet, net::IpAddr, str::FromStr};
|
||||
use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
||||
pub enum ComputeUserInfoParseError {
|
||||
#[error("Parameter '{0}' is missing in startup packet.")]
|
||||
MissingKey(&'static str),
|
||||
|
||||
#[error(
|
||||
"Inconsistent project name inferred from \
|
||||
SNI ('{}') and project option ('{}').",
|
||||
.domain, .option,
|
||||
)]
|
||||
InconsistentProjectNames {
|
||||
domain: EndpointId,
|
||||
option: EndpointId,
|
||||
},
|
||||
|
||||
#[error(
|
||||
"Common name inferred from SNI ('{}') is not known",
|
||||
.cn,
|
||||
)]
|
||||
UnknownCommonName { cn: String },
|
||||
|
||||
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
|
||||
MalformedProjectName(EndpointId),
|
||||
}
|
||||
|
||||
impl UserFacingError for ComputeUserInfoParseError {}
|
||||
|
||||
impl ReportableError for ComputeUserInfoParseError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ComputeUserInfoMaybeEndpoint {
|
||||
pub user: RoleName,
|
||||
pub endpoint_id: Option<EndpointId>,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
#[inline]
|
||||
pub fn endpoint(&self) -> Option<&str> {
|
||||
self.endpoint_id.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint_sni(
|
||||
sni: &str,
|
||||
common_names: &HashSet<String>,
|
||||
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
|
||||
let Some((subdomain, common_name)) = sni.split_once('.') else {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
|
||||
};
|
||||
if !common_names.contains(common_name) {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName {
|
||||
cn: common_name.into(),
|
||||
});
|
||||
}
|
||||
if subdomain == SERVERLESS_DRIVER_SNI {
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(EndpointId::from(subdomain)))
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
pub fn parse(
|
||||
ctx: &RequestMonitoring,
|
||||
params: &StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_names: Option<&HashSet<String>>,
|
||||
) -> Result<Self, ComputeUserInfoParseError> {
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
let user: RoleName = get_param("user")?.into();
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let endpoint_option = params
|
||||
.options_raw()
|
||||
.and_then(|options| {
|
||||
// We support both `project` (deprecated) and `endpoint` options for backward compatibility.
|
||||
// However, if both are present, we don't exactly know which one to use.
|
||||
// Therefore we require that only one of them is present.
|
||||
options
|
||||
.filter_map(parse_endpoint_param)
|
||||
.at_most_one()
|
||||
.ok()?
|
||||
})
|
||||
.map(|name| name.into());
|
||||
|
||||
let endpoint_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
endpoint_sni(sni_str, cn)?
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let endpoint = match (endpoint_option, endpoint_from_domain) {
|
||||
// Invariant: if we have both project name variants, they should match.
|
||||
(Some(option), Some(domain)) if option != domain => {
|
||||
Some(Err(InconsistentProjectNames { domain, option }))
|
||||
}
|
||||
// Invariant: project name may not contain certain characters.
|
||||
(a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
|
||||
false => Err(MalformedProjectName(name)),
|
||||
true => Ok(name),
|
||||
}),
|
||||
}
|
||||
.transpose()?;
|
||||
|
||||
if let Some(ep) = &endpoint {
|
||||
ctx.set_endpoint_id(ep.clone());
|
||||
}
|
||||
|
||||
let metrics = Metrics::get();
|
||||
info!(%user, "credentials");
|
||||
if sni.is_some() {
|
||||
info!("Connection with sni");
|
||||
metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
|
||||
} else if endpoint.is_some() {
|
||||
metrics
|
||||
.proxy
|
||||
.accepted_connections_by_sni
|
||||
.inc(SniKind::NoSni);
|
||||
info!("Connection without sni");
|
||||
} else {
|
||||
metrics
|
||||
.proxy
|
||||
.accepted_connections_by_sni
|
||||
.inc(SniKind::PasswordHack);
|
||||
info!("Connection with password hack");
|
||||
}
|
||||
|
||||
let options = NeonOptions::parse_params(params);
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
endpoint_id: endpoint,
|
||||
options,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
|
||||
ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum IpPattern {
|
||||
Subnet(ipnet::IpNet),
|
||||
Range(IpAddr, IpAddr),
|
||||
Single(IpAddr),
|
||||
None,
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Deserialize<'de> for IpPattern {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct StrVisitor;
|
||||
impl<'de> serde::de::Visitor<'de> for StrVisitor {
|
||||
type Value = IpPattern;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(parse_ip_pattern(v).unwrap_or_else(|e| {
|
||||
warn!("Cannot parse ip pattern {v}: {e}");
|
||||
IpPattern::None
|
||||
}))
|
||||
}
|
||||
}
|
||||
deserializer.deserialize_str(StrVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for IpPattern {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
parse_ip_pattern(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ip_pattern(pattern: &str) -> anyhow::Result<IpPattern> {
|
||||
if pattern.contains('/') {
|
||||
let subnet: ipnet::IpNet = pattern.parse()?;
|
||||
return Ok(IpPattern::Subnet(subnet));
|
||||
}
|
||||
if let Some((start, end)) = pattern.split_once('-') {
|
||||
let start: IpAddr = start.parse()?;
|
||||
let end: IpAddr = end.parse()?;
|
||||
return Ok(IpPattern::Range(start, end));
|
||||
}
|
||||
let addr: IpAddr = pattern.parse()?;
|
||||
Ok(IpPattern::Single(addr))
|
||||
}
|
||||
|
||||
fn check_ip(ip: &IpAddr, pattern: &IpPattern) -> bool {
|
||||
match pattern {
|
||||
IpPattern::Subnet(subnet) => subnet.contains(ip),
|
||||
IpPattern::Range(start, end) => start <= ip && ip <= end,
|
||||
IpPattern::Single(addr) => addr == ip,
|
||||
IpPattern::None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn project_name_valid(name: &str) -> bool {
|
||||
name.chars().all(|c| c.is_alphanumeric() || c == '-')
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
#[test]
|
||||
fn parse_bare_minimum() -> anyhow::Result<()> {
|
||||
// According to postgresql, only `user` should be required.
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_excessive() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("database", "world"), // should be ignored
|
||||
("foo", "bar"), // should be ignored
|
||||
]);
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_project_from_sni() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
let sni = Some("foo.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
|
||||
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_project_from_options() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_endpoint_from_options() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("options", "-ckey=1 endpoint=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_three_endpoints_from_options() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
(
|
||||
"options",
|
||||
"-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off",
|
||||
),
|
||||
]);
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.endpoint_id.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
|
||||
]);
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.endpoint_id.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_projects_identical() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
|
||||
|
||||
let sni = Some("baz.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_multi_common_names() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.a.com");
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.b.com");
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_projects_different() {
|
||||
let options =
|
||||
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
|
||||
|
||||
let sni = Some("second.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
InconsistentProjectNames { domain, option } => {
|
||||
assert_eq!(option, "first");
|
||||
assert_eq!(domain, "second");
|
||||
}
|
||||
_ => panic!("bad error: {err:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_inconsistent_sni() {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
UnknownCommonName { cn } => {
|
||||
assert_eq!(cn, "localhost");
|
||||
}
|
||||
_ => panic!("bad error: {err:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neon_options() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([
|
||||
("user", "john_doe"),
|
||||
("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
|
||||
]);
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
|
||||
assert_eq!(
|
||||
user_info.options.get_cache_key("project"),
|
||||
"project endpoint_type:read_write lsn:0/2"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_peer_addr_is_in_list() {
|
||||
fn check(v: serde_json::Value) -> bool {
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let ip_list: Vec<IpPattern> = serde_json::from_value(v).unwrap();
|
||||
check_peer_addr_is_in_list(&peer_addr, &ip_list)
|
||||
}
|
||||
|
||||
assert!(check(json!([])));
|
||||
assert!(check(json!(["127.0.0.1"])));
|
||||
assert!(!check(json!(["8.8.8.8"])));
|
||||
// If there is an incorrect address, it will be skipped.
|
||||
assert!(check(json!(["88.8.8", "127.0.0.1"])));
|
||||
}
|
||||
#[test]
|
||||
fn test_parse_ip_v4() -> anyhow::Result<()> {
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
// Ok
|
||||
assert_eq!(parse_ip_pattern("127.0.0.1")?, IpPattern::Single(peer_addr));
|
||||
assert_eq!(
|
||||
parse_ip_pattern("127.0.0.1/31")?,
|
||||
IpPattern::Subnet(ipnet::IpNet::new(peer_addr, 31)?)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_ip_pattern("0.0.0.0-200.0.1.2")?,
|
||||
IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
|
||||
);
|
||||
|
||||
// Error
|
||||
assert!(parse_ip_pattern("300.0.1.2").is_err());
|
||||
assert!(parse_ip_pattern("30.1.2").is_err());
|
||||
assert!(parse_ip_pattern("127.0.0.1/33").is_err());
|
||||
assert!(parse_ip_pattern("127.0.0.1-127.0.3").is_err());
|
||||
assert!(parse_ip_pattern("1234.0.0.1-127.0.3.0").is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_ipv4() -> anyhow::Result<()> {
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let peer_addr_next = IpAddr::from([127, 0, 0, 2]);
|
||||
let peer_addr_prev = IpAddr::from([127, 0, 0, 0]);
|
||||
// Success
|
||||
assert!(check_ip(&peer_addr, &IpPattern::Single(peer_addr)));
|
||||
assert!(check_ip(
|
||||
&peer_addr,
|
||||
&IpPattern::Subnet(ipnet::IpNet::new(peer_addr_prev, 31)?)
|
||||
));
|
||||
assert!(check_ip(
|
||||
&peer_addr,
|
||||
&IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 30)?)
|
||||
));
|
||||
assert!(check_ip(
|
||||
&peer_addr,
|
||||
&IpPattern::Range(IpAddr::from([0, 0, 0, 0]), IpAddr::from([200, 0, 1, 2]))
|
||||
));
|
||||
assert!(check_ip(
|
||||
&peer_addr,
|
||||
&IpPattern::Range(peer_addr, peer_addr)
|
||||
));
|
||||
|
||||
// Not success
|
||||
assert!(!check_ip(&peer_addr, &IpPattern::Single(peer_addr_prev)));
|
||||
assert!(!check_ip(
|
||||
&peer_addr,
|
||||
&IpPattern::Subnet(ipnet::IpNet::new(peer_addr_next, 31)?)
|
||||
));
|
||||
assert!(!check_ip(
|
||||
&peer_addr,
|
||||
&IpPattern::Range(IpAddr::from([0, 0, 0, 0]), peer_addr_prev)
|
||||
));
|
||||
assert!(!check_ip(
|
||||
&peer_addr,
|
||||
&IpPattern::Range(peer_addr_next, IpAddr::from([128, 0, 0, 0]))
|
||||
));
|
||||
// There is no check that for range start <= end. But it's fine as long as for all this cases the result is false.
|
||||
assert!(!check_ip(
|
||||
&peer_addr,
|
||||
&IpPattern::Range(peer_addr, peer_addr_prev)
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
226
proxy/core/src/auth/flow.rs
Normal file
226
proxy/core/src/auth/flow.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
intern::EndpointIdInt,
|
||||
sasl,
|
||||
scram::{self, threadpool::ThreadPool},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use std::{io, sync::Arc};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
pub struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a RequestMonitoring);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
|
||||
if channel_binding {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
} else {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
|
||||
scram::METHODS_WITHOUT_PLUS,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
|
||||
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
pub struct PasswordHack;
|
||||
|
||||
impl AuthMethod for PasswordHack {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// Use clear-text password auth called `password` in docs
|
||||
/// <https://www.postgresql.org/docs/current/auth-password.html>
|
||||
pub struct CleartextPassword {
|
||||
pub pool: Arc<ThreadPool>,
|
||||
pub endpoint: EndpointIdInt,
|
||||
pub secret: AuthSecret,
|
||||
}
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, S, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
state: State,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
|
||||
Self {
|
||||
stream,
|
||||
state: Begin,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
self.stream
|
||||
.write_message(&method.first_message(self.tls_server_end_point.supported()))
|
||||
.await?;
|
||||
|
||||
Ok(AuthFlow {
|
||||
stream: self.stream,
|
||||
state: method,
|
||||
tls_server_end_point: self.tls_server_end_point,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
|
||||
let payload = PasswordHackPayload::parse(password)
|
||||
// If we ended up here and the payload is malformed, it means that
|
||||
// the user neither enabled SNI nor resorted to any other method
|
||||
// for passing the project name we rely on. We should show them
|
||||
// the most helpful error message and point to the documentation.
|
||||
.ok_or(AuthErrorImpl::MissingEndpointName)?;
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
|
||||
let outcome = validate_password_and_exchange(
|
||||
&self.state.pool,
|
||||
self.state.endpoint,
|
||||
password,
|
||||
self.state.secret,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
||||
let Scram(secret, ctx) = self.state;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg)
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
|
||||
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
if !scram::METHODS.contains(&sasl.method) {
|
||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||
}
|
||||
|
||||
match sasl.method {
|
||||
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
|
||||
_ => {}
|
||||
}
|
||||
info!("client chooses {}", sasl.method);
|
||||
|
||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
rand::random,
|
||||
self.tls_server_end_point,
|
||||
))
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_password_and_exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
password: &[u8],
|
||||
secret: AuthSecret,
|
||||
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
// test only
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
|
||||
password.to_owned(),
|
||||
)))
|
||||
}
|
||||
// perform scram authentication as both client and server to validate the keys
|
||||
AuthSecret::Scram(scram_secret) => {
|
||||
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
||||
|
||||
let client_key = match outcome {
|
||||
sasl::Outcome::Success(client_key) => client_key,
|
||||
sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
|
||||
};
|
||||
|
||||
let keys = crate::compute::ScramKeys {
|
||||
client_key: client_key.as_bytes(),
|
||||
server_key: scram_secret.server_key.as_bytes(),
|
||||
};
|
||||
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
|
||||
tokio_postgres::config::AuthKeys::ScramSha256(keys),
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
121
proxy/core/src/auth/password_hack.rs
Normal file
121
proxy/core/src/auth/password_hack.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
//! Payload for ad hoc authentication method for clients that don't support SNI.
|
||||
//! See the `impl` for [`super::backend::BackendType<ClientCredentials>`].
|
||||
//! Read more: <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified.
|
||||
|
||||
use bstr::ByteSlice;
|
||||
|
||||
use crate::EndpointId;
|
||||
|
||||
pub struct PasswordHackPayload {
|
||||
pub endpoint: EndpointId,
|
||||
pub password: Vec<u8>,
|
||||
}
|
||||
|
||||
impl PasswordHackPayload {
|
||||
pub fn parse(bytes: &[u8]) -> Option<Self> {
|
||||
// The format is `project=<utf-8>;<password-bytes>` or `project=<utf-8>$<password-bytes>`.
|
||||
let separators = [";", "$"];
|
||||
for sep in separators {
|
||||
if let Some((endpoint, password)) = bytes.split_once_str(sep) {
|
||||
let endpoint = endpoint.to_str().ok()?;
|
||||
return Some(Self {
|
||||
endpoint: parse_endpoint_param(endpoint)?.into(),
|
||||
password: password.to_owned(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_endpoint_param(bytes: &str) -> Option<&str> {
|
||||
bytes
|
||||
.strip_prefix("project=")
|
||||
.or_else(|| bytes.strip_prefix("endpoint="))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_endpoint_param_fn() {
|
||||
let input = "";
|
||||
assert!(parse_endpoint_param(input).is_none());
|
||||
|
||||
let input = "project=";
|
||||
assert_eq!(parse_endpoint_param(input), Some(""));
|
||||
|
||||
let input = "project=foobar";
|
||||
assert_eq!(parse_endpoint_param(input), Some("foobar"));
|
||||
|
||||
let input = "endpoint=";
|
||||
assert_eq!(parse_endpoint_param(input), Some(""));
|
||||
|
||||
let input = "endpoint=foobar";
|
||||
assert_eq!(parse_endpoint_param(input), Some("foobar"));
|
||||
|
||||
let input = "other_option=foobar";
|
||||
assert!(parse_endpoint_param(input).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_hack_payload_project() {
|
||||
let bytes = b"";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"project=";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"project=;";
|
||||
let payload: PasswordHackPayload =
|
||||
PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "");
|
||||
assert_eq!(payload.password, b"");
|
||||
|
||||
let bytes = b"project=foobar;pass;word";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "foobar");
|
||||
assert_eq!(payload.password, b"pass;word");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_hack_payload_endpoint() {
|
||||
let bytes = b"";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"endpoint=";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"endpoint=;";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "");
|
||||
assert_eq!(payload.password, b"");
|
||||
|
||||
let bytes = b"endpoint=foobar;pass;word";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "foobar");
|
||||
assert_eq!(payload.password, b"pass;word");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_hack_payload_dollar() {
|
||||
let bytes = b"";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"endpoint=";
|
||||
assert!(PasswordHackPayload::parse(bytes).is_none());
|
||||
|
||||
let bytes = b"endpoint=$";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "");
|
||||
assert_eq!(payload.password, b"");
|
||||
|
||||
let bytes = b"endpoint=foobar$pass$word";
|
||||
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
|
||||
assert_eq!(payload.endpoint, "foobar");
|
||||
assert_eq!(payload.password, b"pass$word");
|
||||
}
|
||||
}
|
||||
296
proxy/core/src/bin/pg_sni_router.rs
Normal file
296
proxy/core/src/bin/pg_sni_router.rs
Normal file
@@ -0,0 +1,296 @@
|
||||
/// A stand-alone program that routes connections, e.g. from
|
||||
/// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
|
||||
///
|
||||
/// This allows connecting to pods/services running in the same Kubernetes cluster from
|
||||
/// the outside. Similar to an ingress controller for HTTPS.
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use futures::future::Either;
|
||||
use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use proxy::context::RequestMonitoring;
|
||||
use proxy::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
|
||||
use rustls::pki_types::PrivateKeyDer;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use anyhow::{anyhow, bail, ensure, Context};
|
||||
use clap::Arg;
|
||||
use futures::TryFutureExt;
|
||||
use proxy::stream::{PqStream, Stream};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
|
||||
use tracing::{error, info, Instrument};
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
fn cli() -> clap::Command {
|
||||
clap::Command::new("Neon proxy/router")
|
||||
.version(GIT_VERSION)
|
||||
.arg(
|
||||
Arg::new("listen")
|
||||
.short('l')
|
||||
.long("listen")
|
||||
.help("listen for incoming client connections on ip:port")
|
||||
.default_value("127.0.0.1:4432"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("tls-key")
|
||||
.short('k')
|
||||
.long("tls-key")
|
||||
.help("path to TLS key for client postgres connections")
|
||||
.required(true),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("tls-cert")
|
||||
.short('c')
|
||||
.long("tls-cert")
|
||||
.help("path to TLS cert for client postgres connections")
|
||||
.required(true),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("dest")
|
||||
.short('d')
|
||||
.long("destination")
|
||||
.help("append this domain zone to the SNI hostname to get the destination address")
|
||||
.required(true),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let _logging_guard = proxy::logging::init().await?;
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
||||
|
||||
let args = cli().get_matches();
|
||||
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
|
||||
|
||||
// Configure TLS
|
||||
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
(Some(key_path), Some(cert_path)) => {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
|
||||
let mut keys =
|
||||
rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
PrivateKeyDer::Pkcs8(
|
||||
keys.pop()
|
||||
.unwrap()
|
||||
.context(format!("Failed to read TLS keys at '{key_path}'"))?,
|
||||
)
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain: Vec<_> = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
|
||||
})?
|
||||
};
|
||||
|
||||
// needed for channel bindings
|
||||
let first_cert = cert_chain.first().context("missing certificate")?;
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[
|
||||
&rustls::version::TLS13,
|
||||
&rustls::version::TLS12,
|
||||
])
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into();
|
||||
|
||||
(tls_config, tls_server_end_point)
|
||||
}
|
||||
_ => bail!("tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
// Start listening for incoming client connections
|
||||
let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
|
||||
info!("Starting sni router on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let main = tokio::spawn(task_main(
|
||||
Arc::new(destination),
|
||||
tls_config,
|
||||
tls_server_end_point,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token));
|
||||
|
||||
// the signal task cant ever succeed.
|
||||
// the main task can error, or can succeed on cancellation.
|
||||
// we want to immediately exit on either of these cases
|
||||
let signal = match futures::future::select(signals_task, main).await {
|
||||
Either::Left((res, _)) => proxy::flatten_err(res)?,
|
||||
Either::Right((res, _)) => return proxy::flatten_err(res),
|
||||
};
|
||||
|
||||
// maintenance tasks return `Infallible` success values, this is an impossible value
|
||||
// so this match statically ensures that there are no possibilities for that value
|
||||
match signal {}
|
||||
}
|
||||
|
||||
async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let tls_config = Arc::clone(&tls_config);
|
||||
let dest_suffix = Arc::clone(&dest_suffix);
|
||||
|
||||
connections.spawn(
|
||||
async move {
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
info!(%peer_addr, "serving");
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr.ip(),
|
||||
proxy::metrics::Protocol::SniRouter,
|
||||
"sni",
|
||||
);
|
||||
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
})
|
||||
.instrument(tracing::info_span!("handle_client", ?session_id)),
|
||||
);
|
||||
}
|
||||
|
||||
connections.close();
|
||||
drop(listener);
|
||||
|
||||
connections.wait().await;
|
||||
|
||||
info!("all client connections have finished");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestMonitoring,
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use pq_proto::FeStartupPacket::*;
|
||||
|
||||
match msg {
|
||||
SslRequest { direct: false } => {
|
||||
stream
|
||||
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
|
||||
.await?;
|
||||
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empy.
|
||||
// However, you could imagine pipelining of postgres
|
||||
// SSLRequest + TLS ClientHello in one hunk similar to
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
})
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
?unexpected,
|
||||
"unexpected startup packet, rejecting connection"
|
||||
);
|
||||
stream
|
||||
.throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User)
|
||||
.await?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
ctx: RequestMonitoring,
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
|
||||
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
|
||||
let dest: Vec<&str> = sni
|
||||
.split_once('.')
|
||||
.context("invalid SNI")?
|
||||
.0
|
||||
.splitn(3, "--")
|
||||
.collect();
|
||||
let port = dest[2].parse::<u16>().context("invalid port")?;
|
||||
let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
|
||||
|
||||
info!("destination: {}", destination);
|
||||
|
||||
let mut client = tokio::net::TcpStream::connect(destination).await?;
|
||||
|
||||
// doesn't yet matter as pg-sni-router doesn't report analytics logs
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
|
||||
match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
}
|
||||
730
proxy/core/src/bin/proxy.rs
Normal file
730
proxy/core/src/bin/proxy.rs
Normal file
@@ -0,0 +1,730 @@
|
||||
use aws_config::environment::EnvironmentVariableCredentialsProvider;
|
||||
use aws_config::imds::credentials::ImdsCredentialsProvider;
|
||||
use aws_config::meta::credentials::CredentialsProviderChain;
|
||||
use aws_config::meta::region::RegionProviderChain;
|
||||
use aws_config::profile::ProfileFileCredentialsProvider;
|
||||
use aws_config::provider_config::ProviderConfig;
|
||||
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
|
||||
use aws_config::Region;
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::AuthRateLimiter;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::cancellation::CancelMap;
|
||||
use proxy::cancellation::CancellationHandler;
|
||||
use proxy::config::remote_storage_from_toml;
|
||||
use proxy::config::AuthenticationConfig;
|
||||
use proxy::config::CacheOptions;
|
||||
use proxy::config::HttpConfig;
|
||||
use proxy::config::ProjectInfoCacheOptions;
|
||||
use proxy::console;
|
||||
use proxy::context::parquet::ParquetUploadArgs;
|
||||
use proxy::http;
|
||||
use proxy::http::health_server::AppMetrics;
|
||||
use proxy::metrics::Metrics;
|
||||
use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::LeakyBucketConfig;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
use proxy::rate_limiter::WakeComputeRateLimiter;
|
||||
use proxy::redis::cancellation_publisher::RedisPublisherClient;
|
||||
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use proxy::redis::elasticache;
|
||||
use proxy::redis::notifications;
|
||||
use proxy::scram::threadpool::ThreadPool;
|
||||
use proxy::serverless::cancel_set::CancelSet;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::usage_metrics;
|
||||
|
||||
use anyhow::bail;
|
||||
use proxy::config::{self, ProxyConfig};
|
||||
use proxy::serverless;
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
use tracing::Instrument;
|
||||
use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
project_build_tag!(BUILD_TAG);
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[global_allocator]
|
||||
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||
|
||||
#[derive(Clone, Debug, ValueEnum)]
|
||||
enum AuthBackend {
|
||||
Console,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres,
|
||||
Link,
|
||||
}
|
||||
|
||||
/// Neon proxy/router
|
||||
#[derive(Parser)]
|
||||
#[command(version = GIT_VERSION, about)]
|
||||
struct ProxyCliArgs {
|
||||
/// Name of the region this proxy is deployed in
|
||||
#[clap(long, default_value_t = String::new())]
|
||||
region: String,
|
||||
/// listen for incoming client connections on ip:port
|
||||
#[clap(short, long, default_value = "127.0.0.1:4432")]
|
||||
proxy: String,
|
||||
#[clap(value_enum, long, default_value_t = AuthBackend::Link)]
|
||||
auth_backend: AuthBackend,
|
||||
/// listen for management callback connection on ip:port
|
||||
#[clap(short, long, default_value = "127.0.0.1:7000")]
|
||||
mgmt: String,
|
||||
/// listen for incoming http connections (metrics, etc) on ip:port
|
||||
#[clap(long, default_value = "127.0.0.1:7001")]
|
||||
http: String,
|
||||
/// listen for incoming wss connections on ip:port
|
||||
#[clap(long)]
|
||||
wss: Option<String>,
|
||||
/// redirect unauthenticated users to the given uri in case of link auth
|
||||
#[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
|
||||
uri: String,
|
||||
/// cloud API endpoint for authenticating users
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
default_value = "http://localhost:3000/authenticate_proxy_request/"
|
||||
)]
|
||||
auth_endpoint: String,
|
||||
/// path to TLS key for client postgres connections
|
||||
///
|
||||
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||
#[clap(short = 'k', long, alias = "ssl-key")]
|
||||
tls_key: Option<String>,
|
||||
/// path to TLS cert for client postgres connections
|
||||
///
|
||||
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||
#[clap(short = 'c', long, alias = "ssl-cert")]
|
||||
tls_cert: Option<String>,
|
||||
/// path to directory with TLS certificates for client postgres connections
|
||||
#[clap(long)]
|
||||
certs_dir: Option<String>,
|
||||
/// timeout for the TLS handshake
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
handshake_timeout: tokio::time::Duration,
|
||||
/// http endpoint to receive periodic metric updates
|
||||
#[clap(long)]
|
||||
metric_collection_endpoint: Option<String>,
|
||||
/// how often metrics should be sent to a collection endpoint
|
||||
#[clap(long)]
|
||||
metric_collection_interval: Option<String>,
|
||||
/// cache for `wake_compute` api method (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
wake_compute_cache: String,
|
||||
/// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
|
||||
#[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
|
||||
wake_compute_lock: String,
|
||||
/// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
|
||||
#[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)]
|
||||
connect_compute_lock: String,
|
||||
/// Allow self-signed certificates for compute nodes (for testing)
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
allow_self_signed_compute: bool,
|
||||
#[clap(flatten)]
|
||||
sql_over_http: SqlOverHttpArgs,
|
||||
/// timeout for scram authentication protocol
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
scram_protocol_timeout: tokio::time::Duration,
|
||||
/// size of the threadpool for password hashing
|
||||
#[clap(long, default_value_t = 4)]
|
||||
scram_thread_pool_size: u8,
|
||||
/// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
require_client_ip: bool,
|
||||
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
|
||||
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
disable_dynamic_rate_limiter: bool,
|
||||
/// Endpoint rate limiter max number of requests per second.
|
||||
///
|
||||
/// Provided in the form '<Requests Per Second>@<Bucket Duration Size>'.
|
||||
/// Can be given multiple times for different bucket sizes.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
|
||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Wake compute rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
wake_compute_limit: Vec<RateBucketInfo>,
|
||||
/// Whether the auth rate limiter actually takes effect (for testing)
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
auth_rate_limit_enabled: bool,
|
||||
/// Authentication rate limiter max number of hashes per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
|
||||
auth_rate_limit: Vec<RateBucketInfo>,
|
||||
/// The IP subnet to use when considering whether two IP addresses are considered the same.
|
||||
#[clap(long, default_value_t = 64)]
|
||||
auth_rate_limit_ip_subnet: u8,
|
||||
/// Redis rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
redis_rps_limit: Vec<RateBucketInfo>,
|
||||
/// cache for `allowed_ips` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
allowed_ips_cache: String,
|
||||
/// cache for `role_secret` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
role_secret_cache: String,
|
||||
/// disable ip check for http requests. If it is too time consuming, it could be turned off.
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
disable_ip_check_for_http: bool,
|
||||
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
|
||||
#[clap(long)]
|
||||
redis_notifications: Option<String>,
|
||||
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
|
||||
#[clap(long, default_value = "irsa")]
|
||||
redis_auth_type: String,
|
||||
/// redis host for streaming connections (might be different from the notifications host)
|
||||
#[clap(long)]
|
||||
redis_host: Option<String>,
|
||||
/// redis port for streaming connections (might be different from the notifications host)
|
||||
#[clap(long)]
|
||||
redis_port: Option<u16>,
|
||||
/// redis cluster name, used in aws elasticache
|
||||
#[clap(long)]
|
||||
redis_cluster_name: Option<String>,
|
||||
/// redis user_id, used in aws elasticache
|
||||
#[clap(long)]
|
||||
redis_user_id: Option<String>,
|
||||
/// aws region to retrieve credentials
|
||||
#[clap(long, default_value_t = String::new())]
|
||||
aws_region: String,
|
||||
/// cache for `project_info` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
project_info_cache: String,
|
||||
/// cache for all valid endpoints
|
||||
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
|
||||
endpoint_cache_config: String,
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
|
||||
/// interval for backup metric collection
|
||||
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
|
||||
metric_backup_collection_interval: std::time::Duration,
|
||||
/// remote storage configuration for backup metric collection
|
||||
/// Encoded as toml (same format as pageservers), eg
|
||||
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
|
||||
#[clap(long, value_parser = remote_storage_from_toml)]
|
||||
metric_backup_collection_remote_storage: Option<RemoteStorageConfig>,
|
||||
/// chunk size for backup metric collection
|
||||
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
|
||||
#[clap(long, default_value = "4194304")]
|
||||
metric_backup_collection_chunk_size: usize,
|
||||
/// Whether to retry the connection to the compute node
|
||||
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
|
||||
connect_to_compute_retry: String,
|
||||
/// Whether to retry the wake_compute request
|
||||
#[clap(long, default_value = config::RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)]
|
||||
wake_compute_retry: String,
|
||||
}
|
||||
|
||||
#[derive(clap::Args, Clone, Copy, Debug)]
|
||||
struct SqlOverHttpArgs {
|
||||
/// timeout for http connection requests
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
sql_over_http_timeout: tokio::time::Duration,
|
||||
|
||||
/// Whether the SQL over http pool is opt-in
|
||||
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
sql_over_http_pool_opt_in: bool,
|
||||
|
||||
/// How many connections to pool for each endpoint. Excess connections are discarded
|
||||
#[clap(long, default_value_t = 20)]
|
||||
sql_over_http_pool_max_conns_per_endpoint: usize,
|
||||
|
||||
/// How many connections to pool for each endpoint. Excess connections are discarded
|
||||
#[clap(long, default_value_t = 20000)]
|
||||
sql_over_http_pool_max_total_conns: usize,
|
||||
|
||||
/// How long pooled connections should remain idle for before closing
|
||||
#[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
|
||||
sql_over_http_idle_timeout: tokio::time::Duration,
|
||||
|
||||
/// Duration each shard will wait on average before a GC sweep.
|
||||
/// A longer time will causes sweeps to take longer but will interfere less frequently.
|
||||
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
|
||||
sql_over_http_pool_gc_epoch: tokio::time::Duration,
|
||||
|
||||
/// How many shards should the global pool have. Must be a power of two.
|
||||
/// More shards will introduce less contention for pool operations, but can
|
||||
/// increase memory used by the pool
|
||||
#[clap(long, default_value_t = 128)]
|
||||
sql_over_http_pool_shards: usize,
|
||||
|
||||
#[clap(long, default_value_t = 10000)]
|
||||
sql_over_http_client_conn_threshold: u64,
|
||||
|
||||
#[clap(long, default_value_t = 64)]
|
||||
sql_over_http_cancel_set_shards: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let _logging_guard = proxy::logging::init().await?;
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
info!("Version: {GIT_VERSION}");
|
||||
info!("Build_tag: {BUILD_TAG}");
|
||||
let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo {
|
||||
revision: GIT_VERSION,
|
||||
build_tag: BUILD_TAG,
|
||||
});
|
||||
|
||||
let jemalloc = match proxy::jemalloc::MetricRecorder::new() {
|
||||
Ok(t) => Some(t),
|
||||
Err(e) => {
|
||||
tracing::error!(error = ?e, "could not start jemalloc metrics loop");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let args = ProxyCliArgs::parse();
|
||||
let config = build_config(&args)?;
|
||||
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
info!("Using region: {}", args.aws_region);
|
||||
|
||||
let region_provider =
|
||||
RegionProviderChain::default_provider().or_else(Region::new(args.aws_region.clone()));
|
||||
let provider_conf =
|
||||
ProviderConfig::without_region().with_region(region_provider.region().await);
|
||||
let aws_credentials_provider = {
|
||||
// uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"
|
||||
CredentialsProviderChain::first_try("env", EnvironmentVariableCredentialsProvider::new())
|
||||
// uses "AWS_PROFILE" / `aws sso login --profile <profile>`
|
||||
.or_else(
|
||||
"profile-sso",
|
||||
ProfileFileCredentialsProvider::builder()
|
||||
.configure(&provider_conf)
|
||||
.build(),
|
||||
)
|
||||
// uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME"
|
||||
// needed to access remote extensions bucket
|
||||
.or_else(
|
||||
"token",
|
||||
WebIdentityTokenCredentialsProvider::builder()
|
||||
.configure(&provider_conf)
|
||||
.build(),
|
||||
)
|
||||
// uses imds v2
|
||||
.or_else("imds", ImdsCredentialsProvider::builder().build())
|
||||
};
|
||||
let elasticache_credentials_provider = Arc::new(elasticache::CredentialsProvider::new(
|
||||
elasticache::AWSIRSAConfig::new(
|
||||
args.aws_region.clone(),
|
||||
args.redis_cluster_name,
|
||||
args.redis_user_id,
|
||||
),
|
||||
aws_credentials_provider,
|
||||
));
|
||||
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
|
||||
("plain", redis_url) => match redis_url {
|
||||
None => {
|
||||
bail!("plain auth requires redis_notifications to be set");
|
||||
}
|
||||
Some(url) => Some(
|
||||
ConnectionWithCredentialsProvider::new_with_static_credentials(url.to_string()),
|
||||
),
|
||||
},
|
||||
("irsa", _) => match (&args.redis_host, args.redis_port) {
|
||||
(Some(host), Some(port)) => Some(
|
||||
ConnectionWithCredentialsProvider::new_with_credentials_provider(
|
||||
host.to_string(),
|
||||
port,
|
||||
elasticache_credentials_provider.clone(),
|
||||
),
|
||||
),
|
||||
(None, None) => {
|
||||
warn!("irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client");
|
||||
None
|
||||
}
|
||||
_ => {
|
||||
bail!("redis-host and redis-port must be specified together");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
bail!("unknown auth type given");
|
||||
}
|
||||
};
|
||||
|
||||
let redis_notifications_client = if let Some(url) = args.redis_notifications {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.to_string()))
|
||||
} else {
|
||||
regional_redis_client.clone()
|
||||
};
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
let http_address: SocketAddr = args.http.parse()?;
|
||||
info!("Starting http on {http_address}");
|
||||
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
|
||||
|
||||
let mgmt_address: SocketAddr = args.mgmt.parse()?;
|
||||
info!("Starting mgmt on {mgmt_address}");
|
||||
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
|
||||
|
||||
let proxy_address: SocketAddr = args.proxy.parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let cancel_map = CancelMap::default();
|
||||
|
||||
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
|
||||
RateBucketInfo::validate(redis_rps_limit)?;
|
||||
|
||||
let redis_publisher = match ®ional_redis_client {
|
||||
Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
|
||||
redis_publisher.clone(),
|
||||
args.region.clone(),
|
||||
redis_rps_limit,
|
||||
)?))),
|
||||
None => None,
|
||||
};
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<
|
||||
Option<Arc<tokio::sync::Mutex<RedisPublisherClient>>>,
|
||||
>::new(
|
||||
cancel_map.clone(),
|
||||
redis_publisher,
|
||||
proxy::metrics::CancellationSource::FromClient,
|
||||
));
|
||||
|
||||
// bit of a hack - find the min rps and max rps supported and turn it into
|
||||
// leaky bucket config instead
|
||||
let max = args
|
||||
.endpoint_rps_limit
|
||||
.iter()
|
||||
.map(|x| x.rps())
|
||||
.max_by(f64::total_cmp)
|
||||
.unwrap_or(EndpointRateLimiter::DEFAULT.max);
|
||||
let rps = args
|
||||
.endpoint_rps_limit
|
||||
.iter()
|
||||
.map(|x| x.rps())
|
||||
.min_by(f64::total_cmp)
|
||||
.unwrap_or(EndpointRateLimiter::DEFAULT.rps);
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
LeakyBucketConfig { rps, max },
|
||||
64,
|
||||
));
|
||||
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
if let Some(serverless_address) = args.wss {
|
||||
let serverless_address: SocketAddr = serverless_address.parse()?;
|
||||
info!("Starting wss on {serverless_address}");
|
||||
let serverless_listener = TcpListener::bind(serverless_address).await?;
|
||||
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
client_tasks.spawn(proxy::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
));
|
||||
|
||||
// maintenance tasks. these never return unless there's an error
|
||||
let mut maintenance_tasks = JoinSet::new();
|
||||
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone()));
|
||||
maintenance_tasks.spawn(http::health_server::task_main(
|
||||
http_listener,
|
||||
AppMetrics {
|
||||
jemalloc,
|
||||
neon_metrics,
|
||||
proxy: proxy::metrics::Metrics::get(),
|
||||
},
|
||||
));
|
||||
maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
|
||||
|
||||
if let Some(metrics_config) = &config.metric_collection {
|
||||
// TODO: Add gc regardles of the metric collection being enabled.
|
||||
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
|
||||
client_tasks.spawn(usage_metrics::task_backup(
|
||||
&metrics_config.backup_metric_collection_config,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
if let auth::BackendType::Console(api, _) = &config.auth_backend {
|
||||
if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
|
||||
match (redis_notifications_client, regional_redis_client.clone()) {
|
||||
(None, None) => {}
|
||||
(client1, client2) => {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
}
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = regional_redis_client;
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks.spawn(
|
||||
async move { cache.do_read(con, cancellation_token.clone()).await }
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let maintenance = loop {
|
||||
// get one complete task
|
||||
match futures::future::select(
|
||||
pin!(maintenance_tasks.join_next()),
|
||||
pin!(client_tasks.join_next()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
// exit immediately on maintenance task completion
|
||||
Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
|
||||
// exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
|
||||
Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
|
||||
// exit immediately on client task error
|
||||
Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
|
||||
// exit if all our client tasks have shutdown gracefully
|
||||
Either::Right((None, _)) => return Ok(()),
|
||||
}
|
||||
};
|
||||
|
||||
// maintenance tasks return Infallible success values, this is an impossible value
|
||||
// so this match statically ensures that there are no possibilities for that value
|
||||
match maintenance {}
|
||||
}
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
||||
Metrics::install(thread_pool.metrics.clone());
|
||||
|
||||
let tls_config = match (&args.tls_key, &args.tls_cert) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||
key_path,
|
||||
cert_path,
|
||||
args.certs_dir.as_ref(),
|
||||
)?),
|
||||
(None, None) => None,
|
||||
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
if args.allow_self_signed_compute {
|
||||
warn!("allowing self-signed compute certificates");
|
||||
}
|
||||
let backup_metric_collection_config = config::MetricBackupCollectionConfig {
|
||||
interval: args.metric_backup_collection_interval,
|
||||
remote_storage_config: args.metric_backup_collection_remote_storage.clone(),
|
||||
chunk_size: args.metric_backup_collection_chunk_size,
|
||||
};
|
||||
|
||||
let metric_collection = match (
|
||||
&args.metric_collection_endpoint,
|
||||
&args.metric_collection_interval,
|
||||
) {
|
||||
(Some(endpoint), Some(interval)) => Some(config::MetricCollectionConfig {
|
||||
endpoint: endpoint.parse()?,
|
||||
interval: humantime::parse_duration(interval)?,
|
||||
backup_metric_collection_config,
|
||||
}),
|
||||
(None, None) => None,
|
||||
_ => bail!(
|
||||
"either both or neither metric-collection-endpoint \
|
||||
and metric-collection-interval must be specified"
|
||||
),
|
||||
};
|
||||
if !args.disable_dynamic_rate_limiter {
|
||||
bail!("dynamic rate limiter should be disabled");
|
||||
}
|
||||
|
||||
let auth_backend = match &args.auth_backend {
|
||||
AuthBackend::Console => {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse()?;
|
||||
let endpoint_cache_config: config::EndpointCacheConfig =
|
||||
args.endpoint_cache_config.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
|
||||
info!(
|
||||
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
|
||||
);
|
||||
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
|
||||
let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
endpoint_cache_config,
|
||||
)));
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.wake_compute_lock.parse()?;
|
||||
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
|
||||
let locks = Box::leak(Box::new(console::locks::ApiLocks::new(
|
||||
"wake_compute_lock",
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().wake_compute_lock,
|
||||
)?));
|
||||
tokio::spawn(locks.garbage_collect_worker());
|
||||
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let endpoint = http::Endpoint::new(url, http::new_client());
|
||||
|
||||
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
|
||||
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
|
||||
let wake_compute_endpoint_rate_limiter =
|
||||
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
|
||||
let api = console::provider::neon::Api::new(
|
||||
endpoint,
|
||||
caches,
|
||||
locks,
|
||||
wake_compute_endpoint_rate_limiter,
|
||||
);
|
||||
let api = console::provider::ConsoleBackend::Console(api);
|
||||
auth::BackendType::Console(MaybeOwned::Owned(api), ())
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
AuthBackend::Postgres => {
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let api = console::provider::mock::Api::new(url);
|
||||
let api = console::provider::ConsoleBackend::Postgres(api);
|
||||
auth::BackendType::Console(MaybeOwned::Owned(api), ())
|
||||
}
|
||||
AuthBackend::Link => {
|
||||
let url = args.uri.parse()?;
|
||||
auth::BackendType::Link(MaybeOwned::Owned(url), ())
|
||||
}
|
||||
};
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.connect_compute_lock.parse()?;
|
||||
info!(
|
||||
?limiter,
|
||||
shards,
|
||||
?epoch,
|
||||
"Using NodeLocks (connect_compute)"
|
||||
);
|
||||
let connect_compute_locks = console::locks::ApiLocks::new(
|
||||
"connect_compute_lock",
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().proxy.connect_compute_lock,
|
||||
)?;
|
||||
|
||||
let http_config = HttpConfig {
|
||||
pool_options: GlobalConnPoolOptions {
|
||||
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
|
||||
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
|
||||
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
|
||||
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
|
||||
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
|
||||
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
|
||||
},
|
||||
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
|
||||
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
metric_collection,
|
||||
allow_self_signed_compute: args.allow_self_signed_compute,
|
||||
http_config,
|
||||
authentication_config,
|
||||
require_client_ip: args.require_client_ip,
|
||||
handshake_timeout: args.handshake_timeout,
|
||||
region: args.region.clone(),
|
||||
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute_retry_config: config::RetryConfig::parse(
|
||||
&args.connect_to_compute_retry,
|
||||
)?,
|
||||
}));
|
||||
|
||||
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use clap::Parser;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
|
||||
#[test]
|
||||
fn parse_endpoint_rps_limit() {
|
||||
let config = super::ProxyCliArgs::parse_from([
|
||||
"proxy",
|
||||
"--endpoint-rps-limit",
|
||||
"100@1s",
|
||||
"--endpoint-rps-limit",
|
||||
"20@30s",
|
||||
]);
|
||||
|
||||
assert_eq!(
|
||||
config.endpoint_rps_limit,
|
||||
vec![
|
||||
RateBucketInfo::new(100, Duration::from_secs(1)),
|
||||
RateBucketInfo::new(20, Duration::from_secs(30)),
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
7
proxy/core/src/cache.rs
Normal file
7
proxy/core/src/cache.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub mod common;
|
||||
pub mod endpoints;
|
||||
pub mod project_info;
|
||||
mod timed_lru;
|
||||
|
||||
pub use common::{Cache, Cached};
|
||||
pub use timed_lru::TimedLru;
|
||||
89
proxy/core/src/cache/common.rs
vendored
Normal file
89
proxy/core/src/cache/common.rs
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
/// This is useful for [`Cached`].
|
||||
pub trait Cache {
|
||||
/// Entry's key.
|
||||
type Key;
|
||||
|
||||
/// Entry's value.
|
||||
type Value;
|
||||
|
||||
/// Used for entry invalidation.
|
||||
type LookupInfo<Key>;
|
||||
|
||||
/// Invalidate an entry using a lookup info.
|
||||
/// We don't have an empty default impl because it's error-prone.
|
||||
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
|
||||
}
|
||||
|
||||
impl<C: Cache> Cache for &C {
|
||||
type Key = C::Key;
|
||||
type Value = C::Value;
|
||||
type LookupInfo<Key> = C::LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
|
||||
C::invalidate(self, info)
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for convenient entry invalidation.
|
||||
pub struct Cached<C: Cache, V = <C as Cache>::Value> {
|
||||
/// Cache + lookup info.
|
||||
pub token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
|
||||
/// The value itself.
|
||||
pub value: V,
|
||||
}
|
||||
|
||||
impl<C: Cache, V> Cached<C, V> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
pub fn new_uncached(value: V) -> Self {
|
||||
Self { token: None, value }
|
||||
}
|
||||
|
||||
pub fn take_value(self) -> (Cached<C, ()>, V) {
|
||||
(
|
||||
Cached {
|
||||
token: self.token,
|
||||
value: (),
|
||||
},
|
||||
self.value,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn map<U>(self, f: impl FnOnce(V) -> U) -> Cached<C, U> {
|
||||
Cached {
|
||||
token: self.token,
|
||||
value: f(self.value),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(self) -> V {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
cache.invalidate(info);
|
||||
}
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Tell if this entry is actually cached.
|
||||
pub fn cached(&self) -> bool {
|
||||
self.token.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache, V> Deref for Cached<C, V> {
|
||||
type Target = V;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache, V> DerefMut for Cached<C, V> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.value
|
||||
}
|
||||
}
|
||||
247
proxy/core/src/cache/endpoints.rs
vendored
Normal file
247
proxy/core/src/cache/endpoints.rs
vendored
Normal file
@@ -0,0 +1,247 @@
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use dashmap::DashSet;
|
||||
use redis::{
|
||||
streams::{StreamReadOptions, StreamReadReply},
|
||||
AsyncCommands, FromRedisValue, Value,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
config::EndpointCacheConfig,
|
||||
context::RequestMonitoring,
|
||||
intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
|
||||
metrics::{Metrics, RedisErrors, RedisEventsCount},
|
||||
rate_limiter::GlobalRateLimiter,
|
||||
redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
|
||||
EndpointId,
|
||||
};
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct ControlPlaneEventKey {
|
||||
endpoint_created: Option<EndpointCreated>,
|
||||
branch_created: Option<BranchCreated>,
|
||||
project_created: Option<ProjectCreated>,
|
||||
}
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct EndpointCreated {
|
||||
endpoint_id: String,
|
||||
}
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct BranchCreated {
|
||||
branch_id: String,
|
||||
}
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct ProjectCreated {
|
||||
project_id: String,
|
||||
}
|
||||
|
||||
pub struct EndpointsCache {
|
||||
config: EndpointCacheConfig,
|
||||
endpoints: DashSet<EndpointIdInt>,
|
||||
branches: DashSet<BranchIdInt>,
|
||||
projects: DashSet<ProjectIdInt>,
|
||||
ready: AtomicBool,
|
||||
limiter: Arc<Mutex<GlobalRateLimiter>>,
|
||||
}
|
||||
|
||||
impl EndpointsCache {
|
||||
pub fn new(config: EndpointCacheConfig) -> Self {
|
||||
Self {
|
||||
limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
|
||||
config.limiter_info.clone(),
|
||||
))),
|
||||
config,
|
||||
endpoints: DashSet::new(),
|
||||
branches: DashSet::new(),
|
||||
projects: DashSet::new(),
|
||||
ready: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
pub async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
|
||||
if !self.ready.load(Ordering::Acquire) {
|
||||
return true;
|
||||
}
|
||||
let rejected = self.should_reject(endpoint);
|
||||
ctx.set_rejected(rejected);
|
||||
info!(?rejected, "check endpoint is valid, disabled cache");
|
||||
// If cache is disabled, just collect the metrics and return or
|
||||
// If the limiter allows, we don't need to check the cache.
|
||||
if self.config.disable_cache || self.limiter.lock().await.check() {
|
||||
return true;
|
||||
}
|
||||
!rejected
|
||||
}
|
||||
fn should_reject(&self, endpoint: &EndpointId) -> bool {
|
||||
if endpoint.is_endpoint() {
|
||||
!self.endpoints.contains(&EndpointIdInt::from(endpoint))
|
||||
} else if endpoint.is_branch() {
|
||||
!self
|
||||
.branches
|
||||
.contains(&BranchIdInt::from(&endpoint.as_branch()))
|
||||
} else {
|
||||
!self
|
||||
.projects
|
||||
.contains(&ProjectIdInt::from(&endpoint.as_project()))
|
||||
}
|
||||
}
|
||||
fn insert_event(&self, key: ControlPlaneEventKey) {
|
||||
// Do not do normalization here, we expect the events to be normalized.
|
||||
if let Some(endpoint_created) = key.endpoint_created {
|
||||
self.endpoints
|
||||
.insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::EndpointCreated);
|
||||
}
|
||||
if let Some(branch_created) = key.branch_created {
|
||||
self.branches
|
||||
.insert(BranchIdInt::from(&branch_created.branch_id.into()));
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::BranchCreated);
|
||||
}
|
||||
if let Some(project_created) = key.project_created {
|
||||
self.projects
|
||||
.insert(ProjectIdInt::from(&project_created.project_id.into()));
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::ProjectCreated);
|
||||
}
|
||||
}
|
||||
pub async fn do_read(
|
||||
&self,
|
||||
mut con: ConnectionWithCredentialsProvider,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<Infallible> {
|
||||
let mut last_id = "0-0".to_string();
|
||||
loop {
|
||||
if let Err(e) = con.connect().await {
|
||||
tracing::error!("error connecting to redis: {:?}", e);
|
||||
self.ready.store(false, Ordering::Release);
|
||||
}
|
||||
if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
|
||||
tracing::error!("error reading from redis: {:?}", e);
|
||||
self.ready.store(false, Ordering::Release);
|
||||
}
|
||||
if cancellation_token.is_cancelled() {
|
||||
info!("cancellation token is cancelled, exiting");
|
||||
tokio::time::sleep(Duration::from_secs(60 * 60 * 24 * 7)).await;
|
||||
// 1 week.
|
||||
}
|
||||
tokio::time::sleep(self.config.retry_interval).await;
|
||||
}
|
||||
}
|
||||
async fn read_from_stream(
|
||||
&self,
|
||||
con: &mut ConnectionWithCredentialsProvider,
|
||||
last_id: &mut String,
|
||||
) -> anyhow::Result<()> {
|
||||
tracing::info!("reading endpoints/branches/projects from redis");
|
||||
self.batch_read(
|
||||
con,
|
||||
StreamReadOptions::default().count(self.config.initial_batch_size),
|
||||
last_id,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
tracing::info!("ready to filter user requests");
|
||||
self.ready.store(true, Ordering::Release);
|
||||
self.batch_read(
|
||||
con,
|
||||
StreamReadOptions::default()
|
||||
.count(self.config.default_batch_size)
|
||||
.block(self.config.xread_timeout.as_millis() as usize),
|
||||
last_id,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
fn parse_key_value(value: &Value) -> anyhow::Result<ControlPlaneEventKey> {
|
||||
let s: String = FromRedisValue::from_redis_value(value)?;
|
||||
Ok(serde_json::from_str(&s)?)
|
||||
}
|
||||
async fn batch_read(
|
||||
&self,
|
||||
conn: &mut ConnectionWithCredentialsProvider,
|
||||
opts: StreamReadOptions,
|
||||
last_id: &mut String,
|
||||
return_when_finish: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut total: usize = 0;
|
||||
loop {
|
||||
let mut res: StreamReadReply = conn
|
||||
.xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
|
||||
.await?;
|
||||
|
||||
if res.keys.is_empty() {
|
||||
if return_when_finish {
|
||||
if total != 0 {
|
||||
break;
|
||||
}
|
||||
anyhow::bail!(
|
||||
"Redis stream {} is empty, cannot be used to filter endpoints",
|
||||
self.config.stream_name
|
||||
);
|
||||
}
|
||||
// If we are not returning when finish, we should wait for more data.
|
||||
continue;
|
||||
}
|
||||
if res.keys.len() != 1 {
|
||||
anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
|
||||
}
|
||||
|
||||
let res = res.keys.pop().expect("Checked length above");
|
||||
let len = res.ids.len();
|
||||
for x in res.ids {
|
||||
total += 1;
|
||||
for (_, v) in x.map {
|
||||
let key = match Self::parse_key_value(&v) {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
|
||||
channel: &self.config.stream_name,
|
||||
});
|
||||
tracing::error!("error parsing value {v:?}: {e:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
self.insert_event(key);
|
||||
}
|
||||
if total.is_power_of_two() {
|
||||
tracing::debug!("endpoints read {}", total);
|
||||
}
|
||||
*last_id = x.id;
|
||||
}
|
||||
if return_when_finish && len <= self.config.default_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
tracing::info!("read {} endpoints/branches/projects from redis", total);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ControlPlaneEventKey;
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}";
|
||||
let _: ControlPlaneEventKey = serde_json::from_str(s).unwrap();
|
||||
}
|
||||
}
|
||||
574
proxy/core/src/cache/project_info.rs
vendored
Normal file
574
proxy/core/src/cache/project_info.rs
vendored
Normal file
@@ -0,0 +1,574 @@
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
convert::Infallible,
|
||||
sync::{atomic::AtomicU64, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use rand::{thread_rng, Rng};
|
||||
use smol_str::SmolStr;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{
|
||||
auth::IpPattern,
|
||||
config::ProjectInfoCacheOptions,
|
||||
console::AuthSecret,
|
||||
intern::{EndpointIdInt, ProjectIdInt, RoleNameInt},
|
||||
EndpointId, RoleName,
|
||||
};
|
||||
|
||||
use super::{Cache, Cached};
|
||||
|
||||
#[async_trait]
|
||||
pub trait ProjectInfoCache {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
|
||||
async fn decrement_active_listeners(&self);
|
||||
async fn increment_active_listeners(&self);
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<T> Entry<T> {
|
||||
pub fn new(value: T) -> Self {
|
||||
Self {
|
||||
created_at: Instant::now(),
|
||||
value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for Entry<T> {
|
||||
fn from(value: T) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct EndpointInfo {
|
||||
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
|
||||
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
|
||||
}
|
||||
|
||||
impl EndpointInfo {
|
||||
fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
|
||||
match ignore_cache_since {
|
||||
None => false,
|
||||
Some(t) => t < created_at,
|
||||
}
|
||||
}
|
||||
pub fn get_role_secret(
|
||||
&self,
|
||||
role_name: RoleNameInt,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(Option<AuthSecret>, bool)> {
|
||||
if let Some(secret) = self.secret.get(&role_name) {
|
||||
if valid_since < secret.created_at {
|
||||
return Some((
|
||||
secret.value.clone(),
|
||||
Self::check_ignore_cache(ignore_cache_since, secret.created_at),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(Arc<Vec<IpPattern>>, bool)> {
|
||||
if let Some(allowed_ips) = &self.allowed_ips {
|
||||
if valid_since < allowed_ips.created_at {
|
||||
return Some((
|
||||
allowed_ips.value.clone(),
|
||||
Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
pub fn invalidate_allowed_ips(&mut self) {
|
||||
self.allowed_ips = None;
|
||||
}
|
||||
pub fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
|
||||
self.secret.remove(&role_name);
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache for project info.
|
||||
/// This is used to cache auth data for endpoints.
|
||||
/// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
|
||||
///
|
||||
/// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
|
||||
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
|
||||
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
|
||||
pub struct ProjectInfoCacheImpl {
|
||||
cache: DashMap<EndpointIdInt, EndpointInfo>,
|
||||
|
||||
project2ep: DashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
|
||||
config: ProjectInfoCacheOptions,
|
||||
|
||||
start_time: Instant,
|
||||
ttl_disabled_since_us: AtomicU64,
|
||||
active_listeners_lock: Mutex<usize>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
|
||||
info!("invalidating allowed ips for project `{}`", project_id);
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
.get(&project_id)
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_ips();
|
||||
}
|
||||
}
|
||||
}
|
||||
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
|
||||
info!(
|
||||
"invalidating role secret for project_id `{}` and role_name `{}`",
|
||||
project_id, role_name,
|
||||
);
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
.get(&project_id)
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_role_secret(role_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
async fn decrement_active_listeners(&self) {
|
||||
let mut listeners_guard = self.active_listeners_lock.lock().await;
|
||||
if *listeners_guard == 0 {
|
||||
tracing::error!("active_listeners count is already 0, something is broken");
|
||||
return;
|
||||
}
|
||||
*listeners_guard -= 1;
|
||||
if *listeners_guard == 0 {
|
||||
self.ttl_disabled_since_us
|
||||
.store(u64::MAX, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
async fn increment_active_listeners(&self) {
|
||||
let mut listeners_guard = self.active_listeners_lock.lock().await;
|
||||
*listeners_guard += 1;
|
||||
if *listeners_guard == 1 {
|
||||
let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
|
||||
self.ttl_disabled_since_us
|
||||
.store(new_ttl, std::sync::atomic::Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheImpl {
|
||||
pub fn new(config: ProjectInfoCacheOptions) -> Self {
|
||||
Self {
|
||||
cache: DashMap::new(),
|
||||
project2ep: DashMap::new(),
|
||||
config,
|
||||
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
|
||||
start_time: Instant::now(),
|
||||
active_listeners_lock: Mutex::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_role_secret(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
role_name: &RoleName,
|
||||
) -> Option<Cached<&Self, Option<AuthSecret>>> {
|
||||
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
|
||||
let role_name = RoleNameInt::get(role_name)?;
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(&endpoint_id)?;
|
||||
let (value, ignore_cache) =
|
||||
endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((
|
||||
self,
|
||||
CachedLookupInfo::new_role_secret(endpoint_id, role_name),
|
||||
)),
|
||||
value,
|
||||
};
|
||||
return Some(cached);
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
|
||||
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(&endpoint_id)?;
|
||||
let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
|
||||
let (value, ignore_cache) = value?;
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))),
|
||||
value,
|
||||
};
|
||||
return Some(cached);
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub fn insert_role_secret(
|
||||
&self,
|
||||
project_id: ProjectIdInt,
|
||||
endpoint_id: EndpointIdInt,
|
||||
role_name: RoleNameInt,
|
||||
secret: Option<AuthSecret>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
let mut entry = self.cache.entry(endpoint_id).or_default();
|
||||
if entry.secret.len() < self.config.max_roles {
|
||||
entry.secret.insert(role_name, secret.into());
|
||||
}
|
||||
}
|
||||
pub fn insert_allowed_ips(
|
||||
&self,
|
||||
project_id: ProjectIdInt,
|
||||
endpoint_id: EndpointIdInt,
|
||||
allowed_ips: Arc<Vec<IpPattern>>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
|
||||
}
|
||||
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
|
||||
if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
|
||||
endpoints.insert(endpoint_id);
|
||||
} else {
|
||||
self.project2ep
|
||||
.insert(project_id, HashSet::from([endpoint_id]));
|
||||
}
|
||||
}
|
||||
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
|
||||
let mut valid_since = Instant::now() - self.config.ttl;
|
||||
// Only ignore cache if ttl is disabled.
|
||||
let ttl_disabled_since_us = self
|
||||
.ttl_disabled_since_us
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let ignore_cache_since = if ttl_disabled_since_us != u64::MAX {
|
||||
let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
|
||||
// We are fine if entry is not older than ttl or was added before we are getting notifications.
|
||||
valid_since = valid_since.min(ignore_cache_since);
|
||||
Some(ignore_cache_since)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(valid_since, ignore_cache_since)
|
||||
}
|
||||
|
||||
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
|
||||
let mut interval =
|
||||
tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if self.cache.len() < self.config.size {
|
||||
// If there are not too many entries, wait until the next gc cycle.
|
||||
continue;
|
||||
}
|
||||
self.gc();
|
||||
}
|
||||
}
|
||||
|
||||
fn gc(&self) {
|
||||
let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
|
||||
debug!(shard, "project_info_cache: performing epoch reclamation");
|
||||
|
||||
// acquire a random shard lock
|
||||
let mut removed = 0;
|
||||
let shard = self.project2ep.shards()[shard].write();
|
||||
for (_, endpoints) in shard.iter() {
|
||||
for endpoint in endpoints.get().iter() {
|
||||
self.cache.remove(endpoint);
|
||||
removed += 1;
|
||||
}
|
||||
}
|
||||
// We can drop this shard only after making sure that all endpoints are removed.
|
||||
drop(shard);
|
||||
info!("project_info_cache: removed {removed} endpoints");
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup info for project info cache.
|
||||
/// This is used to invalidate cache entries.
|
||||
pub struct CachedLookupInfo {
|
||||
/// Search by this key.
|
||||
endpoint_id: EndpointIdInt,
|
||||
lookup_type: LookupType,
|
||||
}
|
||||
|
||||
impl CachedLookupInfo {
|
||||
pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::RoleSecret(role_name),
|
||||
}
|
||||
}
|
||||
pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::AllowedIps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum LookupType {
|
||||
RoleSecret(RoleNameInt),
|
||||
AllowedIps,
|
||||
}
|
||||
|
||||
impl Cache for ProjectInfoCacheImpl {
|
||||
type Key = SmolStr;
|
||||
// Value is not really used here, but we need to specify it.
|
||||
type Value = SmolStr;
|
||||
|
||||
type LookupInfo<Key> = CachedLookupInfo;
|
||||
|
||||
fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
|
||||
match &key.lookup_type {
|
||||
LookupType::RoleSecret(role_name) => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_role_secret(*role_name);
|
||||
}
|
||||
}
|
||||
LookupType::AllowedIps => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_ips();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{scram::ServerSecret, ProjectId};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_settings() {
|
||||
tokio::time::pause();
|
||||
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
});
|
||||
let project_id: ProjectId = "project".into();
|
||||
let endpoint_id: EndpointId = "endpoint".into();
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
let secret2 = None;
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
secret1.clone(),
|
||||
);
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
secret2.clone(),
|
||||
);
|
||||
cache.insert_allowed_ips(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert!(cached.cached());
|
||||
assert_eq!(cached.value, secret1);
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert!(cached.cached());
|
||||
assert_eq!(cached.value, secret2);
|
||||
|
||||
// Shouldn't add more than 2 roles.
|
||||
let user3: RoleName = "user3".into();
|
||||
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
(&user3).into(),
|
||||
secret3.clone(),
|
||||
);
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(cached.cached());
|
||||
assert_eq!(cached.value, allowed_ips);
|
||||
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_allowed_ips(&endpoint_id);
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_invalidations() {
|
||||
tokio::time::pause();
|
||||
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
}));
|
||||
cache.clone().increment_active_listeners().await;
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
|
||||
let project_id: ProjectId = "project".into();
|
||||
let endpoint_id: EndpointId = "endpoint".into();
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
secret1.clone(),
|
||||
);
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
secret2.clone(),
|
||||
);
|
||||
cache.insert_allowed_ips(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
// Nothing should be invalidated.
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
// TTL is disabled, so it should be impossible to invalidate this value.
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, secret1);
|
||||
|
||||
cached.invalidate(); // Shouldn't do anything.
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert_eq!(cached.value, secret1);
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, secret2);
|
||||
|
||||
// The only way to invalidate this value is to invalidate via the api.
|
||||
cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, allowed_ips);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_active_listeners_invalidate_added_before() {
|
||||
tokio::time::pause();
|
||||
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
}));
|
||||
|
||||
let project_id: ProjectId = "project".into();
|
||||
let endpoint_id: EndpointId = "endpoint".into();
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
]);
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
secret1.clone(),
|
||||
);
|
||||
cache.clone().increment_active_listeners().await;
|
||||
tokio::time::advance(Duration::from_millis(100)).await;
|
||||
cache.insert_role_secret(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
secret2.clone(),
|
||||
);
|
||||
|
||||
// Added before ttl was disabled + ttl should be still cached.
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert!(cached.cached());
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert!(cached.cached());
|
||||
|
||||
tokio::time::advance(Duration::from_secs(1)).await;
|
||||
// Added before ttl was disabled + ttl should expire.
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
|
||||
// Added after ttl was disabled + ttl should not be cached.
|
||||
cache.insert_allowed_ips(
|
||||
(&project_id).into(),
|
||||
(&endpoint_id).into(),
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(!cached.cached());
|
||||
|
||||
tokio::time::advance(Duration::from_secs(1)).await;
|
||||
// Added before ttl was disabled + ttl still should expire.
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
// Shouldn't be invalidated.
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, allowed_ips);
|
||||
}
|
||||
}
|
||||
290
proxy/core/src/cache/timed_lru.rs
vendored
Normal file
290
proxy/core/src/cache/timed_lru.rs
vendored
Normal file
@@ -0,0 +1,290 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
|
||||
use super::{common::Cached, *};
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
|
||||
update_ttl_on_retrieval: bool,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type Key = K;
|
||||
type Value = V;
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
ttl: Duration,
|
||||
update_ttl_on_retrieval: bool,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
capacity: usize,
|
||||
ttl: Duration,
|
||||
update_ttl_on_retrieval: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
update_ttl_on_retrieval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop an entry from the cache if it's outdated.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: &LookupInfo<K>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(raw_entry.key(), entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
let deadline = now.checked_add(raw_entry.get().ttl).expect("time overflow");
|
||||
if raw_entry.get().update_ttl_on_retrieval {
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
}
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
old_expires_at = format_args!("{expires_at:?}"),
|
||||
new_expires_at = format_args!("{deadline:?}"),
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
|
||||
self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw_ttl(
|
||||
&self,
|
||||
key: K,
|
||||
value: V,
|
||||
ttl: Duration,
|
||||
update: bool,
|
||||
) -> (Instant, Option<V>) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
ttl,
|
||||
update_ttl_on_retrieval: update,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(key, entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(created_at, old)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert_ttl(&self, key: K, value: V, ttl: Duration) {
|
||||
self.insert_raw_ttl(key, value, ttl, false);
|
||||
}
|
||||
|
||||
pub fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value);
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value: (),
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper.
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some((self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieve a cached entry in convenient wrapper, ignoring its TTL.
|
||||
pub fn get_ignoring_ttl<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let mut cache = self.cache.lock();
|
||||
cache
|
||||
.get(key)
|
||||
.map(|entry| Cached::new_uncached(entry.value.clone()))
|
||||
}
|
||||
|
||||
/// Remove an entry from the cache.
|
||||
pub fn remove<Q>(&self, key: &Q) -> Option<V>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let mut cache = self.cache.lock();
|
||||
cache.remove(key).map(|entry| entry.value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
235
proxy/core/src/cancellation.rs
Normal file
235
proxy/core/src/cancellation.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
use dashmap::DashMap;
|
||||
use pq_proto::CancelKeyData;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
error::ReportableError,
|
||||
metrics::{CancellationRequest, CancellationSource, Metrics},
|
||||
redis::cancellation_publisher::{
|
||||
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
|
||||
},
|
||||
};
|
||||
|
||||
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
|
||||
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
|
||||
pub type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
|
||||
|
||||
/// Enables serving `CancelRequest`s.
|
||||
///
|
||||
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
|
||||
pub struct CancellationHandler<P> {
|
||||
map: CancelMap,
|
||||
client: P,
|
||||
/// This field used for the monitoring purposes.
|
||||
/// Represents the source of the cancellation request.
|
||||
from: CancellationSource,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CancelError {
|
||||
#[error("{0}")]
|
||||
IO(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
}
|
||||
|
||||
impl ReportableError for CancelError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
CancelError::IO(_) => crate::error::ErrorKind::Compute,
|
||||
CancelError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub fn get_session(self: Arc<Self>) -> Session<P> {
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
// so it doesn't matter.
|
||||
let key = loop {
|
||||
let key = rand::random();
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
match self.map.entry(key) {
|
||||
dashmap::mapref::entry::Entry::Occupied(_) => continue,
|
||||
dashmap::mapref::entry::Entry::Vacant(e) => {
|
||||
e.insert(None);
|
||||
}
|
||||
}
|
||||
break key;
|
||||
};
|
||||
|
||||
info!("registered new query cancellation key {key}");
|
||||
Session {
|
||||
key,
|
||||
cancellation_handler: self,
|
||||
}
|
||||
}
|
||||
/// Try to cancel a running query for the corresponding connection.
|
||||
/// If the cancellation key is not found, it will be published to Redis.
|
||||
pub async fn cancel_session(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> Result<(), CancelError> {
|
||||
// NB: we should immediately release the lock after cloning the token.
|
||||
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
.inc(CancellationRequest {
|
||||
source: self.from,
|
||||
kind: crate::metrics::CancellationOutcome::NotFound,
|
||||
});
|
||||
match self.client.try_publish(key, session_id).await {
|
||||
Ok(()) => {} // do nothing
|
||||
Err(e) => {
|
||||
return Err(CancelError::IO(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
.inc(CancellationRequest {
|
||||
source: self.from,
|
||||
kind: crate::metrics::CancellationOutcome::Found,
|
||||
});
|
||||
info!("cancelling query per user's request using key {key}");
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn contains(&self, session: &Session<P>) -> bool {
|
||||
self.map.contains_key(&session.key)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl CancellationHandler<()> {
|
||||
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
|
||||
Self {
|
||||
map,
|
||||
client: (),
|
||||
from,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
|
||||
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
|
||||
Self { map, client, from }
|
||||
}
|
||||
}
|
||||
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
/// it's impossible to name a type of an unboxed future
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
#[derive(Clone)]
|
||||
pub struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
}
|
||||
|
||||
impl CancelClosure {
|
||||
pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
|
||||
Self {
|
||||
socket_addr,
|
||||
cancel_token,
|
||||
}
|
||||
}
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub async fn try_cancel_query(self) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
|
||||
info!("query was cancelled");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for registering query cancellation tokens.
|
||||
pub struct Session<P> {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
cancellation_handler: Arc<CancellationHandler<P>>,
|
||||
}
|
||||
|
||||
impl<P> Session<P> {
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
||||
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
info!("enabling query cancellation for this session");
|
||||
self.cancellation_handler
|
||||
.map
|
||||
.insert(self.key, Some(cancel_closure));
|
||||
|
||||
self.key
|
||||
}
|
||||
}
|
||||
|
||||
impl<P> Drop for Session<P> {
|
||||
fn drop(&mut self) {
|
||||
self.cancellation_handler.map.remove(&self.key);
|
||||
info!("dropped query cancellation key {}", &self.key);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_session_drop() -> anyhow::Result<()> {
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
|
||||
CancelMap::default(),
|
||||
CancellationSource::FromRedis,
|
||||
));
|
||||
|
||||
let session = cancellation_handler.clone().get_session();
|
||||
assert!(cancellation_handler.contains(&session));
|
||||
drop(session);
|
||||
// Check that the session has been dropped.
|
||||
assert!(cancellation_handler.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_session_noop_regression() {
|
||||
let handler = CancellationHandler::<()>::new(Default::default(), CancellationSource::Local);
|
||||
handler
|
||||
.cancel_session(
|
||||
CancelKeyData {
|
||||
backend_pid: 0,
|
||||
cancel_key: 0,
|
||||
},
|
||||
Uuid::new_v4(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
439
proxy/core/src/compute.rs
Normal file
439
proxy/core/src/compute.rs
Normal file
@@ -0,0 +1,439 @@
|
||||
use crate::{
|
||||
auth::parse_endpoint_param,
|
||||
cancellation::CancelClosure,
|
||||
console::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
metrics::{Metrics, NumDbConnectionsGuard},
|
||||
proxy::neon_option,
|
||||
Host,
|
||||
};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError};
|
||||
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConnectionError {
|
||||
/// This error doesn't seem to reveal any secrets; for instance,
|
||||
/// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] InvalidDnsNameError),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
WakeComputeError(#[from] WakeComputeError),
|
||||
|
||||
#[error("error acquiring resource permit: {0}")]
|
||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnectionError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ConnectionError::*;
|
||||
match self {
|
||||
// This helps us drop irrelevant library-specific prefixes.
|
||||
// TODO: propagate severity level and other parameters.
|
||||
Postgres(err) => match err.as_db_error() {
|
||||
Some(err) => {
|
||||
let msg = err.message();
|
||||
|
||||
if msg.starts_with("unsupported startup parameter: ")
|
||||
|| msg.starts_with("unsupported startup parameter in options: ")
|
||||
{
|
||||
format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
|
||||
} else {
|
||||
msg.to_owned()
|
||||
}
|
||||
}
|
||||
None => err.to_string(),
|
||||
},
|
||||
WakeComputeError(err) => err.to_string_client(),
|
||||
TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
_ => COULD_NOT_CONNECT.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ConnectionError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `tokio_postgres` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ConnCfg(Box<tokio_postgres::Config>);
|
||||
|
||||
/// Creation and initialization routines.
|
||||
impl ConnCfg {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub fn reuse_password(&mut self, other: Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
|
||||
if let Some(keys) = other.get_auth_keys() {
|
||||
self.auth_keys(keys);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_host(&self) -> Result<Host, WakeComputeError> {
|
||||
match self.0.get_hosts() {
|
||||
[tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
|
||||
// we should not have multiple address or unix addresses.
|
||||
_ => Err(WakeComputeError::BadComputeAddress(
|
||||
"invalid compute address".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
|
||||
// Only set `user` if it's not present in the config.
|
||||
// Link auth flow takes username from the console's response.
|
||||
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
|
||||
self.user(user);
|
||||
}
|
||||
|
||||
// Only set `dbname` if it's not present in the config.
|
||||
// Link auth flow takes dbname from the console's response.
|
||||
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
|
||||
self.dbname(dbname);
|
||||
}
|
||||
|
||||
// Don't add `options` if they were only used for specifying a project.
|
||||
// Connection pools don't support `options`, because they affect backend startup.
|
||||
if let Some(options) = filtered_options(params) {
|
||||
self.options(&options);
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.application_name(app_name);
|
||||
}
|
||||
|
||||
// TODO: This is especially ugly...
|
||||
if let Some(replication) = params.get("replication") {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.replication_mode(ReplicationMode::Physical);
|
||||
}
|
||||
"database" => {
|
||||
self.replication_mode(ReplicationMode::Logical);
|
||||
}
|
||||
_other => {}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extend the list of the forwarded startup parameters.
|
||||
// Currently, tokio-postgres doesn't allow us to pass
|
||||
// arbitrary parameters, but the ones above are a good start.
|
||||
//
|
||||
// This and the reverse params problem can be better addressed
|
||||
// in a bespoke connection machinery (a new library for that sake).
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ConnCfg {
|
||||
type Target = tokio_postgres::Config;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// For now, let's make it easier to setup the config.
|
||||
impl std::ops::DerefMut for ConnCfg {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
/// Establish a raw TCP connection to the compute node.
|
||||
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
|
||||
use tokio_postgres::config::Host;
|
||||
|
||||
// wrap TcpStream::connect with timeout
|
||||
let connect_with_timeout = |host, port| {
|
||||
tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
|
||||
move |res| match res {
|
||||
Ok(tcpstream_connect_res) => tcpstream_connect_res,
|
||||
Err(_) => Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
format!("exceeded connection timeout {timeout:?}"),
|
||||
)),
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
let connect_once = |host, port| {
|
||||
info!("trying to connect to compute node at {host}:{port}");
|
||||
connect_with_timeout(host, port).and_then(|socket| async {
|
||||
let socket_addr = socket.peer_addr()?;
|
||||
// This prevents load balancer from severing the connection.
|
||||
socket2::SockRef::from(&socket).set_keepalive(true)?;
|
||||
Ok((socket_addr, socket))
|
||||
})
|
||||
};
|
||||
|
||||
// We can't reuse connection establishing logic from `tokio_postgres` here,
|
||||
// because it has no means for extracting the underlying socket which we
|
||||
// require for our business.
|
||||
let mut connection_error = None;
|
||||
let ports = self.0.get_ports();
|
||||
let hosts = self.0.get_hosts();
|
||||
// the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
|
||||
if ports.len() > 1 && ports.len() != hosts.len() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!(
|
||||
"bad compute config, \
|
||||
ports and hosts entries' count does not match: {:?}",
|
||||
self.0
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
for (i, host) in hosts.iter().enumerate() {
|
||||
let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
|
||||
let host = match host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
Host::Unix(_) => continue, // unix sockets are not welcome here
|
||||
};
|
||||
|
||||
match connect_once(host, *port).await {
|
||||
Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
|
||||
Err(err) => {
|
||||
// We can't throw an error here, as there might be more hosts to try.
|
||||
warn!("couldn't connect to compute node at {host}:{port}: {err}");
|
||||
connection_error = Some(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(connection_error.unwrap_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("bad compute config: {:?}", self.0),
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
|
||||
tokio::net::TcpStream,
|
||||
tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
|
||||
>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
pub cancel_closure: CancelClosure,
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
_guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
allow_self_signed_compute: bool,
|
||||
aux: MetricsAuxInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
drop(pause);
|
||||
|
||||
let client_config = if allow_self_signed_compute {
|
||||
// Allow all certificates for creating the connection
|
||||
let verifier = Arc::new(AcceptEverythingVerifier) as Arc<dyn ServerCertVerifier>;
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(verifier)
|
||||
} else {
|
||||
let root_store = TLS_ROOTS.get_or_try_init(load_certs)?.clone();
|
||||
rustls::ClientConfig::builder().with_root_certificates(root_store)
|
||||
};
|
||||
let client_config = client_config.with_no_client_auth();
|
||||
|
||||
let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
host,
|
||||
)?;
|
||||
|
||||
// connect_raw() will not use TLS if sslmode is "disable"
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = self.0.connect_raw(stream, tls).await?;
|
||||
drop(pause);
|
||||
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
|
||||
let stream = connection.stream.into_inner();
|
||||
|
||||
info!(
|
||||
cold_start_info = ctx.cold_start_info().as_str(),
|
||||
"connected to compute node at {host} ({socket_addr}) sslmode={:?}",
|
||||
self.0.get_ssl_mode()
|
||||
);
|
||||
|
||||
// This is very ugly but as of now there's no better way to
|
||||
// extract the connection parameters from tokio-postgres' connection.
|
||||
// TODO: solve this problem in a more elegant manner (e.g. the new library).
|
||||
let params = connection.parameters;
|
||||
|
||||
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
|
||||
// Yet another reason to rework the connection establishing code.
|
||||
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
|
||||
|
||||
let connection = PostgresConnection {
|
||||
stream,
|
||||
params,
|
||||
cancel_closure,
|
||||
aux,
|
||||
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
};
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve `options` from a startup message, dropping all proxy-secific flags.
|
||||
fn filtered_options(params: &StartupMessageParams) -> Option<String> {
|
||||
#[allow(unstable_name_collisions)]
|
||||
let options: String = params
|
||||
.options_raw()?
|
||||
.filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
|
||||
.intersperse(" ") // TODO: use impl from std once it's stabilized
|
||||
.collect();
|
||||
|
||||
// Don't even bother with empty options.
|
||||
if options.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(options)
|
||||
}
|
||||
|
||||
fn load_certs() -> Result<Arc<rustls::RootCertStore>, io::Error> {
|
||||
let der_certs = rustls_native_certs::load_native_certs()?;
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(der_certs);
|
||||
Ok(Arc::new(store))
|
||||
}
|
||||
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AcceptEverythingVerifier;
|
||||
impl ServerCertVerifier for AcceptEverythingVerifier {
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
use rustls::SignatureScheme::*;
|
||||
// The schemes for which `SignatureScheme::supported_in_tls13` returns true.
|
||||
vec![
|
||||
ECDSA_NISTP521_SHA512,
|
||||
ECDSA_NISTP384_SHA384,
|
||||
ECDSA_NISTP256_SHA256,
|
||||
RSA_PSS_SHA512,
|
||||
RSA_PSS_SHA384,
|
||||
RSA_PSS_SHA256,
|
||||
ED25519,
|
||||
]
|
||||
}
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_filtered_options() {
|
||||
// Empty options is unlikely to be useful anyway.
|
||||
let params = StartupMessageParams::new([("options", "")]);
|
||||
assert_eq!(filtered_options(¶ms), None);
|
||||
|
||||
// It's likely that clients will only use options to specify endpoint/project.
|
||||
let params = StartupMessageParams::new([("options", "project=foo")]);
|
||||
assert_eq!(filtered_options(¶ms), None);
|
||||
|
||||
// Same, because unescaped whitespaces are no-op.
|
||||
let params = StartupMessageParams::new([("options", " project=foo ")]);
|
||||
assert_eq!(filtered_options(¶ms).as_deref(), None);
|
||||
|
||||
let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
|
||||
assert_eq!(filtered_options(¶ms).as_deref(), Some(r"\ \ "));
|
||||
|
||||
let params = StartupMessageParams::new([("options", "project = foo")]);
|
||||
assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
|
||||
|
||||
let params = StartupMessageParams::new([(
|
||||
"options",
|
||||
"project = foo neon_endpoint_type:read_write neon_lsn:0/2",
|
||||
)]);
|
||||
assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
|
||||
}
|
||||
}
|
||||
763
proxy/core/src/config.rs
Normal file
763
proxy/core/src/config.rs
Normal file
@@ -0,0 +1,763 @@
|
||||
use crate::{
|
||||
auth::{self, backend::AuthRateLimiter},
|
||||
console::locks::ApiLocks,
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
||||
scram::threadpool::ThreadPool,
|
||||
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
|
||||
Host,
|
||||
};
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use itertools::Itertools;
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use rustls::{
|
||||
crypto::ring::sign,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::BackendType<'static, (), ()>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub http_config: HttpConfig,
|
||||
pub authentication_config: AuthenticationConfig,
|
||||
pub require_client_ip: bool,
|
||||
pub region: String,
|
||||
pub handshake_timeout: Duration,
|
||||
pub wake_compute_retry_config: RetryConfig,
|
||||
pub connect_compute_locks: ApiLocks<Host>,
|
||||
pub connect_to_compute_retry_config: RetryConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MetricCollectionConfig {
|
||||
pub endpoint: reqwest::Url,
|
||||
pub interval: Duration,
|
||||
pub backup_metric_collection_config: MetricBackupCollectionConfig,
|
||||
}
|
||||
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_names: HashSet<String>,
|
||||
pub cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
pub struct HttpConfig {
|
||||
pub pool_options: GlobalConnPoolOptions,
|
||||
pub cancel_set: CancelSet,
|
||||
pub client_conn_threshold: u64,
|
||||
}
|
||||
|
||||
pub struct AuthenticationConfig {
|
||||
pub thread_pool: Arc<ThreadPool>,
|
||||
pub scram_protocol_timeout: tokio::time::Duration,
|
||||
pub rate_limiter_enabled: bool,
|
||||
pub rate_limiter: AuthRateLimiter,
|
||||
pub rate_limit_ip_subnet: u8,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
|
||||
self.config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
|
||||
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
|
||||
|
||||
/// Configure TLS for the main endpoint.
|
||||
pub fn configure_tls(
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
certs_dir: Option<&String>,
|
||||
) -> anyhow::Result<TlsConfig> {
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
|
||||
// add default certificate
|
||||
cert_resolver.add_cert_path(key_path, cert_path, true)?;
|
||||
|
||||
// add extra certificates
|
||||
if let Some(certs_dir) = certs_dir {
|
||||
for entry in std::fs::read_dir(certs_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
// file names aligned with default cert-manager names
|
||||
let key_path = path.join("tls.key");
|
||||
let cert_path = path.join("tls.crt");
|
||||
if key_path.exists() && cert_path.exists() {
|
||||
cert_resolver.add_cert_path(
|
||||
&key_path.to_string_lossy(),
|
||||
&cert_path.to_string_lossy(),
|
||||
false,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
let cert_resolver = Arc::new(cert_resolver);
|
||||
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[
|
||||
&rustls::version::TLS13,
|
||||
&rustls::version::TLS12,
|
||||
])
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(cert_resolver.clone());
|
||||
|
||||
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
|
||||
|
||||
Ok(TlsConfig {
|
||||
config: Arc::new(config),
|
||||
common_names,
|
||||
cert_resolver,
|
||||
})
|
||||
}
|
||||
|
||||
/// Channel binding parameter
|
||||
///
|
||||
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
|
||||
/// Description: The hash of the TLS server's certificate as it
|
||||
/// appears, octet for octet, in the server's Certificate message. Note
|
||||
/// that the Certificate message contains a certificate_list, in which
|
||||
/// the first element is the server's certificate.
|
||||
///
|
||||
/// The hash function is to be selected as follows:
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function and that hash function neither MD5 nor SHA-1, then use
|
||||
/// the hash function associated with the certificate's
|
||||
/// signatureAlgorithm;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses no hash functions or
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
}
|
||||
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
info!(subject = %pem.subject, "parsing TLS certificate");
|
||||
|
||||
let reg = oid_registry::OidRegistry::default().with_all_crypto();
|
||||
let oid = pem.signature_algorithm.oid();
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
|
||||
Ok(Self::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported(&self) -> bool {
|
||||
!matches!(self, TlsServerEndPoint::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn add_cert_path(
|
||||
&mut self,
|
||||
key_path: &str,
|
||||
cert_path: &str,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
PrivateKeyDer::Pkcs8(
|
||||
keys.pop()
|
||||
.unwrap()
|
||||
.context(format!("Failed to parse TLS keys at '{key_path}'"))?,
|
||||
)
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
|
||||
let cert_chain = {
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.try_collect()
|
||||
.with_context(|| {
|
||||
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
|
||||
})?
|
||||
};
|
||||
|
||||
self.add_cert(priv_key, cert_chain, is_default)
|
||||
}
|
||||
|
||||
pub fn add_cert(
|
||||
&mut self,
|
||||
priv_key: PrivateKeyDer<'static>,
|
||||
cert_chain: Vec<CertificateDer<'static>>,
|
||||
is_default: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
|
||||
|
||||
let first_cert = &cert_chain[0];
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
let pem = x509_parser::parse_x509_certificate(first_cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
let common_name = if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
.context("Failed to parse common name from certificate")?;
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
if is_default {
|
||||
self.default = Some((cert.clone(), tls_server_end_point));
|
||||
}
|
||||
|
||||
self.certs.insert(common_name, (cert, tls_server_end_point));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_common_names(&self) -> HashSet<String> {
|
||||
self.certs.keys().map(|s| s.to_string()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
client_hello: rustls::server::ClientHello,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub fn resolve(
|
||||
&self,
|
||||
server_name: Option<&str>,
|
||||
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
|
||||
// loop here and cut off more and more subdomains until we find
|
||||
// a match to get a proper wildcard support. OTOH, we now do not
|
||||
// use nested domains, so keep this simple for now.
|
||||
//
|
||||
// With the current coding foo.com will match *.foo.com and that
|
||||
// repeats behavior of the old code.
|
||||
if let Some(mut sni_name) = server_name {
|
||||
loop {
|
||||
if let Some(cert) = self.certs.get(sni_name) {
|
||||
return Some(cert.clone());
|
||||
}
|
||||
if let Some((_, rest)) = sni_name.split_once('.') {
|
||||
sni_name = rest;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No SNI, use the default certificate, otherwise we can't get to
|
||||
// options parameter which can be used to set endpoint name too.
|
||||
// That means that non-SNI flow will not work for CNAME domains in
|
||||
// verify-full mode.
|
||||
//
|
||||
// If that will be a problem we can:
|
||||
//
|
||||
// a) Instead of multi-cert approach use single cert with extra
|
||||
// domains listed in Subject Alternative Name (SAN).
|
||||
// b) Deploy separate proxy instances for extra domains.
|
||||
self.default.as_ref().cloned()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EndpointCacheConfig {
|
||||
/// Batch size to receive all endpoints on the startup.
|
||||
pub initial_batch_size: usize,
|
||||
/// Batch size to receive endpoints.
|
||||
pub default_batch_size: usize,
|
||||
/// Timeouts for the stream read operation.
|
||||
pub xread_timeout: Duration,
|
||||
/// Stream name to read from.
|
||||
pub stream_name: String,
|
||||
/// Limiter info (to distinguish when to enable cache).
|
||||
pub limiter_info: Vec<RateBucketInfo>,
|
||||
/// Disable cache.
|
||||
/// If true, cache is ignored, but reports all statistics.
|
||||
pub disable_cache: bool,
|
||||
/// Retry interval for the stream read operation.
|
||||
pub retry_interval: Duration,
|
||||
}
|
||||
|
||||
impl EndpointCacheConfig {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
/// Notice that by default the limiter is empty, which means that cache is disabled.
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut initial_batch_size = None;
|
||||
let mut default_batch_size = None;
|
||||
let mut xread_timeout = None;
|
||||
let mut stream_name = None;
|
||||
let mut limiter_info = vec![];
|
||||
let mut disable_cache = false;
|
||||
let mut retry_interval = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"initial_batch_size" => initial_batch_size = Some(value.parse()?),
|
||||
"default_batch_size" => default_batch_size = Some(value.parse()?),
|
||||
"xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
|
||||
"stream_name" => stream_name = Some(value.to_string()),
|
||||
"limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
|
||||
"disable_cache" => disable_cache = value.parse()?,
|
||||
"retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
RateBucketInfo::validate(&mut limiter_info)?;
|
||||
|
||||
Ok(Self {
|
||||
initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
|
||||
default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
|
||||
xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
|
||||
stream_name: stream_name.context("missing `stream_name`")?,
|
||||
disable_cache,
|
||||
limiter_info,
|
||||
retry_interval: retry_interval.context("missing `retry_interval`")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for EndpointCacheConfig {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse endpoint cache options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct MetricBackupCollectionConfig {
|
||||
pub interval: Duration,
|
||||
pub remote_storage_config: Option<RemoteStorageConfig>,
|
||||
pub chunk_size: usize,
|
||||
}
|
||||
|
||||
pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig> {
|
||||
RemoteStorageConfig::from_toml(&s.parse()?)
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[derive(Debug)]
|
||||
pub struct CacheOptions {
|
||||
/// Max number of entries.
|
||||
pub size: usize,
|
||||
/// Entry's time-to-live.
|
||||
pub ttl: Duration,
|
||||
}
|
||||
|
||||
impl CacheOptions {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut size = None;
|
||||
let mut ttl = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"size" => size = Some(value.parse()?),
|
||||
"ttl" => ttl = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// TTL doesn't matter if cache is always empty.
|
||||
if let Some(0) = size {
|
||||
ttl.get_or_insert(Duration::default());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
size: size.context("missing `size`")?,
|
||||
ttl: ttl.context("missing `ttl`")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for CacheOptions {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse cache options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[derive(Debug)]
|
||||
pub struct ProjectInfoCacheOptions {
|
||||
/// Max number of entries.
|
||||
pub size: usize,
|
||||
/// Entry's time-to-live.
|
||||
pub ttl: Duration,
|
||||
/// Max number of roles per endpoint.
|
||||
pub max_roles: usize,
|
||||
/// Gc interval.
|
||||
pub gc_interval: Duration,
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheOptions {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"size=10000,ttl=4m,max_roles=10,gc_interval=60m";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut size = None;
|
||||
let mut ttl = None;
|
||||
let mut max_roles = None;
|
||||
let mut gc_interval = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"size" => size = Some(value.parse()?),
|
||||
"ttl" => ttl = Some(humantime::parse_duration(value)?),
|
||||
"max_roles" => max_roles = Some(value.parse()?),
|
||||
"gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// TTL doesn't matter if cache is always empty.
|
||||
if let Some(0) = size {
|
||||
ttl.get_or_insert(Duration::default());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
size: size.context("missing `size`")?,
|
||||
ttl: ttl.context("missing `ttl`")?,
|
||||
max_roles: max_roles.context("missing `max_roles`")?,
|
||||
gc_interval: gc_interval.context("missing `gc_interval`")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for ProjectInfoCacheOptions {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse cache options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// This is a config for connect to compute and wake compute.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct RetryConfig {
|
||||
/// Number of times we should retry.
|
||||
pub max_retries: u32,
|
||||
/// Retry duration is base_delay * backoff_factor ^ n, where n starts at 0
|
||||
pub base_delay: tokio::time::Duration,
|
||||
/// Exponential base for retry wait duration
|
||||
pub backoff_factor: f64,
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
/// Default options for RetryConfig.
|
||||
|
||||
/// Total delay for 5 retries with 200ms base delay and 2 backoff factor is about 6s.
|
||||
pub const CONNECT_TO_COMPUTE_DEFAULT_VALUES: &'static str =
|
||||
"num_retries=5,base_retry_wait_duration=200ms,retry_wait_exponent_base=2";
|
||||
/// Total delay for 8 retries with 100ms base delay and 1.6 backoff factor is about 7s.
|
||||
/// Cplane has timeout of 60s on each request. 8m7s in total.
|
||||
pub const WAKE_COMPUTE_DEFAULT_VALUES: &'static str =
|
||||
"num_retries=8,base_retry_wait_duration=100ms,retry_wait_exponent_base=1.6";
|
||||
|
||||
/// Parse retry options passed via cmdline.
|
||||
/// Example: [`Self::CONNECT_TO_COMPUTE_DEFAULT_VALUES`].
|
||||
pub fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut num_retries = None;
|
||||
let mut base_retry_wait_duration = None;
|
||||
let mut retry_wait_exponent_base = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"num_retries" => num_retries = Some(value.parse()?),
|
||||
"base_retry_wait_duration" => {
|
||||
base_retry_wait_duration = Some(humantime::parse_duration(value)?)
|
||||
}
|
||||
"retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
max_retries: num_retries.context("missing `num_retries`")?,
|
||||
base_delay: base_retry_wait_duration.context("missing `base_retry_wait_duration`")?,
|
||||
backoff_factor: retry_wait_exponent_base
|
||||
.context("missing `retry_wait_exponent_base`")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct ConcurrencyLockOptions {
|
||||
/// The number of shards the lock map should have
|
||||
pub shards: usize,
|
||||
/// The number of allowed concurrent requests for each endpoitn
|
||||
#[serde(flatten)]
|
||||
pub limiter: RateLimiterConfig,
|
||||
/// Garbage collection epoch
|
||||
#[serde(deserialize_with = "humantime_serde::deserialize")]
|
||||
pub epoch: Duration,
|
||||
/// Lock timeout
|
||||
#[serde(deserialize_with = "humantime_serde::deserialize")]
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
impl ConcurrencyLockOptions {
|
||||
/// Default options for [`crate::console::provider::ApiLocks`].
|
||||
pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
|
||||
/// Default options for [`crate::console::provider::ApiLocks`].
|
||||
pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
|
||||
"shards=64,permits=100,epoch=10m,timeout=10ms";
|
||||
|
||||
// pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
|
||||
|
||||
/// Parse lock options passed via cmdline.
|
||||
/// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let options = options.trim();
|
||||
if options.starts_with('{') && options.ends_with('}') {
|
||||
return Ok(serde_json::from_str(options)?);
|
||||
}
|
||||
|
||||
let mut shards = None;
|
||||
let mut permits = None;
|
||||
let mut epoch = None;
|
||||
let mut timeout = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"shards" => shards = Some(value.parse()?),
|
||||
"permits" => permits = Some(value.parse()?),
|
||||
"epoch" => epoch = Some(humantime::parse_duration(value)?),
|
||||
"timeout" => timeout = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// these dont matter if lock is disabled
|
||||
if let Some(0) = permits {
|
||||
timeout = Some(Duration::default());
|
||||
epoch = Some(Duration::default());
|
||||
shards = Some(2);
|
||||
}
|
||||
|
||||
let permits = permits.context("missing `permits`")?;
|
||||
let out = Self {
|
||||
shards: shards.context("missing `shards`")?,
|
||||
limiter: RateLimiterConfig {
|
||||
algorithm: RateLimitAlgorithm::Fixed,
|
||||
initial_limit: permits,
|
||||
},
|
||||
epoch: epoch.context("missing `epoch`")?,
|
||||
timeout: timeout.context("missing `timeout`")?,
|
||||
};
|
||||
|
||||
ensure!(out.shards > 1, "shard count must be > 1");
|
||||
ensure!(
|
||||
out.shards.is_power_of_two(),
|
||||
"shard count must be a power of two"
|
||||
);
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for ConcurrencyLockOptions {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse cache lock options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::rate_limiter::Aimd;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_cache_options() -> anyhow::Result<()> {
|
||||
let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
|
||||
assert_eq!(size, 4096);
|
||||
assert_eq!(ttl, Duration::from_secs(5 * 60));
|
||||
|
||||
let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
|
||||
assert_eq!(size, 2);
|
||||
assert_eq!(ttl, Duration::from_secs(4 * 60));
|
||||
|
||||
let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
|
||||
assert_eq!(size, 0);
|
||||
assert_eq!(ttl, Duration::from_secs(1));
|
||||
|
||||
let CacheOptions { size, ttl } = "size=0".parse()?;
|
||||
assert_eq!(size, 0);
|
||||
assert_eq!(ttl, Duration::default());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_lock_options() -> anyhow::Result<()> {
|
||||
let ConcurrencyLockOptions {
|
||||
epoch,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
} = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
|
||||
assert_eq!(epoch, Duration::from_secs(10 * 60));
|
||||
assert_eq!(timeout, Duration::from_secs(1));
|
||||
assert_eq!(shards, 32);
|
||||
assert_eq!(limiter.initial_limit, 4);
|
||||
assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
|
||||
|
||||
let ConcurrencyLockOptions {
|
||||
epoch,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
} = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
|
||||
assert_eq!(epoch, Duration::from_secs(60));
|
||||
assert_eq!(timeout, Duration::from_millis(100));
|
||||
assert_eq!(shards, 16);
|
||||
assert_eq!(limiter.initial_limit, 8);
|
||||
assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
|
||||
|
||||
let ConcurrencyLockOptions {
|
||||
epoch,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
} = "permits=0".parse()?;
|
||||
assert_eq!(epoch, Duration::ZERO);
|
||||
assert_eq!(timeout, Duration::ZERO);
|
||||
assert_eq!(shards, 2);
|
||||
assert_eq!(limiter.initial_limit, 0);
|
||||
assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_json_lock_options() -> anyhow::Result<()> {
|
||||
let ConcurrencyLockOptions {
|
||||
epoch,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
} = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
|
||||
.parse()?;
|
||||
assert_eq!(epoch, Duration::from_secs(10 * 60));
|
||||
assert_eq!(timeout, Duration::from_secs(1));
|
||||
assert_eq!(shards, 32);
|
||||
assert_eq!(limiter.initial_limit, 44);
|
||||
assert_eq!(
|
||||
limiter.algorithm,
|
||||
RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 5,
|
||||
max: 500,
|
||||
dec: 0.9,
|
||||
inc: 10,
|
||||
utilisation: 0.8
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
22
proxy/core/src/console.rs
Normal file
22
proxy/core/src/console.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! Various stuff for dealing with the Neon Console.
|
||||
//! Later we might move some API wrappers here.
|
||||
|
||||
/// Payloads used in the console's APIs.
|
||||
pub mod messages;
|
||||
|
||||
/// Wrappers for console APIs and their mocks.
|
||||
pub mod provider;
|
||||
pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod caches {
|
||||
pub use super::provider::{ApiCaches, NodeInfoCache};
|
||||
}
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod locks {
|
||||
pub use super::provider::ApiLocks;
|
||||
}
|
||||
|
||||
/// Console's management API.
|
||||
pub mod mgmt;
|
||||
447
proxy/core/src/console/messages.rs
Normal file
447
proxy/core/src/console/messages.rs
Normal file
@@ -0,0 +1,447 @@
|
||||
use measured::FixedCardinalityLabel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{self, Display};
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
|
||||
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
|
||||
use crate::proxy::retry::CouldRetry;
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct ConsoleError {
|
||||
pub error: Box<str>,
|
||||
#[serde(skip)]
|
||||
pub http_status_code: http::StatusCode,
|
||||
pub status: Option<Status>,
|
||||
}
|
||||
|
||||
impl ConsoleError {
|
||||
pub fn get_reason(&self) -> Reason {
|
||||
self.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.error_info.as_ref())
|
||||
.map(|e| e.reason)
|
||||
.unwrap_or(Reason::Unknown)
|
||||
}
|
||||
pub fn get_user_facing_message(&self) -> String {
|
||||
use super::provider::errors::REQUEST_FAILED;
|
||||
self.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.user_facing_message.as_ref())
|
||||
.map(|m| m.message.clone().into())
|
||||
.unwrap_or_else(|| {
|
||||
// Ask @neondatabase/control-plane for review before adding more.
|
||||
match self.http_status_code {
|
||||
http::StatusCode::NOT_FOUND => {
|
||||
// Status 404: failed to get a project-related resource.
|
||||
format!("{REQUEST_FAILED}: endpoint cannot be found")
|
||||
}
|
||||
http::StatusCode::NOT_ACCEPTABLE => {
|
||||
// Status 406: endpoint is disabled (we don't allow connections).
|
||||
format!("{REQUEST_FAILED}: endpoint is disabled")
|
||||
}
|
||||
http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
|
||||
// Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
|
||||
format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
|
||||
}
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ConsoleError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let msg = self
|
||||
.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.user_facing_message.as_ref())
|
||||
.map(|m| m.message.as_ref())
|
||||
.unwrap_or_else(|| &self.error);
|
||||
write!(f, "{}", msg)
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for ConsoleError {
|
||||
fn could_retry(&self) -> bool {
|
||||
// If the error message does not have a status,
|
||||
// the error is unknown and probably should not retry automatically
|
||||
let Some(status) = &self.status else {
|
||||
return false;
|
||||
};
|
||||
|
||||
// retry if the retry info is set.
|
||||
if status.details.retry_info.is_some() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// if no retry info set, attempt to use the error code to guess the retry state.
|
||||
let reason = status
|
||||
.details
|
||||
.error_info
|
||||
.map_or(Reason::Unknown, |e| e.reason);
|
||||
|
||||
reason.can_retry()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct Status {
|
||||
pub code: Box<str>,
|
||||
pub message: Box<str>,
|
||||
pub details: Details,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct Details {
|
||||
pub error_info: Option<ErrorInfo>,
|
||||
pub retry_info: Option<RetryInfo>,
|
||||
pub user_facing_message: Option<UserFacingMessage>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
pub struct ErrorInfo {
|
||||
pub reason: Reason,
|
||||
// Schema could also have `metadata` field, but it's not structured. Skip it for now.
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Default)]
|
||||
pub enum Reason {
|
||||
/// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
|
||||
#[serde(rename = "ROLE_PROTECTED")]
|
||||
RoleProtected,
|
||||
/// ResourceNotFound indicates that a resource (project, endpoint, branch, etc.) wasn't found,
|
||||
/// usually due to the provided ID not being correct or because the subject doesn't have enough permissions to
|
||||
/// access the requested resource.
|
||||
/// Prefer a more specific reason if possible, e.g., ProjectNotFound, EndpointNotFound, etc.
|
||||
#[serde(rename = "RESOURCE_NOT_FOUND")]
|
||||
ResourceNotFound,
|
||||
/// ProjectNotFound indicates that the project wasn't found, usually due to the provided ID not being correct,
|
||||
/// or that the subject doesn't have enough permissions to access the requested project.
|
||||
#[serde(rename = "PROJECT_NOT_FOUND")]
|
||||
ProjectNotFound,
|
||||
/// EndpointNotFound indicates that the endpoint wasn't found, usually due to the provided ID not being correct,
|
||||
/// or that the subject doesn't have enough permissions to access the requested endpoint.
|
||||
#[serde(rename = "ENDPOINT_NOT_FOUND")]
|
||||
EndpointNotFound,
|
||||
/// BranchNotFound indicates that the branch wasn't found, usually due to the provided ID not being correct,
|
||||
/// or that the subject doesn't have enough permissions to access the requested branch.
|
||||
#[serde(rename = "BRANCH_NOT_FOUND")]
|
||||
BranchNotFound,
|
||||
/// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
|
||||
#[serde(rename = "RATE_LIMIT_EXCEEDED")]
|
||||
RateLimitExceeded,
|
||||
/// NonDefaultBranchComputeTimeExceeded indicates that the compute time quota of non-default branches has been
|
||||
/// exceeded.
|
||||
#[serde(rename = "NON_PRIMARY_BRANCH_COMPUTE_TIME_EXCEEDED")]
|
||||
NonDefaultBranchComputeTimeExceeded,
|
||||
/// ActiveTimeQuotaExceeded indicates that the active time quota was exceeded.
|
||||
#[serde(rename = "ACTIVE_TIME_QUOTA_EXCEEDED")]
|
||||
ActiveTimeQuotaExceeded,
|
||||
/// ComputeTimeQuotaExceeded indicates that the compute time quota was exceeded.
|
||||
#[serde(rename = "COMPUTE_TIME_QUOTA_EXCEEDED")]
|
||||
ComputeTimeQuotaExceeded,
|
||||
/// WrittenDataQuotaExceeded indicates that the written data quota was exceeded.
|
||||
#[serde(rename = "WRITTEN_DATA_QUOTA_EXCEEDED")]
|
||||
WrittenDataQuotaExceeded,
|
||||
/// DataTransferQuotaExceeded indicates that the data transfer quota was exceeded.
|
||||
#[serde(rename = "DATA_TRANSFER_QUOTA_EXCEEDED")]
|
||||
DataTransferQuotaExceeded,
|
||||
/// LogicalSizeQuotaExceeded indicates that the logical size quota was exceeded.
|
||||
#[serde(rename = "LOGICAL_SIZE_QUOTA_EXCEEDED")]
|
||||
LogicalSizeQuotaExceeded,
|
||||
/// RunningOperations indicates that the project already has some running operations
|
||||
/// and scheduling of new ones is prohibited.
|
||||
#[serde(rename = "RUNNING_OPERATIONS")]
|
||||
RunningOperations,
|
||||
/// ConcurrencyLimitReached indicates that the concurrency limit for an action was reached.
|
||||
#[serde(rename = "CONCURRENCY_LIMIT_REACHED")]
|
||||
ConcurrencyLimitReached,
|
||||
/// LockAlreadyTaken indicates that the we attempted to take a lock that was already taken.
|
||||
#[serde(rename = "LOCK_ALREADY_TAKEN")]
|
||||
LockAlreadyTaken,
|
||||
#[default]
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Reason {
|
||||
pub fn is_not_found(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Reason::ResourceNotFound
|
||||
| Reason::ProjectNotFound
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::BranchNotFound
|
||||
)
|
||||
}
|
||||
|
||||
pub fn can_retry(&self) -> bool {
|
||||
match self {
|
||||
// do not retry role protected errors
|
||||
// not a transitive error
|
||||
Reason::RoleProtected => false,
|
||||
// on retry, it will still not be found
|
||||
Reason::ResourceNotFound
|
||||
| Reason::ProjectNotFound
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::BranchNotFound => false,
|
||||
// we were asked to go away
|
||||
Reason::RateLimitExceeded
|
||||
| Reason::NonDefaultBranchComputeTimeExceeded
|
||||
| Reason::ActiveTimeQuotaExceeded
|
||||
| Reason::ComputeTimeQuotaExceeded
|
||||
| Reason::WrittenDataQuotaExceeded
|
||||
| Reason::DataTransferQuotaExceeded
|
||||
| Reason::LogicalSizeQuotaExceeded => false,
|
||||
// transitive error. control plane is currently busy
|
||||
// but might be ready soon
|
||||
Reason::RunningOperations
|
||||
| Reason::ConcurrencyLimitReached
|
||||
| Reason::LockAlreadyTaken => true,
|
||||
// unknown error. better not retry it.
|
||||
Reason::Unknown => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
pub struct RetryInfo {
|
||||
pub retry_delay_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct UserFacingMessage {
|
||||
pub message: Box<str>,
|
||||
}
|
||||
|
||||
/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
|
||||
/// Returned by the `/proxy_get_role_secret` API method.
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetRoleSecret {
|
||||
pub role_secret: Box<str>,
|
||||
pub allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub project_id: Option<ProjectIdInt>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
impl fmt::Debug for GetRoleSecret {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("GetRoleSecret").finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Response which holds compute node's `host:port` pair.
|
||||
/// Returned by the `/proxy_wake_compute` API method.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WakeCompute {
|
||||
pub address: Box<str>,
|
||||
pub aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
/// Async response which concludes the link auth flow.
|
||||
/// Also known as `kickResponse` in the console.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct KickSession<'a> {
|
||||
/// Session ID is assigned by the proxy.
|
||||
pub session_id: &'a str,
|
||||
|
||||
/// Compute node connection params.
|
||||
#[serde(deserialize_with = "KickSession::parse_db_info")]
|
||||
pub result: DatabaseInfo,
|
||||
}
|
||||
|
||||
impl KickSession<'_> {
|
||||
fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
enum Wrapper {
|
||||
// Currently, console only reports `Success`.
|
||||
// `Failure(String)` used to be here... RIP.
|
||||
Success(DatabaseInfo),
|
||||
}
|
||||
|
||||
Wrapper::deserialize(des).map(|x| match x {
|
||||
Wrapper::Success(info) => info,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute node connection params.
|
||||
#[derive(Deserialize)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: Box<str>,
|
||||
pub port: u16,
|
||||
pub dbname: Box<str>,
|
||||
pub user: Box<str>,
|
||||
/// Console always provides a password, but it might
|
||||
/// be inconvenient for debug with local PG instance.
|
||||
pub password: Option<Box<str>>,
|
||||
pub aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
impl fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.field("dbname", &self.dbname)
|
||||
.field("user", &self.user)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Various labels for prometheus metrics.
|
||||
/// Also known as `ProxyMetricsAuxInfo` in the console.
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct MetricsAuxInfo {
|
||||
pub endpoint_id: EndpointIdInt,
|
||||
pub project_id: ProjectIdInt,
|
||||
pub branch_id: BranchIdInt,
|
||||
#[serde(default)]
|
||||
pub cold_start_info: ColdStartInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ColdStartInfo {
|
||||
#[default]
|
||||
Unknown,
|
||||
/// Compute was already running
|
||||
Warm,
|
||||
#[serde(rename = "pool_hit")]
|
||||
#[label(rename = "pool_hit")]
|
||||
/// Compute was not running but there was an available VM
|
||||
VmPoolHit,
|
||||
#[serde(rename = "pool_miss")]
|
||||
#[label(rename = "pool_miss")]
|
||||
/// Compute was not running and there were no VMs available
|
||||
VmPoolMiss,
|
||||
|
||||
// not provided by control plane
|
||||
/// Connection available from HTTP pool
|
||||
HttpPoolHit,
|
||||
/// Cached connection info
|
||||
WarmCached,
|
||||
}
|
||||
|
||||
impl ColdStartInfo {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
ColdStartInfo::Unknown => "unknown",
|
||||
ColdStartInfo::Warm => "warm",
|
||||
ColdStartInfo::VmPoolHit => "pool_hit",
|
||||
ColdStartInfo::VmPoolMiss => "pool_miss",
|
||||
ColdStartInfo::HttpPoolHit => "http_pool_hit",
|
||||
ColdStartInfo::WarmCached => "warm_cached",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn dummy_aux() -> serde_json::Value {
|
||||
json!({
|
||||
"endpoint_id": "endpoint",
|
||||
"project_id": "project",
|
||||
"branch_id": "branch",
|
||||
"cold_start_info": "unknown",
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_kick_session() -> anyhow::Result<()> {
|
||||
// This is what the console's kickResponse looks like.
|
||||
let json = json!({
|
||||
"session_id": "deadbeef",
|
||||
"result": {
|
||||
"Success": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
"aux": dummy_aux(),
|
||||
}
|
||||
}
|
||||
});
|
||||
let _: KickSession = serde_json::from_str(&json.to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
// with password
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
// without password
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
// new field (forward compatibility)
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"project": "hello_world",
|
||||
"N.E.W": "forward compatibility check",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wake_compute() -> anyhow::Result<()> {
|
||||
let json = json!({
|
||||
"address": "0.0.0.0",
|
||||
"aux": dummy_aux(),
|
||||
});
|
||||
let _: WakeCompute = serde_json::from_str(&json.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_get_role_secret() -> anyhow::Result<()> {
|
||||
// Empty `allowed_ips` field.
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
});
|
||||
let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
});
|
||||
let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
"project_id": "project",
|
||||
});
|
||||
let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
115
proxy/core/src/console/mgmt.rs
Normal file
115
proxy/core/src/console/mgmt.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use crate::{
|
||||
console::messages::{DatabaseInfo, KickSession},
|
||||
waiters::{self, Waiter, Waiters},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
|
||||
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
|
||||
use std::convert::Infallible;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
|
||||
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
|
||||
|
||||
/// Give caller an opportunity to wait for the cloud's reply.
|
||||
pub fn get_waiter(
|
||||
psql_session_id: impl Into<String>,
|
||||
) -> Result<Waiter<'static, ComputeReady>, waiters::RegisterError> {
|
||||
CPLANE_WAITERS.register(psql_session_id.into())
|
||||
}
|
||||
|
||||
pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Console management API listener task.
|
||||
/// It spawns console response handlers needed for the link auth.
|
||||
pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
|
||||
scopeguard::defer! {
|
||||
info!("mgmt has shut down");
|
||||
}
|
||||
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept().await?;
|
||||
info!("accepted connection from {peer_addr}");
|
||||
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set client socket option")?;
|
||||
|
||||
let span = info_span!("mgmt", peer = %peer_addr);
|
||||
|
||||
tokio::task::spawn(
|
||||
async move {
|
||||
info!("serving a new console management API connection");
|
||||
|
||||
// these might be long running connections, have a separate logging for cancelling
|
||||
// on shutdown and other ways of stopping.
|
||||
let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
|
||||
let _e = span.entered();
|
||||
info!("console management API task cancelled");
|
||||
});
|
||||
|
||||
if let Err(e) = handle_connection(socket).await {
|
||||
error!("serving failed with an error: {e}");
|
||||
} else {
|
||||
info!("serving completed");
|
||||
}
|
||||
|
||||
// we can no longer get dropped
|
||||
scopeguard::ScopeGuard::into_inner(cancelled);
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
|
||||
pgbackend
|
||||
.run(&mut MgmtHandler, &CancellationToken::new())
|
||||
.await
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
pub type ComputeReady = DatabaseInfo;
|
||||
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
query: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
try_process_query(pgb, query).map_err(|e| {
|
||||
error!("failed to process response: {e:?}");
|
||||
e
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
|
||||
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
|
||||
|
||||
let span = info_span!("event", session_id = resp.session_id);
|
||||
let _enter = span.enter();
|
||||
info!("got response: {:?}", resp.result);
|
||||
|
||||
match notify(resp.session_id, resp.result) {
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to deliver response to per-client task");
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
561
proxy/core/src/console/provider.rs
Normal file
561
proxy/core/src/console/provider.rs
Normal file
@@ -0,0 +1,561 @@
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use super::messages::{ConsoleError, MetricsAuxInfo};
|
||||
use crate::{
|
||||
auth::{
|
||||
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
||||
IpPattern,
|
||||
},
|
||||
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
intern::ProjectIdInt,
|
||||
metrics::ApiLockMetrics,
|
||||
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
|
||||
scram, EndpointCacheKey,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::{hash::Hash, sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
pub mod errors {
|
||||
use crate::{
|
||||
console::messages::{self, ConsoleError, Reason},
|
||||
error::{io_error, ReportableError, UserFacingError},
|
||||
proxy::retry::CouldRetry,
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ApiLockError;
|
||||
|
||||
/// A go-to error message which doesn't leak any detail.
|
||||
pub const REQUEST_FAILED: &str = "Console request failed";
|
||||
|
||||
/// Common console API error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
/// Error returned by the console itself.
|
||||
#[error("{REQUEST_FAILED} with {0}")]
|
||||
Console(ConsoleError),
|
||||
|
||||
/// Various IO errors like broken pipe or malformed payload.
|
||||
#[error("{REQUEST_FAILED}: {0}")]
|
||||
Transport(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
pub fn get_reason(&self) -> messages::Reason {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
Console(e) => e.get_reason(),
|
||||
_ => messages::Reason::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ApiError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
// To minimize risks, only select errors are forwarded to users.
|
||||
Console(c) => c.get_user_facing_message(),
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ApiError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiError::Console(e) => {
|
||||
use crate::error::ErrorKind::*;
|
||||
match e.get_reason() {
|
||||
Reason::RoleProtected => User,
|
||||
Reason::ResourceNotFound => User,
|
||||
Reason::ProjectNotFound => User,
|
||||
Reason::EndpointNotFound => User,
|
||||
Reason::BranchNotFound => User,
|
||||
Reason::RateLimitExceeded => ServiceRateLimit,
|
||||
Reason::NonDefaultBranchComputeTimeExceeded => User,
|
||||
Reason::ActiveTimeQuotaExceeded => User,
|
||||
Reason::ComputeTimeQuotaExceeded => User,
|
||||
Reason::WrittenDataQuotaExceeded => User,
|
||||
Reason::DataTransferQuotaExceeded => User,
|
||||
Reason::LogicalSizeQuotaExceeded => User,
|
||||
Reason::ConcurrencyLimitReached => ControlPlane,
|
||||
Reason::LockAlreadyTaken => ControlPlane,
|
||||
Reason::RunningOperations => ControlPlane,
|
||||
Reason::Unknown => match &e {
|
||||
ConsoleError {
|
||||
http_status_code:
|
||||
http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
|
||||
..
|
||||
} => crate::error::ErrorKind::User,
|
||||
ConsoleError {
|
||||
http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
|
||||
error,
|
||||
..
|
||||
} if error.contains(
|
||||
"compute time quota of non-primary branches is exceeded",
|
||||
) =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ConsoleError {
|
||||
http_status_code: http::StatusCode::LOCKED,
|
||||
error,
|
||||
..
|
||||
} if error.contains("quota exceeded")
|
||||
|| error.contains("the limit for current plan reached") =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ConsoleError {
|
||||
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
|
||||
..
|
||||
} => crate::error::ErrorKind::ServiceRateLimit,
|
||||
ConsoleError { .. } => crate::error::ErrorKind::ControlPlane,
|
||||
},
|
||||
}
|
||||
}
|
||||
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for ApiError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
// retry some transport errors
|
||||
Self::Transport(io) => io.could_retry(),
|
||||
Self::Console(e) => e.could_retry(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for ApiError {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest_middleware::Error> for ApiError {
|
||||
fn from(e: reqwest_middleware::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Console responded with a malformed auth secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for GetAuthInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use GetAuthInfoError::*;
|
||||
match self {
|
||||
// We absolutely should not leak any secrets!
|
||||
BadSecret => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for GetAuthInfoError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
|
||||
#[error("Too many connections attempts")]
|
||||
TooManyConnections,
|
||||
|
||||
#[error("error acquiring resource permit: {0}")]
|
||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for WakeComputeError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use WakeComputeError::*;
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
|
||||
TooManyConnections => self.to_string(),
|
||||
|
||||
TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for WakeComputeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
|
||||
WakeComputeError::ApiError(e) => e.get_error_kind(),
|
||||
WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
WakeComputeError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for WakeComputeError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
WakeComputeError::BadComputeAddress(_) => false,
|
||||
WakeComputeError::ApiError(e) => e.could_retry(),
|
||||
WakeComputeError::TooManyConnections => false,
|
||||
WakeComputeError::TooManyConnectionAttempts(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub enum AuthSecret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AuthInfo {
|
||||
pub secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub allowed_ips: Vec<IpPattern>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub project_id: Option<ProjectIdInt>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
pub struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub config: compute::ConnCfg,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(
|
||||
ctx,
|
||||
self.allow_self_signed_compute,
|
||||
self.aux.clone(),
|
||||
timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
pub fn reuse_settings(&mut self, other: Self) {
|
||||
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ConsoleError>>>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
pub(crate) trait Api {
|
||||
/// Get the client's auth secret for authentication.
|
||||
/// Returns option because user not found situation is special.
|
||||
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum ConsoleBackend {
|
||||
/// Current Cloud API (V2).
|
||||
Console(neon::Api),
|
||||
/// Local mock of Cloud API (V2).
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Postgres(mock::Api),
|
||||
/// Internal testing
|
||||
#[cfg(test)]
|
||||
Test(Box<dyn crate::auth::backend::TestBackend>),
|
||||
}
|
||||
|
||||
impl Api for ConsoleBackend {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
|
||||
use ConsoleBackend::*;
|
||||
match self {
|
||||
Console(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Postgres(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Test(_) => unreachable!("this function should never be called in the test backend"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
|
||||
use ConsoleBackend::*;
|
||||
match self {
|
||||
Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Test(api) => api.get_allowed_ips_and_secret(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
||||
use ConsoleBackend::*;
|
||||
|
||||
match self {
|
||||
Console(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Postgres(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Test(api) => api.wake_compute(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub node_info: NodeInfoCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
/// List of all valid endpoints.
|
||||
pub endpoints_cache: Arc<EndpointsCache>,
|
||||
}
|
||||
|
||||
impl ApiCaches {
|
||||
pub fn new(
|
||||
wake_compute_cache_config: CacheOptions,
|
||||
project_info_cache_config: ProjectInfoCacheOptions,
|
||||
endpoint_cache_config: EndpointCacheConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_info: NodeInfoCache::new(
|
||||
"node_info_cache",
|
||||
wake_compute_cache_config.size,
|
||||
wake_compute_cache_config.ttl,
|
||||
true,
|
||||
),
|
||||
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
||||
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiLocks<K> {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<K, Arc<DynamicLimiter>>,
|
||||
config: RateLimiterConfig,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ApiLockError {
|
||||
#[error("timeout acquiring resource permit")]
|
||||
TimeoutError(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
impl ReportableError for ApiLockError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
config: RateLimiterConfig,
|
||||
shards: usize,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
) -> prometheus::Result<Self> {
|
||||
Ok(Self {
|
||||
name,
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
config,
|
||||
timeout,
|
||||
epoch,
|
||||
metrics,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
|
||||
if self.config.initial_limit == 0 {
|
||||
return Ok(WakeComputePermit {
|
||||
permit: Token::disabled(),
|
||||
});
|
||||
}
|
||||
let now = Instant::now();
|
||||
let semaphore = {
|
||||
// get fast path
|
||||
if let Some(semaphore) = self.node_locks.get(key) {
|
||||
semaphore.clone()
|
||||
} else {
|
||||
self.node_locks
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| {
|
||||
self.metrics.semaphores_registered.inc();
|
||||
DynamicLimiter::new(self.config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
};
|
||||
let permit = semaphore.acquire_timeout(self.timeout).await;
|
||||
|
||||
self.metrics
|
||||
.semaphore_acquire_seconds
|
||||
.observe(now.elapsed().as_secs_f64());
|
||||
info!("acquired permit {:?}", now.elapsed().as_secs_f64());
|
||||
Ok(WakeComputePermit { permit: permit? })
|
||||
}
|
||||
|
||||
pub async fn garbage_collect_worker(&self) {
|
||||
if self.config.initial_limit == 0 {
|
||||
return;
|
||||
}
|
||||
let mut interval =
|
||||
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
|
||||
loop {
|
||||
for (i, shard) in self.node_locks.shards().iter().enumerate() {
|
||||
interval.tick().await;
|
||||
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
|
||||
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
|
||||
// therefore releasing it is safe from race conditions
|
||||
info!(
|
||||
name = self.name,
|
||||
shard = i,
|
||||
"performing epoch reclamation on api lock"
|
||||
);
|
||||
let mut lock = shard.write();
|
||||
let timer = self.metrics.reclamation_lag_seconds.start_timer();
|
||||
let count = lock
|
||||
.extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
|
||||
.count();
|
||||
drop(lock);
|
||||
self.metrics.semaphores_unregistered.inc_by(count as u64);
|
||||
timer.observe();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WakeComputePermit {
|
||||
permit: Token,
|
||||
}
|
||||
|
||||
impl WakeComputePermit {
|
||||
pub fn should_check_cache(&self) -> bool {
|
||||
!self.permit.is_disabled()
|
||||
}
|
||||
pub fn release(self, outcome: Outcome) {
|
||||
self.permit.release(outcome)
|
||||
}
|
||||
pub fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
|
||||
match res {
|
||||
Ok(_) => self.release(Outcome::Success),
|
||||
Err(_) => self.release(Outcome::Overload),
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
199
proxy/core/src/console/provider/mock.rs
Normal file
199
proxy/core/src/console/provider/mock.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
//! Mock console backend which relies on a user-provided postgres instance.
|
||||
|
||||
use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||
};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::IpPattern, cache::Cached};
|
||||
use crate::{
|
||||
console::{
|
||||
messages::MetricsAuxInfo,
|
||||
provider::{CachedAllowedIps, CachedRoleSecret},
|
||||
},
|
||||
BranchId, EndpointId, ProjectId,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::{config::SslMode, Client};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum MockApiError {
|
||||
#[error("Failed to read password: {0}")]
|
||||
PasswordNotSet(tokio_postgres::Error),
|
||||
}
|
||||
|
||||
impl From<MockApiError> for ApiError {
|
||||
fn from(e: MockApiError) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Error> for ApiError {
|
||||
fn from(e: tokio_postgres::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: ApiUrl,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
pub fn new(endpoint: ApiUrl) -> Self {
|
||||
Self { endpoint }
|
||||
}
|
||||
|
||||
pub fn url(&self) -> &str {
|
||||
self.endpoint.as_str()
|
||||
}
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let (secret, allowed_ips) = async {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
let secret = match get_execute_postgres_query(
|
||||
&client,
|
||||
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
|
||||
&[&&*user_info.user],
|
||||
"rolpassword",
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(entry) => {
|
||||
info!("got a secret: {entry}"); // safe since it's not a prod scenario
|
||||
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
}
|
||||
None => {
|
||||
warn!("user '{}' does not exist", user_info.user);
|
||||
None
|
||||
}
|
||||
};
|
||||
let allowed_ips = match get_execute_postgres_query(
|
||||
&client,
|
||||
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
|
||||
&[&user_info.endpoint.as_str()],
|
||||
"allowed_ips",
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(s) => {
|
||||
info!("got allowed_ips: {s}");
|
||||
s.split(',')
|
||||
.map(|s| IpPattern::from_str(s).unwrap())
|
||||
.collect()
|
||||
}
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
Ok((secret, allowed_ips))
|
||||
}
|
||||
.map_err(crate::error::log_error::<GetAuthInfoError>)
|
||||
.instrument(info_span!("postgres", url = self.endpoint.as_str()))
|
||||
.await?;
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
project_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.ssl_mode(SslMode::Disable);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_execute_postgres_query(
|
||||
client: &Client,
|
||||
query: &str,
|
||||
params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
|
||||
idx: &str,
|
||||
) -> Result<Option<String>, GetAuthInfoError> {
|
||||
let rows = client.query(query, params).await?;
|
||||
|
||||
// We can get at most one row, because `rolname` is unique.
|
||||
let row = match rows.first() {
|
||||
Some(row) => row,
|
||||
// This means that the user doesn't exist, so there can be no secret.
|
||||
// However, this is still a *valid* outcome which is very similar
|
||||
// to getting `404 Not found` from the Neon console.
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
|
||||
Ok(Some(entry))
|
||||
}
|
||||
|
||||
impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(
|
||||
self.do_get_auth_info(user_info).await?.secret,
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
Ok((
|
||||
Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(user_info).await?.allowed_ips,
|
||||
)),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute().map_ok(Cached::new_uncached).await
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_md5(input: &str) -> Option<[u8; 16]> {
|
||||
let text = input.strip_prefix("md5")?;
|
||||
|
||||
let mut bytes = [0u8; 16];
|
||||
hex::decode_to_slice(text, &mut bytes).ok()?;
|
||||
|
||||
Some(bytes)
|
||||
}
|
||||
425
proxy/core/src/console/provider/neon.rs
Normal file
425
proxy/core/src/console/provider/neon.rs
Normal file
@@ -0,0 +1,425 @@
|
||||
//! Production console backend.
|
||||
|
||||
use super::{
|
||||
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
|
||||
NodeInfo,
|
||||
};
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
compute,
|
||||
console::messages::{ColdStartInfo, Reason},
|
||||
http,
|
||||
metrics::{CacheOutcome, Metrics},
|
||||
rate_limiter::WakeComputeRateLimiter,
|
||||
scram, EndpointCacheKey,
|
||||
};
|
||||
use crate::{cache::Cached, context::RequestMonitoring};
|
||||
use futures::TryFutureExt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
pub locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
pub wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
jwt: String,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub fn new(
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
) -> Self {
|
||||
let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
|
||||
Ok(v) => v,
|
||||
Err(_) => "".to_string(),
|
||||
};
|
||||
Self {
|
||||
endpoint,
|
||||
caches,
|
||||
locks,
|
||||
wake_compute_endpoint_rate_limiter,
|
||||
jwt,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn url(&self) -> &str {
|
||||
self.endpoint.url().as_str()
|
||||
}
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
if !self
|
||||
.caches
|
||||
.endpoints_cache
|
||||
.is_valid(ctx, &user_info.endpoint.normalize())
|
||||
.await
|
||||
{
|
||||
info!("endpoint is not valid, skipping the request");
|
||||
return Ok(AuthInfo::default());
|
||||
}
|
||||
let request_id = ctx.session_id().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
("role", user_info.user.as_str()),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
let body = match parse_body::<GetRoleSecret>(response).await {
|
||||
Ok(body) => body,
|
||||
// Error 404 is special: it's ok not to have a secret.
|
||||
// TODO(anna): retry
|
||||
Err(e) => {
|
||||
if e.get_reason().is_not_found() {
|
||||
return Ok(AuthInfo::default());
|
||||
} else {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let secret = if body.role_secret.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let secret = scram::ServerSecret::parse(&body.role_secret)
|
||||
.map(AuthSecret::Scram)
|
||||
.ok_or(GetAuthInfoError::BadSecret)?;
|
||||
Some(secret)
|
||||
};
|
||||
let allowed_ips = body.allowed_ips.unwrap_or_default();
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_ips_number
|
||||
.observe(allowed_ips.len() as f64);
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
project_id: body.project_id,
|
||||
})
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let request_id = ctx.session_id().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let mut request_builder = self
|
||||
.endpoint
|
||||
.get("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
]);
|
||||
|
||||
let options = user_info.options.to_deep_object();
|
||||
if !options.is_empty() {
|
||||
request_builder = request_builder.query(&options);
|
||||
}
|
||||
|
||||
let request = request_builder.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
let body = parse_body::<WakeCompute>(response).await?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&body.address) {
|
||||
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
// Don't set anything but host and port! This config will be cached.
|
||||
// We'll set username and such later using the startup message.
|
||||
// TODO: add more type safety (in progress).
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: body.aux,
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
let user = &user_info.user;
|
||||
if let Some(role_secret) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_role_secret(normalized_ep, user)
|
||||
{
|
||||
return Ok(role_secret);
|
||||
}
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
Arc::new(auth_info.allowed_ips),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
// When we just got a secret, we don't need to invalidate it.
|
||||
Ok(Cached::new_uncached(auth_info.secret))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_ips_cache_misses
|
||||
.inc(CacheOutcome::Hit);
|
||||
return Ok((allowed_ips, None));
|
||||
}
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_ips_cache_misses
|
||||
.inc(CacheOutcome::Miss);
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let user = &user_info.user;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok((
|
||||
Cached::new_uncached(allowed_ips),
|
||||
Some(Cached::new_uncached(auth_info.secret)),
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key = user_info.endpoint_cache_key();
|
||||
|
||||
macro_rules! check_cache {
|
||||
() => {
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
let (cached, info) = cached.take_value();
|
||||
let info = info.map_err(|c| {
|
||||
info!(key = &*key, "found cached wake_compute error");
|
||||
WakeComputeError::ApiError(ApiError::Console(*c))
|
||||
})?;
|
||||
|
||||
debug!(key = &*key, "found cached compute node info");
|
||||
ctx.set_project(info.aux.clone());
|
||||
return Ok(cached.map(|()| info));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
check_cache!();
|
||||
|
||||
let permit = self.locks.get_permit(&key).await?;
|
||||
|
||||
// after getting back a permit - it's possible the cache was filled
|
||||
// double check
|
||||
if permit.should_check_cache() {
|
||||
check_cache!();
|
||||
}
|
||||
|
||||
// check rate limit
|
||||
if !self
|
||||
.wake_compute_endpoint_rate_limiter
|
||||
.check(user_info.endpoint.normalize_intern(), 1)
|
||||
{
|
||||
return Err(WakeComputeError::TooManyConnections);
|
||||
}
|
||||
|
||||
let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
|
||||
match node {
|
||||
Ok(node) => {
|
||||
ctx.set_project(node.aux.clone());
|
||||
debug!(key = &*key, "created a cache entry for woken compute node");
|
||||
|
||||
let mut stored_node = node.clone();
|
||||
// store the cached node as 'warm_cached'
|
||||
stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
|
||||
|
||||
let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
|
||||
|
||||
Ok(cached.map(|()| node))
|
||||
}
|
||||
Err(err) => match err {
|
||||
WakeComputeError::ApiError(ApiError::Console(err)) => {
|
||||
let Some(status) = &err.status else {
|
||||
return Err(WakeComputeError::ApiError(ApiError::Console(err)));
|
||||
};
|
||||
|
||||
let reason = status
|
||||
.details
|
||||
.error_info
|
||||
.map_or(Reason::Unknown, |x| x.reason);
|
||||
|
||||
// if we can retry this error, do not cache it.
|
||||
if reason.can_retry() {
|
||||
return Err(WakeComputeError::ApiError(ApiError::Console(err)));
|
||||
}
|
||||
|
||||
// at this point, we should only have quota errors.
|
||||
debug!(
|
||||
key = &*key,
|
||||
"created a cache entry for the wake compute error"
|
||||
);
|
||||
|
||||
self.caches.node_info.insert_ttl(
|
||||
key,
|
||||
Err(Box::new(err.clone())),
|
||||
Duration::from_secs(30),
|
||||
);
|
||||
|
||||
Err(WakeComputeError::ApiError(ApiError::Console(err)))
|
||||
}
|
||||
err => return Err(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
response: http::Response,
|
||||
) -> Result<T, ApiError> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
// We shouldn't log raw body because it may contain secrets.
|
||||
info!("request succeeded, processing the body");
|
||||
return Ok(response.json().await?);
|
||||
}
|
||||
let s = response.bytes().await?;
|
||||
// Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
|
||||
info!("response_error plaintext: {:?}", s);
|
||||
|
||||
// Don't throw an error here because it's not as important
|
||||
// as the fact that the request itself has failed.
|
||||
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
|
||||
warn!("failed to parse error body: {e}");
|
||||
ConsoleError {
|
||||
error: "reason unclear (malformed error message)".into(),
|
||||
http_status_code: status,
|
||||
status: None,
|
||||
}
|
||||
});
|
||||
body.http_status_code = status;
|
||||
|
||||
error!("console responded with an error ({status}): {body:?}");
|
||||
Err(ApiError::Console(body))
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
let (host, port) = input.rsplit_once(':')?;
|
||||
let ipv6_brackets: &[_] = &['[', ']'];
|
||||
Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port_v4() {
|
||||
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
|
||||
assert_eq!(host, "127.0.0.1");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port_v6() {
|
||||
let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
|
||||
assert_eq!(host, "2001:db8::1");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port_url() {
|
||||
let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
|
||||
.expect("failed to parse");
|
||||
assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
}
|
||||
386
proxy/core/src/context.rs
Normal file
386
proxy/core/src/context.rs
Normal file
@@ -0,0 +1,386 @@
|
||||
//! Connection request monitoring contexts
|
||||
|
||||
use chrono::Utc;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{field::display, info, info_span, Span};
|
||||
use try_lock::TryLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
console::messages::{ColdStartInfo, MetricsAuxInfo},
|
||||
error::ErrorKind,
|
||||
intern::{BranchIdInt, ProjectIdInt},
|
||||
metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting},
|
||||
DbName, EndpointId, RoleName,
|
||||
};
|
||||
|
||||
use self::parquet::RequestData;
|
||||
|
||||
pub mod parquet;
|
||||
|
||||
pub static LOG_CHAN: OnceCell<mpsc::WeakUnboundedSender<RequestData>> = OnceCell::new();
|
||||
pub static LOG_CHAN_DISCONNECT: OnceCell<mpsc::WeakUnboundedSender<RequestData>> = OnceCell::new();
|
||||
|
||||
/// Context data for a single request to connect to a database.
|
||||
///
|
||||
/// This data should **not** be used for connection logic, only for observability and limiting purposes.
|
||||
/// All connection logic should instead use strongly typed state machines, not a bunch of Options.
|
||||
pub struct RequestMonitoring(
|
||||
/// To allow easier use of the ctx object, we have interior mutability.
|
||||
/// I would typically use a RefCell but that would break the `Send` requirements
|
||||
/// so we need something with thread-safety. `TryLock` is a cheap alternative
|
||||
/// that offers similar semantics to a `RefCell` but with synchronisation.
|
||||
TryLock<RequestMonitoringInner>,
|
||||
);
|
||||
|
||||
struct RequestMonitoringInner {
|
||||
pub peer_addr: IpAddr,
|
||||
pub session_id: Uuid,
|
||||
pub protocol: Protocol,
|
||||
first_packet: chrono::DateTime<Utc>,
|
||||
region: &'static str,
|
||||
pub span: Span,
|
||||
|
||||
// filled in as they are discovered
|
||||
project: Option<ProjectIdInt>,
|
||||
branch: Option<BranchIdInt>,
|
||||
endpoint_id: Option<EndpointId>,
|
||||
dbname: Option<DbName>,
|
||||
user: Option<RoleName>,
|
||||
application: Option<SmolStr>,
|
||||
error_kind: Option<ErrorKind>,
|
||||
pub(crate) auth_method: Option<AuthMethod>,
|
||||
success: bool,
|
||||
pub(crate) cold_start_info: ColdStartInfo,
|
||||
pg_options: Option<StartupMessageParams>,
|
||||
|
||||
// extra
|
||||
// This sender is here to keep the request monitoring channel open while requests are taking place.
|
||||
sender: Option<mpsc::UnboundedSender<RequestData>>,
|
||||
// This sender is only used to log the length of session in case of success.
|
||||
disconnect_sender: Option<mpsc::UnboundedSender<RequestData>>,
|
||||
pub latency_timer: LatencyTimer,
|
||||
// Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
|
||||
rejected: Option<bool>,
|
||||
disconnect_timestamp: Option<chrono::DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AuthMethod {
|
||||
// aka link aka passwordless
|
||||
Web,
|
||||
ScramSha256,
|
||||
ScramSha256Plus,
|
||||
Cleartext,
|
||||
}
|
||||
|
||||
impl RequestMonitoring {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
protocol: Protocol,
|
||||
region: &'static str,
|
||||
) -> Self {
|
||||
let span = info_span!(
|
||||
"connect_request",
|
||||
%protocol,
|
||||
?session_id,
|
||||
%peer_addr,
|
||||
ep = tracing::field::Empty,
|
||||
role = tracing::field::Empty,
|
||||
);
|
||||
|
||||
let inner = RequestMonitoringInner {
|
||||
peer_addr,
|
||||
session_id,
|
||||
protocol,
|
||||
first_packet: Utc::now(),
|
||||
region,
|
||||
span,
|
||||
|
||||
project: None,
|
||||
branch: None,
|
||||
endpoint_id: None,
|
||||
dbname: None,
|
||||
user: None,
|
||||
application: None,
|
||||
error_kind: None,
|
||||
auth_method: None,
|
||||
success: false,
|
||||
rejected: None,
|
||||
cold_start_info: ColdStartInfo::Unknown,
|
||||
pg_options: None,
|
||||
|
||||
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
||||
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
|
||||
latency_timer: LatencyTimer::new(protocol),
|
||||
disconnect_timestamp: None,
|
||||
};
|
||||
|
||||
Self(TryLock::new(inner))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
|
||||
}
|
||||
|
||||
pub fn console_application_name(&self) -> String {
|
||||
let this = self.0.try_lock().expect("should not deadlock");
|
||||
format!(
|
||||
"{}/{}",
|
||||
this.application.as_deref().unwrap_or_default(),
|
||||
this.protocol
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_rejected(&self, rejected: bool) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.rejected = Some(rejected);
|
||||
}
|
||||
|
||||
pub fn set_cold_start_info(&self, info: ColdStartInfo) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_cold_start_info(info);
|
||||
}
|
||||
|
||||
pub fn set_db_options(&self, options: StartupMessageParams) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.set_application(options.get("application_name").map(SmolStr::from));
|
||||
if let Some(user) = options.get("user") {
|
||||
this.set_user(user.into());
|
||||
}
|
||||
if let Some(dbname) = options.get("database") {
|
||||
this.set_dbname(dbname.into());
|
||||
}
|
||||
|
||||
this.pg_options = Some(options);
|
||||
}
|
||||
|
||||
pub fn set_project(&self, x: MetricsAuxInfo) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
if this.endpoint_id.is_none() {
|
||||
this.set_endpoint_id(x.endpoint_id.as_str().into())
|
||||
}
|
||||
this.branch = Some(x.branch_id);
|
||||
this.project = Some(x.project_id);
|
||||
this.set_cold_start_info(x.cold_start_info);
|
||||
}
|
||||
|
||||
pub fn set_project_id(&self, project_id: ProjectIdInt) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.project = Some(project_id);
|
||||
}
|
||||
|
||||
pub fn set_endpoint_id(&self, endpoint_id: EndpointId) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_endpoint_id(endpoint_id);
|
||||
}
|
||||
|
||||
pub fn set_dbname(&self, dbname: DbName) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_dbname(dbname);
|
||||
}
|
||||
|
||||
pub fn set_user(&self, user: RoleName) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_user(user);
|
||||
}
|
||||
|
||||
pub fn set_auth_method(&self, auth_method: AuthMethod) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.auth_method = Some(auth_method);
|
||||
}
|
||||
|
||||
pub fn has_private_peer_addr(&self) -> bool {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.has_private_peer_addr()
|
||||
}
|
||||
|
||||
pub fn set_error_kind(&self, kind: ErrorKind) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
// Do not record errors from the private address to metrics.
|
||||
if !this.has_private_peer_addr() {
|
||||
Metrics::get().proxy.errors_total.inc(kind);
|
||||
}
|
||||
if let Some(ep) = &this.endpoint_id {
|
||||
let metric = &Metrics::get().proxy.endpoints_affected_by_errors;
|
||||
let label = metric.with_labels(kind);
|
||||
metric.get_metric(label).measure(ep);
|
||||
}
|
||||
this.error_kind = Some(kind);
|
||||
}
|
||||
|
||||
pub fn set_success(&self) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.success = true;
|
||||
}
|
||||
|
||||
pub fn log_connect(&self) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.log_connect();
|
||||
}
|
||||
|
||||
pub fn protocol(&self) -> Protocol {
|
||||
self.0.try_lock().expect("should not deadlock").protocol
|
||||
}
|
||||
|
||||
pub fn span(&self) -> Span {
|
||||
self.0.try_lock().expect("should not deadlock").span.clone()
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> Uuid {
|
||||
self.0.try_lock().expect("should not deadlock").session_id
|
||||
}
|
||||
|
||||
pub fn peer_addr(&self) -> IpAddr {
|
||||
self.0.try_lock().expect("should not deadlock").peer_addr
|
||||
}
|
||||
|
||||
pub fn cold_start_info(&self) -> ColdStartInfo {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.cold_start_info
|
||||
}
|
||||
|
||||
pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause {
|
||||
LatencyTimerPause {
|
||||
ctx: self,
|
||||
start: tokio::time::Instant::now(),
|
||||
waiting_for,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn success(&self) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.latency_timer
|
||||
.success()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LatencyTimerPause<'a> {
|
||||
ctx: &'a RequestMonitoring,
|
||||
start: tokio::time::Instant,
|
||||
waiting_for: Waiting,
|
||||
}
|
||||
|
||||
impl Drop for LatencyTimerPause<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.ctx
|
||||
.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.latency_timer
|
||||
.unpause(self.start, self.waiting_for);
|
||||
}
|
||||
}
|
||||
|
||||
impl RequestMonitoringInner {
|
||||
fn set_cold_start_info(&mut self, info: ColdStartInfo) {
|
||||
self.cold_start_info = info;
|
||||
self.latency_timer.cold_start_info(info);
|
||||
}
|
||||
|
||||
fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
|
||||
if self.endpoint_id.is_none() {
|
||||
self.span.record("ep", display(&endpoint_id));
|
||||
let metric = &Metrics::get().proxy.connecting_endpoints;
|
||||
let label = metric.with_labels(self.protocol);
|
||||
metric.get_metric(label).measure(&endpoint_id);
|
||||
self.endpoint_id = Some(endpoint_id);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_application(&mut self, app: Option<SmolStr>) {
|
||||
if let Some(app) = app {
|
||||
self.application = Some(app);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_dbname(&mut self, dbname: DbName) {
|
||||
self.dbname = Some(dbname);
|
||||
}
|
||||
|
||||
fn set_user(&mut self, user: RoleName) {
|
||||
self.span.record("role", display(&user));
|
||||
self.user = Some(user);
|
||||
}
|
||||
|
||||
fn has_private_peer_addr(&self) -> bool {
|
||||
match self.peer_addr {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn log_connect(&mut self) {
|
||||
let outcome = if self.success {
|
||||
ConnectOutcome::Success
|
||||
} else {
|
||||
ConnectOutcome::Failed
|
||||
};
|
||||
if let Some(rejected) = self.rejected {
|
||||
let ep = self
|
||||
.endpoint_id
|
||||
.as_ref()
|
||||
.map(|x| x.as_str())
|
||||
.unwrap_or_default();
|
||||
// This makes sense only if cache is disabled
|
||||
info!(
|
||||
?outcome,
|
||||
?rejected,
|
||||
?ep,
|
||||
"check endpoint is valid with outcome"
|
||||
);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.invalid_endpoints_total
|
||||
.inc(InvalidEndpointsGroup {
|
||||
protocol: self.protocol,
|
||||
rejected: rejected.into(),
|
||||
outcome,
|
||||
});
|
||||
}
|
||||
if let Some(tx) = self.sender.take() {
|
||||
let _: Result<(), _> = tx.send(RequestData::from(&*self));
|
||||
}
|
||||
}
|
||||
|
||||
fn log_disconnect(&mut self) {
|
||||
// If we are here, it's guaranteed that the user successfully connected to the endpoint.
|
||||
// Here we log the length of the session.
|
||||
self.disconnect_timestamp = Some(Utc::now());
|
||||
if let Some(tx) = self.disconnect_sender.take() {
|
||||
let _: Result<(), _> = tx.send(RequestData::from(&*self));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RequestMonitoringInner {
|
||||
fn drop(&mut self) {
|
||||
if self.sender.is_some() {
|
||||
self.log_connect();
|
||||
} else {
|
||||
self.log_disconnect();
|
||||
}
|
||||
}
|
||||
}
|
||||
754
proxy/core/src/context/parquet.rs
Normal file
754
proxy/core/src/context/parquet.rs
Normal file
@@ -0,0 +1,754 @@
|
||||
use std::{sync::Arc, time::SystemTime};
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::{buf::Writer, BufMut, BytesMut};
|
||||
use chrono::{Datelike, Timelike};
|
||||
use futures::{Stream, StreamExt};
|
||||
use parquet::{
|
||||
basic::Compression,
|
||||
file::{
|
||||
metadata::RowGroupMetaDataPtr,
|
||||
properties::{WriterProperties, WriterPropertiesPtr, DEFAULT_PAGE_SIZE},
|
||||
writer::SerializedFileWriter,
|
||||
},
|
||||
record::RecordWriter,
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel};
|
||||
use serde::ser::SerializeMap;
|
||||
use tokio::{sync::mpsc, time};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, Span};
|
||||
use utils::backoff;
|
||||
|
||||
use crate::{config::remote_storage_from_toml, context::LOG_CHAN_DISCONNECT};
|
||||
|
||||
use super::{RequestMonitoringInner, LOG_CHAN};
|
||||
|
||||
#[derive(clap::Args, Clone, Debug)]
|
||||
pub struct ParquetUploadArgs {
|
||||
/// Storage location to upload the parquet files to.
|
||||
/// Encoded as toml (same format as pageservers), eg
|
||||
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
|
||||
#[clap(long, value_parser = remote_storage_from_toml)]
|
||||
parquet_upload_remote_storage: Option<RemoteStorageConfig>,
|
||||
|
||||
#[clap(long, value_parser = remote_storage_from_toml)]
|
||||
parquet_upload_disconnect_events_remote_storage: Option<RemoteStorageConfig>,
|
||||
|
||||
/// How many rows to include in a row group
|
||||
#[clap(long, default_value_t = 8192)]
|
||||
parquet_upload_row_group_size: usize,
|
||||
|
||||
/// How large each column page should be in bytes
|
||||
#[clap(long, default_value_t = DEFAULT_PAGE_SIZE)]
|
||||
parquet_upload_page_size: usize,
|
||||
|
||||
/// How large the total parquet file should be in bytes
|
||||
#[clap(long, default_value_t = 100_000_000)]
|
||||
parquet_upload_size: i64,
|
||||
|
||||
/// How long to wait before forcing a file upload
|
||||
#[clap(long, default_value = "20m", value_parser = humantime::parse_duration)]
|
||||
parquet_upload_maximum_duration: tokio::time::Duration,
|
||||
|
||||
/// What level of compression to use
|
||||
#[clap(long, default_value_t = Compression::UNCOMPRESSED)]
|
||||
parquet_upload_compression: Compression,
|
||||
}
|
||||
|
||||
// Occasional network issues and such can cause remote operations to fail, and
|
||||
// that's expected. If a upload fails, we log it at info-level, and retry.
|
||||
// But after FAILED_UPLOAD_WARN_THRESHOLD retries, we start to log it at WARN
|
||||
// level instead, as repeated failures can mean a more serious problem. If it
|
||||
// fails more than FAILED_UPLOAD_RETRIES times, we give up
|
||||
pub const FAILED_UPLOAD_WARN_THRESHOLD: u32 = 3;
|
||||
pub const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
|
||||
|
||||
// the parquet crate leaves a lot to be desired...
|
||||
// what follows is an attempt to write parquet files with minimal allocs.
|
||||
// complication: parquet is a columnar format, while we want to write in as rows.
|
||||
// design:
|
||||
// * we batch up to 1024 rows, then flush them into a 'row group'
|
||||
// * after each rowgroup write, we check the length of the file and upload to s3 if large enough
|
||||
|
||||
#[derive(parquet_derive::ParquetRecordWriter)]
|
||||
pub struct RequestData {
|
||||
region: &'static str,
|
||||
protocol: &'static str,
|
||||
/// Must be UTC. The derive macro doesn't like the timezones
|
||||
timestamp: chrono::NaiveDateTime,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: String,
|
||||
username: Option<String>,
|
||||
application_name: Option<String>,
|
||||
endpoint_id: Option<String>,
|
||||
database: Option<String>,
|
||||
project: Option<String>,
|
||||
branch: Option<String>,
|
||||
pg_options: Option<String>,
|
||||
auth_method: Option<&'static str>,
|
||||
error: Option<&'static str>,
|
||||
/// Success is counted if we form a HTTP response with sql rows inside
|
||||
/// Or if we make it to proxy_pass
|
||||
success: bool,
|
||||
/// Indicates if the cplane started the new compute node for this request.
|
||||
cold_start_info: &'static str,
|
||||
/// Tracks time from session start (HTTP request/libpq TCP handshake)
|
||||
/// Through to success/failure
|
||||
duration_us: u64,
|
||||
/// If the session was successful after the disconnect, will be created one more event with filled `disconnect_timestamp`.
|
||||
disconnect_timestamp: Option<chrono::NaiveDateTime>,
|
||||
}
|
||||
|
||||
struct Options<'a> {
|
||||
options: &'a StartupMessageParams,
|
||||
}
|
||||
|
||||
impl<'a> serde::Serialize for Options<'a> {
|
||||
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut state = s.serialize_map(None)?;
|
||||
for (k, v) in self.options.iter() {
|
||||
state.serialize_entry(k, v)?;
|
||||
}
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RequestMonitoringInner> for RequestData {
|
||||
fn from(value: &RequestMonitoringInner) -> Self {
|
||||
Self {
|
||||
session_id: value.session_id,
|
||||
peer_addr: value.peer_addr.to_string(),
|
||||
timestamp: value.first_packet.naive_utc(),
|
||||
username: value.user.as_deref().map(String::from),
|
||||
application_name: value.application.as_deref().map(String::from),
|
||||
endpoint_id: value.endpoint_id.as_deref().map(String::from),
|
||||
database: value.dbname.as_deref().map(String::from),
|
||||
project: value.project.as_deref().map(String::from),
|
||||
branch: value.branch.as_deref().map(String::from),
|
||||
pg_options: value
|
||||
.pg_options
|
||||
.as_ref()
|
||||
.and_then(|options| serde_json::to_string(&Options { options }).ok()),
|
||||
auth_method: value.auth_method.as_ref().map(|x| match x {
|
||||
super::AuthMethod::Web => "web",
|
||||
super::AuthMethod::ScramSha256 => "scram_sha_256",
|
||||
super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus",
|
||||
super::AuthMethod::Cleartext => "cleartext",
|
||||
}),
|
||||
protocol: value.protocol.as_str(),
|
||||
region: value.region,
|
||||
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
|
||||
success: value.success,
|
||||
cold_start_info: value.cold_start_info.as_str(),
|
||||
duration_us: SystemTime::from(value.first_packet)
|
||||
.elapsed()
|
||||
.unwrap_or_default()
|
||||
.as_micros() as u64, // 584 millenia... good enough
|
||||
disconnect_timestamp: value.disconnect_timestamp.map(|x| x.naive_utc()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parquet request context worker
|
||||
///
|
||||
/// It listened on a channel for all completed requests, extracts the data and writes it into a parquet file,
|
||||
/// then uploads a completed batch to S3
|
||||
pub async fn worker(
|
||||
cancellation_token: CancellationToken,
|
||||
config: ParquetUploadArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
|
||||
tracing::warn!("parquet request upload: no s3 bucket configured");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
LOG_CHAN.set(tx.downgrade()).unwrap();
|
||||
|
||||
// setup row stream that will close on cancellation
|
||||
let cancellation_token2 = cancellation_token.clone();
|
||||
tokio::spawn(async move {
|
||||
cancellation_token2.cancelled().await;
|
||||
// dropping this sender will cause the channel to close only once
|
||||
// all the remaining inflight requests have been completed.
|
||||
drop(tx);
|
||||
});
|
||||
let rx = futures::stream::poll_fn(move |cx| rx.poll_recv(cx));
|
||||
let rx = rx.map(RequestData::from);
|
||||
|
||||
let storage = GenericRemoteStorage::from_config(&remote_storage_config)
|
||||
.await
|
||||
.context("remote storage init")?;
|
||||
|
||||
let properties = WriterProperties::builder()
|
||||
.set_data_page_size_limit(config.parquet_upload_page_size)
|
||||
.set_compression(config.parquet_upload_compression);
|
||||
|
||||
let parquet_config = ParquetConfig {
|
||||
propeties: Arc::new(properties.build()),
|
||||
rows_per_group: config.parquet_upload_row_group_size,
|
||||
file_size: config.parquet_upload_size,
|
||||
max_duration: config.parquet_upload_maximum_duration,
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
// TODO(anna): consider moving this to a separate function.
|
||||
if let Some(disconnect_events_storage_config) =
|
||||
config.parquet_upload_disconnect_events_remote_storage
|
||||
{
|
||||
let (tx_disconnect, mut rx_disconnect) = mpsc::unbounded_channel();
|
||||
LOG_CHAN_DISCONNECT.set(tx_disconnect.downgrade()).unwrap();
|
||||
|
||||
// setup row stream that will close on cancellation
|
||||
tokio::spawn(async move {
|
||||
cancellation_token.cancelled().await;
|
||||
// dropping this sender will cause the channel to close only once
|
||||
// all the remaining inflight requests have been completed.
|
||||
drop(tx_disconnect);
|
||||
});
|
||||
let rx_disconnect = futures::stream::poll_fn(move |cx| rx_disconnect.poll_recv(cx));
|
||||
let rx_disconnect = rx_disconnect.map(RequestData::from);
|
||||
|
||||
let storage_disconnect =
|
||||
GenericRemoteStorage::from_config(&disconnect_events_storage_config)
|
||||
.await
|
||||
.context("remote storage for disconnect events init")?;
|
||||
let parquet_config_disconnect = parquet_config.clone();
|
||||
tokio::try_join!(
|
||||
worker_inner(storage, rx, parquet_config),
|
||||
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect)
|
||||
)
|
||||
.map(|_| ())
|
||||
} else {
|
||||
worker_inner(storage, rx, parquet_config).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ParquetConfig {
|
||||
propeties: WriterPropertiesPtr,
|
||||
rows_per_group: usize,
|
||||
file_size: i64,
|
||||
|
||||
max_duration: tokio::time::Duration,
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
test_remote_failures: u64,
|
||||
}
|
||||
|
||||
async fn worker_inner(
|
||||
storage: GenericRemoteStorage,
|
||||
rx: impl Stream<Item = RequestData>,
|
||||
config: ParquetConfig,
|
||||
) -> anyhow::Result<()> {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
let storage = if config.test_remote_failures > 0 {
|
||||
GenericRemoteStorage::unreliable_wrapper(storage, config.test_remote_failures)
|
||||
} else {
|
||||
storage
|
||||
};
|
||||
|
||||
let mut rx = std::pin::pin!(rx);
|
||||
|
||||
let mut rows = Vec::with_capacity(config.rows_per_group);
|
||||
|
||||
let schema = rows.as_slice().schema()?;
|
||||
let buffer = BytesMut::new();
|
||||
let w = buffer.writer();
|
||||
let mut w = SerializedFileWriter::new(w, schema.clone(), config.propeties.clone())?;
|
||||
|
||||
let mut last_upload = time::Instant::now();
|
||||
|
||||
let mut len = 0;
|
||||
while let Some(row) = rx.next().await {
|
||||
rows.push(row);
|
||||
let force = last_upload.elapsed() > config.max_duration;
|
||||
if rows.len() == config.rows_per_group || force {
|
||||
let rg_meta;
|
||||
(rows, w, rg_meta) = flush_rows(rows, w).await?;
|
||||
len += rg_meta.compressed_size();
|
||||
}
|
||||
if len > config.file_size || force {
|
||||
last_upload = time::Instant::now();
|
||||
let file = upload_parquet(w, len, &storage).await?;
|
||||
w = SerializedFileWriter::new(file, schema.clone(), config.propeties.clone())?;
|
||||
len = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if !rows.is_empty() {
|
||||
let rg_meta;
|
||||
(_, w, rg_meta) = flush_rows(rows, w).await?;
|
||||
len += rg_meta.compressed_size();
|
||||
}
|
||||
|
||||
if !w.flushed_row_groups().is_empty() {
|
||||
let _: Writer<BytesMut> = upload_parquet(w, len, &storage).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn flush_rows<W>(
|
||||
rows: Vec<RequestData>,
|
||||
mut w: SerializedFileWriter<W>,
|
||||
) -> anyhow::Result<(
|
||||
Vec<RequestData>,
|
||||
SerializedFileWriter<W>,
|
||||
RowGroupMetaDataPtr,
|
||||
)>
|
||||
where
|
||||
W: std::io::Write + Send + 'static,
|
||||
{
|
||||
let span = Span::current();
|
||||
let (mut rows, w, rg_meta) = tokio::task::spawn_blocking(move || {
|
||||
let _enter = span.enter();
|
||||
|
||||
let mut rg = w.next_row_group()?;
|
||||
rows.as_slice().write_to_row_group(&mut rg)?;
|
||||
let rg_meta = rg.close()?;
|
||||
|
||||
let size = rg_meta.compressed_size();
|
||||
let compression = rg_meta.compressed_size() as f64 / rg_meta.total_byte_size() as f64;
|
||||
|
||||
debug!(size, compression, "flushed row group to parquet file");
|
||||
|
||||
Ok::<_, parquet::errors::ParquetError>((rows, w, rg_meta))
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
rows.clear();
|
||||
Ok((rows, w, rg_meta))
|
||||
}
|
||||
|
||||
async fn upload_parquet(
|
||||
mut w: SerializedFileWriter<Writer<BytesMut>>,
|
||||
len: i64,
|
||||
storage: &GenericRemoteStorage,
|
||||
) -> anyhow::Result<Writer<BytesMut>> {
|
||||
let len_uncompressed = w
|
||||
.flushed_row_groups()
|
||||
.iter()
|
||||
.map(|rg| rg.total_byte_size())
|
||||
.sum::<i64>();
|
||||
|
||||
// I don't know how compute intensive this is, although it probably isn't much... better be safe than sorry.
|
||||
// finish method only available on the fork: https://github.com/apache/arrow-rs/issues/5253
|
||||
let (mut buffer, metadata) =
|
||||
tokio::task::spawn_blocking(move || -> parquet::errors::Result<_> {
|
||||
let metadata = w.finish()?;
|
||||
let buffer = std::mem::take(w.inner_mut().get_mut());
|
||||
Ok((buffer, metadata))
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
let data = buffer.split().freeze();
|
||||
|
||||
let compression = len as f64 / len_uncompressed as f64;
|
||||
let size = data.len();
|
||||
let now = chrono::Utc::now();
|
||||
let id = uuid::Uuid::new_v7(uuid::Timestamp::from_unix(
|
||||
uuid::NoContext,
|
||||
// we won't be running this in 1970. this cast is ok
|
||||
now.timestamp() as u64,
|
||||
now.timestamp_subsec_nanos(),
|
||||
));
|
||||
|
||||
info!(
|
||||
%id,
|
||||
rows = metadata.num_rows,
|
||||
size, compression, "uploading request parquet file"
|
||||
);
|
||||
|
||||
let year = now.year();
|
||||
let month = now.month();
|
||||
let day = now.day();
|
||||
let hour = now.hour();
|
||||
// segment files by time for S3 performance
|
||||
let path = RemotePath::from_string(&format!(
|
||||
"{year:04}/{month:02}/{day:02}/{hour:02}/requests_{id}.parquet"
|
||||
))?;
|
||||
let cancel = CancellationToken::new();
|
||||
let maybe_err = backoff::retry(
|
||||
|| async {
|
||||
let stream = futures::stream::once(futures::future::ready(Ok(data.clone())));
|
||||
storage
|
||||
.upload(stream, data.len(), &path, None, &cancel)
|
||||
.await
|
||||
},
|
||||
TimeoutOrCancel::caused_by_cancel,
|
||||
FAILED_UPLOAD_WARN_THRESHOLD,
|
||||
FAILED_UPLOAD_MAX_RETRIES,
|
||||
"request_data_upload",
|
||||
// we don't want cancellation to interrupt here, so we make a dummy cancel token
|
||||
&cancel,
|
||||
)
|
||||
.await
|
||||
.ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel))
|
||||
.and_then(|x| x)
|
||||
.context("request_data_upload")
|
||||
.err();
|
||||
|
||||
if let Some(err) = maybe_err {
|
||||
tracing::warn!(%id, %err, "failed to upload request data");
|
||||
}
|
||||
|
||||
Ok(buffer.writer())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{net::Ipv4Addr, num::NonZeroUsize, sync::Arc};
|
||||
|
||||
use camino::Utf8Path;
|
||||
use clap::Parser;
|
||||
use futures::{Stream, StreamExt};
|
||||
use itertools::Itertools;
|
||||
use parquet::{
|
||||
basic::{Compression, ZstdLevel},
|
||||
file::{
|
||||
properties::{WriterProperties, DEFAULT_PAGE_SIZE},
|
||||
reader::FileReader,
|
||||
serialized_reader::SerializedFileReader,
|
||||
},
|
||||
};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use remote_storage::{
|
||||
GenericRemoteStorage, RemoteStorageConfig, RemoteStorageKind, S3Config,
|
||||
DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
|
||||
};
|
||||
use tokio::{sync::mpsc, time};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
use super::{worker_inner, ParquetConfig, ParquetUploadArgs, RequestData};
|
||||
|
||||
#[derive(Parser)]
|
||||
struct ProxyCliArgs {
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_parser() {
|
||||
let ProxyCliArgs { parquet_upload } = ProxyCliArgs::parse_from(["proxy"]);
|
||||
assert_eq!(parquet_upload.parquet_upload_remote_storage, None);
|
||||
assert_eq!(parquet_upload.parquet_upload_row_group_size, 8192);
|
||||
assert_eq!(parquet_upload.parquet_upload_page_size, DEFAULT_PAGE_SIZE);
|
||||
assert_eq!(parquet_upload.parquet_upload_size, 100_000_000);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_maximum_duration,
|
||||
time::Duration::from_secs(20 * 60)
|
||||
);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_compression,
|
||||
Compression::UNCOMPRESSED
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_parser() {
|
||||
let ProxyCliArgs { parquet_upload } = ProxyCliArgs::parse_from([
|
||||
"proxy",
|
||||
"--parquet-upload-remote-storage",
|
||||
"{bucket_name='default',prefix_in_bucket='proxy/',bucket_region='us-east-1',endpoint='http://minio:9000'}",
|
||||
"--parquet-upload-row-group-size",
|
||||
"100",
|
||||
"--parquet-upload-page-size",
|
||||
"10000",
|
||||
"--parquet-upload-size",
|
||||
"10000000",
|
||||
"--parquet-upload-maximum-duration",
|
||||
"10m",
|
||||
"--parquet-upload-compression",
|
||||
"zstd(5)",
|
||||
]);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_remote_storage,
|
||||
Some(RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: "default".into(),
|
||||
bucket_region: "us-east-1".into(),
|
||||
prefix_in_bucket: Some("proxy/".into()),
|
||||
endpoint: Some("http://minio:9000".into()),
|
||||
concurrency_limit: NonZeroUsize::new(
|
||||
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
|
||||
)
|
||||
.unwrap(),
|
||||
max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
|
||||
upload_storage_class: None,
|
||||
}),
|
||||
timeout: RemoteStorageConfig::DEFAULT_TIMEOUT,
|
||||
})
|
||||
);
|
||||
assert_eq!(parquet_upload.parquet_upload_row_group_size, 100);
|
||||
assert_eq!(parquet_upload.parquet_upload_page_size, 10000);
|
||||
assert_eq!(parquet_upload.parquet_upload_size, 10_000_000);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_maximum_duration,
|
||||
time::Duration::from_secs(10 * 60)
|
||||
);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_compression,
|
||||
Compression::ZSTD(ZstdLevel::try_new(5).unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
fn generate_request_data(rng: &mut impl Rng) -> RequestData {
|
||||
RequestData {
|
||||
session_id: uuid::Builder::from_random_bytes(rng.gen()).into_uuid(),
|
||||
peer_addr: Ipv4Addr::from(rng.gen::<[u8; 4]>()).to_string(),
|
||||
timestamp: chrono::DateTime::from_timestamp_millis(
|
||||
rng.gen_range(1703862754..1803862754),
|
||||
)
|
||||
.unwrap()
|
||||
.naive_utc(),
|
||||
application_name: Some("test".to_owned()),
|
||||
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
|
||||
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
database: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
pg_options: None,
|
||||
auth_method: None,
|
||||
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
||||
region: "us-east-1",
|
||||
error: None,
|
||||
success: rng.gen(),
|
||||
cold_start_info: "no",
|
||||
duration_us: rng.gen_range(0..30_000_000),
|
||||
disconnect_timestamp: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn random_stream(len: usize) -> impl Stream<Item = RequestData> + Unpin {
|
||||
let mut rng = StdRng::from_seed([0x39; 32]);
|
||||
futures::stream::iter(
|
||||
std::iter::repeat_with(move || generate_request_data(&mut rng)).take(len),
|
||||
)
|
||||
}
|
||||
|
||||
async fn run_test(
|
||||
tmpdir: &Utf8Path,
|
||||
config: ParquetConfig,
|
||||
rx: impl Stream<Item = RequestData>,
|
||||
) -> Vec<(u64, usize, i64)> {
|
||||
let remote_storage_config = RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::LocalFs {
|
||||
local_path: tmpdir.to_path_buf(),
|
||||
},
|
||||
timeout: std::time::Duration::from_secs(120),
|
||||
};
|
||||
let storage = GenericRemoteStorage::from_config(&remote_storage_config)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
worker_inner(storage, rx, config).await.unwrap();
|
||||
|
||||
let mut files = WalkDir::new(tmpdir.as_std_path())
|
||||
.into_iter()
|
||||
.filter_map(|entry| entry.ok())
|
||||
.filter(|entry| entry.file_type().is_file())
|
||||
.map(|entry| entry.path().to_path_buf())
|
||||
.collect_vec();
|
||||
files.sort();
|
||||
|
||||
files
|
||||
.into_iter()
|
||||
.map(|path| std::fs::File::open(tmpdir.as_std_path().join(path)).unwrap())
|
||||
.map(|file| {
|
||||
(
|
||||
file.metadata().unwrap(),
|
||||
SerializedFileReader::new(file).unwrap().metadata().clone(),
|
||||
)
|
||||
})
|
||||
.map(|(file_meta, parquet_meta)| {
|
||||
(
|
||||
file_meta.len(),
|
||||
parquet_meta.num_row_groups(),
|
||||
parquet_meta.file_metadata().num_rows(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_no_compression() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(WriterProperties::new()),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1315874, 3, 6000),
|
||||
(1315867, 3, 6000),
|
||||
(1315927, 3, 6000),
|
||||
(1315884, 3, 6000),
|
||||
(1316014, 3, 6000),
|
||||
(1315856, 3, 6000),
|
||||
(1315648, 3, 6000),
|
||||
(1315884, 3, 6000),
|
||||
(438913, 1, 2000)
|
||||
]
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_min_compression() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(
|
||||
WriterProperties::builder()
|
||||
.set_compression(parquet::basic::Compression::ZSTD(ZstdLevel::default()))
|
||||
.build(),
|
||||
),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
// with compression, there are fewer files with more rows per file
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1223214, 5, 10000),
|
||||
(1229364, 5, 10000),
|
||||
(1231158, 5, 10000),
|
||||
(1230520, 5, 10000),
|
||||
(1221798, 5, 10000)
|
||||
]
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_strong_compression() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(
|
||||
WriterProperties::builder()
|
||||
.set_compression(parquet::basic::Compression::ZSTD(
|
||||
ZstdLevel::try_new(10).unwrap(),
|
||||
))
|
||||
.build(),
|
||||
),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
// with strong compression, the files are smaller
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1208861, 5, 10000),
|
||||
(1208592, 5, 10000),
|
||||
(1208885, 5, 10000),
|
||||
(1208873, 5, 10000),
|
||||
(1209128, 5, 10000)
|
||||
]
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_unreliable_upload() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(WriterProperties::new()),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 2,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1315874, 3, 6000),
|
||||
(1315867, 3, 6000),
|
||||
(1315927, 3, 6000),
|
||||
(1315884, 3, 6000),
|
||||
(1316014, 3, 6000),
|
||||
(1315856, 3, 6000),
|
||||
(1315648, 3, 6000),
|
||||
(1315884, 3, 6000),
|
||||
(438913, 1, 2000)
|
||||
]
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn verify_parquet_regular_upload() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(WriterProperties::new()),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(60),
|
||||
test_remote_failures: 2,
|
||||
};
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
for _ in 0..3 {
|
||||
let mut s = random_stream(3000);
|
||||
while let Some(r) = s.next().await {
|
||||
tx.send(r).unwrap();
|
||||
}
|
||||
time::sleep(time::Duration::from_secs(70)).await
|
||||
}
|
||||
});
|
||||
|
||||
let rx = futures::stream::poll_fn(move |cx| rx.poll_recv(cx));
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
// files are smaller than the size threshold, but they took too long to fill so were flushed early
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[(659836, 2, 3001), (659550, 2, 3000), (659346, 2, 2999)]
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
}
|
||||
93
proxy/core/src/error.rs
Normal file
93
proxy/core/src/error.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use std::{error::Error as StdError, fmt, io};
|
||||
|
||||
use measured::FixedCardinalityLabel;
|
||||
|
||||
/// Upcast (almost) any error into an opaque [`io::Error`].
|
||||
pub fn io_error(e: impl Into<Box<dyn StdError + Send + Sync>>) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
}
|
||||
|
||||
/// A small combinator for pluggable error logging.
|
||||
pub fn log_error<E: fmt::Display>(e: E) -> E {
|
||||
tracing::error!("{e}");
|
||||
e
|
||||
}
|
||||
|
||||
/// Marks errors that may be safely shown to a client.
|
||||
/// This trait can be seen as a specialized version of [`ToString`].
|
||||
///
|
||||
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
|
||||
/// is way too convenient and tends to proliferate all across the codebase,
|
||||
/// ultimately leading to accidental leaks of sensitive data.
|
||||
pub trait UserFacingError: ReportableError {
|
||||
/// Format the error for client, stripping all sensitive info.
|
||||
///
|
||||
/// Although this might be a no-op for many types, it's highly
|
||||
/// recommended to override the default impl in case error type
|
||||
/// contains anything sensitive: various IDs, IP addresses etc.
|
||||
#[inline(always)]
|
||||
fn to_string_client(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, FixedCardinalityLabel)]
|
||||
#[label(singleton = "type")]
|
||||
pub enum ErrorKind {
|
||||
/// Wrong password, unknown endpoint, protocol violation, etc...
|
||||
User,
|
||||
|
||||
/// Network error between user and proxy. Not necessarily user error
|
||||
#[label(rename = "clientdisconnect")]
|
||||
ClientDisconnect,
|
||||
|
||||
/// Proxy self-imposed user rate limits
|
||||
#[label(rename = "ratelimit")]
|
||||
RateLimit,
|
||||
|
||||
/// Proxy self-imposed service-wise rate limits
|
||||
#[label(rename = "serviceratelimit")]
|
||||
ServiceRateLimit,
|
||||
|
||||
/// internal errors
|
||||
Service,
|
||||
|
||||
/// Error communicating with control plane
|
||||
#[label(rename = "controlplane")]
|
||||
ControlPlane,
|
||||
|
||||
/// Postgres error
|
||||
Postgres,
|
||||
|
||||
/// Error communicating with compute
|
||||
Compute,
|
||||
}
|
||||
|
||||
impl ErrorKind {
|
||||
pub fn to_metric_label(&self) -> &'static str {
|
||||
match self {
|
||||
ErrorKind::User => "user",
|
||||
ErrorKind::ClientDisconnect => "clientdisconnect",
|
||||
ErrorKind::RateLimit => "ratelimit",
|
||||
ErrorKind::ServiceRateLimit => "serviceratelimit",
|
||||
ErrorKind::Service => "service",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "postgres",
|
||||
ErrorKind::Compute => "compute",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ReportableError: fmt::Display + Send + 'static {
|
||||
fn get_error_kind(&self) -> ErrorKind;
|
||||
}
|
||||
|
||||
impl ReportableError for tokio_postgres::error::Error {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
if self.as_db_error().is_some() {
|
||||
ErrorKind::Postgres
|
||||
} else {
|
||||
ErrorKind::Compute
|
||||
}
|
||||
}
|
||||
}
|
||||
173
proxy/core/src/http.rs
Normal file
173
proxy/core/src/http.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
//! HTTP client and server impls.
|
||||
//! Other modules should use stuff from this module instead of
|
||||
//! directly relying on deps like `reqwest` (think loose coupling).
|
||||
|
||||
pub mod health_server;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::bail;
|
||||
use bytes::Bytes;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper1::body::Body;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
pub use reqwest::{Request, Response, StatusCode};
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
|
||||
use crate::{
|
||||
metrics::{ConsoleRequest, Metrics},
|
||||
url::ApiUrl,
|
||||
};
|
||||
use reqwest_middleware::RequestBuilder;
|
||||
|
||||
/// This is the preferred way to create new http clients,
|
||||
/// because it takes care of observability (OpenTelemetry).
|
||||
/// We deliberately don't want to replace this with a public static.
|
||||
pub fn new_client() -> ClientWithMiddleware {
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.build()
|
||||
.expect("Failed to create http client");
|
||||
|
||||
reqwest_middleware::ClientBuilder::new(client)
|
||||
.with(reqwest_tracing::TracingMiddleware::default())
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
|
||||
let timeout_client = reqwest::ClientBuilder::new()
|
||||
.timeout(default_timout)
|
||||
.build()
|
||||
.expect("Failed to create http client with timeout");
|
||||
|
||||
let retry_policy =
|
||||
ExponentialBackoff::builder().build_with_total_retry_duration(default_timout);
|
||||
|
||||
reqwest_middleware::ClientBuilder::new(timeout_client)
|
||||
.with(reqwest_tracing::TracingMiddleware::default())
|
||||
// As per docs, "This middleware always errors when given requests with streaming bodies".
|
||||
// That's all right because we only use this client to send `serde_json::RawValue`, which
|
||||
// is not a stream.
|
||||
//
|
||||
// ex-maintainer note:
|
||||
// this limitation can be fixed if streaming is necessary.
|
||||
// retries will still not be performed, but it wont error immediately
|
||||
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||||
.build()
|
||||
}
|
||||
|
||||
/// Thin convenience wrapper for an API provided by an http endpoint.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Endpoint {
|
||||
/// API's base URL.
|
||||
endpoint: ApiUrl,
|
||||
/// Connection manager with built-in pooling.
|
||||
client: ClientWithMiddleware,
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
/// Construct a new HTTP endpoint wrapper.
|
||||
/// Http client is not constructed under the hood so that it can be shared.
|
||||
pub fn new(endpoint: ApiUrl, client: impl Into<ClientWithMiddleware>) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
client: client.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn url(&self) -> &ApiUrl {
|
||||
&self.endpoint
|
||||
}
|
||||
|
||||
/// Return a [builder](RequestBuilder) for a `GET` request,
|
||||
/// appending a single `path` segment to the base endpoint URL.
|
||||
pub fn get(&self, path: &str) -> RequestBuilder {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push(path);
|
||||
self.client.get(url.into_inner())
|
||||
}
|
||||
|
||||
/// Execute a [request](reqwest::Request).
|
||||
pub async fn execute(&self, request: Request) -> Result<Response, Error> {
|
||||
let _timer = Metrics::get()
|
||||
.proxy
|
||||
.console_request_latency
|
||||
.start_timer(ConsoleRequest {
|
||||
request: request.url().path(),
|
||||
});
|
||||
|
||||
self.client.execute(request).await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn parse_json_body_with_limit<D: DeserializeOwned>(
|
||||
mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<D> {
|
||||
// We could use `b.limited().collect().await.to_bytes()` here
|
||||
// but this ends up being slightly more efficient as far as I can tell.
|
||||
|
||||
// check the lower bound of the size hint.
|
||||
// in reqwest, this value is influenced by the Content-Length header.
|
||||
let lower_bound = match usize::try_from(b.size_hint().lower()) {
|
||||
Ok(bound) if bound <= limit => bound,
|
||||
_ => bail!("content length exceeds limit"),
|
||||
};
|
||||
let mut bytes = Vec::with_capacity(lower_bound);
|
||||
|
||||
while let Some(frame) = b.frame().await.transpose()? {
|
||||
if let Ok(data) = frame.into_data() {
|
||||
if bytes.len() + data.len() > limit {
|
||||
bail!("content length exceeds limit")
|
||||
}
|
||||
bytes.extend_from_slice(&data);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(serde_json::from_slice::<D>(&bytes)?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use reqwest::Client;
|
||||
|
||||
#[test]
|
||||
fn optional_query_params() -> anyhow::Result<()> {
|
||||
let url = "http://example.com".parse()?;
|
||||
let endpoint = Endpoint::new(url, Client::new());
|
||||
|
||||
// Validate that this pattern makes sense.
|
||||
let req = endpoint
|
||||
.get("frobnicate")
|
||||
.query(&[
|
||||
("foo", Some("10")), // should be just `foo=10`
|
||||
("bar", None), // shouldn't be passed at all
|
||||
])
|
||||
.build()?;
|
||||
|
||||
assert_eq!(req.url().as_str(), "http://example.com/frobnicate?foo=10");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uuid_params() -> anyhow::Result<()> {
|
||||
let url = "http://example.com".parse()?;
|
||||
let endpoint = Endpoint::new(url, Client::new());
|
||||
|
||||
let req = endpoint
|
||||
.get("frobnicate")
|
||||
.query(&[("session_id", uuid::Uuid::nil())])
|
||||
.build()?;
|
||||
|
||||
assert_eq!(
|
||||
req.url().as_str(),
|
||||
"http://example.com/frobnicate?session_id=00000000-0000-0000-0000-000000000000"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
107
proxy/core/src/http/health_server.rs
Normal file
107
proxy/core/src/http/health_server.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use anyhow::{anyhow, bail};
|
||||
use hyper::{header::CONTENT_TYPE, Body, Request, Response, StatusCode};
|
||||
use measured::{text::BufferedTextEncoder, MetricGroup};
|
||||
use metrics::NeonMetrics;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
net::TcpListener,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use tracing::{info, info_span};
|
||||
use utils::http::{
|
||||
endpoint::{self, request_span},
|
||||
error::ApiError,
|
||||
json::json_response,
|
||||
RouterBuilder, RouterService,
|
||||
};
|
||||
|
||||
use crate::jemalloc;
|
||||
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
json_response(StatusCode::OK, "")
|
||||
}
|
||||
|
||||
fn make_router(metrics: AppMetrics) -> RouterBuilder<hyper::Body, ApiError> {
|
||||
let state = Arc::new(Mutex::new(PrometheusHandler {
|
||||
encoder: BufferedTextEncoder::new(),
|
||||
metrics,
|
||||
}));
|
||||
|
||||
endpoint::make_router()
|
||||
.get("/metrics", move |r| {
|
||||
let state = state.clone();
|
||||
request_span(r, move |b| prometheus_metrics_handler(b, state))
|
||||
})
|
||||
.get("/v1/status", status_handler)
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
http_listener: TcpListener,
|
||||
metrics: AppMetrics,
|
||||
) -> anyhow::Result<Infallible> {
|
||||
scopeguard::defer! {
|
||||
info!("http has shut down");
|
||||
}
|
||||
|
||||
let service = || RouterService::new(make_router(metrics).build()?);
|
||||
|
||||
hyper::Server::from_tcp(http_listener)?
|
||||
.serve(service().map_err(|e| anyhow!(e))?)
|
||||
.await?;
|
||||
|
||||
bail!("hyper server without shutdown handling cannot shutdown successfully");
|
||||
}
|
||||
|
||||
struct PrometheusHandler {
|
||||
encoder: BufferedTextEncoder,
|
||||
metrics: AppMetrics,
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
pub struct AppMetrics {
|
||||
#[metric(namespace = "jemalloc")]
|
||||
pub jemalloc: Option<jemalloc::MetricRecorder>,
|
||||
#[metric(flatten)]
|
||||
pub neon_metrics: NeonMetrics,
|
||||
#[metric(flatten)]
|
||||
pub proxy: &'static crate::metrics::Metrics,
|
||||
}
|
||||
|
||||
async fn prometheus_metrics_handler(
|
||||
_req: Request<Body>,
|
||||
state: Arc<Mutex<PrometheusHandler>>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
let span = info_span!("blocking");
|
||||
let body = tokio::task::spawn_blocking(move || {
|
||||
let _span = span.entered();
|
||||
|
||||
let mut state = state.lock().unwrap();
|
||||
let PrometheusHandler { encoder, metrics } = &mut *state;
|
||||
|
||||
metrics
|
||||
.collect_group_into(&mut *encoder)
|
||||
.unwrap_or_else(|infallible| match infallible {});
|
||||
|
||||
let body = encoder.finish();
|
||||
|
||||
tracing::info!(
|
||||
bytes = body.len(),
|
||||
elapsed_ms = started_at.elapsed().as_millis(),
|
||||
"responded /metrics"
|
||||
);
|
||||
|
||||
body
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let response = Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, "text/plain; version=0.0.4")
|
||||
.body(Body::from(body))
|
||||
.unwrap();
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
252
proxy/core/src/intern.rs
Normal file
252
proxy/core/src/intern.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
use std::{
|
||||
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
|
||||
};
|
||||
|
||||
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
|
||||
use rustc_hash::FxHasher;
|
||||
|
||||
use crate::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
|
||||
pub trait InternId: Sized + 'static {
|
||||
fn get_interner() -> &'static StringInterner<Self>;
|
||||
}
|
||||
|
||||
pub struct StringInterner<Id> {
|
||||
inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
|
||||
_id: PhantomData<Id>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
|
||||
pub struct InternedString<Id> {
|
||||
inner: Spur,
|
||||
_id: PhantomData<Id>,
|
||||
}
|
||||
|
||||
impl<Id: InternId> std::fmt::Display for InternedString<Id> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.as_str().fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: InternId> InternedString<Id> {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
Id::get_interner().inner.resolve(&self.inner)
|
||||
}
|
||||
pub fn get(s: &str) -> Option<Self> {
|
||||
Id::get_interner().get(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: InternId> AsRef<str> for InternedString<Id> {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: InternId> std::ops::Deref for InternedString<Id> {
|
||||
type Target = str;
|
||||
fn deref(&self) -> &str {
|
||||
self.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
|
||||
fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
|
||||
struct Visitor<Id>(PhantomData<Id>);
|
||||
impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
|
||||
type Value = InternedString<Id>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a string")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(Id::get_interner().get_or_intern(v))
|
||||
}
|
||||
}
|
||||
d.deserialize_str(Visitor::<Id>(PhantomData))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: InternId> serde::Serialize for InternedString<Id> {
|
||||
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||
self.as_str().serialize(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: InternId> StringInterner<Id> {
|
||||
pub fn new() -> Self {
|
||||
StringInterner {
|
||||
inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
|
||||
Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
|
||||
// unbounded
|
||||
MemoryLimits::for_memory_usage(usize::MAX),
|
||||
BuildHasherDefault::<FxHasher>::default(),
|
||||
),
|
||||
_id: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.inner.is_empty()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.inner.len()
|
||||
}
|
||||
|
||||
pub fn current_memory_usage(&self) -> usize {
|
||||
self.inner.current_memory_usage()
|
||||
}
|
||||
|
||||
pub fn get_or_intern(&self, s: &str) -> InternedString<Id> {
|
||||
InternedString {
|
||||
inner: self.inner.get_or_intern(s),
|
||||
_id: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self, s: &str) -> Option<InternedString<Id>> {
|
||||
Some(InternedString {
|
||||
inner: self.inner.get(s)?,
|
||||
_id: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
|
||||
type Output = str;
|
||||
|
||||
fn index(&self, index: InternedString<Id>) -> &Self::Output {
|
||||
self.inner.resolve(&index.inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: InternId> Default for StringInterner<Id> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct RoleNameTag;
|
||||
impl InternId for RoleNameTag {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
pub type RoleNameInt = InternedString<RoleNameTag>;
|
||||
impl From<&RoleName> for RoleNameInt {
|
||||
fn from(value: &RoleName) -> Self {
|
||||
RoleNameTag::get_interner().get_or_intern(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct EndpointIdTag;
|
||||
impl InternId for EndpointIdTag {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
pub type EndpointIdInt = InternedString<EndpointIdTag>;
|
||||
impl From<&EndpointId> for EndpointIdInt {
|
||||
fn from(value: &EndpointId) -> Self {
|
||||
EndpointIdTag::get_interner().get_or_intern(value)
|
||||
}
|
||||
}
|
||||
impl From<EndpointId> for EndpointIdInt {
|
||||
fn from(value: EndpointId) -> Self {
|
||||
EndpointIdTag::get_interner().get_or_intern(&value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct BranchIdTag;
|
||||
impl InternId for BranchIdTag {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
pub type BranchIdInt = InternedString<BranchIdTag>;
|
||||
impl From<&BranchId> for BranchIdInt {
|
||||
fn from(value: &BranchId) -> Self {
|
||||
BranchIdTag::get_interner().get_or_intern(value)
|
||||
}
|
||||
}
|
||||
impl From<BranchId> for BranchIdInt {
|
||||
fn from(value: BranchId) -> Self {
|
||||
BranchIdTag::get_interner().get_or_intern(&value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct ProjectIdTag;
|
||||
impl InternId for ProjectIdTag {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
pub type ProjectIdInt = InternedString<ProjectIdTag>;
|
||||
impl From<&ProjectId> for ProjectIdInt {
|
||||
fn from(value: &ProjectId) -> Self {
|
||||
ProjectIdTag::get_interner().get_or_intern(value)
|
||||
}
|
||||
}
|
||||
impl From<ProjectId> for ProjectIdInt {
|
||||
fn from(value: ProjectId) -> Self {
|
||||
ProjectIdTag::get_interner().get_or_intern(&value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use crate::intern::StringInterner;
|
||||
|
||||
use super::InternId;
|
||||
|
||||
struct MyId;
|
||||
impl InternId for MyId {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_many_strings() {
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use rand_distr::Zipf;
|
||||
|
||||
let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
|
||||
let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
|
||||
|
||||
let interner = MyId::get_interner();
|
||||
|
||||
const N: usize = 100_000;
|
||||
let mut verify = Vec::with_capacity(N);
|
||||
for endpoint in endpoints.take(N) {
|
||||
let endpoint = format!("ep-string-interning-{endpoint}");
|
||||
let key = interner.get_or_intern(&endpoint);
|
||||
verify.push((endpoint, key));
|
||||
}
|
||||
|
||||
for (s, key) in verify {
|
||||
assert_eq!(interner[key], s);
|
||||
}
|
||||
|
||||
// 2031616/59861 = 34 bytes per string
|
||||
assert_eq!(interner.len(), 59_861);
|
||||
// will have other overhead for the internal hashmaps that are not accounted for.
|
||||
assert_eq!(interner.current_memory_usage(), 2_031_616);
|
||||
}
|
||||
}
|
||||
116
proxy/core/src/jemalloc.rs
Normal file
116
proxy/core/src/jemalloc.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use measured::{
|
||||
label::NoLabels,
|
||||
metric::{
|
||||
gauge::GaugeState, group::Encoding, name::MetricNameEncoder, MetricEncoding,
|
||||
MetricFamilyEncoding, MetricType,
|
||||
},
|
||||
text::TextEncoder,
|
||||
LabelGroup, MetricGroup,
|
||||
};
|
||||
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version};
|
||||
|
||||
pub struct MetricRecorder {
|
||||
epoch: epoch_mib,
|
||||
inner: Metrics,
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
struct Metrics {
|
||||
active_bytes: JemallocGaugeFamily<stats::active_mib>,
|
||||
allocated_bytes: JemallocGaugeFamily<stats::allocated_mib>,
|
||||
mapped_bytes: JemallocGaugeFamily<stats::mapped_mib>,
|
||||
metadata_bytes: JemallocGaugeFamily<stats::metadata_mib>,
|
||||
resident_bytes: JemallocGaugeFamily<stats::resident_mib>,
|
||||
retained_bytes: JemallocGaugeFamily<stats::retained_mib>,
|
||||
}
|
||||
|
||||
impl<Enc: Encoding> MetricGroup<Enc> for MetricRecorder
|
||||
where
|
||||
Metrics: MetricGroup<Enc>,
|
||||
{
|
||||
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
|
||||
if self.epoch.advance().is_ok() {
|
||||
self.inner.collect_group_into(enc)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricRecorder {
|
||||
pub fn new() -> Result<Self, anyhow::Error> {
|
||||
tracing::info!(
|
||||
config = config::malloc_conf::read()?,
|
||||
version = version::read()?,
|
||||
"starting jemalloc recorder"
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
epoch: epoch::mib()?,
|
||||
inner: Metrics {
|
||||
active_bytes: JemallocGaugeFamily(stats::active::mib()?),
|
||||
allocated_bytes: JemallocGaugeFamily(stats::allocated::mib()?),
|
||||
mapped_bytes: JemallocGaugeFamily(stats::mapped::mib()?),
|
||||
metadata_bytes: JemallocGaugeFamily(stats::metadata::mib()?),
|
||||
resident_bytes: JemallocGaugeFamily(stats::resident::mib()?),
|
||||
retained_bytes: JemallocGaugeFamily(stats::retained::mib()?),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct JemallocGauge<T>(PhantomData<T>);
|
||||
|
||||
impl<T> Default for JemallocGauge<T> {
|
||||
fn default() -> Self {
|
||||
JemallocGauge(PhantomData)
|
||||
}
|
||||
}
|
||||
impl<T> MetricType for JemallocGauge<T> {
|
||||
type Metadata = T;
|
||||
}
|
||||
|
||||
struct JemallocGaugeFamily<T>(T);
|
||||
impl<M, T: Encoding> MetricFamilyEncoding<T> for JemallocGaugeFamily<M>
|
||||
where
|
||||
JemallocGauge<M>: MetricEncoding<T, Metadata = M>,
|
||||
{
|
||||
fn collect_family_into(&self, name: impl MetricNameEncoder, enc: &mut T) -> Result<(), T::Err> {
|
||||
JemallocGauge::write_type(&name, enc)?;
|
||||
JemallocGauge(PhantomData).collect_into(&self.0, NoLabels, name, enc)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! jemalloc_gauge {
|
||||
($stat:ident, $mib:ident) => {
|
||||
impl<W: std::io::Write> MetricEncoding<TextEncoder<W>> for JemallocGauge<stats::$mib> {
|
||||
fn write_type(
|
||||
name: impl MetricNameEncoder,
|
||||
enc: &mut TextEncoder<W>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
GaugeState::write_type(name, enc)
|
||||
}
|
||||
|
||||
fn collect_into(
|
||||
&self,
|
||||
mib: &stats::$mib,
|
||||
labels: impl LabelGroup,
|
||||
name: impl MetricNameEncoder,
|
||||
enc: &mut TextEncoder<W>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
if let Ok(v) = mib.read() {
|
||||
GaugeState::new(v as i64).collect_into(&(), labels, name, enc)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
jemalloc_gauge!(active, active_mib);
|
||||
jemalloc_gauge!(allocated, allocated_mib);
|
||||
jemalloc_gauge!(mapped, mapped_mib);
|
||||
jemalloc_gauge!(metadata, metadata_mib);
|
||||
jemalloc_gauge!(resident, resident_mib);
|
||||
jemalloc_gauge!(retained, retained_mib);
|
||||
185
proxy/core/src/lib.rs
Normal file
185
proxy/core/src/lib.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use intern::{EndpointIdInt, EndpointIdTag, InternId};
|
||||
use tokio::task::JoinError;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
pub mod auth;
|
||||
pub mod cache;
|
||||
pub mod cancellation;
|
||||
pub mod compute;
|
||||
pub mod config;
|
||||
pub mod console;
|
||||
pub mod context;
|
||||
pub mod error;
|
||||
pub mod http;
|
||||
pub mod intern;
|
||||
pub mod jemalloc;
|
||||
pub mod logging;
|
||||
pub mod metrics;
|
||||
pub mod parse;
|
||||
pub mod protocol2;
|
||||
pub mod proxy;
|
||||
pub mod rate_limiter;
|
||||
pub mod redis;
|
||||
pub mod sasl;
|
||||
pub mod scram;
|
||||
pub mod serverless;
|
||||
pub mod stream;
|
||||
pub mod url;
|
||||
pub mod usage_metrics;
|
||||
pub mod waiters;
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
let mut hangup = signal(SignalKind::hangup())?;
|
||||
let mut interrupt = signal(SignalKind::interrupt())?;
|
||||
let mut terminate = signal(SignalKind::terminate())?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Hangup is commonly used for config reload.
|
||||
_ = hangup.recv() => {
|
||||
warn!("received SIGHUP; config reload is not supported");
|
||||
}
|
||||
// Shut down the whole application.
|
||||
_ = interrupt.recv() => {
|
||||
warn!("received SIGINT, exiting immediately");
|
||||
bail!("interrupted");
|
||||
}
|
||||
_ = terminate.recv() => {
|
||||
warn!("received SIGTERM, shutting down once all existing connections have closed");
|
||||
token.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
|
||||
r.context("join error").and_then(|x| x)
|
||||
}
|
||||
|
||||
macro_rules! smol_str_wrapper {
|
||||
($name:ident) => {
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct $name(smol_str::SmolStr);
|
||||
|
||||
impl $name {
|
||||
pub fn as_str(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for $name {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::cmp::PartialEq<T> for $name
|
||||
where
|
||||
smol_str::SmolStr: std::cmp::PartialEq<T>,
|
||||
{
|
||||
fn eq(&self, other: &T) -> bool {
|
||||
self.0.eq(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for $name
|
||||
where
|
||||
smol_str::SmolStr: From<T>,
|
||||
{
|
||||
fn from(x: T) -> Self {
|
||||
Self(x.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<str> for $name {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for $name {
|
||||
type Target = str;
|
||||
fn deref(&self) -> &str {
|
||||
&*self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Deserialize<'de> for $name {
|
||||
fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
|
||||
<smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for $name {
|
||||
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||
self.0.serialize(s)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const POOLER_SUFFIX: &str = "-pooler";
|
||||
|
||||
impl EndpointId {
|
||||
fn normalize(&self) -> Self {
|
||||
if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
|
||||
stripped.into()
|
||||
} else {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_intern(&self) -> EndpointIdInt {
|
||||
if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
|
||||
EndpointIdTag::get_interner().get_or_intern(stripped)
|
||||
} else {
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 90% of role name strings are 20 characters or less.
|
||||
smol_str_wrapper!(RoleName);
|
||||
// 50% of endpoint strings are 23 characters or less.
|
||||
smol_str_wrapper!(EndpointId);
|
||||
// 50% of branch strings are 23 characters or less.
|
||||
smol_str_wrapper!(BranchId);
|
||||
// 90% of project strings are 23 characters or less.
|
||||
smol_str_wrapper!(ProjectId);
|
||||
|
||||
// will usually equal endpoint ID
|
||||
smol_str_wrapper!(EndpointCacheKey);
|
||||
|
||||
smol_str_wrapper!(DbName);
|
||||
|
||||
// postgres hostname, will likely be a port:ip addr
|
||||
smol_str_wrapper!(Host);
|
||||
|
||||
// Endpoints are a bit tricky. Rare they might be branches or projects.
|
||||
impl EndpointId {
|
||||
pub fn is_endpoint(&self) -> bool {
|
||||
self.0.starts_with("ep-")
|
||||
}
|
||||
pub fn is_branch(&self) -> bool {
|
||||
self.0.starts_with("br-")
|
||||
}
|
||||
pub fn is_project(&self) -> bool {
|
||||
!self.is_endpoint() && !self.is_branch()
|
||||
}
|
||||
pub fn as_branch(&self) -> BranchId {
|
||||
BranchId(self.0.clone())
|
||||
}
|
||||
pub fn as_project(&self) -> ProjectId {
|
||||
ProjectId(self.0.clone())
|
||||
}
|
||||
}
|
||||
48
proxy/core/src/logging.rs
Normal file
48
proxy/core/src/logging.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use tracing_opentelemetry::OpenTelemetryLayer;
|
||||
use tracing_subscriber::{
|
||||
filter::{EnvFilter, LevelFilter},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
/// Initialize logging and OpenTelemetry tracing and exporter.
|
||||
///
|
||||
/// Logging can be configured using `RUST_LOG` environment variable.
|
||||
///
|
||||
/// OpenTelemetry is configured with OTLP/HTTP exporter. It picks up
|
||||
/// configuration from environment variables. For example, to change the
|
||||
/// destination, set `OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4318`.
|
||||
/// See <https://opentelemetry.io/docs/reference/specification/sdk-environment-variables>
|
||||
pub async fn init() -> anyhow::Result<LoggingGuard> {
|
||||
let env_filter = EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.from_env_lossy()
|
||||
.add_directive("azure_core::policies::transport=off".parse().unwrap());
|
||||
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false)
|
||||
.with_writer(std::io::stderr)
|
||||
.with_target(false);
|
||||
|
||||
let otlp_layer = tracing_utils::init_tracing("proxy")
|
||||
.await
|
||||
.map(OpenTelemetryLayer::new);
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(otlp_layer)
|
||||
.with(fmt_layer)
|
||||
.try_init()?;
|
||||
|
||||
Ok(LoggingGuard)
|
||||
}
|
||||
|
||||
pub struct LoggingGuard;
|
||||
|
||||
impl Drop for LoggingGuard {
|
||||
fn drop(&mut self) {
|
||||
// Shutdown trace pipeline gracefully, so that it has a chance to send any
|
||||
// pending traces before we exit.
|
||||
tracing::info!("shutting down the tracing machinery");
|
||||
tracing_utils::shutdown_tracing();
|
||||
}
|
||||
}
|
||||
623
proxy/core/src/metrics.rs
Normal file
623
proxy/core/src/metrics.rs
Normal file
@@ -0,0 +1,623 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use lasso::ThreadedRodeo;
|
||||
use measured::{
|
||||
label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet},
|
||||
metric::{histogram::Thresholds, name::MetricName},
|
||||
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
|
||||
LabelGroup, MetricGroup,
|
||||
};
|
||||
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
|
||||
|
||||
use tokio::time::{self, Instant};
|
||||
|
||||
use crate::console::messages::ColdStartInfo;
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
pub struct Metrics {
|
||||
#[metric(namespace = "proxy")]
|
||||
#[metric(init = ProxyMetrics::new(thread_pool))]
|
||||
pub proxy: ProxyMetrics,
|
||||
|
||||
#[metric(namespace = "wake_compute_lock")]
|
||||
pub wake_compute_lock: ApiLockMetrics,
|
||||
}
|
||||
|
||||
static SELF: OnceLock<Metrics> = OnceLock::new();
|
||||
impl Metrics {
|
||||
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
|
||||
SELF.set(Metrics::new(thread_pool))
|
||||
.ok()
|
||||
.expect("proxy metrics must not be installed more than once");
|
||||
}
|
||||
|
||||
pub fn get() -> &'static Self {
|
||||
#[cfg(test)]
|
||||
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
|
||||
|
||||
#[cfg(not(test))]
|
||||
SELF.get()
|
||||
.expect("proxy metrics must be installed by the main() function")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
pub struct ProxyMetrics {
|
||||
#[metric(flatten)]
|
||||
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
|
||||
#[metric(flatten)]
|
||||
pub client_connections: CounterPairVec<NumClientConnectionsGauge>,
|
||||
#[metric(flatten)]
|
||||
pub connection_requests: CounterPairVec<NumConnectionRequestsGauge>,
|
||||
#[metric(flatten)]
|
||||
pub http_endpoint_pools: HttpEndpointPools,
|
||||
|
||||
/// Time it took for proxy to establish a connection to the compute endpoint.
|
||||
// largest bucket = 2^16 * 0.5ms = 32s
|
||||
#[metric(metadata = Thresholds::exponential_buckets(0.0005, 2.0))]
|
||||
pub compute_connection_latency_seconds: HistogramVec<ComputeConnectionLatencySet, 16>,
|
||||
|
||||
/// Time it took for proxy to receive a response from control plane.
|
||||
#[metric(
|
||||
// largest bucket = 2^16 * 0.2ms = 13s
|
||||
metadata = Thresholds::exponential_buckets(0.0002, 2.0),
|
||||
)]
|
||||
pub console_request_latency: HistogramVec<ConsoleRequestSet, 16>,
|
||||
|
||||
/// Time it takes to acquire a token to call console plane.
|
||||
// largest bucket = 3^16 * 0.05ms = 2.15s
|
||||
#[metric(metadata = Thresholds::exponential_buckets(0.00005, 3.0))]
|
||||
pub control_plane_token_acquire_seconds: Histogram<16>,
|
||||
|
||||
/// Size of the HTTP request body lengths.
|
||||
// smallest bucket = 16 bytes
|
||||
// largest bucket = 4^12 * 16 bytes = 256MB
|
||||
#[metric(metadata = Thresholds::exponential_buckets(16.0, 4.0))]
|
||||
pub http_conn_content_length_bytes: HistogramVec<StaticLabelSet<HttpDirection>, 12>,
|
||||
|
||||
/// Time it takes to reclaim unused connection pools.
|
||||
#[metric(metadata = Thresholds::exponential_buckets(1e-6, 2.0))]
|
||||
pub http_pool_reclaimation_lag_seconds: Histogram<16>,
|
||||
|
||||
/// Number of opened connections to a database.
|
||||
pub http_pool_opened_connections: Gauge,
|
||||
|
||||
/// Number of cache hits/misses for allowed ips.
|
||||
pub allowed_ips_cache_misses: CounterVec<StaticLabelSet<CacheOutcome>>,
|
||||
|
||||
/// Number of allowed ips
|
||||
#[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))]
|
||||
pub allowed_ips_number: Histogram<10>,
|
||||
|
||||
/// Number of connections (per sni).
|
||||
pub accepted_connections_by_sni: CounterVec<StaticLabelSet<SniKind>>,
|
||||
|
||||
/// Number of connection failures (per kind).
|
||||
pub connection_failures_total: CounterVec<StaticLabelSet<ConnectionFailureKind>>,
|
||||
|
||||
/// Number of wake-up failures (per kind).
|
||||
pub connection_failures_breakdown: CounterVec<ConnectionFailuresBreakdownSet>,
|
||||
|
||||
/// Number of bytes sent/received between all clients and backends.
|
||||
pub io_bytes: CounterVec<StaticLabelSet<Direction>>,
|
||||
|
||||
/// Number of errors by a given classification.
|
||||
pub errors_total: CounterVec<StaticLabelSet<crate::error::ErrorKind>>,
|
||||
|
||||
/// Number of cancellation requests (per found/not_found).
|
||||
pub cancellation_requests_total: CounterVec<CancellationRequestSet>,
|
||||
|
||||
/// Number of errors by a given classification
|
||||
pub redis_errors_total: CounterVec<RedisErrorsSet>,
|
||||
|
||||
/// Number of TLS handshake failures
|
||||
pub tls_handshake_failures: Counter,
|
||||
|
||||
/// Number of connection requests affected by authentication rate limits
|
||||
pub requests_auth_rate_limits_total: Counter,
|
||||
|
||||
/// HLL approximate cardinality of endpoints that are connecting
|
||||
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
|
||||
|
||||
/// Number of endpoints affected by errors of a given classification
|
||||
pub endpoints_affected_by_errors: HyperLogLogVec<StaticLabelSet<crate::error::ErrorKind>, 32>,
|
||||
|
||||
/// Number of endpoints affected by authentication rate limits
|
||||
pub endpoints_auth_rate_limits: HyperLogLog<32>,
|
||||
|
||||
/// Number of invalid endpoints (per protocol, per rejected).
|
||||
pub invalid_endpoints_total: CounterVec<InvalidEndpointsSet>,
|
||||
|
||||
/// Number of retries (per outcome, per retry_type).
|
||||
#[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))]
|
||||
pub retries_metric: HistogramVec<RetriesMetricSet, 9>,
|
||||
|
||||
/// Number of events consumed from redis (per event type).
|
||||
pub redis_events_count: CounterVec<StaticLabelSet<RedisEventsCount>>,
|
||||
|
||||
#[metric(namespace = "connect_compute_lock")]
|
||||
pub connect_compute_lock: ApiLockMetrics,
|
||||
|
||||
#[metric(namespace = "scram_pool")]
|
||||
#[metric(init = thread_pool)]
|
||||
pub scram_pool: Arc<ThreadPoolMetrics>,
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new())]
|
||||
pub struct ApiLockMetrics {
|
||||
/// Number of semaphores registered in this api lock
|
||||
pub semaphores_registered: Counter,
|
||||
/// Number of semaphores unregistered in this api lock
|
||||
pub semaphores_unregistered: Counter,
|
||||
/// Time it takes to reclaim unused semaphores in the api lock
|
||||
#[metric(metadata = Thresholds::exponential_buckets(1e-6, 2.0))]
|
||||
pub reclamation_lag_seconds: Histogram<16>,
|
||||
/// Time it takes to acquire a semaphore lock
|
||||
#[metric(metadata = Thresholds::exponential_buckets(1e-4, 2.0))]
|
||||
pub semaphore_acquire_seconds: Histogram<16>,
|
||||
}
|
||||
|
||||
impl Default for ApiLockMetrics {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "direction")]
|
||||
pub enum HttpDirection {
|
||||
Request,
|
||||
Response,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "direction")]
|
||||
pub enum Direction {
|
||||
Tx,
|
||||
Rx,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
#[label(singleton = "protocol")]
|
||||
pub enum Protocol {
|
||||
Http,
|
||||
Ws,
|
||||
Tcp,
|
||||
SniRouter,
|
||||
}
|
||||
|
||||
impl Protocol {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Protocol::Http => "http",
|
||||
Protocol::Ws => "ws",
|
||||
Protocol::Tcp => "tcp",
|
||||
Protocol::SniRouter => "sni_router",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Protocol {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
pub enum Bool {
|
||||
True,
|
||||
False,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "outcome")]
|
||||
pub enum Outcome {
|
||||
Success,
|
||||
Failed,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "outcome")]
|
||||
pub enum CacheOutcome {
|
||||
Hit,
|
||||
Miss,
|
||||
}
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = ConsoleRequestSet)]
|
||||
pub struct ConsoleRequest<'a> {
|
||||
#[label(dynamic_with = ThreadedRodeo, default)]
|
||||
pub request: &'a str,
|
||||
}
|
||||
|
||||
#[derive(MetricGroup, Default)]
|
||||
pub struct HttpEndpointPools {
|
||||
/// Number of endpoints we have registered pools for
|
||||
pub http_pool_endpoints_registered_total: Counter,
|
||||
/// Number of endpoints we have unregistered pools for
|
||||
pub http_pool_endpoints_unregistered_total: Counter,
|
||||
}
|
||||
|
||||
pub struct HttpEndpointPoolsGuard<'a> {
|
||||
dec: &'a Counter,
|
||||
}
|
||||
|
||||
impl Drop for HttpEndpointPoolsGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.dec.inc();
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpEndpointPools {
|
||||
pub fn guard(&self) -> HttpEndpointPoolsGuard {
|
||||
self.http_pool_endpoints_registered_total.inc();
|
||||
HttpEndpointPoolsGuard {
|
||||
dec: &self.http_pool_endpoints_unregistered_total,
|
||||
}
|
||||
}
|
||||
}
|
||||
pub struct NumDbConnectionsGauge;
|
||||
impl CounterPairAssoc for NumDbConnectionsGauge {
|
||||
const INC_NAME: &'static MetricName = MetricName::from_str("opened_db_connections_total");
|
||||
const DEC_NAME: &'static MetricName = MetricName::from_str("closed_db_connections_total");
|
||||
const INC_HELP: &'static str = "Number of opened connections to a database.";
|
||||
const DEC_HELP: &'static str = "Number of closed connections to a database.";
|
||||
type LabelGroupSet = StaticLabelSet<Protocol>;
|
||||
}
|
||||
pub type NumDbConnectionsGuard<'a> = metrics::MeasuredCounterPairGuard<'a, NumDbConnectionsGauge>;
|
||||
|
||||
pub struct NumClientConnectionsGauge;
|
||||
impl CounterPairAssoc for NumClientConnectionsGauge {
|
||||
const INC_NAME: &'static MetricName = MetricName::from_str("opened_client_connections_total");
|
||||
const DEC_NAME: &'static MetricName = MetricName::from_str("closed_client_connections_total");
|
||||
const INC_HELP: &'static str = "Number of opened connections from a client.";
|
||||
const DEC_HELP: &'static str = "Number of closed connections from a client.";
|
||||
type LabelGroupSet = StaticLabelSet<Protocol>;
|
||||
}
|
||||
pub type NumClientConnectionsGuard<'a> =
|
||||
metrics::MeasuredCounterPairGuard<'a, NumClientConnectionsGauge>;
|
||||
|
||||
pub struct NumConnectionRequestsGauge;
|
||||
impl CounterPairAssoc for NumConnectionRequestsGauge {
|
||||
const INC_NAME: &'static MetricName = MetricName::from_str("accepted_connections_total");
|
||||
const DEC_NAME: &'static MetricName = MetricName::from_str("closed_connections_total");
|
||||
const INC_HELP: &'static str = "Number of client connections accepted.";
|
||||
const DEC_HELP: &'static str = "Number of client connections closed.";
|
||||
type LabelGroupSet = StaticLabelSet<Protocol>;
|
||||
}
|
||||
pub type NumConnectionRequestsGuard<'a> =
|
||||
metrics::MeasuredCounterPairGuard<'a, NumConnectionRequestsGauge>;
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = ComputeConnectionLatencySet)]
|
||||
pub struct ComputeConnectionLatencyGroup {
|
||||
protocol: Protocol,
|
||||
cold_start_info: ColdStartInfo,
|
||||
outcome: ConnectOutcome,
|
||||
excluded: LatencyExclusions,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
pub enum LatencyExclusions {
|
||||
Client,
|
||||
ClientAndCplane,
|
||||
ClientCplaneCompute,
|
||||
ClientCplaneComputeRetry,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "kind")]
|
||||
pub enum SniKind {
|
||||
Sni,
|
||||
NoSni,
|
||||
PasswordHack,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "kind")]
|
||||
pub enum ConnectionFailureKind {
|
||||
ComputeCached,
|
||||
ComputeUncached,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "kind")]
|
||||
pub enum WakeupFailureKind {
|
||||
BadComputeAddress,
|
||||
ApiTransportError,
|
||||
QuotaExceeded,
|
||||
ApiConsoleLocked,
|
||||
ApiConsoleBadRequest,
|
||||
ApiConsoleOtherServerError,
|
||||
ApiConsoleOtherError,
|
||||
TimeoutError,
|
||||
}
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = ConnectionFailuresBreakdownSet)]
|
||||
pub struct ConnectionFailuresBreakdownGroup {
|
||||
pub kind: WakeupFailureKind,
|
||||
pub retry: Bool,
|
||||
}
|
||||
|
||||
#[derive(LabelGroup, Copy, Clone)]
|
||||
#[label(set = RedisErrorsSet)]
|
||||
pub struct RedisErrors<'a> {
|
||||
#[label(dynamic_with = ThreadedRodeo, default)]
|
||||
pub channel: &'a str,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
pub enum CancellationSource {
|
||||
FromClient,
|
||||
FromRedis,
|
||||
Local,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
pub enum CancellationOutcome {
|
||||
NotFound,
|
||||
Found,
|
||||
}
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = CancellationRequestSet)]
|
||||
pub struct CancellationRequest {
|
||||
pub source: CancellationSource,
|
||||
pub kind: CancellationOutcome,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Waiting {
|
||||
Cplane,
|
||||
Client,
|
||||
Compute,
|
||||
RetryTimeout,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Accumulated {
|
||||
cplane: time::Duration,
|
||||
client: time::Duration,
|
||||
compute: time::Duration,
|
||||
retry: time::Duration,
|
||||
}
|
||||
|
||||
pub struct LatencyTimer {
|
||||
// time since the stopwatch was started
|
||||
start: time::Instant,
|
||||
// time since the stopwatch was stopped
|
||||
stop: Option<time::Instant>,
|
||||
// accumulated time on the stopwatch
|
||||
accumulated: Accumulated,
|
||||
// label data
|
||||
protocol: Protocol,
|
||||
cold_start_info: ColdStartInfo,
|
||||
outcome: ConnectOutcome,
|
||||
}
|
||||
|
||||
impl LatencyTimer {
|
||||
pub fn new(protocol: Protocol) -> Self {
|
||||
Self {
|
||||
start: time::Instant::now(),
|
||||
stop: None,
|
||||
accumulated: Accumulated::default(),
|
||||
protocol,
|
||||
cold_start_info: ColdStartInfo::Unknown,
|
||||
// assume failed unless otherwise specified
|
||||
outcome: ConnectOutcome::Failed,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unpause(&mut self, start: Instant, waiting_for: Waiting) {
|
||||
let dur = start.elapsed();
|
||||
match waiting_for {
|
||||
Waiting::Cplane => self.accumulated.cplane += dur,
|
||||
Waiting::Client => self.accumulated.client += dur,
|
||||
Waiting::Compute => self.accumulated.compute += dur,
|
||||
Waiting::RetryTimeout => self.accumulated.retry += dur,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cold_start_info(&mut self, cold_start_info: ColdStartInfo) {
|
||||
self.cold_start_info = cold_start_info;
|
||||
}
|
||||
|
||||
pub fn success(&mut self) {
|
||||
// stop the stopwatch and record the time that we have accumulated
|
||||
self.stop = Some(time::Instant::now());
|
||||
|
||||
// success
|
||||
self.outcome = ConnectOutcome::Success;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
pub enum ConnectOutcome {
|
||||
Success,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl Drop for LatencyTimer {
|
||||
fn drop(&mut self) {
|
||||
let duration = self
|
||||
.stop
|
||||
.unwrap_or_else(time::Instant::now)
|
||||
.duration_since(self.start);
|
||||
|
||||
let metric = &Metrics::get().proxy.compute_connection_latency_seconds;
|
||||
|
||||
// Excluding client communication from the accumulated time.
|
||||
metric.observe(
|
||||
ComputeConnectionLatencyGroup {
|
||||
protocol: self.protocol,
|
||||
cold_start_info: self.cold_start_info,
|
||||
outcome: self.outcome,
|
||||
excluded: LatencyExclusions::Client,
|
||||
},
|
||||
duration
|
||||
.saturating_sub(self.accumulated.client)
|
||||
.as_secs_f64(),
|
||||
);
|
||||
|
||||
// Exclude client and cplane communication from the accumulated time.
|
||||
let accumulated_total = self.accumulated.client + self.accumulated.cplane;
|
||||
metric.observe(
|
||||
ComputeConnectionLatencyGroup {
|
||||
protocol: self.protocol,
|
||||
cold_start_info: self.cold_start_info,
|
||||
outcome: self.outcome,
|
||||
excluded: LatencyExclusions::ClientAndCplane,
|
||||
},
|
||||
duration.saturating_sub(accumulated_total).as_secs_f64(),
|
||||
);
|
||||
|
||||
// Exclude client cplane, compue communication from the accumulated time.
|
||||
let accumulated_total =
|
||||
self.accumulated.client + self.accumulated.cplane + self.accumulated.compute;
|
||||
metric.observe(
|
||||
ComputeConnectionLatencyGroup {
|
||||
protocol: self.protocol,
|
||||
cold_start_info: self.cold_start_info,
|
||||
outcome: self.outcome,
|
||||
excluded: LatencyExclusions::ClientCplaneCompute,
|
||||
},
|
||||
duration.saturating_sub(accumulated_total).as_secs_f64(),
|
||||
);
|
||||
|
||||
// Exclude client cplane, compue, retry communication from the accumulated time.
|
||||
let accumulated_total = self.accumulated.client
|
||||
+ self.accumulated.cplane
|
||||
+ self.accumulated.compute
|
||||
+ self.accumulated.retry;
|
||||
metric.observe(
|
||||
ComputeConnectionLatencyGroup {
|
||||
protocol: self.protocol,
|
||||
cold_start_info: self.cold_start_info,
|
||||
outcome: self.outcome,
|
||||
excluded: LatencyExclusions::ClientCplaneComputeRetry,
|
||||
},
|
||||
duration.saturating_sub(accumulated_total).as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bool> for Bool {
|
||||
fn from(value: bool) -> Self {
|
||||
if value {
|
||||
Bool::True
|
||||
} else {
|
||||
Bool::False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = InvalidEndpointsSet)]
|
||||
pub struct InvalidEndpointsGroup {
|
||||
pub protocol: Protocol,
|
||||
pub rejected: Bool,
|
||||
pub outcome: ConnectOutcome,
|
||||
}
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = RetriesMetricSet)]
|
||||
pub struct RetriesMetricGroup {
|
||||
pub outcome: ConnectOutcome,
|
||||
pub retry_type: RetryType,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
pub enum RetryType {
|
||||
WakeCompute,
|
||||
ConnectToCompute,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
#[label(singleton = "event")]
|
||||
pub enum RedisEventsCount {
|
||||
EndpointCreated,
|
||||
BranchCreated,
|
||||
ProjectCreated,
|
||||
CancelSession,
|
||||
PasswordUpdate,
|
||||
AllowedIpsUpdate,
|
||||
}
|
||||
|
||||
pub struct ThreadPoolWorkers(usize);
|
||||
pub struct ThreadPoolWorkerId(pub usize);
|
||||
|
||||
impl LabelValue for ThreadPoolWorkerId {
|
||||
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
|
||||
v.write_int(self.0 as i64)
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelGroup for ThreadPoolWorkerId {
|
||||
fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) {
|
||||
v.write_value(LabelName::from_str("worker"), self);
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelGroupSet for ThreadPoolWorkers {
|
||||
type Group<'a> = ThreadPoolWorkerId;
|
||||
|
||||
fn cardinality(&self) -> Option<usize> {
|
||||
Some(self.0)
|
||||
}
|
||||
|
||||
fn encode_dense(&self, value: Self::Unique) -> Option<usize> {
|
||||
Some(value)
|
||||
}
|
||||
|
||||
fn decode_dense(&self, value: usize) -> Self::Group<'_> {
|
||||
ThreadPoolWorkerId(value)
|
||||
}
|
||||
|
||||
type Unique = usize;
|
||||
|
||||
fn encode(&self, value: Self::Group<'_>) -> Option<Self::Unique> {
|
||||
Some(value.0)
|
||||
}
|
||||
|
||||
fn decode(&self, value: &Self::Unique) -> Self::Group<'_> {
|
||||
ThreadPoolWorkerId(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelSet for ThreadPoolWorkers {
|
||||
type Value<'a> = ThreadPoolWorkerId;
|
||||
|
||||
fn dynamic_cardinality(&self) -> Option<usize> {
|
||||
Some(self.0)
|
||||
}
|
||||
|
||||
fn encode(&self, value: Self::Value<'_>) -> Option<usize> {
|
||||
(value.0 < self.0).then_some(value.0)
|
||||
}
|
||||
|
||||
fn decode(&self, value: usize) -> Self::Value<'_> {
|
||||
ThreadPoolWorkerId(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl FixedCardinalitySet for ThreadPoolWorkers {
|
||||
fn cardinality(&self) -> usize {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(workers: usize))]
|
||||
pub struct ThreadPoolMetrics {
|
||||
pub injector_queue_depth: Gauge,
|
||||
#[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_queue_depth: GaugeVec<ThreadPoolWorkers>,
|
||||
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_task_turns_total: CounterVec<ThreadPoolWorkers>,
|
||||
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
|
||||
}
|
||||
43
proxy/core/src/parse.rs
Normal file
43
proxy/core/src/parse.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
//! Small parsing helpers.
|
||||
|
||||
use std::ffi::CStr;
|
||||
|
||||
pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
|
||||
let cstr = CStr::from_bytes_until_nul(bytes).ok()?;
|
||||
let (_, other) = bytes.split_at(cstr.to_bytes_with_nul().len());
|
||||
Some((cstr, other))
|
||||
}
|
||||
|
||||
/// See <https://doc.rust-lang.org/std/primitive.slice.html#method.split_array_ref>.
|
||||
pub fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
|
||||
(bytes.len() >= N).then(|| {
|
||||
let (head, tail) = bytes.split_at(N);
|
||||
(head.try_into().unwrap(), tail)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_split_cstr() {
|
||||
assert!(split_cstr(b"").is_none());
|
||||
assert!(split_cstr(b"foo").is_none());
|
||||
|
||||
let (cstr, rest) = split_cstr(b"\0").expect("uh-oh");
|
||||
assert_eq!(cstr.to_bytes(), b"");
|
||||
assert_eq!(rest, b"");
|
||||
|
||||
let (cstr, rest) = split_cstr(b"foo\0bar").expect("uh-oh");
|
||||
assert_eq!(cstr.to_bytes(), b"foo");
|
||||
assert_eq!(rest, b"bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_at_const() {
|
||||
assert!(split_at_const::<0>(b"").is_some());
|
||||
assert!(split_at_const::<1>(b"").is_none());
|
||||
assert!(matches!(split_at_const::<1>(b"ok"), Some((b"o", b"k"))));
|
||||
}
|
||||
}
|
||||
349
proxy/core/src/protocol2.rs
Normal file
349
proxy/core/src/protocol2.rs
Normal file
@@ -0,0 +1,349 @@
|
||||
//! Proxy Protocol V2 implementation
|
||||
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
|
||||
pin_project! {
|
||||
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
|
||||
pub struct ChainRW<T> {
|
||||
#[pin]
|
||||
pub inner: T,
|
||||
buf: BytesMut,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
|
||||
#[inline]
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy Protocol Version 2 Header
|
||||
const HEADER: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
pub async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
mut read: T,
|
||||
) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
while buf.len() < 16 {
|
||||
let bytes_read = read.read_buf(&mut buf).await?;
|
||||
|
||||
// exit for bad header
|
||||
let len = usize::min(buf.len(), HEADER.len());
|
||||
if buf[..len] != HEADER[..len] {
|
||||
return Ok((ChainRW { inner: read, buf }, None));
|
||||
}
|
||||
|
||||
// if no more bytes available then exit
|
||||
if bytes_read == 0 {
|
||||
return Ok((ChainRW { inner: read, buf }, None));
|
||||
};
|
||||
}
|
||||
|
||||
let header = buf.split_to(16);
|
||||
|
||||
// The next byte (the 13th one) is the protocol version and command.
|
||||
// The highest four bits contains the version. As of this specification, it must
|
||||
// always be sent as \x2 and the receiver must only accept this value.
|
||||
let vc = header[12];
|
||||
let version = vc >> 4;
|
||||
let command = vc & 0b1111;
|
||||
if version != 2 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"invalid proxy protocol version. expected version 2",
|
||||
));
|
||||
}
|
||||
match command {
|
||||
// the connection was established on purpose by the proxy
|
||||
// without being relayed. The connection endpoints are the sender and the
|
||||
// receiver. Such connections exist when the proxy sends health-checks to the
|
||||
// server. The receiver must accept this connection as valid and must use the
|
||||
// real connection endpoints and discard the protocol block including the
|
||||
// family which is ignored.
|
||||
0 => {}
|
||||
// the connection was established on behalf of another node,
|
||||
// and reflects the original connection endpoints. The receiver must then use
|
||||
// the information provided in the protocol block to get original the address.
|
||||
1 => {}
|
||||
// other values are unassigned and must not be emitted by senders. Receivers
|
||||
// must drop connections presenting unexpected values here.
|
||||
_ => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"invalid proxy protocol command. expected local (0) or proxy (1)",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// The 14th byte contains the transport protocol and address family. The highest 4
|
||||
// bits contain the address family, the lowest 4 bits contain the protocol.
|
||||
let ft = header[13];
|
||||
let address_length = match ft {
|
||||
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
|
||||
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
|
||||
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
|
||||
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
|
||||
0x11 | 0x12 => 12,
|
||||
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
|
||||
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
|
||||
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
|
||||
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
|
||||
0x21 | 0x22 => 36,
|
||||
// unspecified or unix stream. ignore the addresses
|
||||
_ => 0,
|
||||
};
|
||||
|
||||
// The 15th and 16th bytes is the address length in bytes in network endian order.
|
||||
// It is used so that the receiver knows how many address bytes to skip even when
|
||||
// it does not implement the presented protocol. Thus the length of the protocol
|
||||
// header in bytes is always exactly 16 + this value. When a sender presents a
|
||||
// LOCAL connection, it should not present any address so it sets this field to
|
||||
// zero. Receivers MUST always consider this field to skip the appropriate number
|
||||
// of bytes and must not assume zero is presented for LOCAL connections. When a
|
||||
// receiver accepts an incoming connection showing an UNSPEC address family or
|
||||
// protocol, it may or may not decide to log the address information if present.
|
||||
let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
|
||||
if remaining_length < address_length {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"invalid proxy protocol length. not enough to fit requested IP addresses",
|
||||
));
|
||||
}
|
||||
drop(header);
|
||||
|
||||
while buf.len() < remaining_length as usize {
|
||||
if read.read_buf(&mut buf).await? == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"stream closed while waiting for proxy protocol addresses",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Starting from the 17th byte, addresses are presented in network byte order.
|
||||
// The address order is always the same :
|
||||
// - source layer 3 address in network byte order
|
||||
// - destination layer 3 address in network byte order
|
||||
// - source layer 4 address if any, in network byte order (port)
|
||||
// - destination layer 4 address if any, in network byte order (port)
|
||||
let addresses = buf.split_to(remaining_length as usize);
|
||||
let socket = match address_length {
|
||||
12 => {
|
||||
let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
|
||||
let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
}
|
||||
36 => {
|
||||
let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
|
||||
let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok((ChainRW { inner: read, buf }, socket))
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> AsyncRead for ChainRW<T> {
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if self.buf.is_empty() {
|
||||
self.project().inner.poll_read(cx, buf)
|
||||
} else {
|
||||
self.read_from_buf(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> ChainRW<T> {
|
||||
#[cold]
|
||||
fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
|
||||
debug_assert!(!self.buf.is_empty());
|
||||
let this = self.project();
|
||||
|
||||
let write = usize::min(this.buf.len(), buf.remaining());
|
||||
let slice = this.buf.split_to(write).freeze();
|
||||
buf.put_slice(&slice);
|
||||
|
||||
// reset the allocation so it can be freed
|
||||
if this.buf.is_empty() {
|
||||
*this.buf = BytesMut::new();
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use crate::protocol2::read_proxy_protocol;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ipv4() {
|
||||
let header = super::HEADER
|
||||
// Proxy command, IPV4 | TCP
|
||||
.chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
|
||||
// 12 + 3 bytes
|
||||
.chain([0, 15].as_slice())
|
||||
// src ip
|
||||
.chain([127, 0, 0, 1].as_slice())
|
||||
// dst ip
|
||||
.chain([192, 168, 0, 1].as_slice())
|
||||
// src port
|
||||
.chain([255, 255].as_slice())
|
||||
// dst port
|
||||
.chain([1, 1].as_slice())
|
||||
// TLV
|
||||
.chain([1, 2, 3].as_slice());
|
||||
|
||||
let extra_data = [0x55; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ipv6() {
|
||||
let header = super::HEADER
|
||||
// Proxy command, IPV6 | UDP
|
||||
.chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
|
||||
// 36 + 3 bytes
|
||||
.chain([0, 39].as_slice())
|
||||
// src ip
|
||||
.chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
|
||||
// dst ip
|
||||
.chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
|
||||
// src port
|
||||
.chain([1, 1].as_slice())
|
||||
// dst port
|
||||
.chain([255, 255].as_slice())
|
||||
// TLV
|
||||
.chain([1, 2, 3].as_slice());
|
||||
|
||||
let extra_data = [0x55; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
assert_eq!(
|
||||
addr,
|
||||
Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid() {
|
||||
let data = [0x55; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(addr, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_short() {
|
||||
let data = [0x55; 10];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(addr, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_large_tlv() {
|
||||
let tlv = vec![0x55; 32768];
|
||||
let len = (12 + tlv.len() as u16).to_be_bytes();
|
||||
|
||||
let header = super::HEADER
|
||||
// Proxy command, Inet << 4 | Stream
|
||||
.chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
|
||||
// 12 + 3 bytes
|
||||
.chain(len.as_slice())
|
||||
// src ip
|
||||
.chain([55, 56, 57, 58].as_slice())
|
||||
// dst ip
|
||||
.chain([192, 168, 0, 1].as_slice())
|
||||
// src port
|
||||
.chain([255, 255].as_slice())
|
||||
// dst port
|
||||
.chain([1, 1].as_slice())
|
||||
// TLV
|
||||
.chain(tlv.as_slice());
|
||||
|
||||
let extra_data = [0xaa; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
|
||||
}
|
||||
}
|
||||
433
proxy/core/src/proxy.rs
Normal file
433
proxy/core/src/proxy.rs
Normal file
@@ -0,0 +1,433 @@
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub mod connect_compute;
|
||||
mod copy_bidirectional;
|
||||
pub mod handshake;
|
||||
pub mod passthrough;
|
||||
pub mod retry;
|
||||
pub mod wake_compute;
|
||||
pub use copy_bidirectional::copy_bidirectional_client_compute;
|
||||
pub use copy_bidirectional::ErrorSource;
|
||||
|
||||
use crate::{
|
||||
auth,
|
||||
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
|
||||
compute,
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::{Metrics, NumClientConnectionsGuard},
|
||||
protocol2::read_proxy_protocol,
|
||||
proxy::handshake::{handshake, HandshakeData},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
EndpointCacheKey,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use smol_str::{format_smolstr, SmolStr};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, Instrument};
|
||||
|
||||
use self::{
|
||||
connect_compute::{connect_to_compute, TcpMechanism},
|
||||
passthrough::ProxyPassthrough,
|
||||
};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
f: F,
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> Option<F::Output> {
|
||||
match futures::future::select(
|
||||
std::pin::pin!(f),
|
||||
std::pin::pin!(cancellation_token.cancelled()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
futures::future::Either::Left((f, _)) => Some(f),
|
||||
futures::future::Either::Right(((), _)) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
|
||||
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
Ok((socket, Some(addr))) => (socket, addr.ip()),
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, None)) if config.require_client_ip => {
|
||||
error!("missing required client IP");
|
||||
return;
|
||||
}
|
||||
Ok((socket, None)) => (socket, peer_addr.ip()),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: failed to set socket option: {e:#}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
let span = ctx.span();
|
||||
|
||||
let startup = Box::pin(
|
||||
handle_client(
|
||||
config,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
)
|
||||
.instrument(span.clone()),
|
||||
);
|
||||
let res = startup.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
// todo: log and push to ctx the error kind
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
error!(parent: &span, "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass().instrument(span.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub enum ClientMode {
|
||||
Tcp,
|
||||
Websockets { hostname: Option<String> },
|
||||
}
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
pub fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => config.allow_self_signed_compute,
|
||||
ClientMode::Websockets { .. } => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
||||
match self {
|
||||
ClientMode::Tcp => s.sni_hostname(),
|
||||
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
||||
match self {
|
||||
ClientMode::Tcp => tls,
|
||||
// TLS is None here if using websockets, because the connection is already encrypted.
|
||||
ClientMode::Websockets { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
Handshake(#[from] handshake::HandshakeError),
|
||||
#[error("{0}")]
|
||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("{0}")]
|
||||
PrepareClient(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for ClientRequestError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
|
||||
info!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let _request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
let (mut stream, params) =
|
||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
return Ok(cancellation_handler
|
||||
.cancel_session(cancel_key_data, ctx.session_id())
|
||||
.await
|
||||
.map(|()| None)?)
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e).await?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let user_info = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return stream.throw_error(e).instrument(params_span).await?;
|
||||
}
|
||||
};
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&user_info,
|
||||
mode.allow_self_signed_compute(config),
|
||||
config.wake_compute_retry_config,
|
||||
config.connect_to_compute_retry_config,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
let session = cancellation_handler.get_session();
|
||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node,
|
||||
req: _request_gauge,
|
||||
conn: conn_gauge,
|
||||
cancel: session,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn prepare_client_connection<P>(
|
||||
node: &compute::PostgresConnection,
|
||||
session: &cancellation::Session<P>,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
// Register compute's query cancellation token and produce a new, unique one.
|
||||
// The new token (cancel_key_data) will be sent to the client.
|
||||
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
// Right now the implementation is very hacky and inefficent (ideally,
|
||||
// we don't need an intermediate hashmap), but at least it should be correct.
|
||||
for (name, value) in &node.params {
|
||||
// TODO: Theoretically, this could result in a big pile of params...
|
||||
stream.write_message_noflush(&Be::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
})?;
|
||||
}
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||
.write_message(&Be::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
|
||||
impl NeonOptions {
|
||||
pub fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
params
|
||||
.options_raw()
|
||||
.map(Self::parse_from_iter)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
pub fn parse_options_raw(options: &str) -> Self {
|
||||
Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
|
||||
}
|
||||
|
||||
pub fn is_ephemeral(&self) -> bool {
|
||||
// Currently, neon endpoint options are all reserved for ephemeral endpoints.
|
||||
!self.0.is_empty()
|
||||
}
|
||||
|
||||
fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
|
||||
let mut options = options
|
||||
.filter_map(neon_option)
|
||||
.map(|(k, v)| (k.into(), v.into()))
|
||||
.collect_vec();
|
||||
options.sort();
|
||||
Self(options)
|
||||
}
|
||||
|
||||
pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
|
||||
// prefix + format!(" {k}:{v}")
|
||||
// kinda jank because SmolStr is immutable
|
||||
std::iter::once(prefix)
|
||||
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
|
||||
.collect::<SmolStr>()
|
||||
.into()
|
||||
}
|
||||
|
||||
/// <https://swagger.io/docs/specification/serialization/> DeepObject format
|
||||
/// `paramName[prop1]=value1¶mName[prop2]=value2&...`
|
||||
pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
|
||||
self.0
|
||||
.iter()
|
||||
.map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
|
||||
static RE: OnceCell<Regex> = OnceCell::new();
|
||||
let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
|
||||
|
||||
let cap = re.captures(bytes)?;
|
||||
let (_, [k, v]) = cap.extract();
|
||||
Some((k, v))
|
||||
}
|
||||
216
proxy/core/src/proxy/connect_compute.rs
Normal file
216
proxy/core/src/proxy/connect_compute.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use crate::{
|
||||
auth::backend::ComputeCredentialKeys,
|
||||
compute::{self, PostgresConnection},
|
||||
config::RetryConfig,
|
||||
console::{self, errors::WakeComputeError, locks::ApiLocks, CachedNodeInfo, NodeInfo},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::{ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType},
|
||||
proxy::{
|
||||
retry::{retry_after, should_retry, CouldRetry},
|
||||
wake_compute::wake_compute,
|
||||
},
|
||||
Host,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use tokio::time;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use super::retry::ShouldRetryWakeCompute;
|
||||
|
||||
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||
|
||||
/// If we couldn't connect, a cached connection info might be to blame
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
||||
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
warn!("invalidating stalled compute node info cache entry");
|
||||
}
|
||||
let label = match is_cached {
|
||||
true => ConnectionFailureKind::ComputeCached,
|
||||
false => ConnectionFailureKind::ComputeUncached,
|
||||
};
|
||||
Metrics::get().proxy.connection_failures_total.inc(label);
|
||||
|
||||
node_info.invalidate()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ConnectMechanism {
|
||||
type Connection;
|
||||
type ConnectError: ReportableError;
|
||||
type Error: From<Self::ConnectError>;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
|
||||
}
|
||||
|
||||
pub struct TcpMechanism<'a> {
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub params: &'a StartupMessageParams,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
pub locks: &'static ApiLocks<Host>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TcpMechanism<'_> {
|
||||
type Connection = PostgresConnection;
|
||||
type ConnectError = compute::ConnectionError;
|
||||
type Error = compute::ConnectionError;
|
||||
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let host = node_info.config.get_host()?;
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
permit.release_result(node_info.connect(ctx, timeout).await)
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
config.set_startup_params(self.params);
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &RequestMonitoring,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
allow_self_signed_compute: bool,
|
||||
wake_compute_retry_config: RetryConfig,
|
||||
connect_to_compute_retry_config: RetryConfig,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
|
||||
M::Error: From<WakeComputeError>,
|
||||
{
|
||||
let mut num_retries = 0;
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
if let Some(keys) = user_info.get_keys() {
|
||||
node_info.set_keys(keys);
|
||||
}
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
let retry_type = RetryType::ConnectToCompute;
|
||||
|
||||
// try once
|
||||
let err = match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Success,
|
||||
retry_type,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
error!(error = ?err, "could not connect to compute node");
|
||||
|
||||
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
||||
// If we just recieved this from cplane and dodn't get it from cache, we shouldn't retry.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if should_retry(&err, num_retries, connect_to_compute_retry_config) {
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
retry_type,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
return Err(err.into());
|
||||
}
|
||||
node_info
|
||||
} else {
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
let old_node_info = invalidate_cache(node_info);
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
node_info.reuse_settings(old_node_info);
|
||||
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
node_info
|
||||
};
|
||||
|
||||
// now that we have a new node, try connect to it repeatedly.
|
||||
// this can error for a few reasons, for instance:
|
||||
// * DNS connection settings haven't quite propagated yet
|
||||
info!("wake_compute success. attempting to connect");
|
||||
num_retries = 1;
|
||||
loop {
|
||||
match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Success,
|
||||
retry_type,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
info!(?num_retries, "connected to compute node after");
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => {
|
||||
if !should_retry(&e, num_retries, connect_to_compute_retry_config) {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't connect to compute node");
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
retry_type,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
warn!(error = ?e, num_retries, retriable = true, "couldn't connect to compute node");
|
||||
}
|
||||
};
|
||||
|
||||
let wait_duration = retry_after(num_retries, connect_to_compute_retry_config);
|
||||
num_retries += 1;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||
time::sleep(wait_duration).await;
|
||||
drop(pause);
|
||||
}
|
||||
}
|
||||
306
proxy/core/src/proxy/copy_bidirectional.rs
Normal file
306
proxy/core/src/proxy/copy_bidirectional.rs
Normal file
@@ -0,0 +1,306 @@
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::info;
|
||||
|
||||
use std::future::poll_fn;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TransferState {
|
||||
Running(CopyBuffer),
|
||||
ShuttingDown(u64),
|
||||
Done(u64),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ErrorDirection {
|
||||
Read(io::Error),
|
||||
Write(io::Error),
|
||||
}
|
||||
|
||||
impl ErrorSource {
|
||||
fn from_client(err: ErrorDirection) -> ErrorSource {
|
||||
match err {
|
||||
ErrorDirection::Read(client) => Self::Client(client),
|
||||
ErrorDirection::Write(compute) => Self::Compute(compute),
|
||||
}
|
||||
}
|
||||
fn from_compute(err: ErrorDirection) -> ErrorSource {
|
||||
match err {
|
||||
ErrorDirection::Write(client) => Self::Client(client),
|
||||
ErrorDirection::Read(compute) => Self::Compute(compute),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ErrorSource {
|
||||
Client(io::Error),
|
||||
Compute(io::Error),
|
||||
}
|
||||
|
||||
fn transfer_one_direction<A, B>(
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut TransferState,
|
||||
r: &mut A,
|
||||
w: &mut B,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
where
|
||||
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut r = Pin::new(r);
|
||||
let mut w = Pin::new(w);
|
||||
loop {
|
||||
match state {
|
||||
TransferState::Running(buf) => {
|
||||
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
|
||||
*state = TransferState::ShuttingDown(count);
|
||||
}
|
||||
TransferState::ShuttingDown(count) => {
|
||||
ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?;
|
||||
*state = TransferState::Done(*count);
|
||||
}
|
||||
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn copy_bidirectional_client_compute<Client, Compute>(
|
||||
client: &mut Client,
|
||||
compute: &mut Compute,
|
||||
) -> Result<(u64, u64), ErrorSource>
|
||||
where
|
||||
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut client_to_compute = TransferState::Running(CopyBuffer::new());
|
||||
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
|
||||
|
||||
poll_fn(|cx| {
|
||||
let mut client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
let mut compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
|
||||
// Early termination checks from compute to client.
|
||||
if let TransferState::Done(_) = compute_to_client {
|
||||
if let TransferState::Running(buf) = &client_to_compute {
|
||||
info!("Compute is done, terminate client");
|
||||
// Initiate shutdown
|
||||
client_to_compute = TransferState::ShuttingDown(buf.amt);
|
||||
client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Early termination checks from client to compute.
|
||||
if let TransferState::Done(_) = client_to_compute {
|
||||
if let TransferState::Running(buf) = &compute_to_client {
|
||||
info!("Client is done, terminate compute");
|
||||
// Initiate shutdown
|
||||
compute_to_client = TransferState::ShuttingDown(buf.amt);
|
||||
compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
}
|
||||
}
|
||||
|
||||
// It is not a problem if ready! returns early ... (comment remains the same)
|
||||
let client_to_compute = ready!(client_to_compute_result);
|
||||
let compute_to_client = ready!(compute_to_client_result);
|
||||
|
||||
Poll::Ready(Ok((client_to_compute, compute_to_client)))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct CopyBuffer {
|
||||
read_done: bool,
|
||||
need_flush: bool,
|
||||
pos: usize,
|
||||
cap: usize,
|
||||
amt: u64,
|
||||
buf: Box<[u8]>,
|
||||
}
|
||||
const DEFAULT_BUF_SIZE: usize = 1024;
|
||||
|
||||
impl CopyBuffer {
|
||||
pub(super) fn new() -> Self {
|
||||
Self {
|
||||
read_done: false,
|
||||
need_flush: false,
|
||||
pos: 0,
|
||||
cap: 0,
|
||||
amt: 0,
|
||||
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_fill_buf<R>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
reader: Pin<&mut R>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
{
|
||||
let me = &mut *self;
|
||||
let mut buf = ReadBuf::new(&mut me.buf);
|
||||
buf.set_filled(me.cap);
|
||||
|
||||
let res = reader.poll_read(cx, &mut buf);
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
let filled_len = buf.filled().len();
|
||||
me.read_done = me.cap == filled_len;
|
||||
me.cap = filled_len;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn poll_write_buf<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<usize, ErrorDirection>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
let me = &mut *self;
|
||||
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
|
||||
Poll::Pending => {
|
||||
// Top up the buffer towards full if we can read a bit more
|
||||
// data - this should improve the chances of a large write
|
||||
if !me.read_done && me.cap < me.buf.len() {
|
||||
ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
res => res.map_err(ErrorDirection::Write),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn poll_copy<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
loop {
|
||||
// If our buffer is empty, then we need to read some data to
|
||||
// continue.
|
||||
if self.pos == self.cap && !self.read_done {
|
||||
self.pos = 0;
|
||||
self.cap = 0;
|
||||
|
||||
match self.poll_fill_buf(cx, reader.as_mut()) {
|
||||
Poll::Ready(Ok(())) => (),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
|
||||
Poll::Pending => {
|
||||
// Try flushing when the reader has no progress to avoid deadlock
|
||||
// when the reader depends on buffered writer.
|
||||
if self.need_flush {
|
||||
ready!(writer.as_mut().poll_flush(cx))
|
||||
.map_err(ErrorDirection::Write)?;
|
||||
self.need_flush = false;
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
|
||||
if i == 0 {
|
||||
return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"write zero byte into writer",
|
||||
))));
|
||||
} else {
|
||||
self.pos += i;
|
||||
self.amt += i as u64;
|
||||
self.need_flush = true;
|
||||
}
|
||||
}
|
||||
|
||||
// If pos larger than cap, this loop will never stop.
|
||||
// In particular, user's wrong poll_write implementation returning
|
||||
// incorrect written length may lead to thread blocking.
|
||||
debug_assert!(
|
||||
self.pos <= self.cap,
|
||||
"writer returned length larger than input slice"
|
||||
);
|
||||
|
||||
// If we've written all the data and we've seen EOF, flush out the
|
||||
// data and finish the transfer.
|
||||
if self.pos == self.cap && self.read_done {
|
||||
ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?;
|
||||
return Poll::Ready(Ok(self.amt));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_to_compute() {
|
||||
let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
let (mut compute_proxy, mut compute_client) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
|
||||
// Simulate 'a' finishing while there's still data for 'b'
|
||||
client_client.write_all(b"hello").await.unwrap();
|
||||
client_client.shutdown().await.unwrap();
|
||||
compute_client.write_all(b"Neon").await.unwrap();
|
||||
compute_client.shutdown().await.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
assert_eq!(client_to_compute_count, 5); // 'hello' was transferred
|
||||
assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compute_to_client() {
|
||||
let (mut client_client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
let (mut compute_proxy, mut compute_client) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
|
||||
// Simulate 'a' finishing while there's still data for 'b'
|
||||
compute_client.write_all(b"hello").await.unwrap();
|
||||
compute_client.shutdown().await.unwrap();
|
||||
client_client
|
||||
.write_all(b"Neon Serverless Postgres")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
assert_eq!(compute_to_client_count, 5); // 'hello' was transferred
|
||||
assert!(client_to_compute_count <= 8); // response only partially transferred or not at all
|
||||
}
|
||||
}
|
||||
258
proxy/core/src/proxy/handshake.rs
Normal file
258
proxy/core/src/proxy/handshake.rs
Normal file
@@ -0,0 +1,258 @@
|
||||
use bytes::Buf;
|
||||
use pq_proto::{
|
||||
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
|
||||
StartupMessageParams,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
auth::endpoint_sni,
|
||||
config::{TlsConfig, PG_ALPN_PROTOCOL},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::Metrics,
|
||||
proxy::ERR_INSECURE_CONNECTION,
|
||||
stream::{PqStream, Stream, StreamUpgradeError},
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum HandshakeError {
|
||||
#[error("data is sent before server replied with EncryptionResponse")]
|
||||
EarlyData,
|
||||
|
||||
#[error("protocol violation")]
|
||||
ProtocolViolation,
|
||||
|
||||
#[error("missing certificate")]
|
||||
MissingCertificate,
|
||||
|
||||
#[error("{0}")]
|
||||
StreamUpgradeError(#[from] StreamUpgradeError),
|
||||
|
||||
#[error("{0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for HandshakeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
HandshakeError::EarlyData => crate::error::ErrorKind::User,
|
||||
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
|
||||
// This error should not happen, but will if we have no default certificate and
|
||||
// the client sends no SNI extension.
|
||||
// If they provide SNI then we can be sure there is a certificate that matches.
|
||||
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
|
||||
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
|
||||
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
|
||||
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
},
|
||||
HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
HandshakeError::ReportedError(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
/// For better testing experience, `stream` can be any object satisfying the traits.
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestMonitoring,
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
record_handshake_error: bool,
|
||||
) -> Result<HandshakeData<S>, HandshakeError> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use FeStartupPacket::*;
|
||||
match msg {
|
||||
SslRequest { direct } => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_ssl => {
|
||||
tried_ssl = true;
|
||||
|
||||
// We can't perform TLS handshake without a config
|
||||
let have_tls = tls.is_some();
|
||||
if !direct {
|
||||
stream
|
||||
.write_message(&Be::EncryptionResponse(have_tls))
|
||||
.await?;
|
||||
} else if !have_tls {
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let Framed {
|
||||
stream: raw,
|
||||
read_buf,
|
||||
write_buf,
|
||||
} = stream.framed;
|
||||
|
||||
let Stream::Raw { raw } = raw else {
|
||||
return Err(HandshakeError::StreamUpgradeError(
|
||||
StreamUpgradeError::AlreadyTls,
|
||||
));
|
||||
};
|
||||
|
||||
let mut read_buf = read_buf.reader();
|
||||
let mut res = Ok(());
|
||||
let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config())
|
||||
.accept_with(raw, |session| {
|
||||
// push the early data to the tls session
|
||||
while !read_buf.get_ref().is_empty() {
|
||||
match session.read_tls(&mut read_buf) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
res = Err(e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
res?;
|
||||
|
||||
let read_buf = read_buf.into_inner();
|
||||
if !read_buf.is_empty() {
|
||||
return Err(HandshakeError::EarlyData);
|
||||
}
|
||||
|
||||
let tls_stream = accept.await.inspect_err(|_| {
|
||||
if record_handshake_error {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc()
|
||||
}
|
||||
})?;
|
||||
|
||||
let conn_info = tls_stream.get_ref().1;
|
||||
|
||||
// try parse endpoint
|
||||
let ep = conn_info
|
||||
.server_name()
|
||||
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
|
||||
if let Some(ep) = ep {
|
||||
ctx.set_endpoint_id(ep);
|
||||
}
|
||||
|
||||
// check the ALPN, if exists, as required.
|
||||
match conn_info.alpn_protocol() {
|
||||
None | Some(PG_ALPN_PROTOCOL) => {}
|
||||
Some(other) => {
|
||||
let alpn = String::from_utf8_lossy(other);
|
||||
warn!(%alpn, "unexpected ALPN");
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
}
|
||||
|
||||
let (_, tls_server_end_point) = tls
|
||||
.cert_resolver
|
||||
.resolve(conn_info.server_name())
|
||||
.ok_or(HandshakeError::MissingCertificate)?;
|
||||
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
},
|
||||
GssEncRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_gss => {
|
||||
tried_gss = true;
|
||||
|
||||
// Currently, we don't support GSSAPI
|
||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
},
|
||||
StartupMessage { params, version }
|
||||
if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST =>
|
||||
{
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
return stream
|
||||
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
?version,
|
||||
?params,
|
||||
session_type = "normal",
|
||||
"successful handshake"
|
||||
);
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
}
|
||||
// downgrade protocol version
|
||||
StartupMessage { params, version }
|
||||
if version.major() == 3 && version > PG_PROTOCOL_LATEST =>
|
||||
{
|
||||
warn!(?version, "unsupported minor version");
|
||||
|
||||
// no protocol extensions are supported.
|
||||
// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
|
||||
let mut unsupported = vec![];
|
||||
for (k, _) in params.iter() {
|
||||
if k.starts_with("_pq_.") {
|
||||
unsupported.push(k);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove unsupported options so we don't send them to compute.
|
||||
|
||||
stream
|
||||
.write_message(&Be::NegotiateProtocolVersion {
|
||||
version: PG_PROTOCOL_LATEST,
|
||||
options: &unsupported,
|
||||
})
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
?version,
|
||||
session_type = "normal",
|
||||
"successful handshake; unsupported minor version requested"
|
||||
);
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
}
|
||||
StartupMessage { version, .. } => {
|
||||
warn!(
|
||||
?version,
|
||||
session_type = "normal",
|
||||
"unsuccessful handshake; unsupported version"
|
||||
);
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
78
proxy/core/src/proxy/passthrough.rs
Normal file
78
proxy/core/src/proxy/passthrough.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use crate::{
|
||||
cancellation,
|
||||
compute::PostgresConnection,
|
||||
console::messages::MetricsAuxInfo,
|
||||
metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard},
|
||||
stream::Stream,
|
||||
usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Result<(), ErrorSource> {
|
||||
let usage = USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
branch_id: aux.branch_id,
|
||||
});
|
||||
|
||||
let metrics = &Metrics::get().proxy.io_bytes;
|
||||
let m_sent = metrics.with_labels(Direction::Tx);
|
||||
let mut client = MeasuredStream::new(
|
||||
client,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
metrics.get_metric(m_sent).inc_by(cnt as u64);
|
||||
usage.record_egress(cnt as u64);
|
||||
},
|
||||
);
|
||||
|
||||
let m_recv = metrics.with_labels(Direction::Rx);
|
||||
let mut compute = MeasuredStream::new(
|
||||
compute,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
metrics.get_metric(m_recv).inc_by(cnt as u64);
|
||||
},
|
||||
);
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
|
||||
&mut client,
|
||||
&mut compute,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct ProxyPassthrough<P, S> {
|
||||
pub client: Stream<S>,
|
||||
pub compute: PostgresConnection,
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
pub req: NumConnectionRequestsGuard<'static>,
|
||||
pub conn: NumClientConnectionsGuard<'static>,
|
||||
pub cancel: cancellation::Session<P>,
|
||||
}
|
||||
|
||||
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
|
||||
pub async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
|
||||
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
|
||||
tracing::error!(?err, "could not cancel the query in the database");
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
108
proxy/core/src/proxy/retry.rs
Normal file
108
proxy/core/src/proxy/retry.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use crate::{compute, config::RetryConfig};
|
||||
use std::{error::Error, io};
|
||||
use tokio::time;
|
||||
|
||||
pub trait CouldRetry {
|
||||
/// Returns true if the error could be retried
|
||||
fn could_retry(&self) -> bool;
|
||||
}
|
||||
|
||||
pub trait ShouldRetryWakeCompute {
|
||||
/// Returns true if we need to invalidate the cache for this node.
|
||||
/// If false, we can continue retrying with the current node cache.
|
||||
fn should_retry_wake_compute(&self) -> bool;
|
||||
}
|
||||
|
||||
pub fn should_retry(err: &impl CouldRetry, num_retries: u32, config: RetryConfig) -> bool {
|
||||
num_retries < config.max_retries && err.could_retry()
|
||||
}
|
||||
|
||||
impl CouldRetry for io::Error {
|
||||
fn could_retry(&self) -> bool {
|
||||
use std::io::ErrorKind;
|
||||
matches!(
|
||||
self.kind(),
|
||||
ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for tokio_postgres::error::DbError {
|
||||
fn could_retry(&self) -> bool {
|
||||
use tokio_postgres::error::SqlState;
|
||||
matches!(
|
||||
self.code(),
|
||||
&SqlState::CONNECTION_FAILURE
|
||||
| &SqlState::CONNECTION_EXCEPTION
|
||||
| &SqlState::CONNECTION_DOES_NOT_EXIST
|
||||
| &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
|
||||
)
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for tokio_postgres::error::DbError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
use tokio_postgres::error::SqlState;
|
||||
// Here are errors that happens after the user successfully authenticated to the database.
|
||||
// TODO: there are pgbouncer errors that should be retried, but they are not listed here.
|
||||
!matches!(
|
||||
self.code(),
|
||||
&SqlState::TOO_MANY_CONNECTIONS
|
||||
| &SqlState::OUT_OF_MEMORY
|
||||
| &SqlState::SYNTAX_ERROR
|
||||
| &SqlState::T_R_SERIALIZATION_FAILURE
|
||||
| &SqlState::INVALID_CATALOG_NAME
|
||||
| &SqlState::INVALID_SCHEMA_NAME
|
||||
| &SqlState::INVALID_PARAMETER_VALUE
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for tokio_postgres::Error {
|
||||
fn could_retry(&self) -> bool {
|
||||
if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
|
||||
io::Error::could_retry(io_err)
|
||||
} else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
|
||||
tokio_postgres::error::DbError::could_retry(db_err)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for tokio_postgres::Error {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
|
||||
tokio_postgres::error::DbError::should_retry_wake_compute(db_err)
|
||||
} else {
|
||||
// likely an IO error. Possible the compute has shutdown and the
|
||||
// cache is stale.
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for compute::ConnectionError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.could_retry(),
|
||||
compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
|
||||
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for compute::ConnectionError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.should_retry_wake_compute(),
|
||||
// the cache entry was not checked for validity
|
||||
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration {
|
||||
config
|
||||
.base_delay
|
||||
.mul_f64(config.backoff_factor.powi((num_retries as i32) - 1))
|
||||
}
|
||||
708
proxy/core/src/proxy/tests.rs
Normal file
708
proxy/core/src/proxy/tests.rs
Normal file
@@ -0,0 +1,708 @@
|
||||
//! A group of high-level tests for connection establishing logic and auth.
|
||||
|
||||
mod mitm;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::CouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
|
||||
};
|
||||
use crate::config::{CertResolver, RetryConfig};
|
||||
use crate::console::caches::NodeInfoCache;
|
||||
use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
|
||||
use anyhow::{bail, Context};
|
||||
use async_trait::async_trait;
|
||||
use retry::{retry_after, ShouldRetryWakeCompute};
|
||||
use rstest::rstest;
|
||||
use rustls::pki_types;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
hostname: &str,
|
||||
common_name: &str,
|
||||
) -> anyhow::Result<(
|
||||
pki_types::CertificateDer<'static>,
|
||||
pki_types::CertificateDer<'static>,
|
||||
pki_types::PrivateKeyDer<'static>,
|
||||
)> {
|
||||
let ca = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::default();
|
||||
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
|
||||
params
|
||||
})?;
|
||||
|
||||
let cert = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
|
||||
params.distinguished_name = rcgen::DistinguishedName::new();
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, common_name);
|
||||
params
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
pki_types::CertificateDer::from(ca.serialize_der()?),
|
||||
pki_types::CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
|
||||
pki_types::PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
|
||||
))
|
||||
}
|
||||
|
||||
struct ClientConfig<'a> {
|
||||
config: rustls::ClientConfig,
|
||||
hostname: &'a str,
|
||||
}
|
||||
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
self,
|
||||
) -> anyhow::Result<
|
||||
impl tokio_postgres::tls::TlsConnect<
|
||||
S,
|
||||
Error = impl std::fmt::Debug,
|
||||
Future = impl Send,
|
||||
Stream = RustlsStream<S>,
|
||||
>,
|
||||
> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
Ok(tls)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate TLS certificates and build rustls configs for client and server.
|
||||
fn generate_tls_config<'a>(
|
||||
hostname: &'a str,
|
||||
common_name: &'a str,
|
||||
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
|
||||
let (ca, cert, key) = generate_certs(hostname, common_name)?;
|
||||
|
||||
let tls_config = {
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert.clone()], key.clone_key())?
|
||||
.into();
|
||||
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
TlsConfig {
|
||||
config,
|
||||
common_names,
|
||||
cert_resolver: Arc::new(cert_resolver),
|
||||
}
|
||||
};
|
||||
|
||||
let client_config = {
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates({
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add(ca)?;
|
||||
store
|
||||
})
|
||||
.with_no_client_auth();
|
||||
|
||||
ClientConfig { config, hostname }
|
||||
};
|
||||
|
||||
Ok((client_config, tls_config))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
trait TestAuth: Sized {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct NoAuth;
|
||||
impl TestAuth for NoAuth {}
|
||||
|
||||
struct Scram(scram::ServerSecret);
|
||||
|
||||
impl Scram {
|
||||
async fn new(password: &str) -> anyhow::Result<Self> {
|
||||
let secret = scram::ServerSecret::build(password)
|
||||
.await
|
||||
.context("failed to generate scram secret")?;
|
||||
Ok(Scram(secret))
|
||||
}
|
||||
|
||||
fn mock() -> Self {
|
||||
Scram(scram::ServerSecret::mock(rand::random()))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TestAuth for Scram {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let outcome = auth::AuthFlow::new(stream)
|
||||
.begin(auth::Scram(&self.0, &RequestMonitoring::test()))
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
use sasl::Outcome::*;
|
||||
match outcome {
|
||||
Success(_) => Ok(()),
|
||||
Failure(reason) => bail!("autentication failed with an error: {reason}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A dummy proxy impl which performs a handshake and reports auth success.
|
||||
async fn dummy_proxy(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
tls: Option<TlsConfig>,
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let (client, _) = read_proxy_protocol(client).await?;
|
||||
let mut stream =
|
||||
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
};
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let client_err = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Disable)
|
||||
.connect_raw(server, NoTls)
|
||||
.await
|
||||
.err() // -> Option<E>
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
assert!(client_err.to_string().contains(ERR_INSECURE_CONNECTION));
|
||||
|
||||
let server_err = proxy
|
||||
.await?
|
||||
.err() // -> Option<E>
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
assert!(client_err.to_string().contains(&server_err.to_string()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_tls() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_raw() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.options("project=generic-project-name")
|
||||
.ssl_mode(SslMode::Prefer)
|
||||
.connect_raw(server, NoTls)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn keepalive_is_inherited() -> anyhow::Result<()> {
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
||||
let port = listener.local_addr()?.port();
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let t = tokio::spawn(async move {
|
||||
let (client, _) = listener.accept().await?;
|
||||
let keepalive = socket2::SockRef::from(&client).keepalive()?;
|
||||
anyhow::Ok(keepalive)
|
||||
});
|
||||
|
||||
let _ = TcpStream::connect(("127.0.0.1", port)).await?;
|
||||
assert!(t.await??, "keepalive should be inherited");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case("password_foo")]
|
||||
#[case("pwd-bar")]
|
||||
#[case("")]
|
||||
#[tokio::test]
|
||||
async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new(password).await?,
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password").await?,
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
|
||||
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
let password: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(rand::random::<u8>() as usize)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(&password) // no password will match the mocked secret
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await
|
||||
.err() // -> Option<E>
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
let _server_err = proxy
|
||||
.await?
|
||||
.err() // -> Option<E>
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_compute_total_wait() {
|
||||
let mut total_wait = tokio::time::Duration::ZERO;
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
for num_retries in 1..config.max_retries {
|
||||
total_wait += retry_after(num_retries, config);
|
||||
}
|
||||
assert!(f64::abs(total_wait.as_secs_f64() - 15.0) < 0.1);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum ConnectAction {
|
||||
Wake,
|
||||
WakeFail,
|
||||
WakeRetry,
|
||||
Connect,
|
||||
Retry,
|
||||
Fail,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestConnectMechanism {
|
||||
counter: Arc<std::sync::Mutex<usize>>,
|
||||
sequence: Vec<ConnectAction>,
|
||||
cache: &'static NodeInfoCache,
|
||||
}
|
||||
|
||||
impl TestConnectMechanism {
|
||||
fn verify(&self) {
|
||||
let counter = self.counter.lock().unwrap();
|
||||
assert_eq!(
|
||||
*counter,
|
||||
self.sequence.len(),
|
||||
"sequence does not proceed to the end"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl TestConnectMechanism {
|
||||
fn new(sequence: Vec<ConnectAction>) -> Self {
|
||||
Self {
|
||||
counter: Arc::new(std::sync::Mutex::new(0)),
|
||||
sequence,
|
||||
cache: Box::leak(Box::new(NodeInfoCache::new(
|
||||
"test",
|
||||
1,
|
||||
Duration::from_secs(100),
|
||||
false,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestConnection;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestConnectError {
|
||||
retryable: bool,
|
||||
kind: crate::error::ErrorKind,
|
||||
}
|
||||
|
||||
impl ReportableError for TestConnectError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
self.kind
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestConnectError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for TestConnectError {}
|
||||
|
||||
impl CouldRetry for TestConnectError {
|
||||
fn could_retry(&self) -> bool {
|
||||
self.retryable
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for TestConnectError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TestConnectMechanism {
|
||||
type Connection = TestConnection;
|
||||
type ConnectError = TestConnectError;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_node_info: &console::CachedNodeInfo,
|
||||
_timeout: std::time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let mut counter = self.counter.lock().unwrap();
|
||||
let action = self.sequence[*counter];
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Connect => Ok(TestConnection),
|
||||
ConnectAction::Retry => Err(TestConnectError {
|
||||
retryable: true,
|
||||
kind: ErrorKind::Compute,
|
||||
}),
|
||||
ConnectAction::Fail => Err(TestConnectError {
|
||||
retryable: false,
|
||||
kind: ErrorKind::Compute,
|
||||
}),
|
||||
x => panic!("expecting action {:?}, connect is called instead", x),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
impl TestBackend for TestConnectMechanism {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
let mut counter = self.counter.lock().unwrap();
|
||||
let action = self.sequence[*counter];
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
|
||||
ConnectAction::WakeFail => {
|
||||
let err = console::errors::ApiError::Console(ConsoleError {
|
||||
http_status_code: http::StatusCode::BAD_REQUEST,
|
||||
error: "TEST".into(),
|
||||
status: None,
|
||||
});
|
||||
assert!(!err.could_retry());
|
||||
Err(console::errors::WakeComputeError::ApiError(err))
|
||||
}
|
||||
ConnectAction::WakeRetry => {
|
||||
let err = console::errors::ApiError::Console(ConsoleError {
|
||||
http_status_code: http::StatusCode::BAD_REQUEST,
|
||||
error: "TEST".into(),
|
||||
status: Some(Status {
|
||||
code: "error".into(),
|
||||
message: "error".into(),
|
||||
details: Details {
|
||||
error_info: None,
|
||||
retry_info: Some(console::messages::RetryInfo { retry_delay_ms: 1 }),
|
||||
user_facing_message: None,
|
||||
},
|
||||
}),
|
||||
});
|
||||
assert!(err.could_retry());
|
||||
Err(console::errors::WakeComputeError::ApiError(err))
|
||||
}
|
||||
x => panic!("expecting action {:?}, wake_compute is called instead", x),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
|
||||
{
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
}
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
let node = NodeInfo {
|
||||
config: compute::ConnCfg::new(),
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
|
||||
node2.map(|()| node)
|
||||
}
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> auth::BackendType<'static, ComputeCredentials, &()> {
|
||||
let user_info = auth::BackendType::Console(
|
||||
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
|
||||
ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
},
|
||||
);
|
||||
user_info
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_success() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
/// Test that we don't retry if the error is not retryable.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_1() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
/// Even for non-retryable errors, we should retry at least once.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_2() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
/// Retry for at most `NUM_RETRIES_CONNECT` times.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_3() {
|
||||
let _ = env_logger::try_init();
|
||||
tokio::time::pause();
|
||||
use ConnectAction::*;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism =
|
||||
TestConnectMechanism::new(vec![Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let wake_compute_retry_config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 1,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
let connect_to_compute_retry_config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(
|
||||
&ctx,
|
||||
&mechanism,
|
||||
&user_info,
|
||||
false,
|
||||
wake_compute_retry_config,
|
||||
connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
/// Should retry wake compute.
|
||||
#[tokio::test]
|
||||
async fn wake_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
/// Wake failed with a non-retryable error.
|
||||
#[tokio::test]
|
||||
async fn wake_non_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
}
|
||||
262
proxy/core/src/proxy/tests/mitm.rs
Normal file
262
proxy/core/src/proxy/tests/mitm.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
//! Man-in-the-middle tests
|
||||
//!
|
||||
//! Channel binding should prevent a proxy server
|
||||
//! *that has access to create valid certificates*
|
||||
//! from controlling the TLS connection.
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use super::*;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
use tokio_postgres::tls::TlsConnect;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
enum Intercept {
|
||||
None,
|
||||
Methods,
|
||||
SASLResponse,
|
||||
}
|
||||
|
||||
async fn proxy_mitm(
|
||||
intercept: Intercept,
|
||||
) -> (DuplexStream, DuplexStream, ClientConfig<'static>, TlsConfig) {
|
||||
let (end_server1, client1) = tokio::io::duplex(1024);
|
||||
let (server2, end_client2) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config1, server_config1) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
|
||||
let (client_config2, server_config2) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost").unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// begin handshake with end_server
|
||||
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
|
||||
let (end_client, startup) = match handshake(
|
||||
&RequestMonitoring::test(),
|
||||
client1,
|
||||
Some(&server_config1),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
|
||||
};
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
let (end_client, buf) = end_client.framed.into_inner();
|
||||
assert!(buf.is_empty());
|
||||
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
|
||||
|
||||
// give the end_server the startup parameters
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(startup.iter(), &mut buf).unwrap();
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
|
||||
// proxy messages between end_client and end_server
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = end_server.next() => {
|
||||
match message {
|
||||
Some(Ok(message)) => {
|
||||
// intercept SASL and return only SCRAM-SHA-256 ;)
|
||||
if matches!(intercept, Intercept::Methods) && message.starts_with(b"R") && message[5..].starts_with(&[0,0,0,10]) {
|
||||
end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_client.send(message).await.unwrap()
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
message = end_client.next() => {
|
||||
match message {
|
||||
Some(Ok(message)) => {
|
||||
// intercept SASL response and return SCRAM-SHA-256 with no channel binding ;)
|
||||
if matches!(intercept, Intercept::SASLResponse) && message.starts_with(b"p") && message[5..].starts_with(b"SCRAM-SHA-256-PLUS\0") {
|
||||
let sasl_message = &message[1+4+19+4..];
|
||||
let mut new_message = b"n,,".to_vec();
|
||||
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
|
||||
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_server.send(message).await.unwrap()
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
else => { break }
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(end_server1, end_client2, client_config1, server_config2)
|
||||
}
|
||||
|
||||
/// taken from tokio-postgres
|
||||
pub async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
T::Error: Debug,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::ssl_request(&mut buf);
|
||||
stream.write_all(&buf).await.unwrap();
|
||||
|
||||
let mut buf = [0];
|
||||
stream.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
if buf[0] != b'S' {
|
||||
panic!("ssl not supported by server");
|
||||
}
|
||||
|
||||
tls.connect(stream).await.unwrap()
|
||||
}
|
||||
|
||||
struct PgFrame;
|
||||
impl Decoder for PgFrame {
|
||||
type Item = Bytes;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if src.len() < 5 {
|
||||
src.reserve(5 - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize + 1;
|
||||
if src.len() < len {
|
||||
src.reserve(len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(src.split_to(len).freeze()))
|
||||
}
|
||||
}
|
||||
impl Encoder<Bytes> for PgFrame {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
dst.extend_from_slice(&item);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// If the client doesn't support channel bindings, it can be exploited.
|
||||
#[tokio::test]
|
||||
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password").await?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
/// If the client chooses SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the MITM pretends like SCRAM-PLUS isn't available, but the client supports it, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the MITM pretends like the client doesn't support channel bindings, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client chooses SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// If the client requires SCRAM-PLUS, and it is spoofed to remove SCRAM-PLUS, it will fail
|
||||
#[tokio::test]
|
||||
async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_failure(
|
||||
intercept: Intercept,
|
||||
channel_binding: tokio_postgres::config::ChannelBinding,
|
||||
) -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new("password").await?,
|
||||
));
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(channel_binding)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await
|
||||
.err()
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
let _server_err = proxy
|
||||
.await?
|
||||
.err()
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
125
proxy/core/src/proxy/wake_compute.rs
Normal file
125
proxy/core/src/proxy/wake_compute.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use crate::config::RetryConfig;
|
||||
use crate::console::messages::{ConsoleError, Reason};
|
||||
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
|
||||
WakeupFailureKind,
|
||||
};
|
||||
use crate::proxy::retry::{retry_after, should_retry};
|
||||
use hyper1::StatusCode;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use super::connect_compute::ComputeConnectBackend;
|
||||
|
||||
pub async fn wake_compute<B: ComputeConnectBackend>(
|
||||
num_retries: &mut u32,
|
||||
ctx: &RequestMonitoring,
|
||||
api: &B,
|
||||
config: RetryConfig,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let retry_type = RetryType::WakeCompute;
|
||||
loop {
|
||||
match api.wake_compute(ctx).await {
|
||||
Err(e) if !should_retry(&e, *num_retries, config) => {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||
report_error(&e, false);
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
retry_type,
|
||||
},
|
||||
(*num_retries).into(),
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
|
||||
report_error(&e, true);
|
||||
}
|
||||
Ok(n) => {
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Success,
|
||||
retry_type,
|
||||
},
|
||||
(*num_retries).into(),
|
||||
);
|
||||
info!(?num_retries, "compute node woken up after");
|
||||
return Ok(n);
|
||||
}
|
||||
}
|
||||
|
||||
let wait_duration = retry_after(*num_retries, config);
|
||||
*num_retries += 1;
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||
tokio::time::sleep(wait_duration).await;
|
||||
drop(pause);
|
||||
}
|
||||
}
|
||||
|
||||
fn report_error(e: &WakeComputeError, retry: bool) {
|
||||
use crate::console::errors::ApiError;
|
||||
let kind = match e {
|
||||
WakeComputeError::BadComputeAddress(_) => WakeupFailureKind::BadComputeAddress,
|
||||
WakeComputeError::ApiError(ApiError::Transport(_)) => WakeupFailureKind::ApiTransportError,
|
||||
WakeComputeError::ApiError(ApiError::Console(e)) => match e.get_reason() {
|
||||
Reason::RoleProtected => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
Reason::ResourceNotFound => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
Reason::ProjectNotFound => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
Reason::EndpointNotFound => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
Reason::BranchNotFound => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
Reason::RateLimitExceeded => WakeupFailureKind::ApiConsoleLocked,
|
||||
Reason::NonDefaultBranchComputeTimeExceeded => WakeupFailureKind::QuotaExceeded,
|
||||
Reason::ActiveTimeQuotaExceeded => WakeupFailureKind::QuotaExceeded,
|
||||
Reason::ComputeTimeQuotaExceeded => WakeupFailureKind::QuotaExceeded,
|
||||
Reason::WrittenDataQuotaExceeded => WakeupFailureKind::QuotaExceeded,
|
||||
Reason::DataTransferQuotaExceeded => WakeupFailureKind::QuotaExceeded,
|
||||
Reason::LogicalSizeQuotaExceeded => WakeupFailureKind::QuotaExceeded,
|
||||
Reason::ConcurrencyLimitReached => WakeupFailureKind::ApiConsoleLocked,
|
||||
Reason::LockAlreadyTaken => WakeupFailureKind::ApiConsoleLocked,
|
||||
Reason::RunningOperations => WakeupFailureKind::ApiConsoleLocked,
|
||||
Reason::Unknown => match e {
|
||||
ConsoleError {
|
||||
http_status_code: StatusCode::LOCKED,
|
||||
ref error,
|
||||
..
|
||||
} if error.contains("written data quota exceeded")
|
||||
|| error.contains("the limit for current plan reached") =>
|
||||
{
|
||||
WakeupFailureKind::QuotaExceeded
|
||||
}
|
||||
ConsoleError {
|
||||
http_status_code: StatusCode::UNPROCESSABLE_ENTITY,
|
||||
ref error,
|
||||
..
|
||||
} if error.contains("compute time quota of non-primary branches is exceeded") => {
|
||||
WakeupFailureKind::QuotaExceeded
|
||||
}
|
||||
ConsoleError {
|
||||
http_status_code: StatusCode::LOCKED,
|
||||
..
|
||||
} => WakeupFailureKind::ApiConsoleLocked,
|
||||
ConsoleError {
|
||||
http_status_code: StatusCode::BAD_REQUEST,
|
||||
..
|
||||
} => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
ConsoleError {
|
||||
http_status_code, ..
|
||||
} if http_status_code.is_server_error() => {
|
||||
WakeupFailureKind::ApiConsoleOtherServerError
|
||||
}
|
||||
ConsoleError { .. } => WakeupFailureKind::ApiConsoleOtherError,
|
||||
},
|
||||
},
|
||||
WakeComputeError::TooManyConnections => WakeupFailureKind::ApiConsoleLocked,
|
||||
WakeComputeError::TooManyConnectionAttempts(_) => WakeupFailureKind::TimeoutError,
|
||||
};
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.connection_failures_breakdown
|
||||
.inc(ConnectionFailuresBreakdownGroup {
|
||||
kind,
|
||||
retry: retry.into(),
|
||||
});
|
||||
}
|
||||
10
proxy/core/src/rate_limiter.rs
Normal file
10
proxy/core/src/rate_limiter.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
mod limit_algorithm;
|
||||
mod limiter;
|
||||
pub use limit_algorithm::{
|
||||
aimd::Aimd, DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
|
||||
};
|
||||
pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
|
||||
mod leaky_bucket;
|
||||
pub use leaky_bucket::{
|
||||
EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter, LeakyBucketState,
|
||||
};
|
||||
171
proxy/core/src/rate_limiter/leaky_bucket.rs
Normal file
171
proxy/core/src/rate_limiter/leaky_bucket.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use std::{
|
||||
hash::Hash,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
|
||||
use ahash::RandomState;
|
||||
use dashmap::DashMap;
|
||||
use rand::{thread_rng, Rng};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
use crate::intern::EndpointIdInt;
|
||||
|
||||
// Simple per-endpoint rate limiter.
|
||||
pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
|
||||
|
||||
pub struct LeakyBucketRateLimiter<Key> {
|
||||
map: DashMap<Key, LeakyBucketState, RandomState>,
|
||||
config: LeakyBucketConfig,
|
||||
access_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
|
||||
pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
|
||||
rps: 600.0,
|
||||
max: 1500.0,
|
||||
};
|
||||
|
||||
pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
|
||||
Self {
|
||||
map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
|
||||
config,
|
||||
access_count: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
pub fn check(&self, key: K, n: u32) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
|
||||
self.do_gc(now);
|
||||
}
|
||||
|
||||
let mut entry = self.map.entry(key).or_insert_with(|| LeakyBucketState {
|
||||
time: now,
|
||||
filled: 0.0,
|
||||
});
|
||||
|
||||
entry.check(&self.config, now, n as f64)
|
||||
}
|
||||
|
||||
fn do_gc(&self, now: Instant) {
|
||||
info!(
|
||||
"cleaning up bucket rate limiter, current size = {}",
|
||||
self.map.len()
|
||||
);
|
||||
let n = self.map.shards().len();
|
||||
let shard = thread_rng().gen_range(0..n);
|
||||
self.map.shards()[shard]
|
||||
.write()
|
||||
.retain(|_, value| !value.get_mut().update(&self.config, now));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LeakyBucketConfig {
|
||||
pub rps: f64,
|
||||
pub max: f64,
|
||||
}
|
||||
|
||||
pub struct LeakyBucketState {
|
||||
filled: f64,
|
||||
time: Instant,
|
||||
}
|
||||
|
||||
impl LeakyBucketConfig {
|
||||
pub fn new(rps: f64, max: f64) -> Self {
|
||||
assert!(rps > 0.0, "rps must be positive");
|
||||
assert!(max > 0.0, "max must be positive");
|
||||
Self { rps, max }
|
||||
}
|
||||
}
|
||||
|
||||
impl LeakyBucketState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
filled: 0.0,
|
||||
time: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// updates the timer and returns true if the bucket is empty
|
||||
fn update(&mut self, info: &LeakyBucketConfig, now: Instant) -> bool {
|
||||
let drain = now.duration_since(self.time);
|
||||
let drain = drain.as_secs_f64() * info.rps;
|
||||
|
||||
self.filled = (self.filled - drain).clamp(0.0, info.max);
|
||||
self.time = now;
|
||||
|
||||
self.filled == 0.0
|
||||
}
|
||||
|
||||
pub fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool {
|
||||
self.update(info, now);
|
||||
|
||||
if self.filled + n > info.max {
|
||||
return false;
|
||||
}
|
||||
self.filled += n;
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LeakyBucketState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::time::Instant;
|
||||
|
||||
use super::{LeakyBucketConfig, LeakyBucketState};
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn check() {
|
||||
let info = LeakyBucketConfig::new(500.0, 2000.0);
|
||||
let mut bucket = LeakyBucketState::new();
|
||||
|
||||
// should work for 2000 requests this second
|
||||
for _ in 0..2000 {
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
}
|
||||
assert!(!bucket.check(&info, Instant::now(), 1.0));
|
||||
assert_eq!(bucket.filled, 2000.0);
|
||||
|
||||
// in 1ms we should drain 0.5 tokens.
|
||||
// make sure we don't lose any tokens
|
||||
tokio::time::advance(Duration::from_millis(1)).await;
|
||||
assert!(!bucket.check(&info, Instant::now(), 1.0));
|
||||
tokio::time::advance(Duration::from_millis(1)).await;
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
|
||||
// in 10ms we should drain 5 tokens
|
||||
tokio::time::advance(Duration::from_millis(10)).await;
|
||||
for _ in 0..5 {
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
}
|
||||
assert!(!bucket.check(&info, Instant::now(), 1.0));
|
||||
|
||||
// in 10s we should drain 5000 tokens
|
||||
// but cap is only 2000
|
||||
tokio::time::advance(Duration::from_secs(10)).await;
|
||||
for _ in 0..2000 {
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
}
|
||||
assert!(!bucket.check(&info, Instant::now(), 1.0));
|
||||
|
||||
// should sustain 500rps
|
||||
for _ in 0..2000 {
|
||||
tokio::time::advance(Duration::from_millis(10)).await;
|
||||
for _ in 0..5 {
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
265
proxy/core/src/rate_limiter/limit_algorithm.rs
Normal file
265
proxy/core/src/rate_limiter/limit_algorithm.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
//! Algorithms for controlling concurrency limits.
|
||||
use parking_lot::Mutex;
|
||||
use std::{pin::pin, sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
sync::Notify,
|
||||
time::{error::Elapsed, Instant},
|
||||
};
|
||||
|
||||
use self::aimd::Aimd;
|
||||
|
||||
pub mod aimd;
|
||||
|
||||
/// Whether a job succeeded or failed as a result of congestion/overload.
|
||||
///
|
||||
/// Errors not considered to be caused by overload should be ignored.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Outcome {
|
||||
/// The job succeeded, or failed in a way unrelated to overload.
|
||||
Success,
|
||||
/// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
|
||||
/// was observed.
|
||||
Overload,
|
||||
}
|
||||
|
||||
/// An algorithm for controlling a concurrency limit.
|
||||
pub trait LimitAlgorithm: Send + Sync + 'static {
|
||||
/// Update the concurrency limit in response to a new job completion.
|
||||
fn update(&self, old_limit: usize, sample: Sample) -> usize;
|
||||
}
|
||||
|
||||
/// The result of a job (or jobs), including the [`Outcome`] (loss) and latency (delay).
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub struct Sample {
|
||||
pub(crate) latency: Duration,
|
||||
/// Jobs in flight when the sample was taken.
|
||||
pub(crate) in_flight: usize,
|
||||
pub(crate) outcome: Outcome,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, serde::Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RateLimitAlgorithm {
|
||||
#[default]
|
||||
Fixed,
|
||||
Aimd {
|
||||
#[serde(flatten)]
|
||||
conf: Aimd,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Fixed;
|
||||
|
||||
impl LimitAlgorithm for Fixed {
|
||||
fn update(&self, old_limit: usize, _sample: Sample) -> usize {
|
||||
old_limit
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)]
|
||||
pub struct RateLimiterConfig {
|
||||
#[serde(flatten)]
|
||||
pub algorithm: RateLimitAlgorithm,
|
||||
pub initial_limit: usize,
|
||||
}
|
||||
|
||||
impl RateLimiterConfig {
|
||||
pub fn create_rate_limit_algorithm(self) -> Box<dyn LimitAlgorithm> {
|
||||
match self.algorithm {
|
||||
RateLimitAlgorithm::Fixed => Box::new(Fixed),
|
||||
RateLimitAlgorithm::Aimd { conf } => Box::new(conf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LimiterInner {
|
||||
alg: Box<dyn LimitAlgorithm>,
|
||||
available: usize,
|
||||
limit: usize,
|
||||
in_flight: usize,
|
||||
}
|
||||
|
||||
impl LimiterInner {
|
||||
fn update_limit(&mut self, latency: Duration, outcome: Option<Outcome>) {
|
||||
if let Some(outcome) = outcome {
|
||||
let sample = Sample {
|
||||
latency,
|
||||
in_flight: self.in_flight,
|
||||
outcome,
|
||||
};
|
||||
self.limit = self.alg.update(self.limit, sample);
|
||||
}
|
||||
}
|
||||
|
||||
fn take(&mut self, ready: &Notify) -> Option<()> {
|
||||
if self.available >= 1 {
|
||||
self.available -= 1;
|
||||
self.in_flight += 1;
|
||||
|
||||
// tell the next in the queue that there is a permit ready
|
||||
if self.available >= 1 {
|
||||
ready.notify_one();
|
||||
}
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Limits the number of concurrent jobs.
|
||||
///
|
||||
/// Concurrency is limited through the use of [`Token`]s. Acquire a token to run a job, and release the
|
||||
/// token once the job is finished.
|
||||
///
|
||||
/// The limit will be automatically adjusted based on observed latency (delay) and/or failures
|
||||
/// caused by overload (loss).
|
||||
pub struct DynamicLimiter {
|
||||
config: RateLimiterConfig,
|
||||
inner: Mutex<LimiterInner>,
|
||||
// to notify when a token is available
|
||||
ready: Notify,
|
||||
}
|
||||
|
||||
/// A concurrency token, required to run a job.
|
||||
///
|
||||
/// Release the token back to the [`DynamicLimiter`] after the job is complete.
|
||||
pub struct Token {
|
||||
start: Instant,
|
||||
limiter: Option<Arc<DynamicLimiter>>,
|
||||
}
|
||||
|
||||
/// A snapshot of the state of the [`DynamicLimiter`].
|
||||
///
|
||||
/// Not guaranteed to be consistent under high concurrency.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct LimiterState {
|
||||
limit: usize,
|
||||
in_flight: usize,
|
||||
}
|
||||
|
||||
impl DynamicLimiter {
|
||||
/// Create a limiter with a given limit control algorithm.
|
||||
pub fn new(config: RateLimiterConfig) -> Arc<Self> {
|
||||
let ready = Notify::new();
|
||||
ready.notify_one();
|
||||
|
||||
Arc::new(Self {
|
||||
inner: Mutex::new(LimiterInner {
|
||||
alg: config.create_rate_limit_algorithm(),
|
||||
available: config.initial_limit,
|
||||
limit: config.initial_limit,
|
||||
in_flight: 0,
|
||||
}),
|
||||
ready,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
|
||||
pub async fn acquire_timeout(self: &Arc<Self>, duration: Duration) -> Result<Token, Elapsed> {
|
||||
tokio::time::timeout(duration, self.acquire()).await?
|
||||
}
|
||||
|
||||
/// Try to acquire a concurrency [Token].
|
||||
async fn acquire(self: &Arc<Self>) -> Result<Token, Elapsed> {
|
||||
if self.config.initial_limit == 0 {
|
||||
// If the rate limiter is disabled, we can always acquire a token.
|
||||
Ok(Token::disabled())
|
||||
} else {
|
||||
let mut notified = pin!(self.ready.notified());
|
||||
let mut ready = notified.as_mut().enable();
|
||||
loop {
|
||||
if ready {
|
||||
let mut inner = self.inner.lock();
|
||||
if inner.take(&self.ready).is_some() {
|
||||
break Ok(Token::new(self.clone()));
|
||||
} else {
|
||||
notified.set(self.ready.notified());
|
||||
}
|
||||
}
|
||||
notified.as_mut().await;
|
||||
ready = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the concurrency [Token], along with the outcome of the job.
|
||||
///
|
||||
/// The [Outcome] of the job, and the time taken to perform it, may be used
|
||||
/// to update the concurrency limit.
|
||||
///
|
||||
/// Set the outcome to `None` to ignore the job.
|
||||
fn release_inner(&self, start: Instant, outcome: Option<Outcome>) {
|
||||
tracing::info!("outcome is {:?}", outcome);
|
||||
if self.config.initial_limit == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut inner = self.inner.lock();
|
||||
|
||||
inner.update_limit(start.elapsed(), outcome);
|
||||
|
||||
inner.in_flight -= 1;
|
||||
if inner.in_flight < inner.limit {
|
||||
inner.available = inner.limit - inner.in_flight;
|
||||
// At least 1 permit is now available
|
||||
self.ready.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
/// The current state of the limiter.
|
||||
pub fn state(&self) -> LimiterState {
|
||||
let inner = self.inner.lock();
|
||||
LimiterState {
|
||||
limit: inner.limit,
|
||||
in_flight: inner.in_flight,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Token {
|
||||
fn new(limiter: Arc<DynamicLimiter>) -> Self {
|
||||
Self {
|
||||
start: Instant::now(),
|
||||
limiter: Some(limiter),
|
||||
}
|
||||
}
|
||||
pub fn disabled() -> Self {
|
||||
Self {
|
||||
start: Instant::now(),
|
||||
limiter: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_disabled(&self) -> bool {
|
||||
self.limiter.is_none()
|
||||
}
|
||||
|
||||
pub fn release(mut self, outcome: Outcome) {
|
||||
self.release_mut(Some(outcome))
|
||||
}
|
||||
|
||||
pub fn release_mut(&mut self, outcome: Option<Outcome>) {
|
||||
if let Some(limiter) = self.limiter.take() {
|
||||
limiter.release_inner(self.start, outcome);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Token {
|
||||
fn drop(&mut self) {
|
||||
self.release_mut(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl LimiterState {
|
||||
/// The current concurrency limit.
|
||||
pub fn limit(&self) -> usize {
|
||||
self.limit
|
||||
}
|
||||
/// The number of jobs in flight.
|
||||
pub fn in_flight(&self) -> usize {
|
||||
self.in_flight
|
||||
}
|
||||
}
|
||||
266
proxy/core/src/rate_limiter/limit_algorithm/aimd.rs
Normal file
266
proxy/core/src/rate_limiter/limit_algorithm/aimd.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
use super::{LimitAlgorithm, Outcome, Sample};
|
||||
|
||||
/// Loss-based congestion avoidance.
|
||||
///
|
||||
/// Additive-increase, multiplicative decrease.
|
||||
///
|
||||
/// Adds available currency when:
|
||||
/// 1. no load-based errors are observed, and
|
||||
/// 2. the utilisation of the current limit is high.
|
||||
///
|
||||
/// Reduces available concurrency by a factor when load-based errors are detected.
|
||||
#[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)]
|
||||
pub struct Aimd {
|
||||
/// Minimum limit for AIMD algorithm.
|
||||
pub min: usize,
|
||||
/// Maximum limit for AIMD algorithm.
|
||||
pub max: usize,
|
||||
/// Decrease AIMD decrease by value in case of error.
|
||||
pub dec: f32,
|
||||
/// Increase AIMD increase by value in case of success.
|
||||
pub inc: usize,
|
||||
/// A threshold below which the limit won't be increased.
|
||||
pub utilisation: f32,
|
||||
}
|
||||
|
||||
impl LimitAlgorithm for Aimd {
|
||||
fn update(&self, old_limit: usize, sample: Sample) -> usize {
|
||||
use Outcome::*;
|
||||
match sample.outcome {
|
||||
Success => {
|
||||
let utilisation = sample.in_flight as f32 / old_limit as f32;
|
||||
|
||||
if utilisation > self.utilisation {
|
||||
let limit = old_limit + self.inc;
|
||||
let increased_limit = limit.clamp(self.min, self.max);
|
||||
if increased_limit > old_limit {
|
||||
tracing::info!(increased_limit, "limit increased");
|
||||
}
|
||||
|
||||
increased_limit
|
||||
} else {
|
||||
old_limit
|
||||
}
|
||||
}
|
||||
Overload => {
|
||||
let limit = old_limit as f32 * self.dec;
|
||||
|
||||
// Floor instead of round, so the limit reduces even with small numbers.
|
||||
// E.g. round(2 * 0.9) = 2, but floor(2 * 0.9) = 1
|
||||
let limit = limit.floor() as usize;
|
||||
|
||||
let limit = limit.clamp(self.min, self.max);
|
||||
tracing::info!(limit, "limit decreased");
|
||||
limit
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::rate_limiter::limit_algorithm::{
|
||||
DynamicLimiter, RateLimitAlgorithm, RateLimiterConfig,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn increase_decrease() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 1,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 2,
|
||||
inc: 10,
|
||||
dec: 0.5,
|
||||
utilisation: 0.8,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
token.release(Outcome::Success);
|
||||
|
||||
assert_eq!(limiter.state().limit(), 2);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
token.release(Outcome::Success);
|
||||
assert_eq!(limiter.state().limit(), 2);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
token.release(Outcome::Overload);
|
||||
assert_eq!(limiter.state().limit(), 1);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
token.release(Outcome::Overload);
|
||||
assert_eq!(limiter.state().limit(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn should_decrease_limit_on_overload() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 10,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 1500,
|
||||
inc: 10,
|
||||
dec: 0.5,
|
||||
utilisation: 0.8,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(100))
|
||||
.await
|
||||
.unwrap();
|
||||
token.release(Outcome::Overload);
|
||||
|
||||
assert_eq!(limiter.state().limit(), 5, "overload: decrease");
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn acquire_timeout_times_out() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 1,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 2,
|
||||
inc: 10,
|
||||
dec: 0.5,
|
||||
utilisation: 0.8,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
let now = tokio::time::Instant::now();
|
||||
limiter
|
||||
.acquire_timeout(Duration::from_secs(1))
|
||||
.await
|
||||
.err()
|
||||
.unwrap();
|
||||
|
||||
assert!(now.elapsed() >= Duration::from_secs(1));
|
||||
|
||||
token.release(Outcome::Success);
|
||||
|
||||
assert_eq!(limiter.state().limit(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn should_increase_limit_on_success_when_using_gt_util_threshold() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 4,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 1500,
|
||||
inc: 1,
|
||||
dec: 0.5,
|
||||
utilisation: 0.5,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
let _token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
let _token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
token.release(Outcome::Success);
|
||||
assert_eq!(limiter.state().limit(), 5, "success: increase");
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn should_not_change_limit_on_success_when_using_lt_util_threshold() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 4,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 1500,
|
||||
inc: 10,
|
||||
dec: 0.5,
|
||||
utilisation: 0.5,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
token.release(Outcome::Success);
|
||||
assert_eq!(
|
||||
limiter.state().limit(),
|
||||
4,
|
||||
"success: ignore when < half limit"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn should_not_change_limit_when_no_outcome() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 10,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 1500,
|
||||
inc: 10,
|
||||
dec: 0.5,
|
||||
utilisation: 0.5,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(token);
|
||||
assert_eq!(limiter.state().limit(), 10, "ignore");
|
||||
}
|
||||
}
|
||||
353
proxy/core/src/rate_limiter/limiter.rs
Normal file
353
proxy/core/src/rate_limiter/limiter.rs
Normal file
@@ -0,0 +1,353 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::hash_map::RandomState,
|
||||
hash::{BuildHasher, Hash},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use anyhow::bail;
|
||||
use dashmap::DashMap;
|
||||
use itertools::Itertools;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use crate::intern::EndpointIdInt;
|
||||
|
||||
pub struct GlobalRateLimiter {
|
||||
data: Vec<RateBucket>,
|
||||
info: Vec<RateBucketInfo>,
|
||||
}
|
||||
|
||||
impl GlobalRateLimiter {
|
||||
pub fn new(info: Vec<RateBucketInfo>) -> Self {
|
||||
Self {
|
||||
data: vec![
|
||||
RateBucket {
|
||||
start: Instant::now(),
|
||||
count: 0,
|
||||
};
|
||||
info.len()
|
||||
],
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that number of connections is below `max_rps` rps.
|
||||
pub fn check(&mut self) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
let should_allow_request = self
|
||||
.data
|
||||
.iter_mut()
|
||||
.zip(&self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
|
||||
|
||||
if should_allow_request {
|
||||
// only increment the bucket counts if the request will actually be accepted
|
||||
self.data.iter_mut().for_each(|b| b.inc(1));
|
||||
}
|
||||
|
||||
should_allow_request
|
||||
}
|
||||
}
|
||||
|
||||
// Simple per-endpoint rate limiter.
|
||||
//
|
||||
// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
// Purposefully ignore user name and database name as clients can reconnect
|
||||
// with different names, so we'll end up sending some http requests to
|
||||
// the control plane.
|
||||
pub type WakeComputeRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
|
||||
|
||||
pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
|
||||
map: DashMap<Key, Vec<RateBucket>, Hasher>,
|
||||
info: Cow<'static, [RateBucketInfo]>,
|
||||
access_count: AtomicUsize,
|
||||
rand: Mutex<Rand>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct RateBucket {
|
||||
start: Instant,
|
||||
count: u32,
|
||||
}
|
||||
|
||||
impl RateBucket {
|
||||
fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
|
||||
if now - self.start < info.interval {
|
||||
self.count + n <= info.max_rpi
|
||||
} else {
|
||||
// bucket expired, reset
|
||||
self.count = 0;
|
||||
self.start = now;
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn inc(&mut self, n: u32) {
|
||||
self.count += n;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
pub struct RateBucketInfo {
|
||||
pub interval: Duration,
|
||||
// requests per interval
|
||||
pub max_rpi: u32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RateBucketInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let rps = self.rps().floor() as u64;
|
||||
write!(f, "{rps}@{}", humantime::format_duration(self.interval))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RateBucketInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self}")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for RateBucketInfo {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let Some((max_rps, interval)) = s.split_once('@') else {
|
||||
bail!("invalid rate info")
|
||||
};
|
||||
let max_rps = max_rps.parse()?;
|
||||
let interval = humantime::parse_duration(interval)?;
|
||||
Ok(Self::new(max_rps, interval))
|
||||
}
|
||||
}
|
||||
|
||||
impl RateBucketInfo {
|
||||
pub const DEFAULT_SET: [Self; 3] = [
|
||||
Self::new(300, Duration::from_secs(1)),
|
||||
Self::new(200, Duration::from_secs(60)),
|
||||
Self::new(100, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
|
||||
Self::new(500, Duration::from_secs(1)),
|
||||
Self::new(300, Duration::from_secs(60)),
|
||||
Self::new(200, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
pub fn rps(&self) -> f64 {
|
||||
(self.max_rpi as f64) / self.interval.as_secs_f64()
|
||||
}
|
||||
|
||||
pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
|
||||
info.sort_unstable_by_key(|info| info.interval);
|
||||
let invalid = info
|
||||
.iter()
|
||||
.tuple_windows()
|
||||
.find(|(a, b)| a.max_rpi > b.max_rpi);
|
||||
if let Some((a, b)) = invalid {
|
||||
bail!(
|
||||
"invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
|
||||
b.max_rpi,
|
||||
a.max_rpi,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub const fn new(max_rps: u32, interval: Duration) -> Self {
|
||||
Self {
|
||||
interval,
|
||||
max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq> BucketRateLimiter<K> {
|
||||
pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
|
||||
Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
|
||||
fn new_with_rand_and_hasher(
|
||||
info: impl Into<Cow<'static, [RateBucketInfo]>>,
|
||||
rand: R,
|
||||
hasher: S,
|
||||
) -> Self {
|
||||
let info = info.into();
|
||||
info!(buckets = ?info, "endpoint rate limiter");
|
||||
Self {
|
||||
info,
|
||||
map: DashMap::with_hasher_and_shard_amount(hasher, 64),
|
||||
access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
|
||||
rand: Mutex::new(rand),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
pub fn check(&self, key: K, n: u32) -> bool {
|
||||
// do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
|
||||
// worst case memory usage is about:
|
||||
// = 2 * 2048 * 64 * (48B + 72B)
|
||||
// = 30MB
|
||||
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
|
||||
self.do_gc();
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let mut entry = self.map.entry(key).or_insert_with(|| {
|
||||
vec![
|
||||
RateBucket {
|
||||
start: now,
|
||||
count: 0,
|
||||
};
|
||||
self.info.len()
|
||||
]
|
||||
});
|
||||
|
||||
let should_allow_request = entry
|
||||
.iter_mut()
|
||||
.zip(&*self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now, n));
|
||||
|
||||
if should_allow_request {
|
||||
// only increment the bucket counts if the request will actually be accepted
|
||||
entry.iter_mut().for_each(|b| b.inc(n));
|
||||
}
|
||||
|
||||
should_allow_request
|
||||
}
|
||||
|
||||
/// Clean the map. Simple strategy: remove all entries in a random shard.
|
||||
/// At worst, we'll double the effective max_rps during the cleanup.
|
||||
/// But that way deletion does not aquire mutex on each entry access.
|
||||
pub fn do_gc(&self) {
|
||||
info!(
|
||||
"cleaning up bucket rate limiter, current size = {}",
|
||||
self.map.len()
|
||||
);
|
||||
let n = self.map.shards().len();
|
||||
// this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
|
||||
// (impossible, infact, unless we have 2048 threads)
|
||||
let shard = self.rand.lock().unwrap().gen_range(0..n);
|
||||
self.map.shards()[shard].write().clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{hash::BuildHasherDefault, time::Duration};
|
||||
|
||||
use rand::SeedableRng;
|
||||
use rustc_hash::FxHasher;
|
||||
use tokio::time;
|
||||
|
||||
use super::{BucketRateLimiter, WakeComputeRateLimiter};
|
||||
use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId};
|
||||
|
||||
#[test]
|
||||
fn rate_bucket_rpi() {
|
||||
let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
|
||||
assert_eq!(rate_bucket.max_rpi, 50 * 5);
|
||||
|
||||
let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
|
||||
assert_eq!(rate_bucket.max_rpi, 50 / 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_bucket_parse() {
|
||||
let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
|
||||
assert_eq!(rate_bucket.interval, Duration::from_secs(10));
|
||||
assert_eq!(rate_bucket.max_rpi, 100 * 10);
|
||||
assert_eq!(rate_bucket.to_string(), "100@10s");
|
||||
|
||||
let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
|
||||
assert_eq!(rate_bucket.interval, Duration::from_secs(60));
|
||||
assert_eq!(rate_bucket.max_rpi, 100 * 60);
|
||||
assert_eq!(rate_bucket.to_string(), "100@1m");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_rate_buckets() {
|
||||
let mut defaults = RateBucketInfo::DEFAULT_SET;
|
||||
RateBucketInfo::validate(&mut defaults[..]).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
|
||||
fn rate_buckets_validate() {
|
||||
let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
|
||||
.into_iter()
|
||||
.map(|s| s.parse().unwrap())
|
||||
.collect();
|
||||
RateBucketInfo::validate(&mut rates).unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rate_limits() {
|
||||
let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
|
||||
.into_iter()
|
||||
.map(|s| s.parse().unwrap())
|
||||
.collect();
|
||||
RateBucketInfo::validate(&mut rates).unwrap();
|
||||
let limiter = WakeComputeRateLimiter::new(rates);
|
||||
|
||||
let endpoint = EndpointId::from("ep-my-endpoint-1234");
|
||||
let endpoint = EndpointIdInt::from(endpoint);
|
||||
|
||||
time::pause();
|
||||
|
||||
for _ in 0..100 {
|
||||
assert!(limiter.check(endpoint, 1));
|
||||
}
|
||||
// more connections fail
|
||||
assert!(!limiter.check(endpoint, 1));
|
||||
|
||||
// fail even after 500ms as it's in the same bucket
|
||||
time::advance(time::Duration::from_millis(500)).await;
|
||||
assert!(!limiter.check(endpoint, 1));
|
||||
|
||||
// after a full 1s, 100 requests are allowed again
|
||||
time::advance(time::Duration::from_millis(500)).await;
|
||||
for _ in 1..6 {
|
||||
for _ in 0..50 {
|
||||
assert!(limiter.check(endpoint, 2));
|
||||
}
|
||||
time::advance(time::Duration::from_millis(1000)).await;
|
||||
}
|
||||
|
||||
// more connections after 600 will exceed the 20rps@30s limit
|
||||
assert!(!limiter.check(endpoint, 1));
|
||||
|
||||
// will still fail before the 30 second limit
|
||||
time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
|
||||
assert!(!limiter.check(endpoint, 1));
|
||||
|
||||
// after the full 30 seconds, 100 requests are allowed again
|
||||
time::advance(time::Duration::from_millis(1)).await;
|
||||
for _ in 0..100 {
|
||||
assert!(limiter.check(endpoint, 1));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rate_limits_gc() {
|
||||
// fixed seeded random/hasher to ensure that the test is not flaky
|
||||
let rand = rand::rngs::StdRng::from_seed([1; 32]);
|
||||
let hasher = BuildHasherDefault::<FxHasher>::default();
|
||||
|
||||
let limiter =
|
||||
BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
|
||||
for i in 0..1_000_000 {
|
||||
limiter.check(i, 1);
|
||||
}
|
||||
assert!(limiter.map.len() < 150_000);
|
||||
}
|
||||
}
|
||||
4
proxy/core/src/redis.rs
Normal file
4
proxy/core/src/redis.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod cancellation_publisher;
|
||||
pub mod connection_with_credentials_provider;
|
||||
pub mod elasticache;
|
||||
pub mod notifications;
|
||||
161
proxy/core/src/redis/cancellation_publisher.rs
Normal file
161
proxy/core/src/redis/cancellation_publisher.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::AsyncCommands;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
|
||||
|
||||
use super::{
|
||||
connection_with_credentials_provider::ConnectionWithCredentialsProvider,
|
||||
notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME},
|
||||
};
|
||||
|
||||
pub trait CancellationPublisherMut: Send + Sync + 'static {
|
||||
#[allow(async_fn_in_trait)]
|
||||
async fn try_publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
pub trait CancellationPublisher: Send + Sync + 'static {
|
||||
#[allow(async_fn_in_trait)]
|
||||
async fn try_publish(
|
||||
&self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
impl CancellationPublisher for () {
|
||||
async fn try_publish(
|
||||
&self,
|
||||
_cancel_key_data: CancelKeyData,
|
||||
_session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: CancellationPublisher> CancellationPublisherMut for P {
|
||||
async fn try_publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
<P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
|
||||
async fn try_publish(
|
||||
&self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
if let Some(p) = self {
|
||||
p.try_publish(cancel_key_data, session_id).await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
|
||||
async fn try_publish(
|
||||
&self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
self.lock()
|
||||
.await
|
||||
.try_publish(cancel_key_data, session_id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RedisPublisherClient {
|
||||
client: ConnectionWithCredentialsProvider,
|
||||
region_id: String,
|
||||
limiter: GlobalRateLimiter,
|
||||
}
|
||||
|
||||
impl RedisPublisherClient {
|
||||
pub fn new(
|
||||
client: ConnectionWithCredentialsProvider,
|
||||
region_id: String,
|
||||
info: &'static [RateBucketInfo],
|
||||
) -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
client,
|
||||
region_id,
|
||||
limiter: GlobalRateLimiter::new(info.into()),
|
||||
})
|
||||
}
|
||||
|
||||
async fn publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
|
||||
region_id: Some(self.region_id.clone()),
|
||||
cancel_key_data,
|
||||
session_id,
|
||||
}))?;
|
||||
let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.connect().await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to connect to redis: {e}");
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
async fn try_publish_internal(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping cancellation message");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
match self.publish(cancel_key_data, session_id).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
}
|
||||
}
|
||||
tracing::info!("Publisher is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.publish(cancel_key_data, session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
impl CancellationPublisherMut for RedisPublisherClient {
|
||||
async fn try_publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
tracing::info!("publishing cancellation key to Redis");
|
||||
match self.try_publish_internal(cancel_key_data, session_id).await {
|
||||
Ok(()) => {
|
||||
tracing::info!("cancellation key successfuly published to Redis");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
237
proxy/core/src/redis/connection_with_credentials_provider.rs
Normal file
237
proxy/core/src/redis/connection_with_credentials_provider.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use futures::FutureExt;
|
||||
use redis::{
|
||||
aio::{ConnectionLike, MultiplexedConnection},
|
||||
ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult,
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{error, info};
|
||||
|
||||
use super::elasticache::CredentialsProvider;
|
||||
|
||||
enum Credentials {
|
||||
Static(ConnectionInfo),
|
||||
Dynamic(Arc<CredentialsProvider>, redis::ConnectionAddr),
|
||||
}
|
||||
|
||||
impl Clone for Credentials {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Credentials::Static(info) => Credentials::Static(info.clone()),
|
||||
Credentials::Dynamic(provider, addr) => {
|
||||
Credentials::Dynamic(Arc::clone(provider), addr.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token.
|
||||
/// Provides PubSub connection without credentials refresh.
|
||||
pub struct ConnectionWithCredentialsProvider {
|
||||
credentials: Credentials,
|
||||
con: Option<MultiplexedConnection>,
|
||||
refresh_token_task: Option<JoinHandle<()>>,
|
||||
mutex: tokio::sync::Mutex<()>,
|
||||
}
|
||||
|
||||
impl Clone for ConnectionWithCredentialsProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
credentials: self.credentials.clone(),
|
||||
con: None,
|
||||
refresh_token_task: None,
|
||||
mutex: tokio::sync::Mutex::new(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionWithCredentialsProvider {
|
||||
pub fn new_with_credentials_provider(
|
||||
host: String,
|
||||
port: u16,
|
||||
credentials_provider: Arc<CredentialsProvider>,
|
||||
) -> Self {
|
||||
Self {
|
||||
credentials: Credentials::Dynamic(
|
||||
credentials_provider,
|
||||
redis::ConnectionAddr::TcpTls {
|
||||
host,
|
||||
port,
|
||||
insecure: false,
|
||||
tls_params: None,
|
||||
},
|
||||
),
|
||||
con: None,
|
||||
refresh_token_task: None,
|
||||
mutex: tokio::sync::Mutex::new(()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_static_credentials<T: IntoConnectionInfo>(params: T) -> Self {
|
||||
Self {
|
||||
credentials: Credentials::Static(params.into_connection_info().unwrap()),
|
||||
con: None,
|
||||
refresh_token_task: None,
|
||||
mutex: tokio::sync::Mutex::new(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn ping(con: &mut MultiplexedConnection) -> RedisResult<()> {
|
||||
redis::cmd("PING").query_async(con).await
|
||||
}
|
||||
|
||||
pub async fn connect(&mut self) -> anyhow::Result<()> {
|
||||
let _guard = self.mutex.lock().await;
|
||||
if let Some(con) = self.con.as_mut() {
|
||||
match Self::ping(con).await {
|
||||
Ok(()) => {
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error during PING: {e:?}");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("Connection is not established");
|
||||
}
|
||||
info!("Establishing a new connection...");
|
||||
self.con = None;
|
||||
if let Some(f) = self.refresh_token_task.take() {
|
||||
f.abort()
|
||||
}
|
||||
let mut con = self
|
||||
.get_client()
|
||||
.await?
|
||||
.get_multiplexed_tokio_connection()
|
||||
.await?;
|
||||
if let Credentials::Dynamic(credentials_provider, _) = &self.credentials {
|
||||
let credentials_provider = credentials_provider.clone();
|
||||
let con2 = con.clone();
|
||||
let f = tokio::spawn(async move {
|
||||
let _ = Self::keep_connection(con2, credentials_provider).await;
|
||||
});
|
||||
self.refresh_token_task = Some(f);
|
||||
}
|
||||
match Self::ping(&mut con).await {
|
||||
Ok(()) => {
|
||||
info!("Connection succesfully established");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Connection is broken. Error during PING: {e:?}");
|
||||
}
|
||||
}
|
||||
self.con = Some(con);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_connection_info(&self) -> anyhow::Result<ConnectionInfo> {
|
||||
match &self.credentials {
|
||||
Credentials::Static(info) => Ok(info.clone()),
|
||||
Credentials::Dynamic(provider, addr) => {
|
||||
let (username, password) = provider.provide_credentials().await?;
|
||||
Ok(ConnectionInfo {
|
||||
addr: addr.clone(),
|
||||
redis: RedisConnectionInfo {
|
||||
db: 0,
|
||||
username: Some(username),
|
||||
password: Some(password.clone()),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_client(&self) -> anyhow::Result<redis::Client> {
|
||||
let client = redis::Client::open(self.get_connection_info().await?)?;
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
// PubSub does not support credentials refresh.
|
||||
// Requires manual reconnection every 12h.
|
||||
pub async fn get_async_pubsub(&self) -> anyhow::Result<redis::aio::PubSub> {
|
||||
Ok(self.get_client().await?.get_async_pubsub().await?)
|
||||
}
|
||||
|
||||
// The connection lives for 12h.
|
||||
// It can be prolonged with sending `AUTH` commands with the refreshed token.
|
||||
// https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits
|
||||
async fn keep_connection(
|
||||
mut con: MultiplexedConnection,
|
||||
credentials_provider: Arc<CredentialsProvider>,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
// The connection lives for 12h, for the sanity check we refresh it every hour.
|
||||
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
|
||||
match Self::refresh_token(&mut con, credentials_provider.clone()).await {
|
||||
Ok(()) => {
|
||||
info!("Token refreshed");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error during token refresh: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
async fn refresh_token(
|
||||
con: &mut MultiplexedConnection,
|
||||
credentials_provider: Arc<CredentialsProvider>,
|
||||
) -> anyhow::Result<()> {
|
||||
let (user, password) = credentials_provider.provide_credentials().await?;
|
||||
let _: () = redis::cmd("AUTH")
|
||||
.arg(user)
|
||||
.arg(password)
|
||||
.query_async(con)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
/// Sends an already encoded (packed) command into the TCP socket and
|
||||
/// reads the single response from it.
|
||||
pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult<redis::Value> {
|
||||
// Clone connection to avoid having to lock the ArcSwap in write mode
|
||||
let con = self.con.as_mut().ok_or(redis::RedisError::from((
|
||||
redis::ErrorKind::IoError,
|
||||
"Connection not established",
|
||||
)))?;
|
||||
con.send_packed_command(cmd).await
|
||||
}
|
||||
|
||||
/// Sends multiple already encoded (packed) command into the TCP socket
|
||||
/// and reads `count` responses from it. This is used to implement
|
||||
/// pipelining.
|
||||
pub async fn send_packed_commands(
|
||||
&mut self,
|
||||
cmd: &redis::Pipeline,
|
||||
offset: usize,
|
||||
count: usize,
|
||||
) -> RedisResult<Vec<redis::Value>> {
|
||||
// Clone shared connection future to avoid having to lock the ArcSwap in write mode
|
||||
let con = self.con.as_mut().ok_or(redis::RedisError::from((
|
||||
redis::ErrorKind::IoError,
|
||||
"Connection not established",
|
||||
)))?;
|
||||
con.send_packed_commands(cmd, offset, count).await
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionLike for ConnectionWithCredentialsProvider {
|
||||
fn req_packed_command<'a>(
|
||||
&'a mut self,
|
||||
cmd: &'a redis::Cmd,
|
||||
) -> redis::RedisFuture<'a, redis::Value> {
|
||||
(async move { self.send_packed_command(cmd).await }).boxed()
|
||||
}
|
||||
|
||||
fn req_packed_commands<'a>(
|
||||
&'a mut self,
|
||||
cmd: &'a redis::Pipeline,
|
||||
offset: usize,
|
||||
count: usize,
|
||||
) -> redis::RedisFuture<'a, Vec<redis::Value>> {
|
||||
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
|
||||
}
|
||||
|
||||
fn get_db(&self) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
110
proxy/core/src/redis/elasticache.rs
Normal file
110
proxy/core/src/redis/elasticache.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use aws_config::meta::credentials::CredentialsProviderChain;
|
||||
use aws_sdk_iam::config::ProvideCredentials;
|
||||
use aws_sigv4::http_request::{
|
||||
self, SignableBody, SignableRequest, SignatureLocation, SigningSettings,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AWSIRSAConfig {
|
||||
region: String,
|
||||
service_name: String,
|
||||
cluster_name: String,
|
||||
user_id: String,
|
||||
token_ttl: Duration,
|
||||
action: String,
|
||||
}
|
||||
|
||||
impl AWSIRSAConfig {
|
||||
pub fn new(region: String, cluster_name: Option<String>, user_id: Option<String>) -> Self {
|
||||
AWSIRSAConfig {
|
||||
region,
|
||||
service_name: "elasticache".to_string(),
|
||||
cluster_name: cluster_name.unwrap_or_default(),
|
||||
user_id: user_id.unwrap_or_default(),
|
||||
// "The IAM authentication token is valid for 15 minutes"
|
||||
// https://docs.aws.amazon.com/memorydb/latest/devguide/auth-iam.html#auth-iam-limits
|
||||
token_ttl: Duration::from_secs(15 * 60),
|
||||
action: "connect".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Credentials provider for AWS elasticache authentication.
|
||||
///
|
||||
/// Official documentation:
|
||||
/// <https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html>
|
||||
///
|
||||
/// Useful resources:
|
||||
/// <https://aws.amazon.com/blogs/database/simplify-managing-access-to-amazon-elasticache-for-redis-clusters-with-iam/>
|
||||
pub struct CredentialsProvider {
|
||||
config: AWSIRSAConfig,
|
||||
credentials_provider: CredentialsProviderChain,
|
||||
}
|
||||
|
||||
impl CredentialsProvider {
|
||||
pub fn new(config: AWSIRSAConfig, credentials_provider: CredentialsProviderChain) -> Self {
|
||||
CredentialsProvider {
|
||||
config,
|
||||
credentials_provider,
|
||||
}
|
||||
}
|
||||
pub async fn provide_credentials(&self) -> anyhow::Result<(String, String)> {
|
||||
let aws_credentials = self
|
||||
.credentials_provider
|
||||
.provide_credentials()
|
||||
.await?
|
||||
.into();
|
||||
info!("AWS credentials successfully obtained");
|
||||
info!("Connecting to Redis with configuration: {:?}", self.config);
|
||||
let mut settings = SigningSettings::default();
|
||||
settings.signature_location = SignatureLocation::QueryParams;
|
||||
settings.expires_in = Some(self.config.token_ttl);
|
||||
let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
|
||||
.identity(&aws_credentials)
|
||||
.region(&self.config.region)
|
||||
.name(&self.config.service_name)
|
||||
.time(SystemTime::now())
|
||||
.settings(settings)
|
||||
.build()?
|
||||
.into();
|
||||
let auth_params = [
|
||||
("Action", &self.config.action),
|
||||
("User", &self.config.user_id),
|
||||
];
|
||||
let auth_params = url::form_urlencoded::Serializer::new(String::new())
|
||||
.extend_pairs(auth_params)
|
||||
.finish();
|
||||
let auth_uri = http::Uri::builder()
|
||||
.scheme("http")
|
||||
.authority(self.config.cluster_name.as_bytes())
|
||||
.path_and_query(format!("/?{auth_params}"))
|
||||
.build()?;
|
||||
info!("{}", auth_uri);
|
||||
|
||||
// Convert the HTTP request into a signable request
|
||||
let signable_request = SignableRequest::new(
|
||||
"GET",
|
||||
auth_uri.to_string(),
|
||||
std::iter::empty(),
|
||||
SignableBody::Bytes(&[]),
|
||||
)?;
|
||||
|
||||
// Sign and then apply the signature to the request
|
||||
let (si, _) = http_request::sign(signable_request, &signing_params)?.into_parts();
|
||||
let mut signable_request = http::Request::builder()
|
||||
.method("GET")
|
||||
.uri(auth_uri)
|
||||
.body(())?;
|
||||
si.apply_to_request_http1x(&mut signable_request);
|
||||
Ok((
|
||||
self.config.user_id.clone(),
|
||||
signable_request
|
||||
.uri()
|
||||
.to_string()
|
||||
.replacen("http://", "", 1),
|
||||
))
|
||||
}
|
||||
}
|
||||
356
proxy/core/src/redis/notifications.rs
Normal file
356
proxy/core/src/redis/notifications.rs
Normal file
@@ -0,0 +1,356 @@
|
||||
use std::{convert::Infallible, sync::Arc};
|
||||
|
||||
use futures::StreamExt;
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::aio::PubSub;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::{
|
||||
cache::project_info::ProjectInfoCache,
|
||||
cancellation::{CancelMap, CancellationHandler},
|
||||
intern::{ProjectIdInt, RoleNameInt},
|
||||
metrics::{Metrics, RedisErrors, RedisEventsCount},
|
||||
};
|
||||
|
||||
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
|
||||
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
|
||||
async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
|
||||
let mut conn = client.get_async_pubsub().await?;
|
||||
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
|
||||
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
|
||||
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
|
||||
conn.subscribe(PROXY_CHANNEL_NAME).await?;
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[serde(tag = "topic", content = "data")]
|
||||
pub(crate) enum Notification {
|
||||
#[serde(
|
||||
rename = "/allowed_ips_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate,
|
||||
},
|
||||
#[serde(
|
||||
rename = "/password_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
PasswordUpdate { password_update: PasswordUpdate },
|
||||
#[serde(rename = "/cancel_session")]
|
||||
Cancel(CancelSession),
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedIpsUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct PasswordUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
role_name: RoleNameInt,
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct CancelSession {
|
||||
pub region_id: Option<String>,
|
||||
pub cancel_key_data: CancelKeyData,
|
||||
pub session_id: Uuid,
|
||||
}
|
||||
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: for<'de2> serde::Deserialize<'de2>,
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
|
||||
}
|
||||
|
||||
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
|
||||
cache: Arc<C>,
|
||||
cancellation_handler: Arc<CancellationHandler<()>>,
|
||||
region_id: String,
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
cache: self.cache.clone(),
|
||||
cancellation_handler: self.cancellation_handler.clone(),
|
||||
region_id: self.region_id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
pub fn new(
|
||||
cache: Arc<C>,
|
||||
cancellation_handler: Arc<CancellationHandler<()>>,
|
||||
region_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
cache,
|
||||
cancellation_handler,
|
||||
region_id,
|
||||
}
|
||||
}
|
||||
pub async fn increment_active_listeners(&self) {
|
||||
self.cache.increment_active_listeners().await;
|
||||
}
|
||||
pub async fn decrement_active_listeners(&self) {
|
||||
self.cache.decrement_active_listeners().await;
|
||||
}
|
||||
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
|
||||
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
|
||||
use Notification::*;
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
let msg: Notification = match serde_json::from_str(&payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
|
||||
channel: msg.get_channel_name(),
|
||||
});
|
||||
tracing::error!("broken message: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
match msg {
|
||||
Cancel(cancel_session) => {
|
||||
tracing::Span::current().record(
|
||||
"session_id",
|
||||
tracing::field::display(cancel_session.session_id),
|
||||
);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::CancelSession);
|
||||
if let Some(cancel_region) = cancel_session.region_id {
|
||||
// If the message is not for this region, ignore it.
|
||||
if cancel_region != self.region_id {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
|
||||
match self
|
||||
.cancellation_handler
|
||||
.cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil())
|
||||
.await
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to cancel session: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
invalidate_cache(self.cache.clone(), msg.clone());
|
||||
if matches!(msg, AllowedIpsUpdate { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedIpsUpdate);
|
||||
} else if matches!(msg, PasswordUpdate { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::PasswordUpdate);
|
||||
}
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
let cache = self.cache.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
use Notification::*;
|
||||
match msg {
|
||||
AllowedIpsUpdate { allowed_ips_update } => {
|
||||
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
|
||||
}
|
||||
PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
|
||||
password_update.project_id,
|
||||
password_update.role_name,
|
||||
),
|
||||
Cancel(_) => unreachable!("cancel message should be handled separately"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
handler: MessageHandler<C>,
|
||||
redis: ConnectionWithCredentialsProvider,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
if cancellation_token.is_cancelled() {
|
||||
return Ok(());
|
||||
}
|
||||
let mut conn = match try_connect(&redis).await {
|
||||
Ok(conn) => {
|
||||
handler.increment_active_listeners().await;
|
||||
conn
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
|
||||
);
|
||||
tokio::time::sleep(RECONNECT_TIMEOUT).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let mut stream = conn.on_message();
|
||||
while let Some(msg) = stream.next().await {
|
||||
match handler.handle_message(msg).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to handle message: {e}, will try to reconnect");
|
||||
break;
|
||||
}
|
||||
}
|
||||
if cancellation_token.is_cancelled() {
|
||||
handler.decrement_active_listeners().await;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
handler.decrement_active_listeners().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "redis_notifications", skip_all)]
|
||||
pub async fn task_main<C>(
|
||||
redis: ConnectionWithCredentialsProvider,
|
||||
cache: Arc<C>,
|
||||
cancel_map: CancelMap,
|
||||
region_id: String,
|
||||
) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
|
||||
cancel_map,
|
||||
crate::metrics::CancellationSource::FromRedis,
|
||||
));
|
||||
let handler = MessageHandler::new(cache, cancellation_handler, region_id);
|
||||
// 6h - 1m.
|
||||
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
|
||||
loop {
|
||||
let cancellation_token = CancellationToken::new();
|
||||
interval.tick().await;
|
||||
|
||||
tokio::spawn(handle_messages(
|
||||
handler.clone(),
|
||||
redis.clone(),
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
|
||||
cancellation_token.cancel();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{ProjectId, RoleName};
|
||||
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_allowed_ips() -> anyhow::Result<()> {
|
||||
let project_id: ProjectId = "new_project".into();
|
||||
let data = format!("{{\"project_id\": \"{project_id}\"}}");
|
||||
let text = json!({
|
||||
"type": "message",
|
||||
"topic": "/allowed_ips_updated",
|
||||
"data": data,
|
||||
"extre_fields": "something"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate {
|
||||
project_id: (&project_id).into()
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_updated() -> anyhow::Result<()> {
|
||||
let project_id: ProjectId = "new_project".into();
|
||||
let role_name: RoleName = "new_role".into();
|
||||
let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
|
||||
let text = json!({
|
||||
"type": "message",
|
||||
"topic": "/password_updated",
|
||||
"data": data,
|
||||
"extre_fields": "something"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::PasswordUpdate {
|
||||
password_update: PasswordUpdate {
|
||||
project_id: (&project_id).into(),
|
||||
role_name: (&role_name).into(),
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn parse_cancel_session() -> anyhow::Result<()> {
|
||||
let cancel_key_data = CancelKeyData {
|
||||
backend_pid: 42,
|
||||
cancel_key: 41,
|
||||
};
|
||||
let uuid = uuid::Uuid::new_v4();
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: None,
|
||||
session_id: uuid,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result);
|
||||
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: Some("region".to_string()),
|
||||
session_id: uuid,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result,);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
89
proxy/core/src/sasl.rs
Normal file
89
proxy/core/src/sasl.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
//! Simple Authentication and Security Layer.
|
||||
//!
|
||||
//! RFC: <https://datatracker.ietf.org/doc/html/rfc4422>.
|
||||
//!
|
||||
//! Reference implementation:
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-sasl.c>
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth.c>
|
||||
|
||||
mod channel_binding;
|
||||
mod messages;
|
||||
mod stream;
|
||||
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
pub use channel_binding::ChannelBinding;
|
||||
pub use messages::FirstMessage;
|
||||
pub use stream::{Outcome, SaslStream};
|
||||
|
||||
/// Fine-grained auth errors help in writing tests.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Channel binding failed: {0}")]
|
||||
ChannelBindingFailed(&'static str),
|
||||
|
||||
#[error("Unsupported channel binding method: {0}")]
|
||||
ChannelBindingBadMethod(Box<str>),
|
||||
|
||||
#[error("Bad client message: {0}")]
|
||||
BadClientMessage(&'static str),
|
||||
|
||||
#[error("Internal error: missing digest")]
|
||||
MissingBinding,
|
||||
|
||||
#[error("could not decode salt: {0}")]
|
||||
Base64(#[from] base64::DecodeError),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl UserFacingError for Error {
|
||||
fn to_string_client(&self) -> String {
|
||||
use Error::*;
|
||||
match self {
|
||||
ChannelBindingFailed(m) => m.to_string(),
|
||||
ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
|
||||
_ => "authentication protocol violation".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for Error {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
|
||||
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
|
||||
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
|
||||
Error::MissingBinding => crate::error::ErrorKind::Service,
|
||||
Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
|
||||
Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A convenient result type for SASL exchange.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// A result of one SASL exchange.
|
||||
#[must_use]
|
||||
pub enum Step<T, R> {
|
||||
/// We should continue exchanging messages.
|
||||
Continue(T, String),
|
||||
/// The client has been authenticated successfully.
|
||||
Success(R, String),
|
||||
/// Authentication failed (reason attached).
|
||||
Failure(&'static str),
|
||||
}
|
||||
|
||||
/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait.
|
||||
pub trait Mechanism: Sized {
|
||||
/// What's produced as a result of successful authentication.
|
||||
type Output;
|
||||
|
||||
/// Produce a server challenge to be sent to the client.
|
||||
/// This is how this method is called in PostgreSQL (`libpq/sasl.h`).
|
||||
fn exchange(self, input: &str) -> Result<Step<Self, Self::Output>>;
|
||||
}
|
||||
84
proxy/core/src/sasl/channel_binding.rs
Normal file
84
proxy/core/src/sasl/channel_binding.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
//! Definition and parser for channel binding flag (a part of the `GS2` header).
|
||||
|
||||
/// Channel binding flag (possibly with params).
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum ChannelBinding<T> {
|
||||
/// Client doesn't support channel binding.
|
||||
NotSupportedClient,
|
||||
/// Client thinks server doesn't support channel binding.
|
||||
NotSupportedServer,
|
||||
/// Client wants to use this type of channel binding.
|
||||
Required(T),
|
||||
}
|
||||
|
||||
impl<T> ChannelBinding<T> {
|
||||
pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
NotSupportedClient => NotSupportedClient,
|
||||
NotSupportedServer => NotSupportedServer,
|
||||
Required(x) => Required(f(x)?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ChannelBinding<&'a str> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
use ChannelBinding::*;
|
||||
Some(match input {
|
||||
"n" => NotSupportedClient,
|
||||
"y" => NotSupportedServer,
|
||||
other => Required(other.strip_prefix("p=")?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
/// Encode channel binding data as base64 for subsequent checks.
|
||||
pub fn encode<'a, E>(
|
||||
&self,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
|
||||
) -> Result<std::borrow::Cow<'static, str>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
NotSupportedClient => {
|
||||
// base64::encode("n,,")
|
||||
"biws".into()
|
||||
}
|
||||
NotSupportedServer => {
|
||||
// base64::encode("y,,")
|
||||
"eSws".into()
|
||||
}
|
||||
Required(mode) => {
|
||||
use std::io::Write;
|
||||
let mut cbind_input = vec![];
|
||||
write!(&mut cbind_input, "p={mode},,",).unwrap();
|
||||
cbind_input.extend_from_slice(get_cbind_data(mode)?);
|
||||
base64::encode(&cbind_input).into()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn channel_binding_encode() -> anyhow::Result<()> {
|
||||
use ChannelBinding::*;
|
||||
|
||||
let cases = [
|
||||
(NotSupportedClient, base64::encode("n,,")),
|
||||
(NotSupportedServer, base64::encode("y,,")),
|
||||
(Required("foo"), base64::encode("p=foo,,bar")),
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
68
proxy/core/src/sasl/messages.rs
Normal file
68
proxy/core/src/sasl/messages.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
//! Definitions for SASL messages.
|
||||
|
||||
use crate::parse::{split_at_const, split_cstr};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
|
||||
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
|
||||
#[derive(Debug)]
|
||||
pub struct FirstMessage<'a> {
|
||||
/// Authentication method, e.g. `"SCRAM-SHA-256"`.
|
||||
pub method: &'a str,
|
||||
/// Initial client message.
|
||||
pub message: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> FirstMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(bytes: &'a [u8]) -> Option<Self> {
|
||||
let (method_cstr, tail) = split_cstr(bytes)?;
|
||||
let method = method_cstr.to_str().ok()?;
|
||||
|
||||
let (len_bytes, bytes) = split_at_const(tail)?;
|
||||
let len = u32::from_be_bytes(*len_bytes) as usize;
|
||||
if len != bytes.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let message = std::str::from_utf8(bytes).ok()?;
|
||||
Some(Self { method, message })
|
||||
}
|
||||
}
|
||||
|
||||
/// A single SASL message.
|
||||
/// This struct is deliberately decoupled from lower-level
|
||||
/// [`BeAuthenticationSaslMessage`].
|
||||
#[derive(Debug)]
|
||||
pub(super) enum ServerMessage<T> {
|
||||
/// We expect to see more steps.
|
||||
Continue(T),
|
||||
/// This is the final step.
|
||||
Final(T),
|
||||
}
|
||||
|
||||
impl<'a> ServerMessage<&'a str> {
|
||||
pub(super) fn to_reply(&self) -> BeMessage<'a> {
|
||||
use BeAuthenticationSaslMessage::*;
|
||||
BeMessage::AuthenticationSasl(match self {
|
||||
ServerMessage::Continue(s) => Continue(s.as_bytes()),
|
||||
ServerMessage::Final(s) => Final(s.as_bytes()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_sasl_first_message() {
|
||||
let proto = "SCRAM-SHA-256";
|
||||
let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4";
|
||||
let sasl_len = (sasl.len() as u32).to_be_bytes();
|
||||
let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat();
|
||||
|
||||
let password = FirstMessage::parse(&bytes).unwrap();
|
||||
assert_eq!(password.method, proto);
|
||||
assert_eq!(password.message, sasl);
|
||||
}
|
||||
}
|
||||
92
proxy/core/src/sasl/stream.rs
Normal file
92
proxy/core/src/sasl/stream.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
//! Abstraction for the string-oriented SASL protocols.
|
||||
|
||||
use super::{messages::ServerMessage, Mechanism};
|
||||
use crate::stream::PqStream;
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
/// Abstracts away all peculiarities of the libpq's protocol.
|
||||
pub struct SaslStream<'a, S> {
|
||||
/// The underlying stream.
|
||||
stream: &'a mut PqStream<S>,
|
||||
/// Current password message we received from client.
|
||||
current: bytes::Bytes,
|
||||
/// First SASL message produced by client.
|
||||
first: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a, S> SaslStream<'a, S> {
|
||||
pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
current: bytes::Bytes::new(),
|
||||
first: Some(first),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
|
||||
// Receive a new SASL message from the client.
|
||||
async fn recv(&mut self) -> io::Result<&str> {
|
||||
if let Some(first) = self.first.take() {
|
||||
return Ok(first);
|
||||
}
|
||||
|
||||
self.current = self.stream.read_password_message().await?;
|
||||
let s = std::str::from_utf8(&self.current)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
// Send a SASL message to the client.
|
||||
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
|
||||
self.stream.write_message(&msg.to_reply()).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// SASL authentication outcome.
|
||||
/// It's much easier to match on those two variants
|
||||
/// than to peek into a noisy protocol error type.
|
||||
#[must_use = "caller must explicitly check for success"]
|
||||
pub enum Outcome<R> {
|
||||
/// Authentication succeeded and produced some value.
|
||||
Success(R),
|
||||
/// Authentication failed (reason attached).
|
||||
Failure(&'static str),
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
/// Perform SASL message exchange according to the underlying algorithm
|
||||
/// until user is either authenticated or denied access.
|
||||
pub async fn authenticate<M: Mechanism>(
|
||||
mut self,
|
||||
mut mechanism: M,
|
||||
) -> super::Result<Outcome<M::Output>> {
|
||||
loop {
|
||||
let input = self.recv().await?;
|
||||
let step = mechanism.exchange(input).map_err(|error| {
|
||||
info!(?error, "error during SASL exchange");
|
||||
error
|
||||
})?;
|
||||
|
||||
use super::Step;
|
||||
return Ok(match step {
|
||||
Step::Continue(moved_mechanism, reply) => {
|
||||
self.send(&ServerMessage::Continue(&reply)).await?;
|
||||
mechanism = moved_mechanism;
|
||||
continue;
|
||||
}
|
||||
Step::Success(result, reply) => {
|
||||
self.send(&ServerMessage::Final(&reply)).await?;
|
||||
Outcome::Success(result)
|
||||
}
|
||||
Step::Failure(reason) => Outcome::Failure(reason),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
148
proxy/core/src/scram.rs
Normal file
148
proxy/core/src/scram.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
//! Salted Challenge Response Authentication Mechanism.
|
||||
//!
|
||||
//! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
|
||||
//!
|
||||
//! Reference implementation:
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
|
||||
|
||||
mod countmin;
|
||||
mod exchange;
|
||||
mod key;
|
||||
mod messages;
|
||||
mod pbkdf2;
|
||||
mod secret;
|
||||
mod signature;
|
||||
pub mod threadpool;
|
||||
|
||||
pub use exchange::{exchange, Exchange};
|
||||
pub use key::ScramKey;
|
||||
pub use secret::ServerSecret;
|
||||
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
|
||||
/// A list of supported SCRAM methods.
|
||||
pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
|
||||
pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
|
||||
|
||||
/// Decode base64 into array without any heap allocations
|
||||
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
|
||||
let mut bytes = [0u8; N];
|
||||
|
||||
let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
|
||||
if size != N {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(bytes)
|
||||
}
|
||||
|
||||
/// This function essentially is `Hmac(sha256, key, input)`.
|
||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
|
||||
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
|
||||
parts.into_iter().for_each(|s| mac.update(s));
|
||||
|
||||
mac.finalize().into_bytes().into()
|
||||
}
|
||||
|
||||
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
parts.into_iter().for_each(|s| hasher.update(s));
|
||||
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
intern::EndpointIdInt,
|
||||
sasl::{Mechanism, Step},
|
||||
EndpointId,
|
||||
};
|
||||
|
||||
use super::{threadpool::ThreadPool, Exchange, ServerSecret};
|
||||
|
||||
#[test]
|
||||
fn snapshot() {
|
||||
let iterations = 4096;
|
||||
let salt = "QSXCR+Q6sek8bf92";
|
||||
let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
|
||||
let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
|
||||
let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
|
||||
let secret = ServerSecret::parse(&secret).unwrap();
|
||||
|
||||
const NONCE: [u8; 18] = [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
];
|
||||
let mut exchange = Exchange::new(
|
||||
&secret,
|
||||
|| NONCE,
|
||||
crate::config::TlsServerEndPoint::Undefined,
|
||||
);
|
||||
|
||||
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
|
||||
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
|
||||
let server_first =
|
||||
"r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,s=QSXCR+Q6sek8bf92,i=4096";
|
||||
let server_final = "v=qtUDIofVnIhM7tKn93EQUUt5vgMOldcDVu1HC+OH0o0=";
|
||||
|
||||
exchange = match exchange.exchange(client_first).unwrap() {
|
||||
Step::Continue(exchange, message) => {
|
||||
assert_eq!(message, server_first);
|
||||
exchange
|
||||
}
|
||||
Step::Success(_, _) => panic!("expected continue, got success"),
|
||||
Step::Failure(f) => panic!("{f}"),
|
||||
};
|
||||
|
||||
let key = match exchange.exchange(client_final).unwrap() {
|
||||
Step::Success(key, message) => {
|
||||
assert_eq!(message, server_final);
|
||||
key
|
||||
}
|
||||
Step::Continue(_, _) => panic!("expected success, got continue"),
|
||||
Step::Failure(f) => panic!("{f}"),
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
key.as_bytes(),
|
||||
[
|
||||
74, 103, 1, 132, 12, 31, 200, 48, 28, 54, 82, 232, 207, 12, 138, 189, 40, 32, 134,
|
||||
27, 125, 170, 232, 35, 171, 167, 166, 41, 70, 228, 182, 112,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
async fn run_round_trip_test(server_password: &str, client_password: &str) {
|
||||
let pool = ThreadPool::new(1);
|
||||
|
||||
let ep = EndpointId::from("foo");
|
||||
let ep = EndpointIdInt::from(ep);
|
||||
|
||||
let scram_secret = ServerSecret::build(server_password).await.unwrap();
|
||||
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match outcome {
|
||||
crate::sasl::Outcome::Success(_) => {}
|
||||
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn round_trip() {
|
||||
run_round_trip_test("pencil", "pencil").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic(expected = "password doesn't match")]
|
||||
async fn failure() {
|
||||
run_round_trip_test("pencil", "eraser").await
|
||||
}
|
||||
}
|
||||
173
proxy/core/src/scram/countmin.rs
Normal file
173
proxy/core/src/scram/countmin.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use std::hash::Hash;
|
||||
|
||||
/// estimator of hash jobs per second.
|
||||
/// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
|
||||
pub struct CountMinSketch {
|
||||
// one for each depth
|
||||
hashers: Vec<ahash::RandomState>,
|
||||
width: usize,
|
||||
depth: usize,
|
||||
// buckets, width*depth
|
||||
buckets: Vec<u32>,
|
||||
}
|
||||
|
||||
impl CountMinSketch {
|
||||
/// Given parameters (ε, δ),
|
||||
/// set width = ceil(e/ε)
|
||||
/// set depth = ceil(ln(1/δ))
|
||||
///
|
||||
/// guarantees:
|
||||
/// actual <= estimate
|
||||
/// estimate <= actual + ε * N with probability 1 - δ
|
||||
/// where N is the cardinality of the stream
|
||||
pub fn with_params(epsilon: f64, delta: f64) -> Self {
|
||||
CountMinSketch::new(
|
||||
(std::f64::consts::E / epsilon).ceil() as usize,
|
||||
(1.0_f64 / delta).ln().ceil() as usize,
|
||||
)
|
||||
}
|
||||
|
||||
fn new(width: usize, depth: usize) -> Self {
|
||||
Self {
|
||||
#[cfg(test)]
|
||||
hashers: (0..depth)
|
||||
.map(|i| {
|
||||
// digits of pi for good randomness
|
||||
ahash::RandomState::with_seeds(
|
||||
314159265358979323,
|
||||
84626433832795028,
|
||||
84197169399375105,
|
||||
82097494459230781 + i as u64,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
#[cfg(not(test))]
|
||||
hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
|
||||
width,
|
||||
depth,
|
||||
buckets: vec![0; width * depth],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
|
||||
let mut min = u32::MAX;
|
||||
for row in 0..self.depth {
|
||||
let col = (self.hashers[row].hash_one(t) as usize) % self.width;
|
||||
|
||||
let row = &mut self.buckets[row * self.width..][..self.width];
|
||||
row[col] = row[col].saturating_add(x);
|
||||
min = std::cmp::min(min, row[col]);
|
||||
}
|
||||
min
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.buckets.clear();
|
||||
self.buckets.resize(self.width * self.depth, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
|
||||
|
||||
use super::CountMinSketch;
|
||||
|
||||
fn eval_precision(n: usize, p: f64, q: f64) -> usize {
|
||||
// fixed value of phi for consistent test
|
||||
let mut rng = StdRng::seed_from_u64(16180339887498948482);
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
let mut N = 0;
|
||||
|
||||
let mut ids = vec![];
|
||||
|
||||
for _ in 0..n {
|
||||
// number of insert operations
|
||||
let n = rng.gen_range(1..100);
|
||||
// number to insert at once
|
||||
let m = rng.gen_range(1..4096);
|
||||
|
||||
let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
|
||||
ids.push((id, n, m));
|
||||
|
||||
// N = sum(actual)
|
||||
N += n * m;
|
||||
}
|
||||
|
||||
// q% of counts will be within p of the actual value
|
||||
let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
|
||||
|
||||
dbg!(sketch.buckets.len());
|
||||
|
||||
// insert a bunch of entries in a random order
|
||||
let mut ids2 = ids.clone();
|
||||
while !ids2.is_empty() {
|
||||
ids2.shuffle(&mut rng);
|
||||
|
||||
let mut i = 0;
|
||||
while i < ids2.len() {
|
||||
sketch.inc_and_return(&ids2[i].0, ids2[i].1);
|
||||
ids2[i].2 -= 1;
|
||||
if ids2[i].2 == 0 {
|
||||
ids2.remove(i);
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut within_p = 0;
|
||||
for (id, n, m) in ids {
|
||||
let actual = n * m;
|
||||
let estimate = sketch.inc_and_return(&id, 0);
|
||||
|
||||
// This estimate has the guarantee that actual <= estimate
|
||||
assert!(actual <= estimate);
|
||||
|
||||
// This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
|
||||
// ε = p / N, δ = 1 - q;
|
||||
// therefore, estimate <= actual + p with probability q.
|
||||
if estimate as f64 <= actual as f64 + p {
|
||||
within_p += 1;
|
||||
}
|
||||
}
|
||||
within_p
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn precision() {
|
||||
assert_eq!(eval_precision(100, 100.0, 0.99), 100);
|
||||
assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
|
||||
assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
|
||||
assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
|
||||
|
||||
// seems to be more precise than the literature indicates?
|
||||
// probably numbers are too small to truly represent the probabilities.
|
||||
assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
|
||||
assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
|
||||
assert_eq!(eval_precision(100, 4096.0, 0.1), 98);
|
||||
assert_eq!(eval_precision(1000, 4096.0, 0.1), 991);
|
||||
}
|
||||
|
||||
// returns memory usage in bytes, and the time complexity per insert.
|
||||
fn eval_cost(p: f64, q: f64) -> (usize, usize) {
|
||||
#[allow(non_snake_case)]
|
||||
// N = sum(actual)
|
||||
// Let's assume 1021 samples, all of 4096
|
||||
let N = 1021 * 4096;
|
||||
let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
|
||||
|
||||
let memory = size_of::<u32>() * sketch.buckets.len();
|
||||
let time = sketch.depth;
|
||||
(memory, time)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memory_usage() {
|
||||
assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
|
||||
assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
|
||||
assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
|
||||
assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
|
||||
}
|
||||
}
|
||||
234
proxy/core/src/scram/exchange.rs
Normal file
234
proxy/core/src/scram/exchange.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
//! Implementation of the SCRAM authentication algorithm.
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
use super::messages::{
|
||||
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
|
||||
};
|
||||
use super::pbkdf2::Pbkdf2;
|
||||
use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use super::threadpool::ThreadPool;
|
||||
use super::ScramKey;
|
||||
use crate::config;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
#[derive(Debug)]
|
||||
struct TlsServerEndPoint;
|
||||
|
||||
impl std::fmt::Display for TlsServerEndPoint {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "tls-server-end-point")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for TlsServerEndPoint {
|
||||
type Err = sasl::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"tls-server-end-point" => Ok(TlsServerEndPoint),
|
||||
_ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SaslSentInner {
|
||||
cbind_flag: ChannelBinding<TlsServerEndPoint>,
|
||||
client_first_message_bare: String,
|
||||
server_first_message: OwnedServerFirstMessage,
|
||||
}
|
||||
|
||||
struct SaslInitial {
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
}
|
||||
|
||||
enum ExchangeState {
|
||||
/// Waiting for [`ClientFirstMessage`].
|
||||
Initial(SaslInitial),
|
||||
/// Waiting for [`ClientFinalMessage`].
|
||||
SaltSent(SaslSentInner),
|
||||
}
|
||||
|
||||
/// Server's side of SCRAM auth algorithm.
|
||||
pub struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
}
|
||||
|
||||
impl<'a> Exchange<'a> {
|
||||
pub fn new(
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: ExchangeState::Initial(SaslInitial { nonce }),
|
||||
secret,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
|
||||
async fn derive_client_key(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
password: &[u8],
|
||||
salt: &[u8],
|
||||
iterations: u32,
|
||||
) -> ScramKey {
|
||||
let salted_password = pool
|
||||
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
|
||||
.await
|
||||
.expect("job should not be cancelled");
|
||||
|
||||
let make_key = |name| {
|
||||
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
|
||||
.expect("HMAC is able to accept all key sizes")
|
||||
.chain_update(name)
|
||||
.finalize();
|
||||
|
||||
<[u8; 32]>::from(key.into_bytes())
|
||||
};
|
||||
|
||||
make_key(b"Client Key").into()
|
||||
}
|
||||
|
||||
pub async fn exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||
let salt = base64::decode(&secret.salt_base64)?;
|
||||
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
|
||||
if secret.is_password_invalid(&client_key).into() {
|
||||
Ok(sasl::Outcome::Failure("password doesn't match"))
|
||||
} else {
|
||||
Ok(sasl::Outcome::Success(client_key))
|
||||
}
|
||||
}
|
||||
|
||||
impl SaslInitial {
|
||||
fn transition(
|
||||
&self,
|
||||
secret: &ServerSecret,
|
||||
tls_server_end_point: &config::TlsServerEndPoint,
|
||||
input: &str,
|
||||
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
|
||||
let client_first_message = ClientFirstMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
|
||||
|
||||
// If the flag is set to "y" and the server supports channel
|
||||
// binding, the server MUST fail authentication
|
||||
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
|
||||
&& tls_server_end_point.supported()
|
||||
{
|
||||
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
|
||||
}
|
||||
|
||||
let server_first_message = client_first_message.build_server_first_message(
|
||||
&(self.nonce)(),
|
||||
&secret.salt_base64,
|
||||
secret.iterations,
|
||||
);
|
||||
let msg = server_first_message.as_str().to_owned();
|
||||
|
||||
let next = SaslSentInner {
|
||||
cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
|
||||
client_first_message_bare: client_first_message.bare.to_owned(),
|
||||
server_first_message,
|
||||
};
|
||||
|
||||
Ok(sasl::Step::Continue(next, msg))
|
||||
}
|
||||
}
|
||||
|
||||
impl SaslSentInner {
|
||||
fn transition(
|
||||
&self,
|
||||
secret: &ServerSecret,
|
||||
tls_server_end_point: &config::TlsServerEndPoint,
|
||||
input: &str,
|
||||
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
|
||||
let Self {
|
||||
cbind_flag,
|
||||
client_first_message_bare,
|
||||
server_first_message,
|
||||
} = self;
|
||||
|
||||
let client_final_message = ClientFinalMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
|
||||
config::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
if client_final_message.channel_binding != channel_binding {
|
||||
return Err(SaslError::ChannelBindingFailed(
|
||||
"insecure connection: secure channel data mismatch",
|
||||
));
|
||||
}
|
||||
|
||||
if client_final_message.nonce != server_first_message.nonce() {
|
||||
return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
|
||||
}
|
||||
|
||||
let signature_builder = SignatureBuilder {
|
||||
client_first_message_bare,
|
||||
server_first_message: server_first_message.as_str(),
|
||||
client_final_message_without_proof: client_final_message.without_proof,
|
||||
};
|
||||
|
||||
let client_key = signature_builder
|
||||
.build(&secret.stored_key)
|
||||
.derive_client_key(&client_final_message.proof);
|
||||
|
||||
// Auth fails either if keys don't match or it's pre-determined to fail.
|
||||
if secret.is_password_invalid(&client_key).into() {
|
||||
return Ok(sasl::Step::Failure("password doesn't match"));
|
||||
}
|
||||
|
||||
let msg =
|
||||
client_final_message.build_server_final_message(signature_builder, &secret.server_key);
|
||||
|
||||
Ok(sasl::Step::Success(client_key, msg))
|
||||
}
|
||||
}
|
||||
|
||||
impl sasl::Mechanism for Exchange<'_> {
|
||||
type Output = super::ScramKey;
|
||||
|
||||
fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
|
||||
use {sasl::Step::*, ExchangeState::*};
|
||||
match &self.state {
|
||||
Initial(init) => {
|
||||
match init.transition(self.secret, &self.tls_server_end_point, input)? {
|
||||
Continue(sent, msg) => {
|
||||
self.state = SaltSent(sent);
|
||||
Ok(Continue(self, msg))
|
||||
}
|
||||
Success(x, _) => match x {},
|
||||
Failure(msg) => Ok(Failure(msg)),
|
||||
}
|
||||
}
|
||||
SaltSent(sent) => {
|
||||
match sent.transition(self.secret, &self.tls_server_end_point, input)? {
|
||||
Success(keys, msg) => Ok(Success(keys, msg)),
|
||||
Continue(x, _) => match x {},
|
||||
Failure(msg) => Ok(Failure(msg)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
51
proxy/core/src/scram/key.rs
Normal file
51
proxy/core/src/scram/key.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
//! Tools for client/server/stored key management.
|
||||
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_KEY_LEN: usize = 32;
|
||||
|
||||
/// One of the keys derived from the user's password.
|
||||
/// We use the same structure for all keys, i.e.
|
||||
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
||||
#[derive(Clone, Default, Eq, Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct ScramKey {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl PartialEq for ScramKey {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.ct_eq(other).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConstantTimeEq for ScramKey {
|
||||
fn ct_eq(&self, other: &Self) -> subtle::Choice {
|
||||
self.bytes.ct_eq(&other.bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl ScramKey {
|
||||
pub fn sha256(&self) -> Self {
|
||||
super::sha256([self.as_ref()]).into()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
|
||||
self.bytes
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
|
||||
#[inline(always)]
|
||||
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for ScramKey {
|
||||
#[inline(always)]
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.bytes
|
||||
}
|
||||
}
|
||||
257
proxy/core/src/scram/messages.rs
Normal file
257
proxy/core/src/scram/messages.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
//! Definitions for SCRAM messages.
|
||||
|
||||
use super::base64_decode_array;
|
||||
use super::key::{ScramKey, SCRAM_KEY_LEN};
|
||||
use super::signature::SignatureBuilder;
|
||||
use crate::sasl::ChannelBinding;
|
||||
use std::fmt;
|
||||
use std::ops::Range;
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_RAW_NONCE_LEN: usize = 18;
|
||||
|
||||
/// Although we ignore all extensions, we still have to validate the message.
|
||||
fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
|
||||
for mut chars in parts.map(|s| s.chars()) {
|
||||
let attr = chars.next()?;
|
||||
if !attr.is_ascii_alphabetic() {
|
||||
return None;
|
||||
}
|
||||
let eq = chars.next()?;
|
||||
if eq != '=' {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientFirstMessage<'a> {
|
||||
/// `client-first-message-bare`.
|
||||
pub bare: &'a str,
|
||||
/// Channel binding mode.
|
||||
pub cbind_flag: ChannelBinding<&'a str>,
|
||||
/// Client nonce.
|
||||
pub nonce: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> ClientFirstMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
let mut parts = input.split(',');
|
||||
|
||||
let cbind_flag = ChannelBinding::parse(parts.next()?)?;
|
||||
|
||||
// PG doesn't support authorization identity,
|
||||
// so we don't bother defining GS2 header type
|
||||
let authzid = parts.next()?;
|
||||
if !authzid.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Unfortunately, `parts.as_str()` is unstable
|
||||
let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
|
||||
let (_, bare) = input.split_at(pos);
|
||||
|
||||
// In theory, these might be preceded by "reserved-mext" (i.e. "m=")
|
||||
let username = parts.next()?.strip_prefix("n=")?;
|
||||
|
||||
// https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
|
||||
if !username.is_empty() {
|
||||
tracing::warn!(username, "scram username provided, but is not expected")
|
||||
// TODO(conrad):
|
||||
// return None;
|
||||
}
|
||||
|
||||
let nonce = parts.next()?.strip_prefix("r=")?;
|
||||
|
||||
// Validate but ignore auth extensions
|
||||
validate_sasl_extensions(parts)?;
|
||||
|
||||
Some(Self {
|
||||
bare,
|
||||
cbind_flag,
|
||||
nonce,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a response to [`ClientFirstMessage`].
|
||||
pub fn build_server_first_message(
|
||||
&self,
|
||||
nonce: &[u8; SCRAM_RAW_NONCE_LEN],
|
||||
salt_base64: &str,
|
||||
iterations: u32,
|
||||
) -> OwnedServerFirstMessage {
|
||||
use std::fmt::Write;
|
||||
|
||||
let mut message = String::new();
|
||||
write!(&mut message, "r={}", self.nonce).unwrap();
|
||||
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
|
||||
let combined_nonce = 2..message.len();
|
||||
write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
|
||||
|
||||
// This design guarantees that it's impossible to create a
|
||||
// server-first-message without receiving a client-first-message
|
||||
OwnedServerFirstMessage {
|
||||
message,
|
||||
nonce: combined_nonce,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientFinalMessage<'a> {
|
||||
/// `client-final-message-without-proof`.
|
||||
pub without_proof: &'a str,
|
||||
/// Channel binding data (base64).
|
||||
pub channel_binding: &'a str,
|
||||
/// Combined client & server nonce.
|
||||
pub nonce: &'a str,
|
||||
/// Client auth proof.
|
||||
pub proof: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl<'a> ClientFinalMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
let (without_proof, proof) = input.rsplit_once(',')?;
|
||||
|
||||
let mut parts = without_proof.split(',');
|
||||
let channel_binding = parts.next()?.strip_prefix("c=")?;
|
||||
let nonce = parts.next()?.strip_prefix("r=")?;
|
||||
|
||||
// Validate but ignore auth extensions
|
||||
validate_sasl_extensions(parts)?;
|
||||
|
||||
let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
|
||||
|
||||
Some(Self {
|
||||
without_proof,
|
||||
channel_binding,
|
||||
nonce,
|
||||
proof,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a response to [`ClientFinalMessage`].
|
||||
pub fn build_server_final_message(
|
||||
&self,
|
||||
signature_builder: SignatureBuilder,
|
||||
server_key: &ScramKey,
|
||||
) -> String {
|
||||
let mut buf = String::from("v=");
|
||||
base64::encode_config_buf(
|
||||
signature_builder.build(server_key),
|
||||
base64::STANDARD,
|
||||
&mut buf,
|
||||
);
|
||||
|
||||
buf
|
||||
}
|
||||
}
|
||||
|
||||
/// We need to keep a convenient representation of this
|
||||
/// message for the next authentication step.
|
||||
pub struct OwnedServerFirstMessage {
|
||||
/// Owned `server-first-message`.
|
||||
message: String,
|
||||
/// Slice into `message`.
|
||||
nonce: Range<usize>,
|
||||
}
|
||||
|
||||
impl OwnedServerFirstMessage {
|
||||
/// Extract combined nonce from the message.
|
||||
#[inline(always)]
|
||||
pub fn nonce(&self) -> &str {
|
||||
&self.message[self.nonce.clone()]
|
||||
}
|
||||
|
||||
/// Get reference to a text representation of the message.
|
||||
#[inline(always)]
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.message
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for OwnedServerFirstMessage {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("ServerFirstMessage")
|
||||
.field("message", &self.as_str())
|
||||
.field("nonce", &self.nonce())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_client_first_message() {
|
||||
use ChannelBinding::*;
|
||||
|
||||
// (Almost) real strings captured during debug sessions
|
||||
let cases = [
|
||||
(NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
|
||||
(NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
|
||||
(
|
||||
Required("tls-server-end-point"),
|
||||
"p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
|
||||
),
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
let msg = ClientFirstMessage::parse(input).unwrap();
|
||||
|
||||
assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
|
||||
assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
|
||||
assert_eq!(msg.cbind_flag, cb);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_client_first_message_with_invalid_gs2_authz() {
|
||||
assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_client_first_message_with_extra_params() {
|
||||
let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
|
||||
assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
|
||||
assert_eq!(msg.nonce, "nonce");
|
||||
assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_client_first_message_with_extra_params_invalid() {
|
||||
// must be of the form `<ascii letter>=<...>`
|
||||
assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
|
||||
assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
|
||||
assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_client_final_message() {
|
||||
let input = [
|
||||
"c=eSws",
|
||||
"r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
|
||||
"p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
|
||||
]
|
||||
.join(",");
|
||||
|
||||
let msg = ClientFinalMessage::parse(&input).unwrap();
|
||||
assert_eq!(
|
||||
msg.without_proof,
|
||||
"c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
|
||||
);
|
||||
assert_eq!(
|
||||
msg.nonce,
|
||||
"iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
|
||||
);
|
||||
assert_eq!(
|
||||
base64::encode(msg.proof),
|
||||
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
|
||||
);
|
||||
}
|
||||
}
|
||||
89
proxy/core/src/scram/pbkdf2.rs
Normal file
89
proxy/core/src/scram/pbkdf2.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use hmac::{
|
||||
digest::{consts::U32, generic_array::GenericArray},
|
||||
Hmac, Mac,
|
||||
};
|
||||
use sha2::Sha256;
|
||||
|
||||
pub struct Pbkdf2 {
|
||||
hmac: Hmac<Sha256>,
|
||||
prev: GenericArray<u8, U32>,
|
||||
hi: GenericArray<u8, U32>,
|
||||
iterations: u32,
|
||||
}
|
||||
|
||||
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
|
||||
impl Pbkdf2 {
|
||||
pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
|
||||
let hmac =
|
||||
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
||||
|
||||
let prev = hmac
|
||||
.clone()
|
||||
.chain_update(salt)
|
||||
.chain_update(1u32.to_be_bytes())
|
||||
.finalize()
|
||||
.into_bytes();
|
||||
|
||||
Self {
|
||||
hmac,
|
||||
// one consumed for the hash above
|
||||
iterations: iterations - 1,
|
||||
hi: prev,
|
||||
prev,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cost(&self) -> u32 {
|
||||
(self.iterations).clamp(0, 4096)
|
||||
}
|
||||
|
||||
pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
|
||||
let Self {
|
||||
hmac,
|
||||
prev,
|
||||
hi,
|
||||
iterations,
|
||||
} = self;
|
||||
|
||||
// only do 4096 iterations per turn before sharing the thread for fairness
|
||||
let n = (*iterations).clamp(0, 4096);
|
||||
for _ in 0..n {
|
||||
*prev = hmac.clone().chain_update(*prev).finalize().into_bytes();
|
||||
|
||||
for (hi, prev) in hi.iter_mut().zip(*prev) {
|
||||
*hi ^= prev;
|
||||
}
|
||||
}
|
||||
|
||||
*iterations -= n;
|
||||
if *iterations == 0 {
|
||||
std::task::Poll::Ready((*hi).into())
|
||||
} else {
|
||||
std::task::Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Pbkdf2;
|
||||
use pbkdf2::pbkdf2_hmac_array;
|
||||
use sha2::Sha256;
|
||||
|
||||
#[test]
|
||||
fn works() {
|
||||
let salt = b"sodium chloride";
|
||||
let pass = b"Ne0n_!5_50_C007";
|
||||
|
||||
let mut job = Pbkdf2::start(pass, salt, 600000);
|
||||
let hash = loop {
|
||||
let std::task::Poll::Ready(hash) = job.turn() else {
|
||||
continue;
|
||||
};
|
||||
break hash;
|
||||
};
|
||||
|
||||
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 600000);
|
||||
assert_eq!(hash, expected)
|
||||
}
|
||||
}
|
||||
100
proxy/core/src/scram/secret.rs
Normal file
100
proxy/core/src/scram/secret.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
//! Tools for SCRAM server secret management.
|
||||
|
||||
use subtle::{Choice, ConstantTimeEq};
|
||||
|
||||
use super::base64_decode_array;
|
||||
use super::key::ScramKey;
|
||||
|
||||
/// Server secret is produced from user's password,
|
||||
/// and is used throughout the authentication process.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub struct ServerSecret {
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
pub iterations: u32,
|
||||
/// Salt used to hash user's password.
|
||||
pub salt_base64: String,
|
||||
/// Hashed `ClientKey`.
|
||||
pub stored_key: ScramKey,
|
||||
/// Used by client to verify server's signature.
|
||||
pub server_key: ScramKey,
|
||||
/// Should auth fail no matter what?
|
||||
/// This is exactly the case for mocked secrets.
|
||||
pub doomed: bool,
|
||||
}
|
||||
|
||||
impl ServerSecret {
|
||||
pub fn parse(input: &str) -> Option<Self> {
|
||||
// SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
|
||||
let s = input.strip_prefix("SCRAM-SHA-256$")?;
|
||||
let (params, keys) = s.split_once('$')?;
|
||||
|
||||
let ((iterations, salt), (stored_key, server_key)) =
|
||||
params.split_once(':').zip(keys.split_once(':'))?;
|
||||
|
||||
let secret = ServerSecret {
|
||||
iterations: iterations.parse().ok()?,
|
||||
salt_base64: salt.to_owned(),
|
||||
stored_key: base64_decode_array(stored_key)?.into(),
|
||||
server_key: base64_decode_array(server_key)?.into(),
|
||||
doomed: false,
|
||||
};
|
||||
|
||||
Some(secret)
|
||||
}
|
||||
|
||||
pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice {
|
||||
// constant time to not leak partial key match
|
||||
client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8)
|
||||
}
|
||||
|
||||
/// To avoid revealing information to an attacker, we use a
|
||||
/// mocked server secret even if the user doesn't exist.
|
||||
/// See `auth-scram.c : mock_scram_secret` for details.
|
||||
pub fn mock(nonce: [u8; 32]) -> Self {
|
||||
Self {
|
||||
// this doesn't reveal much information as we're going to use
|
||||
// iteration count 1 for our generated passwords going forward.
|
||||
// PG16 users can set iteration count=1 already today.
|
||||
iterations: 1,
|
||||
salt_base64: base64::encode(nonce),
|
||||
stored_key: ScramKey::default(),
|
||||
server_key: ScramKey::default(),
|
||||
doomed: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a new server secret from the prerequisites.
|
||||
/// XXX: We only use this function in tests.
|
||||
#[cfg(test)]
|
||||
pub async fn build(password: &str) -> Option<Self> {
|
||||
Self::parse(&postgres_protocol::password::scram_sha_256(password.as_bytes()).await)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_scram_secret() {
|
||||
let iterations = 4096;
|
||||
let salt = "+/tQQax7twvwTj64mjBsxQ==";
|
||||
let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns=";
|
||||
let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI=";
|
||||
|
||||
let secret = format!(
|
||||
"SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",
|
||||
iterations = iterations,
|
||||
salt = salt,
|
||||
stored_key = stored_key,
|
||||
server_key = server_key,
|
||||
);
|
||||
|
||||
let parsed = ServerSecret::parse(&secret).unwrap();
|
||||
assert_eq!(parsed.iterations, iterations);
|
||||
assert_eq!(parsed.salt_base64, salt);
|
||||
|
||||
assert_eq!(base64::encode(parsed.stored_key), stored_key);
|
||||
assert_eq!(base64::encode(parsed.server_key), server_key);
|
||||
}
|
||||
}
|
||||
66
proxy/core/src/scram/signature.rs
Normal file
66
proxy/core/src/scram/signature.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
//! Tools for client/server signature management.
|
||||
|
||||
use super::key::{ScramKey, SCRAM_KEY_LEN};
|
||||
|
||||
/// A collection of message parts needed to derive the client's signature.
|
||||
#[derive(Debug)]
|
||||
pub struct SignatureBuilder<'a> {
|
||||
pub client_first_message_bare: &'a str,
|
||||
pub server_first_message: &'a str,
|
||||
pub client_final_message_without_proof: &'a str,
|
||||
}
|
||||
|
||||
impl SignatureBuilder<'_> {
|
||||
pub fn build(&self, key: &ScramKey) -> Signature {
|
||||
let parts = [
|
||||
self.client_first_message_bare.as_bytes(),
|
||||
b",",
|
||||
self.server_first_message.as_bytes(),
|
||||
b",",
|
||||
self.client_final_message_without_proof.as_bytes(),
|
||||
];
|
||||
|
||||
super::hmac_sha256(key.as_ref(), parts).into()
|
||||
}
|
||||
}
|
||||
|
||||
/// A computed value which, when xored with `ClientProof`,
|
||||
/// produces `ClientKey` that we need for authentication.
|
||||
#[derive(Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct Signature {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
/// Derive `ClientKey` from client's signature and proof.
|
||||
pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey {
|
||||
// This is how the proof is calculated:
|
||||
//
|
||||
// 1. sha256(ClientKey) -> StoredKey
|
||||
// 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature
|
||||
// 3. ClientKey ^ ClientSignature -> ClientProof
|
||||
//
|
||||
// Step 3 implies that we can restore ClientKey from the proof
|
||||
// by xoring the latter with the ClientSignature. Afterwards we
|
||||
// can check that the presumed ClientKey meets our expectations.
|
||||
let mut signature = self.bytes;
|
||||
for (i, x) in proof.iter().enumerate() {
|
||||
signature[i] ^= x;
|
||||
}
|
||||
|
||||
signature.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SCRAM_KEY_LEN]> for Signature {
|
||||
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for Signature {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.bytes
|
||||
}
|
||||
}
|
||||
321
proxy/core/src/scram/threadpool.rs
Normal file
321
proxy/core/src/scram/threadpool.rs
Normal file
@@ -0,0 +1,321 @@
|
||||
//! Custom threadpool implementation for password hashing.
|
||||
//!
|
||||
//! Requirements:
|
||||
//! 1. Fairness per endpoint.
|
||||
//! 2. Yield support for high iteration counts.
|
||||
|
||||
use std::sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
use crossbeam_deque::{Injector, Stealer, Worker};
|
||||
use itertools::Itertools;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use rand::Rng;
|
||||
use rand::{rngs::SmallRng, SeedableRng};
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::{
|
||||
intern::EndpointIdInt,
|
||||
metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
|
||||
scram::countmin::CountMinSketch,
|
||||
};
|
||||
|
||||
use super::pbkdf2::Pbkdf2;
|
||||
|
||||
pub struct ThreadPool {
|
||||
queue: Injector<JobSpec>,
|
||||
stealers: Vec<Stealer<JobSpec>>,
|
||||
parkers: Vec<(Condvar, Mutex<ThreadState>)>,
|
||||
/// bitpacked representation.
|
||||
/// lower 8 bits = number of sleeping threads
|
||||
/// next 8 bits = number of idle threads (searching for work)
|
||||
counters: AtomicU64,
|
||||
|
||||
pub metrics: Arc<ThreadPoolMetrics>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum ThreadState {
|
||||
Parked,
|
||||
Active,
|
||||
}
|
||||
|
||||
impl ThreadPool {
|
||||
pub fn new(n_workers: u8) -> Arc<Self> {
|
||||
let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
|
||||
let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
|
||||
|
||||
let parkers = (0..n_workers)
|
||||
.map(|_| (Condvar::new(), Mutex::new(ThreadState::Active)))
|
||||
.collect_vec();
|
||||
|
||||
let pool = Arc::new(Self {
|
||||
queue: Injector::new(),
|
||||
stealers,
|
||||
parkers,
|
||||
// threads start searching for work
|
||||
counters: AtomicU64::new((n_workers as u64) << 8),
|
||||
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
|
||||
});
|
||||
|
||||
for (i, worker) in workers.into_iter().enumerate() {
|
||||
let pool = Arc::clone(&pool);
|
||||
std::thread::spawn(move || thread_rt(pool, worker, i));
|
||||
}
|
||||
|
||||
pool
|
||||
}
|
||||
|
||||
pub fn spawn_job(
|
||||
&self,
|
||||
endpoint: EndpointIdInt,
|
||||
pbkdf2: Pbkdf2,
|
||||
) -> oneshot::Receiver<[u8; 32]> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let queue_was_empty = self.queue.is_empty();
|
||||
|
||||
self.metrics.injector_queue_depth.inc();
|
||||
self.queue.push(JobSpec {
|
||||
response: tx,
|
||||
pbkdf2,
|
||||
endpoint,
|
||||
});
|
||||
|
||||
// inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
|
||||
let counts = self.counters.load(Ordering::SeqCst);
|
||||
let num_awake_but_idle = (counts >> 8) & 0xff;
|
||||
let num_sleepers = counts & 0xff;
|
||||
|
||||
// If the queue is non-empty, then we always wake up a worker
|
||||
// -- clearly the existing idle jobs aren't enough. Otherwise,
|
||||
// check to see if we have enough idle workers.
|
||||
if !queue_was_empty || num_awake_but_idle == 0 {
|
||||
let num_to_wake = Ord::min(1, num_sleepers);
|
||||
self.wake_any_threads(num_to_wake);
|
||||
}
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
#[cold]
|
||||
fn wake_any_threads(&self, mut num_to_wake: u64) {
|
||||
if num_to_wake > 0 {
|
||||
for i in 0..self.parkers.len() {
|
||||
if self.wake_specific_thread(i) {
|
||||
num_to_wake -= 1;
|
||||
if num_to_wake == 0 {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wake_specific_thread(&self, index: usize) -> bool {
|
||||
let (condvar, lock) = &self.parkers[index];
|
||||
|
||||
let mut state = lock.lock();
|
||||
if *state == ThreadState::Parked {
|
||||
condvar.notify_one();
|
||||
|
||||
// When the thread went to sleep, it will have incremented
|
||||
// this value. When we wake it, its our job to decrement
|
||||
// it. We could have the thread do it, but that would
|
||||
// introduce a delay between when the thread was
|
||||
// *notified* and when this counter was decremented. That
|
||||
// might mislead people with new work into thinking that
|
||||
// there are sleeping threads that they should try to
|
||||
// wake, when in fact there is nothing left for them to
|
||||
// do.
|
||||
self.counters.fetch_sub(1, Ordering::SeqCst);
|
||||
*state = ThreadState::Active;
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
|
||||
// announce thread as idle
|
||||
self.counters.fetch_add(256, Ordering::SeqCst);
|
||||
|
||||
// try steal from the global queue
|
||||
loop {
|
||||
match self.queue.steal_batch_and_pop(worker) {
|
||||
crossbeam_deque::Steal::Success(job) => {
|
||||
self.metrics
|
||||
.injector_queue_depth
|
||||
.set(self.queue.len() as i64);
|
||||
// no longer idle
|
||||
self.counters.fetch_sub(256, Ordering::SeqCst);
|
||||
return Some(job);
|
||||
}
|
||||
crossbeam_deque::Steal::Retry => continue,
|
||||
crossbeam_deque::Steal::Empty => break,
|
||||
}
|
||||
}
|
||||
|
||||
// try steal from our neighbours
|
||||
loop {
|
||||
let mut retry = false;
|
||||
let start = rng.gen_range(0..self.stealers.len());
|
||||
let job = (start..self.stealers.len())
|
||||
.chain(0..start)
|
||||
.filter(|i| *i != skip)
|
||||
.find_map(
|
||||
|victim| match self.stealers[victim].steal_batch_and_pop(worker) {
|
||||
crossbeam_deque::Steal::Success(job) => Some(job),
|
||||
crossbeam_deque::Steal::Empty => None,
|
||||
crossbeam_deque::Steal::Retry => {
|
||||
retry = true;
|
||||
None
|
||||
}
|
||||
},
|
||||
);
|
||||
if job.is_some() {
|
||||
// no longer idle
|
||||
self.counters.fetch_sub(256, Ordering::SeqCst);
|
||||
return job;
|
||||
}
|
||||
if !retry {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
|
||||
/// interval when we should steal from the global queue
|
||||
/// so that tail latencies are managed appropriately
|
||||
const STEAL_INTERVAL: usize = 61;
|
||||
|
||||
/// How often to reset the sketch values
|
||||
const SKETCH_RESET_INTERVAL: usize = 1021;
|
||||
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
|
||||
// used to determine whether we should temporarily skip tasks for fairness.
|
||||
// 99% of estimates will overcount by no more than 4096 samples
|
||||
let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01);
|
||||
|
||||
let (condvar, lock) = &pool.parkers[index];
|
||||
|
||||
'wait: loop {
|
||||
// wait for notification of work
|
||||
{
|
||||
let mut lock = lock.lock();
|
||||
|
||||
// queue is empty
|
||||
pool.metrics
|
||||
.worker_queue_depth
|
||||
.set(ThreadPoolWorkerId(index), 0);
|
||||
|
||||
// subtract 1 from idle count, add 1 to sleeping count.
|
||||
pool.counters.fetch_sub(255, Ordering::SeqCst);
|
||||
|
||||
*lock = ThreadState::Parked;
|
||||
condvar.wait(&mut lock);
|
||||
}
|
||||
|
||||
for i in 0.. {
|
||||
let mut job = match worker
|
||||
.pop()
|
||||
.or_else(|| pool.steal(&mut rng, index, &worker))
|
||||
{
|
||||
Some(job) => job,
|
||||
None => continue 'wait,
|
||||
};
|
||||
|
||||
pool.metrics
|
||||
.worker_queue_depth
|
||||
.set(ThreadPoolWorkerId(index), worker.len() as i64);
|
||||
|
||||
// receiver is closed, cancel the task
|
||||
if !job.response.is_closed() {
|
||||
let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
|
||||
|
||||
const P: f64 = 2000.0;
|
||||
// probability decreases as rate increases.
|
||||
// lower probability, higher chance of being skipped
|
||||
//
|
||||
// estimates (rate in terms of 4096 rounds):
|
||||
// rate = 0 => probability = 100%
|
||||
// rate = 10 => probability = 71.3%
|
||||
// rate = 50 => probability = 62.1%
|
||||
// rate = 500 => probability = 52.3%
|
||||
// rate = 1021 => probability = 49.8%
|
||||
//
|
||||
// My expectation is that the pool queue will only begin backing up at ~1000rps
|
||||
// in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
|
||||
// are in requests per second.
|
||||
let probability = P.ln() / (P + rate as f64).ln();
|
||||
if pool.queue.len() > 32 || rng.gen_bool(probability) {
|
||||
pool.metrics
|
||||
.worker_task_turns_total
|
||||
.inc(ThreadPoolWorkerId(index));
|
||||
|
||||
match job.pbkdf2.turn() {
|
||||
std::task::Poll::Ready(result) => {
|
||||
let _ = job.response.send(result);
|
||||
}
|
||||
std::task::Poll::Pending => worker.push(job),
|
||||
}
|
||||
} else {
|
||||
pool.metrics
|
||||
.worker_task_skips_total
|
||||
.inc(ThreadPoolWorkerId(index));
|
||||
|
||||
// skip for now
|
||||
worker.push(job)
|
||||
}
|
||||
}
|
||||
|
||||
// if we get stuck with a few long lived jobs in the queue
|
||||
// it's better to try and steal from the queue too for fairness
|
||||
if i % STEAL_INTERVAL == 0 {
|
||||
let _ = pool.queue.steal_batch(&worker);
|
||||
}
|
||||
|
||||
if i % SKETCH_RESET_INTERVAL == 0 {
|
||||
sketch.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct JobSpec {
|
||||
response: oneshot::Sender<[u8; 32]>,
|
||||
pbkdf2: Pbkdf2,
|
||||
endpoint: EndpointIdInt,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::EndpointId;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn hash_is_correct() {
|
||||
let pool = ThreadPool::new(1);
|
||||
|
||||
let ep = EndpointId::from("foo");
|
||||
let ep = EndpointIdInt::from(ep);
|
||||
|
||||
let salt = [0x55; 32];
|
||||
let actual = pool
|
||||
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let expected = [
|
||||
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
|
||||
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
|
||||
];
|
||||
assert_eq!(actual, expected)
|
||||
}
|
||||
}
|
||||
390
proxy/core/src/serverless.rs
Normal file
390
proxy/core/src/serverless.rs
Normal file
@@ -0,0 +1,390 @@
|
||||
//! Routers for our serverless APIs
|
||||
//!
|
||||
//! Handles both SQL over HTTP and SQL over Websockets.
|
||||
|
||||
mod backend;
|
||||
pub mod cancel_set;
|
||||
mod conn_pool;
|
||||
mod http_util;
|
||||
mod json;
|
||||
mod sql_over_http;
|
||||
mod websocket;
|
||||
|
||||
use atomic_take::AtomicTake;
|
||||
use bytes::Bytes;
|
||||
pub use conn_pool::GlobalConnPoolOptions;
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::future::{select, Either};
|
||||
use futures::TryFutureExt;
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Incoming;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::{server::TlsStream, TlsAcceptor};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{read_proxy_protocol, ChainRW};
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::serverless::http_util::{api_error_into_response, json_response};
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
pub const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("websocket server has shut down");
|
||||
}
|
||||
|
||||
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
|
||||
{
|
||||
let conn_pool = Arc::clone(&conn_pool);
|
||||
tokio::spawn(async move {
|
||||
conn_pool.gc_worker(StdRng::from_entropy()).await;
|
||||
});
|
||||
}
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::spawn({
|
||||
let cancellation_token = cancellation_token.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
async move {
|
||||
cancellation_token.cancelled().await;
|
||||
tokio::task::spawn_blocking(move || conn_pool.shutdown())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let backend = Arc::new(PoolingBackend {
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
});
|
||||
|
||||
let tls_config = match config.tls_config.as_ref() {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
warn!("TLS config is missing, WebSocket Secure server will not be started");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
|
||||
// prefer http2, but support http/1.1
|
||||
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
tracing::error!("could not set nodelay: {e}");
|
||||
continue;
|
||||
}
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
|
||||
|
||||
let n_connections = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.sample(crate::metrics::Protocol::Http);
|
||||
tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check");
|
||||
if n_connections > config.http_config.client_conn_threshold {
|
||||
tracing::trace!("attempting to cancel a random connection");
|
||||
if let Some(token) = config.http_config.cancel_set.take() {
|
||||
tracing::debug!("cancelling a random connection");
|
||||
token.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
let conn_token = cancellation_token.child_token();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let backend = backend.clone();
|
||||
let connections2 = connections.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
connections.spawn(
|
||||
async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
let _gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Http);
|
||||
|
||||
let startup_result = Box::pin(connection_startup(
|
||||
config,
|
||||
tls_acceptor,
|
||||
session_id,
|
||||
conn,
|
||||
peer_addr,
|
||||
))
|
||||
.await;
|
||||
let Some((conn, peer_addr)) = startup_result else {
|
||||
return;
|
||||
};
|
||||
|
||||
Box::pin(connection_handler(
|
||||
config,
|
||||
backend,
|
||||
connections2,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
conn,
|
||||
peer_addr,
|
||||
session_id,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
.instrument(http_conn_span),
|
||||
);
|
||||
}
|
||||
|
||||
connections.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handles the TCP startup lifecycle.
|
||||
/// 1. Parses PROXY protocol V2
|
||||
/// 2. Handles TLS handshake
|
||||
async fn connection_startup(
|
||||
config: &ProxyConfig,
|
||||
tls_acceptor: TlsAcceptor,
|
||||
session_id: uuid::Uuid,
|
||||
conn: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> Option<(TlsStream<ChainRW<TcpStream>>, IpAddr)> {
|
||||
// handle PROXY protocol
|
||||
let (conn, peer) = match read_proxy_protocol(conn).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let peer_addr = peer.unwrap_or(peer_addr).ip();
|
||||
let has_private_peer_addr = match peer_addr {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
_ => false,
|
||||
};
|
||||
info!(?session_id, %peer_addr, "accepted new TCP connection");
|
||||
|
||||
// try upgrade to TLS, but with a timeout.
|
||||
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
|
||||
Ok(Ok(conn)) => {
|
||||
info!(?session_id, %peer_addr, "accepted new TLS connection");
|
||||
conn
|
||||
}
|
||||
// The handshake failed
|
||||
Ok(Err(e)) => {
|
||||
if !has_private_peer_addr {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
return None;
|
||||
}
|
||||
// The handshake timed out
|
||||
Err(e) => {
|
||||
if !has_private_peer_addr {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
Some((conn, peer_addr))
|
||||
}
|
||||
|
||||
/// Handles HTTP connection
|
||||
/// 1. With graceful shutdowns
|
||||
/// 2. With graceful request cancellation with connection failure
|
||||
/// 3. With websocket upgrade support.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn connection_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
conn: TlsStream<ChainRW<TcpStream>>,
|
||||
peer_addr: IpAddr,
|
||||
session_id: uuid::Uuid,
|
||||
) {
|
||||
let session_id = AtomicTake::new(session_id);
|
||||
|
||||
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let server = Builder::new(TokioExecutor::new());
|
||||
let conn = server.serve_connection_with_upgrades(
|
||||
hyper_util::rt::TokioIo::new(conn),
|
||||
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
|
||||
// First HTTP request shares the same session ID
|
||||
let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
|
||||
|
||||
// Cancel the current inflight HTTP request if the requets stream is closed.
|
||||
// This is slightly different to `_cancel_connection` in that
|
||||
// h2 can cancel individual requests with a `RST_STREAM`.
|
||||
let http_request_token = http_cancellation_token.child_token();
|
||||
let cancel_request = http_request_token.clone().drop_guard();
|
||||
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
let handler = connections.spawn(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
peer_addr,
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
);
|
||||
async move {
|
||||
let res = handler.await;
|
||||
cancel_request.disarm();
|
||||
res
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// On cancellation, trigger the HTTP connection handler to shut down.
|
||||
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
|
||||
Either::Left((_cancelled, mut conn)) => {
|
||||
tracing::debug!(%peer_addr, "cancelling connection");
|
||||
conn.as_mut().graceful_shutdown();
|
||||
conn.await
|
||||
}
|
||||
Either::Right((res, _)) => res,
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
|
||||
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: hyper1::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if framed_websockets::upgrade::is_upgrade_request(&request) {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
crate::metrics::Protocol::Ws,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let span = ctx.span();
|
||||
info!(parent: &span, "performing websocket upgrade");
|
||||
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
ctx,
|
||||
websocket,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
host,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("error in websocket connection: {e:#}");
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
crate::metrics::Protocol::Http,
|
||||
&config.region,
|
||||
);
|
||||
let span = ctx.span();
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
|
||||
.instrument(span)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
.header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Full::new(Bytes::new()))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
}
|
||||
}
|
||||
257
proxy/core/src/serverless/backend.rs
Normal file
257
proxy/core/src/serverless/backend.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tracing::{field::display, info};
|
||||
|
||||
use crate::{
|
||||
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
|
||||
compute,
|
||||
config::{AuthenticationConfig, ProxyConfig},
|
||||
console::{
|
||||
errors::{GetAuthInfoError, WakeComputeError},
|
||||
locks::ApiLocks,
|
||||
provider::ApiLockError,
|
||||
CachedNodeInfo,
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
error::{ErrorKind, ReportableError, UserFacingError},
|
||||
intern::EndpointIdInt,
|
||||
proxy::{
|
||||
connect_compute::ConnectMechanism,
|
||||
retry::{CouldRetry, ShouldRetryWakeCompute},
|
||||
},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
Host,
|
||||
};
|
||||
|
||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||
|
||||
pub struct PoolingBackend {
|
||||
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub config: &'static ProxyConfig,
|
||||
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
impl PoolingBackend {
|
||||
pub async fn authenticate(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = conn_info.user_info.clone();
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
if !self
|
||||
.endpoint_rate_limiter
|
||||
.check(conn_info.user_info.endpoint.clone().into(), 1)
|
||||
{
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => backend.get_role_secret(ctx).await?,
|
||||
};
|
||||
|
||||
let secret = match cached_secret.value.clone() {
|
||||
Some(secret) => self.config.authentication_config.check_rate_limit(
|
||||
ctx,
|
||||
config,
|
||||
secret,
|
||||
&user_info.endpoint,
|
||||
true,
|
||||
)?,
|
||||
None => {
|
||||
// If we don't have an authentication secret, for the http flow we can just return an error.
|
||||
info!("authentication info not found");
|
||||
return Err(AuthError::auth_failed(&*user_info.user));
|
||||
}
|
||||
};
|
||||
let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
|
||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||
&config.thread_pool,
|
||||
ep,
|
||||
&conn_info.password,
|
||||
secret,
|
||||
)
|
||||
.await?;
|
||||
let res = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => {
|
||||
info!("user successfully authenticated");
|
||||
Ok(key)
|
||||
}
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
Err(AuthError::auth_failed(&*conn_info.user_info.user))
|
||||
}
|
||||
};
|
||||
res.map(|key| ComputeCredentials {
|
||||
info: user_info,
|
||||
keys: key,
|
||||
})
|
||||
}
|
||||
|
||||
// Wake up the destination if needed. Code here is a bit involved because
|
||||
// we reuse the code from the usual proxy and we need to prepare few structures
|
||||
// that this code expects.
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
pub async fn connect_to_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentials,
|
||||
force_new: bool,
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
let maybe_client = if !force_new {
|
||||
info!("pool: looking for an existing connection");
|
||||
self.pool.get(ctx, &conn_info)?
|
||||
} else {
|
||||
info!("pool: pool is disabled");
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(client) = maybe_client {
|
||||
return Ok(client);
|
||||
}
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| keys);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&backend,
|
||||
false, // do not allow self signed compute for http flow
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
#[error("could not connection to compute")]
|
||||
ConnectionError(#[from] tokio_postgres::Error),
|
||||
|
||||
#[error("could not get auth info")]
|
||||
GetAuthInfo(#[from] GetAuthInfoError),
|
||||
#[error("user not authenticated")]
|
||||
AuthError(#[from] AuthError),
|
||||
#[error("wake_compute returned error")]
|
||||
WakeCompute(#[from] WakeComputeError),
|
||||
#[error("error acquiring resource permit: {0}")]
|
||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||
}
|
||||
|
||||
impl ReportableError for HttpConnError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
||||
HttpConnError::ConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
||||
HttpConnError::AuthError(a) => a.get_error_kind(),
|
||||
HttpConnError::WakeCompute(w) => w.get_error_kind(),
|
||||
HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for HttpConnError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
||||
HttpConnError::ConnectionError(p) => p.to_string(),
|
||||
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
||||
HttpConnError::AuthError(c) => c.to_string_client(),
|
||||
HttpConnError::WakeCompute(c) => c.to_string_client(),
|
||||
HttpConnError::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for HttpConnError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
HttpConnError::ConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => false,
|
||||
HttpConnError::GetAuthInfo(_) => false,
|
||||
HttpConnError::AuthError(_) => false,
|
||||
HttpConnError::WakeCompute(_) => false,
|
||||
HttpConnError::TooManyConnectionAttempts(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for HttpConnError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
|
||||
// we never checked cache validity
|
||||
HttpConnError::TooManyConnectionAttempts(_) => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
locks: &'static ApiLocks<Host>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TokioMechanism {
|
||||
type Connection = Client<tokio_postgres::Client>;
|
||||
type ConnectError = HttpConnError;
|
||||
type Error = HttpConnError;
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host()?;
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
|
||||
let mut config = (*node_info.config).clone();
|
||||
let config = config
|
||||
.user(&self.conn_info.user_info.user)
|
||||
.password(&*self.conn_info.password)
|
||||
.dbname(&self.conn_info.dbname)
|
||||
.connect_timeout(timeout);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(tokio_postgres::NoTls).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
|
||||
Ok(poll_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
self.conn_info.clone(),
|
||||
client,
|
||||
connection,
|
||||
self.conn_id,
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
102
proxy/core/src/serverless/cancel_set.rs
Normal file
102
proxy/core/src/serverless/cancel_set.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
//! A set for cancelling random http connections
|
||||
|
||||
use std::{
|
||||
hash::{BuildHasher, BuildHasherDefault},
|
||||
num::NonZeroUsize,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use parking_lot::Mutex;
|
||||
use rand::{thread_rng, Rng};
|
||||
use rustc_hash::FxHasher;
|
||||
use tokio::time::Instant;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
type Hasher = BuildHasherDefault<FxHasher>;
|
||||
|
||||
pub struct CancelSet {
|
||||
shards: Box<[Mutex<CancelShard>]>,
|
||||
// keyed by random uuid, fxhasher is fine
|
||||
hasher: Hasher,
|
||||
}
|
||||
|
||||
pub struct CancelShard {
|
||||
tokens: IndexMap<uuid::Uuid, (Instant, CancellationToken), Hasher>,
|
||||
}
|
||||
|
||||
impl CancelSet {
|
||||
pub fn new(shards: usize) -> Self {
|
||||
CancelSet {
|
||||
shards: (0..shards)
|
||||
.map(|_| {
|
||||
Mutex::new(CancelShard {
|
||||
tokens: IndexMap::with_hasher(Hasher::default()),
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
hasher: Hasher::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take(&self) -> Option<CancellationToken> {
|
||||
for _ in 0..4 {
|
||||
if let Some(token) = self.take_raw(thread_rng().gen()) {
|
||||
return Some(token);
|
||||
}
|
||||
tracing::trace!("failed to get cancel token");
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn take_raw(&self, rng: usize) -> Option<CancellationToken> {
|
||||
NonZeroUsize::new(self.shards.len())
|
||||
.and_then(|len| self.shards[rng % len].lock().take(rng / len))
|
||||
}
|
||||
|
||||
pub fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> {
|
||||
let shard = NonZeroUsize::new(self.shards.len()).map(|len| {
|
||||
let hash = self.hasher.hash_one(id) as usize;
|
||||
let shard = &self.shards[hash % len];
|
||||
shard.lock().insert(id, token);
|
||||
shard
|
||||
});
|
||||
CancelGuard { shard, id }
|
||||
}
|
||||
}
|
||||
|
||||
impl CancelShard {
|
||||
fn take(&mut self, rng: usize) -> Option<CancellationToken> {
|
||||
NonZeroUsize::new(self.tokens.len()).and_then(|len| {
|
||||
// 10 second grace period so we don't cancel new connections
|
||||
if self.tokens.get_index(rng % len)?.1 .0.elapsed() < Duration::from_secs(10) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let (_key, (_insert, token)) = self.tokens.swap_remove_index(rng % len)?;
|
||||
Some(token)
|
||||
})
|
||||
}
|
||||
|
||||
fn remove(&mut self, id: uuid::Uuid) {
|
||||
self.tokens.swap_remove(&id);
|
||||
}
|
||||
|
||||
fn insert(&mut self, id: uuid::Uuid, token: CancellationToken) {
|
||||
self.tokens.insert(id, (Instant::now(), token));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CancelGuard<'a> {
|
||||
shard: Option<&'a Mutex<CancelShard>>,
|
||||
id: Uuid,
|
||||
}
|
||||
|
||||
impl Drop for CancelGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(shard) = self.shard {
|
||||
shard.lock().remove(self.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
865
proxy/core/src/serverless/conn_pool.rs
Normal file
865
proxy/core/src/serverless/conn_pool.rs
Normal file
@@ -0,0 +1,865 @@
|
||||
use dashmap::DashMap;
|
||||
use futures::{future::poll_fn, Future};
|
||||
use parking_lot::RwLock;
|
||||
use rand::Rng;
|
||||
use smallvec::SmallVec;
|
||||
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||
use std::{
|
||||
fmt,
|
||||
task::{ready, Poll},
|
||||
};
|
||||
use std::{
|
||||
ops::Deref,
|
||||
sync::atomic::{self, AtomicUsize},
|
||||
};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
|
||||
};
|
||||
|
||||
use tracing::{debug, error, warn, Span};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnInfo {
|
||||
pub user_info: ComputeUserInfo,
|
||||
pub dbname: DbName,
|
||||
pub password: SmallVec<[u8; 16]>,
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
// hm, change to hasher to avoid cloning?
|
||||
pub fn db_and_user(&self) -> (DbName, RoleName) {
|
||||
(self.dbname.clone(), self.user_info.user.clone())
|
||||
}
|
||||
|
||||
pub fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
|
||||
// We don't want to cache http connections for ephemeral endpoints.
|
||||
if self.user_info.options.is_ephemeral() {
|
||||
None
|
||||
} else {
|
||||
Some(self.user_info.endpoint_cache_key())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnInfo {
|
||||
// use custom display to avoid logging password
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}@{}/{}?{}",
|
||||
self.user_info.user,
|
||||
self.user_info.endpoint,
|
||||
self.dbname,
|
||||
self.user_info.options.get_cache_key("")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
conn: ClientInner<C>,
|
||||
_last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub struct EndpointConnPool<C: ClientInnerExt> {
|
||||
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
|
||||
total_conns: usize,
|
||||
max_conns: usize,
|
||||
_guard: HttpEndpointPoolsGuard<'static>,
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
global_pool_size_max_conns: usize,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
|
||||
let Self {
|
||||
pools,
|
||||
total_conns,
|
||||
global_connections_count,
|
||||
..
|
||||
} = self;
|
||||
pools.get_mut(&db_user).and_then(|pool_entries| {
|
||||
pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
|
||||
})
|
||||
}
|
||||
|
||||
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
|
||||
let Self {
|
||||
pools,
|
||||
total_conns,
|
||||
global_connections_count,
|
||||
..
|
||||
} = self;
|
||||
if let Some(pool) = pools.get_mut(&db_user) {
|
||||
let old_len = pool.conns.len();
|
||||
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
|
||||
let new_len = pool.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
}
|
||||
*total_conns -= removed;
|
||||
removed > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
|
||||
let conn_id = client.conn_id;
|
||||
|
||||
if client.is_closed() {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
return;
|
||||
}
|
||||
let global_max_conn = pool.read().global_pool_size_max_conns;
|
||||
if pool
|
||||
.read()
|
||||
.global_connections_count
|
||||
.load(atomic::Ordering::Relaxed)
|
||||
>= global_max_conn
|
||||
{
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
|
||||
return;
|
||||
}
|
||||
|
||||
// return connection to the pool
|
||||
let mut returned = false;
|
||||
let mut per_db_size = 0;
|
||||
let total_conns = {
|
||||
let mut pool = pool.write();
|
||||
|
||||
if pool.total_conns < pool.max_conns {
|
||||
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
|
||||
pool_entries.conns.push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
returned = true;
|
||||
per_db_size = pool_entries.conns.len();
|
||||
|
||||
pool.total_conns += 1;
|
||||
pool.global_connections_count
|
||||
.fetch_add(1, atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.inc();
|
||||
}
|
||||
|
||||
pool.total_conns
|
||||
};
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
} else {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
|
||||
fn drop(&mut self) {
|
||||
if self.total_conns > 0 {
|
||||
self.global_connections_count
|
||||
.fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(self.total_conns as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DbUserConnPool<C: ClientInnerExt> {
|
||||
conns: Vec<ConnPoolEntry<C>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
|
||||
fn default() -> Self {
|
||||
Self { conns: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
|
||||
let old_len = self.conns.len();
|
||||
|
||||
self.conns.retain(|conn| !conn.conn.is_closed());
|
||||
|
||||
let new_len = self.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
*conns -= removed;
|
||||
removed
|
||||
}
|
||||
|
||||
fn get_conn_entry(
|
||||
&mut self,
|
||||
conns: &mut usize,
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
) -> Option<ConnPoolEntry<C>> {
|
||||
let mut removed = self.clear_closed_clients(conns);
|
||||
let conn = self.conns.pop();
|
||||
if conn.is_some() {
|
||||
*conns -= 1;
|
||||
removed += 1;
|
||||
}
|
||||
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
conn
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GlobalConnPool<C: ClientInnerExt> {
|
||||
// endpoint -> per-endpoint connection pool
|
||||
//
|
||||
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||
// pool as early as possible and release the lock.
|
||||
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
|
||||
|
||||
/// Number of endpoint-connection pools
|
||||
///
|
||||
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
|
||||
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
|
||||
/// It's only used for diagnostics.
|
||||
global_pool_size: AtomicUsize,
|
||||
|
||||
/// Total number of connections in the pool
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
|
||||
config: &'static crate::config::HttpConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GlobalConnPoolOptions {
|
||||
// Maximum number of connections per one endpoint.
|
||||
// Can mix different (dbname, username) connections.
|
||||
// When running out of free slots for a particular endpoint,
|
||||
// falls back to opening a new connection for each request.
|
||||
pub max_conns_per_endpoint: usize,
|
||||
|
||||
pub gc_epoch: Duration,
|
||||
|
||||
pub pool_shards: usize,
|
||||
|
||||
pub idle_timeout: Duration,
|
||||
|
||||
pub opt_in: bool,
|
||||
|
||||
// Total number of connections in the pool.
|
||||
pub max_total_conns: usize,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
pub fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
let shards = config.pool_options.pool_shards;
|
||||
Arc::new(Self {
|
||||
global_pool: DashMap::with_shard_amount(shards),
|
||||
global_pool_size: AtomicUsize::new(0),
|
||||
config,
|
||||
global_connections_count: Arc::new(AtomicUsize::new(0)),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn get_global_connections_count(&self) -> usize {
|
||||
self.global_connections_count
|
||||
.load(atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_idle_timeout(&self) -> Duration {
|
||||
self.config.pool_options.idle_timeout
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
// drops all strong references to endpoint-pools
|
||||
self.global_pool.clear();
|
||||
}
|
||||
|
||||
pub async fn gc_worker(&self, mut rng: impl Rng) {
|
||||
let epoch = self.config.pool_options.gc_epoch;
|
||||
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let shard = rng.gen_range(0..self.global_pool.shards().len());
|
||||
self.gc(shard);
|
||||
}
|
||||
}
|
||||
|
||||
fn gc(&self, shard: usize) {
|
||||
debug!(shard, "pool: performing epoch reclamation");
|
||||
|
||||
// acquire a random shard lock
|
||||
let mut shard = self.global_pool.shards()[shard].write();
|
||||
|
||||
let timer = Metrics::get()
|
||||
.proxy
|
||||
.http_pool_reclaimation_lag_seconds
|
||||
.start_timer();
|
||||
let current_len = shard.len();
|
||||
let mut clients_removed = 0;
|
||||
shard.retain(|endpoint, x| {
|
||||
// if the current endpoint pool is unique (no other strong or weak references)
|
||||
// then it is currently not in use by any connections.
|
||||
if let Some(pool) = Arc::get_mut(x.get_mut()) {
|
||||
let EndpointConnPool {
|
||||
pools, total_conns, ..
|
||||
} = pool.get_mut();
|
||||
|
||||
// ensure that closed clients are removed
|
||||
pools.iter_mut().for_each(|(_, db_pool)| {
|
||||
clients_removed += db_pool.clear_closed_clients(total_conns);
|
||||
});
|
||||
|
||||
// we only remove this pool if it has no active connections
|
||||
if *total_conns == 0 {
|
||||
info!("pool: discarding pool for endpoint {endpoint}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
});
|
||||
|
||||
let new_len = shard.len();
|
||||
drop(shard);
|
||||
timer.observe();
|
||||
|
||||
// Do logging outside of the lock.
|
||||
if clients_removed > 0 {
|
||||
let size = self
|
||||
.global_connections_count
|
||||
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
|
||||
- clients_removed;
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(clients_removed as i64);
|
||||
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
|
||||
}
|
||||
let removed = current_len - new_len;
|
||||
|
||||
if removed > 0 {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_sub(removed, atomic::Ordering::Relaxed)
|
||||
- removed;
|
||||
info!("pool: performed global pool gc. size now {global_pool_size}");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInner<C>> = None;
|
||||
let Some(endpoint) = conn_info.endpoint_cache_key() else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
if let Some(entry) = endpoint_pool
|
||||
.write()
|
||||
.get_conn_entry(conn_info.db_and_user())
|
||||
{
|
||||
client = Some(entry.conn)
|
||||
}
|
||||
let endpoint_pool = Arc::downgrade(&endpoint_pool);
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(client) = client {
|
||||
if client.is_closed() {
|
||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
} else {
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
tracing::field::display(client.inner.get_process_id()),
|
||||
);
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
client.session.send(ctx.session_id())?;
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(
|
||||
self: &Arc<Self>,
|
||||
endpoint: &EndpointCacheKey,
|
||||
) -> Arc<RwLock<EndpointConnPool<C>>> {
|
||||
// fast path
|
||||
if let Some(pool) = self.global_pool.get(endpoint) {
|
||||
return pool.clone();
|
||||
}
|
||||
|
||||
// slow path
|
||||
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
|
||||
pools: HashMap::new(),
|
||||
total_conns: 0,
|
||||
max_conns: self.config.pool_options.max_conns_per_endpoint,
|
||||
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
|
||||
global_connections_count: self.global_connections_count.clone(),
|
||||
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
|
||||
}));
|
||||
|
||||
// find or create a pool for this endpoint
|
||||
let mut created = false;
|
||||
let pool = self
|
||||
.global_pool
|
||||
.entry(endpoint.clone())
|
||||
.or_insert_with(|| {
|
||||
created = true;
|
||||
new_pool
|
||||
})
|
||||
.clone();
|
||||
|
||||
// log new global pool size
|
||||
if created {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_add(1, atomic::Ordering::Relaxed)
|
||||
+ 1;
|
||||
info!(
|
||||
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
|
||||
);
|
||||
}
|
||||
|
||||
pool
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_client<C: ClientInnerExt>(
|
||||
global_pool: Arc<GlobalConnPool<C>>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
client: C,
|
||||
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<C> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let mut session_id = ctx.session_id();
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
});
|
||||
let pool = match conn_info.endpoint_cache_key() {
|
||||
Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
|
||||
None => Weak::new(),
|
||||
};
|
||||
let pool_clone = pool.clone();
|
||||
|
||||
let db_user = conn_info.db_and_user();
|
||||
let idle = global_pool.get_idle_timeout();
|
||||
let cancel = CancellationToken::new();
|
||||
let cancelled = cancel.clone().cancelled_owned();
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let _conn_gauge = conn_gauge;
|
||||
let mut idle_timeout = pin!(tokio::time::sleep(idle));
|
||||
let mut cancelled = pin!(cancelled);
|
||||
|
||||
poll_fn(move |cx| {
|
||||
if cancelled.as_mut().poll(cx).is_ready() {
|
||||
info!("connection dropped");
|
||||
return Poll::Ready(())
|
||||
}
|
||||
|
||||
match rx.has_changed() {
|
||||
Ok(true) => {
|
||||
session_id = *rx.borrow_and_update();
|
||||
info!(%session_id, "changed session");
|
||||
idle_timeout.as_mut().reset(Instant::now() + idle);
|
||||
}
|
||||
Err(_) => {
|
||||
info!("connection dropped");
|
||||
return Poll::Ready(())
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// 5 minute idle connection timeout
|
||||
if idle_timeout.as_mut().poll(cx).is_ready() {
|
||||
idle_timeout.as_mut().reset(Instant::now() + idle);
|
||||
info!("connection idle");
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
// remove client from pool - should close the connection if it's idle.
|
||||
// does nothing if the client is currently checked-out and in-use
|
||||
if pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
info!("idle connection removed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let message = ready!(connection.poll_message(cx));
|
||||
|
||||
match message {
|
||||
Some(Ok(AsyncMessage::Notice(notice))) => {
|
||||
info!(%session_id, "notice: {}", notice);
|
||||
}
|
||||
Some(Ok(AsyncMessage::Notification(notif))) => {
|
||||
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
warn!(%session_id, "unknown message");
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!(%session_id, "connection error: {}", e);
|
||||
break
|
||||
}
|
||||
None => {
|
||||
info!("connection closed");
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// remove from connection pool
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
if pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
info!("closed connection removed");
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Ready(())
|
||||
}).await;
|
||||
|
||||
}
|
||||
.instrument(span));
|
||||
let inner = ClientInner {
|
||||
inner: client,
|
||||
session: tx,
|
||||
cancel,
|
||||
aux,
|
||||
conn_id,
|
||||
};
|
||||
Client::new(inner, conn_info, pool_clone)
|
||||
}
|
||||
|
||||
struct ClientInner<C: ClientInnerExt> {
|
||||
inner: C,
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
cancel: CancellationToken,
|
||||
aux: MetricsAuxInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for ClientInner<C> {
|
||||
fn drop(&mut self) {
|
||||
// on client drop, tell the conn to shut down
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ClientInnerExt: Sync + Send + 'static {
|
||||
fn is_closed(&self) -> bool;
|
||||
fn get_process_id(&self) -> i32;
|
||||
}
|
||||
|
||||
impl ClientInnerExt for tokio_postgres::Client {
|
||||
fn is_closed(&self) -> bool {
|
||||
self.is_closed()
|
||||
}
|
||||
fn get_process_id(&self) -> i32 {
|
||||
self.get_process_id()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInner<C> {
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.inner.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Client<C> {
|
||||
pub fn metrics(&self) -> Arc<MetricCounter> {
|
||||
let aux = &self.inner.as_ref().unwrap().aux;
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
branch_id: aux.branch_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Client<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInner<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
pub struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Client<C> {
|
||||
pub(self) fn new(
|
||||
inner: ClientInner<C>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<RwLock<EndpointConnPool<C>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Some(inner),
|
||||
span: Span::current(),
|
||||
conn_info,
|
||||
pool,
|
||||
}
|
||||
}
|
||||
pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_info,
|
||||
span: _,
|
||||
} = self;
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
(&mut inner.inner, Discard { pool, conn_info })
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is not idle")
|
||||
}
|
||||
}
|
||||
pub fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Deref for Client<C> {
|
||||
type Target = C;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed")
|
||||
.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Client<C> {
|
||||
fn do_drop(&mut self) -> Option<impl FnOnce()> {
|
||||
let conn_info = self.conn_info.clone();
|
||||
let client = self
|
||||
.inner
|
||||
.take()
|
||||
.expect("client inner should not be removed");
|
||||
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
|
||||
let current_span = self.span.clone();
|
||||
// return connection to the pool
|
||||
return Some(move || {
|
||||
let _span = current_span.enter();
|
||||
EndpointConnPool::put(&conn_pool, &conn_info, client);
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for Client<C> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(drop) = self.do_drop() {
|
||||
tokio::task::spawn_blocking(drop);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{mem, sync::atomic::AtomicBool};
|
||||
|
||||
use crate::{serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId};
|
||||
|
||||
use super::*;
|
||||
|
||||
struct MockClient(Arc<AtomicBool>);
|
||||
impl MockClient {
|
||||
fn new(is_closed: bool) -> Self {
|
||||
MockClient(Arc::new(is_closed.into()))
|
||||
}
|
||||
}
|
||||
impl ClientInnerExt for MockClient {
|
||||
fn is_closed(&self) -> bool {
|
||||
self.0.load(atomic::Ordering::Relaxed)
|
||||
}
|
||||
fn get_process_id(&self) -> i32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn create_inner() -> ClientInner<MockClient> {
|
||||
create_inner_with(MockClient::new(false))
|
||||
}
|
||||
|
||||
fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
|
||||
ClientInner {
|
||||
inner: client,
|
||||
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
|
||||
cancel: CancellationToken::new(),
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
conn_id: uuid::Uuid::new_v4(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pool() {
|
||||
let _ = env_logger::try_init();
|
||||
let config = Box::leak(Box::new(crate::config::HttpConfig {
|
||||
pool_options: GlobalConnPoolOptions {
|
||||
max_conns_per_endpoint: 2,
|
||||
gc_epoch: Duration::from_secs(1),
|
||||
pool_shards: 2,
|
||||
idle_timeout: Duration::from_secs(1),
|
||||
opt_in: false,
|
||||
max_total_conns: 3,
|
||||
},
|
||||
cancel_set: CancelSet::new(0),
|
||||
client_conn_threshold: u64::MAX,
|
||||
}));
|
||||
let pool = GlobalConnPool::new(config);
|
||||
let conn_info = ConnInfo {
|
||||
user_info: ComputeUserInfo {
|
||||
user: "user".into(),
|
||||
endpoint: "endpoint".into(),
|
||||
options: Default::default(),
|
||||
},
|
||||
dbname: "dbname".into(),
|
||||
password: "password".as_bytes().into(),
|
||||
};
|
||||
let ep_pool = Arc::downgrade(
|
||||
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
|
||||
);
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
assert_eq!(0, pool.get_global_connections_count());
|
||||
client.inner().1.discard();
|
||||
// Discard should not add the connection from the pool.
|
||||
assert_eq!(0, pool.get_global_connections_count());
|
||||
}
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
assert_eq!(1, pool.get_global_connections_count());
|
||||
}
|
||||
{
|
||||
let mut closed_client = Client::new(
|
||||
create_inner_with(MockClient::new(true)),
|
||||
conn_info.clone(),
|
||||
ep_pool.clone(),
|
||||
);
|
||||
closed_client.do_drop().unwrap()();
|
||||
mem::forget(closed_client); // drop the client
|
||||
// The closed client shouldn't be added to the pool.
|
||||
assert_eq!(1, pool.get_global_connections_count());
|
||||
}
|
||||
let is_closed: Arc<AtomicBool> = Arc::new(false.into());
|
||||
{
|
||||
let mut client = Client::new(
|
||||
create_inner_with(MockClient(is_closed.clone())),
|
||||
conn_info.clone(),
|
||||
ep_pool.clone(),
|
||||
);
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
|
||||
// The client should be added to the pool.
|
||||
assert_eq!(2, pool.get_global_connections_count());
|
||||
}
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info, ep_pool);
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
|
||||
// The client shouldn't be added to the pool. Because the ep-pool is full.
|
||||
assert_eq!(2, pool.get_global_connections_count());
|
||||
}
|
||||
|
||||
let conn_info = ConnInfo {
|
||||
user_info: ComputeUserInfo {
|
||||
user: "user".into(),
|
||||
endpoint: "endpoint-2".into(),
|
||||
options: Default::default(),
|
||||
},
|
||||
dbname: "dbname".into(),
|
||||
password: "password".as_bytes().into(),
|
||||
};
|
||||
let ep_pool = Arc::downgrade(
|
||||
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
|
||||
);
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
assert_eq!(3, pool.get_global_connections_count());
|
||||
}
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
|
||||
// The client shouldn't be added to the pool. Because the global pool is full.
|
||||
assert_eq!(3, pool.get_global_connections_count());
|
||||
}
|
||||
|
||||
is_closed.store(true, atomic::Ordering::Relaxed);
|
||||
// Do gc for all shards.
|
||||
pool.gc(0);
|
||||
pool.gc(1);
|
||||
// Closed client should be removed from the pool.
|
||||
assert_eq!(2, pool.get_global_connections_count());
|
||||
}
|
||||
}
|
||||
96
proxy/core/src/serverless/http_util.rs
Normal file
96
proxy/core/src/serverless/http_util.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
//! Things stolen from `libs/utils/src/http` to add hyper 1.0 compatibility
|
||||
//! Will merge back in at some point in the future.
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
use anyhow::Context;
|
||||
use http::{Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
|
||||
use serde::Serialize;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
/// Like [`ApiError::into_response`]
|
||||
pub fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
match this {
|
||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||
StatusCode::BAD_REQUEST,
|
||||
),
|
||||
ApiError::Forbidden(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
|
||||
}
|
||||
ApiError::Unauthorized(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
ApiError::NotFound(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
|
||||
}
|
||||
ApiError::Conflict(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
|
||||
}
|
||||
ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
|
||||
this.to_string(),
|
||||
StatusCode::PRECONDITION_FAILED,
|
||||
),
|
||||
ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
|
||||
"Shutting down".to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::REQUEST_TIMEOUT,
|
||||
),
|
||||
ApiError::Cancelled => HttpErrorBody::response_from_msg_and_status(
|
||||
this.to_string(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::error::HttpErrorBody`]
|
||||
#[derive(Serialize)]
|
||||
struct HttpErrorBody {
|
||||
pub msg: String,
|
||||
}
|
||||
|
||||
impl HttpErrorBody {
|
||||
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
|
||||
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
HttpErrorBody { msg }.to_response(status)
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
|
||||
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::json::json_response`]
|
||||
pub fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
.body(Full::new(Bytes::from(json)))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
462
proxy/core/src/serverless/json.rs
Normal file
462
proxy/core/src/serverless/json.rs
Normal file
@@ -0,0 +1,462 @@
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use tokio_postgres::types::Kind;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::Row;
|
||||
|
||||
//
|
||||
// Convert json non-string types to strings, so that they can be passed to Postgres
|
||||
// as parameters.
|
||||
//
|
||||
pub fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
|
||||
json.iter().map(json_value_to_pg_text).collect()
|
||||
}
|
||||
|
||||
fn json_value_to_pg_text(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
|
||||
// convert to text with escaping
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
|
||||
|
||||
// avoid escaping here, as we pass this as a parameter
|
||||
Value::String(s) => Some(s.to_string()),
|
||||
|
||||
// special care for arrays
|
||||
Value::Array(_) => json_array_to_pg_array(value),
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Serialize a JSON array to a Postgres array. Contrary to the strings in the params
|
||||
// in the array we need to escape the strings. Postgres is okay with arrays of form
|
||||
// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
|
||||
// it for Postgres to check.
|
||||
//
|
||||
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
|
||||
//
|
||||
fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
|
||||
// convert to text with escaping
|
||||
// here string needs to be escaped, as it is part of the array
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
|
||||
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
|
||||
|
||||
// recurse into array
|
||||
Value::Array(arr) => {
|
||||
let vals = arr
|
||||
.iter()
|
||||
.map(json_array_to_pg_array)
|
||||
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
Some(format!("{{{}}}", vals))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum JsonConversionError {
|
||||
#[error("internal error compute returned invalid data: {0}")]
|
||||
AsTextError(tokio_postgres::Error),
|
||||
#[error("parse int error: {0}")]
|
||||
ParseIntError(#[from] std::num::ParseIntError),
|
||||
#[error("parse float error: {0}")]
|
||||
ParseFloatError(#[from] std::num::ParseFloatError),
|
||||
#[error("parse json error: {0}")]
|
||||
ParseJsonError(#[from] serde_json::Error),
|
||||
#[error("unbalanced array")]
|
||||
UnbalancedArray,
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
pub fn pg_text_row_to_json(
|
||||
row: &Row,
|
||||
columns: &[Type],
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> Result<Value, JsonConversionError> {
|
||||
let iter = row
|
||||
.columns()
|
||||
.iter()
|
||||
.zip(columns)
|
||||
.enumerate()
|
||||
.map(|(i, (column, typ))| {
|
||||
let name = column.name();
|
||||
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
|
||||
let json_value = if raw_output {
|
||||
match pg_value {
|
||||
Some(v) => Value::String(v.to_string()),
|
||||
None => Value::Null,
|
||||
}
|
||||
} else {
|
||||
pg_text_to_json(pg_value, typ)?
|
||||
};
|
||||
Ok((name.to_string(), json_value))
|
||||
});
|
||||
|
||||
if array_mode {
|
||||
// drop keys and aggregate into array
|
||||
let arr = iter
|
||||
.map(|r| r.map(|(_key, val)| val))
|
||||
.collect::<Result<Vec<Value>, JsonConversionError>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter.collect::<Result<Map<String, Value>, JsonConversionError>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres text-encoded value to JSON value
|
||||
//
|
||||
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, JsonConversionError> {
|
||||
if let Some(val) = pg_value {
|
||||
if let Kind::Array(elem_type) = pg_type.kind() {
|
||||
return pg_array_parse(val, elem_type);
|
||||
}
|
||||
|
||||
match *pg_type {
|
||||
Type::BOOL => Ok(Value::Bool(val == "t")),
|
||||
Type::INT2 | Type::INT4 => {
|
||||
let val = val.parse::<i32>()?;
|
||||
Ok(Value::Number(serde_json::Number::from(val)))
|
||||
}
|
||||
Type::FLOAT4 | Type::FLOAT8 => {
|
||||
let fval = val.parse::<f64>()?;
|
||||
let num = serde_json::Number::from_f64(fval);
|
||||
if let Some(num) = num {
|
||||
Ok(Value::Number(num))
|
||||
} else {
|
||||
// Pass Nan, Inf, -Inf as strings
|
||||
// JS JSON.stringify() does converts them to null, but we
|
||||
// want to preserve them, so we pass them as strings
|
||||
Ok(Value::String(val.to_string()))
|
||||
}
|
||||
}
|
||||
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
|
||||
_ => Ok(Value::String(val.to_string())),
|
||||
}
|
||||
} else {
|
||||
Ok(Value::Null)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Parse postgres array into JSON array.
|
||||
//
|
||||
// This is a bit involved because we need to handle nested arrays and quoted
|
||||
// values. Unlike postgres we don't check that all nested arrays have the same
|
||||
// dimensions, we just return them as is.
|
||||
//
|
||||
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, JsonConversionError> {
|
||||
_pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
|
||||
}
|
||||
|
||||
fn _pg_array_parse(
|
||||
pg_array: &str,
|
||||
elem_type: &Type,
|
||||
nested: bool,
|
||||
) -> Result<(Value, usize), JsonConversionError> {
|
||||
let mut pg_array_chr = pg_array.char_indices();
|
||||
let mut level = 0;
|
||||
let mut quote = false;
|
||||
let mut entries: Vec<Value> = Vec::new();
|
||||
let mut entry = String::new();
|
||||
|
||||
// skip bounds decoration
|
||||
if let Some('[') = pg_array.chars().next() {
|
||||
for (_, c) in pg_array_chr.by_ref() {
|
||||
if c == '=' {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn push_checked(
|
||||
entry: &mut String,
|
||||
entries: &mut Vec<Value>,
|
||||
elem_type: &Type,
|
||||
) -> Result<(), JsonConversionError> {
|
||||
if !entry.is_empty() {
|
||||
// While in usual postgres response we get nulls as None and everything else
|
||||
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
|
||||
// string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
|
||||
// here while we have quotation info and convert them to None.
|
||||
if entry == "NULL" {
|
||||
entries.push(pg_text_to_json(None, elem_type)?);
|
||||
} else {
|
||||
entries.push(pg_text_to_json(Some(entry), elem_type)?);
|
||||
}
|
||||
entry.clear();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
while let Some((mut i, mut c)) = pg_array_chr.next() {
|
||||
let mut escaped = false;
|
||||
|
||||
if c == '\\' {
|
||||
escaped = true;
|
||||
(i, c) = pg_array_chr.next().unwrap();
|
||||
}
|
||||
|
||||
match c {
|
||||
'{' if !quote => {
|
||||
level += 1;
|
||||
if level > 1 {
|
||||
let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?;
|
||||
entries.push(res);
|
||||
for _ in 0..off - 1 {
|
||||
pg_array_chr.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
'}' if !quote => {
|
||||
level -= 1;
|
||||
if level == 0 {
|
||||
push_checked(&mut entry, &mut entries, elem_type)?;
|
||||
if nested {
|
||||
return Ok((Value::Array(entries), i));
|
||||
}
|
||||
}
|
||||
}
|
||||
'"' if !escaped => {
|
||||
if quote {
|
||||
// end of quoted string, so push it manually without any checks
|
||||
// for emptiness or nulls
|
||||
entries.push(pg_text_to_json(Some(&entry), elem_type)?);
|
||||
entry.clear();
|
||||
}
|
||||
quote = !quote;
|
||||
}
|
||||
',' if !quote => {
|
||||
push_checked(&mut entry, &mut entries, elem_type)?;
|
||||
}
|
||||
_ => {
|
||||
entry.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if level != 0 {
|
||||
return Err(JsonConversionError::UnbalancedArray);
|
||||
}
|
||||
|
||||
Ok((Value::Array(entries), 0))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_atomic_types_to_pg_params() {
|
||||
let json = vec![Value::Bool(true), Value::Bool(false)];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some("true".to_owned()), Some("false".to_owned())]
|
||||
);
|
||||
|
||||
let json = vec![Value::Number(serde_json::Number::from(42))];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
assert_eq!(pg_params, vec![Some("42".to_owned())]);
|
||||
|
||||
let json = vec![Value::String("foo\"".to_string())];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
|
||||
|
||||
let json = vec![Value::Null];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
assert_eq!(pg_params, vec![None]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_array_to_pg_array() {
|
||||
// atoms and escaping
|
||||
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(
|
||||
"{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
|
||||
)]
|
||||
);
|
||||
|
||||
// nested arrays
|
||||
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(
|
||||
"{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
|
||||
)]
|
||||
);
|
||||
// array of objects
|
||||
let json = r#"[{"foo": 1},{"bar": 2}]"#;
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_types_parse() {
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
|
||||
json!("foo")
|
||||
);
|
||||
assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
|
||||
assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
|
||||
assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
|
||||
json!("42")
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
|
||||
json!(42.42)
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
|
||||
json!(42.42)
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
|
||||
json!("NaN")
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
|
||||
json!("Infinity")
|
||||
);
|
||||
assert_eq!(
|
||||
pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
|
||||
json!("-Infinity")
|
||||
);
|
||||
|
||||
let json: Value =
|
||||
serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}")
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
pg_text_to_json(
|
||||
Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
|
||||
&Type::JSONB
|
||||
)
|
||||
.unwrap(),
|
||||
json
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_array_parse_text() {
|
||||
fn pt(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::TEXT).unwrap()
|
||||
}
|
||||
assert_eq!(
|
||||
pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
|
||||
json!(["aa\"\\,a", "cha", "bbbb"])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{{"foo","bar"},{"bee","bop"}}"#),
|
||||
json!([["foo", "bar"], ["bee", "bop"]])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#),
|
||||
json!([[[["foo", null, "bop", "bup"]]]])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#),
|
||||
json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_array_parse_bool() {
|
||||
fn pb(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::BOOL).unwrap()
|
||||
}
|
||||
assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
|
||||
assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
|
||||
assert_eq!(
|
||||
pb(r#"{{t,f},{f,t}}"#),
|
||||
json!([[true, false], [false, true]])
|
||||
);
|
||||
assert_eq!(
|
||||
pb(r#"{{t,NULL},{NULL,f}}"#),
|
||||
json!([[true, null], [null, false]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_array_parse_numbers() {
|
||||
fn pn(pg_arr: &str, ty: &Type) -> Value {
|
||||
pg_array_parse(pg_arr, ty).unwrap()
|
||||
}
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"]));
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0]));
|
||||
assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0]));
|
||||
assert_eq!(
|
||||
pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4),
|
||||
json!([1.1, 2.2, 3.3])
|
||||
);
|
||||
assert_eq!(
|
||||
pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8),
|
||||
json!([1.1, 2.2, 3.3])
|
||||
);
|
||||
assert_eq!(
|
||||
pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4),
|
||||
json!(["NaN", "Infinity", "-Infinity"])
|
||||
);
|
||||
assert_eq!(
|
||||
pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8),
|
||||
json!(["NaN", "Infinity", "-Infinity"])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_array_with_decoration() {
|
||||
fn p(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::INT2).unwrap()
|
||||
}
|
||||
assert_eq!(
|
||||
p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
|
||||
json!([[[1, 2, 3], [4, 5, 6]]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pg_array_parse_json() {
|
||||
fn pt(pg_arr: &str) -> Value {
|
||||
pg_array_parse(pg_arr, &Type::JSONB).unwrap()
|
||||
}
|
||||
assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
|
||||
assert_eq!(
|
||||
pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
|
||||
json!([{"foo": 1, "bar": 2}])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
|
||||
json!([{"foo": 1}, {"bar": 2}])
|
||||
);
|
||||
assert_eq!(
|
||||
pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
|
||||
json!([[{"foo": 1}, {"bar": 2}]])
|
||||
);
|
||||
}
|
||||
}
|
||||
874
proxy/core/src/serverless/sql_over_http.rs
Normal file
874
proxy/core/src/serverless/sql_over_http.rs
Normal file
@@ -0,0 +1,874 @@
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::future::select;
|
||||
use futures::future::try_join;
|
||||
use futures::future::Either;
|
||||
use futures::StreamExt;
|
||||
use futures::TryFutureExt;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Body;
|
||||
use hyper1::body::Incoming;
|
||||
use hyper1::header;
|
||||
use hyper1::http::HeaderName;
|
||||
use hyper1::http::HeaderValue;
|
||||
use hyper1::Response;
|
||||
use hyper1::StatusCode;
|
||||
use hyper1::{HeaderMap, Request};
|
||||
use pq_proto::StartupMessageParamsBuilder;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tokio::time;
|
||||
use tokio_postgres::error::DbError;
|
||||
use tokio_postgres::error::ErrorPosition;
|
||||
use tokio_postgres::error::SqlState;
|
||||
use tokio_postgres::GenericClient;
|
||||
use tokio_postgres::IsolationLevel;
|
||||
use tokio_postgres::NoTls;
|
||||
use tokio_postgres::ReadyForQueryStatus;
|
||||
use tokio_postgres::Transaction;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use typed_json::json;
|
||||
use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ErrorKind;
|
||||
use crate::error::ReportableError;
|
||||
use crate::error::UserFacingError;
|
||||
use crate::metrics::HttpDirection;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::usage_metrics::MetricCounterRecorder;
|
||||
use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::backend::PoolingBackend;
|
||||
use super::conn_pool::Client;
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::http_util::json_response;
|
||||
use super::json::json_to_pg_text;
|
||||
use super::json::pg_text_row_to_json;
|
||||
use super::json::JsonConversionError;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
#[serde(deserialize_with = "bytes_to_pg_text")]
|
||||
params: Vec<Option<String>>,
|
||||
#[serde(default)]
|
||||
array_mode: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct BatchQueryData {
|
||||
queries: Vec<QueryData>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Payload {
|
||||
Single(QueryData),
|
||||
Batch(BatchQueryData),
|
||||
}
|
||||
|
||||
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
|
||||
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
|
||||
|
||||
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
||||
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
// TODO: consider avoiding the allocation here.
|
||||
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static str),
|
||||
#[error("invalid connection string: {0}")]
|
||||
UrlParseError(#[from] url::ParseError),
|
||||
#[error("incorrect scheme")]
|
||||
IncorrectScheme,
|
||||
#[error("missing database name")]
|
||||
MissingDbName,
|
||||
#[error("invalid database name")]
|
||||
InvalidDbName,
|
||||
#[error("missing username")]
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing password")]
|
||||
MissingPassword,
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
InvalidEndpoint(#[from] ComputeUserInfoParseError),
|
||||
#[error("malformed endpoint")]
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ConnInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
ctx: &RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
tls: &TlsConfig,
|
||||
) -> Result<ConnInfo, ConnInfoError> {
|
||||
// HTTP only uses cleartext (for now and likely always)
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
|
||||
|
||||
let connection_url = Url::parse(connection_string)?;
|
||||
|
||||
let protocol = connection_url.scheme();
|
||||
if protocol != "postgres" && protocol != "postgresql" {
|
||||
return Err(ConnInfoError::IncorrectScheme);
|
||||
}
|
||||
|
||||
let mut url_path = connection_url
|
||||
.path_segments()
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
|
||||
let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into();
|
||||
ctx.set_dbname(dbname.clone());
|
||||
|
||||
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
||||
if username.is_empty() {
|
||||
return Err(ConnInfoError::MissingUsername);
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
.ok_or(ConnInfoError::MissingPassword)?;
|
||||
let password = urlencoding::decode_binary(password.as_bytes());
|
||||
|
||||
let hostname = connection_url
|
||||
.host_str()
|
||||
.ok_or(ConnInfoError::MissingHostname)?;
|
||||
|
||||
let endpoint =
|
||||
endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
let mut params = StartupMessageParamsBuilder::default();
|
||||
params.insert("user", &username);
|
||||
params.insert("database", &dbname);
|
||||
for (key, value) in pairs {
|
||||
params.insert(&key, &value);
|
||||
if key == "options" {
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
}
|
||||
}
|
||||
|
||||
let user_info = ComputeUserInfo {
|
||||
endpoint,
|
||||
user: username,
|
||||
options: options.unwrap_or_default(),
|
||||
};
|
||||
|
||||
Ok(ConnInfo {
|
||||
user_info,
|
||||
dbname,
|
||||
password: match password {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
Ok(r) => {
|
||||
ctx.set_success();
|
||||
r
|
||||
}
|
||||
Err(e @ SqlOverHttpError::Cancelled(_)) => {
|
||||
let error_kind = e.get_error_kind();
|
||||
ctx.set_error_kind(error_kind);
|
||||
|
||||
let message = "Query cancelled, connection was terminated";
|
||||
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%e,
|
||||
msg=message,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
json_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({ "message": message, "code": SqlState::PROTOCOL_VIOLATION.code() }),
|
||||
)?
|
||||
}
|
||||
Err(e) => {
|
||||
let error_kind = e.get_error_kind();
|
||||
ctx.set_error_kind(error_kind);
|
||||
|
||||
let mut message = e.to_string_client();
|
||||
let db_error = match &e {
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
fn get<'a, T: Default>(db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T) -> T {
|
||||
db.map(x).unwrap_or_default()
|
||||
}
|
||||
|
||||
if let Some(db_error) = db_error {
|
||||
db_error.message().clone_into(&mut message);
|
||||
}
|
||||
|
||||
let position = db_error.and_then(|db| db.position());
|
||||
let (position, internal_position, internal_query) = match position {
|
||||
Some(ErrorPosition::Original(position)) => (Some(position.to_string()), None, None),
|
||||
Some(ErrorPosition::Internal { position, query }) => {
|
||||
(None, Some(position.to_string()), Some(query.clone()))
|
||||
}
|
||||
None => (None, None, None),
|
||||
};
|
||||
|
||||
let code = get(db_error, |db| db.code().code());
|
||||
let severity = get(db_error, |db| db.severity());
|
||||
let detail = get(db_error, |db| db.detail());
|
||||
let hint = get(db_error, |db| db.hint());
|
||||
let where_ = get(db_error, |db| db.where_());
|
||||
let table = get(db_error, |db| db.table());
|
||||
let column = get(db_error, |db| db.column());
|
||||
let schema = get(db_error, |db| db.schema());
|
||||
let datatype = get(db_error, |db| db.datatype());
|
||||
let constraint = get(db_error, |db| db.constraint());
|
||||
let file = get(db_error, |db| db.file());
|
||||
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
|
||||
let routine = get(db_error, |db| db.routine());
|
||||
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%e,
|
||||
msg=message,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
// TODO: this shouldn't always be bad request.
|
||||
json_response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({
|
||||
"message": message,
|
||||
"code": code,
|
||||
"detail": detail,
|
||||
"hint": hint,
|
||||
"position": position,
|
||||
"internalPosition": internal_position,
|
||||
"internalQuery": internal_query,
|
||||
"severity": severity,
|
||||
"where": where_,
|
||||
"table": table,
|
||||
"column": column,
|
||||
"schema": schema,
|
||||
"dataType": datatype,
|
||||
"constraint": constraint,
|
||||
"file": file,
|
||||
"line": line,
|
||||
"routine": routine,
|
||||
}),
|
||||
)?
|
||||
}
|
||||
};
|
||||
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SqlOverHttpError {
|
||||
#[error("{0}")]
|
||||
ReadPayload(#[from] ReadPayloadError),
|
||||
#[error("{0}")]
|
||||
ConnectCompute(#[from] HttpConnError),
|
||||
#[error("{0}")]
|
||||
ConnInfo(#[from] ConnInfoError),
|
||||
#[error("request is too large (max is {MAX_REQUEST_SIZE} bytes)")]
|
||||
RequestTooLarge,
|
||||
#[error("response is too large (max is {MAX_RESPONSE_SIZE} bytes)")]
|
||||
ResponseTooLarge,
|
||||
#[error("invalid isolation level")]
|
||||
InvalidIsolationLevel,
|
||||
#[error("{0}")]
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
#[error("{0}")]
|
||||
JsonConversion(#[from] JsonConversionError),
|
||||
#[error("{0}")]
|
||||
Cancelled(SqlOverHttpCancel),
|
||||
}
|
||||
|
||||
impl ReportableError for SqlOverHttpError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
SqlOverHttpError::ReadPayload(e) => e.get_error_kind(),
|
||||
SqlOverHttpError::ConnectCompute(e) => e.get_error_kind(),
|
||||
SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
|
||||
SqlOverHttpError::RequestTooLarge => ErrorKind::User,
|
||||
SqlOverHttpError::ResponseTooLarge => ErrorKind::User,
|
||||
SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
|
||||
SqlOverHttpError::Postgres(p) => p.get_error_kind(),
|
||||
SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
|
||||
SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for SqlOverHttpError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
SqlOverHttpError::ReadPayload(p) => p.to_string(),
|
||||
SqlOverHttpError::ConnectCompute(c) => c.to_string_client(),
|
||||
SqlOverHttpError::ConnInfo(c) => c.to_string_client(),
|
||||
SqlOverHttpError::RequestTooLarge => self.to_string(),
|
||||
SqlOverHttpError::ResponseTooLarge => self.to_string(),
|
||||
SqlOverHttpError::InvalidIsolationLevel => self.to_string(),
|
||||
SqlOverHttpError::Postgres(p) => p.to_string(),
|
||||
SqlOverHttpError::JsonConversion(_) => "could not parse postgres response".to_string(),
|
||||
SqlOverHttpError::Cancelled(_) => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ReadPayloadError {
|
||||
#[error("could not read the HTTP request body: {0}")]
|
||||
Read(#[from] hyper1::Error),
|
||||
#[error("could not parse the HTTP request body: {0}")]
|
||||
Parse(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl ReportableError for ReadPayloadError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
|
||||
ReadPayloadError::Parse(_) => ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SqlOverHttpCancel {
|
||||
#[error("query was cancelled")]
|
||||
Postgres,
|
||||
#[error("query was cancelled while stuck trying to connect to the database")]
|
||||
Connect,
|
||||
}
|
||||
|
||||
impl ReportableError for SqlOverHttpCancel {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
SqlOverHttpCancel::Postgres => ErrorKind::ClientDisconnect,
|
||||
SqlOverHttpCancel::Connect => ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct HttpHeaders {
|
||||
raw_output: bool,
|
||||
default_array_mode: bool,
|
||||
txn_isolation_level: Option<IsolationLevel>,
|
||||
txn_read_only: bool,
|
||||
txn_deferrable: bool,
|
||||
}
|
||||
|
||||
impl HttpHeaders {
|
||||
fn try_parse(headers: &hyper1::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
|
||||
// Determine the output options. Default behaviour is 'false'. Anything that is not
|
||||
// strictly 'true' assumed to be false.
|
||||
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
|
||||
let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
// isolation level, read only and deferrable
|
||||
let txn_isolation_level = match headers.get(&TXN_ISOLATION_LEVEL) {
|
||||
Some(x) => Some(
|
||||
map_header_to_isolation_level(x).ok_or(SqlOverHttpError::InvalidIsolationLevel)?,
|
||||
),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
||||
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
Ok(Self {
|
||||
raw_output,
|
||||
default_array_mode,
|
||||
txn_isolation_level,
|
||||
txn_read_only,
|
||||
txn_deferrable,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn map_header_to_isolation_level(level: &HeaderValue) -> Option<IsolationLevel> {
|
||||
match level.as_bytes() {
|
||||
b"Serializable" => Some(IsolationLevel::Serializable),
|
||||
b"ReadUncommitted" => Some(IsolationLevel::ReadUncommitted),
|
||||
b"ReadCommitted" => Some(IsolationLevel::ReadCommitted),
|
||||
b"RepeatableRead" => Some(IsolationLevel::RepeatableRead),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_isolation_level_to_headers(level: IsolationLevel) -> Option<HeaderValue> {
|
||||
match level {
|
||||
IsolationLevel::ReadUncommitted => Some(HeaderValue::from_static("ReadUncommitted")),
|
||||
IsolationLevel::ReadCommitted => Some(HeaderValue::from_static("ReadCommitted")),
|
||||
IsolationLevel::RepeatableRead => Some(HeaderValue::from_static("RepeatableRead")),
|
||||
IsolationLevel::Serializable => Some(HeaderValue::from_static("Serializable")),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_inner(
|
||||
cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
|
||||
let _requeset_gauge = Metrics::get()
|
||||
.proxy
|
||||
.connection_requests
|
||||
.guard(ctx.protocol());
|
||||
info!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
|
||||
// TLS config should be there.
|
||||
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
|
||||
info!(user = conn_info.user_info.user.as_str(), "credentials");
|
||||
|
||||
// Allow connection pooling only if explicitly requested
|
||||
// or if we have decided that http pool is no longer opt-in
|
||||
let allow_pool = !config.http_config.pool_options.opt_in
|
||||
|| headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
let parsed_headers = HttpHeaders::try_parse(headers)?;
|
||||
|
||||
let request_content_length = match request.body().size_hint().upper() {
|
||||
Some(v) => v,
|
||||
None => MAX_REQUEST_SIZE + 1,
|
||||
};
|
||||
info!(request_content_length, "request size in bytes");
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_conn_content_length_bytes
|
||||
.observe(HttpDirection::Request, request_content_length as f64);
|
||||
|
||||
// we don't have a streaming request support yet so this is to prevent OOM
|
||||
// from a malicious user sending an extremely large request body
|
||||
if request_content_length > MAX_REQUEST_SIZE {
|
||||
return Err(SqlOverHttpError::RequestTooLarge);
|
||||
}
|
||||
|
||||
let fetch_and_process_request = Box::pin(
|
||||
async {
|
||||
let body = request.into_body().collect().await?.to_bytes();
|
||||
info!(length = body.len(), "request payload read");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
|
||||
}
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let keys = backend
|
||||
.authenticate(ctx, &config.authentication_config, &conn_info)
|
||||
.await?;
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.success();
|
||||
Ok::<_, HttpConnError>(client)
|
||||
}
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
|
||||
let (payload, mut client) = match run_until_cancelled(
|
||||
// Run both operations in parallel
|
||||
try_join(
|
||||
pin!(fetch_and_process_request),
|
||||
pin!(authenticate_and_connect),
|
||||
),
|
||||
&cancel,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Some(result) => result?,
|
||||
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
|
||||
};
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/json");
|
||||
|
||||
// Now execute the query and return the result.
|
||||
let json_output = match payload {
|
||||
Payload::Single(stmt) => stmt.process(cancel, &mut client, parsed_headers).await?,
|
||||
Payload::Batch(statements) => {
|
||||
if parsed_headers.txn_read_only {
|
||||
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
if parsed_headers.txn_deferrable {
|
||||
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
if let Some(txn_isolation_level) = parsed_headers
|
||||
.txn_isolation_level
|
||||
.and_then(map_isolation_level_to_headers)
|
||||
{
|
||||
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
}
|
||||
|
||||
statements
|
||||
.process(cancel, &mut client, parsed_headers)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let metrics = client.metrics();
|
||||
|
||||
let len = json_output.len();
|
||||
let response = response
|
||||
.body(Full::new(Bytes::from(json_output)))
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
|
||||
// count the egress bytes - we miss the TLS and header overhead but oh well...
|
||||
// moving this later in the stack is going to be a lot of effort and ehhhh
|
||||
metrics.record_egress(len as u64);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_conn_content_length_bytes
|
||||
.observe(HttpDirection::Response, len as f64);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
impl QueryData {
|
||||
async fn process(
|
||||
self,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
|
||||
let res = match select(
|
||||
pin!(query_to_json(&*inner, self, &mut 0, parsed_headers)),
|
||||
pin!(cancel.cancelled()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
// The query successfully completed.
|
||||
Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
|
||||
discard.check_idle(status);
|
||||
|
||||
let json_output =
|
||||
serde_json::to_string(&results).expect("json serialization should not fail");
|
||||
Ok(json_output)
|
||||
}
|
||||
// The query failed with an error
|
||||
Either::Left((Err(e), __not_yet_cancelled)) => {
|
||||
discard.discard();
|
||||
return Err(e);
|
||||
}
|
||||
// The query was cancelled.
|
||||
Either::Right((_cancelled, query)) => {
|
||||
tracing::info!("cancelling query");
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::error!(?err, "could not cancel query");
|
||||
}
|
||||
// wait for the query cancellation
|
||||
match time::timeout(time::Duration::from_millis(100), query).await {
|
||||
// query successed before it was cancelled.
|
||||
Ok(Ok((status, results))) => {
|
||||
discard.check_idle(status);
|
||||
|
||||
let json_output = serde_json::to_string(&results)
|
||||
.expect("json serialization should not fail");
|
||||
Ok(json_output)
|
||||
}
|
||||
// query failed or was cancelled.
|
||||
Ok(Err(error)) => {
|
||||
let db_error = match &error {
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// if errored for some other reason, it might not be safe to return
|
||||
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
|
||||
discard.discard();
|
||||
}
|
||||
|
||||
Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
|
||||
}
|
||||
Err(_timeout) => {
|
||||
discard.discard();
|
||||
Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl BatchQueryData {
|
||||
async fn process(
|
||||
self,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
info!("starting transaction");
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
let mut builder = inner.build_transaction();
|
||||
if let Some(isolation_level) = parsed_headers.txn_isolation_level {
|
||||
builder = builder.isolation_level(isolation_level);
|
||||
}
|
||||
if parsed_headers.txn_read_only {
|
||||
builder = builder.read_only(true);
|
||||
}
|
||||
if parsed_headers.txn_deferrable {
|
||||
builder = builder.deferrable(true);
|
||||
}
|
||||
|
||||
let transaction = builder.start().await.map_err(|e| {
|
||||
// if we cannot start a transaction, we should return immediately
|
||||
// and not return to the pool. connection is clearly broken
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
|
||||
let json_output =
|
||||
match query_batch(cancel.child_token(), &transaction, self, parsed_headers).await {
|
||||
Ok(json_output) => {
|
||||
info!("commit");
|
||||
let status = transaction.commit().await.map_err(|e| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
discard.check_idle(status);
|
||||
json_output
|
||||
}
|
||||
Err(SqlOverHttpError::Cancelled(_)) => {
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::error!(?err, "could not cancel query");
|
||||
}
|
||||
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
|
||||
discard.discard();
|
||||
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
Err(err) => {
|
||||
info!("rollback");
|
||||
let status = transaction.rollback().await.map_err(|e| {
|
||||
// if we cannot rollback - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
e
|
||||
})?;
|
||||
discard.check_idle(status);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(json_output)
|
||||
}
|
||||
}
|
||||
|
||||
async fn query_batch(
|
||||
cancel: CancellationToken,
|
||||
transaction: &Transaction<'_>,
|
||||
queries: BatchQueryData,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let mut results = Vec::with_capacity(queries.queries.len());
|
||||
let mut current_size = 0;
|
||||
for stmt in queries.queries {
|
||||
let query = pin!(query_to_json(
|
||||
transaction,
|
||||
stmt,
|
||||
&mut current_size,
|
||||
parsed_headers,
|
||||
));
|
||||
let cancelled = pin!(cancel.cancelled());
|
||||
let res = select(query, cancelled).await;
|
||||
match res {
|
||||
// TODO: maybe we should check that the transaction bit is set here
|
||||
Either::Left((Ok((_, values)), _cancelled)) => {
|
||||
results.push(values);
|
||||
}
|
||||
Either::Left((Err(e), _cancelled)) => {
|
||||
return Err(e);
|
||||
}
|
||||
Either::Right((_cancelled, _)) => {
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let results = json!({ "results": results });
|
||||
let json_output = serde_json::to_string(&results).expect("json serialization should not fail");
|
||||
|
||||
Ok(json_output)
|
||||
}
|
||||
|
||||
async fn query_to_json<T: GenericClient>(
|
||||
client: &T,
|
||||
data: QueryData,
|
||||
current_size: &mut usize,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
|
||||
info!("executing query");
|
||||
let query_params = data.params;
|
||||
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
|
||||
info!("finished executing query");
|
||||
|
||||
// Manually drain the stream into a vector to leave row_stream hanging
|
||||
// around to get a command tag. Also check that the response is not too
|
||||
// big.
|
||||
let mut rows: Vec<tokio_postgres::Row> = Vec::new();
|
||||
while let Some(row) = row_stream.next().await {
|
||||
let row = row?;
|
||||
*current_size += row.body_len();
|
||||
rows.push(row);
|
||||
// we don't have a streaming response support yet so this is to prevent OOM
|
||||
// from a malicious query (eg a cross join)
|
||||
if *current_size > MAX_RESPONSE_SIZE {
|
||||
return Err(SqlOverHttpError::ResponseTooLarge);
|
||||
}
|
||||
}
|
||||
|
||||
let ready = row_stream.ready_status();
|
||||
|
||||
// grab the command tag and number of rows affected
|
||||
let command_tag = row_stream.command_tag().unwrap_or_default();
|
||||
let mut command_tag_split = command_tag.split(' ');
|
||||
let command_tag_name = command_tag_split.next().unwrap_or_default();
|
||||
let command_tag_count = if command_tag_name == "INSERT" {
|
||||
// INSERT returns OID first and then number of rows
|
||||
command_tag_split.nth(1)
|
||||
} else {
|
||||
// other commands return number of rows (if any)
|
||||
command_tag_split.next()
|
||||
}
|
||||
.and_then(|s| s.parse::<i64>().ok());
|
||||
|
||||
info!(
|
||||
rows = rows.len(),
|
||||
?ready,
|
||||
command_tag,
|
||||
"finished reading rows"
|
||||
);
|
||||
|
||||
let columns_len = row_stream.columns().len();
|
||||
let mut fields = Vec::with_capacity(columns_len);
|
||||
let mut columns = Vec::with_capacity(columns_len);
|
||||
|
||||
for c in row_stream.columns() {
|
||||
fields.push(json!({
|
||||
"name": c.name().to_owned(),
|
||||
"dataTypeID": c.type_().oid(),
|
||||
"tableID": c.table_oid(),
|
||||
"columnID": c.column_id(),
|
||||
"dataTypeSize": c.type_size(),
|
||||
"dataTypeModifier": c.type_modifier(),
|
||||
"format": "text",
|
||||
}));
|
||||
columns.push(client.get_type(c.type_oid()).await?);
|
||||
}
|
||||
|
||||
let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
|
||||
|
||||
// convert rows to JSON
|
||||
let rows = rows
|
||||
.iter()
|
||||
.map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// Resulting JSON format is based on the format of node-postgres result.
|
||||
let results = json!({
|
||||
"command": command_tag_name.to_string(),
|
||||
"rowCount": command_tag_count,
|
||||
"rows": rows,
|
||||
"fields": fields,
|
||||
"rowAsArray": array_mode,
|
||||
});
|
||||
|
||||
Ok((ready, results))
|
||||
}
|
||||
233
proxy/core/src/serverless/websocket.rs
Normal file
233
proxy/core/src/serverless/websocket.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use crate::proxy::ErrorSource;
|
||||
use crate::{
|
||||
cancellation::CancellationHandlerMain,
|
||||
config::ProxyConfig,
|
||||
context::RequestMonitoring,
|
||||
error::{io_error, ReportableError},
|
||||
metrics::Metrics,
|
||||
proxy::{handle_client, ClientMode},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use framed_websockets::{Frame, OpCode, WebSocketServer};
|
||||
use futures::{Sink, Stream};
|
||||
use hyper1::upgrade::OnUpgrade;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::warn;
|
||||
|
||||
pin_project! {
|
||||
/// This is a wrapper around a [`WebSocketStream`] that
|
||||
/// implements [`AsyncRead`] and [`AsyncWrite`].
|
||||
pub struct WebSocketRw<S> {
|
||||
#[pin]
|
||||
stream: WebSocketServer<S>,
|
||||
recv: Bytes,
|
||||
send: BytesMut,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> WebSocketRw<S> {
|
||||
pub fn new(stream: WebSocketServer<S>) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
recv: Bytes::new(),
|
||||
send: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
let mut stream = this.stream;
|
||||
|
||||
ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
|
||||
|
||||
this.send.put(buf);
|
||||
match stream.as_mut().start_send(Frame::binary(this.send.split())) {
|
||||
Ok(()) => Poll::Ready(Ok(buf.len())),
|
||||
Err(e) => Poll::Ready(Err(io_error(e))),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let stream = self.project().stream;
|
||||
stream.poll_flush(cx).map_err(io_error)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let stream = self.project().stream;
|
||||
stream.poll_close(cx).map_err(io_error)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
||||
let len = std::cmp::min(bytes.len(), buf.remaining());
|
||||
buf.put_slice(&bytes[..len]);
|
||||
self.consume(len);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
|
||||
// Please refer to poll_fill_buf's documentation.
|
||||
const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
|
||||
|
||||
let mut this = self.project();
|
||||
loop {
|
||||
if !this.recv.chunk().is_empty() {
|
||||
let chunk = (*this.recv).chunk();
|
||||
return Poll::Ready(Ok(chunk));
|
||||
}
|
||||
|
||||
let res = ready!(this.stream.as_mut().poll_next(cx));
|
||||
match res.transpose().map_err(io_error)? {
|
||||
Some(message) => match message.opcode {
|
||||
OpCode::Ping => {}
|
||||
OpCode::Pong => {}
|
||||
OpCode::Text => {
|
||||
// We expect to see only binary messages.
|
||||
let error = "unexpected text message in the websocket";
|
||||
warn!(length = message.payload.len(), error);
|
||||
return Poll::Ready(Err(io_error(error)));
|
||||
}
|
||||
OpCode::Binary | OpCode::Continuation => {
|
||||
debug_assert!(this.recv.is_empty());
|
||||
*this.recv = message.payload.freeze();
|
||||
}
|
||||
OpCode::Close => return EOF,
|
||||
},
|
||||
None => return EOF,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(self: Pin<&mut Self>, amount: usize) {
|
||||
self.project().recv.advance(amount);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestMonitoring,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Ws);
|
||||
|
||||
let res = Box::pin(handle_client(
|
||||
config,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
))
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
// todo: log and push to ctx the error kind
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
Err(e.into())
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
Ok(())
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::pin;
|
||||
|
||||
use framed_websockets::WebSocketServer;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::{
|
||||
io::{duplex, AsyncReadExt, AsyncWriteExt},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::{protocol::Role, Message},
|
||||
WebSocketStream,
|
||||
};
|
||||
|
||||
use super::WebSocketRw;
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_stream_wrapper_happy_path() {
|
||||
let (stream1, stream2) = duplex(1024);
|
||||
|
||||
let mut js = JoinSet::new();
|
||||
|
||||
js.spawn(async move {
|
||||
let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
|
||||
|
||||
client
|
||||
.send(Message::Binary(b"hello world".to_vec()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let message = client.next().await.unwrap().unwrap();
|
||||
assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
|
||||
|
||||
client.close(None).await.unwrap();
|
||||
});
|
||||
|
||||
js.spawn(async move {
|
||||
let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
|
||||
|
||||
let mut buf = vec![0; 1024];
|
||||
let n = rw.read(&mut buf).await.unwrap();
|
||||
assert_eq!(&buf[..n], b"hello world");
|
||||
|
||||
rw.write_all(b"websockets are cool").await.unwrap();
|
||||
rw.flush().await.unwrap();
|
||||
|
||||
let n = rw.read_to_end(&mut buf).await.unwrap();
|
||||
assert_eq!(n, 0);
|
||||
});
|
||||
|
||||
js.join_next().await.unwrap().unwrap();
|
||||
js.join_next().await.unwrap().unwrap();
|
||||
}
|
||||
}
|
||||
289
proxy/core/src/stream.rs
Normal file
289
proxy/core/src/stream.rs
Normal file
@@ -0,0 +1,289 @@
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use bytes::BytesMut;
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
pub(crate) framed: Framed<S>,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
framed: Framed::new(stream),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream and read buffer.
|
||||
pub fn into_inner(self) -> (S, BytesMut) {
|
||||
self.framed.into_inner()
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
self.framed.get_ref()
|
||||
}
|
||||
}
|
||||
|
||||
fn err_connection() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
|
||||
self.framed
|
||||
.read_startup_message()
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||
self.framed
|
||||
.read_message()
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
|
||||
pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
|
||||
match self.read_message().await? {
|
||||
FeMessage::PasswordMessage(msg) => Ok(msg),
|
||||
bad => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("unexpected message type: {:?}", bad),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReportedError {
|
||||
source: anyhow::Error,
|
||||
error_kind: ErrorKind,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ReportedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.source.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ReportedError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
self.source.source()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ReportedError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
self.error_kind
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
self.framed
|
||||
.write_message(message)
|
||||
.map_err(ProtocolError::into_io_error)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer and flush it.
|
||||
pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
pub async fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
self.framed.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
||||
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
||||
pub async fn throw_error_str<T>(
|
||||
&mut self,
|
||||
msg: &'static str,
|
||||
error_kind: ErrorKind,
|
||||
) -> Result<T, ReportedError> {
|
||||
tracing::info!(
|
||||
kind = error_kind.to_metric_label(),
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
// already error case, ignore client IO error
|
||||
let _: Result<_, std::io::Error> = self
|
||||
.write_message(&BeMessage::ErrorResponse(msg, None))
|
||||
.await;
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(msg),
|
||||
error_kind,
|
||||
})
|
||||
}
|
||||
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
pub async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%error,
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
// already error case, ignore client IO error
|
||||
let _: Result<_, std::io::Error> = self
|
||||
.write_message(&BeMessage::ErrorResponse(&msg, None))
|
||||
.await;
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(error),
|
||||
error_kind,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { raw: S },
|
||||
Tls {
|
||||
/// We box [`TlsStream`] since it can be quite large.
|
||||
tls: Box<TlsStream<S>>,
|
||||
/// Channel binding parameter
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
},
|
||||
}
|
||||
|
||||
impl<S: Unpin> Unpin for Stream<S> {}
|
||||
|
||||
impl<S> Stream<S> {
|
||||
/// Construct a new instance from a raw stream.
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
Self::Raw { raw }
|
||||
}
|
||||
|
||||
/// Return SNI hostname when it's available.
|
||||
pub fn sni_hostname(&self) -> Option<&str> {
|
||||
match self {
|
||||
Stream::Raw { .. } => None,
|
||||
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
|
||||
match self {
|
||||
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
|
||||
Stream::Tls {
|
||||
tls_server_end_point,
|
||||
..
|
||||
} => *tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Can't upgrade TLS stream")]
|
||||
pub enum StreamUpgradeError {
|
||||
#[error("Bad state reached: can't upgrade TLS stream")]
|
||||
AlreadyTls,
|
||||
|
||||
#[error("Can't upgrade stream: IO error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
/// If possible, upgrade raw stream into a secure TLS-based stream.
|
||||
pub async fn upgrade(
|
||||
self,
|
||||
cfg: Arc<ServerConfig>,
|
||||
record_handshake_error: bool,
|
||||
) -> Result<TlsStream<S>, StreamUpgradeError> {
|
||||
match self {
|
||||
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
|
||||
.accept(raw)
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
if record_handshake_error {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc()
|
||||
}
|
||||
})?),
|
||||
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
|
||||
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
74
proxy/core/src/url.rs
Normal file
74
proxy/core/src/url.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use anyhow::bail;
|
||||
|
||||
/// A [url](url::Url) type with additional guarantees.
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ApiUrl(url::Url);
|
||||
|
||||
impl ApiUrl {
|
||||
/// Consume the wrapper and return inner [url](url::Url).
|
||||
pub fn into_inner(self) -> url::Url {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// See [`url::Url::path_segments_mut`].
|
||||
pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut {
|
||||
// We've already verified that it works during construction.
|
||||
self.0.path_segments_mut().expect("bad API url")
|
||||
}
|
||||
}
|
||||
|
||||
/// This instance imposes additional requirements on the url.
|
||||
impl std::str::FromStr for ApiUrl {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> anyhow::Result<Self> {
|
||||
let mut url: url::Url = s.parse()?;
|
||||
|
||||
// Make sure that we can build upon this URL.
|
||||
if url.path_segments_mut().is_err() {
|
||||
bail!("bad API url provided");
|
||||
}
|
||||
|
||||
Ok(Self(url))
|
||||
}
|
||||
}
|
||||
|
||||
/// This instance is safe because it doesn't allow us to modify the object.
|
||||
impl std::ops::Deref for ApiUrl {
|
||||
type Target = url::Url;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ApiUrl {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn bad_url() {
|
||||
let url = "test:foobar";
|
||||
url.parse::<url::Url>().expect("unexpected parsing failure");
|
||||
let _ = url.parse::<ApiUrl>().expect_err("should not parse");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn good_url() {
|
||||
let url = "test://foobar";
|
||||
let mut a = url.parse::<url::Url>().expect("unexpected parsing failure");
|
||||
let mut b = url.parse::<ApiUrl>().expect("unexpected parsing failure");
|
||||
|
||||
a.path_segments_mut().unwrap().push("method");
|
||||
b.path_segments_mut().push("method");
|
||||
|
||||
assert_eq!(a, b.into_inner());
|
||||
}
|
||||
}
|
||||
584
proxy/core/src/usage_metrics.rs
Normal file
584
proxy/core/src/usage_metrics.rs
Normal file
@@ -0,0 +1,584 @@
|
||||
//! Periodically collect proxy consumption metrics
|
||||
//! and push them to a HTTP endpoint.
|
||||
use crate::{
|
||||
config::{MetricBackupCollectionConfig, MetricCollectionConfig},
|
||||
context::parquet::{FAILED_UPLOAD_MAX_RETRIES, FAILED_UPLOAD_WARN_THRESHOLD},
|
||||
http,
|
||||
intern::{BranchIdInt, EndpointIdInt},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use async_compression::tokio::write::GzipEncoder;
|
||||
use bytes::Bytes;
|
||||
use chrono::{DateTime, Datelike, Timelike, Utc};
|
||||
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
|
||||
use dashmap::{mapref::entry::Entry, DashMap};
|
||||
use futures::future::select;
|
||||
use once_cell::sync::Lazy;
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
pin::pin,
|
||||
sync::{
|
||||
atomic::{AtomicU64, AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, instrument, trace};
|
||||
use utils::backoff;
|
||||
use uuid::{NoContext, Timestamp};
|
||||
|
||||
const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client";
|
||||
|
||||
const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
/// Key that uniquely identifies the object, this metric describes.
|
||||
/// Currently, endpoint_id is enough, but this may change later,
|
||||
/// so keep it in a named struct.
|
||||
///
|
||||
/// Both the proxy and the ingestion endpoint will live in the same region (or cell)
|
||||
/// so while the project-id is unique across regions the whole pipeline will work correctly
|
||||
/// because we enrich the event with project_id in the control-plane endpoint.
|
||||
#[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Ids {
|
||||
pub endpoint_id: EndpointIdInt,
|
||||
pub branch_id: BranchIdInt,
|
||||
}
|
||||
|
||||
pub trait MetricCounterRecorder {
|
||||
/// Record that some bytes were sent from the proxy to the client
|
||||
fn record_egress(&self, bytes: u64);
|
||||
/// Record that some connections were opened
|
||||
fn record_connection(&self, count: usize);
|
||||
}
|
||||
|
||||
trait MetricCounterReporter {
|
||||
fn get_metrics(&mut self) -> (u64, usize);
|
||||
fn move_metrics(&self) -> (u64, usize);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MetricBackupCounter {
|
||||
transmitted: AtomicU64,
|
||||
opened_connections: AtomicUsize,
|
||||
}
|
||||
|
||||
impl MetricCounterRecorder for MetricBackupCounter {
|
||||
fn record_egress(&self, bytes: u64) {
|
||||
self.transmitted.fetch_add(bytes, Ordering::AcqRel);
|
||||
}
|
||||
|
||||
fn record_connection(&self, count: usize) {
|
||||
self.opened_connections.fetch_add(count, Ordering::AcqRel);
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricCounterReporter for MetricBackupCounter {
|
||||
fn get_metrics(&mut self) -> (u64, usize) {
|
||||
(
|
||||
*self.transmitted.get_mut(),
|
||||
*self.opened_connections.get_mut(),
|
||||
)
|
||||
}
|
||||
fn move_metrics(&self) -> (u64, usize) {
|
||||
(
|
||||
self.transmitted.swap(0, Ordering::AcqRel),
|
||||
self.opened_connections.swap(0, Ordering::AcqRel),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MetricCounter {
|
||||
transmitted: AtomicU64,
|
||||
opened_connections: AtomicUsize,
|
||||
backup: Arc<MetricBackupCounter>,
|
||||
}
|
||||
|
||||
impl MetricCounterRecorder for MetricCounter {
|
||||
/// Record that some bytes were sent from the proxy to the client
|
||||
fn record_egress(&self, bytes: u64) {
|
||||
self.transmitted.fetch_add(bytes, Ordering::AcqRel);
|
||||
self.backup.record_egress(bytes);
|
||||
}
|
||||
|
||||
/// Record that some connections were opened
|
||||
fn record_connection(&self, count: usize) {
|
||||
self.opened_connections.fetch_add(count, Ordering::AcqRel);
|
||||
self.backup.record_connection(count);
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricCounterReporter for MetricCounter {
|
||||
fn get_metrics(&mut self) -> (u64, usize) {
|
||||
(
|
||||
*self.transmitted.get_mut(),
|
||||
*self.opened_connections.get_mut(),
|
||||
)
|
||||
}
|
||||
fn move_metrics(&self) -> (u64, usize) {
|
||||
(
|
||||
self.transmitted.swap(0, Ordering::AcqRel),
|
||||
self.opened_connections.swap(0, Ordering::AcqRel),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
trait Clearable {
|
||||
/// extract the value that should be reported
|
||||
fn should_report(self: &Arc<Self>) -> Option<u64>;
|
||||
/// Determine whether the counter should be cleared from the global map.
|
||||
fn should_clear(self: &mut Arc<Self>) -> bool;
|
||||
}
|
||||
|
||||
impl<C: MetricCounterReporter> Clearable for C {
|
||||
fn should_report(self: &Arc<Self>) -> Option<u64> {
|
||||
// heuristic to see if the branch is still open
|
||||
// if a clone happens while we are observing, the heuristic will be incorrect.
|
||||
//
|
||||
// Worst case is that we won't report an event for this endpoint.
|
||||
// However, for the strong count to be 1 it must have occured that at one instant
|
||||
// all the endpoints were closed, so missing a report because the endpoints are closed is valid.
|
||||
let is_open = Arc::strong_count(self) > 1;
|
||||
|
||||
// update cached metrics eagerly, even if they can't get sent
|
||||
// (to avoid sending the same metrics twice)
|
||||
// see the relevant discussion on why to do so even if the status is not success:
|
||||
// https://github.com/neondatabase/neon/pull/4563#discussion_r1246710956
|
||||
let (value, opened) = self.move_metrics();
|
||||
|
||||
// Our only requirement is that we report in every interval if there was an open connection
|
||||
// if there were no opened connections since, then we don't need to report
|
||||
if value == 0 && !is_open && opened == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(value)
|
||||
}
|
||||
}
|
||||
fn should_clear(self: &mut Arc<Self>) -> bool {
|
||||
// we can't clear this entry if it's acquired elsewhere
|
||||
let Some(counter) = Arc::get_mut(self) else {
|
||||
return false;
|
||||
};
|
||||
let (opened, value) = counter.get_metrics();
|
||||
// clear if there's no data to report
|
||||
value == 0 && opened == 0
|
||||
}
|
||||
}
|
||||
|
||||
// endpoint and branch IDs are not user generated so we don't run the risk of hash-dos
|
||||
type FastHasher = std::hash::BuildHasherDefault<rustc_hash::FxHasher>;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Metrics {
|
||||
endpoints: DashMap<Ids, Arc<MetricCounter>, FastHasher>,
|
||||
backup_endpoints: DashMap<Ids, Arc<MetricBackupCounter>, FastHasher>,
|
||||
}
|
||||
|
||||
impl Metrics {
|
||||
/// Register a new byte metrics counter for this endpoint
|
||||
pub fn register(&self, ids: Ids) -> Arc<MetricCounter> {
|
||||
let backup = if let Some(entry) = self.backup_endpoints.get(&ids) {
|
||||
entry.clone()
|
||||
} else {
|
||||
self.backup_endpoints
|
||||
.entry(ids.clone())
|
||||
.or_insert_with(|| {
|
||||
Arc::new(MetricBackupCounter {
|
||||
transmitted: AtomicU64::new(0),
|
||||
opened_connections: AtomicUsize::new(0),
|
||||
})
|
||||
})
|
||||
.clone()
|
||||
};
|
||||
|
||||
let entry = if let Some(entry) = self.endpoints.get(&ids) {
|
||||
entry.clone()
|
||||
} else {
|
||||
self.endpoints
|
||||
.entry(ids)
|
||||
.or_insert_with(|| {
|
||||
Arc::new(MetricCounter {
|
||||
transmitted: AtomicU64::new(0),
|
||||
opened_connections: AtomicUsize::new(0),
|
||||
backup: backup.clone(),
|
||||
})
|
||||
})
|
||||
.clone()
|
||||
};
|
||||
|
||||
entry.record_connection(1);
|
||||
entry
|
||||
}
|
||||
}
|
||||
|
||||
pub static USAGE_METRICS: Lazy<Metrics> = Lazy::new(Metrics::default);
|
||||
|
||||
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<Infallible> {
|
||||
info!("metrics collector config: {config:?}");
|
||||
scopeguard::defer! {
|
||||
info!("metrics collector has shut down");
|
||||
}
|
||||
|
||||
let http_client = http::new_client_with_timeout(DEFAULT_HTTP_REPORTING_TIMEOUT);
|
||||
let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned();
|
||||
|
||||
let mut prev = Utc::now();
|
||||
let mut ticker = tokio::time::interval(config.interval);
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
let now = Utc::now();
|
||||
collect_metrics_iteration(
|
||||
&USAGE_METRICS.endpoints,
|
||||
&http_client,
|
||||
&config.endpoint,
|
||||
&hostname,
|
||||
prev,
|
||||
now,
|
||||
)
|
||||
.await;
|
||||
prev = now;
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_and_clear_metrics<C: Clearable>(
|
||||
endpoints: &DashMap<Ids, Arc<C>, FastHasher>,
|
||||
) -> Vec<(Ids, u64)> {
|
||||
let mut metrics_to_clear = Vec::new();
|
||||
|
||||
let metrics_to_send: Vec<(Ids, u64)> = endpoints
|
||||
.iter()
|
||||
.filter_map(|counter| {
|
||||
let key = counter.key().clone();
|
||||
let Some(value) = counter.should_report() else {
|
||||
metrics_to_clear.push(key);
|
||||
return None;
|
||||
};
|
||||
Some((key, value))
|
||||
})
|
||||
.collect();
|
||||
|
||||
for metric in metrics_to_clear {
|
||||
match endpoints.entry(metric) {
|
||||
Entry::Occupied(mut counter) => {
|
||||
if counter.get_mut().should_clear() {
|
||||
counter.remove_entry();
|
||||
}
|
||||
}
|
||||
Entry::Vacant(_) => {}
|
||||
}
|
||||
}
|
||||
metrics_to_send
|
||||
}
|
||||
|
||||
fn create_event_chunks<'a>(
|
||||
metrics_to_send: &'a [(Ids, u64)],
|
||||
hostname: &'a str,
|
||||
prev: DateTime<Utc>,
|
||||
now: DateTime<Utc>,
|
||||
chunk_size: usize,
|
||||
) -> impl Iterator<Item = EventChunk<'a, Event<Ids, &'static str>>> + 'a {
|
||||
// Split into chunks of 1000 metrics to avoid exceeding the max request size
|
||||
metrics_to_send
|
||||
.chunks(chunk_size)
|
||||
.map(move |chunk| EventChunk {
|
||||
events: chunk
|
||||
.iter()
|
||||
.map(|(ids, value)| Event {
|
||||
kind: EventType::Incremental {
|
||||
start_time: prev,
|
||||
stop_time: now,
|
||||
},
|
||||
metric: PROXY_IO_BYTES_PER_CLIENT,
|
||||
idempotency_key: idempotency_key(hostname),
|
||||
value: *value,
|
||||
extra: ids.clone(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn collect_metrics_iteration(
|
||||
endpoints: &DashMap<Ids, Arc<MetricCounter>, FastHasher>,
|
||||
client: &http::ClientWithMiddleware,
|
||||
metric_collection_endpoint: &reqwest::Url,
|
||||
hostname: &str,
|
||||
prev: DateTime<Utc>,
|
||||
now: DateTime<Utc>,
|
||||
) {
|
||||
info!(
|
||||
"starting collect_metrics_iteration. metric_collection_endpoint: {}",
|
||||
metric_collection_endpoint
|
||||
);
|
||||
|
||||
let metrics_to_send = collect_and_clear_metrics(endpoints);
|
||||
|
||||
if metrics_to_send.is_empty() {
|
||||
trace!("no new metrics to send");
|
||||
}
|
||||
|
||||
// Send metrics.
|
||||
for chunk in create_event_chunks(&metrics_to_send, hostname, prev, now, CHUNK_SIZE) {
|
||||
let res = client
|
||||
.post(metric_collection_endpoint.clone())
|
||||
.json(&chunk)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let res = match res {
|
||||
Ok(x) => x,
|
||||
Err(err) => {
|
||||
error!("failed to send metrics: {:?}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if !res.status().is_success() {
|
||||
error!("metrics endpoint refused the sent metrics: {:?}", res);
|
||||
for metric in chunk.events.iter().filter(|e| e.value > (1u64 << 40)) {
|
||||
// Report if the metric value is suspiciously large
|
||||
error!("potentially abnormal metric value: {:?}", metric);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_backup(
|
||||
backup_config: &MetricBackupCollectionConfig,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
info!("metrics backup config: {backup_config:?}");
|
||||
scopeguard::defer! {
|
||||
info!("metrics backup has shut down");
|
||||
}
|
||||
// Even if the remote storage is not configured, we still want to clear the metrics.
|
||||
let storage = if let Some(config) = backup_config.remote_storage_config.as_ref() {
|
||||
Some(
|
||||
GenericRemoteStorage::from_config(config)
|
||||
.await
|
||||
.context("remote storage init")?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut ticker = tokio::time::interval(backup_config.interval);
|
||||
let mut prev = Utc::now();
|
||||
let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned();
|
||||
loop {
|
||||
select(pin!(ticker.tick()), pin!(cancellation_token.cancelled())).await;
|
||||
let now = Utc::now();
|
||||
collect_metrics_backup_iteration(
|
||||
&USAGE_METRICS.backup_endpoints,
|
||||
&storage,
|
||||
&hostname,
|
||||
prev,
|
||||
now,
|
||||
backup_config.chunk_size,
|
||||
)
|
||||
.await;
|
||||
|
||||
prev = now;
|
||||
if cancellation_token.is_cancelled() {
|
||||
info!("metrics backup has been cancelled");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn collect_metrics_backup_iteration(
|
||||
endpoints: &DashMap<Ids, Arc<MetricBackupCounter>, FastHasher>,
|
||||
storage: &Option<GenericRemoteStorage>,
|
||||
hostname: &str,
|
||||
prev: DateTime<Utc>,
|
||||
now: DateTime<Utc>,
|
||||
chunk_size: usize,
|
||||
) {
|
||||
let year = now.year();
|
||||
let month = now.month();
|
||||
let day = now.day();
|
||||
let hour = now.hour();
|
||||
let minute = now.minute();
|
||||
let second = now.second();
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
info!("starting collect_metrics_backup_iteration");
|
||||
|
||||
let metrics_to_send = collect_and_clear_metrics(endpoints);
|
||||
|
||||
if metrics_to_send.is_empty() {
|
||||
trace!("no new metrics to send");
|
||||
}
|
||||
|
||||
// Send metrics.
|
||||
for chunk in create_event_chunks(&metrics_to_send, hostname, prev, now, chunk_size) {
|
||||
let real_now = Utc::now();
|
||||
let id = uuid::Uuid::new_v7(Timestamp::from_unix(
|
||||
NoContext,
|
||||
real_now.second().into(),
|
||||
real_now.nanosecond(),
|
||||
));
|
||||
let path = format!("year={year:04}/month={month:02}/day={day:02}/{hour:02}:{minute:02}:{second:02}Z_{id}.json.gz");
|
||||
let remote_path = match RemotePath::from_string(&path) {
|
||||
Ok(remote_path) => remote_path,
|
||||
Err(e) => {
|
||||
error!("failed to create remote path from str {path}: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let res = upload_events_chunk(storage, chunk, &remote_path, &cancel).await;
|
||||
|
||||
if let Err(e) = res {
|
||||
error!(
|
||||
"failed to upload consumption events to remote storage: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn upload_events_chunk(
|
||||
storage: &Option<GenericRemoteStorage>,
|
||||
chunk: EventChunk<'_, Event<Ids, &'static str>>,
|
||||
remote_path: &RemotePath,
|
||||
cancel: &CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let storage = match storage {
|
||||
Some(storage) => storage,
|
||||
None => {
|
||||
error!("no remote storage configured");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let data = serde_json::to_vec(&chunk).context("serialize metrics")?;
|
||||
let mut encoder = GzipEncoder::new(Vec::new());
|
||||
encoder.write_all(&data).await.context("compress metrics")?;
|
||||
encoder.shutdown().await.context("compress metrics")?;
|
||||
let compressed_data: Bytes = encoder.get_ref().clone().into();
|
||||
backoff::retry(
|
||||
|| async {
|
||||
let stream = futures::stream::once(futures::future::ready(Ok(compressed_data.clone())));
|
||||
storage
|
||||
.upload(stream, compressed_data.len(), remote_path, None, cancel)
|
||||
.await
|
||||
},
|
||||
TimeoutOrCancel::caused_by_cancel,
|
||||
FAILED_UPLOAD_WARN_THRESHOLD,
|
||||
FAILED_UPLOAD_MAX_RETRIES,
|
||||
"request_data_upload",
|
||||
cancel,
|
||||
)
|
||||
.await
|
||||
.ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel))
|
||||
.and_then(|x| x)
|
||||
.context("request_data_upload")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
net::TcpListener,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use anyhow::Error;
|
||||
use chrono::Utc;
|
||||
use consumption_metrics::{Event, EventChunk};
|
||||
use hyper::{
|
||||
service::{make_service_fn, service_fn},
|
||||
Body, Response,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use super::*;
|
||||
use crate::{http, BranchId, EndpointId};
|
||||
|
||||
#[tokio::test]
|
||||
async fn metrics() {
|
||||
let listener = TcpListener::bind("0.0.0.0:0").unwrap();
|
||||
|
||||
let reports = Arc::new(Mutex::new(vec![]));
|
||||
let reports2 = reports.clone();
|
||||
|
||||
let server = hyper::server::Server::from_tcp(listener)
|
||||
.unwrap()
|
||||
.serve(make_service_fn(move |_| {
|
||||
let reports = reports.clone();
|
||||
async move {
|
||||
Ok::<_, Error>(service_fn(move |req| {
|
||||
let reports = reports.clone();
|
||||
async move {
|
||||
let bytes = hyper::body::to_bytes(req.into_body()).await?;
|
||||
let events: EventChunk<'static, Event<Ids, String>> =
|
||||
serde_json::from_slice(&bytes)?;
|
||||
reports.lock().unwrap().push(events);
|
||||
Ok::<_, Error>(Response::new(Body::from(vec![])))
|
||||
}
|
||||
}))
|
||||
}
|
||||
}));
|
||||
let addr = server.local_addr();
|
||||
tokio::spawn(server);
|
||||
|
||||
let metrics = Metrics::default();
|
||||
let client = http::new_client();
|
||||
let endpoint = Url::parse(&format!("http://{addr}")).unwrap();
|
||||
let now = Utc::now();
|
||||
|
||||
// no counters have been registered
|
||||
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
|
||||
let r = std::mem::take(&mut *reports2.lock().unwrap());
|
||||
assert!(r.is_empty());
|
||||
|
||||
// register a new counter
|
||||
|
||||
let counter = metrics.register(Ids {
|
||||
endpoint_id: (&EndpointId::from("e1")).into(),
|
||||
branch_id: (&BranchId::from("b1")).into(),
|
||||
});
|
||||
|
||||
// the counter should be observed despite 0 egress
|
||||
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
|
||||
let r = std::mem::take(&mut *reports2.lock().unwrap());
|
||||
assert_eq!(r.len(), 1);
|
||||
assert_eq!(r[0].events.len(), 1);
|
||||
assert_eq!(r[0].events[0].value, 0);
|
||||
|
||||
// record egress
|
||||
counter.record_egress(1);
|
||||
|
||||
// egress should be observered
|
||||
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
|
||||
let r = std::mem::take(&mut *reports2.lock().unwrap());
|
||||
assert_eq!(r.len(), 1);
|
||||
assert_eq!(r[0].events.len(), 1);
|
||||
assert_eq!(r[0].events[0].value, 1);
|
||||
|
||||
// release counter
|
||||
drop(counter);
|
||||
|
||||
// we do not observe the counter
|
||||
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
|
||||
let r = std::mem::take(&mut *reports2.lock().unwrap());
|
||||
assert!(r.is_empty());
|
||||
|
||||
// counter is unregistered
|
||||
assert!(metrics.endpoints.is_empty());
|
||||
|
||||
collect_metrics_backup_iteration(&metrics.backup_endpoints, &None, "foo", now, now, 1000)
|
||||
.await;
|
||||
assert!(!metrics.backup_endpoints.is_empty());
|
||||
collect_metrics_backup_iteration(&metrics.backup_endpoints, &None, "foo", now, now, 1000)
|
||||
.await;
|
||||
// backup counter is unregistered after the second iteration
|
||||
assert!(metrics.backup_endpoints.is_empty());
|
||||
}
|
||||
}
|
||||
121
proxy/core/src/waiters.rs
Normal file
121
proxy/core/src/waiters.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use hashbrown::HashMap;
|
||||
use parking_lot::Mutex;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::pin::Pin;
|
||||
use std::task;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RegisterError {
|
||||
#[error("Waiter `{0}` already registered")]
|
||||
Occupied(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NotifyError {
|
||||
#[error("Notify failed: waiter `{0}` not registered")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Notify failed: channel hangup")]
|
||||
Hangup,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WaitError {
|
||||
#[error("Wait failed: channel hangup")]
|
||||
Hangup,
|
||||
}
|
||||
|
||||
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
|
||||
|
||||
impl<T> Default for Waiters<T> {
|
||||
fn default() -> Self {
|
||||
Waiters(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Waiters<T> {
|
||||
pub fn register(&self, key: String) -> Result<Waiter<T>, RegisterError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
self.0
|
||||
.lock()
|
||||
.try_insert(key.clone(), tx)
|
||||
.map_err(|e| RegisterError::Occupied(e.entry.key().clone()))?;
|
||||
|
||||
Ok(Waiter {
|
||||
receiver: rx,
|
||||
guard: DropKey {
|
||||
registry: self,
|
||||
key,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn notify(&self, key: &str, value: T) -> Result<(), NotifyError>
|
||||
where
|
||||
T: Send + Sync,
|
||||
{
|
||||
let tx = self
|
||||
.0
|
||||
.lock()
|
||||
.remove(key)
|
||||
.ok_or_else(|| NotifyError::NotFound(key.to_string()))?;
|
||||
|
||||
tx.send(value).map_err(|_| NotifyError::Hangup)
|
||||
}
|
||||
}
|
||||
|
||||
struct DropKey<'a, T> {
|
||||
key: String,
|
||||
registry: &'a Waiters<T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Drop for DropKey<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
self.registry.0.lock().remove(&self.key);
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct Waiter<'a, T> {
|
||||
#[pin]
|
||||
receiver: oneshot::Receiver<T>,
|
||||
guard: DropKey<'a, T>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::future::Future for Waiter<'_, T> {
|
||||
type Output = Result<T, WaitError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
|
||||
self.project()
|
||||
.receiver
|
||||
.poll(cx)
|
||||
.map_err(|_| WaitError::Hangup)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_waiter() -> anyhow::Result<()> {
|
||||
let waiters = Arc::new(Waiters::default());
|
||||
|
||||
let key = "Key";
|
||||
let waiter = waiters.register(key.to_owned())?;
|
||||
|
||||
let waiters = Arc::clone(&waiters);
|
||||
let notifier = tokio::spawn(async move {
|
||||
waiters.notify(key, ())?;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
waiter.await?;
|
||||
notifier.await?
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user