From 222cc181e9314fe6c7596f06d677a92875f4aa8b Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Wed, 29 Jan 2025 13:19:10 +0200 Subject: [PATCH] impr(proxy): Move the CancelMap to Redis hashes (#10364) ## Problem The approach of having CancelMap as an in-memory structure increases code complexity, as well as putting additional load for Redis streams. ## Summary of changes - Implement a set of KV ops for Redis client; - Remove cancel notifications code; - Send KV ops over the bounded channel to the handling background task for removing and adding the cancel keys. Closes #9660 --- Cargo.lock | 1 + libs/pq_proto/src/lib.rs | 7 + libs/proxy/tokio-postgres2/Cargo.toml | 1 + .../proxy/tokio-postgres2/src/cancel_token.rs | 3 +- libs/proxy/tokio-postgres2/src/client.rs | 3 +- libs/proxy/tokio-postgres2/src/config.rs | 5 +- proxy/src/auth/backend/mod.rs | 3 +- proxy/src/bin/local_proxy.rs | 10 +- proxy/src/bin/proxy.rs | 46 +- proxy/src/cancellation.rs | 563 +++++++++--------- proxy/src/compute.rs | 1 - proxy/src/console_redirect_proxy.rs | 29 +- proxy/src/metrics.rs | 30 +- proxy/src/proxy/mod.rs | 39 +- proxy/src/proxy/passthrough.rs | 9 +- proxy/src/rate_limiter/limiter.rs | 6 + proxy/src/redis/cancellation_publisher.rs | 72 +-- .../connection_with_credentials_provider.rs | 1 + proxy/src/redis/keys.rs | 88 +++ proxy/src/redis/kv_ops.rs | 185 ++++++ proxy/src/redis/mod.rs | 2 + proxy/src/redis/notifications.rs | 105 +--- proxy/src/serverless/mod.rs | 8 +- proxy/src/serverless/websocket.rs | 4 +- 24 files changed, 674 insertions(+), 547 deletions(-) create mode 100644 proxy/src/redis/keys.rs create mode 100644 proxy/src/redis/kv_ops.rs diff --git a/Cargo.lock b/Cargo.lock index 3c33901247..c19fdc0941 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6935,6 +6935,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol2", "postgres-types2", + "serde", "tokio", "tokio-util", ] diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 50b2c69d24..f99128b76a 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -182,6 +182,13 @@ pub struct CancelKeyData { pub cancel_key: i32, } +pub fn id_to_cancel_key(id: u64) -> CancelKeyData { + CancelKeyData { + backend_pid: (id >> 32) as i32, + cancel_key: (id & 0xffffffff) as i32, + } +} + impl fmt::Display for CancelKeyData { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let hi = (self.backend_pid as u64) << 32; diff --git a/libs/proxy/tokio-postgres2/Cargo.toml b/libs/proxy/tokio-postgres2/Cargo.toml index 56e7c4da47..ade0ffc9f6 100644 --- a/libs/proxy/tokio-postgres2/Cargo.toml +++ b/libs/proxy/tokio-postgres2/Cargo.toml @@ -19,3 +19,4 @@ postgres-protocol2 = { path = "../postgres-protocol2" } postgres-types2 = { path = "../postgres-types2" } tokio = { workspace = true, features = ["io-util", "time", "net"] } tokio-util = { workspace = true, features = ["codec"] } +serde = { workspace = true, features = ["derive"] } \ No newline at end of file diff --git a/libs/proxy/tokio-postgres2/src/cancel_token.rs b/libs/proxy/tokio-postgres2/src/cancel_token.rs index a10e8bf5c3..718f903a92 100644 --- a/libs/proxy/tokio-postgres2/src/cancel_token.rs +++ b/libs/proxy/tokio-postgres2/src/cancel_token.rs @@ -3,12 +3,13 @@ use crate::tls::TlsConnect; use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect}; use crate::{cancel_query_raw, Error}; +use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; /// The capability to request cancellation of in-progress queries on a /// connection. -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct CancelToken { pub socket_config: Option, pub ssl_mode: SslMode, diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index a7cd53afc3..9bbbd4c260 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -18,6 +18,7 @@ use fallible_iterator::FallibleIterator; use futures_util::{future, ready, TryStreamExt}; use parking_lot::Mutex; use postgres_protocol2::message::{backend::Message, frontend}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt; use std::sync::Arc; @@ -137,7 +138,7 @@ impl InnerClient { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct SocketConfig { pub host: Host, pub port: u16, diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 11a361a81b..47cc45ac80 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -7,6 +7,7 @@ use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; use crate::{Client, Connection, Error}; use postgres_protocol2::message::frontend::StartupMessageParams; +use serde::{Deserialize, Serialize}; use std::fmt; use std::str; use std::time::Duration; @@ -16,7 +17,7 @@ pub use postgres_protocol2::authentication::sasl::ScramKeys; use tokio::net::TcpStream; /// TLS configuration. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] #[non_exhaustive] pub enum SslMode { /// Do not use TLS. @@ -50,7 +51,7 @@ pub enum ReplicationMode { } /// A host specification. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Host { /// A TCP hostname. Tcp(String), diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index de48be2952..d17d91a56d 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -12,6 +12,7 @@ pub(crate) use console_redirect::ConsoleRedirectError; use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; use postgres_client::config::AuthKeys; +use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -133,7 +134,7 @@ pub(crate) struct ComputeUserInfoNoEndpoint { pub(crate) options: NeonOptions, } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub(crate) struct ComputeUserInfo { pub(crate) endpoint: EndpointId, pub(crate) user: RoleName, diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 644f670f88..ee8b3d4ef5 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -7,12 +7,11 @@ use std::time::Duration; use anyhow::{bail, ensure, Context}; use camino::{Utf8Path, Utf8PathBuf}; use compute_api::spec::LocalProxySpec; -use dashmap::DashMap; use futures::future::Either; use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP}; use proxy::auth::{self}; -use proxy::cancellation::CancellationHandlerMain; +use proxy::cancellation::CancellationHandler; use proxy::config::{ self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, }; @@ -211,12 +210,7 @@ async fn main() -> anyhow::Result<()> { auth_backend, http_listener, shutdown.clone(), - Arc::new(CancellationHandlerMain::new( - &config.connect_to_compute, - Arc::new(DashMap::new()), - None, - proxy::metrics::CancellationSource::Local, - )), + Arc::new(CancellationHandler::new(&config.connect_to_compute, None)), endpoint_rate_limiter, ); diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 70b50436bf..e1affe8391 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -7,7 +7,7 @@ use anyhow::bail; use futures::future::Either; use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; -use proxy::cancellation::{CancelMap, CancellationHandler}; +use proxy::cancellation::{handle_cancel_messages, CancellationHandler}; use proxy::config::{ self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, @@ -18,8 +18,8 @@ use proxy::metrics::Metrics; use proxy::rate_limiter::{ EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter, }; -use proxy::redis::cancellation_publisher::RedisPublisherClient; use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; +use proxy::redis::kv_ops::RedisKVClient; use proxy::redis::{elasticache, notifications}; use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; @@ -28,7 +28,6 @@ use proxy::tls::client_config::compute_client_config_with_root_certs; use proxy::{auth, control_plane, http, serverless, usage_metrics}; use remote_storage::RemoteStorageConfig; use tokio::net::TcpListener; -use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::{info, warn, Instrument}; @@ -158,8 +157,11 @@ struct ProxyCliArgs { #[clap(long, default_value_t = 64)] auth_rate_limit_ip_subnet: u8, /// Redis rate limiter max number of requests per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)] redis_rps_limit: Vec, + /// Cancellation channel size (max queue size for redis kv client) + #[clap(long, default_value = "1024")] + cancellation_ch_size: usize, /// cache for `allowed_ips` (use `size=0` to disable) #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] allowed_ips_cache: String, @@ -382,27 +384,19 @@ async fn main() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); - let cancel_map = CancelMap::default(); - let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone()); RateBucketInfo::validate(redis_rps_limit)?; - let redis_publisher = match ®ional_redis_client { - Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( - redis_publisher.clone(), - args.region.clone(), - redis_rps_limit, - )?))), - None => None, - }; + let redis_kv_client = regional_redis_client + .as_ref() + .map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit)); - let cancellation_handler = Arc::new(CancellationHandler::< - Option>>, - >::new( + // channel size should be higher than redis client limit to avoid blocking + let cancel_ch_size = args.cancellation_ch_size; + let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size); + let cancellation_handler = Arc::new(CancellationHandler::new( &config.connect_to_compute, - cancel_map.clone(), - redis_publisher, - proxy::metrics::CancellationSource::FromClient, + Some(tx_cancel), )); // bit of a hack - find the min rps and max rps supported and turn it into @@ -495,25 +489,29 @@ async fn main() -> anyhow::Result<()> { let cache = api.caches.project_info.clone(); if let Some(client) = client1 { maintenance_tasks.spawn(notifications::task_main( - config, client, cache.clone(), - cancel_map.clone(), args.region.clone(), )); } if let Some(client) = client2 { maintenance_tasks.spawn(notifications::task_main( - config, client, cache.clone(), - cancel_map.clone(), args.region.clone(), )); } maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } } + + if let Some(mut redis_kv_client) = redis_kv_client { + maintenance_tasks.spawn(async move { + redis_kv_client.try_connect().await?; + handle_cancel_messages(&mut redis_kv_client, rx_cancel).await + }); + } + if let Some(regional_redis_client) = regional_redis_client { let cache = api.caches.endpoints_cache.clone(); let con = regional_redis_client; diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a96c43f2ce..34f708a36b 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,48 +1,124 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; -use dashmap::DashMap; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::tls::MakeTlsConnect; use postgres_client::CancelToken; use pq_proto::CancelKeyData; +use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::net::TcpStream; -use tokio::sync::Mutex; +use tokio::sync::mpsc; use tracing::{debug, info}; -use uuid::Uuid; use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo}; -use crate::auth::{check_peer_addr_is_in_list, AuthError, IpPattern}; +use crate::auth::{check_peer_addr_is_in_list, AuthError}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::ext::LockExt; -use crate::metrics::{CancellationRequest, CancellationSource, Metrics}; +use crate::metrics::CancelChannelSizeGuard; +use crate::metrics::{CancellationRequest, Metrics, RedisMsgKind}; use crate::rate_limiter::LeakyBucketRateLimiter; -use crate::redis::cancellation_publisher::{ - CancellationPublisher, CancellationPublisherMut, RedisPublisherClient, -}; +use crate::redis::keys::KeyPrefix; +use crate::redis::kv_ops::RedisKVClient; use crate::tls::postgres_rustls::MakeRustlsConnect; - -pub type CancelMap = Arc>>; -pub type CancellationHandlerMain = CancellationHandler>>>; -pub(crate) type CancellationHandlerMainInternal = Option>>; +use std::convert::Infallible; +use tokio::sync::oneshot; type IpSubnetKey = IpNet; +const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time +const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10); + +// Message types for sending through mpsc channel +pub enum CancelKeyOp { + StoreCancelKey { + key: String, + field: String, + value: String, + resp_tx: Option>>, + _guard: CancelChannelSizeGuard<'static>, + expire: i64, // TTL for key + }, + GetCancelData { + key: String, + resp_tx: oneshot::Sender>>, + _guard: CancelChannelSizeGuard<'static>, + }, + RemoveCancelKey { + key: String, + field: String, + resp_tx: Option>>, + _guard: CancelChannelSizeGuard<'static>, + }, +} + +// Running as a separate task to accept messages through the rx channel +// In case of problems with RTT: switch to recv_many() + redis pipeline +pub async fn handle_cancel_messages( + client: &mut RedisKVClient, + mut rx: mpsc::Receiver, +) -> anyhow::Result { + loop { + if let Some(msg) = rx.recv().await { + match msg { + CancelKeyOp::StoreCancelKey { + key, + field, + value, + resp_tx, + _guard, + expire: _, + } => { + if let Some(resp_tx) = resp_tx { + resp_tx + .send(client.hset(key, field, value).await) + .inspect_err(|e| { + tracing::debug!("failed to send StoreCancelKey response: {:?}", e); + }) + .ok(); + } else { + drop(client.hset(key, field, value).await); + } + } + CancelKeyOp::GetCancelData { + key, + resp_tx, + _guard, + } => { + drop(resp_tx.send(client.hget_all(key).await)); + } + CancelKeyOp::RemoveCancelKey { + key, + field, + resp_tx, + _guard, + } => { + if let Some(resp_tx) = resp_tx { + resp_tx + .send(client.hdel(key, field).await) + .inspect_err(|e| { + tracing::debug!("failed to send StoreCancelKey response: {:?}", e); + }) + .ok(); + } else { + drop(client.hdel(key, field).await); + } + } + } + } + } +} + /// Enables serving `CancelRequest`s. /// /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances. -pub struct CancellationHandler

{ +pub struct CancellationHandler { compute_config: &'static ComputeConfig, - map: CancelMap, - client: P, - /// This field used for the monitoring purposes. - /// Represents the source of the cancellation request. - from: CancellationSource, // rate limiter of cancellation requests limiter: Arc>>, + tx: Option>, // send messages to the redis KV client task } #[derive(Debug, Error)] @@ -61,6 +137,12 @@ pub(crate) enum CancelError { #[error("Authentication backend error")] AuthError(#[from] AuthError), + + #[error("key not found")] + NotFound, + + #[error("proxy service error")] + InternalError, } impl ReportableError for CancelError { @@ -73,274 +155,191 @@ impl ReportableError for CancelError { CancelError::Postgres(_) => crate::error::ErrorKind::Compute, CancelError::RateLimit => crate::error::ErrorKind::RateLimit, CancelError::IpNotAllowed => crate::error::ErrorKind::User, + CancelError::NotFound => crate::error::ErrorKind::User, CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane, + CancelError::InternalError => crate::error::ErrorKind::Service, } } } -impl CancellationHandler

{ - /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub(crate) fn get_session(self: Arc) -> Session

{ +impl CancellationHandler { + pub fn new( + compute_config: &'static ComputeConfig, + tx: Option>, + ) -> Self { + Self { + compute_config, + tx, + limiter: Arc::new(std::sync::Mutex::new( + LeakyBucketRateLimiter::::new_with_shards( + LeakyBucketRateLimiter::::DEFAULT, + 64, + ), + )), + } + } + + pub(crate) fn get_key(self: &Arc) -> Session { // we intentionally generate a random "backend pid" and "secret key" here. // we use the corresponding u64 as an identifier for the // actual endpoint+pid+secret for postgres/pgbouncer. // // if we forwarded the backend_pid from postgres to the client, there would be a lot // of overlap between our computes as most pids are small (~100). - let key = loop { - let key = rand::random(); - // Random key collisions are unlikely to happen here, but they're still possible, - // which is why we have to take care not to rewrite an existing key. - match self.map.entry(key) { - dashmap::mapref::entry::Entry::Occupied(_) => continue, - dashmap::mapref::entry::Entry::Vacant(e) => { - e.insert(None); - } - } - break key; - }; + let key: CancelKeyData = rand::random(); + + let prefix_key: KeyPrefix = KeyPrefix::Cancel(key); + let redis_key = prefix_key.build_redis_key(); debug!("registered new query cancellation key {key}"); Session { key, - cancellation_handler: self, + redis_key, + cancellation_handler: Arc::clone(self), } } - /// Cancelling only in notification, will be removed - pub(crate) async fn cancel_session( + async fn get_cancel_key( &self, key: CancelKeyData, - session_id: Uuid, - peer_addr: IpAddr, - check_allowed: bool, - ) -> Result<(), CancelError> { - // TODO: check for unspecified address is only for backward compatibility, should be removed - if !peer_addr.is_unspecified() { - let subnet_key = match peer_addr { - IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here - IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), - }; - if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { - // log only the subnet part of the IP address to know which subnet is rate limited - tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); - Metrics::get() - .proxy - .cancellation_requests_total - .inc(CancellationRequest { - source: self.from, - kind: crate::metrics::CancellationOutcome::RateLimitExceeded, - }); - return Err(CancelError::RateLimit); - } - } + ) -> Result, CancelError> { + let prefix_key: KeyPrefix = KeyPrefix::Cancel(key); + let redis_key = prefix_key.build_redis_key(); - // NB: we should immediately release the lock after cloning the token. - let cancel_state = self.map.get(&key).and_then(|x| x.clone()); - let Some(cancel_closure) = cancel_state else { - tracing::warn!("query cancellation key not found: {key}"); - Metrics::get() + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); + let op = CancelKeyOp::GetCancelData { + key: redis_key, + resp_tx, + _guard: Metrics::get() .proxy - .cancellation_requests_total - .inc(CancellationRequest { - source: self.from, - kind: crate::metrics::CancellationOutcome::NotFound, - }); - - if session_id == Uuid::nil() { - // was already published, do not publish it again - return Ok(()); - } - - match self.client.try_publish(key, session_id, peer_addr).await { - Ok(()) => {} // do nothing - Err(e) => { - // log it here since cancel_session could be spawned in a task - tracing::error!("failed to publish cancellation key: {key}, error: {e}"); - return Err(CancelError::IO(std::io::Error::new( - std::io::ErrorKind::Other, - e.to_string(), - ))); - } - } - return Ok(()); + .cancel_channel_size + .guard(RedisMsgKind::HGetAll), }; - if check_allowed - && !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice()) - { - // log it here since cancel_session could be spawned in a task - tracing::warn!("IP is not allowed to cancel the query: {key}"); - return Err(CancelError::IpNotAllowed); - } + let Some(tx) = &self.tx else { + tracing::warn!("cancellation handler is not available"); + return Err(CancelError::InternalError); + }; - Metrics::get() - .proxy - .cancellation_requests_total - .inc(CancellationRequest { - source: self.from, - kind: crate::metrics::CancellationOutcome::Found, - }); - info!( - "cancelling query per user's request using key {key}, hostname {}, address: {}", - cancel_closure.hostname, cancel_closure.socket_addr - ); - cancel_closure.try_cancel_query(self.compute_config).await + tx.send_timeout(op, REDIS_SEND_TIMEOUT) + .await + .map_err(|e| { + tracing::warn!("failed to send GetCancelData for {key}: {e}"); + }) + .map_err(|()| CancelError::InternalError)?; + + let result = resp_rx.await.map_err(|e| { + tracing::warn!("failed to receive GetCancelData response: {e}"); + CancelError::InternalError + })?; + + let cancel_state_str: Option = match result { + Ok(mut state) => { + if state.len() == 1 { + Some(state.remove(0).1) + } else { + tracing::warn!("unexpected number of entries in cancel state: {state:?}"); + return Err(CancelError::InternalError); + } + } + Err(e) => { + tracing::warn!("failed to receive cancel state from redis: {e}"); + return Err(CancelError::InternalError); + } + }; + + let cancel_state: Option = match cancel_state_str { + Some(state) => { + let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| { + tracing::warn!("failed to deserialize cancel state: {e}"); + CancelError::InternalError + })?; + Some(cancel_closure) + } + None => None, + }; + Ok(cancel_state) } - /// Try to cancel a running query for the corresponding connection. /// If the cancellation key is not found, it will be published to Redis. /// check_allowed - if true, check if the IP is allowed to cancel the query. /// Will fetch IP allowlist internally. /// /// return Result primarily for tests - pub(crate) async fn cancel_session_auth( + pub(crate) async fn cancel_session( &self, key: CancelKeyData, ctx: RequestContext, check_allowed: bool, auth_backend: &T, ) -> Result<(), CancelError> { - // TODO: check for unspecified address is only for backward compatibility, should be removed - if !ctx.peer_addr().is_unspecified() { - let subnet_key = match ctx.peer_addr() { - IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here - IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), - }; - if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { - // log only the subnet part of the IP address to know which subnet is rate limited - tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); - Metrics::get() - .proxy - .cancellation_requests_total - .inc(CancellationRequest { - source: self.from, - kind: crate::metrics::CancellationOutcome::RateLimitExceeded, - }); - return Err(CancelError::RateLimit); - } + let subnet_key = match ctx.peer_addr() { + IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here + IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), + }; + if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { + // log only the subnet part of the IP address to know which subnet is rate limited + tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); + Metrics::get() + .proxy + .cancellation_requests_total + .inc(CancellationRequest { + kind: crate::metrics::CancellationOutcome::RateLimitExceeded, + }); + return Err(CancelError::RateLimit); } - // NB: we should immediately release the lock after cloning the token. - let cancel_state = self.map.get(&key).and_then(|x| x.clone()); + let cancel_state = self.get_cancel_key(key).await.map_err(|e| { + tracing::warn!("failed to receive RedisOp response: {e}"); + CancelError::InternalError + })?; + let Some(cancel_closure) = cancel_state else { tracing::warn!("query cancellation key not found: {key}"); Metrics::get() .proxy .cancellation_requests_total .inc(CancellationRequest { - source: self.from, kind: crate::metrics::CancellationOutcome::NotFound, }); - - if ctx.session_id() == Uuid::nil() { - // was already published, do not publish it again - return Ok(()); - } - - match self - .client - .try_publish(key, ctx.session_id(), ctx.peer_addr()) - .await - { - Ok(()) => {} // do nothing - Err(e) => { - // log it here since cancel_session could be spawned in a task - tracing::error!("failed to publish cancellation key: {key}, error: {e}"); - return Err(CancelError::IO(std::io::Error::new( - std::io::ErrorKind::Other, - e.to_string(), - ))); - } - } - return Ok(()); + return Err(CancelError::NotFound); }; - let ip_allowlist = auth_backend - .get_allowed_ips(&ctx, &cancel_closure.user_info) - .await - .map_err(CancelError::AuthError)?; + if check_allowed { + let ip_allowlist = auth_backend + .get_allowed_ips(&ctx, &cancel_closure.user_info) + .await + .map_err(CancelError::AuthError)?; - if check_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) { - // log it here since cancel_session could be spawned in a task - tracing::warn!("IP is not allowed to cancel the query: {key}"); - return Err(CancelError::IpNotAllowed); + if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) { + // log it here since cancel_session could be spawned in a task + tracing::warn!( + "IP is not allowed to cancel the query: {key}, address: {}", + ctx.peer_addr() + ); + return Err(CancelError::IpNotAllowed); + } } Metrics::get() .proxy .cancellation_requests_total .inc(CancellationRequest { - source: self.from, kind: crate::metrics::CancellationOutcome::Found, }); info!("cancelling query per user's request using key {key}"); cancel_closure.try_cancel_query(self.compute_config).await } - - #[cfg(test)] - fn contains(&self, session: &Session

) -> bool { - self.map.contains_key(&session.key) - } - - #[cfg(test)] - fn is_empty(&self) -> bool { - self.map.is_empty() - } -} - -impl CancellationHandler<()> { - pub fn new( - compute_config: &'static ComputeConfig, - map: CancelMap, - from: CancellationSource, - ) -> Self { - Self { - compute_config, - map, - client: (), - from, - limiter: Arc::new(std::sync::Mutex::new( - LeakyBucketRateLimiter::::new_with_shards( - LeakyBucketRateLimiter::::DEFAULT, - 64, - ), - )), - } - } -} - -impl CancellationHandler>>> { - pub fn new( - compute_config: &'static ComputeConfig, - map: CancelMap, - client: Option>>, - from: CancellationSource, - ) -> Self { - Self { - compute_config, - map, - client, - from, - limiter: Arc::new(std::sync::Mutex::new( - LeakyBucketRateLimiter::::new_with_shards( - LeakyBucketRateLimiter::::DEFAULT, - 64, - ), - )), - } - } } /// This should've been a [`std::future::Future`], but /// it's impossible to name a type of an unboxed future /// (we'd need something like `#![feature(type_alias_impl_trait)]`). -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct CancelClosure { socket_addr: SocketAddr, cancel_token: CancelToken, - ip_allowlist: Vec, hostname: String, // for pg_sni router user_info: ComputeUserInfo, } @@ -349,14 +348,12 @@ impl CancelClosure { pub(crate) fn new( socket_addr: SocketAddr, cancel_token: CancelToken, - ip_allowlist: Vec, hostname: String, user_info: ComputeUserInfo, ) -> Self { Self { socket_addr, cancel_token, - ip_allowlist, hostname, user_info, } @@ -385,99 +382,75 @@ impl CancelClosure { debug!("query was cancelled"); Ok(()) } - - /// Obsolete (will be removed after moving CancelMap to Redis), only for notifications - pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec) { - self.ip_allowlist = ip_allowlist; - } } /// Helper for registering query cancellation tokens. -pub(crate) struct Session

{ +pub(crate) struct Session { /// The user-facing key identifying this session. key: CancelKeyData, - /// The [`CancelMap`] this session belongs to. - cancellation_handler: Arc>, + redis_key: String, + cancellation_handler: Arc, } -impl

Session

{ - /// Store the cancel token for the given session. - /// This enables query cancellation in `crate::proxy::prepare_client_connection`. - pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { - debug!("enabling query cancellation for this session"); - self.cancellation_handler - .map - .insert(self.key, Some(cancel_closure)); - - self.key +impl Session { + pub(crate) fn key(&self) -> &CancelKeyData { + &self.key } -} -impl

Drop for Session

{ - fn drop(&mut self) { - self.cancellation_handler.map.remove(&self.key); - debug!("dropped query cancellation key {}", &self.key); - } -} - -#[cfg(test)] -#[expect(clippy::unwrap_used)] -mod tests { - use std::time::Duration; - - use super::*; - use crate::config::RetryConfig; - use crate::tls::client_config::compute_client_config_with_certs; - - fn config() -> ComputeConfig { - let retry = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, + // Send the store key op to the cancellation handler + pub(crate) async fn write_cancel_key( + &self, + cancel_closure: CancelClosure, + ) -> Result<(), CancelError> { + let Some(tx) = &self.cancellation_handler.tx else { + tracing::warn!("cancellation handler is not available"); + return Err(CancelError::InternalError); }; - ComputeConfig { - retry, - tls: Arc::new(compute_client_config_with_certs(std::iter::empty())), - timeout: Duration::from_secs(2), - } - } + let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| { + tracing::warn!("failed to serialize cancel closure: {e}"); + CancelError::InternalError + })?; - #[tokio::test] - async fn check_session_drop() -> anyhow::Result<()> { - let cancellation_handler = Arc::new(CancellationHandler::<()>::new( - Box::leak(Box::new(config())), - CancelMap::default(), - CancellationSource::FromRedis, - )); - - let session = cancellation_handler.clone().get_session(); - assert!(cancellation_handler.contains(&session)); - drop(session); - // Check that the session has been dropped. - assert!(cancellation_handler.is_empty()); + let op = CancelKeyOp::StoreCancelKey { + key: self.redis_key.clone(), + field: "data".to_string(), + value: closure_json, + resp_tx: None, + _guard: Metrics::get() + .proxy + .cancel_channel_size + .guard(RedisMsgKind::HSet), + expire: CANCEL_KEY_TTL, + }; + let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| { + let key = self.key; + tracing::warn!("failed to send StoreCancelKey for {key}: {e}"); + }); Ok(()) } - #[tokio::test] - async fn cancel_session_noop_regression() { - let handler = CancellationHandler::<()>::new( - Box::leak(Box::new(config())), - CancelMap::default(), - CancellationSource::Local, - ); - handler - .cancel_session( - CancelKeyData { - backend_pid: 0, - cancel_key: 0, - }, - Uuid::new_v4(), - "127.0.0.1".parse().unwrap(), - true, - ) - .await - .unwrap(); + pub(crate) async fn remove_cancel_key(&self) -> Result<(), CancelError> { + let Some(tx) = &self.cancellation_handler.tx else { + tracing::warn!("cancellation handler is not available"); + return Err(CancelError::InternalError); + }; + + let op = CancelKeyOp::RemoveCancelKey { + key: self.redis_key.clone(), + field: "data".to_string(), + resp_tx: None, + _guard: Metrics::get() + .proxy + .cancel_channel_size + .guard(RedisMsgKind::HSet), + }; + + let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| { + let key = self.key; + tracing::warn!("failed to send RemoveCancelKey for {key}: {e}"); + }); + Ok(()) } } diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index aff796bbab..d71465765f 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -296,7 +296,6 @@ impl ConnCfg { process_id, secret_key, }, - vec![], // TODO: deprecated, will be removed host.to_string(), user_info, ); diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 0c6755063f..78bfb6deac 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -6,7 +6,7 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, Instrument}; use crate::auth::backend::ConsoleRedirectBackend; -use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal}; +use crate::cancellation::CancellationHandler; use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::error::ReportableError; @@ -24,7 +24,7 @@ pub async fn task_main( backend: &'static ConsoleRedirectBackend, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, - cancellation_handler: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); @@ -140,15 +140,16 @@ pub async fn task_main( Ok(()) } +#[allow(clippy::too_many_arguments)] pub(crate) async fn handle_client( config: &'static ProxyConfig, backend: &'static ConsoleRedirectBackend, ctx: &RequestContext, - cancellation_handler: Arc, + cancellation_handler: Arc, stream: S, conn_gauge: NumClientConnectionsGuard<'static>, cancellations: tokio_util::task::task_tracker::TaskTracker, -) -> Result>, ClientRequestError> { +) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), "handling interactive connection from client" @@ -171,13 +172,13 @@ pub(crate) async fn handle_client( HandshakeData::Cancel(cancel_key_data) => { // spawn a task to cancel the session, but don't wait for it cancellations.spawn({ - let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let cancellation_handler_clone = Arc::clone(&cancellation_handler); let ctx = ctx.clone(); let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id()); cancel_span.follows_from(tracing::Span::current()); async move { cancellation_handler_clone - .cancel_session_auth( + .cancel_session( cancel_key_data, ctx, config.authentication_config.ip_allowlist_check_enabled, @@ -195,7 +196,7 @@ pub(crate) async fn handle_client( ctx.set_db_options(params.clone()); - let (node_info, user_info, ip_allowlist) = match backend + let (node_info, user_info, _ip_allowlist) = match backend .authenticate(ctx, &config.authentication_config, &mut stream) .await { @@ -220,10 +221,14 @@ pub(crate) async fn handle_client( .or_else(|e| stream.throw_error(e)) .await?; - node.cancel_closure - .set_ip_allowlist(ip_allowlist.unwrap_or_default()); - let session = cancellation_handler.get_session(); - prepare_client_connection(&node, &session, &mut stream).await?; + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let session = cancellation_handler_clone.get_key(); + + session + .write_cancel_key(node.cancel_closure.clone()) + .await?; + + prepare_client_connection(&node, *session.key(), &mut stream).await?; // Before proxy passing, forward to compute whatever data is left in the // PqStream input buffer. Normally there is none, but our serverless npm @@ -237,8 +242,8 @@ pub(crate) async fn handle_client( aux: node.aux.clone(), compute: node, session_id: ctx.session_id(), + cancel: session, _req: request_gauge, _conn: conn_gauge, - _cancel: session, })) } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 659c57c865..f3d281a26b 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -56,6 +56,8 @@ pub struct ProxyMetrics { pub connection_requests: CounterPairVec, #[metric(flatten)] pub http_endpoint_pools: HttpEndpointPools, + #[metric(flatten)] + pub cancel_channel_size: CounterPairVec, /// Time it took for proxy to establish a connection to the compute endpoint. // largest bucket = 2^16 * 0.5ms = 32s @@ -294,6 +296,16 @@ impl CounterPairAssoc for NumConnectionRequestsGauge { pub type NumConnectionRequestsGuard<'a> = metrics::MeasuredCounterPairGuard<'a, NumConnectionRequestsGauge>; +pub struct CancelChannelSizeGauge; +impl CounterPairAssoc for CancelChannelSizeGauge { + const INC_NAME: &'static MetricName = MetricName::from_str("opened_msgs_cancel_channel_total"); + const DEC_NAME: &'static MetricName = MetricName::from_str("closed_msgs_cancel_channel_total"); + const INC_HELP: &'static str = "Number of processing messages in the cancellation channel."; + const DEC_HELP: &'static str = "Number of closed messages in the cancellation channel."; + type LabelGroupSet = StaticLabelSet; +} +pub type CancelChannelSizeGuard<'a> = metrics::MeasuredCounterPairGuard<'a, CancelChannelSizeGauge>; + #[derive(LabelGroup)] #[label(set = ComputeConnectionLatencySet)] pub struct ComputeConnectionLatencyGroup { @@ -340,13 +352,6 @@ pub struct RedisErrors<'a> { pub channel: &'a str, } -#[derive(FixedCardinalityLabel, Copy, Clone)] -pub enum CancellationSource { - FromClient, - FromRedis, - Local, -} - #[derive(FixedCardinalityLabel, Copy, Clone)] pub enum CancellationOutcome { NotFound, @@ -357,7 +362,6 @@ pub enum CancellationOutcome { #[derive(LabelGroup)] #[label(set = CancellationRequestSet)] pub struct CancellationRequest { - pub source: CancellationSource, pub kind: CancellationOutcome, } @@ -369,6 +373,16 @@ pub enum Waiting { RetryTimeout, } +#[derive(FixedCardinalityLabel, Copy, Clone)] +#[label(singleton = "kind")] +pub enum RedisMsgKind { + HSet, + HSetMultiple, + HGet, + HGetAll, + HDel, +} + #[derive(Default)] struct Accumulated { cplane: time::Duration, diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 63f93f0a91..ab173bd0d0 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -13,8 +13,9 @@ pub use copy_bidirectional::{copy_bidirectional_client_compute, ErrorSource}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use once_cell::sync::OnceCell; -use pq_proto::{BeMessage as Be, StartupMessageParams}; +use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams}; use regex::Regex; +use serde::{Deserialize, Serialize}; use smol_str::{format_smolstr, SmolStr}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; @@ -23,7 +24,7 @@ use tracing::{debug, error, info, warn, Instrument}; use self::connect_compute::{connect_to_compute, TcpMechanism}; use self::passthrough::ProxyPassthrough; -use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}; +use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; use crate::error::ReportableError; @@ -57,7 +58,7 @@ pub async fn task_main( auth_backend: &'static auth::Backend<'static, ()>, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, - cancellation_handler: Arc, + cancellation_handler: Arc, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { @@ -243,13 +244,13 @@ pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, - cancellation_handler: Arc, + cancellation_handler: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: NumClientConnectionsGuard<'static>, cancellations: tokio_util::task::task_tracker::TaskTracker, -) -> Result>, ClientRequestError> { +) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), "handling interactive connection from client" @@ -278,7 +279,7 @@ pub(crate) async fn handle_client( cancel_span.follows_from(tracing::Span::current()); async move { cancellation_handler_clone - .cancel_session_auth( + .cancel_session( cancel_key_data, ctx, config.authentication_config.ip_allowlist_check_enabled, @@ -312,7 +313,7 @@ pub(crate) async fn handle_client( }; let user = user_info.get_user().to_owned(); - let (user_info, ip_allowlist) = match user_info + let (user_info, _ip_allowlist) = match user_info .authenticate( ctx, &mut stream, @@ -356,10 +357,14 @@ pub(crate) async fn handle_client( .or_else(|e| stream.throw_error(e)) .await?; - node.cancel_closure - .set_ip_allowlist(ip_allowlist.unwrap_or_default()); - let session = cancellation_handler.get_session(); - prepare_client_connection(&node, &session, &mut stream).await?; + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let session = cancellation_handler_clone.get_key(); + + session + .write_cancel_key(node.cancel_closure.clone()) + .await?; + + prepare_client_connection(&node, *session.key(), &mut stream).await?; // Before proxy passing, forward to compute whatever data is left in the // PqStream input buffer. Normally there is none, but our serverless npm @@ -373,23 +378,19 @@ pub(crate) async fn handle_client( aux: node.aux.clone(), compute: node, session_id: ctx.session_id(), + cancel: session, _req: request_gauge, _conn: conn_gauge, - _cancel: session, })) } /// Finish client connection initialization: confirm auth success, send params, etc. #[tracing::instrument(skip_all)] -pub(crate) async fn prepare_client_connection

( +pub(crate) async fn prepare_client_connection( node: &compute::PostgresConnection, - session: &cancellation::Session

, + cancel_key_data: CancelKeyData, stream: &mut PqStream, ) -> Result<(), std::io::Error> { - // Register compute's query cancellation token and produce a new, unique one. - // The new token (cancel_key_data) will be sent to the client. - let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone()); - // Forward all deferred notices to the client. for notice in &node.delayed_notice { stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?; @@ -411,7 +412,7 @@ pub(crate) async fn prepare_client_connection

( Ok(()) } -#[derive(Debug, Clone, PartialEq, Eq, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>); impl NeonOptions { diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index a42f9aad39..08871380d6 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -56,18 +56,18 @@ pub(crate) async fn proxy_pass( Ok(()) } -pub(crate) struct ProxyPassthrough { +pub(crate) struct ProxyPassthrough { pub(crate) client: Stream, pub(crate) compute: PostgresConnection, pub(crate) aux: MetricsAuxInfo, pub(crate) session_id: uuid::Uuid, + pub(crate) cancel: cancellation::Session, pub(crate) _req: NumConnectionRequestsGuard<'static>, pub(crate) _conn: NumClientConnectionsGuard<'static>, - pub(crate) _cancel: cancellation::Session

, } -impl ProxyPassthrough { +impl ProxyPassthrough { pub(crate) async fn proxy_pass( self, compute_config: &ComputeConfig, @@ -81,6 +81,9 @@ impl ProxyPassthrough { { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); } + + drop(self.cancel.remove_cancel_key().await); // we don't need a result. If the queue is full, we just log the error + res } } diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 6f6a8c9d47..ec080f270b 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -138,6 +138,12 @@ impl RateBucketInfo { Self::new(200, Duration::from_secs(600)), ]; + // For all the sessions will be cancel key. So this limit is essentially global proxy limit. + pub const DEFAULT_REDIS_SET: [Self; 2] = [ + Self::new(100_000, Duration::from_secs(1)), + Self::new(50_000, Duration::from_secs(10)), + ]; + /// All of these are per endpoint-maskedip pair. /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). /// diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 228dbb7f64..30d8b83e60 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -2,12 +2,10 @@ use core::net::IpAddr; use std::sync::Arc; use pq_proto::CancelKeyData; -use redis::AsyncCommands; use tokio::sync::Mutex; use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; -use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}; use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo}; pub trait CancellationPublisherMut: Send + Sync + 'static { @@ -83,9 +81,10 @@ impl CancellationPublisher for Arc> { } pub struct RedisPublisherClient { + #[allow(dead_code)] client: ConnectionWithCredentialsProvider, - region_id: String, - limiter: GlobalRateLimiter, + _region_id: String, + _limiter: GlobalRateLimiter, } impl RedisPublisherClient { @@ -96,26 +95,12 @@ impl RedisPublisherClient { ) -> anyhow::Result { Ok(Self { client, - region_id, - limiter: GlobalRateLimiter::new(info.into()), + _region_id: region_id, + _limiter: GlobalRateLimiter::new(info.into()), }) } - async fn publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - peer_addr: IpAddr, - ) -> anyhow::Result<()> { - let payload = serde_json::to_string(&Notification::Cancel(CancelSession { - region_id: Some(self.region_id.clone()), - cancel_key_data, - session_id, - peer_addr: Some(peer_addr), - }))?; - let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?; - Ok(()) - } + #[allow(dead_code)] pub(crate) async fn try_connect(&mut self) -> anyhow::Result<()> { match self.client.connect().await { Ok(()) => {} @@ -126,49 +111,4 @@ impl RedisPublisherClient { } Ok(()) } - async fn try_publish_internal( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - peer_addr: IpAddr, - ) -> anyhow::Result<()> { - // TODO: review redundant error duplication logs. - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping cancellation message"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - match self.publish(cancel_key_data, session_id, peer_addr).await { - Ok(()) => return Ok(()), - Err(e) => { - tracing::error!("failed to publish a message: {e}"); - } - } - tracing::info!("Publisher is disconnected. Reconnectiong..."); - self.try_connect().await?; - self.publish(cancel_key_data, session_id, peer_addr).await - } -} - -impl CancellationPublisherMut for RedisPublisherClient { - async fn try_publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - peer_addr: IpAddr, - ) -> anyhow::Result<()> { - tracing::info!("publishing cancellation key to Redis"); - match self - .try_publish_internal(cancel_key_data, session_id, peer_addr) - .await - { - Ok(()) => { - tracing::debug!("cancellation key successfuly published to Redis"); - Ok(()) - } - Err(e) => { - tracing::error!("failed to publish a message: {e}"); - Err(e) - } - } - } } diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index 0f6e765b02..b5c3d13216 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -29,6 +29,7 @@ impl Clone for Credentials { /// Provides PubSub connection without credentials refresh. pub struct ConnectionWithCredentialsProvider { credentials: Credentials, + // TODO: with more load on the connection, we should consider using a connection pool con: Option, refresh_token_task: Option>, mutex: tokio::sync::Mutex<()>, diff --git a/proxy/src/redis/keys.rs b/proxy/src/redis/keys.rs new file mode 100644 index 0000000000..dddc7e2054 --- /dev/null +++ b/proxy/src/redis/keys.rs @@ -0,0 +1,88 @@ +use anyhow::Ok; +use pq_proto::{id_to_cancel_key, CancelKeyData}; +use serde::{Deserialize, Serialize}; +use std::io::ErrorKind; + +pub mod keyspace { + pub const CANCEL_PREFIX: &str = "cancel"; +} + +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) enum KeyPrefix { + #[serde(untagged)] + Cancel(CancelKeyData), +} + +impl KeyPrefix { + pub(crate) fn build_redis_key(&self) -> String { + match self { + KeyPrefix::Cancel(key) => { + let hi = (key.backend_pid as u64) << 32; + let lo = (key.cancel_key as u64) & 0xffff_ffff; + let id = hi | lo; + let keyspace = keyspace::CANCEL_PREFIX; + format!("{keyspace}:{id:x}") + } + } + } + + #[allow(dead_code)] + pub(crate) fn as_str(&self) -> &'static str { + match self { + KeyPrefix::Cancel(_) => keyspace::CANCEL_PREFIX, + } + } +} + +#[allow(dead_code)] +pub(crate) fn parse_redis_key(key: &str) -> anyhow::Result { + let (prefix, key_str) = key.split_once(':').ok_or_else(|| { + anyhow::anyhow!(std::io::Error::new( + ErrorKind::InvalidData, + "missing prefix" + )) + })?; + + match prefix { + keyspace::CANCEL_PREFIX => { + let id = u64::from_str_radix(key_str, 16)?; + + Ok(KeyPrefix::Cancel(id_to_cancel_key(id))) + } + _ => Err(anyhow::anyhow!(std::io::Error::new( + ErrorKind::InvalidData, + "unknown prefix" + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_redis_key() { + let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData { + backend_pid: 12345, + cancel_key: 54321, + }); + + let redis_key = cancel_key.build_redis_key(); + assert_eq!(redis_key, "cancel:30390000d431"); + } + + #[test] + fn test_parse_redis_key() { + let redis_key = "cancel:30390000d431"; + let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key"); + + let ref_key = CancelKeyData { + backend_pid: 12345, + cancel_key: 54321, + }; + + assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str()); + let KeyPrefix::Cancel(cancel_key) = key; + assert_eq!(ref_key, cancel_key); + } +} diff --git a/proxy/src/redis/kv_ops.rs b/proxy/src/redis/kv_ops.rs new file mode 100644 index 0000000000..dcc6aac51b --- /dev/null +++ b/proxy/src/redis/kv_ops.rs @@ -0,0 +1,185 @@ +use redis::{AsyncCommands, ToRedisArgs}; + +use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; + +use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo}; + +pub struct RedisKVClient { + client: ConnectionWithCredentialsProvider, + limiter: GlobalRateLimiter, +} + +impl RedisKVClient { + pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self { + Self { + client, + limiter: GlobalRateLimiter::new(info.into()), + } + } + + pub async fn try_connect(&mut self) -> anyhow::Result<()> { + match self.client.connect().await { + Ok(()) => {} + Err(e) => { + tracing::error!("failed to connect to redis: {e}"); + return Err(e); + } + } + Ok(()) + } + + pub(crate) async fn hset(&mut self, key: K, field: F, value: V) -> anyhow::Result<()> + where + K: ToRedisArgs + Send + Sync, + F: ToRedisArgs + Send + Sync, + V: ToRedisArgs + Send + Sync, + { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping hset"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + + match self.client.hset(&key, &field, &value).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to set a key-value pair: {e}"); + } + } + + tracing::info!("Redis client is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.client + .hset(key, field, value) + .await + .map_err(anyhow::Error::new) + } + + #[allow(dead_code)] + pub(crate) async fn hset_multiple( + &mut self, + key: &str, + items: &[(K, V)], + ) -> anyhow::Result<()> + where + K: ToRedisArgs + Send + Sync, + V: ToRedisArgs + Send + Sync, + { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping hset_multiple"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + + match self.client.hset_multiple(key, items).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to set a key-value pair: {e}"); + } + } + + tracing::info!("Redis client is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.client + .hset_multiple(key, items) + .await + .map_err(anyhow::Error::new) + } + + #[allow(dead_code)] + pub(crate) async fn expire(&mut self, key: K, seconds: i64) -> anyhow::Result<()> + where + K: ToRedisArgs + Send + Sync, + { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping expire"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + + match self.client.expire(&key, seconds).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to set a key-value pair: {e}"); + } + } + + tracing::info!("Redis client is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.client + .expire(key, seconds) + .await + .map_err(anyhow::Error::new) + } + + #[allow(dead_code)] + pub(crate) async fn hget(&mut self, key: K, field: F) -> anyhow::Result + where + K: ToRedisArgs + Send + Sync, + F: ToRedisArgs + Send + Sync, + V: redis::FromRedisValue, + { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping hget"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + + match self.client.hget(&key, &field).await { + Ok(value) => return Ok(value), + Err(e) => { + tracing::error!("failed to get a value: {e}"); + } + } + + tracing::info!("Redis client is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.client + .hget(key, field) + .await + .map_err(anyhow::Error::new) + } + + pub(crate) async fn hget_all(&mut self, key: K) -> anyhow::Result + where + K: ToRedisArgs + Send + Sync, + V: redis::FromRedisValue, + { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping hgetall"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + + match self.client.hgetall(&key).await { + Ok(value) => return Ok(value), + Err(e) => { + tracing::error!("failed to get a value: {e}"); + } + } + + tracing::info!("Redis client is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.client.hgetall(key).await.map_err(anyhow::Error::new) + } + + pub(crate) async fn hdel(&mut self, key: K, field: F) -> anyhow::Result<()> + where + K: ToRedisArgs + Send + Sync, + F: ToRedisArgs + Send + Sync, + { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping hdel"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + + match self.client.hdel(&key, &field).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to delete a key-value pair: {e}"); + } + } + + tracing::info!("Redis client is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.client + .hdel(key, field) + .await + .map_err(anyhow::Error::new) + } +} diff --git a/proxy/src/redis/mod.rs b/proxy/src/redis/mod.rs index a322f0368c..8b46a8e6ca 100644 --- a/proxy/src/redis/mod.rs +++ b/proxy/src/redis/mod.rs @@ -1,4 +1,6 @@ pub mod cancellation_publisher; pub mod connection_with_credentials_provider; pub mod elasticache; +pub mod keys; +pub mod kv_ops; pub mod notifications; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 63cdf6176c..19fdd3280d 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -6,18 +6,14 @@ use pq_proto::CancelKeyData; use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use tokio_util::sync::CancellationToken; -use tracing::Instrument; use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; -use crate::cancellation::{CancelMap, CancellationHandler}; -use crate::config::ProxyConfig; use crate::intern::{ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; -pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20); @@ -25,8 +21,6 @@ async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Resu let mut conn = client.get_async_pubsub().await?; tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); conn.subscribe(CPLANE_CHANNEL_NAME).await?; - tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); - conn.subscribe(PROXY_CHANNEL_NAME).await?; Ok(conn) } @@ -71,8 +65,6 @@ pub(crate) enum Notification { deserialize_with = "deserialize_json_string" )] PasswordUpdate { password_update: PasswordUpdate }, - #[serde(rename = "/cancel_session")] - Cancel(CancelSession), #[serde( other, @@ -138,7 +130,6 @@ where struct MessageHandler { cache: Arc, - cancellation_handler: Arc>, region_id: String, } @@ -146,23 +137,14 @@ impl Clone for MessageHandler { fn clone(&self) -> Self { Self { cache: self.cache.clone(), - cancellation_handler: self.cancellation_handler.clone(), region_id: self.region_id.clone(), } } } impl MessageHandler { - pub(crate) fn new( - cache: Arc, - cancellation_handler: Arc>, - region_id: String, - ) -> Self { - Self { - cache, - cancellation_handler, - region_id, - } + pub(crate) fn new(cache: Arc, region_id: String) -> Self { + Self { cache, region_id } } pub(crate) async fn increment_active_listeners(&self) { @@ -207,46 +189,6 @@ impl MessageHandler { tracing::debug!(?msg, "received a message"); match msg { - Notification::Cancel(cancel_session) => { - tracing::Span::current().record( - "session_id", - tracing::field::display(cancel_session.session_id), - ); - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::CancelSession); - if let Some(cancel_region) = cancel_session.region_id { - // If the message is not for this region, ignore it. - if cancel_region != self.region_id { - return Ok(()); - } - } - - // TODO: Remove unspecified peer_addr after the complete migration to the new format - let peer_addr = cancel_session - .peer_addr - .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)); - let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?cancel_session.session_id); - cancel_span.follows_from(tracing::Span::current()); - // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message. - match self - .cancellation_handler - .cancel_session( - cancel_session.cancel_key_data, - uuid::Uuid::nil(), - peer_addr, - cancel_session.peer_addr.is_some(), - ) - .instrument(cancel_span) - .await - { - Ok(()) => {} - Err(e) => { - tracing::warn!("failed to cancel session: {e}"); - } - } - } Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } | Notification::BlockPublicOrVpcAccessUpdated { .. } @@ -293,7 +235,6 @@ fn invalidate_cache(cache: Arc, msg: Notification) { password_update.project_id, password_update.role_name, ), - Notification::Cancel(_) => unreachable!("cancel message should be handled separately"), Notification::BlockPublicOrVpcAccessUpdated { .. } => { // https://github.com/neondatabase/neon/pull/10073 } @@ -323,8 +264,8 @@ async fn handle_messages( } Err(e) => { tracing::error!( - "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}" - ); + "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}" + ); tokio::time::sleep(RECONNECT_TIMEOUT).await; continue; } @@ -350,21 +291,14 @@ async fn handle_messages( /// Handle console's invalidation messages. #[tracing::instrument(name = "redis_notifications", skip_all)] pub async fn task_main( - config: &'static ProxyConfig, redis: ConnectionWithCredentialsProvider, cache: Arc, - cancel_map: CancelMap, region_id: String, ) -> anyhow::Result where C: ProjectInfoCache + Send + Sync + 'static, { - let cancellation_handler = Arc::new(CancellationHandler::<()>::new( - &config.connect_to_compute, - cancel_map, - crate::metrics::CancellationSource::FromRedis, - )); - let handler = MessageHandler::new(cache, cancellation_handler, region_id); + let handler = MessageHandler::new(cache, region_id); // 6h - 1m. // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost. let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60)); @@ -442,35 +376,6 @@ mod tests { Ok(()) } - #[test] - fn parse_cancel_session() -> anyhow::Result<()> { - let cancel_key_data = CancelKeyData { - backend_pid: 42, - cancel_key: 41, - }; - let uuid = uuid::Uuid::new_v4(); - let msg = Notification::Cancel(CancelSession { - cancel_key_data, - region_id: None, - session_id: uuid, - peer_addr: None, - }); - let text = serde_json::to_string(&msg)?; - let result: Notification = serde_json::from_str(&text)?; - assert_eq!(msg, result); - - let msg = Notification::Cancel(CancelSession { - cancel_key_data, - region_id: Some("region".to_string()), - session_id: uuid, - peer_addr: None, - }); - let text = serde_json::to_string(&msg)?; - let result: Notification = serde_json::from_str(&text)?; - assert_eq!(msg, result,); - - Ok(()) - } #[test] fn parse_unknown_topic() -> anyhow::Result<()> { diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index c2623e0eca..6888772362 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -43,7 +43,7 @@ use tokio_util::task::TaskTracker; use tracing::{info, warn, Instrument}; use utils::http::error::ApiError; -use crate::cancellation::CancellationHandlerMain; +use crate::cancellation::CancellationHandler; use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::ext::TaskExt; @@ -61,7 +61,7 @@ pub async fn task_main( auth_backend: &'static crate::auth::Backend<'static, ()>, ws_listener: TcpListener, cancellation_token: CancellationToken, - cancellation_handler: Arc, + cancellation_handler: Arc, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { @@ -318,7 +318,7 @@ async fn connection_handler( backend: Arc, connections: TaskTracker, cancellations: TaskTracker, - cancellation_handler: Arc, + cancellation_handler: Arc, endpoint_rate_limiter: Arc, cancellation_token: CancellationToken, conn: AsyncRW, @@ -412,7 +412,7 @@ async fn request_handler( config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, - cancellation_handler: Arc, + cancellation_handler: Arc, session_id: uuid::Uuid, conn_info: ConnectionInfo, // used to cancel in-flight HTTP requests. not used to cancel websockets diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 47326c1181..585a7d63b2 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -12,7 +12,7 @@ use pin_project_lite::pin_project; use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use tracing::warn; -use crate::cancellation::CancellationHandlerMain; +use crate::cancellation::CancellationHandler; use crate::config::ProxyConfig; use crate::context::RequestContext; use crate::error::{io_error, ReportableError}; @@ -129,7 +129,7 @@ pub(crate) async fn serve_websocket( auth_backend: &'static crate::auth::Backend<'static, ()>, ctx: RequestContext, websocket: OnUpgrade, - cancellation_handler: Arc, + cancellation_handler: Arc, endpoint_rate_limiter: Arc, hostname: Option, cancellations: tokio_util::task::task_tracker::TaskTracker,