mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
impr(proxy): Move the CancelMap to Redis hashes (#10364)
## Problem The approach of having CancelMap as an in-memory structure increases code complexity, as well as putting additional load for Redis streams. ## Summary of changes - Implement a set of KV ops for Redis client; - Remove cancel notifications code; - Send KV ops over the bounded channel to the handling background task for removing and adding the cancel keys. Closes #9660
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -6935,6 +6935,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"postgres-protocol2",
|
||||
"postgres-types2",
|
||||
"serde",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
@@ -182,6 +182,13 @@ pub struct CancelKeyData {
|
||||
pub cancel_key: i32,
|
||||
}
|
||||
|
||||
pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
|
||||
CancelKeyData {
|
||||
backend_pid: (id >> 32) as i32,
|
||||
cancel_key: (id & 0xffffffff) as i32,
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for CancelKeyData {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let hi = (self.backend_pid as u64) << 32;
|
||||
|
||||
@@ -19,3 +19,4 @@ postgres-protocol2 = { path = "../postgres-protocol2" }
|
||||
postgres-types2 = { path = "../postgres-types2" }
|
||||
tokio = { workspace = true, features = ["io-util", "time", "net"] }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
@@ -3,12 +3,13 @@ use crate::tls::TlsConnect;
|
||||
|
||||
use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect};
|
||||
use crate::{cancel_query_raw, Error};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// The capability to request cancellation of in-progress queries on a
|
||||
/// connection.
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct CancelToken {
|
||||
pub socket_config: Option<SocketConfig>,
|
||||
pub ssl_mode: SslMode,
|
||||
|
||||
@@ -18,6 +18,7 @@ use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{future, ready, TryStreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use postgres_protocol2::message::{backend::Message, frontend};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
@@ -137,7 +138,7 @@ impl InnerClient {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct SocketConfig {
|
||||
pub host: Host,
|
||||
pub port: u16,
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::tls::MakeTlsConnect;
|
||||
use crate::tls::TlsConnect;
|
||||
use crate::{Client, Connection, Error};
|
||||
use postgres_protocol2::message::frontend::StartupMessageParams;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::str;
|
||||
use std::time::Duration;
|
||||
@@ -16,7 +17,7 @@ pub use postgres_protocol2::authentication::sasl::ScramKeys;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// TLS configuration.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum SslMode {
|
||||
/// Do not use TLS.
|
||||
@@ -50,7 +51,7 @@ pub enum ReplicationMode {
|
||||
}
|
||||
|
||||
/// A host specification.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Host {
|
||||
/// A TCP hostname.
|
||||
Tcp(String),
|
||||
|
||||
@@ -12,6 +12,7 @@ pub(crate) use console_redirect::ConsoleRedirectError;
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
use local::LocalBackend;
|
||||
use postgres_client::config::AuthKeys;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
@@ -133,7 +134,7 @@ pub(crate) struct ComputeUserInfoNoEndpoint {
|
||||
pub(crate) options: NeonOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub(crate) struct ComputeUserInfo {
|
||||
pub(crate) endpoint: EndpointId,
|
||||
pub(crate) user: RoleName,
|
||||
|
||||
@@ -7,12 +7,11 @@ use std::time::Duration;
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use compute_api::spec::LocalProxySpec;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::Either;
|
||||
use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
|
||||
use proxy::auth::{self};
|
||||
use proxy::cancellation::CancellationHandlerMain;
|
||||
use proxy::cancellation::CancellationHandler;
|
||||
use proxy::config::{
|
||||
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
|
||||
};
|
||||
@@ -211,12 +210,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
auth_backend,
|
||||
http_listener,
|
||||
shutdown.clone(),
|
||||
Arc::new(CancellationHandlerMain::new(
|
||||
&config.connect_to_compute,
|
||||
Arc::new(DashMap::new()),
|
||||
None,
|
||||
proxy::metrics::CancellationSource::Local,
|
||||
)),
|
||||
Arc::new(CancellationHandler::new(&config.connect_to_compute, None)),
|
||||
endpoint_rate_limiter,
|
||||
);
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ use anyhow::bail;
|
||||
use futures::future::Either;
|
||||
use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
|
||||
use proxy::cancellation::{CancelMap, CancellationHandler};
|
||||
use proxy::cancellation::{handle_cancel_messages, CancellationHandler};
|
||||
use proxy::config::{
|
||||
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
|
||||
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
|
||||
@@ -18,8 +18,8 @@ use proxy::metrics::Metrics;
|
||||
use proxy::rate_limiter::{
|
||||
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
|
||||
};
|
||||
use proxy::redis::cancellation_publisher::RedisPublisherClient;
|
||||
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use proxy::redis::kv_ops::RedisKVClient;
|
||||
use proxy::redis::{elasticache, notifications};
|
||||
use proxy::scram::threadpool::ThreadPool;
|
||||
use proxy::serverless::cancel_set::CancelSet;
|
||||
@@ -28,7 +28,6 @@ use proxy::tls::client_config::compute_client_config_with_root_certs;
|
||||
use proxy::{auth, control_plane, http, serverless, usage_metrics};
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, warn, Instrument};
|
||||
@@ -158,8 +157,11 @@ struct ProxyCliArgs {
|
||||
#[clap(long, default_value_t = 64)]
|
||||
auth_rate_limit_ip_subnet: u8,
|
||||
/// Redis rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
|
||||
redis_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Cancellation channel size (max queue size for redis kv client)
|
||||
#[clap(long, default_value = "1024")]
|
||||
cancellation_ch_size: usize,
|
||||
/// cache for `allowed_ips` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
allowed_ips_cache: String,
|
||||
@@ -382,27 +384,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let cancel_map = CancelMap::default();
|
||||
|
||||
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
|
||||
RateBucketInfo::validate(redis_rps_limit)?;
|
||||
|
||||
let redis_publisher = match ®ional_redis_client {
|
||||
Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
|
||||
redis_publisher.clone(),
|
||||
args.region.clone(),
|
||||
redis_rps_limit,
|
||||
)?))),
|
||||
None => None,
|
||||
};
|
||||
let redis_kv_client = regional_redis_client
|
||||
.as_ref()
|
||||
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
|
||||
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<
|
||||
Option<Arc<Mutex<RedisPublisherClient>>>,
|
||||
>::new(
|
||||
// channel size should be higher than redis client limit to avoid blocking
|
||||
let cancel_ch_size = args.cancellation_ch_size;
|
||||
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||
&config.connect_to_compute,
|
||||
cancel_map.clone(),
|
||||
redis_publisher,
|
||||
proxy::metrics::CancellationSource::FromClient,
|
||||
Some(tx_cancel),
|
||||
));
|
||||
|
||||
// bit of a hack - find the min rps and max rps supported and turn it into
|
||||
@@ -495,25 +489,29 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
config,
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
config,
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(mut redis_kv_client) = redis_kv_client {
|
||||
maintenance_tasks.spawn(async move {
|
||||
redis_kv_client.try_connect().await?;
|
||||
handle_cancel_messages(&mut redis_kv_client, rx_cancel).await
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = regional_redis_client;
|
||||
|
||||
@@ -1,48 +1,124 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::CancelToken;
|
||||
use pq_proto::CancelKeyData;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::backend::{BackendIpAllowlist, ComputeUserInfo};
|
||||
use crate::auth::{check_peer_addr_is_in_list, AuthError, IpPattern};
|
||||
use crate::auth::{check_peer_addr_is_in_list, AuthError};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::ext::LockExt;
|
||||
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
|
||||
use crate::metrics::CancelChannelSizeGuard;
|
||||
use crate::metrics::{CancellationRequest, Metrics, RedisMsgKind};
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
use crate::redis::cancellation_publisher::{
|
||||
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
|
||||
};
|
||||
use crate::redis::keys::KeyPrefix;
|
||||
use crate::redis::kv_ops::RedisKVClient;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
|
||||
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
|
||||
pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
|
||||
use std::convert::Infallible;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
type IpSubnetKey = IpNet;
|
||||
|
||||
const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
|
||||
const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10);
|
||||
|
||||
// Message types for sending through mpsc channel
|
||||
pub enum CancelKeyOp {
|
||||
StoreCancelKey {
|
||||
key: String,
|
||||
field: String,
|
||||
value: String,
|
||||
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
expire: i64, // TTL for key
|
||||
},
|
||||
GetCancelData {
|
||||
key: String,
|
||||
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
RemoveCancelKey {
|
||||
key: String,
|
||||
field: String,
|
||||
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
}
|
||||
|
||||
// Running as a separate task to accept messages through the rx channel
|
||||
// In case of problems with RTT: switch to recv_many() + redis pipeline
|
||||
pub async fn handle_cancel_messages(
|
||||
client: &mut RedisKVClient,
|
||||
mut rx: mpsc::Receiver<CancelKeyOp>,
|
||||
) -> anyhow::Result<Infallible> {
|
||||
loop {
|
||||
if let Some(msg) = rx.recv().await {
|
||||
match msg {
|
||||
CancelKeyOp::StoreCancelKey {
|
||||
key,
|
||||
field,
|
||||
value,
|
||||
resp_tx,
|
||||
_guard,
|
||||
expire: _,
|
||||
} => {
|
||||
if let Some(resp_tx) = resp_tx {
|
||||
resp_tx
|
||||
.send(client.hset(key, field, value).await)
|
||||
.inspect_err(|e| {
|
||||
tracing::debug!("failed to send StoreCancelKey response: {:?}", e);
|
||||
})
|
||||
.ok();
|
||||
} else {
|
||||
drop(client.hset(key, field, value).await);
|
||||
}
|
||||
}
|
||||
CancelKeyOp::GetCancelData {
|
||||
key,
|
||||
resp_tx,
|
||||
_guard,
|
||||
} => {
|
||||
drop(resp_tx.send(client.hget_all(key).await));
|
||||
}
|
||||
CancelKeyOp::RemoveCancelKey {
|
||||
key,
|
||||
field,
|
||||
resp_tx,
|
||||
_guard,
|
||||
} => {
|
||||
if let Some(resp_tx) = resp_tx {
|
||||
resp_tx
|
||||
.send(client.hdel(key, field).await)
|
||||
.inspect_err(|e| {
|
||||
tracing::debug!("failed to send StoreCancelKey response: {:?}", e);
|
||||
})
|
||||
.ok();
|
||||
} else {
|
||||
drop(client.hdel(key, field).await);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Enables serving `CancelRequest`s.
|
||||
///
|
||||
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
|
||||
pub struct CancellationHandler<P> {
|
||||
pub struct CancellationHandler {
|
||||
compute_config: &'static ComputeConfig,
|
||||
map: CancelMap,
|
||||
client: P,
|
||||
/// This field used for the monitoring purposes.
|
||||
/// Represents the source of the cancellation request.
|
||||
from: CancellationSource,
|
||||
// rate limiter of cancellation requests
|
||||
limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
|
||||
tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -61,6 +137,12 @@ pub(crate) enum CancelError {
|
||||
|
||||
#[error("Authentication backend error")]
|
||||
AuthError(#[from] AuthError),
|
||||
|
||||
#[error("key not found")]
|
||||
NotFound,
|
||||
|
||||
#[error("proxy service error")]
|
||||
InternalError,
|
||||
}
|
||||
|
||||
impl ReportableError for CancelError {
|
||||
@@ -73,274 +155,191 @@ impl ReportableError for CancelError {
|
||||
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
|
||||
CancelError::IpNotAllowed => crate::error::ErrorKind::User,
|
||||
CancelError::NotFound => crate::error::ErrorKind::User,
|
||||
CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
CancelError::InternalError => crate::error::ErrorKind::Service,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
|
||||
impl CancellationHandler {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
tx: Option<mpsc::Sender<CancelKeyOp>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
tx,
|
||||
limiter: Arc::new(std::sync::Mutex::new(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
|
||||
64,
|
||||
),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
|
||||
// we intentionally generate a random "backend pid" and "secret key" here.
|
||||
// we use the corresponding u64 as an identifier for the
|
||||
// actual endpoint+pid+secret for postgres/pgbouncer.
|
||||
//
|
||||
// if we forwarded the backend_pid from postgres to the client, there would be a lot
|
||||
// of overlap between our computes as most pids are small (~100).
|
||||
let key = loop {
|
||||
let key = rand::random();
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
match self.map.entry(key) {
|
||||
dashmap::mapref::entry::Entry::Occupied(_) => continue,
|
||||
dashmap::mapref::entry::Entry::Vacant(e) => {
|
||||
e.insert(None);
|
||||
}
|
||||
}
|
||||
break key;
|
||||
};
|
||||
let key: CancelKeyData = rand::random();
|
||||
|
||||
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
|
||||
let redis_key = prefix_key.build_redis_key();
|
||||
|
||||
debug!("registered new query cancellation key {key}");
|
||||
Session {
|
||||
key,
|
||||
cancellation_handler: self,
|
||||
redis_key,
|
||||
cancellation_handler: Arc::clone(self),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancelling only in notification, will be removed
|
||||
pub(crate) async fn cancel_session(
|
||||
async fn get_cancel_key(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
check_allowed: bool,
|
||||
) -> Result<(), CancelError> {
|
||||
// TODO: check for unspecified address is only for backward compatibility, should be removed
|
||||
if !peer_addr.is_unspecified() {
|
||||
let subnet_key = match peer_addr {
|
||||
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
|
||||
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
|
||||
};
|
||||
if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
|
||||
// log only the subnet part of the IP address to know which subnet is rate limited
|
||||
tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
.inc(CancellationRequest {
|
||||
source: self.from,
|
||||
kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
|
||||
});
|
||||
return Err(CancelError::RateLimit);
|
||||
}
|
||||
}
|
||||
) -> Result<Option<CancelClosure>, CancelError> {
|
||||
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
|
||||
let redis_key = prefix_key.build_redis_key();
|
||||
|
||||
// NB: we should immediately release the lock after cloning the token.
|
||||
let cancel_state = self.map.get(&key).and_then(|x| x.clone());
|
||||
let Some(cancel_closure) = cancel_state else {
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
Metrics::get()
|
||||
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
|
||||
let op = CancelKeyOp::GetCancelData {
|
||||
key: redis_key,
|
||||
resp_tx,
|
||||
_guard: Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
.inc(CancellationRequest {
|
||||
source: self.from,
|
||||
kind: crate::metrics::CancellationOutcome::NotFound,
|
||||
});
|
||||
|
||||
if session_id == Uuid::nil() {
|
||||
// was already published, do not publish it again
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self.client.try_publish(key, session_id, peer_addr).await {
|
||||
Ok(()) => {} // do nothing
|
||||
Err(e) => {
|
||||
// log it here since cancel_session could be spawned in a task
|
||||
tracing::error!("failed to publish cancellation key: {key}, error: {e}");
|
||||
return Err(CancelError::IO(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HGetAll),
|
||||
};
|
||||
|
||||
if check_allowed
|
||||
&& !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice())
|
||||
{
|
||||
// log it here since cancel_session could be spawned in a task
|
||||
tracing::warn!("IP is not allowed to cancel the query: {key}");
|
||||
return Err(CancelError::IpNotAllowed);
|
||||
}
|
||||
let Some(tx) = &self.tx else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
};
|
||||
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
.inc(CancellationRequest {
|
||||
source: self.from,
|
||||
kind: crate::metrics::CancellationOutcome::Found,
|
||||
});
|
||||
info!(
|
||||
"cancelling query per user's request using key {key}, hostname {}, address: {}",
|
||||
cancel_closure.hostname, cancel_closure.socket_addr
|
||||
);
|
||||
cancel_closure.try_cancel_query(self.compute_config).await
|
||||
tx.send_timeout(op, REDIS_SEND_TIMEOUT)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::warn!("failed to send GetCancelData for {key}: {e}");
|
||||
})
|
||||
.map_err(|()| CancelError::InternalError)?;
|
||||
|
||||
let result = resp_rx.await.map_err(|e| {
|
||||
tracing::warn!("failed to receive GetCancelData response: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let cancel_state_str: Option<String> = match result {
|
||||
Ok(mut state) => {
|
||||
if state.len() == 1 {
|
||||
Some(state.remove(0).1)
|
||||
} else {
|
||||
tracing::warn!("unexpected number of entries in cancel state: {state:?}");
|
||||
return Err(CancelError::InternalError);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("failed to receive cancel state from redis: {e}");
|
||||
return Err(CancelError::InternalError);
|
||||
}
|
||||
};
|
||||
|
||||
let cancel_state: Option<CancelClosure> = match cancel_state_str {
|
||||
Some(state) => {
|
||||
let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
|
||||
tracing::warn!("failed to deserialize cancel state: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
Some(cancel_closure)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
Ok(cancel_state)
|
||||
}
|
||||
|
||||
/// Try to cancel a running query for the corresponding connection.
|
||||
/// If the cancellation key is not found, it will be published to Redis.
|
||||
/// check_allowed - if true, check if the IP is allowed to cancel the query.
|
||||
/// Will fetch IP allowlist internally.
|
||||
///
|
||||
/// return Result primarily for tests
|
||||
pub(crate) async fn cancel_session_auth<T: BackendIpAllowlist>(
|
||||
pub(crate) async fn cancel_session<T: BackendIpAllowlist>(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
ctx: RequestContext,
|
||||
check_allowed: bool,
|
||||
auth_backend: &T,
|
||||
) -> Result<(), CancelError> {
|
||||
// TODO: check for unspecified address is only for backward compatibility, should be removed
|
||||
if !ctx.peer_addr().is_unspecified() {
|
||||
let subnet_key = match ctx.peer_addr() {
|
||||
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
|
||||
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
|
||||
};
|
||||
if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
|
||||
// log only the subnet part of the IP address to know which subnet is rate limited
|
||||
tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
.inc(CancellationRequest {
|
||||
source: self.from,
|
||||
kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
|
||||
});
|
||||
return Err(CancelError::RateLimit);
|
||||
}
|
||||
let subnet_key = match ctx.peer_addr() {
|
||||
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
|
||||
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
|
||||
};
|
||||
if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
|
||||
// log only the subnet part of the IP address to know which subnet is rate limited
|
||||
tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
.inc(CancellationRequest {
|
||||
kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
|
||||
});
|
||||
return Err(CancelError::RateLimit);
|
||||
}
|
||||
|
||||
// NB: we should immediately release the lock after cloning the token.
|
||||
let cancel_state = self.map.get(&key).and_then(|x| x.clone());
|
||||
let cancel_state = self.get_cancel_key(key).await.map_err(|e| {
|
||||
tracing::warn!("failed to receive RedisOp response: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let Some(cancel_closure) = cancel_state else {
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
.inc(CancellationRequest {
|
||||
source: self.from,
|
||||
kind: crate::metrics::CancellationOutcome::NotFound,
|
||||
});
|
||||
|
||||
if ctx.session_id() == Uuid::nil() {
|
||||
// was already published, do not publish it again
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self
|
||||
.client
|
||||
.try_publish(key, ctx.session_id(), ctx.peer_addr())
|
||||
.await
|
||||
{
|
||||
Ok(()) => {} // do nothing
|
||||
Err(e) => {
|
||||
// log it here since cancel_session could be spawned in a task
|
||||
tracing::error!("failed to publish cancellation key: {key}, error: {e}");
|
||||
return Err(CancelError::IO(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
return Err(CancelError::NotFound);
|
||||
};
|
||||
|
||||
let ip_allowlist = auth_backend
|
||||
.get_allowed_ips(&ctx, &cancel_closure.user_info)
|
||||
.await
|
||||
.map_err(CancelError::AuthError)?;
|
||||
if check_allowed {
|
||||
let ip_allowlist = auth_backend
|
||||
.get_allowed_ips(&ctx, &cancel_closure.user_info)
|
||||
.await
|
||||
.map_err(CancelError::AuthError)?;
|
||||
|
||||
if check_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
|
||||
// log it here since cancel_session could be spawned in a task
|
||||
tracing::warn!("IP is not allowed to cancel the query: {key}");
|
||||
return Err(CancelError::IpNotAllowed);
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
|
||||
// log it here since cancel_session could be spawned in a task
|
||||
tracing::warn!(
|
||||
"IP is not allowed to cancel the query: {key}, address: {}",
|
||||
ctx.peer_addr()
|
||||
);
|
||||
return Err(CancelError::IpNotAllowed);
|
||||
}
|
||||
}
|
||||
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancellation_requests_total
|
||||
.inc(CancellationRequest {
|
||||
source: self.from,
|
||||
kind: crate::metrics::CancellationOutcome::Found,
|
||||
});
|
||||
info!("cancelling query per user's request using key {key}");
|
||||
cancel_closure.try_cancel_query(self.compute_config).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn contains(&self, session: &Session<P>) -> bool {
|
||||
self.map.contains_key(&session.key)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl CancellationHandler<()> {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
map: CancelMap,
|
||||
from: CancellationSource,
|
||||
) -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
map,
|
||||
client: (),
|
||||
from,
|
||||
limiter: Arc::new(std::sync::Mutex::new(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
|
||||
64,
|
||||
),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
map: CancelMap,
|
||||
client: Option<Arc<Mutex<P>>>,
|
||||
from: CancellationSource,
|
||||
) -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
map,
|
||||
client,
|
||||
from,
|
||||
limiter: Arc::new(std::sync::Mutex::new(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
|
||||
64,
|
||||
),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
/// it's impossible to name a type of an unboxed future
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
ip_allowlist: Vec<IpPattern>,
|
||||
hostname: String, // for pg_sni router
|
||||
user_info: ComputeUserInfo,
|
||||
}
|
||||
@@ -349,14 +348,12 @@ impl CancelClosure {
|
||||
pub(crate) fn new(
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
ip_allowlist: Vec<IpPattern>,
|
||||
hostname: String,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Self {
|
||||
Self {
|
||||
socket_addr,
|
||||
cancel_token,
|
||||
ip_allowlist,
|
||||
hostname,
|
||||
user_info,
|
||||
}
|
||||
@@ -385,99 +382,75 @@ impl CancelClosure {
|
||||
debug!("query was cancelled");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Obsolete (will be removed after moving CancelMap to Redis), only for notifications
|
||||
pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
|
||||
self.ip_allowlist = ip_allowlist;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for registering query cancellation tokens.
|
||||
pub(crate) struct Session<P> {
|
||||
pub(crate) struct Session {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
cancellation_handler: Arc<CancellationHandler<P>>,
|
||||
redis_key: String,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
}
|
||||
|
||||
impl<P> Session<P> {
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
||||
pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
debug!("enabling query cancellation for this session");
|
||||
self.cancellation_handler
|
||||
.map
|
||||
.insert(self.key, Some(cancel_closure));
|
||||
|
||||
self.key
|
||||
impl Session {
|
||||
pub(crate) fn key(&self) -> &CancelKeyData {
|
||||
&self.key
|
||||
}
|
||||
}
|
||||
|
||||
impl<P> Drop for Session<P> {
|
||||
fn drop(&mut self) {
|
||||
self.cancellation_handler.map.remove(&self.key);
|
||||
debug!("dropped query cancellation key {}", &self.key);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::*;
|
||||
use crate::config::RetryConfig;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
|
||||
fn config() -> ComputeConfig {
|
||||
let retry = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
// Send the store key op to the cancellation handler
|
||||
pub(crate) async fn write_cancel_key(
|
||||
&self,
|
||||
cancel_closure: CancelClosure,
|
||||
) -> Result<(), CancelError> {
|
||||
let Some(tx) = &self.cancellation_handler.tx else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
};
|
||||
|
||||
ComputeConfig {
|
||||
retry,
|
||||
tls: Arc::new(compute_client_config_with_certs(std::iter::empty())),
|
||||
timeout: Duration::from_secs(2),
|
||||
}
|
||||
}
|
||||
let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
|
||||
tracing::warn!("failed to serialize cancel closure: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_session_drop() -> anyhow::Result<()> {
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
|
||||
Box::leak(Box::new(config())),
|
||||
CancelMap::default(),
|
||||
CancellationSource::FromRedis,
|
||||
));
|
||||
|
||||
let session = cancellation_handler.clone().get_session();
|
||||
assert!(cancellation_handler.contains(&session));
|
||||
drop(session);
|
||||
// Check that the session has been dropped.
|
||||
assert!(cancellation_handler.is_empty());
|
||||
let op = CancelKeyOp::StoreCancelKey {
|
||||
key: self.redis_key.clone(),
|
||||
field: "data".to_string(),
|
||||
value: closure_json,
|
||||
resp_tx: None,
|
||||
_guard: Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HSet),
|
||||
expire: CANCEL_KEY_TTL,
|
||||
};
|
||||
|
||||
let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
|
||||
let key = self.key;
|
||||
tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_session_noop_regression() {
|
||||
let handler = CancellationHandler::<()>::new(
|
||||
Box::leak(Box::new(config())),
|
||||
CancelMap::default(),
|
||||
CancellationSource::Local,
|
||||
);
|
||||
handler
|
||||
.cancel_session(
|
||||
CancelKeyData {
|
||||
backend_pid: 0,
|
||||
cancel_key: 0,
|
||||
},
|
||||
Uuid::new_v4(),
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
pub(crate) async fn remove_cancel_key(&self) -> Result<(), CancelError> {
|
||||
let Some(tx) = &self.cancellation_handler.tx else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
};
|
||||
|
||||
let op = CancelKeyOp::RemoveCancelKey {
|
||||
key: self.redis_key.clone(),
|
||||
field: "data".to_string(),
|
||||
resp_tx: None,
|
||||
_guard: Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HSet),
|
||||
};
|
||||
|
||||
let _ = tx.send_timeout(op, REDIS_SEND_TIMEOUT).await.map_err(|e| {
|
||||
let key = self.key;
|
||||
tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,7 +296,6 @@ impl ConnCfg {
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
vec![], // TODO: deprecated, will be removed
|
||||
host.to_string(),
|
||||
user_info,
|
||||
);
|
||||
|
||||
@@ -6,7 +6,7 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, Instrument};
|
||||
|
||||
use crate::auth::backend::ConsoleRedirectBackend;
|
||||
use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal};
|
||||
use crate::cancellation::CancellationHandler;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
@@ -24,7 +24,7 @@ pub async fn task_main(
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
@@ -140,15 +140,16 @@ pub async fn task_main(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
@@ -171,13 +172,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session_auth(
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
@@ -195,7 +196,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, user_info, ip_allowlist) = match backend
|
||||
let (node_info, user_info, _ip_allowlist) = match backend
|
||||
.authenticate(ctx, &config.authentication_config, &mut stream)
|
||||
.await
|
||||
{
|
||||
@@ -220,10 +221,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
node.cancel_closure
|
||||
.set_ip_allowlist(ip_allowlist.unwrap_or_default());
|
||||
let session = cancellation_handler.get_session();
|
||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session
|
||||
.write_cancel_key(node.cancel_closure.clone())
|
||||
.await?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
@@ -237,8 +242,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
aux: node.aux.clone(),
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_cancel: session,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -56,6 +56,8 @@ pub struct ProxyMetrics {
|
||||
pub connection_requests: CounterPairVec<NumConnectionRequestsGauge>,
|
||||
#[metric(flatten)]
|
||||
pub http_endpoint_pools: HttpEndpointPools,
|
||||
#[metric(flatten)]
|
||||
pub cancel_channel_size: CounterPairVec<CancelChannelSizeGauge>,
|
||||
|
||||
/// Time it took for proxy to establish a connection to the compute endpoint.
|
||||
// largest bucket = 2^16 * 0.5ms = 32s
|
||||
@@ -294,6 +296,16 @@ impl CounterPairAssoc for NumConnectionRequestsGauge {
|
||||
pub type NumConnectionRequestsGuard<'a> =
|
||||
metrics::MeasuredCounterPairGuard<'a, NumConnectionRequestsGauge>;
|
||||
|
||||
pub struct CancelChannelSizeGauge;
|
||||
impl CounterPairAssoc for CancelChannelSizeGauge {
|
||||
const INC_NAME: &'static MetricName = MetricName::from_str("opened_msgs_cancel_channel_total");
|
||||
const DEC_NAME: &'static MetricName = MetricName::from_str("closed_msgs_cancel_channel_total");
|
||||
const INC_HELP: &'static str = "Number of processing messages in the cancellation channel.";
|
||||
const DEC_HELP: &'static str = "Number of closed messages in the cancellation channel.";
|
||||
type LabelGroupSet = StaticLabelSet<RedisMsgKind>;
|
||||
}
|
||||
pub type CancelChannelSizeGuard<'a> = metrics::MeasuredCounterPairGuard<'a, CancelChannelSizeGauge>;
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = ComputeConnectionLatencySet)]
|
||||
pub struct ComputeConnectionLatencyGroup {
|
||||
@@ -340,13 +352,6 @@ pub struct RedisErrors<'a> {
|
||||
pub channel: &'a str,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
pub enum CancellationSource {
|
||||
FromClient,
|
||||
FromRedis,
|
||||
Local,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
pub enum CancellationOutcome {
|
||||
NotFound,
|
||||
@@ -357,7 +362,6 @@ pub enum CancellationOutcome {
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = CancellationRequestSet)]
|
||||
pub struct CancellationRequest {
|
||||
pub source: CancellationSource,
|
||||
pub kind: CancellationOutcome,
|
||||
}
|
||||
|
||||
@@ -369,6 +373,16 @@ pub enum Waiting {
|
||||
RetryTimeout,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "kind")]
|
||||
pub enum RedisMsgKind {
|
||||
HSet,
|
||||
HSetMultiple,
|
||||
HGet,
|
||||
HGetAll,
|
||||
HDel,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Accumulated {
|
||||
cplane: time::Duration,
|
||||
|
||||
@@ -13,8 +13,9 @@ pub use copy_bidirectional::{copy_bidirectional_client_compute, ErrorSource};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, StartupMessageParams};
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{format_smolstr, SmolStr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
@@ -23,7 +24,7 @@ use tracing::{debug, error, info, warn, Instrument};
|
||||
|
||||
use self::connect_compute::{connect_to_compute, TcpMechanism};
|
||||
use self::passthrough::ProxyPassthrough;
|
||||
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
@@ -57,7 +58,7 @@ pub async fn task_main(
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
@@ -243,13 +244,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
@@ -278,7 +279,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session_auth(
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
@@ -312,7 +313,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let (user_info, ip_allowlist) = match user_info
|
||||
let (user_info, _ip_allowlist) = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
@@ -356,10 +357,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
node.cancel_closure
|
||||
.set_ip_allowlist(ip_allowlist.unwrap_or_default());
|
||||
let session = cancellation_handler.get_session();
|
||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session
|
||||
.write_cancel_key(node.cancel_closure.clone())
|
||||
.await?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
@@ -373,23 +378,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
aux: node.aux.clone(),
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_cancel: session,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn prepare_client_connection<P>(
|
||||
pub(crate) async fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
session: &cancellation::Session<P>,
|
||||
cancel_key_data: CancelKeyData,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
// Register compute's query cancellation token and produce a new, unique one.
|
||||
// The new token (cancel_key_data) will be sent to the client.
|
||||
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
|
||||
|
||||
// Forward all deferred notices to the client.
|
||||
for notice in &node.delayed_notice {
|
||||
stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?;
|
||||
@@ -411,7 +412,7 @@ pub(crate) async fn prepare_client_connection<P>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
|
||||
impl NeonOptions {
|
||||
|
||||
@@ -56,18 +56,18 @@ pub(crate) async fn proxy_pass(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) struct ProxyPassthrough<P, S> {
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) cancel: cancellation::Session,
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
pub(crate) _cancel: cancellation::Session<P>,
|
||||
}
|
||||
|
||||
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
@@ -81,6 +81,9 @@ impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
drop(self.cancel.remove_cancel_key().await); // we don't need a result. If the queue is full, we just log the error
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,6 +138,12 @@ impl RateBucketInfo {
|
||||
Self::new(200, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
|
||||
pub const DEFAULT_REDIS_SET: [Self; 2] = [
|
||||
Self::new(100_000, Duration::from_secs(1)),
|
||||
Self::new(50_000, Duration::from_secs(10)),
|
||||
];
|
||||
|
||||
/// All of these are per endpoint-maskedip pair.
|
||||
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
|
||||
///
|
||||
|
||||
@@ -2,12 +2,10 @@ use core::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::AsyncCommands;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
|
||||
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
|
||||
|
||||
pub trait CancellationPublisherMut: Send + Sync + 'static {
|
||||
@@ -83,9 +81,10 @@ impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
|
||||
}
|
||||
|
||||
pub struct RedisPublisherClient {
|
||||
#[allow(dead_code)]
|
||||
client: ConnectionWithCredentialsProvider,
|
||||
region_id: String,
|
||||
limiter: GlobalRateLimiter,
|
||||
_region_id: String,
|
||||
_limiter: GlobalRateLimiter,
|
||||
}
|
||||
|
||||
impl RedisPublisherClient {
|
||||
@@ -96,26 +95,12 @@ impl RedisPublisherClient {
|
||||
) -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
client,
|
||||
region_id,
|
||||
limiter: GlobalRateLimiter::new(info.into()),
|
||||
_region_id: region_id,
|
||||
_limiter: GlobalRateLimiter::new(info.into()),
|
||||
})
|
||||
}
|
||||
|
||||
async fn publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<()> {
|
||||
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
|
||||
region_id: Some(self.region_id.clone()),
|
||||
cancel_key_data,
|
||||
session_id,
|
||||
peer_addr: Some(peer_addr),
|
||||
}))?;
|
||||
let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?;
|
||||
Ok(())
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.connect().await {
|
||||
Ok(()) => {}
|
||||
@@ -126,49 +111,4 @@ impl RedisPublisherClient {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
async fn try_publish_internal(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<()> {
|
||||
// TODO: review redundant error duplication logs.
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping cancellation message");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
match self.publish(cancel_key_data, session_id, peer_addr).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
}
|
||||
}
|
||||
tracing::info!("Publisher is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.publish(cancel_key_data, session_id, peer_addr).await
|
||||
}
|
||||
}
|
||||
|
||||
impl CancellationPublisherMut for RedisPublisherClient {
|
||||
async fn try_publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<()> {
|
||||
tracing::info!("publishing cancellation key to Redis");
|
||||
match self
|
||||
.try_publish_internal(cancel_key_data, session_id, peer_addr)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
tracing::debug!("cancellation key successfuly published to Redis");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ impl Clone for Credentials {
|
||||
/// Provides PubSub connection without credentials refresh.
|
||||
pub struct ConnectionWithCredentialsProvider {
|
||||
credentials: Credentials,
|
||||
// TODO: with more load on the connection, we should consider using a connection pool
|
||||
con: Option<MultiplexedConnection>,
|
||||
refresh_token_task: Option<JoinHandle<()>>,
|
||||
mutex: tokio::sync::Mutex<()>,
|
||||
|
||||
88
proxy/src/redis/keys.rs
Normal file
88
proxy/src/redis/keys.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use anyhow::Ok;
|
||||
use pq_proto::{id_to_cancel_key, CancelKeyData};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::ErrorKind;
|
||||
|
||||
pub mod keyspace {
|
||||
pub const CANCEL_PREFIX: &str = "cancel";
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) enum KeyPrefix {
|
||||
#[serde(untagged)]
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
impl KeyPrefix {
|
||||
pub(crate) fn build_redis_key(&self) -> String {
|
||||
match self {
|
||||
KeyPrefix::Cancel(key) => {
|
||||
let hi = (key.backend_pid as u64) << 32;
|
||||
let lo = (key.cancel_key as u64) & 0xffff_ffff;
|
||||
let id = hi | lo;
|
||||
let keyspace = keyspace::CANCEL_PREFIX;
|
||||
format!("{keyspace}:{id:x}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
KeyPrefix::Cancel(_) => keyspace::CANCEL_PREFIX,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn parse_redis_key(key: &str) -> anyhow::Result<KeyPrefix> {
|
||||
let (prefix, key_str) = key.split_once(':').ok_or_else(|| {
|
||||
anyhow::anyhow!(std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"missing prefix"
|
||||
))
|
||||
})?;
|
||||
|
||||
match prefix {
|
||||
keyspace::CANCEL_PREFIX => {
|
||||
let id = u64::from_str_radix(key_str, 16)?;
|
||||
|
||||
Ok(KeyPrefix::Cancel(id_to_cancel_key(id)))
|
||||
}
|
||||
_ => Err(anyhow::anyhow!(std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"unknown prefix"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_build_redis_key() {
|
||||
let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData {
|
||||
backend_pid: 12345,
|
||||
cancel_key: 54321,
|
||||
});
|
||||
|
||||
let redis_key = cancel_key.build_redis_key();
|
||||
assert_eq!(redis_key, "cancel:30390000d431");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_redis_key() {
|
||||
let redis_key = "cancel:30390000d431";
|
||||
let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key");
|
||||
|
||||
let ref_key = CancelKeyData {
|
||||
backend_pid: 12345,
|
||||
cancel_key: 54321,
|
||||
};
|
||||
|
||||
assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str());
|
||||
let KeyPrefix::Cancel(cancel_key) = key;
|
||||
assert_eq!(ref_key, cancel_key);
|
||||
}
|
||||
}
|
||||
185
proxy/src/redis/kv_ops.rs
Normal file
185
proxy/src/redis/kv_ops.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use redis::{AsyncCommands, ToRedisArgs};
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
|
||||
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
|
||||
|
||||
pub struct RedisKVClient {
|
||||
client: ConnectionWithCredentialsProvider,
|
||||
limiter: GlobalRateLimiter,
|
||||
}
|
||||
|
||||
impl RedisKVClient {
|
||||
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
|
||||
Self {
|
||||
client,
|
||||
limiter: GlobalRateLimiter::new(info.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.connect().await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to connect to redis: {e}");
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn hset<K, F, V>(&mut self, key: K, field: F, value: V) -> anyhow::Result<()>
|
||||
where
|
||||
K: ToRedisArgs + Send + Sync,
|
||||
F: ToRedisArgs + Send + Sync,
|
||||
V: ToRedisArgs + Send + Sync,
|
||||
{
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping hset");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match self.client.hset(&key, &field, &value).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to set a key-value pair: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Redis client is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.client
|
||||
.hset(key, field, value)
|
||||
.await
|
||||
.map_err(anyhow::Error::new)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn hset_multiple<K, V>(
|
||||
&mut self,
|
||||
key: &str,
|
||||
items: &[(K, V)],
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
K: ToRedisArgs + Send + Sync,
|
||||
V: ToRedisArgs + Send + Sync,
|
||||
{
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping hset_multiple");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match self.client.hset_multiple(key, items).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to set a key-value pair: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Redis client is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.client
|
||||
.hset_multiple(key, items)
|
||||
.await
|
||||
.map_err(anyhow::Error::new)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn expire<K>(&mut self, key: K, seconds: i64) -> anyhow::Result<()>
|
||||
where
|
||||
K: ToRedisArgs + Send + Sync,
|
||||
{
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping expire");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match self.client.expire(&key, seconds).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to set a key-value pair: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Redis client is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.client
|
||||
.expire(key, seconds)
|
||||
.await
|
||||
.map_err(anyhow::Error::new)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn hget<K, F, V>(&mut self, key: K, field: F) -> anyhow::Result<V>
|
||||
where
|
||||
K: ToRedisArgs + Send + Sync,
|
||||
F: ToRedisArgs + Send + Sync,
|
||||
V: redis::FromRedisValue,
|
||||
{
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping hget");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match self.client.hget(&key, &field).await {
|
||||
Ok(value) => return Ok(value),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to get a value: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Redis client is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.client
|
||||
.hget(key, field)
|
||||
.await
|
||||
.map_err(anyhow::Error::new)
|
||||
}
|
||||
|
||||
pub(crate) async fn hget_all<K, V>(&mut self, key: K) -> anyhow::Result<V>
|
||||
where
|
||||
K: ToRedisArgs + Send + Sync,
|
||||
V: redis::FromRedisValue,
|
||||
{
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping hgetall");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match self.client.hgetall(&key).await {
|
||||
Ok(value) => return Ok(value),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to get a value: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Redis client is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.client.hgetall(key).await.map_err(anyhow::Error::new)
|
||||
}
|
||||
|
||||
pub(crate) async fn hdel<K, F>(&mut self, key: K, field: F) -> anyhow::Result<()>
|
||||
where
|
||||
K: ToRedisArgs + Send + Sync,
|
||||
F: ToRedisArgs + Send + Sync,
|
||||
{
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping hdel");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match self.client.hdel(&key, &field).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to delete a key-value pair: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Redis client is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.client
|
||||
.hdel(key, field)
|
||||
.await
|
||||
.map_err(anyhow::Error::new)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
pub mod cancellation_publisher;
|
||||
pub mod connection_with_credentials_provider;
|
||||
pub mod elasticache;
|
||||
pub mod keys;
|
||||
pub mod kv_ops;
|
||||
pub mod notifications;
|
||||
|
||||
@@ -6,18 +6,14 @@ use pq_proto::CancelKeyData;
|
||||
use redis::aio::PubSub;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::Instrument;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
use crate::cancellation::{CancelMap, CancellationHandler};
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::intern::{ProjectIdInt, RoleNameInt};
|
||||
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
|
||||
|
||||
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
|
||||
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
|
||||
@@ -25,8 +21,6 @@ async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Resu
|
||||
let mut conn = client.get_async_pubsub().await?;
|
||||
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
|
||||
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
|
||||
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
|
||||
conn.subscribe(PROXY_CHANNEL_NAME).await?;
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
@@ -71,8 +65,6 @@ pub(crate) enum Notification {
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
PasswordUpdate { password_update: PasswordUpdate },
|
||||
#[serde(rename = "/cancel_session")]
|
||||
Cancel(CancelSession),
|
||||
|
||||
#[serde(
|
||||
other,
|
||||
@@ -138,7 +130,6 @@ where
|
||||
|
||||
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
|
||||
cache: Arc<C>,
|
||||
cancellation_handler: Arc<CancellationHandler<()>>,
|
||||
region_id: String,
|
||||
}
|
||||
|
||||
@@ -146,23 +137,14 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
cache: self.cache.clone(),
|
||||
cancellation_handler: self.cancellation_handler.clone(),
|
||||
region_id: self.region_id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
pub(crate) fn new(
|
||||
cache: Arc<C>,
|
||||
cancellation_handler: Arc<CancellationHandler<()>>,
|
||||
region_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
cache,
|
||||
cancellation_handler,
|
||||
region_id,
|
||||
}
|
||||
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
|
||||
Self { cache, region_id }
|
||||
}
|
||||
|
||||
pub(crate) async fn increment_active_listeners(&self) {
|
||||
@@ -207,46 +189,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
|
||||
tracing::debug!(?msg, "received a message");
|
||||
match msg {
|
||||
Notification::Cancel(cancel_session) => {
|
||||
tracing::Span::current().record(
|
||||
"session_id",
|
||||
tracing::field::display(cancel_session.session_id),
|
||||
);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::CancelSession);
|
||||
if let Some(cancel_region) = cancel_session.region_id {
|
||||
// If the message is not for this region, ignore it.
|
||||
if cancel_region != self.region_id {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Remove unspecified peer_addr after the complete migration to the new format
|
||||
let peer_addr = cancel_session
|
||||
.peer_addr
|
||||
.unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?cancel_session.session_id);
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
|
||||
match self
|
||||
.cancellation_handler
|
||||
.cancel_session(
|
||||
cancel_session.cancel_key_data,
|
||||
uuid::Uuid::nil(),
|
||||
peer_addr,
|
||||
cancel_session.peer_addr.is_some(),
|
||||
)
|
||||
.instrument(cancel_span)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!("failed to cancel session: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Notification::AllowedIpsUpdate { .. }
|
||||
| Notification::PasswordUpdate { .. }
|
||||
| Notification::BlockPublicOrVpcAccessUpdated { .. }
|
||||
@@ -293,7 +235,6 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
password_update.project_id,
|
||||
password_update.role_name,
|
||||
),
|
||||
Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
|
||||
Notification::BlockPublicOrVpcAccessUpdated { .. } => {
|
||||
// https://github.com/neondatabase/neon/pull/10073
|
||||
}
|
||||
@@ -323,8 +264,8 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
|
||||
);
|
||||
"failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
|
||||
);
|
||||
tokio::time::sleep(RECONNECT_TIMEOUT).await;
|
||||
continue;
|
||||
}
|
||||
@@ -350,21 +291,14 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "redis_notifications", skip_all)]
|
||||
pub async fn task_main<C>(
|
||||
config: &'static ProxyConfig,
|
||||
redis: ConnectionWithCredentialsProvider,
|
||||
cache: Arc<C>,
|
||||
cancel_map: CancelMap,
|
||||
region_id: String,
|
||||
) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
|
||||
&config.connect_to_compute,
|
||||
cancel_map,
|
||||
crate::metrics::CancellationSource::FromRedis,
|
||||
));
|
||||
let handler = MessageHandler::new(cache, cancellation_handler, region_id);
|
||||
let handler = MessageHandler::new(cache, region_id);
|
||||
// 6h - 1m.
|
||||
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
|
||||
@@ -442,35 +376,6 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn parse_cancel_session() -> anyhow::Result<()> {
|
||||
let cancel_key_data = CancelKeyData {
|
||||
backend_pid: 42,
|
||||
cancel_key: 41,
|
||||
};
|
||||
let uuid = uuid::Uuid::new_v4();
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: None,
|
||||
session_id: uuid,
|
||||
peer_addr: None,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result);
|
||||
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: Some("region".to_string()),
|
||||
session_id: uuid,
|
||||
peer_addr: None,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result,);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unknown_topic() -> anyhow::Result<()> {
|
||||
|
||||
@@ -43,7 +43,7 @@ use tokio_util::task::TaskTracker;
|
||||
use tracing::{info, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::cancellation::CancellationHandler;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::ext::TaskExt;
|
||||
@@ -61,7 +61,7 @@ pub async fn task_main(
|
||||
auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
@@ -318,7 +318,7 @@ async fn connection_handler(
|
||||
backend: Arc<PoolingBackend>,
|
||||
connections: TaskTracker,
|
||||
cancellations: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
conn: AsyncRW,
|
||||
@@ -412,7 +412,7 @@ async fn request_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
session_id: uuid::Uuid,
|
||||
conn_info: ConnectionInfo,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
|
||||
@@ -12,7 +12,7 @@ use pin_project_lite::pin_project;
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::cancellation::CancellationHandler;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{io_error, ReportableError};
|
||||
@@ -129,7 +129,7 @@ pub(crate) async fn serve_websocket(
|
||||
auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
ctx: RequestContext,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
|
||||
Reference in New Issue
Block a user