Proxy reconnect pubsub before expiration (#7562)

## Problem

Proxy reconnects to redis only after it's already unavailable.

## Summary of changes

Reconnects every 6h.
This commit is contained in:
Anna Khanova
2024-05-03 10:00:29 +02:00
committed by GitHub
parent 5f099dc760
commit 240efb82f9
2 changed files with 95 additions and 40 deletions

View File

@@ -5,9 +5,11 @@ use std::{
time::Duration,
};
use async_trait::async_trait;
use dashmap::DashMap;
use rand::{thread_rng, Rng};
use smol_str::SmolStr;
use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, info};
@@ -21,11 +23,12 @@ use crate::{
use super::{Cache, Cached};
#[async_trait]
pub trait ProjectInfoCache {
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
fn enable_ttl(&self);
fn disable_ttl(&self);
async fn decrement_active_listeners(&self);
async fn increment_active_listeners(&self);
}
struct Entry<T> {
@@ -116,8 +119,10 @@ pub struct ProjectInfoCacheImpl {
start_time: Instant,
ttl_disabled_since_us: AtomicU64,
active_listeners_lock: Mutex<usize>,
}
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating allowed ips for project `{}`", project_id);
@@ -148,15 +153,27 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
}
}
}
fn enable_ttl(&self) {
self.ttl_disabled_since_us
.store(u64::MAX, std::sync::atomic::Ordering::Relaxed);
async fn decrement_active_listeners(&self) {
let mut listeners_guard = self.active_listeners_lock.lock().await;
if *listeners_guard == 0 {
tracing::error!("active_listeners count is already 0, something is broken");
return;
}
*listeners_guard -= 1;
if *listeners_guard == 0 {
self.ttl_disabled_since_us
.store(u64::MAX, std::sync::atomic::Ordering::SeqCst);
}
}
fn disable_ttl(&self) {
let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
self.ttl_disabled_since_us
.store(new_ttl, std::sync::atomic::Ordering::Relaxed);
async fn increment_active_listeners(&self) {
let mut listeners_guard = self.active_listeners_lock.lock().await;
*listeners_guard += 1;
if *listeners_guard == 1 {
let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
self.ttl_disabled_since_us
.store(new_ttl, std::sync::atomic::Ordering::SeqCst);
}
}
}
@@ -168,6 +185,7 @@ impl ProjectInfoCacheImpl {
config,
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
start_time: Instant::now(),
active_listeners_lock: Mutex::new(0),
}
}
@@ -432,7 +450,7 @@ mod tests {
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
}));
cache.clone().disable_ttl();
cache.clone().increment_active_listeners().await;
tokio::time::advance(Duration::from_secs(2)).await;
let project_id: ProjectId = "project".into();
@@ -489,7 +507,7 @@ mod tests {
}
#[tokio::test]
async fn test_disable_ttl_invalidate_added_before() {
async fn test_increment_active_listeners_invalidate_added_before() {
tokio::time::pause();
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 2,
@@ -514,7 +532,7 @@ mod tests {
(&user1).into(),
secret1.clone(),
);
cache.clone().disable_ttl();
cache.clone().increment_active_listeners().await;
tokio::time::advance(Duration::from_millis(100)).await;
cache.insert_role_secret(
(&project_id).into(),

View File

@@ -4,6 +4,7 @@ use futures::StreamExt;
use pq_proto::CancelKeyData;
use redis::aio::PubSub;
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
@@ -77,6 +78,16 @@ struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
region_id: String,
}
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
cancellation_handler: self.cancellation_handler.clone(),
region_id: self.region_id.clone(),
}
}
}
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub fn new(
cache: Arc<C>,
@@ -89,11 +100,11 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
region_id,
}
}
pub fn disable_ttl(&self) {
self.cache.disable_ttl();
pub async fn increment_active_listeners(&self) {
self.cache.increment_active_listeners().await;
}
pub fn enable_ttl(&self) {
self.cache.enable_ttl();
pub async fn decrement_active_listeners(&self) {
self.cache.decrement_active_listeners().await;
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
@@ -182,37 +193,24 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
}
}
/// Handle console's invalidation messages.
#[tracing::instrument(name = "console_notifications", skip_all)]
pub async fn task_main<C>(
async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
handler: MessageHandler<C>,
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
cancel_map: CancelMap,
region_id: String,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
cache.enable_ttl();
let handler = MessageHandler::new(
cache,
Arc::new(CancellationHandler::<()>::new(
cancel_map,
crate::metrics::CancellationSource::FromRedis,
)),
region_id,
);
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
loop {
if cancellation_token.is_cancelled() {
return Ok(());
}
let mut conn = match try_connect(&redis).await {
Ok(conn) => {
handler.disable_ttl();
handler.increment_active_listeners().await;
conn
}
Err(e) => {
tracing::error!(
"failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
);
"failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
);
tokio::time::sleep(RECONNECT_TIMEOUT).await;
continue;
}
@@ -226,8 +224,47 @@ where
break;
}
}
if cancellation_token.is_cancelled() {
handler.decrement_active_listeners().await;
return Ok(());
}
}
handler.enable_ttl();
handler.decrement_active_listeners().await;
}
}
/// Handle console's invalidation messages.
#[tracing::instrument(name = "redis_notifications", skip_all)]
pub async fn task_main<C>(
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
cancel_map: CancelMap,
region_id: String,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
cancel_map,
crate::metrics::CancellationSource::FromRedis,
));
let handler = MessageHandler::new(cache, cancellation_handler, region_id);
// 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
loop {
let cancellation_token = CancellationToken::new();
interval.tick().await;
tokio::spawn(handle_messages(
handler.clone(),
redis.clone(),
cancellation_token.clone(),
));
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
cancellation_token.cancel();
});
}
}