mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-16 20:50:37 +00:00
Merge pull request #7909 from neondatabase/rc/proxy/2024-05-30
Proxy release 2024-05-30
This commit is contained in:
@@ -38,6 +38,7 @@ hmac.workspace = true
|
||||
hostname.workspace = true
|
||||
http.workspace = true
|
||||
humantime.workspace = true
|
||||
humantime-serde.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
|
||||
hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] }
|
||||
@@ -82,6 +83,7 @@ thiserror.workspace = true
|
||||
tikv-jemallocator.workspace = true
|
||||
tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
|
||||
tokio-postgres.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
@@ -94,10 +96,8 @@ url.workspace = true
|
||||
urlencoding.workspace = true
|
||||
utils.workspace = true
|
||||
uuid.workspace = true
|
||||
webpki-roots.workspace = true
|
||||
rustls-native-certs.workspace = true
|
||||
x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ use crate::{
|
||||
},
|
||||
stream, url,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, Normalize, RoleName};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
|
||||
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
|
||||
pub enum MaybeOwned<'a, T> {
|
||||
|
||||
@@ -100,6 +100,7 @@ pub(super) async fn authenticate(
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
ctx.set_dbname(db_info.dbname.into());
|
||||
ctx.set_user(db_info.user.into());
|
||||
ctx.set_project(db_info.aux.clone());
|
||||
info!("woken up a compute node");
|
||||
|
||||
@@ -11,7 +11,6 @@ use crate::{
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashSet, net::IpAddr, str::FromStr};
|
||||
use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
@@ -96,13 +95,6 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
let user: RoleName = get_param("user")?.into();
|
||||
|
||||
// record the values if we have them
|
||||
ctx.set_application(params.get("application_name").map(SmolStr::from));
|
||||
ctx.set_user(user.clone());
|
||||
if let Some(dbname) = params.get("database") {
|
||||
ctx.set_dbname(dbname.into());
|
||||
}
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let endpoint_option = params
|
||||
.options_raw()
|
||||
|
||||
@@ -557,14 +557,14 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
permits,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.wake_compute_lock.parse()?;
|
||||
info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)");
|
||||
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
|
||||
let locks = Box::leak(Box::new(console::locks::ApiLocks::new(
|
||||
"wake_compute_lock",
|
||||
permits,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
@@ -603,14 +603,19 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
permits,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.connect_compute_lock.parse()?;
|
||||
info!(permits, shards, ?epoch, "Using NodeLocks (connect_compute)");
|
||||
info!(
|
||||
?limiter,
|
||||
shards,
|
||||
?epoch,
|
||||
"Using NodeLocks (connect_compute)"
|
||||
);
|
||||
let connect_compute_locks = console::locks::ApiLocks::new(
|
||||
"connect_compute_lock",
|
||||
permits,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
|
||||
@@ -10,11 +10,14 @@ use crate::{
|
||||
};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use std::{io, net::SocketAddr, time::Duration};
|
||||
use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError};
|
||||
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
@@ -30,7 +33,7 @@ pub enum ConnectionError {
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] native_tls::Error),
|
||||
TlsError(#[from] InvalidDnsNameError),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
WakeComputeError(#[from] WakeComputeError),
|
||||
@@ -257,7 +260,7 @@ pub struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
|
||||
tokio::net::TcpStream,
|
||||
postgres_native_tls::TlsStream<tokio::net::TcpStream>,
|
||||
tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
|
||||
>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
@@ -282,12 +285,23 @@ impl ConnCfg {
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
drop(pause);
|
||||
|
||||
let tls_connector = native_tls::TlsConnector::builder()
|
||||
.danger_accept_invalid_certs(allow_self_signed_compute)
|
||||
.build()
|
||||
.unwrap();
|
||||
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
|
||||
let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
|
||||
let client_config = if allow_self_signed_compute {
|
||||
// Allow all certificates for creating the connection
|
||||
let verifier = Arc::new(AcceptEverythingVerifier) as Arc<dyn ServerCertVerifier>;
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(verifier)
|
||||
} else {
|
||||
let root_store = TLS_ROOTS.get_or_try_init(load_certs)?.clone();
|
||||
rustls::ClientConfig::builder().with_root_certificates(root_store)
|
||||
};
|
||||
let client_config = client_config.with_no_client_auth();
|
||||
|
||||
let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
host,
|
||||
)?;
|
||||
|
||||
// connect_raw() will not use TLS if sslmode is "disable"
|
||||
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
|
||||
@@ -340,6 +354,58 @@ fn filtered_options(params: &StartupMessageParams) -> Option<String> {
|
||||
Some(options)
|
||||
}
|
||||
|
||||
fn load_certs() -> Result<Arc<rustls::RootCertStore>, io::Error> {
|
||||
let der_certs = rustls_native_certs::load_native_certs()?;
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(der_certs);
|
||||
Ok(Arc::new(store))
|
||||
}
|
||||
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AcceptEverythingVerifier;
|
||||
impl ServerCertVerifier for AcceptEverythingVerifier {
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
use rustls::SignatureScheme::*;
|
||||
// The schemes for which `SignatureScheme::supported_in_tls13` returns true.
|
||||
vec![
|
||||
ECDSA_NISTP521_SHA512,
|
||||
ECDSA_NISTP384_SHA384,
|
||||
ECDSA_NISTP256_SHA256,
|
||||
RSA_PSS_SHA512,
|
||||
RSA_PSS_SHA384,
|
||||
RSA_PSS_SHA256,
|
||||
ED25519,
|
||||
]
|
||||
}
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
auth::{self, backend::AuthRateLimiter},
|
||||
console::locks::ApiLocks,
|
||||
rate_limiter::RateBucketInfo,
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
||||
scram::threadpool::ThreadPool,
|
||||
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
|
||||
Host,
|
||||
@@ -580,14 +580,18 @@ impl RetryConfig {
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct ConcurrencyLockOptions {
|
||||
/// The number of shards the lock map should have
|
||||
pub shards: usize,
|
||||
/// The number of allowed concurrent requests for each endpoitn
|
||||
pub permits: usize,
|
||||
#[serde(flatten)]
|
||||
pub limiter: RateLimiterConfig,
|
||||
/// Garbage collection epoch
|
||||
#[serde(deserialize_with = "humantime_serde::deserialize")]
|
||||
pub epoch: Duration,
|
||||
/// Lock timeout
|
||||
#[serde(deserialize_with = "humantime_serde::deserialize")]
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
@@ -596,13 +600,18 @@ impl ConcurrencyLockOptions {
|
||||
pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
|
||||
/// Default options for [`crate::console::provider::ApiLocks`].
|
||||
pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
|
||||
"shards=64,permits=10,epoch=10m,timeout=10ms";
|
||||
"shards=64,permits=100,epoch=10m,timeout=10ms";
|
||||
|
||||
// pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
|
||||
|
||||
/// Parse lock options passed via cmdline.
|
||||
/// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let options = options.trim();
|
||||
if options.starts_with('{') && options.ends_with('}') {
|
||||
return Ok(serde_json::from_str(options)?);
|
||||
}
|
||||
|
||||
let mut shards = None;
|
||||
let mut permits = None;
|
||||
let mut epoch = None;
|
||||
@@ -629,9 +638,13 @@ impl ConcurrencyLockOptions {
|
||||
shards = Some(2);
|
||||
}
|
||||
|
||||
let permits = permits.context("missing `permits`")?;
|
||||
let out = Self {
|
||||
shards: shards.context("missing `shards`")?,
|
||||
permits: permits.context("missing `permits`")?,
|
||||
limiter: RateLimiterConfig {
|
||||
algorithm: RateLimitAlgorithm::Fixed,
|
||||
initial_limit: permits,
|
||||
},
|
||||
epoch: epoch.context("missing `epoch`")?,
|
||||
timeout: timeout.context("missing `timeout`")?,
|
||||
};
|
||||
@@ -657,6 +670,8 @@ impl FromStr for ConcurrencyLockOptions {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::rate_limiter::Aimd;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@@ -684,36 +699,68 @@ mod tests {
|
||||
fn test_parse_lock_options() -> anyhow::Result<()> {
|
||||
let ConcurrencyLockOptions {
|
||||
epoch,
|
||||
permits,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
} = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
|
||||
assert_eq!(epoch, Duration::from_secs(10 * 60));
|
||||
assert_eq!(timeout, Duration::from_secs(1));
|
||||
assert_eq!(shards, 32);
|
||||
assert_eq!(permits, 4);
|
||||
assert_eq!(limiter.initial_limit, 4);
|
||||
assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
|
||||
|
||||
let ConcurrencyLockOptions {
|
||||
epoch,
|
||||
permits,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
} = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
|
||||
assert_eq!(epoch, Duration::from_secs(60));
|
||||
assert_eq!(timeout, Duration::from_millis(100));
|
||||
assert_eq!(shards, 16);
|
||||
assert_eq!(permits, 8);
|
||||
assert_eq!(limiter.initial_limit, 8);
|
||||
assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
|
||||
|
||||
let ConcurrencyLockOptions {
|
||||
epoch,
|
||||
permits,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
} = "permits=0".parse()?;
|
||||
assert_eq!(epoch, Duration::ZERO);
|
||||
assert_eq!(timeout, Duration::ZERO);
|
||||
assert_eq!(shards, 2);
|
||||
assert_eq!(permits, 0);
|
||||
assert_eq!(limiter.initial_limit, 0);
|
||||
assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_json_lock_options() -> anyhow::Result<()> {
|
||||
let ConcurrencyLockOptions {
|
||||
epoch,
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
} = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
|
||||
.parse()?;
|
||||
assert_eq!(epoch, Duration::from_secs(10 * 60));
|
||||
assert_eq!(timeout, Duration::from_secs(1));
|
||||
assert_eq!(shards, 32);
|
||||
assert_eq!(limiter.initial_limit, 44);
|
||||
assert_eq!(
|
||||
limiter.algorithm,
|
||||
RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 5,
|
||||
max: 500,
|
||||
dec: 0.9,
|
||||
inc: 10,
|
||||
utilisation: 0.8
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ use crate::{
|
||||
error::ReportableError,
|
||||
intern::ProjectIdInt,
|
||||
metrics::ApiLockMetrics,
|
||||
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
|
||||
scram, EndpointCacheKey,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::{hash::Hash, sync::Arc, time::Duration};
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
@@ -443,8 +443,8 @@ impl ApiCaches {
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiLocks<K> {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<K, Arc<Semaphore>>,
|
||||
permits: usize,
|
||||
node_locks: DashMap<K, Arc<DynamicLimiter>>,
|
||||
config: RateLimiterConfig,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
@@ -452,8 +452,6 @@ pub struct ApiLocks<K> {
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ApiLockError {
|
||||
#[error("lock was closed")]
|
||||
AcquireError(#[from] tokio::sync::AcquireError),
|
||||
#[error("permit could not be acquired")]
|
||||
TimeoutError(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
@@ -461,7 +459,6 @@ pub enum ApiLockError {
|
||||
impl ReportableError for ApiLockError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiLockError::AcquireError(_) => crate::error::ErrorKind::Service,
|
||||
ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
|
||||
}
|
||||
}
|
||||
@@ -470,7 +467,7 @@ impl ReportableError for ApiLockError {
|
||||
impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
permits: usize,
|
||||
config: RateLimiterConfig,
|
||||
shards: usize,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
@@ -479,7 +476,7 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
Ok(Self {
|
||||
name,
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
permits,
|
||||
config,
|
||||
timeout,
|
||||
epoch,
|
||||
metrics,
|
||||
@@ -487,8 +484,10 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
}
|
||||
|
||||
pub async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
|
||||
if self.permits == 0 {
|
||||
return Ok(WakeComputePermit { permit: None });
|
||||
if self.config.initial_limit == 0 {
|
||||
return Ok(WakeComputePermit {
|
||||
permit: Token::disabled(),
|
||||
});
|
||||
}
|
||||
let now = Instant::now();
|
||||
let semaphore = {
|
||||
@@ -500,24 +499,22 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| {
|
||||
self.metrics.semaphores_registered.inc();
|
||||
Arc::new(Semaphore::new(self.permits))
|
||||
DynamicLimiter::new(self.config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
};
|
||||
let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
|
||||
let permit = semaphore.acquire_deadline(now + self.timeout).await;
|
||||
|
||||
self.metrics
|
||||
.semaphore_acquire_seconds
|
||||
.observe(now.elapsed().as_secs_f64());
|
||||
|
||||
Ok(WakeComputePermit {
|
||||
permit: Some(permit??),
|
||||
})
|
||||
Ok(WakeComputePermit { permit: permit? })
|
||||
}
|
||||
|
||||
pub async fn garbage_collect_worker(&self) {
|
||||
if self.permits == 0 {
|
||||
if self.config.initial_limit == 0 {
|
||||
return;
|
||||
}
|
||||
let mut interval =
|
||||
@@ -547,12 +544,21 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
}
|
||||
|
||||
pub struct WakeComputePermit {
|
||||
// None if the lock is disabled
|
||||
permit: Option<OwnedSemaphorePermit>,
|
||||
permit: Token,
|
||||
}
|
||||
|
||||
impl WakeComputePermit {
|
||||
pub fn should_check_cache(&self) -> bool {
|
||||
self.permit.is_some()
|
||||
!self.permit.is_disabled()
|
||||
}
|
||||
pub fn release(self, outcome: Outcome) {
|
||||
self.permit.release(outcome)
|
||||
}
|
||||
pub fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
|
||||
match res {
|
||||
Ok(_) => self.release(Outcome::Success),
|
||||
Err(_) => self.release(Outcome::Overload),
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{
|
||||
http,
|
||||
metrics::{CacheOutcome, Metrics},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
scram, EndpointCacheKey, Normalize,
|
||||
scram, EndpointCacheKey,
|
||||
};
|
||||
use crate::{cache::Cached, context::RequestMonitoring};
|
||||
use futures::TryFutureExt;
|
||||
@@ -281,14 +281,6 @@ impl super::Api for Api {
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
// check rate limit
|
||||
if !self
|
||||
.wake_compute_endpoint_rate_limiter
|
||||
.check(user_info.endpoint.normalize().into(), 1)
|
||||
{
|
||||
return Err(WakeComputeError::TooManyConnections);
|
||||
}
|
||||
|
||||
let permit = self.locks.get_permit(&key).await?;
|
||||
|
||||
// after getting back a permit - it's possible the cache was filled
|
||||
@@ -301,7 +293,16 @@ impl super::Api for Api {
|
||||
}
|
||||
}
|
||||
|
||||
let mut node = self.do_wake_compute(ctx, user_info).await?;
|
||||
// check rate limit
|
||||
if !self
|
||||
.wake_compute_endpoint_rate_limiter
|
||||
.check(user_info.endpoint.normalize_intern(), 1)
|
||||
{
|
||||
info!(key = &*key, "found cached compute node info");
|
||||
return Err(WakeComputeError::TooManyConnections);
|
||||
}
|
||||
|
||||
let mut node = permit.release_result(self.do_wake_compute(ctx, user_info).await)?;
|
||||
ctx.set_project(node.aux.clone());
|
||||
let cold_start_info = node.aux.cold_start_info;
|
||||
info!("woken up a compute node");
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use chrono::Utc;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -46,6 +47,7 @@ pub struct RequestMonitoring {
|
||||
pub(crate) auth_method: Option<AuthMethod>,
|
||||
success: bool,
|
||||
pub(crate) cold_start_info: ColdStartInfo,
|
||||
pg_options: Option<StartupMessageParams>,
|
||||
|
||||
// extra
|
||||
// This sender is here to keep the request monitoring channel open while requests are taking place.
|
||||
@@ -102,6 +104,7 @@ impl RequestMonitoring {
|
||||
success: false,
|
||||
rejected: None,
|
||||
cold_start_info: ColdStartInfo::Unknown,
|
||||
pg_options: None,
|
||||
|
||||
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
||||
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
|
||||
@@ -132,6 +135,18 @@ impl RequestMonitoring {
|
||||
self.latency_timer.cold_start_info(info);
|
||||
}
|
||||
|
||||
pub fn set_db_options(&mut self, options: StartupMessageParams) {
|
||||
self.set_application(options.get("application_name").map(SmolStr::from));
|
||||
if let Some(user) = options.get("user") {
|
||||
self.set_user(user.into());
|
||||
}
|
||||
if let Some(dbname) = options.get("database") {
|
||||
self.set_dbname(dbname.into());
|
||||
}
|
||||
|
||||
self.pg_options = Some(options);
|
||||
}
|
||||
|
||||
pub fn set_project(&mut self, x: MetricsAuxInfo) {
|
||||
if self.endpoint_id.is_none() {
|
||||
self.set_endpoint_id(x.endpoint_id.as_str().into())
|
||||
@@ -155,8 +170,10 @@ impl RequestMonitoring {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_application(&mut self, app: Option<SmolStr>) {
|
||||
self.application = app.or_else(|| self.application.clone());
|
||||
fn set_application(&mut self, app: Option<SmolStr>) {
|
||||
if let Some(app) = app {
|
||||
self.application = Some(app);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_dbname(&mut self, dbname: DbName) {
|
||||
|
||||
@@ -13,7 +13,9 @@ use parquet::{
|
||||
},
|
||||
record::RecordWriter,
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel};
|
||||
use serde::ser::SerializeMap;
|
||||
use tokio::{sync::mpsc, time};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, Span};
|
||||
@@ -87,6 +89,7 @@ pub struct RequestData {
|
||||
database: Option<String>,
|
||||
project: Option<String>,
|
||||
branch: Option<String>,
|
||||
pg_options: Option<String>,
|
||||
auth_method: Option<&'static str>,
|
||||
error: Option<&'static str>,
|
||||
/// Success is counted if we form a HTTP response with sql rows inside
|
||||
@@ -101,6 +104,23 @@ pub struct RequestData {
|
||||
disconnect_timestamp: Option<chrono::NaiveDateTime>,
|
||||
}
|
||||
|
||||
struct Options<'a> {
|
||||
options: &'a StartupMessageParams,
|
||||
}
|
||||
|
||||
impl<'a> serde::Serialize for Options<'a> {
|
||||
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut state = s.serialize_map(None)?;
|
||||
for (k, v) in self.options.iter() {
|
||||
state.serialize_entry(k, v)?;
|
||||
}
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RequestMonitoring> for RequestData {
|
||||
fn from(value: &RequestMonitoring) -> Self {
|
||||
Self {
|
||||
@@ -113,6 +133,10 @@ impl From<&RequestMonitoring> for RequestData {
|
||||
database: value.dbname.as_deref().map(String::from),
|
||||
project: value.project.as_deref().map(String::from),
|
||||
branch: value.branch.as_deref().map(String::from),
|
||||
pg_options: value
|
||||
.pg_options
|
||||
.as_ref()
|
||||
.and_then(|options| serde_json::to_string(&Options { options }).ok()),
|
||||
auth_method: value.auth_method.as_ref().map(|x| match x {
|
||||
super::AuthMethod::Web => "web",
|
||||
super::AuthMethod::ScramSha256 => "scram_sha_256",
|
||||
@@ -494,6 +518,7 @@ mod tests {
|
||||
database: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
pg_options: None,
|
||||
auth_method: None,
|
||||
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
||||
region: "us-east-1",
|
||||
@@ -570,15 +595,15 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1315314, 3, 6000),
|
||||
(1315307, 3, 6000),
|
||||
(1315367, 3, 6000),
|
||||
(1315324, 3, 6000),
|
||||
(1315454, 3, 6000),
|
||||
(1315296, 3, 6000),
|
||||
(1315088, 3, 6000),
|
||||
(1315324, 3, 6000),
|
||||
(438713, 1, 2000)
|
||||
(1315874, 3, 6000),
|
||||
(1315867, 3, 6000),
|
||||
(1315927, 3, 6000),
|
||||
(1315884, 3, 6000),
|
||||
(1316014, 3, 6000),
|
||||
(1315856, 3, 6000),
|
||||
(1315648, 3, 6000),
|
||||
(1315884, 3, 6000),
|
||||
(438913, 1, 2000)
|
||||
]
|
||||
);
|
||||
|
||||
@@ -608,11 +633,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1222212, 5, 10000),
|
||||
(1228362, 5, 10000),
|
||||
(1230156, 5, 10000),
|
||||
(1229518, 5, 10000),
|
||||
(1220796, 5, 10000)
|
||||
(1223214, 5, 10000),
|
||||
(1229364, 5, 10000),
|
||||
(1231158, 5, 10000),
|
||||
(1230520, 5, 10000),
|
||||
(1221798, 5, 10000)
|
||||
]
|
||||
);
|
||||
|
||||
@@ -644,11 +669,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1207859, 5, 10000),
|
||||
(1207590, 5, 10000),
|
||||
(1207883, 5, 10000),
|
||||
(1207871, 5, 10000),
|
||||
(1208126, 5, 10000)
|
||||
(1208861, 5, 10000),
|
||||
(1208592, 5, 10000),
|
||||
(1208885, 5, 10000),
|
||||
(1208873, 5, 10000),
|
||||
(1209128, 5, 10000)
|
||||
]
|
||||
);
|
||||
|
||||
@@ -673,15 +698,15 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1315314, 3, 6000),
|
||||
(1315307, 3, 6000),
|
||||
(1315367, 3, 6000),
|
||||
(1315324, 3, 6000),
|
||||
(1315454, 3, 6000),
|
||||
(1315296, 3, 6000),
|
||||
(1315088, 3, 6000),
|
||||
(1315324, 3, 6000),
|
||||
(438713, 1, 2000)
|
||||
(1315874, 3, 6000),
|
||||
(1315867, 3, 6000),
|
||||
(1315927, 3, 6000),
|
||||
(1315884, 3, 6000),
|
||||
(1316014, 3, 6000),
|
||||
(1315856, 3, 6000),
|
||||
(1315648, 3, 6000),
|
||||
(1315884, 3, 6000),
|
||||
(438913, 1, 2000)
|
||||
]
|
||||
);
|
||||
|
||||
@@ -718,7 +743,7 @@ mod tests {
|
||||
// files are smaller than the size threshold, but they took too long to fill so were flushed early
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[(659462, 2, 3001), (659176, 2, 3000), (658972, 2, 2999)]
|
||||
[(659836, 2, 3001), (659550, 2, 3000), (659346, 2, 2999)]
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use intern::{EndpointIdInt, EndpointIdTag, InternId};
|
||||
use tokio::task::JoinError;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
@@ -129,20 +130,22 @@ macro_rules! smol_str_wrapper {
|
||||
|
||||
const POOLER_SUFFIX: &str = "-pooler";
|
||||
|
||||
pub trait Normalize {
|
||||
fn normalize(&self) -> Self;
|
||||
}
|
||||
|
||||
impl<S: Clone + AsRef<str> + From<String>> Normalize for S {
|
||||
impl EndpointId {
|
||||
fn normalize(&self) -> Self {
|
||||
if self.as_ref().ends_with(POOLER_SUFFIX) {
|
||||
let mut s = self.as_ref().to_string();
|
||||
s.truncate(s.len() - POOLER_SUFFIX.len());
|
||||
s.into()
|
||||
if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
|
||||
stripped.into()
|
||||
} else {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_intern(&self) -> EndpointIdInt {
|
||||
if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
|
||||
EndpointIdTag::get_interner().get_or_intern(stripped)
|
||||
} else {
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 90% of role name strings are 20 characters or less.
|
||||
|
||||
@@ -267,6 +267,8 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
@@ -84,8 +84,8 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let host = node_info.config.get_host()?;
|
||||
let _permit = self.locks.get_permit(&host).await?;
|
||||
node_info.connect(ctx, timeout).await
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
permit.release_result(node_info.connect(ctx, timeout).await)
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
|
||||
@@ -1,2 +1,6 @@
|
||||
mod limit_algorithm;
|
||||
mod limiter;
|
||||
pub use limit_algorithm::{
|
||||
aimd::Aimd, DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
|
||||
};
|
||||
pub use limiter::{BucketRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo};
|
||||
|
||||
275
proxy/src/rate_limiter/limit_algorithm.rs
Normal file
275
proxy/src/rate_limiter/limit_algorithm.rs
Normal file
@@ -0,0 +1,275 @@
|
||||
//! Algorithms for controlling concurrency limits.
|
||||
use parking_lot::Mutex;
|
||||
use std::{pin::pin, sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
sync::Notify,
|
||||
time::{error::Elapsed, timeout_at, Instant},
|
||||
};
|
||||
|
||||
use self::aimd::Aimd;
|
||||
|
||||
pub mod aimd;
|
||||
|
||||
/// Whether a job succeeded or failed as a result of congestion/overload.
|
||||
///
|
||||
/// Errors not considered to be caused by overload should be ignored.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Outcome {
|
||||
/// The job succeeded, or failed in a way unrelated to overload.
|
||||
Success,
|
||||
/// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
|
||||
/// was observed.
|
||||
Overload,
|
||||
}
|
||||
|
||||
/// An algorithm for controlling a concurrency limit.
|
||||
pub trait LimitAlgorithm: Send + Sync + 'static {
|
||||
/// Update the concurrency limit in response to a new job completion.
|
||||
fn update(&self, old_limit: usize, sample: Sample) -> usize;
|
||||
}
|
||||
|
||||
/// The result of a job (or jobs), including the [`Outcome`] (loss) and latency (delay).
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub struct Sample {
|
||||
pub(crate) latency: Duration,
|
||||
/// Jobs in flight when the sample was taken.
|
||||
pub(crate) in_flight: usize,
|
||||
pub(crate) outcome: Outcome,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, serde::Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RateLimitAlgorithm {
|
||||
#[default]
|
||||
Fixed,
|
||||
Aimd {
|
||||
#[serde(flatten)]
|
||||
conf: Aimd,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Fixed;
|
||||
|
||||
impl LimitAlgorithm for Fixed {
|
||||
fn update(&self, old_limit: usize, _sample: Sample) -> usize {
|
||||
old_limit
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)]
|
||||
pub struct RateLimiterConfig {
|
||||
#[serde(flatten)]
|
||||
pub algorithm: RateLimitAlgorithm,
|
||||
pub initial_limit: usize,
|
||||
}
|
||||
|
||||
impl RateLimiterConfig {
|
||||
pub fn create_rate_limit_algorithm(self) -> Box<dyn LimitAlgorithm> {
|
||||
match self.algorithm {
|
||||
RateLimitAlgorithm::Fixed => Box::new(Fixed),
|
||||
RateLimitAlgorithm::Aimd { conf } => Box::new(conf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LimiterInner {
|
||||
alg: Box<dyn LimitAlgorithm>,
|
||||
available: usize,
|
||||
limit: usize,
|
||||
in_flight: usize,
|
||||
}
|
||||
|
||||
impl LimiterInner {
|
||||
fn update(&mut self, latency: Duration, outcome: Option<Outcome>) {
|
||||
if let Some(outcome) = outcome {
|
||||
let sample = Sample {
|
||||
latency,
|
||||
in_flight: self.in_flight,
|
||||
outcome,
|
||||
};
|
||||
self.limit = self.alg.update(self.limit, sample);
|
||||
}
|
||||
}
|
||||
|
||||
fn take(&mut self, ready: &Notify) -> Option<()> {
|
||||
if self.available > 1 {
|
||||
self.available -= 1;
|
||||
self.in_flight += 1;
|
||||
|
||||
// tell the next in the queue that there is a permit ready
|
||||
if self.available > 1 {
|
||||
ready.notify_one();
|
||||
}
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Limits the number of concurrent jobs.
|
||||
///
|
||||
/// Concurrency is limited through the use of [`Token`]s. Acquire a token to run a job, and release the
|
||||
/// token once the job is finished.
|
||||
///
|
||||
/// The limit will be automatically adjusted based on observed latency (delay) and/or failures
|
||||
/// caused by overload (loss).
|
||||
pub struct DynamicLimiter {
|
||||
config: RateLimiterConfig,
|
||||
inner: Mutex<LimiterInner>,
|
||||
// to notify when a token is available
|
||||
ready: Notify,
|
||||
}
|
||||
|
||||
/// A concurrency token, required to run a job.
|
||||
///
|
||||
/// Release the token back to the [`DynamicLimiter`] after the job is complete.
|
||||
pub struct Token {
|
||||
start: Instant,
|
||||
limiter: Option<Arc<DynamicLimiter>>,
|
||||
}
|
||||
|
||||
/// A snapshot of the state of the [`DynamicLimiter`].
|
||||
///
|
||||
/// Not guaranteed to be consistent under high concurrency.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct LimiterState {
|
||||
limit: usize,
|
||||
in_flight: usize,
|
||||
}
|
||||
|
||||
impl DynamicLimiter {
|
||||
/// Create a limiter with a given limit control algorithm.
|
||||
pub fn new(config: RateLimiterConfig) -> Arc<Self> {
|
||||
let ready = Notify::new();
|
||||
ready.notify_one();
|
||||
|
||||
Arc::new(Self {
|
||||
inner: Mutex::new(LimiterInner {
|
||||
alg: config.create_rate_limit_algorithm(),
|
||||
available: config.initial_limit,
|
||||
limit: config.initial_limit,
|
||||
in_flight: 0,
|
||||
}),
|
||||
ready,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
|
||||
///
|
||||
/// Returns `None` if there are none available after `duration`.
|
||||
pub async fn acquire_timeout(self: &Arc<Self>, duration: Duration) -> Result<Token, Elapsed> {
|
||||
self.acquire_deadline(Instant::now() + duration).await
|
||||
}
|
||||
|
||||
/// Try to acquire a concurrency [Token], waiting until `deadline` if there are none available.
|
||||
///
|
||||
/// Returns `None` if there are none available after `deadline`.
|
||||
pub async fn acquire_deadline(self: &Arc<Self>, deadline: Instant) -> Result<Token, Elapsed> {
|
||||
if self.config.initial_limit == 0 {
|
||||
// If the rate limiter is disabled, we can always acquire a token.
|
||||
Ok(Token::disabled())
|
||||
} else {
|
||||
let mut notified = pin!(self.ready.notified());
|
||||
let mut ready = notified.as_mut().enable();
|
||||
loop {
|
||||
let mut limit = None;
|
||||
if ready {
|
||||
let mut inner = self.inner.lock();
|
||||
if inner.take(&self.ready).is_some() {
|
||||
break Ok(Token::new(self.clone()));
|
||||
}
|
||||
limit = Some(inner.limit);
|
||||
}
|
||||
match timeout_at(deadline, notified.as_mut()).await {
|
||||
Ok(()) => ready = true,
|
||||
Err(e) => {
|
||||
let limit = limit.unwrap_or_else(|| self.inner.lock().limit);
|
||||
tracing::info!(limit, "could not acquire token in time");
|
||||
break Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the concurrency [Token], along with the outcome of the job.
|
||||
///
|
||||
/// The [Outcome] of the job, and the time taken to perform it, may be used
|
||||
/// to update the concurrency limit.
|
||||
///
|
||||
/// Set the outcome to `None` to ignore the job.
|
||||
fn release_inner(&self, start: Instant, outcome: Option<Outcome>) {
|
||||
tracing::info!("outcome is {:?}", outcome);
|
||||
if self.config.initial_limit == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut inner = self.inner.lock();
|
||||
|
||||
inner.update(start.elapsed(), outcome);
|
||||
if inner.in_flight < inner.limit {
|
||||
inner.available = inner.limit - inner.in_flight;
|
||||
// At least 1 permit is now available
|
||||
self.ready.notify_one();
|
||||
}
|
||||
|
||||
inner.in_flight -= 1;
|
||||
}
|
||||
|
||||
/// The current state of the limiter.
|
||||
pub fn state(&self) -> LimiterState {
|
||||
let inner = self.inner.lock();
|
||||
LimiterState {
|
||||
limit: inner.limit,
|
||||
in_flight: inner.in_flight,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Token {
|
||||
fn new(limiter: Arc<DynamicLimiter>) -> Self {
|
||||
Self {
|
||||
start: Instant::now(),
|
||||
limiter: Some(limiter),
|
||||
}
|
||||
}
|
||||
pub fn disabled() -> Self {
|
||||
Self {
|
||||
start: Instant::now(),
|
||||
limiter: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_disabled(&self) -> bool {
|
||||
self.limiter.is_none()
|
||||
}
|
||||
|
||||
pub fn release(mut self, outcome: Outcome) {
|
||||
self.release_mut(Some(outcome))
|
||||
}
|
||||
|
||||
pub fn release_mut(&mut self, outcome: Option<Outcome>) {
|
||||
if let Some(limiter) = self.limiter.take() {
|
||||
limiter.release_inner(self.start, outcome);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Token {
|
||||
fn drop(&mut self) {
|
||||
self.release_mut(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl LimiterState {
|
||||
/// The current concurrency limit.
|
||||
pub fn limit(&self) -> usize {
|
||||
self.limit
|
||||
}
|
||||
/// The number of jobs in flight.
|
||||
pub fn in_flight(&self) -> usize {
|
||||
self.in_flight
|
||||
}
|
||||
}
|
||||
184
proxy/src/rate_limiter/limit_algorithm/aimd.rs
Normal file
184
proxy/src/rate_limiter/limit_algorithm/aimd.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
use std::usize;
|
||||
|
||||
use super::{LimitAlgorithm, Outcome, Sample};
|
||||
|
||||
/// Loss-based congestion avoidance.
|
||||
///
|
||||
/// Additive-increase, multiplicative decrease.
|
||||
///
|
||||
/// Adds available currency when:
|
||||
/// 1. no load-based errors are observed, and
|
||||
/// 2. the utilisation of the current limit is high.
|
||||
///
|
||||
/// Reduces available concurrency by a factor when load-based errors are detected.
|
||||
#[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)]
|
||||
pub struct Aimd {
|
||||
/// Minimum limit for AIMD algorithm.
|
||||
pub min: usize,
|
||||
/// Maximum limit for AIMD algorithm.
|
||||
pub max: usize,
|
||||
/// Decrease AIMD decrease by value in case of error.
|
||||
pub dec: f32,
|
||||
/// Increase AIMD increase by value in case of success.
|
||||
pub inc: usize,
|
||||
/// A threshold below which the limit won't be increased.
|
||||
pub utilisation: f32,
|
||||
}
|
||||
|
||||
impl LimitAlgorithm for Aimd {
|
||||
fn update(&self, old_limit: usize, sample: Sample) -> usize {
|
||||
use Outcome::*;
|
||||
match sample.outcome {
|
||||
Success => {
|
||||
let utilisation = sample.in_flight as f32 / old_limit as f32;
|
||||
|
||||
if utilisation > self.utilisation {
|
||||
let limit = old_limit + self.inc;
|
||||
let increased_limit = limit.clamp(self.min, self.max);
|
||||
if increased_limit > old_limit {
|
||||
tracing::info!(increased_limit, "limit increased");
|
||||
}
|
||||
|
||||
increased_limit
|
||||
} else {
|
||||
old_limit
|
||||
}
|
||||
}
|
||||
Overload => {
|
||||
let limit = old_limit as f32 * self.dec;
|
||||
|
||||
// Floor instead of round, so the limit reduces even with small numbers.
|
||||
// E.g. round(2 * 0.9) = 2, but floor(2 * 0.9) = 1
|
||||
let limit = limit.floor() as usize;
|
||||
|
||||
limit.clamp(self.min, self.max)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::rate_limiter::limit_algorithm::{
|
||||
DynamicLimiter, RateLimitAlgorithm, RateLimiterConfig,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn should_decrease_limit_on_overload() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 10,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 1500,
|
||||
inc: 10,
|
||||
dec: 0.5,
|
||||
utilisation: 0.8,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
token.release(Outcome::Overload);
|
||||
|
||||
assert_eq!(limiter.state().limit(), 5, "overload: decrease");
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn should_increase_limit_on_success_when_using_gt_util_threshold() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 4,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 1500,
|
||||
inc: 1,
|
||||
dec: 0.5,
|
||||
utilisation: 0.5,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
let _token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
let _token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
token.release(Outcome::Success);
|
||||
assert_eq!(limiter.state().limit(), 5, "success: increase");
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn should_not_change_limit_on_success_when_using_lt_util_threshold() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 4,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 1500,
|
||||
inc: 10,
|
||||
dec: 0.5,
|
||||
utilisation: 0.5,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
token.release(Outcome::Success);
|
||||
assert_eq!(
|
||||
limiter.state().limit(),
|
||||
4,
|
||||
"success: ignore when < half limit"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn should_not_change_limit_when_no_outcome() {
|
||||
let config = RateLimiterConfig {
|
||||
initial_limit: 10,
|
||||
algorithm: RateLimitAlgorithm::Aimd {
|
||||
conf: Aimd {
|
||||
min: 1,
|
||||
max: 1500,
|
||||
inc: 10,
|
||||
dec: 0.5,
|
||||
utilisation: 0.5,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let limiter = DynamicLimiter::new(config);
|
||||
|
||||
let token = limiter
|
||||
.acquire_timeout(Duration::from_millis(1))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(token);
|
||||
assert_eq!(limiter.state().limit(), 10, "ignore");
|
||||
}
|
||||
}
|
||||
@@ -232,9 +232,9 @@ impl ConnectMechanism for TokioMechanism {
|
||||
.connect_timeout(timeout);
|
||||
|
||||
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
|
||||
let res = config.connect(tokio_postgres::NoTls).await;
|
||||
drop(pause);
|
||||
drop(permit);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
|
||||
Ok(poll_client(
|
||||
|
||||
@@ -17,6 +17,7 @@ use hyper1::http::HeaderValue;
|
||||
use hyper1::Response;
|
||||
use hyper1::StatusCode;
|
||||
use hyper1::{HeaderMap, Request};
|
||||
use pq_proto::StartupMessageParamsBuilder;
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use tokio::time;
|
||||
@@ -192,13 +193,13 @@ fn get_conn_info(
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
let mut params = StartupMessageParamsBuilder::default();
|
||||
params.insert("user", &username);
|
||||
params.insert("database", &dbname);
|
||||
for (key, value) in pairs {
|
||||
match &*key {
|
||||
"options" => {
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
}
|
||||
"application_name" => ctx.set_application(Some(value.into())),
|
||||
_ => {}
|
||||
params.insert(&key, &value);
|
||||
if key == "options" {
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user