mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-13 16:32:56 +00:00
Proxy: send cancel notifications to all instances (#6719)
## Problem If cancel request ends up on the wrong proxy instance, it doesn't take an effect. ## Summary of changes Send redis notifications to all proxy pods about the cancel request. Related issue: https://github.com/neondatabase/neon/issues/5839, https://github.com/neondatabase/cloud/issues/10262
This commit is contained in:
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -2263,11 +2263,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.8.2"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa"
|
||||
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
|
||||
dependencies = [
|
||||
"hashbrown 0.13.2",
|
||||
"hashbrown 0.14.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3952,6 +3952,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
|
||||
@@ -81,7 +81,7 @@ futures-core = "0.3"
|
||||
futures-util = "0.3"
|
||||
git-version = "0.3"
|
||||
hashbrown = "0.13"
|
||||
hashlink = "0.8.1"
|
||||
hashlink = "0.8.4"
|
||||
hdrhistogram = "7.5.2"
|
||||
hex = "0.4"
|
||||
hex-literal = "0.4"
|
||||
|
||||
@@ -13,5 +13,6 @@ rand.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
thiserror.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
@@ -7,6 +7,7 @@ pub mod framed;
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{borrow::Cow, collections::HashMap, fmt, io, str};
|
||||
|
||||
// re-export for use in utils pageserver_feedback.rs
|
||||
@@ -123,7 +124,7 @@ impl StartupMessageParams {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct CancelKeyData {
|
||||
pub backend_pid: i32,
|
||||
pub cancel_key: i32,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::cancellation::CancelMap;
|
||||
use proxy::cancellation::CancellationHandler;
|
||||
use proxy::config::AuthenticationConfig;
|
||||
use proxy::config::CacheOptions;
|
||||
use proxy::config::HttpConfig;
|
||||
@@ -12,6 +14,7 @@ use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
use proxy::rate_limiter::RateLimiterConfig;
|
||||
use proxy::redis::notifications;
|
||||
use proxy::redis::publisher::RedisPublisherClient;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::usage_metrics;
|
||||
|
||||
@@ -22,6 +25,7 @@ use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
@@ -129,6 +133,9 @@ struct ProxyCliArgs {
|
||||
/// Can be given multiple times for different bucket sizes.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Redis rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
redis_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
|
||||
#[clap(long, default_value_t = 100)]
|
||||
initial_limit: usize,
|
||||
@@ -225,6 +232,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
|
||||
let cancel_map = CancelMap::default();
|
||||
let redis_publisher = match &args.redis_notifications {
|
||||
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
|
||||
url,
|
||||
args.region.clone(),
|
||||
&config.redis_rps_limit,
|
||||
)?))),
|
||||
None => None,
|
||||
};
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||
cancel_map.clone(),
|
||||
redis_publisher,
|
||||
));
|
||||
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
@@ -234,6 +254,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
@@ -248,6 +269,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -271,7 +293,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(url) = args.redis_notifications {
|
||||
info!("Starting redis notifications listener ({url})");
|
||||
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
url.to_owned(),
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
@@ -403,6 +430,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
|
||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
|
||||
let mut redis_rps_limit = args.redis_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut redis_rps_limit)?;
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
@@ -414,6 +443,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
require_client_ip: args.require_client_ip,
|
||||
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
||||
endpoint_rps_limit,
|
||||
redis_rps_limit,
|
||||
handshake_timeout: args.handshake_timeout,
|
||||
// TODO: add this argument
|
||||
region: args.region.clone(),
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use pq_proto::CancelKeyData;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::ReportableError;
|
||||
use crate::{
|
||||
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
|
||||
redis::publisher::RedisPublisherClient,
|
||||
};
|
||||
|
||||
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
|
||||
|
||||
/// Enables serving `CancelRequest`s.
|
||||
#[derive(Default)]
|
||||
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
|
||||
///
|
||||
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
|
||||
pub struct CancellationHandler {
|
||||
map: CancelMap,
|
||||
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CancelError {
|
||||
@@ -32,15 +44,43 @@ impl ReportableError for CancelError {
|
||||
}
|
||||
}
|
||||
|
||||
impl CancelMap {
|
||||
impl CancellationHandler {
|
||||
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
|
||||
Self { map, redis_client }
|
||||
}
|
||||
/// Cancel a running query for the corresponding connection.
|
||||
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
|
||||
pub async fn cancel_session(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> Result<(), CancelError> {
|
||||
let from = "from_client";
|
||||
// NB: we should immediately release the lock after cloning the token.
|
||||
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
|
||||
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
if let Some(redis_client) = &self.redis_client {
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "not_found"])
|
||||
.inc();
|
||||
info!("publishing cancellation key to Redis");
|
||||
match redis_client.lock().await.try_publish(key, session_id).await {
|
||||
Ok(()) => {
|
||||
info!("cancellation key successfuly published to Redis");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
return Err(CancelError::IO(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "found"])
|
||||
.inc();
|
||||
info!("cancelling query per user's request using key {key}");
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
@@ -57,7 +97,7 @@ impl CancelMap {
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
match self.0.entry(key) {
|
||||
match self.map.entry(key) {
|
||||
dashmap::mapref::entry::Entry::Occupied(_) => continue,
|
||||
dashmap::mapref::entry::Entry::Vacant(e) => {
|
||||
e.insert(None);
|
||||
@@ -69,18 +109,46 @@ impl CancelMap {
|
||||
info!("registered new query cancellation key {key}");
|
||||
Session {
|
||||
key,
|
||||
cancel_map: self,
|
||||
cancellation_handler: self,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn contains(&self, session: &Session) -> bool {
|
||||
self.0.contains_key(&session.key)
|
||||
self.map.contains_key(&session.key)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
self.map.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait NotificationsCancellationHandler {
|
||||
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl NotificationsCancellationHandler for CancellationHandler {
|
||||
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
|
||||
let from = "from_redis";
|
||||
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
|
||||
match cancel_closure {
|
||||
Some(cancel_closure) => {
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "found"])
|
||||
.inc();
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
None => {
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "not_found"])
|
||||
.inc();
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +183,7 @@ pub struct Session {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
cancel_map: Arc<CancelMap>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -123,7 +191,9 @@ impl Session {
|
||||
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
||||
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
info!("enabling query cancellation for this session");
|
||||
self.cancel_map.0.insert(self.key, Some(cancel_closure));
|
||||
self.cancellation_handler
|
||||
.map
|
||||
.insert(self.key, Some(cancel_closure));
|
||||
|
||||
self.key
|
||||
}
|
||||
@@ -131,7 +201,7 @@ impl Session {
|
||||
|
||||
impl Drop for Session {
|
||||
fn drop(&mut self) {
|
||||
self.cancel_map.0.remove(&self.key);
|
||||
self.cancellation_handler.map.remove(&self.key);
|
||||
info!("dropped query cancellation key {}", &self.key);
|
||||
}
|
||||
}
|
||||
@@ -142,13 +212,16 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_session_drop() -> anyhow::Result<()> {
|
||||
let cancel_map: Arc<CancelMap> = Default::default();
|
||||
let cancellation_handler = Arc::new(CancellationHandler {
|
||||
map: CancelMap::default(),
|
||||
redis_client: None,
|
||||
});
|
||||
|
||||
let session = cancel_map.clone().get_session();
|
||||
assert!(cancel_map.contains(&session));
|
||||
let session = cancellation_handler.clone().get_session();
|
||||
assert!(cancellation_handler.contains(&session));
|
||||
drop(session);
|
||||
// Check that the session has been dropped.
|
||||
assert!(cancel_map.is_empty());
|
||||
assert!(cancellation_handler.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ pub struct ProxyConfig {
|
||||
pub require_client_ip: bool,
|
||||
pub disable_ip_check_for_http: bool,
|
||||
pub endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
pub redis_rps_limit: Vec<RateBucketInfo>,
|
||||
pub region: String,
|
||||
pub handshake_timeout: Duration,
|
||||
}
|
||||
|
||||
@@ -152,6 +152,15 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"proxy_cancellation_requests_total",
|
||||
"Number of cancellation requests (per found/not_found).",
|
||||
&["source", "kind"],
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LatencyTimer {
|
||||
// time since the stopwatch was started
|
||||
|
||||
@@ -10,7 +10,7 @@ pub mod wake_compute;
|
||||
|
||||
use crate::{
|
||||
auth,
|
||||
cancellation::{self, CancelMap},
|
||||
cancellation::{self, CancellationHandler},
|
||||
compute,
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
context::RequestMonitoring,
|
||||
@@ -62,6 +62,7 @@ pub async fn task_main(
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
@@ -72,7 +73,6 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -80,7 +80,7 @@ pub async fn task_main(
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
let session_span = info_span!(
|
||||
@@ -113,7 +113,7 @@ pub async fn task_main(
|
||||
let res = handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
cancel_map,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter,
|
||||
@@ -227,7 +227,7 @@ impl ReportableError for ClientRequestError {
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -253,8 +253,8 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
return Ok(cancel_map
|
||||
.cancel_session(cancel_key_data)
|
||||
return Ok(cancellation_handler
|
||||
.cancel_session(cancel_key_data, ctx.session_id)
|
||||
.await
|
||||
.map(|()| None)?)
|
||||
}
|
||||
@@ -315,7 +315,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
let session = cancel_map.get_session();
|
||||
let session = cancellation_handler.get_session();
|
||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
|
||||
@@ -4,4 +4,4 @@ mod limiter;
|
||||
pub use aimd::Aimd;
|
||||
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
|
||||
pub use limiter::Limiter;
|
||||
pub use limiter::{EndpointRateLimiter, RateBucketInfo};
|
||||
pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
|
||||
|
||||
@@ -22,6 +22,44 @@ use super::{
|
||||
RateLimiterConfig,
|
||||
};
|
||||
|
||||
pub struct RedisRateLimiter {
|
||||
data: Vec<RateBucket>,
|
||||
info: &'static [RateBucketInfo],
|
||||
}
|
||||
|
||||
impl RedisRateLimiter {
|
||||
pub fn new(info: &'static [RateBucketInfo]) -> Self {
|
||||
Self {
|
||||
data: vec![
|
||||
RateBucket {
|
||||
start: Instant::now(),
|
||||
count: 0,
|
||||
};
|
||||
info.len()
|
||||
],
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that number of connections is below `max_rps` rps.
|
||||
pub fn check(&mut self) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
let should_allow_request = self
|
||||
.data
|
||||
.iter_mut()
|
||||
.zip(self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now));
|
||||
|
||||
if should_allow_request {
|
||||
// only increment the bucket counts if the request will actually be accepted
|
||||
self.data.iter_mut().for_each(RateBucket::inc);
|
||||
}
|
||||
|
||||
should_allow_request
|
||||
}
|
||||
}
|
||||
|
||||
// Simple per-endpoint rate limiter.
|
||||
//
|
||||
// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
pub mod notifications;
|
||||
pub mod publisher;
|
||||
|
||||
@@ -1,38 +1,44 @@
|
||||
use std::{convert::Infallible, sync::Arc};
|
||||
|
||||
use futures::StreamExt;
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::aio::PubSub;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
cache::project_info::ProjectInfoCache,
|
||||
cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
|
||||
intern::{ProjectIdInt, RoleNameInt},
|
||||
};
|
||||
|
||||
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
|
||||
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
|
||||
struct ConsoleRedisClient {
|
||||
struct RedisConsumerClient {
|
||||
client: redis::Client,
|
||||
}
|
||||
|
||||
impl ConsoleRedisClient {
|
||||
impl RedisConsumerClient {
|
||||
pub fn new(url: &str) -> anyhow::Result<Self> {
|
||||
let client = redis::Client::open(url)?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
async fn try_connect(&self) -> anyhow::Result<PubSub> {
|
||||
let mut conn = self.client.get_async_connection().await?.into_pubsub();
|
||||
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
|
||||
conn.subscribe(CHANNEL_NAME).await?;
|
||||
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
|
||||
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
|
||||
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
|
||||
conn.subscribe(PROXY_CHANNEL_NAME).await?;
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[serde(tag = "topic", content = "data")]
|
||||
enum Notification {
|
||||
pub(crate) enum Notification {
|
||||
#[serde(
|
||||
rename = "/allowed_ips_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
@@ -45,16 +51,25 @@ enum Notification {
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
PasswordUpdate { password_update: PasswordUpdate },
|
||||
#[serde(rename = "/cancel_session")]
|
||||
Cancel(CancelSession),
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct AllowedIpsUpdate {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedIpsUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct PasswordUpdate {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct PasswordUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
role_name: RoleNameInt,
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct CancelSession {
|
||||
pub region_id: Option<String>,
|
||||
pub cancel_key_data: CancelKeyData,
|
||||
pub session_id: Uuid,
|
||||
}
|
||||
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: for<'de2> serde::Deserialize<'de2>,
|
||||
@@ -64,6 +79,88 @@ where
|
||||
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
|
||||
}
|
||||
|
||||
struct MessageHandler<
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
H: NotificationsCancellationHandler + Send + Sync + 'static,
|
||||
> {
|
||||
cache: Arc<C>,
|
||||
cancellation_handler: Arc<H>,
|
||||
region_id: String,
|
||||
}
|
||||
|
||||
impl<
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
H: NotificationsCancellationHandler + Send + Sync + 'static,
|
||||
> MessageHandler<C, H>
|
||||
{
|
||||
pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
|
||||
Self {
|
||||
cache,
|
||||
cancellation_handler,
|
||||
region_id,
|
||||
}
|
||||
}
|
||||
pub fn disable_ttl(&self) {
|
||||
self.cache.disable_ttl();
|
||||
}
|
||||
pub fn enable_ttl(&self) {
|
||||
self.cache.enable_ttl();
|
||||
}
|
||||
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
|
||||
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
|
||||
use Notification::*;
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
let msg: Notification = match serde_json::from_str(&payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
tracing::error!("broken message: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
match msg {
|
||||
Cancel(cancel_session) => {
|
||||
tracing::Span::current().record(
|
||||
"session_id",
|
||||
&tracing::field::display(cancel_session.session_id),
|
||||
);
|
||||
if let Some(cancel_region) = cancel_session.region_id {
|
||||
// If the message is not for this region, ignore it.
|
||||
if cancel_region != self.region_id {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
|
||||
match self
|
||||
.cancellation_handler
|
||||
.cancel_session_no_publish(cancel_session.cancel_key_data)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to cancel session: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
invalidate_cache(self.cache.clone(), msg.clone());
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
let cache = self.cache.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
use Notification::*;
|
||||
match msg {
|
||||
@@ -74,50 +171,33 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
password_update.project_id,
|
||||
password_update.role_name,
|
||||
),
|
||||
Cancel(_) => unreachable!("cancel message should be handled separately"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(cache))]
|
||||
fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
let msg: Notification = match serde_json::from_str(&payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
tracing::error!("broken message: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
invalidate_cache(cache.clone(), msg.clone());
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg.clone());
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "console_notifications", skip_all)]
|
||||
pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
|
||||
pub async fn task_main<C>(
|
||||
url: String,
|
||||
cache: Arc<C>,
|
||||
cancel_map: CancelMap,
|
||||
region_id: String,
|
||||
) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
cache.enable_ttl();
|
||||
let handler = MessageHandler::new(
|
||||
cache,
|
||||
Arc::new(CancellationHandler::new(cancel_map, None)),
|
||||
region_id,
|
||||
);
|
||||
|
||||
loop {
|
||||
let redis = ConsoleRedisClient::new(&url)?;
|
||||
let redis = RedisConsumerClient::new(&url)?;
|
||||
let conn = match redis.try_connect().await {
|
||||
Ok(conn) => {
|
||||
cache.disable_ttl();
|
||||
handler.disable_ttl();
|
||||
conn
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -130,7 +210,7 @@ where
|
||||
};
|
||||
let mut stream = conn.into_on_message();
|
||||
while let Some(msg) = stream.next().await {
|
||||
match handle_message(msg, cache.clone()) {
|
||||
match handler.handle_message(msg).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to handle message: {e}, will try to reconnect");
|
||||
@@ -138,7 +218,7 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
cache.enable_ttl();
|
||||
handler.enable_ttl();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,6 +278,33 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn parse_cancel_session() -> anyhow::Result<()> {
|
||||
let cancel_key_data = CancelKeyData {
|
||||
backend_pid: 42,
|
||||
cancel_key: 41,
|
||||
};
|
||||
let uuid = uuid::Uuid::new_v4();
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: None,
|
||||
session_id: uuid,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result);
|
||||
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: Some("region".to_string()),
|
||||
session_id: uuid,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result,);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
80
proxy/src/redis/publisher.rs
Normal file
80
proxy/src/redis/publisher.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::AsyncCommands;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
|
||||
|
||||
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
|
||||
|
||||
pub struct RedisPublisherClient {
|
||||
client: redis::Client,
|
||||
publisher: Option<redis::aio::Connection>,
|
||||
region_id: String,
|
||||
limiter: RedisRateLimiter,
|
||||
}
|
||||
|
||||
impl RedisPublisherClient {
|
||||
pub fn new(
|
||||
url: &str,
|
||||
region_id: String,
|
||||
info: &'static [RateBucketInfo],
|
||||
) -> anyhow::Result<Self> {
|
||||
let client = redis::Client::open(url)?;
|
||||
Ok(Self {
|
||||
client,
|
||||
publisher: None,
|
||||
region_id,
|
||||
limiter: RedisRateLimiter::new(info),
|
||||
})
|
||||
}
|
||||
pub async fn try_publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping cancellation message");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
match self.publish(cancel_key_data, session_id).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
self.publisher = None;
|
||||
}
|
||||
}
|
||||
tracing::info!("Publisher is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.publish(cancel_key_data, session_id).await
|
||||
}
|
||||
|
||||
async fn publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
let conn = self
|
||||
.publisher
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow::anyhow!("not connected"))?;
|
||||
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
|
||||
region_id: Some(self.region_id.clone()),
|
||||
cancel_key_data,
|
||||
session_id,
|
||||
}))?;
|
||||
conn.publish(PROXY_CHANNEL_NAME, payload).await?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.get_async_connection().await {
|
||||
Ok(conn) => {
|
||||
self.publisher = Some(conn);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to connect to redis: {e}");
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -24,7 +24,7 @@ use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::{cancellation::CancelMap, config::ProxyConfig};
|
||||
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
server::{
|
||||
@@ -50,6 +50,7 @@ pub async fn task_main(
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("websocket server has shut down");
|
||||
@@ -115,7 +116,7 @@ pub async fn task_main(
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
@@ -127,9 +128,9 @@ pub async fn task_main(
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
|
||||
async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
request_handler(
|
||||
@@ -137,7 +138,7 @@ pub async fn task_main(
|
||||
config,
|
||||
backend,
|
||||
ws_connections,
|
||||
cancel_map,
|
||||
cancellation_handler,
|
||||
session_id,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
@@ -205,7 +206,7 @@ async fn request_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -232,7 +233,7 @@ async fn request_handler(
|
||||
config,
|
||||
ctx,
|
||||
websocket,
|
||||
cancel_map,
|
||||
cancellation_handler,
|
||||
host,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
cancellation::CancelMap,
|
||||
cancellation::CancellationHandler,
|
||||
config::ProxyConfig,
|
||||
context::RequestMonitoring,
|
||||
error::{io_error, ReportableError},
|
||||
@@ -133,7 +133,7 @@ pub async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
websocket: HyperWebsocket,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
hostname: Option<String>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -141,7 +141,7 @@ pub async fn serve_websocket(
|
||||
let res = handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
cancel_map,
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
|
||||
@@ -38,7 +38,7 @@ futures-io = { version = "0.3" }
|
||||
futures-sink = { version = "0.3" }
|
||||
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
|
||||
getrandom = { version = "0.2", default-features = false, features = ["std"] }
|
||||
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] }
|
||||
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
|
||||
hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] }
|
||||
hex = { version = "0.4", features = ["serde"] }
|
||||
hmac = { version = "0.12", default-features = false, features = ["reset"] }
|
||||
@@ -91,7 +91,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] }
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
|
||||
either = { version = "1" }
|
||||
getrandom = { version = "0.2", default-features = false, features = ["std"] }
|
||||
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] }
|
||||
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
|
||||
indexmap = { version = "1", default-features = false, features = ["std"] }
|
||||
itertools = { version = "0.10" }
|
||||
libc = { version = "0.2", features = ["extra_traits", "use_std"] }
|
||||
|
||||
Reference in New Issue
Block a user