mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-25 00:50:36 +00:00
Merge pull request #7853 from neondatabase/rc/proxy/2024-05-23
Proxy release 2024-05-23
This commit is contained in:
@@ -13,7 +13,7 @@ use tokio_postgres::config::AuthKeys;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::validate_password_and_exchange;
|
||||
use crate::auth::{validate_password_and_exchange, AuthError};
|
||||
use crate::cache::Cached;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
|
||||
@@ -23,7 +23,7 @@ use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::rate_limiter::{BucketRateLimiter, RateBucketInfo};
|
||||
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
auth::{self, ComputeUserInfoMaybeEndpoint},
|
||||
@@ -280,6 +280,7 @@ async fn auth_quirks(
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
@@ -305,6 +306,10 @@ async fn auth_quirks(
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
|
||||
}
|
||||
|
||||
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => api.get_role_secret(ctx, &info).await?,
|
||||
@@ -360,7 +365,10 @@ async fn authenticate_with_secret(
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let auth_outcome = validate_password_and_exchange(&password, secret).await?;
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
|
||||
let auth_outcome =
|
||||
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
|
||||
let keys = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => key,
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
@@ -381,7 +389,7 @@ async fn authenticate_with_secret(
|
||||
// Currently, we use it for websocket connections (latency).
|
||||
if allow_cleartext {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
|
||||
return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
|
||||
}
|
||||
|
||||
// Finally, proceed with the main auth flow (SCRAM-based).
|
||||
@@ -417,6 +425,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
|
||||
use BackendType::*;
|
||||
|
||||
@@ -428,8 +437,16 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let credentials =
|
||||
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
|
||||
let credentials = auth_quirks(
|
||||
ctx,
|
||||
&*api,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
BackendType::Console(api, credentials)
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
@@ -539,8 +556,8 @@ mod tests {
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
proxy::NeonOptions,
|
||||
rate_limiter::RateBucketInfo,
|
||||
scram::ServerSecret,
|
||||
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
|
||||
scram::{threadpool::ThreadPool, ServerSecret},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
|
||||
@@ -582,6 +599,7 @@ mod tests {
|
||||
}
|
||||
|
||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||
thread_pool: ThreadPool::new(1),
|
||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||
rate_limiter_enabled: true,
|
||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||
@@ -699,10 +717,20 @@ mod tests {
|
||||
_ => panic!("wrong message"),
|
||||
}
|
||||
});
|
||||
let endpoint_rate_limiter =
|
||||
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
|
||||
|
||||
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, &CONFIG)
|
||||
.await
|
||||
.unwrap();
|
||||
let _creds = auth_quirks(
|
||||
&mut ctx,
|
||||
&api,
|
||||
user_info,
|
||||
&mut stream,
|
||||
false,
|
||||
&CONFIG,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
@@ -739,10 +767,20 @@ mod tests {
|
||||
frontend::password_message(b"my-secret-password", &mut write).unwrap();
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
let endpoint_rate_limiter =
|
||||
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
|
||||
|
||||
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
|
||||
.await
|
||||
.unwrap();
|
||||
let _creds = auth_quirks(
|
||||
&mut ctx,
|
||||
&api,
|
||||
user_info,
|
||||
&mut stream,
|
||||
true,
|
||||
&CONFIG,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
@@ -780,9 +818,20 @@ mod tests {
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
|
||||
let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
|
||||
.await
|
||||
.unwrap();
|
||||
let endpoint_rate_limiter =
|
||||
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
|
||||
|
||||
let creds = auth_quirks(
|
||||
&mut ctx,
|
||||
&api,
|
||||
user_info,
|
||||
&mut stream,
|
||||
true,
|
||||
&CONFIG,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(creds.info.endpoint, "my-endpoint");
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ use super::{
|
||||
};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow},
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
intern::EndpointIdInt,
|
||||
sasl,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
@@ -20,6 +22,7 @@ pub async fn authenticate_cleartext(
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
secret: AuthSecret,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
@@ -27,8 +30,14 @@ pub async fn authenticate_cleartext(
|
||||
// pause the timer while we communicate with the client
|
||||
let paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
|
||||
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
|
||||
let auth_flow = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword(secret))
|
||||
.begin(auth::CleartextPassword {
|
||||
secret,
|
||||
endpoint: ep,
|
||||
pool: config.thread_pool.clone(),
|
||||
})
|
||||
.await?;
|
||||
drop(paused);
|
||||
// cleartext auth is only allowed to the ws/http protocol.
|
||||
|
||||
@@ -5,12 +5,14 @@ use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
sasl, scram,
|
||||
intern::EndpointIdInt,
|
||||
sasl,
|
||||
scram::{self, threadpool::ThreadPool},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use std::io;
|
||||
use std::{io, sync::Arc};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
@@ -53,7 +55,11 @@ impl AuthMethod for PasswordHack {
|
||||
|
||||
/// Use clear-text password auth called `password` in docs
|
||||
/// <https://www.postgresql.org/docs/current/auth-password.html>
|
||||
pub struct CleartextPassword(pub AuthSecret);
|
||||
pub struct CleartextPassword {
|
||||
pub pool: Arc<ThreadPool>,
|
||||
pub endpoint: EndpointIdInt,
|
||||
pub secret: AuthSecret,
|
||||
}
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
@@ -126,7 +132,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
|
||||
let outcome = validate_password_and_exchange(password, self.state.0).await?;
|
||||
let outcome = validate_password_and_exchange(
|
||||
&self.state.pool,
|
||||
self.state.endpoint,
|
||||
password,
|
||||
self.state.secret,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
@@ -181,6 +193,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_password_and_exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
password: &[u8],
|
||||
secret: AuthSecret,
|
||||
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
@@ -194,7 +208,7 @@ pub(crate) async fn validate_password_and_exchange(
|
||||
}
|
||||
// perform scram authentication as both client and server to validate the keys
|
||||
AuthSecret::Scram(scram_secret) => {
|
||||
let outcome = crate::scram::exchange(&scram_secret, password).await?;
|
||||
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
||||
|
||||
let client_key = match outcome {
|
||||
sasl::Outcome::Success(client_key) => client_key,
|
||||
|
||||
@@ -27,6 +27,7 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient;
|
||||
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use proxy::redis::elasticache;
|
||||
use proxy::redis::notifications;
|
||||
use proxy::scram::threadpool::ThreadPool;
|
||||
use proxy::serverless::cancel_set::CancelSet;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::usage_metrics;
|
||||
@@ -132,6 +133,9 @@ struct ProxyCliArgs {
|
||||
/// timeout for scram authentication protocol
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
scram_protocol_timeout: tokio::time::Duration,
|
||||
/// size of the threadpool for password hashing
|
||||
#[clap(long, default_value_t = 4)]
|
||||
scram_thread_pool_size: u8,
|
||||
/// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
require_client_ip: bool,
|
||||
@@ -144,6 +148,9 @@ struct ProxyCliArgs {
|
||||
/// Can be given multiple times for different bucket sizes.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
|
||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Wake compute rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
wake_compute_limit: Vec<RateBucketInfo>,
|
||||
/// Whether the auth rate limiter actually takes effect (for testing)
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
auth_rate_limit_enabled: bool,
|
||||
@@ -154,7 +161,7 @@ struct ProxyCliArgs {
|
||||
#[clap(long, default_value_t = 64)]
|
||||
auth_rate_limit_ip_subnet: u8,
|
||||
/// Redis rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
redis_rps_limit: Vec<RateBucketInfo>,
|
||||
/// cache for `allowed_ips` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
@@ -365,6 +372,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
proxy::metrics::CancellationSource::FromClient,
|
||||
));
|
||||
|
||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit));
|
||||
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
@@ -373,6 +384,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
@@ -387,6 +399,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -480,6 +493,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
||||
Metrics::install(thread_pool.metrics.clone());
|
||||
|
||||
let tls_config = match (&args.tls_key, &args.tls_cert) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||
key_path,
|
||||
@@ -559,11 +575,16 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let endpoint = http::Endpoint::new(url, http::new_client());
|
||||
|
||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit));
|
||||
let api =
|
||||
console::provider::neon::Api::new(endpoint, caches, locks, endpoint_rate_limiter);
|
||||
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
|
||||
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
|
||||
let wake_compute_endpoint_rate_limiter =
|
||||
Arc::new(EndpointRateLimiter::new(wake_compute_rps_limit));
|
||||
let api = console::provider::neon::Api::new(
|
||||
endpoint,
|
||||
caches,
|
||||
locks,
|
||||
wake_compute_endpoint_rate_limiter,
|
||||
);
|
||||
let api = console::provider::ConsoleBackend::Console(api);
|
||||
auth::BackendType::Console(MaybeOwned::Owned(api), ())
|
||||
}
|
||||
@@ -610,6 +631,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::{
|
||||
auth::{self, backend::AuthRateLimiter},
|
||||
console::locks::ApiLocks,
|
||||
rate_limiter::RateBucketInfo,
|
||||
scram::threadpool::ThreadPool,
|
||||
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
|
||||
Host,
|
||||
};
|
||||
@@ -61,6 +62,7 @@ pub struct HttpConfig {
|
||||
}
|
||||
|
||||
pub struct AuthenticationConfig {
|
||||
pub thread_pool: Arc<ThreadPool>,
|
||||
pub scram_protocol_timeout: tokio::time::Duration,
|
||||
pub rate_limiter_enabled: bool,
|
||||
pub rate_limiter: AuthRateLimiter,
|
||||
|
||||
@@ -26,7 +26,7 @@ pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
pub locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
pub wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
jwt: String,
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ impl Api {
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Self {
|
||||
let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
|
||||
Ok(v) => v,
|
||||
@@ -46,7 +46,7 @@ impl Api {
|
||||
endpoint,
|
||||
caches,
|
||||
locks,
|
||||
endpoint_rate_limiter,
|
||||
wake_compute_endpoint_rate_limiter,
|
||||
jwt,
|
||||
}
|
||||
}
|
||||
@@ -283,7 +283,7 @@ impl super::Api for Api {
|
||||
|
||||
// check rate limit
|
||||
if !self
|
||||
.endpoint_rate_limiter
|
||||
.wake_compute_endpoint_rate_limiter
|
||||
.check(user_info.endpoint.normalize().into(), 1)
|
||||
{
|
||||
return Err(WakeComputeError::TooManyConnections);
|
||||
|
||||
@@ -307,7 +307,7 @@ where
|
||||
}
|
||||
|
||||
async fn upload_parquet(
|
||||
w: SerializedFileWriter<Writer<BytesMut>>,
|
||||
mut w: SerializedFileWriter<Writer<BytesMut>>,
|
||||
len: i64,
|
||||
storage: &GenericRemoteStorage,
|
||||
) -> anyhow::Result<Writer<BytesMut>> {
|
||||
@@ -319,11 +319,15 @@ async fn upload_parquet(
|
||||
|
||||
// I don't know how compute intensive this is, although it probably isn't much... better be safe than sorry.
|
||||
// finish method only available on the fork: https://github.com/apache/arrow-rs/issues/5253
|
||||
let (writer, metadata) = tokio::task::spawn_blocking(move || w.finish())
|
||||
let (mut buffer, metadata) =
|
||||
tokio::task::spawn_blocking(move || -> parquet::errors::Result<_> {
|
||||
let metadata = w.finish()?;
|
||||
let buffer = std::mem::take(w.inner_mut().get_mut());
|
||||
Ok((buffer, metadata))
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
let mut buffer = writer.into_inner();
|
||||
let data = buffer.split().freeze();
|
||||
|
||||
let compression = len as f64 / len_uncompressed as f64;
|
||||
@@ -351,7 +355,7 @@ async fn upload_parquet(
|
||||
"{year:04}/{month:02}/{day:02}/{hour:02}/requests_{id}.parquet"
|
||||
))?;
|
||||
let cancel = CancellationToken::new();
|
||||
backoff::retry(
|
||||
let maybe_err = backoff::retry(
|
||||
|| async {
|
||||
let stream = futures::stream::once(futures::future::ready(Ok(data.clone())));
|
||||
storage
|
||||
@@ -368,7 +372,12 @@ async fn upload_parquet(
|
||||
.await
|
||||
.ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel))
|
||||
.and_then(|x| x)
|
||||
.context("request_data_upload")?;
|
||||
.context("request_data_upload")
|
||||
.err();
|
||||
|
||||
if let Some(err) = maybe_err {
|
||||
tracing::warn!(%id, %err, "failed to upload request data");
|
||||
}
|
||||
|
||||
Ok(buffer.writer())
|
||||
}
|
||||
@@ -474,10 +483,11 @@ mod tests {
|
||||
RequestData {
|
||||
session_id: uuid::Builder::from_random_bytes(rng.gen()).into_uuid(),
|
||||
peer_addr: Ipv4Addr::from(rng.gen::<[u8; 4]>()).to_string(),
|
||||
timestamp: chrono::NaiveDateTime::from_timestamp_millis(
|
||||
timestamp: chrono::DateTime::from_timestamp_millis(
|
||||
rng.gen_range(1703862754..1803862754),
|
||||
)
|
||||
.unwrap(),
|
||||
.unwrap()
|
||||
.naive_utc(),
|
||||
application_name: Some("test".to_owned()),
|
||||
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
|
||||
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
@@ -560,15 +570,15 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1315008, 3, 6000),
|
||||
(1315001, 3, 6000),
|
||||
(1315061, 3, 6000),
|
||||
(1315018, 3, 6000),
|
||||
(1315148, 3, 6000),
|
||||
(1314990, 3, 6000),
|
||||
(1314782, 3, 6000),
|
||||
(1315018, 3, 6000),
|
||||
(438575, 1, 2000)
|
||||
(1315314, 3, 6000),
|
||||
(1315307, 3, 6000),
|
||||
(1315367, 3, 6000),
|
||||
(1315324, 3, 6000),
|
||||
(1315454, 3, 6000),
|
||||
(1315296, 3, 6000),
|
||||
(1315088, 3, 6000),
|
||||
(1315324, 3, 6000),
|
||||
(438713, 1, 2000)
|
||||
]
|
||||
);
|
||||
|
||||
@@ -598,11 +608,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1221738, 5, 10000),
|
||||
(1227888, 5, 10000),
|
||||
(1229682, 5, 10000),
|
||||
(1229044, 5, 10000),
|
||||
(1220322, 5, 10000)
|
||||
(1222212, 5, 10000),
|
||||
(1228362, 5, 10000),
|
||||
(1230156, 5, 10000),
|
||||
(1229518, 5, 10000),
|
||||
(1220796, 5, 10000)
|
||||
]
|
||||
);
|
||||
|
||||
@@ -634,11 +644,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1207385, 5, 10000),
|
||||
(1207116, 5, 10000),
|
||||
(1207409, 5, 10000),
|
||||
(1207397, 5, 10000),
|
||||
(1207652, 5, 10000)
|
||||
(1207859, 5, 10000),
|
||||
(1207590, 5, 10000),
|
||||
(1207883, 5, 10000),
|
||||
(1207871, 5, 10000),
|
||||
(1208126, 5, 10000)
|
||||
]
|
||||
);
|
||||
|
||||
@@ -663,15 +673,15 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1315008, 3, 6000),
|
||||
(1315001, 3, 6000),
|
||||
(1315061, 3, 6000),
|
||||
(1315018, 3, 6000),
|
||||
(1315148, 3, 6000),
|
||||
(1314990, 3, 6000),
|
||||
(1314782, 3, 6000),
|
||||
(1315018, 3, 6000),
|
||||
(438575, 1, 2000)
|
||||
(1315314, 3, 6000),
|
||||
(1315307, 3, 6000),
|
||||
(1315367, 3, 6000),
|
||||
(1315324, 3, 6000),
|
||||
(1315454, 3, 6000),
|
||||
(1315296, 3, 6000),
|
||||
(1315088, 3, 6000),
|
||||
(1315324, 3, 6000),
|
||||
(438713, 1, 2000)
|
||||
]
|
||||
);
|
||||
|
||||
@@ -708,7 +718,7 @@ mod tests {
|
||||
// files are smaller than the size threshold, but they took too long to fill so were flushed early
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[(659240, 2, 3001), (658954, 2, 3000), (658750, 2, 2999)]
|
||||
[(659462, 2, 3001), (659176, 2, 3000), (658972, 2, 2999)]
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use lasso::ThreadedRodeo;
|
||||
use measured::{
|
||||
label::StaticLabelSet,
|
||||
label::{FixedCardinalitySet, LabelName, LabelSet, LabelValue, StaticLabelSet},
|
||||
metric::{histogram::Thresholds, name::MetricName},
|
||||
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
|
||||
MetricGroup,
|
||||
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
|
||||
LabelGroup, MetricGroup,
|
||||
};
|
||||
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
|
||||
|
||||
@@ -14,26 +14,36 @@ use tokio::time::{self, Instant};
|
||||
use crate::console::messages::ColdStartInfo;
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
pub struct Metrics {
|
||||
#[metric(namespace = "proxy")]
|
||||
#[metric(init = ProxyMetrics::new(thread_pool))]
|
||||
pub proxy: ProxyMetrics,
|
||||
|
||||
#[metric(namespace = "wake_compute_lock")]
|
||||
pub wake_compute_lock: ApiLockMetrics,
|
||||
}
|
||||
|
||||
static SELF: OnceLock<Metrics> = OnceLock::new();
|
||||
impl Metrics {
|
||||
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
|
||||
SELF.set(Metrics::new(thread_pool))
|
||||
.ok()
|
||||
.expect("proxy metrics must not be installed more than once");
|
||||
}
|
||||
|
||||
pub fn get() -> &'static Self {
|
||||
static SELF: OnceLock<Metrics> = OnceLock::new();
|
||||
SELF.get_or_init(|| Metrics {
|
||||
proxy: ProxyMetrics::default(),
|
||||
wake_compute_lock: ApiLockMetrics::new(),
|
||||
})
|
||||
#[cfg(test)]
|
||||
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
|
||||
|
||||
#[cfg(not(test))]
|
||||
SELF.get()
|
||||
.expect("proxy metrics must be installed by the main() function")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new())]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
pub struct ProxyMetrics {
|
||||
#[metric(flatten)]
|
||||
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
|
||||
@@ -129,6 +139,10 @@ pub struct ProxyMetrics {
|
||||
|
||||
#[metric(namespace = "connect_compute_lock")]
|
||||
pub connect_compute_lock: ApiLockMetrics,
|
||||
|
||||
#[metric(namespace = "scram_pool")]
|
||||
#[metric(init = thread_pool)]
|
||||
pub scram_pool: Arc<ThreadPoolMetrics>,
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
@@ -146,12 +160,6 @@ pub struct ApiLockMetrics {
|
||||
pub semaphore_acquire_seconds: Histogram<16>,
|
||||
}
|
||||
|
||||
impl Default for ProxyMetrics {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ApiLockMetrics {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
@@ -553,3 +561,52 @@ pub enum RedisEventsCount {
|
||||
PasswordUpdate,
|
||||
AllowedIpsUpdate,
|
||||
}
|
||||
|
||||
pub struct ThreadPoolWorkers(usize);
|
||||
pub struct ThreadPoolWorkerId(pub usize);
|
||||
|
||||
impl LabelValue for ThreadPoolWorkerId {
|
||||
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
|
||||
v.write_int(self.0 as i64)
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelGroup for ThreadPoolWorkerId {
|
||||
fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) {
|
||||
v.write_value(LabelName::from_str("worker"), self);
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelSet for ThreadPoolWorkers {
|
||||
type Value<'a> = ThreadPoolWorkerId;
|
||||
|
||||
fn dynamic_cardinality(&self) -> Option<usize> {
|
||||
Some(self.0)
|
||||
}
|
||||
|
||||
fn encode(&self, value: Self::Value<'_>) -> Option<usize> {
|
||||
(value.0 < self.0).then_some(value.0)
|
||||
}
|
||||
|
||||
fn decode(&self, value: usize) -> Self::Value<'_> {
|
||||
ThreadPoolWorkerId(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl FixedCardinalitySet for ThreadPoolWorkers {
|
||||
fn cardinality(&self) -> usize {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(workers: usize))]
|
||||
pub struct ThreadPoolMetrics {
|
||||
pub injector_queue_depth: Gauge,
|
||||
#[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_queue_depth: GaugeVec<ThreadPoolWorkers>,
|
||||
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_task_turns_total: CounterVec<ThreadPoolWorkers>,
|
||||
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ use crate::{
|
||||
metrics::{Metrics, NumClientConnectionsGuard},
|
||||
protocol2::read_proxy_protocol,
|
||||
proxy::handshake::{handshake, HandshakeData},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
EndpointCacheKey,
|
||||
};
|
||||
@@ -61,6 +62,7 @@ pub async fn task_main(
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
@@ -86,6 +88,7 @@ pub async fn task_main(
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
|
||||
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await{
|
||||
@@ -123,6 +126,7 @@ pub async fn task_main(
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
)
|
||||
.instrument(span.clone())
|
||||
@@ -234,6 +238,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
|
||||
info!(
|
||||
@@ -243,7 +248,6 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol;
|
||||
// let _client_gauge = metrics.client_connections.guard(proto);
|
||||
let _request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
@@ -286,6 +290,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -128,12 +128,18 @@ impl std::str::FromStr for RateBucketInfo {
|
||||
}
|
||||
|
||||
impl RateBucketInfo {
|
||||
pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
|
||||
pub const DEFAULT_SET: [Self; 3] = [
|
||||
Self::new(300, Duration::from_secs(1)),
|
||||
Self::new(200, Duration::from_secs(60)),
|
||||
Self::new(100, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
|
||||
Self::new(500, Duration::from_secs(1)),
|
||||
Self::new(300, Duration::from_secs(60)),
|
||||
Self::new(200, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
|
||||
info.sort_unstable_by_key(|info| info.interval);
|
||||
let invalid = info
|
||||
@@ -266,7 +272,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn default_rate_buckets() {
|
||||
let mut defaults = RateBucketInfo::DEFAULT_ENDPOINT_SET;
|
||||
let mut defaults = RateBucketInfo::DEFAULT_SET;
|
||||
RateBucketInfo::validate(&mut defaults[..]).unwrap();
|
||||
}
|
||||
|
||||
@@ -333,11 +339,8 @@ mod tests {
|
||||
let rand = rand::rngs::StdRng::from_seed([1; 32]);
|
||||
let hasher = BuildHasherDefault::<FxHasher>::default();
|
||||
|
||||
let limiter = BucketRateLimiter::new_with_rand_and_hasher(
|
||||
&RateBucketInfo::DEFAULT_ENDPOINT_SET,
|
||||
rand,
|
||||
hasher,
|
||||
);
|
||||
let limiter =
|
||||
BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
|
||||
for i in 0..1_000_000 {
|
||||
limiter.check(i, 1);
|
||||
}
|
||||
|
||||
@@ -6,11 +6,14 @@
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
|
||||
|
||||
mod countmin;
|
||||
mod exchange;
|
||||
mod key;
|
||||
mod messages;
|
||||
mod pbkdf2;
|
||||
mod secret;
|
||||
mod signature;
|
||||
pub mod threadpool;
|
||||
|
||||
pub use exchange::{exchange, Exchange};
|
||||
pub use key::ScramKey;
|
||||
@@ -56,9 +59,13 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::sasl::{Mechanism, Step};
|
||||
use crate::{
|
||||
intern::EndpointIdInt,
|
||||
sasl::{Mechanism, Step},
|
||||
EndpointId,
|
||||
};
|
||||
|
||||
use super::{Exchange, ServerSecret};
|
||||
use super::{threadpool::ThreadPool, Exchange, ServerSecret};
|
||||
|
||||
#[test]
|
||||
fn snapshot() {
|
||||
@@ -112,8 +119,13 @@ mod tests {
|
||||
}
|
||||
|
||||
async fn run_round_trip_test(server_password: &str, client_password: &str) {
|
||||
let pool = ThreadPool::new(1);
|
||||
|
||||
let ep = EndpointId::from("foo");
|
||||
let ep = EndpointIdInt::from(ep);
|
||||
|
||||
let scram_secret = ServerSecret::build(server_password).await.unwrap();
|
||||
let outcome = super::exchange(&scram_secret, client_password.as_bytes())
|
||||
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
173
proxy/src/scram/countmin.rs
Normal file
173
proxy/src/scram/countmin.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use std::hash::Hash;
|
||||
|
||||
/// estimator of hash jobs per second.
|
||||
/// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
|
||||
pub struct CountMinSketch {
|
||||
// one for each depth
|
||||
hashers: Vec<ahash::RandomState>,
|
||||
width: usize,
|
||||
depth: usize,
|
||||
// buckets, width*depth
|
||||
buckets: Vec<u32>,
|
||||
}
|
||||
|
||||
impl CountMinSketch {
|
||||
/// Given parameters (ε, δ),
|
||||
/// set width = ceil(e/ε)
|
||||
/// set depth = ceil(ln(1/δ))
|
||||
///
|
||||
/// guarantees:
|
||||
/// actual <= estimate
|
||||
/// estimate <= actual + ε * N with probability 1 - δ
|
||||
/// where N is the cardinality of the stream
|
||||
pub fn with_params(epsilon: f64, delta: f64) -> Self {
|
||||
CountMinSketch::new(
|
||||
(std::f64::consts::E / epsilon).ceil() as usize,
|
||||
(1.0_f64 / delta).ln().ceil() as usize,
|
||||
)
|
||||
}
|
||||
|
||||
fn new(width: usize, depth: usize) -> Self {
|
||||
Self {
|
||||
#[cfg(test)]
|
||||
hashers: (0..depth)
|
||||
.map(|i| {
|
||||
// digits of pi for good randomness
|
||||
ahash::RandomState::with_seeds(
|
||||
314159265358979323,
|
||||
84626433832795028,
|
||||
84197169399375105,
|
||||
82097494459230781 + i as u64,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
#[cfg(not(test))]
|
||||
hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
|
||||
width,
|
||||
depth,
|
||||
buckets: vec![0; width * depth],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
|
||||
let mut min = u32::MAX;
|
||||
for row in 0..self.depth {
|
||||
let col = (self.hashers[row].hash_one(t) as usize) % self.width;
|
||||
|
||||
let row = &mut self.buckets[row * self.width..][..self.width];
|
||||
row[col] = row[col].saturating_add(x);
|
||||
min = std::cmp::min(min, row[col]);
|
||||
}
|
||||
min
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.buckets.clear();
|
||||
self.buckets.resize(self.width * self.depth, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
|
||||
|
||||
use super::CountMinSketch;
|
||||
|
||||
fn eval_precision(n: usize, p: f64, q: f64) -> usize {
|
||||
// fixed value of phi for consistent test
|
||||
let mut rng = StdRng::seed_from_u64(16180339887498948482);
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
let mut N = 0;
|
||||
|
||||
let mut ids = vec![];
|
||||
|
||||
for _ in 0..n {
|
||||
// number of insert operations
|
||||
let n = rng.gen_range(1..100);
|
||||
// number to insert at once
|
||||
let m = rng.gen_range(1..4096);
|
||||
|
||||
let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
|
||||
ids.push((id, n, m));
|
||||
|
||||
// N = sum(actual)
|
||||
N += n * m;
|
||||
}
|
||||
|
||||
// q% of counts will be within p of the actual value
|
||||
let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
|
||||
|
||||
dbg!(sketch.buckets.len());
|
||||
|
||||
// insert a bunch of entries in a random order
|
||||
let mut ids2 = ids.clone();
|
||||
while !ids2.is_empty() {
|
||||
ids2.shuffle(&mut rng);
|
||||
|
||||
let mut i = 0;
|
||||
while i < ids2.len() {
|
||||
sketch.inc_and_return(&ids2[i].0, ids2[i].1);
|
||||
ids2[i].2 -= 1;
|
||||
if ids2[i].2 == 0 {
|
||||
ids2.remove(i);
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut within_p = 0;
|
||||
for (id, n, m) in ids {
|
||||
let actual = n * m;
|
||||
let estimate = sketch.inc_and_return(&id, 0);
|
||||
|
||||
// This estimate has the guarantee that actual <= estimate
|
||||
assert!(actual <= estimate);
|
||||
|
||||
// This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
|
||||
// ε = p / N, δ = 1 - q;
|
||||
// therefore, estimate <= actual + p with probability q.
|
||||
if estimate as f64 <= actual as f64 + p {
|
||||
within_p += 1;
|
||||
}
|
||||
}
|
||||
within_p
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn precision() {
|
||||
assert_eq!(eval_precision(100, 100.0, 0.99), 100);
|
||||
assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
|
||||
assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
|
||||
assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
|
||||
|
||||
// seems to be more precise than the literature indicates?
|
||||
// probably numbers are too small to truly represent the probabilities.
|
||||
assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
|
||||
assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
|
||||
assert_eq!(eval_precision(100, 4096.0, 0.1), 98);
|
||||
assert_eq!(eval_precision(1000, 4096.0, 0.1), 991);
|
||||
}
|
||||
|
||||
// returns memory usage in bytes, and the time complexity per insert.
|
||||
fn eval_cost(p: f64, q: f64) -> (usize, usize) {
|
||||
#[allow(non_snake_case)]
|
||||
// N = sum(actual)
|
||||
// Let's assume 1021 samples, all of 4096
|
||||
let N = 1021 * 4096;
|
||||
let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
|
||||
|
||||
let memory = std::mem::size_of::<u32>() * sketch.buckets.len();
|
||||
let time = sketch.depth;
|
||||
(memory, time)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memory_usage() {
|
||||
assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
|
||||
assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
|
||||
assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
|
||||
assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
|
||||
}
|
||||
}
|
||||
@@ -4,15 +4,17 @@ use std::convert::Infallible;
|
||||
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use tokio::task::yield_now;
|
||||
|
||||
use super::messages::{
|
||||
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
|
||||
};
|
||||
use super::pbkdf2::Pbkdf2;
|
||||
use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use super::threadpool::ThreadPool;
|
||||
use super::ScramKey;
|
||||
use crate::config;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
@@ -74,37 +76,18 @@ impl<'a> Exchange<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
|
||||
async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
|
||||
let hmac = Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
||||
let mut prev = hmac
|
||||
.clone()
|
||||
.chain_update(salt)
|
||||
.chain_update(1u32.to_be_bytes())
|
||||
.finalize()
|
||||
.into_bytes();
|
||||
|
||||
let mut hi = prev;
|
||||
|
||||
for i in 1..iterations {
|
||||
prev = hmac.clone().chain_update(prev).finalize().into_bytes();
|
||||
|
||||
for (hi, prev) in hi.iter_mut().zip(prev) {
|
||||
*hi ^= prev;
|
||||
}
|
||||
// yield every ~250us
|
||||
// hopefully reduces tail latencies
|
||||
if i % 1024 == 0 {
|
||||
yield_now().await
|
||||
}
|
||||
}
|
||||
|
||||
hi.into()
|
||||
}
|
||||
|
||||
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
|
||||
async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey {
|
||||
let salted_password = pbkdf2(password, salt, iterations).await;
|
||||
async fn derive_client_key(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
password: &[u8],
|
||||
salt: &[u8],
|
||||
iterations: u32,
|
||||
) -> ScramKey {
|
||||
let salted_password = pool
|
||||
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
|
||||
.await
|
||||
.expect("job should not be cancelled");
|
||||
|
||||
let make_key = |name| {
|
||||
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
|
||||
@@ -119,11 +102,13 @@ async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> Scr
|
||||
}
|
||||
|
||||
pub async fn exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||
let salt = base64::decode(&secret.salt_base64)?;
|
||||
let client_key = derive_client_key(password, &salt, secret.iterations).await;
|
||||
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
|
||||
if secret.is_password_invalid(&client_key).into() {
|
||||
Ok(sasl::Outcome::Failure("password doesn't match"))
|
||||
|
||||
89
proxy/src/scram/pbkdf2.rs
Normal file
89
proxy/src/scram/pbkdf2.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use hmac::{
|
||||
digest::{consts::U32, generic_array::GenericArray},
|
||||
Hmac, Mac,
|
||||
};
|
||||
use sha2::Sha256;
|
||||
|
||||
pub struct Pbkdf2 {
|
||||
hmac: Hmac<Sha256>,
|
||||
prev: GenericArray<u8, U32>,
|
||||
hi: GenericArray<u8, U32>,
|
||||
iterations: u32,
|
||||
}
|
||||
|
||||
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
|
||||
impl Pbkdf2 {
|
||||
pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
|
||||
let hmac =
|
||||
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
||||
|
||||
let prev = hmac
|
||||
.clone()
|
||||
.chain_update(salt)
|
||||
.chain_update(1u32.to_be_bytes())
|
||||
.finalize()
|
||||
.into_bytes();
|
||||
|
||||
Self {
|
||||
hmac,
|
||||
// one consumed for the hash above
|
||||
iterations: iterations - 1,
|
||||
hi: prev,
|
||||
prev,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cost(&self) -> u32 {
|
||||
(self.iterations).clamp(0, 4096)
|
||||
}
|
||||
|
||||
pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
|
||||
let Self {
|
||||
hmac,
|
||||
prev,
|
||||
hi,
|
||||
iterations,
|
||||
} = self;
|
||||
|
||||
// only do 4096 iterations per turn before sharing the thread for fairness
|
||||
let n = (*iterations).clamp(0, 4096);
|
||||
for _ in 0..n {
|
||||
*prev = hmac.clone().chain_update(*prev).finalize().into_bytes();
|
||||
|
||||
for (hi, prev) in hi.iter_mut().zip(*prev) {
|
||||
*hi ^= prev;
|
||||
}
|
||||
}
|
||||
|
||||
*iterations -= n;
|
||||
if *iterations == 0 {
|
||||
std::task::Poll::Ready((*hi).into())
|
||||
} else {
|
||||
std::task::Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Pbkdf2;
|
||||
use pbkdf2::pbkdf2_hmac_array;
|
||||
use sha2::Sha256;
|
||||
|
||||
#[test]
|
||||
fn works() {
|
||||
let salt = b"sodium chloride";
|
||||
let pass = b"Ne0n_!5_50_C007";
|
||||
|
||||
let mut job = Pbkdf2::start(pass, salt, 600000);
|
||||
let hash = loop {
|
||||
let std::task::Poll::Ready(hash) = job.turn() else {
|
||||
continue;
|
||||
};
|
||||
break hash;
|
||||
};
|
||||
|
||||
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 600000);
|
||||
assert_eq!(hash, expected)
|
||||
}
|
||||
}
|
||||
321
proxy/src/scram/threadpool.rs
Normal file
321
proxy/src/scram/threadpool.rs
Normal file
@@ -0,0 +1,321 @@
|
||||
//! Custom threadpool implementation for password hashing.
|
||||
//!
|
||||
//! Requirements:
|
||||
//! 1. Fairness per endpoint.
|
||||
//! 2. Yield support for high iteration counts.
|
||||
|
||||
use std::sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
use crossbeam_deque::{Injector, Stealer, Worker};
|
||||
use itertools::Itertools;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use rand::Rng;
|
||||
use rand::{rngs::SmallRng, SeedableRng};
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::{
|
||||
intern::EndpointIdInt,
|
||||
metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
|
||||
scram::countmin::CountMinSketch,
|
||||
};
|
||||
|
||||
use super::pbkdf2::Pbkdf2;
|
||||
|
||||
pub struct ThreadPool {
|
||||
queue: Injector<JobSpec>,
|
||||
stealers: Vec<Stealer<JobSpec>>,
|
||||
parkers: Vec<(Condvar, Mutex<ThreadState>)>,
|
||||
/// bitpacked representation.
|
||||
/// lower 8 bits = number of sleeping threads
|
||||
/// next 8 bits = number of idle threads (searching for work)
|
||||
counters: AtomicU64,
|
||||
|
||||
pub metrics: Arc<ThreadPoolMetrics>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum ThreadState {
|
||||
Parked,
|
||||
Active,
|
||||
}
|
||||
|
||||
impl ThreadPool {
|
||||
pub fn new(n_workers: u8) -> Arc<Self> {
|
||||
let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
|
||||
let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
|
||||
|
||||
let parkers = (0..n_workers)
|
||||
.map(|_| (Condvar::new(), Mutex::new(ThreadState::Active)))
|
||||
.collect_vec();
|
||||
|
||||
let pool = Arc::new(Self {
|
||||
queue: Injector::new(),
|
||||
stealers,
|
||||
parkers,
|
||||
// threads start searching for work
|
||||
counters: AtomicU64::new((n_workers as u64) << 8),
|
||||
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
|
||||
});
|
||||
|
||||
for (i, worker) in workers.into_iter().enumerate() {
|
||||
let pool = Arc::clone(&pool);
|
||||
std::thread::spawn(move || thread_rt(pool, worker, i));
|
||||
}
|
||||
|
||||
pool
|
||||
}
|
||||
|
||||
pub fn spawn_job(
|
||||
&self,
|
||||
endpoint: EndpointIdInt,
|
||||
pbkdf2: Pbkdf2,
|
||||
) -> oneshot::Receiver<[u8; 32]> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let queue_was_empty = self.queue.is_empty();
|
||||
|
||||
self.metrics.injector_queue_depth.inc();
|
||||
self.queue.push(JobSpec {
|
||||
response: tx,
|
||||
pbkdf2,
|
||||
endpoint,
|
||||
});
|
||||
|
||||
// inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
|
||||
let counts = self.counters.load(Ordering::SeqCst);
|
||||
let num_awake_but_idle = (counts >> 8) & 0xff;
|
||||
let num_sleepers = counts & 0xff;
|
||||
|
||||
// If the queue is non-empty, then we always wake up a worker
|
||||
// -- clearly the existing idle jobs aren't enough. Otherwise,
|
||||
// check to see if we have enough idle workers.
|
||||
if !queue_was_empty || num_awake_but_idle == 0 {
|
||||
let num_to_wake = Ord::min(1, num_sleepers);
|
||||
self.wake_any_threads(num_to_wake);
|
||||
}
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
#[cold]
|
||||
fn wake_any_threads(&self, mut num_to_wake: u64) {
|
||||
if num_to_wake > 0 {
|
||||
for i in 0..self.parkers.len() {
|
||||
if self.wake_specific_thread(i) {
|
||||
num_to_wake -= 1;
|
||||
if num_to_wake == 0 {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wake_specific_thread(&self, index: usize) -> bool {
|
||||
let (condvar, lock) = &self.parkers[index];
|
||||
|
||||
let mut state = lock.lock();
|
||||
if *state == ThreadState::Parked {
|
||||
condvar.notify_one();
|
||||
|
||||
// When the thread went to sleep, it will have incremented
|
||||
// this value. When we wake it, its our job to decrement
|
||||
// it. We could have the thread do it, but that would
|
||||
// introduce a delay between when the thread was
|
||||
// *notified* and when this counter was decremented. That
|
||||
// might mislead people with new work into thinking that
|
||||
// there are sleeping threads that they should try to
|
||||
// wake, when in fact there is nothing left for them to
|
||||
// do.
|
||||
self.counters.fetch_sub(1, Ordering::SeqCst);
|
||||
*state = ThreadState::Active;
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
|
||||
// announce thread as idle
|
||||
self.counters.fetch_add(256, Ordering::SeqCst);
|
||||
|
||||
// try steal from the global queue
|
||||
loop {
|
||||
match self.queue.steal_batch_and_pop(worker) {
|
||||
crossbeam_deque::Steal::Success(job) => {
|
||||
self.metrics
|
||||
.injector_queue_depth
|
||||
.set(self.queue.len() as i64);
|
||||
// no longer idle
|
||||
self.counters.fetch_sub(256, Ordering::SeqCst);
|
||||
return Some(job);
|
||||
}
|
||||
crossbeam_deque::Steal::Retry => continue,
|
||||
crossbeam_deque::Steal::Empty => break,
|
||||
}
|
||||
}
|
||||
|
||||
// try steal from our neighbours
|
||||
loop {
|
||||
let mut retry = false;
|
||||
let start = rng.gen_range(0..self.stealers.len());
|
||||
let job = (start..self.stealers.len())
|
||||
.chain(0..start)
|
||||
.filter(|i| *i != skip)
|
||||
.find_map(
|
||||
|victim| match self.stealers[victim].steal_batch_and_pop(worker) {
|
||||
crossbeam_deque::Steal::Success(job) => Some(job),
|
||||
crossbeam_deque::Steal::Empty => None,
|
||||
crossbeam_deque::Steal::Retry => {
|
||||
retry = true;
|
||||
None
|
||||
}
|
||||
},
|
||||
);
|
||||
if job.is_some() {
|
||||
// no longer idle
|
||||
self.counters.fetch_sub(256, Ordering::SeqCst);
|
||||
return job;
|
||||
}
|
||||
if !retry {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
|
||||
/// interval when we should steal from the global queue
|
||||
/// so that tail latencies are managed appropriately
|
||||
const STEAL_INTERVAL: usize = 61;
|
||||
|
||||
/// How often to reset the sketch values
|
||||
const SKETCH_RESET_INTERVAL: usize = 1021;
|
||||
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
|
||||
// used to determine whether we should temporarily skip tasks for fairness.
|
||||
// 99% of estimates will overcount by no more than 4096 samples
|
||||
let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01);
|
||||
|
||||
let (condvar, lock) = &pool.parkers[index];
|
||||
|
||||
'wait: loop {
|
||||
// wait for notification of work
|
||||
{
|
||||
let mut lock = lock.lock();
|
||||
|
||||
// queue is empty
|
||||
pool.metrics
|
||||
.worker_queue_depth
|
||||
.set(ThreadPoolWorkerId(index), 0);
|
||||
|
||||
// subtract 1 from idle count, add 1 to sleeping count.
|
||||
pool.counters.fetch_sub(255, Ordering::SeqCst);
|
||||
|
||||
*lock = ThreadState::Parked;
|
||||
condvar.wait(&mut lock);
|
||||
}
|
||||
|
||||
for i in 0.. {
|
||||
let mut job = match worker
|
||||
.pop()
|
||||
.or_else(|| pool.steal(&mut rng, index, &worker))
|
||||
{
|
||||
Some(job) => job,
|
||||
None => continue 'wait,
|
||||
};
|
||||
|
||||
pool.metrics
|
||||
.worker_queue_depth
|
||||
.set(ThreadPoolWorkerId(index), worker.len() as i64);
|
||||
|
||||
// receiver is closed, cancel the task
|
||||
if !job.response.is_closed() {
|
||||
let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
|
||||
|
||||
const P: f64 = 2000.0;
|
||||
// probability decreases as rate increases.
|
||||
// lower probability, higher chance of being skipped
|
||||
//
|
||||
// estimates (rate in terms of 4096 rounds):
|
||||
// rate = 0 => probability = 100%
|
||||
// rate = 10 => probability = 71.3%
|
||||
// rate = 50 => probability = 62.1%
|
||||
// rate = 500 => probability = 52.3%
|
||||
// rate = 1021 => probability = 49.8%
|
||||
//
|
||||
// My expectation is that the pool queue will only begin backing up at ~1000rps
|
||||
// in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
|
||||
// are in requests per second.
|
||||
let probability = P.ln() / (P + rate as f64).ln();
|
||||
if pool.queue.len() > 32 || rng.gen_bool(probability) {
|
||||
pool.metrics
|
||||
.worker_task_turns_total
|
||||
.inc(ThreadPoolWorkerId(index));
|
||||
|
||||
match job.pbkdf2.turn() {
|
||||
std::task::Poll::Ready(result) => {
|
||||
let _ = job.response.send(result);
|
||||
}
|
||||
std::task::Poll::Pending => worker.push(job),
|
||||
}
|
||||
} else {
|
||||
pool.metrics
|
||||
.worker_task_skips_total
|
||||
.inc(ThreadPoolWorkerId(index));
|
||||
|
||||
// skip for now
|
||||
worker.push(job)
|
||||
}
|
||||
}
|
||||
|
||||
// if we get stuck with a few long lived jobs in the queue
|
||||
// it's better to try and steal from the queue too for fairness
|
||||
if i % STEAL_INTERVAL == 0 {
|
||||
let _ = pool.queue.steal_batch(&worker);
|
||||
}
|
||||
|
||||
if i % SKETCH_RESET_INTERVAL == 0 {
|
||||
sketch.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct JobSpec {
|
||||
response: oneshot::Sender<[u8; 32]>,
|
||||
pbkdf2: Pbkdf2,
|
||||
endpoint: EndpointIdInt,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::EndpointId;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn hash_is_correct() {
|
||||
let pool = ThreadPool::new(1);
|
||||
|
||||
let ep = EndpointId::from("foo");
|
||||
let ep = EndpointIdInt::from(ep);
|
||||
|
||||
let salt = [0x55; 32];
|
||||
let actual = pool
|
||||
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let expected = [
|
||||
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
|
||||
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
|
||||
];
|
||||
assert_eq!(actual, expected)
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,7 @@ use crate::context::RequestMonitoring;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::read_proxy_protocol;
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::serverless::http_util::{api_error_into_response, json_response};
|
||||
|
||||
@@ -54,6 +55,7 @@ pub async fn task_main(
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("websocket server has shut down");
|
||||
@@ -82,6 +84,7 @@ pub async fn task_main(
|
||||
let backend = Arc::new(PoolingBackend {
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
});
|
||||
|
||||
let tls_config = match config.tls_config.as_ref() {
|
||||
@@ -99,7 +102,7 @@ pub async fn task_main(
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
let server = Builder::new(hyper_util::rt::TokioExecutor::new());
|
||||
let server = Builder::new(TokioExecutor::new());
|
||||
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
@@ -129,6 +132,7 @@ pub async fn task_main(
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
conn_token.clone(),
|
||||
server.clone(),
|
||||
tls_acceptor.clone(),
|
||||
@@ -162,6 +166,7 @@ async fn connection_handler(
|
||||
backend: Arc<PoolingBackend>,
|
||||
connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
server: Builder<TokioExecutor>,
|
||||
tls_acceptor: TlsAcceptor,
|
||||
@@ -245,11 +250,11 @@ async fn connection_handler(
|
||||
session_id,
|
||||
peer_addr,
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
);
|
||||
|
||||
async move {
|
||||
let res = handler.await;
|
||||
cancel_request.disarm();
|
||||
@@ -285,6 +290,7 @@ async fn request_handler(
|
||||
peer_addr: IpAddr,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -294,7 +300,7 @@ async fn request_handler(
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||
if framed_websockets::upgrade::is_upgrade_request(&request) {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
@@ -305,14 +311,20 @@ async fn request_handler(
|
||||
let span = ctx.span.clone();
|
||||
info!(parent: &span, "performing websocket upgrade");
|
||||
|
||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
if let Err(e) =
|
||||
websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host)
|
||||
.await
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
ctx,
|
||||
websocket,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
host,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("error in websocket connection: {e:#}");
|
||||
}
|
||||
@@ -321,7 +333,7 @@ async fn request_handler(
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
|
||||
@@ -15,7 +15,9 @@ use crate::{
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
error::{ErrorKind, ReportableError, UserFacingError},
|
||||
intern::EndpointIdInt,
|
||||
proxy::{connect_compute::ConnectMechanism, retry::ShouldRetry},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
Host,
|
||||
};
|
||||
|
||||
@@ -24,6 +26,7 @@ use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||
pub struct PoolingBackend {
|
||||
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub config: &'static ProxyConfig,
|
||||
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
impl PoolingBackend {
|
||||
@@ -39,6 +42,12 @@ impl PoolingBackend {
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
|
||||
}
|
||||
if !self
|
||||
.endpoint_rate_limiter
|
||||
.check(conn_info.user_info.endpoint.clone().into(), 1)
|
||||
{
|
||||
return Err(AuthError::too_many_connections());
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => backend.get_role_secret(ctx).await?,
|
||||
@@ -58,8 +67,14 @@ impl PoolingBackend {
|
||||
return Err(AuthError::auth_failed(&*user_info.user));
|
||||
}
|
||||
};
|
||||
let auth_outcome =
|
||||
crate::auth::validate_password_and_exchange(&conn_info.password, secret).await?;
|
||||
let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
|
||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||
&config.thread_pool,
|
||||
ep,
|
||||
&conn_info.password,
|
||||
secret,
|
||||
)
|
||||
.await?;
|
||||
let res = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => {
|
||||
info!("user successfully authenticated");
|
||||
|
||||
@@ -5,11 +5,13 @@ use crate::{
|
||||
error::{io_error, ReportableError},
|
||||
metrics::Metrics,
|
||||
proxy::{handle_client, ClientMode},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
};
|
||||
use bytes::{Buf, Bytes};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use framed_websockets::{Frame, OpCode, WebSocketServer};
|
||||
use futures::{Sink, Stream};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use hyper1::upgrade::OnUpgrade;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use std::{
|
||||
@@ -20,25 +22,23 @@ use std::{
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::warn;
|
||||
|
||||
// TODO: use `std::sync::Exclusive` once it's stabilized.
|
||||
// Tracking issue: https://github.com/rust-lang/rust/issues/98407.
|
||||
use sync_wrapper::SyncWrapper;
|
||||
|
||||
pin_project! {
|
||||
/// This is a wrapper around a [`WebSocketStream`] that
|
||||
/// implements [`AsyncRead`] and [`AsyncWrite`].
|
||||
pub struct WebSocketRw<S = Upgraded> {
|
||||
pub struct WebSocketRw<S> {
|
||||
#[pin]
|
||||
stream: SyncWrapper<WebSocketStream<S>>,
|
||||
bytes: Bytes,
|
||||
stream: WebSocketServer<S>,
|
||||
recv: Bytes,
|
||||
send: BytesMut,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> WebSocketRw<S> {
|
||||
pub fn new(stream: WebSocketStream<S>) -> Self {
|
||||
pub fn new(stream: WebSocketServer<S>) -> Self {
|
||||
Self {
|
||||
stream: stream.into(),
|
||||
bytes: Bytes::new(),
|
||||
stream,
|
||||
recv: Bytes::new(),
|
||||
send: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -49,22 +49,24 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let mut stream = self.project().stream.get_pin_mut();
|
||||
let this = self.project();
|
||||
let mut stream = this.stream;
|
||||
this.send.put(buf);
|
||||
|
||||
ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
|
||||
match stream.as_mut().start_send(Message::Binary(buf.into())) {
|
||||
match stream.as_mut().start_send(Frame::binary(this.send.split())) {
|
||||
Ok(()) => Poll::Ready(Ok(buf.len())),
|
||||
Err(e) => Poll::Ready(Err(io_error(e))),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let stream = self.project().stream.get_pin_mut();
|
||||
let stream = self.project().stream;
|
||||
stream.poll_flush(cx).map_err(io_error)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let stream = self.project().stream.get_pin_mut();
|
||||
let stream = self.project().stream;
|
||||
stream.poll_close(cx).map_err(io_error)
|
||||
}
|
||||
}
|
||||
@@ -75,13 +77,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if buf.remaining() > 0 {
|
||||
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
||||
let len = std::cmp::min(bytes.len(), buf.remaining());
|
||||
buf.put_slice(&bytes[..len]);
|
||||
self.consume(len);
|
||||
}
|
||||
|
||||
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
||||
let len = std::cmp::min(bytes.len(), buf.remaining());
|
||||
buf.put_slice(&bytes[..len]);
|
||||
self.consume(len);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
@@ -93,31 +92,27 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
|
||||
let mut this = self.project();
|
||||
loop {
|
||||
if !this.bytes.chunk().is_empty() {
|
||||
let chunk = (*this.bytes).chunk();
|
||||
if !this.recv.chunk().is_empty() {
|
||||
let chunk = (*this.recv).chunk();
|
||||
return Poll::Ready(Ok(chunk));
|
||||
}
|
||||
|
||||
let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
|
||||
let res = ready!(this.stream.as_mut().poll_next(cx));
|
||||
match res.transpose().map_err(io_error)? {
|
||||
Some(message) => match message {
|
||||
Message::Ping(_) => {}
|
||||
Message::Pong(_) => {}
|
||||
Message::Text(text) => {
|
||||
Some(message) => match message.opcode {
|
||||
OpCode::Ping => {}
|
||||
OpCode::Pong => {}
|
||||
OpCode::Text => {
|
||||
// We expect to see only binary messages.
|
||||
let error = "unexpected text message in the websocket";
|
||||
warn!(length = text.len(), error);
|
||||
warn!(length = message.payload.len(), error);
|
||||
return Poll::Ready(Err(io_error(error)));
|
||||
}
|
||||
Message::Frame(_) => {
|
||||
// This case is impossible according to Frame's doc.
|
||||
panic!("unexpected raw frame in the websocket");
|
||||
OpCode::Binary | OpCode::Continuation => {
|
||||
debug_assert!(this.recv.is_empty());
|
||||
*this.recv = message.payload.freeze();
|
||||
}
|
||||
Message::Binary(chunk) => {
|
||||
assert!(this.bytes.is_empty());
|
||||
*this.bytes = Bytes::from(chunk);
|
||||
}
|
||||
Message::Close(_) => return EOF,
|
||||
OpCode::Close => return EOF,
|
||||
},
|
||||
None => return EOF,
|
||||
}
|
||||
@@ -125,18 +120,21 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
}
|
||||
|
||||
fn consume(self: Pin<&mut Self>, amount: usize) {
|
||||
self.project().bytes.advance(amount);
|
||||
self.project().recv.advance(amount);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
websocket: HyperWebsocket,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
@@ -148,6 +146,7 @@ pub async fn serve_websocket(
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
)
|
||||
.await;
|
||||
@@ -174,15 +173,16 @@ pub async fn serve_websocket(
|
||||
mod tests {
|
||||
use std::pin::pin;
|
||||
|
||||
use framed_websockets::WebSocketServer;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use hyper_tungstenite::{
|
||||
tungstenite::{protocol::Role, Message},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tokio::{
|
||||
io::{duplex, AsyncReadExt, AsyncWriteExt},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::{protocol::Role, Message},
|
||||
WebSocketStream,
|
||||
};
|
||||
|
||||
use super::WebSocketRw;
|
||||
|
||||
@@ -207,9 +207,7 @@ mod tests {
|
||||
});
|
||||
|
||||
js.spawn(async move {
|
||||
let mut rw = pin!(WebSocketRw::new(
|
||||
WebSocketStream::from_raw_socket(stream2, Role::Server, None).await
|
||||
));
|
||||
let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
|
||||
|
||||
let mut buf = vec![0; 1024];
|
||||
let n = rw.read(&mut buf).await.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user