From 331935df91abe03a1e8a081bc96b6ef871f71bb1 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Tue, 13 Feb 2024 17:58:58 +0100 Subject: [PATCH] Proxy: send cancel notifications to all instances (#6719) ## Problem If cancel request ends up on the wrong proxy instance, it doesn't take an effect. ## Summary of changes Send redis notifications to all proxy pods about the cancel request. Related issue: https://github.com/neondatabase/neon/issues/5839, https://github.com/neondatabase/cloud/issues/10262 --- Cargo.lock | 7 +- Cargo.toml | 2 +- libs/pq_proto/Cargo.toml | 1 + libs/pq_proto/src/lib.rs | 3 +- proxy/src/bin/proxy.rs | 32 ++++- proxy/src/cancellation.rs | 109 ++++++++++++++--- proxy/src/config.rs | 1 + proxy/src/metrics.rs | 9 ++ proxy/src/proxy.rs | 16 +-- proxy/src/rate_limiter.rs | 2 +- proxy/src/rate_limiter/limiter.rs | 38 ++++++ proxy/src/redis.rs | 1 + proxy/src/redis/notifications.rs | 197 +++++++++++++++++++++++------- proxy/src/redis/publisher.rs | 80 ++++++++++++ proxy/src/serverless.rs | 13 +- proxy/src/serverless/websocket.rs | 6 +- workspace_hack/Cargo.toml | 4 +- 17 files changed, 432 insertions(+), 89 deletions(-) create mode 100644 proxy/src/redis/publisher.rs diff --git a/Cargo.lock b/Cargo.lock index f11c774016..45a313a72b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2263,11 +2263,11 @@ dependencies = [ [[package]] name = "hashlink" -version = "0.8.2" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown 0.13.2", + "hashbrown 0.14.0", ] [[package]] @@ -3952,6 +3952,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "rand 0.8.5", + "serde", "thiserror", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 8df9ca9988..8952f7627f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,7 @@ futures-core = "0.3" futures-util = "0.3" git-version = "0.3" hashbrown = "0.13" -hashlink = "0.8.1" +hashlink = "0.8.4" hdrhistogram = "7.5.2" hex = "0.4" hex-literal = "0.4" diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index b286eb0358..6eeb3bafef 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -13,5 +13,6 @@ rand.workspace = true tokio.workspace = true tracing.workspace = true thiserror.workspace = true +serde.workspace = true workspace_hack.workspace = true diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index c52a21bcd3..522b65f5d1 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -7,6 +7,7 @@ pub mod framed; use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use serde::{Deserialize, Serialize}; use std::{borrow::Cow, collections::HashMap, fmt, io, str}; // re-export for use in utils pageserver_feedback.rs @@ -123,7 +124,7 @@ impl StartupMessageParams { } } -#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub struct CancelKeyData { pub backend_pid: i32, pub cancel_key: i32, diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 00a229c135..b3d4fc0411 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,6 +1,8 @@ use futures::future::Either; use proxy::auth; use proxy::auth::backend::MaybeOwned; +use proxy::cancellation::CancelMap; +use proxy::cancellation::CancellationHandler; use proxy::config::AuthenticationConfig; use proxy::config::CacheOptions; use proxy::config::HttpConfig; @@ -12,6 +14,7 @@ use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateLimiterConfig; use proxy::redis::notifications; +use proxy::redis::publisher::RedisPublisherClient; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -22,6 +25,7 @@ use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; use tokio::net::TcpListener; +use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::info; @@ -129,6 +133,9 @@ struct ProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] endpoint_rps_limit: Vec, + /// Redis rate limiter max number of requests per second. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] + redis_rps_limit: Vec, /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`. #[clap(long, default_value_t = 100)] initial_limit: usize, @@ -225,6 +232,19 @@ async fn main() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit)); + let cancel_map = CancelMap::default(); + let redis_publisher = match &args.redis_notifications { + Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( + url, + args.region.clone(), + &config.redis_rps_limit, + )?))), + None => None, + }; + let cancellation_handler = Arc::new(CancellationHandler::new( + cancel_map.clone(), + redis_publisher, + )); // client facing tasks. these will exit on error or on cancellation // cancellation returns Ok(()) @@ -234,6 +254,7 @@ async fn main() -> anyhow::Result<()> { proxy_listener, cancellation_token.clone(), endpoint_rate_limiter.clone(), + cancellation_handler.clone(), )); // TODO: rename the argument to something like serverless. @@ -248,6 +269,7 @@ async fn main() -> anyhow::Result<()> { serverless_listener, cancellation_token.clone(), endpoint_rate_limiter.clone(), + cancellation_handler.clone(), )); } @@ -271,7 +293,12 @@ async fn main() -> anyhow::Result<()> { let cache = api.caches.project_info.clone(); if let Some(url) = args.redis_notifications { info!("Starting redis notifications listener ({url})"); - maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone())); + maintenance_tasks.spawn(notifications::task_main( + url.to_owned(), + cache.clone(), + cancel_map.clone(), + args.region.clone(), + )); } maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } @@ -403,6 +430,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); RateBucketInfo::validate(&mut endpoint_rps_limit)?; + let mut redis_rps_limit = args.redis_rps_limit.clone(); + RateBucketInfo::validate(&mut redis_rps_limit)?; let config = Box::leak(Box::new(ProxyConfig { tls_config, @@ -414,6 +443,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { require_client_ip: args.require_client_ip, disable_ip_check_for_http: args.disable_ip_check_for_http, endpoint_rps_limit, + redis_rps_limit, handshake_timeout: args.handshake_timeout, // TODO: add this argument region: args.region.clone(), diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index fe614628d8..93a77bc4ae 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,16 +1,28 @@ +use async_trait::async_trait; use dashmap::DashMap; use pq_proto::CancelKeyData; use std::{net::SocketAddr, sync::Arc}; use thiserror::Error; use tokio::net::TcpStream; +use tokio::sync::Mutex; use tokio_postgres::{CancelToken, NoTls}; use tracing::info; +use uuid::Uuid; -use crate::error::ReportableError; +use crate::{ + error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS, + redis::publisher::RedisPublisherClient, +}; + +pub type CancelMap = Arc>>; /// Enables serving `CancelRequest`s. -#[derive(Default)] -pub struct CancelMap(DashMap>); +/// +/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances. +pub struct CancellationHandler { + map: CancelMap, + redis_client: Option>>, +} #[derive(Debug, Error)] pub enum CancelError { @@ -32,15 +44,43 @@ impl ReportableError for CancelError { } } -impl CancelMap { +impl CancellationHandler { + pub fn new(map: CancelMap, redis_client: Option>>) -> Self { + Self { map, redis_client } + } /// Cancel a running query for the corresponding connection. - pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> { + pub async fn cancel_session( + &self, + key: CancelKeyData, + session_id: Uuid, + ) -> Result<(), CancelError> { + let from = "from_client"; // NB: we should immediately release the lock after cloning the token. - let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else { + let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { tracing::warn!("query cancellation key not found: {key}"); + if let Some(redis_client) = &self.redis_client { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + info!("publishing cancellation key to Redis"); + match redis_client.lock().await.try_publish(key, session_id).await { + Ok(()) => { + info!("cancellation key successfuly published to Redis"); + } + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + return Err(CancelError::IO(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + ))); + } + } + } return Ok(()); }; - + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); info!("cancelling query per user's request using key {key}"); cancel_closure.try_cancel_query().await } @@ -57,7 +97,7 @@ impl CancelMap { // Random key collisions are unlikely to happen here, but they're still possible, // which is why we have to take care not to rewrite an existing key. - match self.0.entry(key) { + match self.map.entry(key) { dashmap::mapref::entry::Entry::Occupied(_) => continue, dashmap::mapref::entry::Entry::Vacant(e) => { e.insert(None); @@ -69,18 +109,46 @@ impl CancelMap { info!("registered new query cancellation key {key}"); Session { key, - cancel_map: self, + cancellation_handler: self, } } #[cfg(test)] fn contains(&self, session: &Session) -> bool { - self.0.contains_key(&session.key) + self.map.contains_key(&session.key) } #[cfg(test)] fn is_empty(&self) -> bool { - self.0.is_empty() + self.map.is_empty() + } +} + +#[async_trait] +pub trait NotificationsCancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>; +} + +#[async_trait] +impl NotificationsCancellationHandler for CancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> { + let from = "from_redis"; + let cancel_closure = self.map.get(&key).and_then(|x| x.clone()); + match cancel_closure { + Some(cancel_closure) => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); + cancel_closure.try_cancel_query().await + } + None => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + tracing::warn!("query cancellation key not found: {key}"); + Ok(()) + } + } } } @@ -115,7 +183,7 @@ pub struct Session { /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. - cancel_map: Arc, + cancellation_handler: Arc, } impl Session { @@ -123,7 +191,9 @@ impl Session { /// This enables query cancellation in `crate::proxy::prepare_client_connection`. pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); - self.cancel_map.0.insert(self.key, Some(cancel_closure)); + self.cancellation_handler + .map + .insert(self.key, Some(cancel_closure)); self.key } @@ -131,7 +201,7 @@ impl Session { impl Drop for Session { fn drop(&mut self) { - self.cancel_map.0.remove(&self.key); + self.cancellation_handler.map.remove(&self.key); info!("dropped query cancellation key {}", &self.key); } } @@ -142,13 +212,16 @@ mod tests { #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { - let cancel_map: Arc = Default::default(); + let cancellation_handler = Arc::new(CancellationHandler { + map: CancelMap::default(), + redis_client: None, + }); - let session = cancel_map.clone().get_session(); - assert!(cancel_map.contains(&session)); + let session = cancellation_handler.clone().get_session(); + assert!(cancellation_handler.contains(&session)); drop(session); // Check that the session has been dropped. - assert!(cancel_map.is_empty()); + assert!(cancellation_handler.is_empty()); Ok(()) } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 5fcb537834..9f276c3c24 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -21,6 +21,7 @@ pub struct ProxyConfig { pub require_client_ip: bool, pub disable_ip_check_for_http: bool, pub endpoint_rps_limit: Vec, + pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index f7f162a075..66031f5eb2 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -152,6 +152,15 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy = Lazy::new(|| { .unwrap() }); +pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_cancellation_requests_total", + "Number of cancellation requests (per found/not_found).", + &["source", "kind"], + ) + .unwrap() +}); + #[derive(Clone)] pub struct LatencyTimer { // time since the stopwatch was started diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 5f65de4c98..ce77098a5f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -10,7 +10,7 @@ pub mod wake_compute; use crate::{ auth, - cancellation::{self, CancelMap}, + cancellation::{self, CancellationHandler}, compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, @@ -62,6 +62,7 @@ pub async fn task_main( listener: tokio::net::TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); @@ -72,7 +73,6 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); - let cancel_map = Arc::new(CancelMap::default()); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -80,7 +80,7 @@ pub async fn task_main( let (socket, peer_addr) = accept_result?; let session_id = uuid::Uuid::new_v4(); - let cancel_map = Arc::clone(&cancel_map); + let cancellation_handler = Arc::clone(&cancellation_handler); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); let session_span = info_span!( @@ -113,7 +113,7 @@ pub async fn task_main( let res = handle_client( config, &mut ctx, - cancel_map, + cancellation_handler, socket, ClientMode::Tcp, endpoint_rate_limiter, @@ -227,7 +227,7 @@ impl ReportableError for ClientRequestError { pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - cancel_map: Arc, + cancellation_handler: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, @@ -253,8 +253,8 @@ pub async fn handle_client( match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancel_map - .cancel_session(cancel_key_data) + return Ok(cancellation_handler + .cancel_session(cancel_key_data, ctx.session_id) .await .map(|()| None)?) } @@ -315,7 +315,7 @@ pub async fn handle_client( .or_else(|e| stream.throw_error(e)) .await?; - let session = cancel_map.get_session(); + let session = cancellation_handler.get_session(); prepare_client_connection(&node, &session, &mut stream).await?; // Before proxy passing, forward to compute whatever data is left in the diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index b26386d159..f0da4ead23 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -4,4 +4,4 @@ mod limiter; pub use aimd::Aimd; pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; pub use limiter::Limiter; -pub use limiter::{EndpointRateLimiter, RateBucketInfo}; +pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter}; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index cbae72711c..3181060e2f 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -22,6 +22,44 @@ use super::{ RateLimiterConfig, }; +pub struct RedisRateLimiter { + data: Vec, + info: &'static [RateBucketInfo], +} + +impl RedisRateLimiter { + pub fn new(info: &'static [RateBucketInfo]) -> Self { + Self { + data: vec![ + RateBucket { + start: Instant::now(), + count: 0, + }; + info.len() + ], + info, + } + } + + /// Check that number of connections is below `max_rps` rps. + pub fn check(&mut self) -> bool { + let now = Instant::now(); + + let should_allow_request = self + .data + .iter_mut() + .zip(self.info) + .all(|(bucket, info)| bucket.should_allow_request(info, now)); + + if should_allow_request { + // only increment the bucket counts if the request will actually be accepted + self.data.iter_mut().for_each(RateBucket::inc); + } + + should_allow_request + } +} + // Simple per-endpoint rate limiter. // // Check that number of connections to the endpoint is below `max_rps` rps. diff --git a/proxy/src/redis.rs b/proxy/src/redis.rs index c2a91bed97..35d6db074e 100644 --- a/proxy/src/redis.rs +++ b/proxy/src/redis.rs @@ -1 +1,2 @@ pub mod notifications; +pub mod publisher; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 158884aa17..b8297a206c 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -1,38 +1,44 @@ use std::{convert::Infallible, sync::Arc}; use futures::StreamExt; +use pq_proto::CancelKeyData; use redis::aio::PubSub; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; use crate::{ cache::project_info::ProjectInfoCache, + cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler}, intern::{ProjectIdInt, RoleNameInt}, }; -const CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; +const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; +pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20); -struct ConsoleRedisClient { +struct RedisConsumerClient { client: redis::Client, } -impl ConsoleRedisClient { +impl RedisConsumerClient { pub fn new(url: &str) -> anyhow::Result { let client = redis::Client::open(url)?; Ok(Self { client }) } async fn try_connect(&self) -> anyhow::Result { let mut conn = self.client.get_async_connection().await?.into_pubsub(); - tracing::info!("subscribing to a channel `{CHANNEL_NAME}`"); - conn.subscribe(CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); + conn.subscribe(CPLANE_CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); + conn.subscribe(PROXY_CHANNEL_NAME).await?; Ok(conn) } } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(tag = "topic", content = "data")] -enum Notification { +pub(crate) enum Notification { #[serde( rename = "/allowed_ips_updated", deserialize_with = "deserialize_json_string" @@ -45,16 +51,25 @@ enum Notification { deserialize_with = "deserialize_json_string" )] PasswordUpdate { password_update: PasswordUpdate }, + #[serde(rename = "/cancel_session")] + Cancel(CancelSession), } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] -struct AllowedIpsUpdate { +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct AllowedIpsUpdate { project_id: ProjectIdInt, } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] -struct PasswordUpdate { +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct PasswordUpdate { project_id: ProjectIdInt, role_name: RoleNameInt, } +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct CancelSession { + pub region_id: Option, + pub cancel_key_data: CancelKeyData, + pub session_id: Uuid, +} + fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, @@ -64,6 +79,88 @@ where serde_json::from_str(&s).map_err(::custom) } +struct MessageHandler< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, +> { + cache: Arc, + cancellation_handler: Arc, + region_id: String, +} + +impl< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, + > MessageHandler +{ + pub fn new(cache: Arc, cancellation_handler: Arc, region_id: String) -> Self { + Self { + cache, + cancellation_handler, + region_id, + } + } + pub fn disable_ttl(&self) { + self.cache.disable_ttl(); + } + pub fn enable_ttl(&self) { + self.cache.enable_ttl(); + } + #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] + async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> { + use Notification::*; + let payload: String = msg.get_payload()?; + tracing::debug!(?payload, "received a message payload"); + + let msg: Notification = match serde_json::from_str(&payload) { + Ok(msg) => msg, + Err(e) => { + tracing::error!("broken message: {e}"); + return Ok(()); + } + }; + tracing::debug!(?msg, "received a message"); + match msg { + Cancel(cancel_session) => { + tracing::Span::current().record( + "session_id", + &tracing::field::display(cancel_session.session_id), + ); + if let Some(cancel_region) = cancel_session.region_id { + // If the message is not for this region, ignore it. + if cancel_region != self.region_id { + return Ok(()); + } + } + // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message. + match self + .cancellation_handler + .cancel_session_no_publish(cancel_session.cancel_key_data) + .await + { + Ok(()) => {} + Err(e) => { + tracing::error!("failed to cancel session: {e}"); + } + } + } + _ => { + invalidate_cache(self.cache.clone(), msg.clone()); + // It might happen that the invalid entry is on the way to be cached. + // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. + // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message. + let cache = self.cache.clone(); + tokio::spawn(async move { + tokio::time::sleep(INVALIDATION_LAG).await; + invalidate_cache(cache, msg); + }); + } + } + + Ok(()) + } +} + fn invalidate_cache(cache: Arc, msg: Notification) { use Notification::*; match msg { @@ -74,50 +171,33 @@ fn invalidate_cache(cache: Arc, msg: Notification) { password_update.project_id, password_update.role_name, ), + Cancel(_) => unreachable!("cancel message should be handled separately"), } } -#[tracing::instrument(skip(cache))] -fn handle_message(msg: redis::Msg, cache: Arc) -> anyhow::Result<()> -where - C: ProjectInfoCache + Send + Sync + 'static, -{ - let payload: String = msg.get_payload()?; - tracing::debug!(?payload, "received a message payload"); - - let msg: Notification = match serde_json::from_str(&payload) { - Ok(msg) => msg, - Err(e) => { - tracing::error!("broken message: {e}"); - return Ok(()); - } - }; - tracing::debug!(?msg, "received a message"); - invalidate_cache(cache.clone(), msg.clone()); - // It might happen that the invalid entry is on the way to be cached. - // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. - // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message. - tokio::spawn(async move { - tokio::time::sleep(INVALIDATION_LAG).await; - invalidate_cache(cache, msg.clone()); - }); - - Ok(()) -} - /// Handle console's invalidation messages. #[tracing::instrument(name = "console_notifications", skip_all)] -pub async fn task_main(url: String, cache: Arc) -> anyhow::Result +pub async fn task_main( + url: String, + cache: Arc, + cancel_map: CancelMap, + region_id: String, +) -> anyhow::Result where C: ProjectInfoCache + Send + Sync + 'static, { cache.enable_ttl(); + let handler = MessageHandler::new( + cache, + Arc::new(CancellationHandler::new(cancel_map, None)), + region_id, + ); loop { - let redis = ConsoleRedisClient::new(&url)?; + let redis = RedisConsumerClient::new(&url)?; let conn = match redis.try_connect().await { Ok(conn) => { - cache.disable_ttl(); + handler.disable_ttl(); conn } Err(e) => { @@ -130,7 +210,7 @@ where }; let mut stream = conn.into_on_message(); while let Some(msg) = stream.next().await { - match handle_message(msg, cache.clone()) { + match handler.handle_message(msg).await { Ok(()) => {} Err(e) => { tracing::error!("failed to handle message: {e}, will try to reconnect"); @@ -138,7 +218,7 @@ where } } } - cache.enable_ttl(); + handler.enable_ttl(); } } @@ -198,6 +278,33 @@ mod tests { } ); + Ok(()) + } + #[test] + fn parse_cancel_session() -> anyhow::Result<()> { + let cancel_key_data = CancelKeyData { + backend_pid: 42, + cancel_key: 41, + }; + let uuid = uuid::Uuid::new_v4(); + let msg = Notification::Cancel(CancelSession { + cancel_key_data, + region_id: None, + session_id: uuid, + }); + let text = serde_json::to_string(&msg)?; + let result: Notification = serde_json::from_str(&text)?; + assert_eq!(msg, result); + + let msg = Notification::Cancel(CancelSession { + cancel_key_data, + region_id: Some("region".to_string()), + session_id: uuid, + }); + let text = serde_json::to_string(&msg)?; + let result: Notification = serde_json::from_str(&text)?; + assert_eq!(msg, result,); + Ok(()) } } diff --git a/proxy/src/redis/publisher.rs b/proxy/src/redis/publisher.rs new file mode 100644 index 0000000000..f85593afdd --- /dev/null +++ b/proxy/src/redis/publisher.rs @@ -0,0 +1,80 @@ +use pq_proto::CancelKeyData; +use redis::AsyncCommands; +use uuid::Uuid; + +use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; + +use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}; + +pub struct RedisPublisherClient { + client: redis::Client, + publisher: Option, + region_id: String, + limiter: RedisRateLimiter, +} + +impl RedisPublisherClient { + pub fn new( + url: &str, + region_id: String, + info: &'static [RateBucketInfo], + ) -> anyhow::Result { + let client = redis::Client::open(url)?; + Ok(Self { + client, + publisher: None, + region_id, + limiter: RedisRateLimiter::new(info), + }) + } + pub async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping cancellation message"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + match self.publish(cancel_key_data, session_id).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + self.publisher = None; + } + } + tracing::info!("Publisher is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.publish(cancel_key_data, session_id).await + } + + async fn publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + let conn = self + .publisher + .as_mut() + .ok_or_else(|| anyhow::anyhow!("not connected"))?; + let payload = serde_json::to_string(&Notification::Cancel(CancelSession { + region_id: Some(self.region_id.clone()), + cancel_key_data, + session_id, + }))?; + conn.publish(PROXY_CHANNEL_NAME, payload).await?; + Ok(()) + } + pub async fn try_connect(&mut self) -> anyhow::Result<()> { + match self.client.get_async_connection().await { + Ok(conn) => { + self.publisher = Some(conn); + } + Err(e) => { + tracing::error!("failed to connect to redis: {e}"); + return Err(e.into()); + } + } + Ok(()) + } +} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index a20600b94a..ee3e91495b 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -24,7 +24,7 @@ use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; -use crate::{cancellation::CancelMap, config::ProxyConfig}; +use crate::{cancellation::CancellationHandler, config::ProxyConfig}; use futures::StreamExt; use hyper::{ server::{ @@ -50,6 +50,7 @@ pub async fn task_main( ws_listener: TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("websocket server has shut down"); @@ -115,7 +116,7 @@ pub async fn task_main( let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - + let cancellation_handler = cancellation_handler.clone(); async move { let peer_addr = match client_addr { Some(addr) => addr, @@ -127,9 +128,9 @@ pub async fn task_main( let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellation_handler = cancellation_handler.clone(); async move { - let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); request_handler( @@ -137,7 +138,7 @@ pub async fn task_main( config, backend, ws_connections, - cancel_map, + cancellation_handler, session_id, peer_addr.ip(), endpoint_rate_limiter, @@ -205,7 +206,7 @@ async fn request_handler( config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, - cancel_map: Arc, + cancellation_handler: Arc, session_id: uuid::Uuid, peer_addr: IpAddr, endpoint_rate_limiter: Arc, @@ -232,7 +233,7 @@ async fn request_handler( config, ctx, websocket, - cancel_map, + cancellation_handler, host, endpoint_rate_limiter, ) diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 062dd440b2..24f2bb7e8c 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -1,5 +1,5 @@ use crate::{ - cancellation::CancelMap, + cancellation::CancellationHandler, config::ProxyConfig, context::RequestMonitoring, error::{io_error, ReportableError}, @@ -133,7 +133,7 @@ pub async fn serve_websocket( config: &'static ProxyConfig, mut ctx: RequestMonitoring, websocket: HyperWebsocket, - cancel_map: Arc, + cancellation_handler: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { @@ -141,7 +141,7 @@ pub async fn serve_websocket( let res = handle_client( config, &mut ctx, - cancel_map, + cancellation_handler, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, endpoint_rate_limiter, diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 8e9cc43152..e808fabbe7 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -38,7 +38,7 @@ futures-io = { version = "0.3" } futures-sink = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] } hex = { version = "0.4", features = ["serde"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } @@ -91,7 +91,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } either = { version = "1" } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } libc = { version = "0.2", features = ["extra_traits", "use_std"] }