Files
neon/proxy/src/cache/endpoints.rs
Anna Khanova 40f15c3123 Read cplane events from regional redis (#7352)
## Problem

Actually read redis events.

## Summary of changes

This is revert of https://github.com/neondatabase/neon/pull/7350 +
fixes.
* Fixed events parsing
* Added timeout after connection failure
* Separated regional and global redis clients.
2024-04-11 18:24:34 +00:00

227 lines
7.5 KiB
Rust

use std::{
convert::Infallible,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use dashmap::DashSet;
use redis::{
streams::{StreamReadOptions, StreamReadReply},
AsyncCommands, FromRedisValue, Value,
};
use serde::Deserialize;
use tokio::sync::Mutex;
use crate::{
config::EndpointCacheConfig,
context::RequestMonitoring,
intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
metrics::{Metrics, RedisErrors},
rate_limiter::GlobalRateLimiter,
redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
EndpointId,
};
#[derive(Deserialize, Debug, Clone)]
pub struct ControlPlaneEventKey {
endpoint_created: Option<EndpointCreated>,
branch_created: Option<BranchCreated>,
project_created: Option<ProjectCreated>,
}
#[derive(Deserialize, Debug, Clone)]
struct EndpointCreated {
endpoint_id: String,
}
#[derive(Deserialize, Debug, Clone)]
struct BranchCreated {
branch_id: String,
}
#[derive(Deserialize, Debug, Clone)]
struct ProjectCreated {
project_id: String,
}
pub struct EndpointsCache {
config: EndpointCacheConfig,
endpoints: DashSet<EndpointIdInt>,
branches: DashSet<BranchIdInt>,
projects: DashSet<ProjectIdInt>,
ready: AtomicBool,
limiter: Arc<Mutex<GlobalRateLimiter>>,
}
impl EndpointsCache {
pub fn new(config: EndpointCacheConfig) -> Self {
Self {
limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
config.limiter_info.clone(),
))),
config,
endpoints: DashSet::new(),
branches: DashSet::new(),
projects: DashSet::new(),
ready: AtomicBool::new(false),
}
}
pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool {
if !self.ready.load(Ordering::Acquire) {
return true;
}
// If cache is disabled, just collect the metrics and return.
if self.config.disable_cache {
ctx.set_rejected(self.should_reject(endpoint));
return true;
}
// If the limiter allows, we don't need to check the cache.
if self.limiter.lock().await.check() {
return true;
}
let rejected = self.should_reject(endpoint);
ctx.set_rejected(rejected);
!rejected
}
fn should_reject(&self, endpoint: &EndpointId) -> bool {
if endpoint.is_endpoint() {
!self.endpoints.contains(&EndpointIdInt::from(endpoint))
} else if endpoint.is_branch() {
!self
.branches
.contains(&BranchIdInt::from(&endpoint.as_branch()))
} else {
!self
.projects
.contains(&ProjectIdInt::from(&endpoint.as_project()))
}
}
fn insert_event(&self, key: ControlPlaneEventKey) {
// Do not do normalization here, we expect the events to be normalized.
if let Some(endpoint_created) = key.endpoint_created {
self.endpoints
.insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
}
if let Some(branch_created) = key.branch_created {
self.branches
.insert(BranchIdInt::from(&branch_created.branch_id.into()));
}
if let Some(project_created) = key.project_created {
self.projects
.insert(ProjectIdInt::from(&project_created.project_id.into()));
}
}
pub async fn do_read(
&self,
mut con: ConnectionWithCredentialsProvider,
) -> anyhow::Result<Infallible> {
let mut last_id = "0-0".to_string();
loop {
self.ready.store(false, Ordering::Release);
if let Err(e) = con.connect().await {
tracing::error!("error connecting to redis: {:?}", e);
continue;
}
if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
tracing::error!("error reading from redis: {:?}", e);
}
tokio::time::sleep(self.config.retry_interval).await;
}
}
async fn read_from_stream(
&self,
con: &mut ConnectionWithCredentialsProvider,
last_id: &mut String,
) -> anyhow::Result<()> {
tracing::info!("reading endpoints/branches/projects from redis");
self.batch_read(
con,
StreamReadOptions::default().count(self.config.initial_batch_size),
last_id,
true,
)
.await?;
tracing::info!("ready to filter user requests");
self.ready.store(true, Ordering::Release);
self.batch_read(
con,
StreamReadOptions::default()
.count(self.config.default_batch_size)
.block(self.config.xread_timeout.as_millis() as usize),
last_id,
false,
)
.await
}
fn parse_key_value(value: &Value) -> anyhow::Result<ControlPlaneEventKey> {
let s: String = FromRedisValue::from_redis_value(value)?;
Ok(serde_json::from_str(&s)?)
}
async fn batch_read(
&self,
conn: &mut ConnectionWithCredentialsProvider,
opts: StreamReadOptions,
last_id: &mut String,
return_when_finish: bool,
) -> anyhow::Result<()> {
let mut total: usize = 0;
loop {
let mut res: StreamReadReply = conn
.xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
.await?;
if res.keys.is_empty() {
if return_when_finish {
anyhow::bail!(
"Redis stream {} is empty, cannot be used to filter endpoints",
self.config.stream_name
);
}
// If we are not returning when finish, we should wait for more data.
continue;
}
if res.keys.len() != 1 {
anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
}
let res = res.keys.pop().expect("Checked length above");
let len = res.ids.len();
for x in res.ids {
total += 1;
for (_, v) in x.map {
let key = match Self::parse_key_value(&v) {
Ok(x) => x,
Err(e) => {
Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
channel: &self.config.stream_name,
});
tracing::error!("error parsing value {v:?}: {e:?}");
continue;
}
};
self.insert_event(key);
}
if total.is_power_of_two() {
tracing::debug!("endpoints read {}", total);
}
*last_id = x.id;
}
if return_when_finish && len <= self.config.default_batch_size {
break;
}
}
tracing::info!("read {} endpoints/branches/projects from redis", total);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::ControlPlaneEventKey;
#[test]
fn test() {
let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}";
let _: ControlPlaneEventKey = serde_json::from_str(s).unwrap();
}
}