Merge pull request #7853 from neondatabase/rc/proxy/2024-05-23

Proxy release 2024-05-23
This commit is contained in:
Anna Khanova
2024-05-23 12:09:13 +02:00
committed by GitHub
260 changed files with 11048 additions and 4860 deletions

View File

@@ -13,7 +13,7 @@ use tokio_postgres::config::AuthKeys;
use tracing::{info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::validate_password_and_exchange;
use crate::auth::{validate_password_and_exchange, AuthError};
use crate::cache::Cached;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
@@ -23,7 +23,7 @@ use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, RateBucketInfo};
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::stream::Stream;
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
@@ -280,6 +280,7 @@ async fn auth_quirks(
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
@@ -305,6 +306,10 @@ async fn auth_quirks(
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
}
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
return Err(AuthError::too_many_connections());
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => api.get_role_secret(ctx, &info).await?,
@@ -360,7 +365,10 @@ async fn authenticate_with_secret(
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let auth_outcome = validate_password_and_exchange(&password, secret).await?;
let ep = EndpointIdInt::from(&info.endpoint);
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
@@ -381,7 +389,7 @@ async fn authenticate_with_secret(
// Currently, we use it for websocket connections (latency).
if allow_cleartext {
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
}
// Finally, proceed with the main auth flow (SCRAM-based).
@@ -417,6 +425,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
use BackendType::*;
@@ -428,8 +437,16 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
"performing authentication using the console"
);
let credentials =
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
let credentials = auth_quirks(
ctx,
&*api,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
BackendType::Console(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
@@ -539,8 +556,8 @@ mod tests {
},
context::RequestMonitoring,
proxy::NeonOptions,
rate_limiter::RateBucketInfo,
scram::ServerSecret,
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
};
@@ -582,6 +599,7 @@ mod tests {
}
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
@@ -699,10 +717,20 @@ mod tests {
_ => panic!("wrong message"),
}
});
let endpoint_rate_limiter =
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, &CONFIG)
.await
.unwrap();
let _creds = auth_quirks(
&mut ctx,
&api,
user_info,
&mut stream,
false,
&CONFIG,
endpoint_rate_limiter,
)
.await
.unwrap();
handle.await.unwrap();
}
@@ -739,10 +767,20 @@ mod tests {
frontend::password_message(b"my-secret-password", &mut write).unwrap();
client.write_all(&write).await.unwrap();
});
let endpoint_rate_limiter =
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
.await
.unwrap();
let _creds = auth_quirks(
&mut ctx,
&api,
user_info,
&mut stream,
true,
&CONFIG,
endpoint_rate_limiter,
)
.await
.unwrap();
handle.await.unwrap();
}
@@ -780,9 +818,20 @@ mod tests {
client.write_all(&write).await.unwrap();
});
let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
.await
.unwrap();
let endpoint_rate_limiter =
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
let creds = auth_quirks(
&mut ctx,
&api,
user_info,
&mut stream,
true,
&CONFIG,
endpoint_rate_limiter,
)
.await
.unwrap();
assert_eq!(creds.info.endpoint, "my-endpoint");

View File

@@ -3,8 +3,10 @@ use super::{
};
use crate::{
auth::{self, AuthFlow},
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
sasl,
stream::{self, Stream},
};
@@ -20,6 +22,7 @@ pub async fn authenticate_cleartext(
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
secret: AuthSecret,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
warn!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
@@ -27,8 +30,14 @@ pub async fn authenticate_cleartext(
// pause the timer while we communicate with the client
let paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let ep = EndpointIdInt::from(&info.endpoint);
let auth_flow = AuthFlow::new(client)
.begin(auth::CleartextPassword(secret))
.begin(auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
})
.await?;
drop(paused);
// cleartext auth is only allowed to the ws/http protocol.

View File

@@ -5,12 +5,14 @@ use crate::{
config::TlsServerEndPoint,
console::AuthSecret,
context::RequestMonitoring,
sasl, scram,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use std::{io, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -53,7 +55,11 @@ impl AuthMethod for PasswordHack {
/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub struct CleartextPassword(pub AuthSecret);
pub struct CleartextPassword {
pub pool: Arc<ThreadPool>,
pub endpoint: EndpointIdInt,
pub secret: AuthSecret,
}
impl AuthMethod for CleartextPassword {
#[inline(always)]
@@ -126,7 +132,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
let outcome = validate_password_and_exchange(password, self.state.0).await?;
let outcome = validate_password_and_exchange(
&self.state.pool,
self.state.endpoint,
password,
self.state.secret,
)
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
@@ -181,6 +193,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
@@ -194,7 +208,7 @@ pub(crate) async fn validate_password_and_exchange(
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(&scram_secret, password).await?;
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -27,6 +27,7 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::elasticache;
use proxy::redis::notifications;
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
@@ -132,6 +133,9 @@ struct ProxyCliArgs {
/// timeout for scram authentication protocol
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
scram_protocol_timeout: tokio::time::Duration,
/// size of the threadpool for password hashing
#[clap(long, default_value_t = 4)]
scram_thread_pool_size: u8,
/// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
require_client_ip: bool,
@@ -144,6 +148,9 @@ struct ProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
@@ -154,7 +161,7 @@ struct ProxyCliArgs {
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
@@ -365,6 +372,10 @@ async fn main() -> anyhow::Result<()> {
proxy::metrics::CancellationSource::FromClient,
));
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit));
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
@@ -373,6 +384,7 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
// TODO: rename the argument to something like serverless.
@@ -387,6 +399,7 @@ async fn main() -> anyhow::Result<()> {
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
@@ -480,6 +493,9 @@ async fn main() -> anyhow::Result<()> {
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
key_path,
@@ -559,11 +575,16 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit));
let api =
console::provider::neon::Api::new(endpoint, caches, locks, endpoint_rate_limiter);
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(EndpointRateLimiter::new(wake_compute_rps_limit));
let api = console::provider::neon::Api::new(
endpoint,
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let api = console::provider::ConsoleBackend::Console(api);
auth::BackendType::Console(MaybeOwned::Owned(api), ())
}
@@ -610,6 +631,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
};
let authentication_config = AuthenticationConfig {
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),

View File

@@ -2,6 +2,7 @@ use crate::{
auth::{self, backend::AuthRateLimiter},
console::locks::ApiLocks,
rate_limiter::RateBucketInfo,
scram::threadpool::ThreadPool,
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
Host,
};
@@ -61,6 +62,7 @@ pub struct HttpConfig {
}
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,

View File

@@ -26,7 +26,7 @@ pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub locks: &'static ApiLocks<EndpointCacheKey>,
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
pub wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
jwt: String,
}
@@ -36,7 +36,7 @@ impl Api {
endpoint: http::Endpoint,
caches: &'static ApiCaches,
locks: &'static ApiLocks<EndpointCacheKey>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Self {
let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
Ok(v) => v,
@@ -46,7 +46,7 @@ impl Api {
endpoint,
caches,
locks,
endpoint_rate_limiter,
wake_compute_endpoint_rate_limiter,
jwt,
}
}
@@ -283,7 +283,7 @@ impl super::Api for Api {
// check rate limit
if !self
.endpoint_rate_limiter
.wake_compute_endpoint_rate_limiter
.check(user_info.endpoint.normalize().into(), 1)
{
return Err(WakeComputeError::TooManyConnections);

View File

@@ -307,7 +307,7 @@ where
}
async fn upload_parquet(
w: SerializedFileWriter<Writer<BytesMut>>,
mut w: SerializedFileWriter<Writer<BytesMut>>,
len: i64,
storage: &GenericRemoteStorage,
) -> anyhow::Result<Writer<BytesMut>> {
@@ -319,11 +319,15 @@ async fn upload_parquet(
// I don't know how compute intensive this is, although it probably isn't much... better be safe than sorry.
// finish method only available on the fork: https://github.com/apache/arrow-rs/issues/5253
let (writer, metadata) = tokio::task::spawn_blocking(move || w.finish())
let (mut buffer, metadata) =
tokio::task::spawn_blocking(move || -> parquet::errors::Result<_> {
let metadata = w.finish()?;
let buffer = std::mem::take(w.inner_mut().get_mut());
Ok((buffer, metadata))
})
.await
.unwrap()?;
let mut buffer = writer.into_inner();
let data = buffer.split().freeze();
let compression = len as f64 / len_uncompressed as f64;
@@ -351,7 +355,7 @@ async fn upload_parquet(
"{year:04}/{month:02}/{day:02}/{hour:02}/requests_{id}.parquet"
))?;
let cancel = CancellationToken::new();
backoff::retry(
let maybe_err = backoff::retry(
|| async {
let stream = futures::stream::once(futures::future::ready(Ok(data.clone())));
storage
@@ -368,7 +372,12 @@ async fn upload_parquet(
.await
.ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel))
.and_then(|x| x)
.context("request_data_upload")?;
.context("request_data_upload")
.err();
if let Some(err) = maybe_err {
tracing::warn!(%id, %err, "failed to upload request data");
}
Ok(buffer.writer())
}
@@ -474,10 +483,11 @@ mod tests {
RequestData {
session_id: uuid::Builder::from_random_bytes(rng.gen()).into_uuid(),
peer_addr: Ipv4Addr::from(rng.gen::<[u8; 4]>()).to_string(),
timestamp: chrono::NaiveDateTime::from_timestamp_millis(
timestamp: chrono::DateTime::from_timestamp_millis(
rng.gen_range(1703862754..1803862754),
)
.unwrap(),
.unwrap()
.naive_utc(),
application_name: Some("test".to_owned()),
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
@@ -560,15 +570,15 @@ mod tests {
assert_eq!(
file_stats,
[
(1315008, 3, 6000),
(1315001, 3, 6000),
(1315061, 3, 6000),
(1315018, 3, 6000),
(1315148, 3, 6000),
(1314990, 3, 6000),
(1314782, 3, 6000),
(1315018, 3, 6000),
(438575, 1, 2000)
(1315314, 3, 6000),
(1315307, 3, 6000),
(1315367, 3, 6000),
(1315324, 3, 6000),
(1315454, 3, 6000),
(1315296, 3, 6000),
(1315088, 3, 6000),
(1315324, 3, 6000),
(438713, 1, 2000)
]
);
@@ -598,11 +608,11 @@ mod tests {
assert_eq!(
file_stats,
[
(1221738, 5, 10000),
(1227888, 5, 10000),
(1229682, 5, 10000),
(1229044, 5, 10000),
(1220322, 5, 10000)
(1222212, 5, 10000),
(1228362, 5, 10000),
(1230156, 5, 10000),
(1229518, 5, 10000),
(1220796, 5, 10000)
]
);
@@ -634,11 +644,11 @@ mod tests {
assert_eq!(
file_stats,
[
(1207385, 5, 10000),
(1207116, 5, 10000),
(1207409, 5, 10000),
(1207397, 5, 10000),
(1207652, 5, 10000)
(1207859, 5, 10000),
(1207590, 5, 10000),
(1207883, 5, 10000),
(1207871, 5, 10000),
(1208126, 5, 10000)
]
);
@@ -663,15 +673,15 @@ mod tests {
assert_eq!(
file_stats,
[
(1315008, 3, 6000),
(1315001, 3, 6000),
(1315061, 3, 6000),
(1315018, 3, 6000),
(1315148, 3, 6000),
(1314990, 3, 6000),
(1314782, 3, 6000),
(1315018, 3, 6000),
(438575, 1, 2000)
(1315314, 3, 6000),
(1315307, 3, 6000),
(1315367, 3, 6000),
(1315324, 3, 6000),
(1315454, 3, 6000),
(1315296, 3, 6000),
(1315088, 3, 6000),
(1315324, 3, 6000),
(438713, 1, 2000)
]
);
@@ -708,7 +718,7 @@ mod tests {
// files are smaller than the size threshold, but they took too long to fill so were flushed early
assert_eq!(
file_stats,
[(659240, 2, 3001), (658954, 2, 3000), (658750, 2, 2999)]
[(659462, 2, 3001), (659176, 2, 3000), (658972, 2, 2999)]
);
tmpdir.close().unwrap();

View File

@@ -1,11 +1,11 @@
use std::sync::OnceLock;
use std::sync::{Arc, OnceLock};
use lasso::ThreadedRodeo;
use measured::{
label::StaticLabelSet,
label::{FixedCardinalitySet, LabelName, LabelSet, LabelValue, StaticLabelSet},
metric::{histogram::Thresholds, name::MetricName},
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
MetricGroup,
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
LabelGroup, MetricGroup,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
@@ -14,26 +14,36 @@ use tokio::time::{self, Instant};
use crate::console::messages::ColdStartInfo;
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
pub struct Metrics {
#[metric(namespace = "proxy")]
#[metric(init = ProxyMetrics::new(thread_pool))]
pub proxy: ProxyMetrics,
#[metric(namespace = "wake_compute_lock")]
pub wake_compute_lock: ApiLockMetrics,
}
static SELF: OnceLock<Metrics> = OnceLock::new();
impl Metrics {
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
SELF.set(Metrics::new(thread_pool))
.ok()
.expect("proxy metrics must not be installed more than once");
}
pub fn get() -> &'static Self {
static SELF: OnceLock<Metrics> = OnceLock::new();
SELF.get_or_init(|| Metrics {
proxy: ProxyMetrics::default(),
wake_compute_lock: ApiLockMetrics::new(),
})
#[cfg(test)]
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
#[cfg(not(test))]
SELF.get()
.expect("proxy metrics must be installed by the main() function")
}
}
#[derive(MetricGroup)]
#[metric(new())]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
pub struct ProxyMetrics {
#[metric(flatten)]
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
@@ -129,6 +139,10 @@ pub struct ProxyMetrics {
#[metric(namespace = "connect_compute_lock")]
pub connect_compute_lock: ApiLockMetrics,
#[metric(namespace = "scram_pool")]
#[metric(init = thread_pool)]
pub scram_pool: Arc<ThreadPoolMetrics>,
}
#[derive(MetricGroup)]
@@ -146,12 +160,6 @@ pub struct ApiLockMetrics {
pub semaphore_acquire_seconds: Histogram<16>,
}
impl Default for ProxyMetrics {
fn default() -> Self {
Self::new()
}
}
impl Default for ApiLockMetrics {
fn default() -> Self {
Self::new()
@@ -553,3 +561,52 @@ pub enum RedisEventsCount {
PasswordUpdate,
AllowedIpsUpdate,
}
pub struct ThreadPoolWorkers(usize);
pub struct ThreadPoolWorkerId(pub usize);
impl LabelValue for ThreadPoolWorkerId {
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
v.write_int(self.0 as i64)
}
}
impl LabelGroup for ThreadPoolWorkerId {
fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) {
v.write_value(LabelName::from_str("worker"), self);
}
}
impl LabelSet for ThreadPoolWorkers {
type Value<'a> = ThreadPoolWorkerId;
fn dynamic_cardinality(&self) -> Option<usize> {
Some(self.0)
}
fn encode(&self, value: Self::Value<'_>) -> Option<usize> {
(value.0 < self.0).then_some(value.0)
}
fn decode(&self, value: usize) -> Self::Value<'_> {
ThreadPoolWorkerId(value)
}
}
impl FixedCardinalitySet for ThreadPoolWorkers {
fn cardinality(&self) -> usize {
self.0
}
}
#[derive(MetricGroup)]
#[metric(new(workers: usize))]
pub struct ThreadPoolMetrics {
pub injector_queue_depth: Gauge,
#[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_queue_depth: GaugeVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_turns_total: CounterVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
}

View File

@@ -19,6 +19,7 @@ use crate::{
metrics::{Metrics, NumClientConnectionsGuard},
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey,
};
@@ -61,6 +62,7 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -86,6 +88,7 @@ pub async fn task_main(
let cancellation_handler = Arc::clone(&cancellation_handler);
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await{
@@ -123,6 +126,7 @@ pub async fn task_main(
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
)
.instrument(span.clone())
@@ -234,6 +238,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
info!(
@@ -243,7 +248,6 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol;
// let _client_gauge = metrics.client_connections.guard(proto);
let _request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.as_ref();
@@ -286,6 +290,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.await
{

View File

@@ -128,12 +128,18 @@ impl std::str::FromStr for RateBucketInfo {
}
impl RateBucketInfo {
pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
pub const DEFAULT_SET: [Self; 3] = [
Self::new(300, Duration::from_secs(1)),
Self::new(200, Duration::from_secs(60)),
Self::new(100, Duration::from_secs(600)),
];
pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
Self::new(500, Duration::from_secs(1)),
Self::new(300, Duration::from_secs(60)),
Self::new(200, Duration::from_secs(600)),
];
pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
info.sort_unstable_by_key(|info| info.interval);
let invalid = info
@@ -266,7 +272,7 @@ mod tests {
#[test]
fn default_rate_buckets() {
let mut defaults = RateBucketInfo::DEFAULT_ENDPOINT_SET;
let mut defaults = RateBucketInfo::DEFAULT_SET;
RateBucketInfo::validate(&mut defaults[..]).unwrap();
}
@@ -333,11 +339,8 @@ mod tests {
let rand = rand::rngs::StdRng::from_seed([1; 32]);
let hasher = BuildHasherDefault::<FxHasher>::default();
let limiter = BucketRateLimiter::new_with_rand_and_hasher(
&RateBucketInfo::DEFAULT_ENDPOINT_SET,
rand,
hasher,
);
let limiter =
BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
for i in 0..1_000_000 {
limiter.check(i, 1);
}

View File

@@ -6,11 +6,14 @@
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
mod countmin;
mod exchange;
mod key;
mod messages;
mod pbkdf2;
mod secret;
mod signature;
pub mod threadpool;
pub use exchange::{exchange, Exchange};
pub use key::ScramKey;
@@ -56,9 +59,13 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
#[cfg(test)]
mod tests {
use crate::sasl::{Mechanism, Step};
use crate::{
intern::EndpointIdInt,
sasl::{Mechanism, Step},
EndpointId,
};
use super::{Exchange, ServerSecret};
use super::{threadpool::ThreadPool, Exchange, ServerSecret};
#[test]
fn snapshot() {
@@ -112,8 +119,13 @@ mod tests {
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let outcome = super::exchange(&scram_secret, client_password.as_bytes())
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
.await
.unwrap();

173
proxy/src/scram/countmin.rs Normal file
View File

@@ -0,0 +1,173 @@
use std::hash::Hash;
/// estimator of hash jobs per second.
/// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
pub struct CountMinSketch {
// one for each depth
hashers: Vec<ahash::RandomState>,
width: usize,
depth: usize,
// buckets, width*depth
buckets: Vec<u32>,
}
impl CountMinSketch {
/// Given parameters (ε, δ),
/// set width = ceil(e/ε)
/// set depth = ceil(ln(1/δ))
///
/// guarantees:
/// actual <= estimate
/// estimate <= actual + ε * N with probability 1 - δ
/// where N is the cardinality of the stream
pub fn with_params(epsilon: f64, delta: f64) -> Self {
CountMinSketch::new(
(std::f64::consts::E / epsilon).ceil() as usize,
(1.0_f64 / delta).ln().ceil() as usize,
)
}
fn new(width: usize, depth: usize) -> Self {
Self {
#[cfg(test)]
hashers: (0..depth)
.map(|i| {
// digits of pi for good randomness
ahash::RandomState::with_seeds(
314159265358979323,
84626433832795028,
84197169399375105,
82097494459230781 + i as u64,
)
})
.collect(),
#[cfg(not(test))]
hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
width,
depth,
buckets: vec![0; width * depth],
}
}
pub fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
let mut min = u32::MAX;
for row in 0..self.depth {
let col = (self.hashers[row].hash_one(t) as usize) % self.width;
let row = &mut self.buckets[row * self.width..][..self.width];
row[col] = row[col].saturating_add(x);
min = std::cmp::min(min, row[col]);
}
min
}
pub fn reset(&mut self) {
self.buckets.clear();
self.buckets.resize(self.width * self.depth, 0);
}
}
#[cfg(test)]
mod tests {
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
use super::CountMinSketch;
fn eval_precision(n: usize, p: f64, q: f64) -> usize {
// fixed value of phi for consistent test
let mut rng = StdRng::seed_from_u64(16180339887498948482);
#[allow(non_snake_case)]
let mut N = 0;
let mut ids = vec![];
for _ in 0..n {
// number of insert operations
let n = rng.gen_range(1..100);
// number to insert at once
let m = rng.gen_range(1..4096);
let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
ids.push((id, n, m));
// N = sum(actual)
N += n * m;
}
// q% of counts will be within p of the actual value
let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
dbg!(sketch.buckets.len());
// insert a bunch of entries in a random order
let mut ids2 = ids.clone();
while !ids2.is_empty() {
ids2.shuffle(&mut rng);
let mut i = 0;
while i < ids2.len() {
sketch.inc_and_return(&ids2[i].0, ids2[i].1);
ids2[i].2 -= 1;
if ids2[i].2 == 0 {
ids2.remove(i);
} else {
i += 1;
}
}
}
let mut within_p = 0;
for (id, n, m) in ids {
let actual = n * m;
let estimate = sketch.inc_and_return(&id, 0);
// This estimate has the guarantee that actual <= estimate
assert!(actual <= estimate);
// This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
// ε = p / N, δ = 1 - q;
// therefore, estimate <= actual + p with probability q.
if estimate as f64 <= actual as f64 + p {
within_p += 1;
}
}
within_p
}
#[test]
fn precision() {
assert_eq!(eval_precision(100, 100.0, 0.99), 100);
assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
// seems to be more precise than the literature indicates?
// probably numbers are too small to truly represent the probabilities.
assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
assert_eq!(eval_precision(100, 4096.0, 0.1), 98);
assert_eq!(eval_precision(1000, 4096.0, 0.1), 991);
}
// returns memory usage in bytes, and the time complexity per insert.
fn eval_cost(p: f64, q: f64) -> (usize, usize) {
#[allow(non_snake_case)]
// N = sum(actual)
// Let's assume 1021 samples, all of 4096
let N = 1021 * 4096;
let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
let memory = std::mem::size_of::<u32>() * sketch.buckets.len();
let time = sketch.depth;
(memory, time)
}
#[test]
fn memory_usage() {
assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
}
}

View File

@@ -4,15 +4,17 @@ use std::convert::Infallible;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use tokio::task::yield_now;
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
use super::pbkdf2::Pbkdf2;
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use super::ScramKey;
use crate::config;
use crate::intern::EndpointIdInt;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
/// The only channel binding mode we currently support.
@@ -74,37 +76,18 @@ impl<'a> Exchange<'a> {
}
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
let hmac = Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let mut prev = hmac
.clone()
.chain_update(salt)
.chain_update(1u32.to_be_bytes())
.finalize()
.into_bytes();
let mut hi = prev;
for i in 1..iterations {
prev = hmac.clone().chain_update(prev).finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(prev) {
*hi ^= prev;
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
yield_now().await
}
}
hi.into()
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey {
let salted_password = pbkdf2(password, salt, iterations).await;
async fn derive_client_key(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
salt: &[u8],
iterations: u32,
) -> ScramKey {
let salted_password = pool
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await
.expect("job should not be cancelled");
let make_key = |name| {
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
@@ -119,11 +102,13 @@ async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> Scr
}
pub async fn exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = base64::decode(&secret.salt_base64)?;
let client_key = derive_client_key(password, &salt, secret.iterations).await;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))

89
proxy/src/scram/pbkdf2.rs Normal file
View File

@@ -0,0 +1,89 @@
use hmac::{
digest::{consts::U32, generic_array::GenericArray},
Hmac, Mac,
};
use sha2::Sha256;
pub struct Pbkdf2 {
hmac: Hmac<Sha256>,
prev: GenericArray<u8, U32>,
hi: GenericArray<u8, U32>,
iterations: u32,
}
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
impl Pbkdf2 {
pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
let hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let prev = hmac
.clone()
.chain_update(salt)
.chain_update(1u32.to_be_bytes())
.finalize()
.into_bytes();
Self {
hmac,
// one consumed for the hash above
iterations: iterations - 1,
hi: prev,
prev,
}
}
pub fn cost(&self) -> u32 {
(self.iterations).clamp(0, 4096)
}
pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
let Self {
hmac,
prev,
hi,
iterations,
} = self;
// only do 4096 iterations per turn before sharing the thread for fairness
let n = (*iterations).clamp(0, 4096);
for _ in 0..n {
*prev = hmac.clone().chain_update(*prev).finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(*prev) {
*hi ^= prev;
}
}
*iterations -= n;
if *iterations == 0 {
std::task::Poll::Ready((*hi).into())
} else {
std::task::Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use super::Pbkdf2;
use pbkdf2::pbkdf2_hmac_array;
use sha2::Sha256;
#[test]
fn works() {
let salt = b"sodium chloride";
let pass = b"Ne0n_!5_50_C007";
let mut job = Pbkdf2::start(pass, salt, 600000);
let hash = loop {
let std::task::Poll::Ready(hash) = job.turn() else {
continue;
};
break hash;
};
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 600000);
assert_eq!(hash, expected)
}
}

View File

@@ -0,0 +1,321 @@
//! Custom threadpool implementation for password hashing.
//!
//! Requirements:
//! 1. Fairness per endpoint.
//! 2. Yield support for high iteration counts.
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use crossbeam_deque::{Injector, Stealer, Worker};
use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use rand::Rng;
use rand::{rngs::SmallRng, SeedableRng};
use tokio::sync::oneshot;
use crate::{
intern::EndpointIdInt,
metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
scram::countmin::CountMinSketch,
};
use super::pbkdf2::Pbkdf2;
pub struct ThreadPool {
queue: Injector<JobSpec>,
stealers: Vec<Stealer<JobSpec>>,
parkers: Vec<(Condvar, Mutex<ThreadState>)>,
/// bitpacked representation.
/// lower 8 bits = number of sleeping threads
/// next 8 bits = number of idle threads (searching for work)
counters: AtomicU64,
pub metrics: Arc<ThreadPoolMetrics>,
}
#[derive(PartialEq)]
enum ThreadState {
Parked,
Active,
}
impl ThreadPool {
pub fn new(n_workers: u8) -> Arc<Self> {
let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
let parkers = (0..n_workers)
.map(|_| (Condvar::new(), Mutex::new(ThreadState::Active)))
.collect_vec();
let pool = Arc::new(Self {
queue: Injector::new(),
stealers,
parkers,
// threads start searching for work
counters: AtomicU64::new((n_workers as u64) << 8),
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
});
for (i, worker) in workers.into_iter().enumerate() {
let pool = Arc::clone(&pool);
std::thread::spawn(move || thread_rt(pool, worker, i));
}
pool
}
pub fn spawn_job(
&self,
endpoint: EndpointIdInt,
pbkdf2: Pbkdf2,
) -> oneshot::Receiver<[u8; 32]> {
let (tx, rx) = oneshot::channel();
let queue_was_empty = self.queue.is_empty();
self.metrics.injector_queue_depth.inc();
self.queue.push(JobSpec {
response: tx,
pbkdf2,
endpoint,
});
// inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
let counts = self.counters.load(Ordering::SeqCst);
let num_awake_but_idle = (counts >> 8) & 0xff;
let num_sleepers = counts & 0xff;
// If the queue is non-empty, then we always wake up a worker
// -- clearly the existing idle jobs aren't enough. Otherwise,
// check to see if we have enough idle workers.
if !queue_was_empty || num_awake_but_idle == 0 {
let num_to_wake = Ord::min(1, num_sleepers);
self.wake_any_threads(num_to_wake);
}
rx
}
#[cold]
fn wake_any_threads(&self, mut num_to_wake: u64) {
if num_to_wake > 0 {
for i in 0..self.parkers.len() {
if self.wake_specific_thread(i) {
num_to_wake -= 1;
if num_to_wake == 0 {
return;
}
}
}
}
}
fn wake_specific_thread(&self, index: usize) -> bool {
let (condvar, lock) = &self.parkers[index];
let mut state = lock.lock();
if *state == ThreadState::Parked {
condvar.notify_one();
// When the thread went to sleep, it will have incremented
// this value. When we wake it, its our job to decrement
// it. We could have the thread do it, but that would
// introduce a delay between when the thread was
// *notified* and when this counter was decremented. That
// might mislead people with new work into thinking that
// there are sleeping threads that they should try to
// wake, when in fact there is nothing left for them to
// do.
self.counters.fetch_sub(1, Ordering::SeqCst);
*state = ThreadState::Active;
true
} else {
false
}
}
fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
// announce thread as idle
self.counters.fetch_add(256, Ordering::SeqCst);
// try steal from the global queue
loop {
match self.queue.steal_batch_and_pop(worker) {
crossbeam_deque::Steal::Success(job) => {
self.metrics
.injector_queue_depth
.set(self.queue.len() as i64);
// no longer idle
self.counters.fetch_sub(256, Ordering::SeqCst);
return Some(job);
}
crossbeam_deque::Steal::Retry => continue,
crossbeam_deque::Steal::Empty => break,
}
}
// try steal from our neighbours
loop {
let mut retry = false;
let start = rng.gen_range(0..self.stealers.len());
let job = (start..self.stealers.len())
.chain(0..start)
.filter(|i| *i != skip)
.find_map(
|victim| match self.stealers[victim].steal_batch_and_pop(worker) {
crossbeam_deque::Steal::Success(job) => Some(job),
crossbeam_deque::Steal::Empty => None,
crossbeam_deque::Steal::Retry => {
retry = true;
None
}
},
);
if job.is_some() {
// no longer idle
self.counters.fetch_sub(256, Ordering::SeqCst);
return job;
}
if !retry {
return None;
}
}
}
}
fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
/// interval when we should steal from the global queue
/// so that tail latencies are managed appropriately
const STEAL_INTERVAL: usize = 61;
/// How often to reset the sketch values
const SKETCH_RESET_INTERVAL: usize = 1021;
let mut rng = SmallRng::from_entropy();
// used to determine whether we should temporarily skip tasks for fairness.
// 99% of estimates will overcount by no more than 4096 samples
let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01);
let (condvar, lock) = &pool.parkers[index];
'wait: loop {
// wait for notification of work
{
let mut lock = lock.lock();
// queue is empty
pool.metrics
.worker_queue_depth
.set(ThreadPoolWorkerId(index), 0);
// subtract 1 from idle count, add 1 to sleeping count.
pool.counters.fetch_sub(255, Ordering::SeqCst);
*lock = ThreadState::Parked;
condvar.wait(&mut lock);
}
for i in 0.. {
let mut job = match worker
.pop()
.or_else(|| pool.steal(&mut rng, index, &worker))
{
Some(job) => job,
None => continue 'wait,
};
pool.metrics
.worker_queue_depth
.set(ThreadPoolWorkerId(index), worker.len() as i64);
// receiver is closed, cancel the task
if !job.response.is_closed() {
let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
const P: f64 = 2000.0;
// probability decreases as rate increases.
// lower probability, higher chance of being skipped
//
// estimates (rate in terms of 4096 rounds):
// rate = 0 => probability = 100%
// rate = 10 => probability = 71.3%
// rate = 50 => probability = 62.1%
// rate = 500 => probability = 52.3%
// rate = 1021 => probability = 49.8%
//
// My expectation is that the pool queue will only begin backing up at ~1000rps
// in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
// are in requests per second.
let probability = P.ln() / (P + rate as f64).ln();
if pool.queue.len() > 32 || rng.gen_bool(probability) {
pool.metrics
.worker_task_turns_total
.inc(ThreadPoolWorkerId(index));
match job.pbkdf2.turn() {
std::task::Poll::Ready(result) => {
let _ = job.response.send(result);
}
std::task::Poll::Pending => worker.push(job),
}
} else {
pool.metrics
.worker_task_skips_total
.inc(ThreadPoolWorkerId(index));
// skip for now
worker.push(job)
}
}
// if we get stuck with a few long lived jobs in the queue
// it's better to try and steal from the queue too for fairness
if i % STEAL_INTERVAL == 0 {
let _ = pool.queue.steal_batch(&worker);
}
if i % SKETCH_RESET_INTERVAL == 0 {
sketch.reset();
}
}
}
}
struct JobSpec {
response: oneshot::Sender<[u8; 32]>,
pbkdf2: Pbkdf2,
endpoint: EndpointIdInt,
}
#[cfg(test)]
mod tests {
use crate::EndpointId;
use super::*;
#[tokio::test]
async fn hash_is_correct() {
let pool = ThreadPool::new(1);
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let salt = [0x55; 32];
let actual = pool
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
.await
.unwrap();
let expected = [
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
];
assert_eq!(actual, expected)
}
}

View File

@@ -36,6 +36,7 @@ use crate::context::RequestMonitoring;
use crate::metrics::Metrics;
use crate::protocol2::read_proxy_protocol;
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
@@ -54,6 +55,7 @@ pub async fn task_main(
ws_listener: TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");
@@ -82,6 +84,7 @@ pub async fn task_main(
let backend = Arc::new(PoolingBackend {
pool: Arc::clone(&conn_pool),
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_config = match config.tls_config.as_ref() {
@@ -99,7 +102,7 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let server = Builder::new(hyper_util::rt::TokioExecutor::new());
let server = Builder::new(TokioExecutor::new());
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
@@ -129,6 +132,7 @@ pub async fn task_main(
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
conn_token.clone(),
server.clone(),
tls_acceptor.clone(),
@@ -162,6 +166,7 @@ async fn connection_handler(
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
server: Builder<TokioExecutor>,
tls_acceptor: TlsAcceptor,
@@ -245,11 +250,11 @@ async fn connection_handler(
session_id,
peer_addr,
http_request_token,
endpoint_rate_limiter.clone(),
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
);
async move {
let res = handler.await;
cancel_request.disarm();
@@ -285,6 +290,7 @@ async fn request_handler(
peer_addr: IpAddr,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Full<Bytes>>, ApiError> {
let host = request
.headers()
@@ -294,7 +300,7 @@ async fn request_handler(
.map(|s| s.to_string());
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&request) {
if framed_websockets::upgrade::is_upgrade_request(&request) {
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
@@ -305,14 +311,20 @@ async fn request_handler(
let span = ctx.span.clone();
info!(parent: &span, "performing websocket upgrade");
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
ws_connections.spawn(
async move {
if let Err(e) =
websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host)
.await
if let Err(e) = websocket::serve_websocket(
config,
ctx,
websocket,
cancellation_handler,
endpoint_rate_limiter,
host,
)
.await
{
error!("error in websocket connection: {e:#}");
}
@@ -321,7 +333,7 @@ async fn request_handler(
);
// Return the response so the spawned future can continue.
Ok(response)
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,

View File

@@ -15,7 +15,9 @@ use crate::{
},
context::RequestMonitoring,
error::{ErrorKind, ReportableError, UserFacingError},
intern::EndpointIdInt,
proxy::{connect_compute::ConnectMechanism, retry::ShouldRetry},
rate_limiter::EndpointRateLimiter,
Host,
};
@@ -24,6 +26,7 @@ use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
pub struct PoolingBackend {
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub config: &'static ProxyConfig,
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
impl PoolingBackend {
@@ -39,6 +42,12 @@ impl PoolingBackend {
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
}
if !self
.endpoint_rate_limiter
.check(conn_info.user_info.endpoint.clone().into(), 1)
{
return Err(AuthError::too_many_connections());
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => backend.get_role_secret(ctx).await?,
@@ -58,8 +67,14 @@ impl PoolingBackend {
return Err(AuthError::auth_failed(&*user_info.user));
}
};
let auth_outcome =
crate::auth::validate_password_and_exchange(&conn_info.password, secret).await?;
let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&config.thread_pool,
ep,
&conn_info.password,
secret,
)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");

View File

@@ -5,11 +5,13 @@ use crate::{
error::{io_error, ReportableError},
metrics::Metrics,
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
use bytes::{Buf, Bytes};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
use hyper1::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use std::{
@@ -20,25 +22,23 @@ use std::{
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tracing::warn;
// TODO: use `std::sync::Exclusive` once it's stabilized.
// Tracking issue: https://github.com/rust-lang/rust/issues/98407.
use sync_wrapper::SyncWrapper;
pin_project! {
/// This is a wrapper around a [`WebSocketStream`] that
/// implements [`AsyncRead`] and [`AsyncWrite`].
pub struct WebSocketRw<S = Upgraded> {
pub struct WebSocketRw<S> {
#[pin]
stream: SyncWrapper<WebSocketStream<S>>,
bytes: Bytes,
stream: WebSocketServer<S>,
recv: Bytes,
send: BytesMut,
}
}
impl<S> WebSocketRw<S> {
pub fn new(stream: WebSocketStream<S>) -> Self {
pub fn new(stream: WebSocketServer<S>) -> Self {
Self {
stream: stream.into(),
bytes: Bytes::new(),
stream,
recv: Bytes::new(),
send: BytesMut::new(),
}
}
}
@@ -49,22 +49,24 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut stream = self.project().stream.get_pin_mut();
let this = self.project();
let mut stream = this.stream;
this.send.put(buf);
ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
match stream.as_mut().start_send(Message::Binary(buf.into())) {
match stream.as_mut().start_send(Frame::binary(this.send.split())) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(io_error(e))),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream = self.project().stream.get_pin_mut();
let stream = self.project().stream;
stream.poll_flush(cx).map_err(io_error)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream = self.project().stream.get_pin_mut();
let stream = self.project().stream;
stream.poll_close(cx).map_err(io_error)
}
}
@@ -75,13 +77,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() > 0 {
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
self.consume(len);
}
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
self.consume(len);
Poll::Ready(Ok(()))
}
}
@@ -93,31 +92,27 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
let mut this = self.project();
loop {
if !this.bytes.chunk().is_empty() {
let chunk = (*this.bytes).chunk();
if !this.recv.chunk().is_empty() {
let chunk = (*this.recv).chunk();
return Poll::Ready(Ok(chunk));
}
let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
let res = ready!(this.stream.as_mut().poll_next(cx));
match res.transpose().map_err(io_error)? {
Some(message) => match message {
Message::Ping(_) => {}
Message::Pong(_) => {}
Message::Text(text) => {
Some(message) => match message.opcode {
OpCode::Ping => {}
OpCode::Pong => {}
OpCode::Text => {
// We expect to see only binary messages.
let error = "unexpected text message in the websocket";
warn!(length = text.len(), error);
warn!(length = message.payload.len(), error);
return Poll::Ready(Err(io_error(error)));
}
Message::Frame(_) => {
// This case is impossible according to Frame's doc.
panic!("unexpected raw frame in the websocket");
OpCode::Binary | OpCode::Continuation => {
debug_assert!(this.recv.is_empty());
*this.recv = message.payload.freeze();
}
Message::Binary(chunk) => {
assert!(this.bytes.is_empty());
*this.bytes = Bytes::from(chunk);
}
Message::Close(_) => return EOF,
OpCode::Close => return EOF,
},
None => return EOF,
}
@@ -125,18 +120,21 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
}
fn consume(self: Pin<&mut Self>, amount: usize) {
self.project().bytes.advance(amount);
self.project().recv.advance(amount);
}
}
pub async fn serve_websocket(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
websocket: HyperWebsocket,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
let conn_gauge = Metrics::get()
.proxy
.client_connections
@@ -148,6 +146,7 @@ pub async fn serve_websocket(
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
)
.await;
@@ -174,15 +173,16 @@ pub async fn serve_websocket(
mod tests {
use std::pin::pin;
use framed_websockets::WebSocketServer;
use futures::{SinkExt, StreamExt};
use hyper_tungstenite::{
tungstenite::{protocol::Role, Message},
WebSocketStream,
};
use tokio::{
io::{duplex, AsyncReadExt, AsyncWriteExt},
task::JoinSet,
};
use tokio_tungstenite::{
tungstenite::{protocol::Role, Message},
WebSocketStream,
};
use super::WebSocketRw;
@@ -207,9 +207,7 @@ mod tests {
});
js.spawn(async move {
let mut rw = pin!(WebSocketRw::new(
WebSocketStream::from_raw_socket(stream2, Role::Server, None).await
));
let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
let mut buf = vec![0; 1024];
let n = rw.read(&mut buf).await.unwrap();