proxy: connect redis with AWS IAM (#7189)

## Problem

Support of IAM Roles for Service Accounts for authentication.

## Summary of changes

* Obtain aws 15m-long credentials
* Retrieve redis password from credentials
* Update every 1h to keep connection for more than 12h
* For now allow to have different endpoints for pubsub/stream redis.

TODOs: 
* PubSub doesn't support credentials refresh, consider using stream
instead.
* We need an AWS role for proxy to be able to connect to both: S3 and
elasticache.

Credentials obtaining and connection refresh was tested on xenon
preview.

https://github.com/neondatabase/cloud/issues/10365
This commit is contained in:
Anna Khanova
2024-03-22 09:38:04 +01:00
committed by GitHub
parent 3ee34a3f26
commit 6770ddba2e
18 changed files with 803 additions and 264 deletions

View File

@@ -11,6 +11,10 @@ testing = []
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
aws-config.workspace = true
aws-sdk-iam.workspace = true
aws-sigv4.workspace = true
aws-types.workspace = true
base64.workspace = true
bstr.workspace = true
bytes = { workspace = true, features = ["serde"] }
@@ -27,6 +31,7 @@ hashlink.workspace = true
hex.workspace = true
hmac.workspace = true
hostname.workspace = true
http.workspace = true
humantime.workspace = true
hyper-tungstenite.workspace = true
hyper.workspace = true

View File

@@ -1,3 +1,10 @@
use aws_config::environment::EnvironmentVariableCredentialsProvider;
use aws_config::imds::credentials::ImdsCredentialsProvider;
use aws_config::meta::credentials::CredentialsProviderChain;
use aws_config::meta::region::RegionProviderChain;
use aws_config::profile::ProfileFileCredentialsProvider;
use aws_config::provider_config::ProviderConfig;
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::MaybeOwned;
@@ -10,11 +17,14 @@ use proxy::config::ProjectInfoCacheOptions;
use proxy::console;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http;
use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT;
use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::RateLimiterConfig;
use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::elasticache;
use proxy::redis::notifications;
use proxy::redis::publisher::RedisPublisherClient;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
@@ -150,9 +160,24 @@ struct ProxyCliArgs {
/// disable ip check for http requests. If it is too time consuming, it could be turned off.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_ip_check_for_http: bool,
/// redis url for notifications.
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
#[clap(long)]
redis_notifications: Option<String>,
/// redis host for streaming connections (might be different from the notifications host)
#[clap(long)]
redis_host: Option<String>,
/// redis port for streaming connections (might be different from the notifications host)
#[clap(long)]
redis_port: Option<u16>,
/// redis cluster name, used in aws elasticache
#[clap(long)]
redis_cluster_name: Option<String>,
/// redis user_id, used in aws elasticache
#[clap(long)]
redis_user_id: Option<String>,
/// aws region to retrieve credentials
#[clap(long, default_value_t = String::new())]
aws_region: String,
/// cache for `project_info` (use `size=0` to disable)
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
project_info_cache: String,
@@ -216,6 +241,61 @@ async fn main() -> anyhow::Result<()> {
let config = build_config(&args)?;
info!("Authentication backend: {}", config.auth_backend);
info!("Using region: {}", config.aws_region);
let region_provider = RegionProviderChain::default_provider().or_else(&*config.aws_region); // Replace with your Redis region if needed
let provider_conf =
ProviderConfig::without_region().with_region(region_provider.region().await);
let aws_credentials_provider = {
// uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"
CredentialsProviderChain::first_try("env", EnvironmentVariableCredentialsProvider::new())
// uses "AWS_PROFILE" / `aws sso login --profile <profile>`
.or_else(
"profile-sso",
ProfileFileCredentialsProvider::builder()
.configure(&provider_conf)
.build(),
)
// uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME"
// needed to access remote extensions bucket
.or_else(
"token",
WebIdentityTokenCredentialsProvider::builder()
.configure(&provider_conf)
.build(),
)
// uses imds v2
.or_else("imds", ImdsCredentialsProvider::builder().build())
};
let elasticache_credentials_provider = Arc::new(elasticache::CredentialsProvider::new(
elasticache::AWSIRSAConfig::new(
config.aws_region.clone(),
args.redis_cluster_name,
args.redis_user_id,
),
aws_credentials_provider,
));
let redis_notifications_client =
match (args.redis_notifications, (args.redis_host, args.redis_port)) {
(Some(url), _) => {
info!("Starting redis notifications listener ({url})");
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url))
}
(None, (Some(host), Some(port))) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host,
port,
elasticache_credentials_provider.clone(),
),
),
(None, (None, None)) => {
warn!("Redis is disabled");
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
};
// Check that we can bind to address before further initialization
let http_address: SocketAddr = args.http.parse()?;
@@ -233,17 +313,22 @@ async fn main() -> anyhow::Result<()> {
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
let cancel_map = CancelMap::default();
let redis_publisher = match &args.redis_notifications {
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
url,
// let redis_notifications_client = redis_notifications_client.map(|x| Box::leak(Box::new(x)));
let redis_publisher = match &redis_notifications_client {
Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
redis_publisher.clone(),
args.region.clone(),
&config.redis_rps_limit,
)?))),
None => None,
};
let cancellation_handler = Arc::new(CancellationHandler::new(
let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<tokio::sync::Mutex<RedisPublisherClient>>>,
>::new(
cancel_map.clone(),
redis_publisher,
NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT,
));
// client facing tasks. these will exit on error or on cancellation
@@ -290,17 +375,16 @@ async fn main() -> anyhow::Result<()> {
if let auth::BackendType::Console(api, _) = &config.auth_backend {
if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
if let Some(redis_notifications_client) = redis_notifications_client {
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(
url.to_owned(),
redis_notifications_client.clone(),
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
@@ -445,8 +529,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
endpoint_rps_limit,
redis_rps_limit,
handshake_timeout: args.handshake_timeout,
// TODO: add this argument
region: args.region.clone(),
aws_region: args.aws_region.clone(),
}));
Ok(config)

View File

@@ -1,4 +1,3 @@
use async_trait::async_trait;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
@@ -10,18 +9,26 @@ use tracing::info;
use uuid::Uuid;
use crate::{
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
redis::publisher::RedisPublisherClient,
error::ReportableError,
metrics::NUM_CANCELLATION_REQUESTS,
redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
},
};
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
pub type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
/// Enables serving `CancelRequest`s.
///
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler {
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler<P> {
map: CancelMap,
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
client: P,
/// This field used for the monitoring purposes.
/// Represents the source of the cancellation request.
from: &'static str,
}
#[derive(Debug, Error)]
@@ -44,49 +51,9 @@ impl ReportableError for CancelError {
}
}
impl CancellationHandler {
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
Self { map, redis_client }
}
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
) -> Result<(), CancelError> {
let from = "from_client";
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
if let Some(redis_client) = &self.redis_client {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
info!("publishing cancellation key to Redis");
match redis_client.lock().await.try_publish(key, session_id).await {
Ok(()) => {
info!("cancellation key successfuly published to Redis");
}
Err(e) => {
tracing::error!("failed to publish a message: {e}");
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)));
}
}
}
return Ok(());
};
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
}
impl<P: CancellationPublisher> CancellationHandler<P> {
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub fn get_session(self: Arc<Self>) -> Session {
pub fn get_session(self: Arc<Self>) -> Session<P> {
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
@@ -112,9 +79,39 @@ impl CancellationHandler {
cancellation_handler: self,
}
}
/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
pub async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
) -> Result<(), CancelError> {
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
NUM_CANCELLATION_REQUESTS
.with_label_values(&[self.from, "not_found"])
.inc();
match self.client.try_publish(key, session_id).await {
Ok(()) => {} // do nothing
Err(e) => {
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)));
}
}
return Ok(());
};
NUM_CANCELLATION_REQUESTS
.with_label_values(&[self.from, "found"])
.inc();
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
}
#[cfg(test)]
fn contains(&self, session: &Session) -> bool {
fn contains(&self, session: &Session<P>) -> bool {
self.map.contains_key(&session.key)
}
@@ -124,31 +121,19 @@ impl CancellationHandler {
}
}
#[async_trait]
pub trait NotificationsCancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
impl CancellationHandler<()> {
pub fn new(map: CancelMap, from: &'static str) -> Self {
Self {
map,
client: (),
from,
}
}
}
#[async_trait]
impl NotificationsCancellationHandler for CancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
let from = "from_redis";
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
match cancel_closure {
Some(cancel_closure) => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
cancel_closure.try_cancel_query().await
}
None => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
tracing::warn!("query cancellation key not found: {key}");
Ok(())
}
}
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: &'static str) -> Self {
Self { map, client, from }
}
}
@@ -178,14 +163,14 @@ impl CancelClosure {
}
/// Helper for registering query cancellation tokens.
pub struct Session {
pub struct Session<P> {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandler<P>>,
}
impl Session {
impl<P> Session<P> {
/// Store the cancel token for the given session.
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
@@ -198,7 +183,7 @@ impl Session {
}
}
impl Drop for Session {
impl<P> Drop for Session<P> {
fn drop(&mut self) {
self.cancellation_handler.map.remove(&self.key);
info!("dropped query cancellation key {}", &self.key);
@@ -207,14 +192,16 @@ impl Drop for Session {
#[cfg(test)]
mod tests {
use crate::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS;
use super::*;
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler {
map: CancelMap::default(),
redis_client: None,
});
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
CancelMap::default(),
NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS,
));
let session = cancellation_handler.clone().get_session();
assert!(cancellation_handler.contains(&session));

View File

@@ -28,6 +28,7 @@ pub struct ProxyConfig {
pub redis_rps_limit: Vec<RateBucketInfo>,
pub region: String,
pub handshake_timeout: Duration,
pub aws_region: String,
}
#[derive(Debug)]

View File

@@ -161,6 +161,9 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
.unwrap()
});
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client";
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis";
pub enum Waiting {
Cplane,
Client,

View File

@@ -10,7 +10,7 @@ pub mod wake_compute;
use crate::{
auth,
cancellation::{self, CancellationHandler},
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
@@ -62,7 +62,7 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -233,12 +233,12 @@ impl ReportableError for ClientRequestError {
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: IntCounterPairGuard,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
info!("handling interactive connection from client");
let proto = ctx.protocol;
@@ -338,9 +338,9 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
async fn prepare_client_connection(
async fn prepare_client_connection<P>(
node: &compute::PostgresConnection,
session: &cancellation::Session,
session: &cancellation::Session<P>,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<(), std::io::Error> {
// Register compute's query cancellation token and produce a new, unique one.

View File

@@ -55,17 +55,17 @@ pub async fn proxy_pass(
Ok(())
}
pub struct ProxyPassthrough<S> {
pub struct ProxyPassthrough<P, S> {
pub client: Stream<S>,
pub compute: PostgresConnection,
pub aux: MetricsAuxInfo,
pub req: IntCounterPairGuard,
pub conn: IntCounterPairGuard,
pub cancel: cancellation::Session,
pub cancel: cancellation::Session<P>,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub async fn proxy_pass(self) -> anyhow::Result<()> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
self.compute.cancel_closure.try_cancel_query().await?;

View File

@@ -1,2 +1,4 @@
pub mod cancellation_publisher;
pub mod connection_with_credentials_provider;
pub mod elasticache;
pub mod notifications;
pub mod publisher;

View File

@@ -0,0 +1,167 @@
use std::sync::Arc;
use async_trait::async_trait;
use pq_proto::CancelKeyData;
use redis::AsyncCommands;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
use super::{
connection_with_credentials_provider::ConnectionWithCredentialsProvider,
notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME},
};
#[async_trait]
pub trait CancellationPublisherMut: Send + Sync + 'static {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()>;
}
#[async_trait]
pub trait CancellationPublisher: Send + Sync + 'static {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()>;
}
#[async_trait]
impl CancellationPublisherMut for () {
async fn try_publish(
&mut self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
) -> anyhow::Result<()> {
Ok(())
}
}
#[async_trait]
impl<P: CancellationPublisherMut> CancellationPublisher for P {
async fn try_publish(
&self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
) -> anyhow::Result<()> {
self.try_publish(_cancel_key_data, _session_id).await
}
}
#[async_trait]
impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
if let Some(p) = self {
p.try_publish(cancel_key_data, session_id).await
} else {
Ok(())
}
}
}
#[async_trait]
impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
self.lock()
.await
.try_publish(cancel_key_data, session_id)
.await
}
}
pub struct RedisPublisherClient {
client: ConnectionWithCredentialsProvider,
region_id: String,
limiter: RedisRateLimiter,
}
impl RedisPublisherClient {
pub fn new(
client: ConnectionWithCredentialsProvider,
region_id: String,
info: &'static [RateBucketInfo],
) -> anyhow::Result<Self> {
Ok(Self {
client,
region_id,
limiter: RedisRateLimiter::new(info),
})
}
async fn publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
region_id: Some(self.region_id.clone()),
cancel_key_data,
session_id,
}))?;
self.client.publish(PROXY_CHANNEL_NAME, payload).await?;
Ok(())
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
match self.client.connect().await {
Ok(()) => {}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e);
}
}
Ok(())
}
async fn try_publish_internal(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping cancellation message");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.publish(cancel_key_data, session_id).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to publish a message: {e}");
}
}
tracing::info!("Publisher is disconnected. Reconnectiong...");
self.try_connect().await?;
self.publish(cancel_key_data, session_id).await
}
}
#[async_trait]
impl CancellationPublisherMut for RedisPublisherClient {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
tracing::info!("publishing cancellation key to Redis");
match self.try_publish_internal(cancel_key_data, session_id).await {
Ok(()) => {
tracing::info!("cancellation key successfuly published to Redis");
Ok(())
}
Err(e) => {
tracing::error!("failed to publish a message: {e}");
Err(e)
}
}
}
}

View File

@@ -0,0 +1,225 @@
use std::{sync::Arc, time::Duration};
use futures::FutureExt;
use redis::{
aio::{ConnectionLike, MultiplexedConnection},
ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult,
};
use tokio::task::JoinHandle;
use tracing::{error, info};
use super::elasticache::CredentialsProvider;
enum Credentials {
Static(ConnectionInfo),
Dynamic(Arc<CredentialsProvider>, redis::ConnectionAddr),
}
impl Clone for Credentials {
fn clone(&self) -> Self {
match self {
Credentials::Static(info) => Credentials::Static(info.clone()),
Credentials::Dynamic(provider, addr) => {
Credentials::Dynamic(Arc::clone(provider), addr.clone())
}
}
}
}
/// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token.
/// Provides PubSub connection without credentials refresh.
pub struct ConnectionWithCredentialsProvider {
credentials: Credentials,
con: Option<MultiplexedConnection>,
refresh_token_task: Option<JoinHandle<()>>,
mutex: tokio::sync::Mutex<()>,
}
impl Clone for ConnectionWithCredentialsProvider {
fn clone(&self) -> Self {
Self {
credentials: self.credentials.clone(),
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
}
}
}
impl ConnectionWithCredentialsProvider {
pub fn new_with_credentials_provider(
host: String,
port: u16,
credentials_provider: Arc<CredentialsProvider>,
) -> Self {
Self {
credentials: Credentials::Dynamic(
credentials_provider,
redis::ConnectionAddr::TcpTls {
host,
port,
insecure: false,
tls_params: None,
},
),
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
}
}
pub fn new_with_static_credentials<T: IntoConnectionInfo>(params: T) -> Self {
Self {
credentials: Credentials::Static(params.into_connection_info().unwrap()),
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
}
}
pub async fn connect(&mut self) -> anyhow::Result<()> {
let _guard = self.mutex.lock().await;
if let Some(con) = self.con.as_mut() {
match redis::cmd("PING").query_async(con).await {
Ok(()) => {
return Ok(());
}
Err(e) => {
error!("Error during PING: {e:?}");
}
}
} else {
info!("Connection is not established");
}
info!("Establishing a new connection...");
self.con = None;
if let Some(f) = self.refresh_token_task.take() {
f.abort()
}
let con = self
.get_client()
.await?
.get_multiplexed_tokio_connection()
.await?;
if let Credentials::Dynamic(credentials_provider, _) = &self.credentials {
let credentials_provider = credentials_provider.clone();
let con2 = con.clone();
let f = tokio::spawn(async move {
let _ = Self::keep_connection(con2, credentials_provider).await;
});
self.refresh_token_task = Some(f);
}
self.con = Some(con);
Ok(())
}
async fn get_connection_info(&self) -> anyhow::Result<ConnectionInfo> {
match &self.credentials {
Credentials::Static(info) => Ok(info.clone()),
Credentials::Dynamic(provider, addr) => {
let (username, password) = provider.provide_credentials().await?;
Ok(ConnectionInfo {
addr: addr.clone(),
redis: RedisConnectionInfo {
db: 0,
username: Some(username),
password: Some(password.clone()),
},
})
}
}
}
async fn get_client(&self) -> anyhow::Result<redis::Client> {
let client = redis::Client::open(self.get_connection_info().await?)?;
Ok(client)
}
// PubSub does not support credentials refresh.
// Requires manual reconnection every 12h.
pub async fn get_async_pubsub(&self) -> anyhow::Result<redis::aio::PubSub> {
Ok(self.get_client().await?.get_async_pubsub().await?)
}
// The connection lives for 12h.
// It can be prolonged with sending `AUTH` commands with the refreshed token.
// https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits
async fn keep_connection(
mut con: MultiplexedConnection,
credentials_provider: Arc<CredentialsProvider>,
) -> anyhow::Result<()> {
loop {
// The connection lives for 12h, for the sanity check we refresh it every hour.
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
match Self::refresh_token(&mut con, credentials_provider.clone()).await {
Ok(()) => {
info!("Token refreshed");
}
Err(e) => {
error!("Error during token refresh: {e:?}");
}
}
}
}
async fn refresh_token(
con: &mut MultiplexedConnection,
credentials_provider: Arc<CredentialsProvider>,
) -> anyhow::Result<()> {
let (user, password) = credentials_provider.provide_credentials().await?;
redis::cmd("AUTH")
.arg(user)
.arg(password)
.query_async(con)
.await?;
Ok(())
}
/// Sends an already encoded (packed) command into the TCP socket and
/// reads the single response from it.
pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult<redis::Value> {
// Clone connection to avoid having to lock the ArcSwap in write mode
let con = self.con.as_mut().ok_or(redis::RedisError::from((
redis::ErrorKind::IoError,
"Connection not established",
)))?;
con.send_packed_command(cmd).await
}
/// Sends multiple already encoded (packed) command into the TCP socket
/// and reads `count` responses from it. This is used to implement
/// pipelining.
pub async fn send_packed_commands(
&mut self,
cmd: &redis::Pipeline,
offset: usize,
count: usize,
) -> RedisResult<Vec<redis::Value>> {
// Clone shared connection future to avoid having to lock the ArcSwap in write mode
let con = self.con.as_mut().ok_or(redis::RedisError::from((
redis::ErrorKind::IoError,
"Connection not established",
)))?;
con.send_packed_commands(cmd, offset, count).await
}
}
impl ConnectionLike for ConnectionWithCredentialsProvider {
fn req_packed_command<'a>(
&'a mut self,
cmd: &'a redis::Cmd,
) -> redis::RedisFuture<'a, redis::Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a redis::Pipeline,
offset: usize,
count: usize,
) -> redis::RedisFuture<'a, Vec<redis::Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
}
fn get_db(&self) -> i64 {
0
}
}

View File

@@ -0,0 +1,110 @@
use std::time::{Duration, SystemTime};
use aws_config::meta::credentials::CredentialsProviderChain;
use aws_sdk_iam::config::ProvideCredentials;
use aws_sigv4::http_request::{
self, SignableBody, SignableRequest, SignatureLocation, SigningSettings,
};
use tracing::info;
#[derive(Debug)]
pub struct AWSIRSAConfig {
region: String,
service_name: String,
cluster_name: String,
user_id: String,
token_ttl: Duration,
action: String,
}
impl AWSIRSAConfig {
pub fn new(region: String, cluster_name: Option<String>, user_id: Option<String>) -> Self {
AWSIRSAConfig {
region,
service_name: "elasticache".to_string(),
cluster_name: cluster_name.unwrap_or_default(),
user_id: user_id.unwrap_or_default(),
// "The IAM authentication token is valid for 15 minutes"
// https://docs.aws.amazon.com/memorydb/latest/devguide/auth-iam.html#auth-iam-limits
token_ttl: Duration::from_secs(15 * 60),
action: "connect".to_string(),
}
}
}
/// Credentials provider for AWS elasticache authentication.
///
/// Official documentation:
/// <https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html>
///
/// Useful resources:
/// <https://aws.amazon.com/blogs/database/simplify-managing-access-to-amazon-elasticache-for-redis-clusters-with-iam/>
pub struct CredentialsProvider {
config: AWSIRSAConfig,
credentials_provider: CredentialsProviderChain,
}
impl CredentialsProvider {
pub fn new(config: AWSIRSAConfig, credentials_provider: CredentialsProviderChain) -> Self {
CredentialsProvider {
config,
credentials_provider,
}
}
pub async fn provide_credentials(&self) -> anyhow::Result<(String, String)> {
let aws_credentials = self
.credentials_provider
.provide_credentials()
.await?
.into();
info!("AWS credentials successfully obtained");
info!("Connecting to Redis with configuration: {:?}", self.config);
let mut settings = SigningSettings::default();
settings.signature_location = SignatureLocation::QueryParams;
settings.expires_in = Some(self.config.token_ttl);
let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
.identity(&aws_credentials)
.region(&self.config.region)
.name(&self.config.service_name)
.time(SystemTime::now())
.settings(settings)
.build()?
.into();
let auth_params = [
("Action", &self.config.action),
("User", &self.config.user_id),
];
let auth_params = url::form_urlencoded::Serializer::new(String::new())
.extend_pairs(auth_params)
.finish();
let auth_uri = http::Uri::builder()
.scheme("http")
.authority(self.config.cluster_name.as_bytes())
.path_and_query(format!("/?{auth_params}"))
.build()?;
info!("{}", auth_uri);
// Convert the HTTP request into a signable request
let signable_request = SignableRequest::new(
"GET",
auth_uri.to_string(),
std::iter::empty(),
SignableBody::Bytes(&[]),
)?;
// Sign and then apply the signature to the request
let (si, _) = http_request::sign(signable_request, &signing_params)?.into_parts();
let mut signable_request = http::Request::builder()
.method("GET")
.uri(auth_uri)
.body(())?;
si.apply_to_request_http1x(&mut signable_request);
Ok((
self.config.user_id.clone(),
signable_request
.uri()
.to_string()
.replacen("http://", "", 1),
))
}
}

View File

@@ -6,11 +6,12 @@ use redis::aio::PubSub;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::{
cache::project_info::ProjectInfoCache,
cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
cancellation::{CancelMap, CancellationHandler},
intern::{ProjectIdInt, RoleNameInt},
metrics::REDIS_BROKEN_MESSAGES,
metrics::{NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, REDIS_BROKEN_MESSAGES},
};
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
@@ -18,23 +19,13 @@ pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
struct RedisConsumerClient {
client: redis::Client,
}
impl RedisConsumerClient {
pub fn new(url: &str) -> anyhow::Result<Self> {
let client = redis::Client::open(url)?;
Ok(Self { client })
}
async fn try_connect(&self) -> anyhow::Result<PubSub> {
let mut conn = self.client.get_async_connection().await?.into_pubsub();
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
conn.subscribe(PROXY_CHANNEL_NAME).await?;
Ok(conn)
}
async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
let mut conn = client.get_async_pubsub().await?;
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
conn.subscribe(PROXY_CHANNEL_NAME).await?;
Ok(conn)
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
@@ -80,21 +71,18 @@ where
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
}
struct MessageHandler<
C: ProjectInfoCache + Send + Sync + 'static,
H: NotificationsCancellationHandler + Send + Sync + 'static,
> {
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
cache: Arc<C>,
cancellation_handler: Arc<H>,
cancellation_handler: Arc<CancellationHandler<()>>,
region_id: String,
}
impl<
C: ProjectInfoCache + Send + Sync + 'static,
H: NotificationsCancellationHandler + Send + Sync + 'static,
> MessageHandler<C, H>
{
pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub fn new(
cache: Arc<C>,
cancellation_handler: Arc<CancellationHandler<()>>,
region_id: String,
) -> Self {
Self {
cache,
cancellation_handler,
@@ -139,7 +127,7 @@ impl<
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
match self
.cancellation_handler
.cancel_session_no_publish(cancel_session.cancel_key_data)
.cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil())
.await
{
Ok(()) => {}
@@ -182,7 +170,7 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
/// Handle console's invalidation messages.
#[tracing::instrument(name = "console_notifications", skip_all)]
pub async fn task_main<C>(
url: String,
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
cancel_map: CancelMap,
region_id: String,
@@ -193,13 +181,15 @@ where
cache.enable_ttl();
let handler = MessageHandler::new(
cache,
Arc::new(CancellationHandler::new(cancel_map, None)),
Arc::new(CancellationHandler::<()>::new(
cancel_map,
NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS,
)),
region_id,
);
loop {
let redis = RedisConsumerClient::new(&url)?;
let conn = match redis.try_connect().await {
let mut conn = match try_connect(&redis).await {
Ok(conn) => {
handler.disable_ttl();
conn
@@ -212,7 +202,7 @@ where
continue;
}
};
let mut stream = conn.into_on_message();
let mut stream = conn.on_message();
while let Some(msg) = stream.next().await {
match handler.handle_message(msg).await {
Ok(()) => {}

View File

@@ -1,80 +0,0 @@
use pq_proto::CancelKeyData;
use redis::AsyncCommands;
use uuid::Uuid;
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
pub struct RedisPublisherClient {
client: redis::Client,
publisher: Option<redis::aio::Connection>,
region_id: String,
limiter: RedisRateLimiter,
}
impl RedisPublisherClient {
pub fn new(
url: &str,
region_id: String,
info: &'static [RateBucketInfo],
) -> anyhow::Result<Self> {
let client = redis::Client::open(url)?;
Ok(Self {
client,
publisher: None,
region_id,
limiter: RedisRateLimiter::new(info),
})
}
pub async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping cancellation message");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.publish(cancel_key_data, session_id).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to publish a message: {e}");
self.publisher = None;
}
}
tracing::info!("Publisher is disconnected. Reconnectiong...");
self.try_connect().await?;
self.publish(cancel_key_data, session_id).await
}
async fn publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
let conn = self
.publisher
.as_mut()
.ok_or_else(|| anyhow::anyhow!("not connected"))?;
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
region_id: Some(self.region_id.clone()),
cancel_key_data,
session_id,
}))?;
conn.publish(PROXY_CHANNEL_NAME, payload).await?;
Ok(())
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
match self.client.get_async_connection().await {
Ok(conn) => {
self.publisher = Some(conn);
}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e.into());
}
}
Ok(())
}
}

View File

@@ -21,11 +21,12 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio_util::task::TaskTracker;
use tracing::instrument::Instrumented;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
use hyper::{
server::conn::{AddrIncoming, AddrStream},
Body, Method, Request, Response,
@@ -47,7 +48,7 @@ pub async fn task_main(
ws_listener: TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");
@@ -237,7 +238,7 @@ async fn request_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
// used to cancel in-flight HTTP requests. not used to cancel websockets

View File

@@ -1,5 +1,5 @@
use crate::{
cancellation::CancellationHandler,
cancellation::CancellationHandlerMain,
config::ProxyConfig,
context::RequestMonitoring,
error::{io_error, ReportableError},
@@ -134,7 +134,7 @@ pub async fn serve_websocket(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
websocket: HyperWebsocket,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {