[proxy] review and cleanup CLI args (#12167)

I was looking at how we could expose our proxy config as toml again, and
as I was writing out the schema format, I noticed some cruft in our CLI
args that no longer seem to be in use.

The redis change is the most complex, but I am pretty sure it's sound.
Since https://github.com/neondatabase/cloud/pull/15613 cplane longer
publishes to the global redis instance.
This commit is contained in:
Conrad Ludgate
2025-06-26 12:25:41 +01:00
committed by GitHub
parent be23eae3b6
commit fd1e8ec257
12 changed files with 70 additions and 159 deletions

View File

@@ -279,7 +279,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
}, },
proxy_protocol_v2: config::ProxyProtocolV2::Rejected, proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10), handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks, connect_compute_locks,
connect_to_compute: compute_config, connect_to_compute: compute_config,

View File

@@ -236,7 +236,6 @@ pub(super) async fn task_main(
extra: None, extra: None,
}, },
crate::metrics::Protocol::SniRouter, crate::metrics::Protocol::SniRouter,
"sni",
); );
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
} }

View File

@@ -123,12 +123,6 @@ struct ProxyCliArgs {
/// timeout for the TLS handshake /// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration, handshake_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// cache for `wake_compute` api method (use `size=0` to disable) /// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String, wake_compute_cache: String,
@@ -155,40 +149,31 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second. /// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>, wake_compute_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Cancellation channel size (max queue size for redis kv client) /// Cancellation channel size (max queue size for redis kv client)
#[clap(long, default_value_t = 1024)] #[clap(long, default_value_t = 1024)]
cancellation_ch_size: usize, cancellation_ch_size: usize,
/// Cancellation ops batch size for redis /// Cancellation ops batch size for redis
#[clap(long, default_value_t = 8)] #[clap(long, default_value_t = 8)]
cancellation_batch_size: usize, cancellation_batch_size: usize,
/// cache for `allowed_ips` (use `size=0` to disable) /// redis url for plain authentication
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] #[clap(long, alias("redis-notifications"))]
allowed_ips_cache: String, redis_plain: Option<String>,
/// cache for `role_secret` (use `size=0` to disable) /// what from the available authentications type to use for redis. Supported are "irsa" and "plain".
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
#[clap(long)]
redis_notifications: Option<String>,
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
#[clap(long, default_value = "irsa")] #[clap(long, default_value = "irsa")]
redis_auth_type: String, redis_auth_type: String,
/// redis host for streaming connections (might be different from the notifications host) /// redis host for irsa authentication
#[clap(long)] #[clap(long)]
redis_host: Option<String>, redis_host: Option<String>,
/// redis port for streaming connections (might be different from the notifications host) /// redis port for irsa authentication
#[clap(long)] #[clap(long)]
redis_port: Option<u16>, redis_port: Option<u16>,
/// redis cluster name, used in aws elasticache /// redis cluster name for irsa authentication
#[clap(long)] #[clap(long)]
redis_cluster_name: Option<String>, redis_cluster_name: Option<String>,
/// redis user_id, used in aws elasticache /// redis user_id for irsa authentication
#[clap(long)] #[clap(long)]
redis_user_id: Option<String>, redis_user_id: Option<String>,
/// aws region to retrieve credentials /// aws region for irsa authentication
#[clap(long, default_value_t = String::new())] #[clap(long, default_value_t = String::new())]
aws_region: String, aws_region: String,
/// cache for `project_info` (use `size=0` to disable) /// cache for `project_info` (use `size=0` to disable)
@@ -200,6 +185,12 @@ struct ProxyCliArgs {
#[clap(flatten)] #[clap(flatten)]
parquet_upload: ParquetUploadArgs, parquet_upload: ParquetUploadArgs,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// interval for backup metric collection /// interval for backup metric collection
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)] #[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
metric_backup_collection_interval: std::time::Duration, metric_backup_collection_interval: std::time::Duration,
@@ -212,6 +203,7 @@ struct ProxyCliArgs {
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression. /// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
#[clap(long, default_value = "4194304")] #[clap(long, default_value = "4194304")]
metric_backup_collection_chunk_size: usize, metric_backup_collection_chunk_size: usize,
/// Whether to retry the connection to the compute node /// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String, connect_to_compute_retry: String,
@@ -331,7 +323,7 @@ pub async fn run() -> anyhow::Result<()> {
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"), Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
} }
info!("Using region: {}", args.aws_region); info!("Using region: {}", args.aws_region);
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?; let redis_client = configure_redis(&args).await?;
// Check that we can bind to address before further initialization // Check that we can bind to address before further initialization
info!("Starting http on {}", args.http); info!("Starting http on {}", args.http);
@@ -386,13 +378,6 @@ pub async fn run() -> anyhow::Result<()> {
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
RateBucketInfo::validate(redis_rps_limit)?;
let redis_kv_client = regional_redis_client
.as_ref()
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute)); let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
@@ -472,6 +457,7 @@ pub async fn run() -> anyhow::Result<()> {
client_tasks.spawn(crate::context::parquet::worker( client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(), cancellation_token.clone(),
args.parquet_upload, args.parquet_upload,
args.region,
)); ));
// maintenance tasks. these never return unless there's an error // maintenance tasks. these never return unless there's an error
@@ -495,32 +481,17 @@ pub async fn run() -> anyhow::Result<()> {
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))] #[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend { if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api { if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) { if let Some(client) = redis_client {
(None, None) => {} // project info cache and invalidation of that cache.
(client1, client2) => {
let cache = api.caches.project_info.clone(); let cache = api.caches.project_info.clone();
if let Some(client) = client1 { maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval. // Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart, // This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues. // which can cause hammering of the redis in case of connection issues.
if let Some(mut redis_kv_client) = redis_kv_client { // cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() { for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await { match redis_kv_client.try_connect().await {
Ok(()) => { Ok(()) => {
@@ -545,14 +516,12 @@ pub async fn run() -> anyhow::Result<()> {
} }
} }
} }
}
if let Some(regional_redis_client) = regional_redis_client { // listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone(); let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache"); let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn( maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await } async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span), .instrument(span),
); );
} }
@@ -681,7 +650,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
authentication_config, authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2, proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout, handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks, connect_compute_locks,
connect_to_compute: compute_config, connect_to_compute: compute_config,
@@ -843,21 +811,18 @@ fn build_auth_backend(
async fn configure_redis( async fn configure_redis(
args: &ProxyCliArgs, args: &ProxyCliArgs,
) -> anyhow::Result<( ) -> anyhow::Result<Option<ConnectionWithCredentialsProvider>> {
Option<ConnectionWithCredentialsProvider>,
Option<ConnectionWithCredentialsProvider>,
)> {
// TODO: untangle the config args // TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) { let redis_client = match &*args.redis_auth_type {
("plain", redis_url) => match redis_url { "plain" => match &args.redis_plain {
None => { None => {
bail!("plain auth requires redis_notifications to be set"); bail!("plain auth requires redis_plain to be set");
} }
Some(url) => { Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())) Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
} }
}, },
("irsa", _) => match (&args.redis_host, args.redis_port) { "irsa" => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some( (Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider( ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.clone(), host.clone(),
@@ -881,18 +846,12 @@ async fn configure_redis(
bail!("redis-host and redis-port must be specified together"); bail!("redis-host and redis-port must be specified together");
} }
}, },
_ => { auth_type => {
bail!("unknown auth type given"); bail!("unknown auth type {auth_type:?} given")
} }
}; };
let redis_notifications_client = if let Some(url) = &args.redis_notifications { Ok(redis_client)
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
} else {
regional_redis_client.clone()
};
Ok((regional_redis_client, redis_notifications_client))
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -22,7 +22,6 @@ pub struct ProxyConfig {
pub http_config: HttpConfig, pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig, pub authentication_config: AuthenticationConfig,
pub proxy_protocol_v2: ProxyProtocolV2, pub proxy_protocol_v2: ProxyProtocolV2,
pub region: String,
pub handshake_timeout: Duration, pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig, pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>, pub connect_compute_locks: ApiLocks<Host>,

View File

@@ -89,12 +89,7 @@ pub async fn task_main(
} }
} }
let ctx = RequestContext::new( let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let res = handle_client( let res = handle_client(
config, config,

View File

@@ -46,7 +46,6 @@ struct RequestContextInner {
pub(crate) session_id: Uuid, pub(crate) session_id: Uuid,
pub(crate) protocol: Protocol, pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>, first_packet: chrono::DateTime<Utc>,
region: &'static str,
pub(crate) span: Span, pub(crate) span: Span,
// filled in as they are discovered // filled in as they are discovered
@@ -94,7 +93,6 @@ impl Clone for RequestContext {
session_id: inner.session_id, session_id: inner.session_id,
protocol: inner.protocol, protocol: inner.protocol,
first_packet: inner.first_packet, first_packet: inner.first_packet,
region: inner.region,
span: info_span!("background_task"), span: info_span!("background_task"),
project: inner.project, project: inner.project,
@@ -124,12 +122,7 @@ impl Clone for RequestContext {
} }
impl RequestContext { impl RequestContext {
pub fn new( pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
session_id: Uuid,
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
) -> Self {
// TODO: be careful with long lived spans // TODO: be careful with long lived spans
let span = info_span!( let span = info_span!(
"connect_request", "connect_request",
@@ -145,7 +138,6 @@ impl RequestContext {
session_id, session_id,
protocol, protocol,
first_packet: Utc::now(), first_packet: Utc::now(),
region,
span, span,
project: None, project: None,
@@ -179,7 +171,7 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]); let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432); let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None }; let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test") RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
} }
pub(crate) fn console_application_name(&self) -> String { pub(crate) fn console_application_name(&self) -> String {

View File

@@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
#[derive(parquet_derive::ParquetRecordWriter)] #[derive(parquet_derive::ParquetRecordWriter)]
pub(crate) struct RequestData { pub(crate) struct RequestData {
region: &'static str, region: String,
protocol: &'static str, protocol: &'static str,
/// Must be UTC. The derive macro doesn't like the timezones /// Must be UTC. The derive macro doesn't like the timezones
timestamp: chrono::NaiveDateTime, timestamp: chrono::NaiveDateTime,
@@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData {
}), }),
jwt_issuer: value.jwt_issuer.clone(), jwt_issuer: value.jwt_issuer.clone(),
protocol: value.protocol.as_str(), protocol: value.protocol.as_str(),
region: value.region, region: String::new(),
error: value.error_kind.as_ref().map(|e| e.to_metric_label()), error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success, success: value.success,
cold_start_info: value.cold_start_info.as_str(), cold_start_info: value.cold_start_info.as_str(),
@@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData {
pub async fn worker( pub async fn worker(
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
config: ParquetUploadArgs, config: ParquetUploadArgs,
region: String,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let Some(remote_storage_config) = config.parquet_upload_remote_storage else { let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
tracing::warn!("parquet request upload: no s3 bucket configured"); tracing::warn!("parquet request upload: no s3 bucket configured");
@@ -232,12 +233,17 @@ pub async fn worker(
.context("remote storage for disconnect events init")?; .context("remote storage for disconnect events init")?;
let parquet_config_disconnect = parquet_config.clone(); let parquet_config_disconnect = parquet_config.clone();
tokio::try_join!( tokio::try_join!(
worker_inner(storage, rx, parquet_config), worker_inner(storage, rx, parquet_config, &region),
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect) worker_inner(
storage_disconnect,
rx_disconnect,
parquet_config_disconnect,
&region
)
) )
.map(|_| ()) .map(|_| ())
} else { } else {
worker_inner(storage, rx, parquet_config).await worker_inner(storage, rx, parquet_config, &region).await
} }
} }
@@ -257,6 +263,7 @@ async fn worker_inner(
storage: GenericRemoteStorage, storage: GenericRemoteStorage,
rx: impl Stream<Item = RequestData>, rx: impl Stream<Item = RequestData>,
config: ParquetConfig, config: ParquetConfig,
region: &str,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
#[cfg(any(test, feature = "testing"))] #[cfg(any(test, feature = "testing"))]
let storage = if config.test_remote_failures > 0 { let storage = if config.test_remote_failures > 0 {
@@ -277,7 +284,8 @@ async fn worker_inner(
let mut last_upload = time::Instant::now(); let mut last_upload = time::Instant::now();
let mut len = 0; let mut len = 0;
while let Some(row) = rx.next().await { while let Some(mut row) = rx.next().await {
region.clone_into(&mut row.region);
rows.push(row); rows.push(row);
let force = last_upload.elapsed() > config.max_duration; let force = last_upload.elapsed() > config.max_duration;
if rows.len() == config.rows_per_group || force { if rows.len() == config.rows_per_group || force {
@@ -533,7 +541,7 @@ mod tests {
auth_method: None, auth_method: None,
jwt_issuer: None, jwt_issuer: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: "us-east-1", region: String::new(),
error: None, error: None,
success: rng.r#gen(), success: rng.r#gen(),
cold_start_info: "no", cold_start_info: "no",
@@ -565,7 +573,9 @@ mod tests {
.await .await
.unwrap(); .unwrap();
worker_inner(storage, rx, config).await.unwrap(); worker_inner(storage, rx, config, "us-east-1")
.await
.unwrap();
let mut files = WalkDir::new(tmpdir.as_std_path()) let mut files = WalkDir::new(tmpdir.as_std_path())
.into_iter() .into_iter()

View File

@@ -122,12 +122,7 @@ pub async fn task_main(
} }
} }
let ctx = RequestContext::new( let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let res = handle_client( let res = handle_client(
config, config,

View File

@@ -139,12 +139,6 @@ impl RateBucketInfo {
Self::new(200, Duration::from_secs(600)), Self::new(200, Duration::from_secs(600)),
]; ];
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
pub const DEFAULT_REDIS_SET: [Self; 2] = [
Self::new(100_000, Duration::from_secs(1)),
Self::new(50_000, Duration::from_secs(10)),
];
pub fn rps(&self) -> f64 { pub fn rps(&self) -> f64 {
(self.max_rpi as f64) / self.interval.as_secs_f64() (self.max_rpi as f64) / self.interval.as_secs_f64()
} }

View File

@@ -5,11 +5,9 @@ use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult}; use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
pub struct RedisKVClient { pub struct RedisKVClient {
client: ConnectionWithCredentialsProvider, client: ConnectionWithCredentialsProvider,
limiter: GlobalRateLimiter,
} }
#[allow(async_fn_in_trait)] #[allow(async_fn_in_trait)]
@@ -30,11 +28,8 @@ impl Queryable for Cmd {
} }
impl RedisKVClient { impl RedisKVClient {
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self { pub fn new(client: ConnectionWithCredentialsProvider) -> Self {
Self { Self { client }
client,
limiter: GlobalRateLimiter::new(info.into()),
}
} }
pub async fn try_connect(&mut self) -> anyhow::Result<()> { pub async fn try_connect(&mut self) -> anyhow::Result<()> {
@@ -49,11 +44,6 @@ impl RedisKVClient {
&mut self, &mut self,
q: &impl Queryable, q: &impl Queryable,
) -> anyhow::Result<T> { ) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping query");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
let e = match q.query(&mut self.client).await { let e = match q.query(&mut self.client).await {
Ok(t) => return Ok(t), Ok(t) => return Ok(t),
Err(e) => e, Err(e) => e,

View File

@@ -141,29 +141,19 @@ where
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> { struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
cache: Arc<C>, cache: Arc<C>,
region_id: String,
} }
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> { impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
cache: self.cache.clone(), cache: self.cache.clone(),
region_id: self.region_id.clone(),
} }
} }
} }
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> { impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self { pub(crate) fn new(cache: Arc<C>) -> Self {
Self { cache, region_id } Self { cache }
}
pub(crate) async fn increment_active_listeners(&self) {
self.cache.increment_active_listeners().await;
}
pub(crate) async fn decrement_active_listeners(&self) {
self.cache.decrement_active_listeners().await;
} }
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
@@ -276,7 +266,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
} }
let mut conn = match try_connect(&redis).await { let mut conn = match try_connect(&redis).await {
Ok(conn) => { Ok(conn) => {
handler.increment_active_listeners().await; handler.cache.increment_active_listeners().await;
conn conn
} }
Err(e) => { Err(e) => {
@@ -297,11 +287,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
} }
} }
if cancellation_token.is_cancelled() { if cancellation_token.is_cancelled() {
handler.decrement_active_listeners().await; handler.cache.decrement_active_listeners().await;
return Ok(()); return Ok(());
} }
} }
handler.decrement_active_listeners().await; handler.cache.decrement_active_listeners().await;
} }
} }
@@ -310,12 +300,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
pub async fn task_main<C>( pub async fn task_main<C>(
redis: ConnectionWithCredentialsProvider, redis: ConnectionWithCredentialsProvider,
cache: Arc<C>, cache: Arc<C>,
region_id: String,
) -> anyhow::Result<Infallible> ) -> anyhow::Result<Infallible>
where where
C: ProjectInfoCache + Send + Sync + 'static, C: ProjectInfoCache + Send + Sync + 'static,
{ {
let handler = MessageHandler::new(cache, region_id); let handler = MessageHandler::new(cache);
// 6h - 1m. // 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost. // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60)); let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));

View File

@@ -417,12 +417,7 @@ async fn request_handler(
if config.http_config.accept_websockets if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request) && framed_websockets::upgrade::is_upgrade_request(&request)
{ {
let ctx = RequestContext::new( let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
session_id,
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
);
ctx.set_user_agent( ctx.set_user_agent(
request request
@@ -462,12 +457,7 @@ async fn request_handler(
// Return the response so the spawned future can continue. // Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed())) Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST { } else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestContext::new( let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
session_id,
conn_info,
crate::metrics::Protocol::Http,
&config.region,
);
let span = ctx.span(); let span = ctx.span();
let testodrome_id = request let testodrome_id = request