mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-27 18:10:37 +00:00
Merge branch 'main' into ruslan/subzero-integration
This commit is contained in:
180
proxy/src/batch.rs
Normal file
180
proxy/src/batch.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
//! Batch processing system based on intrusive linked lists.
|
||||
//!
|
||||
//! Enqueuing a batch job requires no allocations, with
|
||||
//! direct support for cancelling jobs early.
|
||||
use std::collections::BTreeMap;
|
||||
use std::pin::pin;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use scopeguard::ScopeGuard;
|
||||
use tokio::sync::oneshot::error::TryRecvError;
|
||||
|
||||
use crate::ext::LockExt;
|
||||
|
||||
pub trait QueueProcessing: Send + 'static {
|
||||
type Req: Send + 'static;
|
||||
type Res: Send;
|
||||
|
||||
/// Get the desired batch size.
|
||||
fn batch_size(&self, queue_size: usize) -> usize;
|
||||
|
||||
/// This applies a full batch of events.
|
||||
/// Must respond with a full batch of replies.
|
||||
///
|
||||
/// If this apply can error, it's expected that errors be forwarded to each Self::Res.
|
||||
///
|
||||
/// Batching does not need to happen atomically.
|
||||
fn apply(&mut self, req: Vec<Self::Req>) -> impl Future<Output = Vec<Self::Res>> + Send;
|
||||
}
|
||||
|
||||
pub struct BatchQueue<P: QueueProcessing> {
|
||||
processor: tokio::sync::Mutex<P>,
|
||||
inner: Mutex<BatchQueueInner<P>>,
|
||||
}
|
||||
|
||||
struct BatchJob<P: QueueProcessing> {
|
||||
req: P::Req,
|
||||
res: tokio::sync::oneshot::Sender<P::Res>,
|
||||
}
|
||||
|
||||
impl<P: QueueProcessing> BatchQueue<P> {
|
||||
pub fn new(p: P) -> Self {
|
||||
Self {
|
||||
processor: tokio::sync::Mutex::new(p),
|
||||
inner: Mutex::new(BatchQueueInner {
|
||||
version: 0,
|
||||
queue: BTreeMap::new(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a single request-response process, this may be batched internally.
|
||||
///
|
||||
/// This function is not cancel safe.
|
||||
pub async fn call<R>(
|
||||
&self,
|
||||
req: P::Req,
|
||||
cancelled: impl Future<Output = R>,
|
||||
) -> Result<P::Res, R> {
|
||||
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
|
||||
|
||||
let mut cancelled = pin!(cancelled);
|
||||
let resp = loop {
|
||||
// try become the leader, or try wait for success.
|
||||
let mut processor = tokio::select! {
|
||||
// try become leader.
|
||||
p = self.processor.lock() => p,
|
||||
// wait for success.
|
||||
resp = &mut rx => break resp.ok(),
|
||||
// wait for cancellation.
|
||||
cancel = cancelled.as_mut() => {
|
||||
let mut inner = self.inner.lock_propagate_poison();
|
||||
if inner.queue.remove(&id).is_some() {
|
||||
tracing::warn!("batched task cancelled before completion");
|
||||
}
|
||||
return Err(cancel);
|
||||
},
|
||||
};
|
||||
|
||||
tracing::debug!(id, "batch: became leader");
|
||||
let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
|
||||
|
||||
// snitch incase the task gets cancelled.
|
||||
let cancel_safety = scopeguard::guard((), |()| {
|
||||
if !std::thread::panicking() {
|
||||
tracing::error!(
|
||||
id,
|
||||
"batch: leader cancelled, despite not being cancellation safe"
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
// apply a batch.
|
||||
// if this is cancelled, jobs will not be completed and will panic.
|
||||
let values = processor.apply(reqs).await;
|
||||
|
||||
// good: we didn't get cancelled.
|
||||
ScopeGuard::into_inner(cancel_safety);
|
||||
|
||||
if values.len() != resps.len() {
|
||||
tracing::error!(
|
||||
"batch: invalid response size, expected={}, got={}",
|
||||
resps.len(),
|
||||
values.len()
|
||||
);
|
||||
}
|
||||
|
||||
// send response values.
|
||||
for (tx, value) in std::iter::zip(resps, values) {
|
||||
if tx.send(value).is_err() {
|
||||
// receiver hung up but that's fine.
|
||||
}
|
||||
}
|
||||
|
||||
match rx.try_recv() {
|
||||
Ok(resp) => break Some(resp),
|
||||
Err(TryRecvError::Closed) => break None,
|
||||
// edge case - there was a race condition where
|
||||
// we became the leader but were not in the batch.
|
||||
//
|
||||
// Example:
|
||||
// thread 1: register job id=1
|
||||
// thread 2: register job id=2
|
||||
// thread 2: processor.lock().await
|
||||
// thread 1: processor.lock().await
|
||||
// thread 2: becomes leader, batch_size=1, jobs=[1].
|
||||
Err(TryRecvError::Empty) => {}
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!(id, "batch: job completed");
|
||||
|
||||
Ok(resp.expect("no response found. batch processer should not panic"))
|
||||
}
|
||||
}
|
||||
|
||||
struct BatchQueueInner<P: QueueProcessing> {
|
||||
version: u64,
|
||||
queue: BTreeMap<u64, BatchJob<P>>,
|
||||
}
|
||||
|
||||
impl<P: QueueProcessing> BatchQueueInner<P> {
|
||||
fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver<P::Res>) {
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
let id = self.version;
|
||||
|
||||
// Overflow concern:
|
||||
// This is a u64, and we might enqueue 2^16 tasks per second.
|
||||
// This gives us 2^48 seconds (9 million years).
|
||||
// Even if this does overflow, it will not break, but some
|
||||
// jobs with the higher version might never get prioritised.
|
||||
self.version += 1;
|
||||
|
||||
self.queue.insert(id, BatchJob { req, res: tx });
|
||||
|
||||
tracing::debug!(id, "batch: registered job in the queue");
|
||||
|
||||
(id, rx)
|
||||
}
|
||||
|
||||
fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<tokio::sync::oneshot::Sender<P::Res>>) {
|
||||
let batch_size = p.batch_size(self.queue.len());
|
||||
let mut reqs = Vec::with_capacity(batch_size);
|
||||
let mut resps = Vec::with_capacity(batch_size);
|
||||
let mut ids = Vec::with_capacity(batch_size);
|
||||
|
||||
while reqs.len() < batch_size {
|
||||
let Some((id, job)) = self.queue.pop_first() else {
|
||||
break;
|
||||
};
|
||||
reqs.push(job.req);
|
||||
resps.push(job.res);
|
||||
ids.push(id);
|
||||
}
|
||||
|
||||
tracing::debug!(ids=?ids, "batch: acquired jobs");
|
||||
|
||||
(reqs, resps)
|
||||
}
|
||||
}
|
||||
@@ -206,7 +206,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
auth_backend,
|
||||
http_listener,
|
||||
shutdown.clone(),
|
||||
Arc::new(CancellationHandler::new(&config.connect_to_compute, None)),
|
||||
Arc::new(CancellationHandler::new(&config.connect_to_compute)),
|
||||
endpoint_rate_limiter,
|
||||
);
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ use utils::{project_build_tag, project_git_version};
|
||||
|
||||
use crate::auth::backend::jwt::JwkCache;
|
||||
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
|
||||
use crate::cancellation::{CancellationHandler, handle_cancel_messages};
|
||||
use crate::batch::BatchQueue;
|
||||
use crate::cancellation::{CancellationHandler, CancellationProcessor};
|
||||
use crate::config::{
|
||||
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
|
||||
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml, RestConfig,
|
||||
@@ -396,13 +397,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
.as_ref()
|
||||
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
|
||||
|
||||
// channel size should be higher than redis client limit to avoid blocking
|
||||
let cancel_ch_size = args.cancellation_ch_size;
|
||||
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||
&config.connect_to_compute,
|
||||
Some(tx_cancel),
|
||||
));
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute));
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
|
||||
@@ -534,21 +529,11 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
match redis_kv_client.try_connect().await {
|
||||
Ok(()) => {
|
||||
info!("Connected to Redis KV client");
|
||||
maintenance_tasks.spawn(async move {
|
||||
handle_cancel_messages(
|
||||
&mut redis_kv_client,
|
||||
rx_cancel,
|
||||
args.cancellation_batch_size,
|
||||
)
|
||||
.await?;
|
||||
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
|
||||
client: redis_kv_client,
|
||||
batch_size: args.cancellation_batch_size,
|
||||
}));
|
||||
|
||||
drop(redis_kv_client);
|
||||
|
||||
// `handle_cancel_messages` was terminated due to the tx_cancel
|
||||
// being dropped. this is not worthy of an error, and this task can only return `Err`,
|
||||
// so let's wait forever instead.
|
||||
std::future::pending().await
|
||||
});
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
use std::convert::Infallible;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::pin::pin;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use anyhow::anyhow;
|
||||
use futures::FutureExt;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use postgres_client::CancelToken;
|
||||
use postgres_client::RawCancelToken;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use redis::{Cmd, FromRedisValue, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::auth::AuthError;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::batch::{BatchQueue, QueueProcessing};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::ControlPlaneApi;
|
||||
@@ -27,46 +32,36 @@ use crate::redis::kv_ops::RedisKVClient;
|
||||
|
||||
type IpSubnetKey = IpNet;
|
||||
|
||||
const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
|
||||
const CANCEL_KEY_TTL: std::time::Duration = std::time::Duration::from_secs(600);
|
||||
const CANCEL_KEY_REFRESH: std::time::Duration = std::time::Duration::from_secs(570);
|
||||
|
||||
// Message types for sending through mpsc channel
|
||||
pub enum CancelKeyOp {
|
||||
StoreCancelKey {
|
||||
key: String,
|
||||
field: String,
|
||||
value: String,
|
||||
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
expire: i64, // TTL for key
|
||||
key: CancelKeyData,
|
||||
value: Box<str>,
|
||||
expire: std::time::Duration,
|
||||
},
|
||||
GetCancelData {
|
||||
key: String,
|
||||
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
RemoveCancelKey {
|
||||
key: String,
|
||||
field: String,
|
||||
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
key: CancelKeyData,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Pipeline {
|
||||
inner: redis::Pipeline,
|
||||
replies: Vec<CancelReplyOp>,
|
||||
replies: usize,
|
||||
}
|
||||
|
||||
impl Pipeline {
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self {
|
||||
inner: redis::Pipeline::with_capacity(n),
|
||||
replies: Vec::with_capacity(n),
|
||||
replies: 0,
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute(&mut self, client: &mut RedisKVClient) {
|
||||
let responses = self.replies.len();
|
||||
async fn execute(self, client: &mut RedisKVClient) -> Vec<anyhow::Result<Value>> {
|
||||
let responses = self.replies;
|
||||
let batch_size = self.inner.len();
|
||||
|
||||
match client.query(&self.inner).await {
|
||||
@@ -76,176 +71,72 @@ impl Pipeline {
|
||||
batch_size,
|
||||
responses, "successfully completed cancellation jobs",
|
||||
);
|
||||
for (value, reply) in std::iter::zip(values, self.replies.drain(..)) {
|
||||
reply.send_value(value);
|
||||
}
|
||||
values.into_iter().map(Ok).collect()
|
||||
}
|
||||
Ok(value) => {
|
||||
error!(batch_size, ?value, "unexpected redis return value");
|
||||
for reply in self.replies.drain(..) {
|
||||
reply.send_err(anyhow!("incorrect response type from redis"));
|
||||
}
|
||||
std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis")))
|
||||
.take(responses)
|
||||
.collect()
|
||||
}
|
||||
Err(err) => {
|
||||
for reply in self.replies.drain(..) {
|
||||
reply.send_err(anyhow!("could not send cmd to redis: {err}"));
|
||||
}
|
||||
std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}")))
|
||||
.take(responses)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
self.inner.clear();
|
||||
self.replies.clear();
|
||||
}
|
||||
|
||||
fn add_command_with_reply(&mut self, cmd: Cmd, reply: CancelReplyOp) {
|
||||
fn add_command_with_reply(&mut self, cmd: Cmd) {
|
||||
self.inner.add_command(cmd);
|
||||
self.replies.push(reply);
|
||||
self.replies += 1;
|
||||
}
|
||||
|
||||
fn add_command_no_reply(&mut self, cmd: Cmd) {
|
||||
self.inner.add_command(cmd).ignore();
|
||||
}
|
||||
|
||||
fn add_command(&mut self, cmd: Cmd, reply: Option<CancelReplyOp>) {
|
||||
match reply {
|
||||
Some(reply) => self.add_command_with_reply(cmd, reply),
|
||||
None => self.add_command_no_reply(cmd),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CancelKeyOp {
|
||||
fn register(self, pipe: &mut Pipeline) {
|
||||
#[allow(clippy::used_underscore_binding)]
|
||||
fn register(&self, pipe: &mut Pipeline) {
|
||||
match self {
|
||||
CancelKeyOp::StoreCancelKey {
|
||||
key,
|
||||
field,
|
||||
value,
|
||||
resp_tx,
|
||||
_guard,
|
||||
expire,
|
||||
} => {
|
||||
let reply =
|
||||
resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
|
||||
pipe.add_command(Cmd::hset(&key, field, value), reply);
|
||||
pipe.add_command_no_reply(Cmd::expire(key, expire));
|
||||
CancelKeyOp::StoreCancelKey { key, value, expire } => {
|
||||
let key = KeyPrefix::Cancel(*key).build_redis_key();
|
||||
pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value));
|
||||
pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64));
|
||||
}
|
||||
CancelKeyOp::GetCancelData {
|
||||
key,
|
||||
resp_tx,
|
||||
_guard,
|
||||
} => {
|
||||
let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
|
||||
pipe.add_command_with_reply(Cmd::hgetall(key), reply);
|
||||
}
|
||||
CancelKeyOp::RemoveCancelKey {
|
||||
key,
|
||||
field,
|
||||
resp_tx,
|
||||
_guard,
|
||||
} => {
|
||||
let reply =
|
||||
resp_tx.map(|resp_tx| CancelReplyOp::RemoveCancelKey { resp_tx, _guard });
|
||||
pipe.add_command(Cmd::hdel(key, field), reply);
|
||||
CancelKeyOp::GetCancelData { key } => {
|
||||
let key = KeyPrefix::Cancel(*key).build_redis_key();
|
||||
pipe.add_command_with_reply(Cmd::hget(key, "data"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Message types for sending through mpsc channel
|
||||
pub enum CancelReplyOp {
|
||||
StoreCancelKey {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
GetCancelData {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
RemoveCancelKey {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
_guard: CancelChannelSizeGuard<'static>,
|
||||
},
|
||||
pub struct CancellationProcessor {
|
||||
pub client: RedisKVClient,
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
impl CancelReplyOp {
|
||||
fn send_err(self, e: anyhow::Error) {
|
||||
match self {
|
||||
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
|
||||
resp_tx
|
||||
.send(Err(e))
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
|
||||
resp_tx
|
||||
.send(Err(e))
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
|
||||
resp_tx
|
||||
.send(Err(e))
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
impl QueueProcessing for CancellationProcessor {
|
||||
type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
|
||||
type Res = anyhow::Result<redis::Value>;
|
||||
|
||||
fn batch_size(&self, _queue_size: usize) -> usize {
|
||||
self.batch_size
|
||||
}
|
||||
|
||||
fn send_value(self, v: redis::Value) {
|
||||
match self {
|
||||
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
|
||||
let send =
|
||||
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
|
||||
resp_tx
|
||||
.send(send)
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
|
||||
let send =
|
||||
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
|
||||
resp_tx
|
||||
.send(send)
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
|
||||
let send =
|
||||
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
|
||||
resp_tx
|
||||
.send(send)
|
||||
.inspect_err(|_| tracing::debug!("could not send reply"))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Running as a separate task to accept messages through the rx channel
|
||||
pub async fn handle_cancel_messages(
|
||||
client: &mut RedisKVClient,
|
||||
mut rx: mpsc::Receiver<CancelKeyOp>,
|
||||
batch_size: usize,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut batch = Vec::with_capacity(batch_size);
|
||||
let mut pipeline = Pipeline::with_capacity(batch_size);
|
||||
|
||||
loop {
|
||||
if rx.recv_many(&mut batch, batch_size).await == 0 {
|
||||
warn!("shutting down cancellation queue");
|
||||
break Ok(());
|
||||
}
|
||||
async fn apply(&mut self, batch: Vec<Self::Req>) -> Vec<Self::Res> {
|
||||
let mut pipeline = Pipeline::with_capacity(batch.len());
|
||||
|
||||
let batch_size = batch.len();
|
||||
debug!(batch_size, "running cancellation jobs");
|
||||
|
||||
for msg in batch.drain(..) {
|
||||
msg.register(&mut pipeline);
|
||||
for (_, op) in &batch {
|
||||
op.register(&mut pipeline);
|
||||
}
|
||||
|
||||
pipeline.execute(client).await;
|
||||
pipeline.execute(&mut self.client).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,7 +147,7 @@ pub struct CancellationHandler {
|
||||
compute_config: &'static ComputeConfig,
|
||||
// rate limiter of cancellation requests
|
||||
limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
|
||||
tx: Option<mpsc::Sender<CancelKeyOp>>, // send messages to the redis KV client task
|
||||
tx: OnceLock<BatchQueue<CancellationProcessor>>, // send messages to the redis KV client task
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -296,13 +187,10 @@ impl ReportableError for CancelError {
|
||||
}
|
||||
|
||||
impl CancellationHandler {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
tx: Option<mpsc::Sender<CancelKeyOp>>,
|
||||
) -> Self {
|
||||
pub fn new(compute_config: &'static ComputeConfig) -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
tx,
|
||||
tx: OnceLock::new(),
|
||||
limiter: Arc::new(std::sync::Mutex::new(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
|
||||
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
|
||||
@@ -312,7 +200,14 @@ impl CancellationHandler {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
|
||||
pub fn init_tx(&self, queue: BatchQueue<CancellationProcessor>) {
|
||||
self.tx
|
||||
.set(queue)
|
||||
.map_err(|_| {})
|
||||
.expect("cancellation queue should be registered once");
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: Arc<Self>) -> Session {
|
||||
// we intentionally generate a random "backend pid" and "secret key" here.
|
||||
// we use the corresponding u64 as an identifier for the
|
||||
// actual endpoint+pid+secret for postgres/pgbouncer.
|
||||
@@ -322,83 +217,68 @@ impl CancellationHandler {
|
||||
|
||||
let key: CancelKeyData = rand::random();
|
||||
|
||||
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
|
||||
let redis_key = prefix_key.build_redis_key();
|
||||
|
||||
debug!("registered new query cancellation key {key}");
|
||||
Session {
|
||||
key,
|
||||
redis_key,
|
||||
cancellation_handler: Arc::clone(self),
|
||||
cancellation_handler: self,
|
||||
}
|
||||
}
|
||||
|
||||
/// This is not cancel safe
|
||||
async fn get_cancel_key(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
) -> Result<Option<CancelClosure>, CancelError> {
|
||||
let prefix_key: KeyPrefix = KeyPrefix::Cancel(key);
|
||||
let redis_key = prefix_key.build_redis_key();
|
||||
let guard = Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HGet);
|
||||
let op = CancelKeyOp::GetCancelData { key };
|
||||
|
||||
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
|
||||
let op = CancelKeyOp::GetCancelData {
|
||||
key: redis_key,
|
||||
resp_tx,
|
||||
_guard: Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HGetAll),
|
||||
};
|
||||
|
||||
let Some(tx) = &self.tx else {
|
||||
let Some(tx) = self.tx.get() else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
};
|
||||
|
||||
tx.try_send(op)
|
||||
.map_err(|e| {
|
||||
tracing::warn!("failed to send GetCancelData for {key}: {e}");
|
||||
})
|
||||
.map_err(|()| CancelError::InternalError)?;
|
||||
|
||||
let result = resp_rx.await.map_err(|e| {
|
||||
const TIMEOUT: Duration = Duration::from_secs(5);
|
||||
let result = timeout(
|
||||
TIMEOUT,
|
||||
tx.call((guard, op), std::future::pending::<Infallible>()),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
tracing::warn!("timed out waiting to receive GetCancelData response");
|
||||
CancelError::RateLimit
|
||||
})?
|
||||
// cannot be cancelled
|
||||
.unwrap_or_else(|x| match x {})
|
||||
.map_err(|e| {
|
||||
tracing::warn!("failed to receive GetCancelData response: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let cancel_state_str: Option<String> = match result {
|
||||
Ok(mut state) => {
|
||||
if state.len() == 1 {
|
||||
Some(state.remove(0).1)
|
||||
} else {
|
||||
tracing::warn!("unexpected number of entries in cancel state: {state:?}");
|
||||
return Err(CancelError::InternalError);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("failed to receive cancel state from redis: {e}");
|
||||
return Err(CancelError::InternalError);
|
||||
}
|
||||
};
|
||||
let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
|
||||
tracing::warn!("failed to receive GetCancelData response: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
let cancel_state: Option<CancelClosure> = match cancel_state_str {
|
||||
Some(state) => {
|
||||
let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
|
||||
tracing::warn!("failed to deserialize cancel state: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
Some(cancel_closure)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
Ok(cancel_state)
|
||||
let cancel_closure: CancelClosure =
|
||||
serde_json::from_str(&cancel_state_str).map_err(|e| {
|
||||
tracing::warn!("failed to deserialize cancel state: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
|
||||
Ok(Some(cancel_closure))
|
||||
}
|
||||
|
||||
/// Try to cancel a running query for the corresponding connection.
|
||||
/// If the cancellation key is not found, it will be published to Redis.
|
||||
/// check_allowed - if true, check if the IP is allowed to cancel the query.
|
||||
/// Will fetch IP allowlist internally.
|
||||
///
|
||||
/// return Result primarily for tests
|
||||
///
|
||||
/// This is not cancel safe
|
||||
pub(crate) async fn cancel_session<T: ControlPlaneApi>(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
@@ -467,10 +347,10 @@ impl CancellationHandler {
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
/// it's impossible to name a type of an unboxed future
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
cancel_token: RawCancelToken,
|
||||
hostname: String, // for pg_sni router
|
||||
user_info: ComputeUserInfo,
|
||||
}
|
||||
@@ -478,7 +358,7 @@ pub struct CancelClosure {
|
||||
impl CancelClosure {
|
||||
pub(crate) fn new(
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
cancel_token: RawCancelToken,
|
||||
hostname: String,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Self {
|
||||
@@ -491,7 +371,7 @@ impl CancelClosure {
|
||||
}
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub(crate) async fn try_cancel_query(
|
||||
self,
|
||||
&self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
@@ -512,7 +392,6 @@ impl CancelClosure {
|
||||
pub(crate) struct Session {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
redis_key: String,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
}
|
||||
|
||||
@@ -521,60 +400,75 @@ impl Session {
|
||||
&self.key
|
||||
}
|
||||
|
||||
// Send the store key op to the cancellation handler and set TTL for the key
|
||||
pub(crate) fn write_cancel_key(
|
||||
/// Ensure the cancel key is continously refreshed,
|
||||
/// but stop when the channel is dropped.
|
||||
///
|
||||
/// This is not cancel safe
|
||||
pub(crate) async fn maintain_cancel_key(
|
||||
&self,
|
||||
cancel_closure: CancelClosure,
|
||||
) -> Result<(), CancelError> {
|
||||
let Some(tx) = &self.cancellation_handler.tx else {
|
||||
session_id: uuid::Uuid,
|
||||
cancel: tokio::sync::oneshot::Receiver<Infallible>,
|
||||
cancel_closure: &CancelClosure,
|
||||
compute_config: &ComputeConfig,
|
||||
) {
|
||||
let Some(tx) = self.cancellation_handler.tx.get() else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
// don't exit, as we only want to exit if cancelled externally.
|
||||
std::future::pending().await
|
||||
};
|
||||
|
||||
let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
|
||||
tracing::warn!("failed to serialize cancel closure: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
let closure_json = serde_json::to_string(&cancel_closure)
|
||||
.expect("serialising to json string should not fail")
|
||||
.into_boxed_str();
|
||||
|
||||
let op = CancelKeyOp::StoreCancelKey {
|
||||
key: self.redis_key.clone(),
|
||||
field: "data".to_string(),
|
||||
value: closure_json,
|
||||
resp_tx: None,
|
||||
_guard: Metrics::get()
|
||||
let mut cancel = pin!(cancel);
|
||||
|
||||
loop {
|
||||
let guard = Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HSet),
|
||||
expire: CANCEL_KEY_TTL,
|
||||
};
|
||||
.guard(RedisMsgKind::HSet);
|
||||
let op = CancelKeyOp::StoreCancelKey {
|
||||
key: self.key,
|
||||
value: closure_json.clone(),
|
||||
expire: CANCEL_KEY_TTL,
|
||||
};
|
||||
|
||||
let _ = tx.try_send(op).map_err(|e| {
|
||||
let key = self.key;
|
||||
tracing::warn!("failed to send StoreCancelKey for {key}: {e}");
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
tracing::debug!(
|
||||
src=%self.key,
|
||||
dest=?cancel_closure.cancel_token,
|
||||
"registering cancellation key"
|
||||
);
|
||||
|
||||
pub(crate) fn remove_cancel_key(&self) -> Result<(), CancelError> {
|
||||
let Some(tx) = &self.cancellation_handler.tx else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
};
|
||||
match tx.call((guard, op), cancel.as_mut()).await {
|
||||
Ok(Ok(_)) => {
|
||||
tracing::debug!(
|
||||
src=%self.key,
|
||||
dest=?cancel_closure.cancel_token,
|
||||
"registered cancellation key"
|
||||
);
|
||||
|
||||
let op = CancelKeyOp::RemoveCancelKey {
|
||||
key: self.redis_key.clone(),
|
||||
field: "data".to_string(),
|
||||
resp_tx: None,
|
||||
_guard: Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::HDel),
|
||||
};
|
||||
// wait before continuing.
|
||||
tokio::time::sleep(CANCEL_KEY_REFRESH).await;
|
||||
}
|
||||
// retry immediately.
|
||||
Ok(Err(error)) => {
|
||||
tracing::warn!(?error, "error registering cancellation key");
|
||||
}
|
||||
Err(Err(_cancelled)) => break,
|
||||
}
|
||||
}
|
||||
|
||||
let _ = tx.try_send(op).map_err(|e| {
|
||||
let key = self.key;
|
||||
tracing::warn!("failed to send RemoveCancelKey for {key}: {e}");
|
||||
});
|
||||
Ok(())
|
||||
if let Err(err) = cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.boxed()
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
?session_id,
|
||||
?err,
|
||||
"could not cancel the query in the database"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,10 +6,10 @@ use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use postgres_client::config::{AuthKeys, SslMode};
|
||||
use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, NoTls, RawConnection};
|
||||
use postgres_client::{NoTls, RawCancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
@@ -33,12 +33,51 @@ use crate::types::Host;
|
||||
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum ConnectionError {
|
||||
pub(crate) enum PostgresError {
|
||||
/// This error doesn't seem to reveal any secrets; for instance,
|
||||
/// `postgres_client::error::Kind` doesn't contain ip addresses and such.
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
}
|
||||
|
||||
impl UserFacingError for PostgresError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// This helps us drop irrelevant library-specific prefixes.
|
||||
// TODO: propagate severity level and other parameters.
|
||||
PostgresError::Postgres(err) => match err.as_db_error() {
|
||||
Some(err) => {
|
||||
let msg = err.message();
|
||||
|
||||
if msg.starts_with("unsupported startup parameter: ")
|
||||
|| msg.starts_with("unsupported startup parameter in options: ")
|
||||
{
|
||||
format!(
|
||||
"{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter"
|
||||
)
|
||||
} else {
|
||||
msg.to_owned()
|
||||
}
|
||||
}
|
||||
None => err.to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for PostgresError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
PostgresError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
PostgresError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum ConnectionError {
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] TlsError),
|
||||
|
||||
@@ -52,22 +91,6 @@ pub(crate) enum ConnectionError {
|
||||
impl UserFacingError for ConnectionError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// This helps us drop irrelevant library-specific prefixes.
|
||||
// TODO: propagate severity level and other parameters.
|
||||
ConnectionError::Postgres(err) => match err.as_db_error() {
|
||||
Some(err) => {
|
||||
let msg = err.message();
|
||||
|
||||
if msg.starts_with("unsupported startup parameter: ")
|
||||
|| msg.starts_with("unsupported startup parameter in options: ")
|
||||
{
|
||||
format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
|
||||
} else {
|
||||
msg.to_owned()
|
||||
}
|
||||
}
|
||||
None => err.to_string(),
|
||||
},
|
||||
ConnectionError::WakeComputeError(err) => err.to_string_client(),
|
||||
ConnectionError::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
@@ -80,10 +103,6 @@ impl UserFacingError for ConnectionError {
|
||||
impl ReportableError for ConnectionError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
@@ -110,6 +129,8 @@ pub(crate) struct AuthInfo {
|
||||
auth: Option<Auth>,
|
||||
server_params: StartupMessageParams,
|
||||
|
||||
channel_binding: ChannelBinding,
|
||||
|
||||
/// Console redirect sets user and database, we shouldn't re-use those from the params.
|
||||
skip_db_user: bool,
|
||||
}
|
||||
@@ -133,6 +154,8 @@ impl AuthInfo {
|
||||
auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())),
|
||||
server_params,
|
||||
skip_db_user: true,
|
||||
// pg-sni-router is a mitm so this would fail.
|
||||
channel_binding: ChannelBinding::Disable,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,6 +169,7 @@ impl AuthInfo {
|
||||
},
|
||||
server_params: StartupMessageParams::default(),
|
||||
skip_db_user: false,
|
||||
channel_binding: ChannelBinding::Prefer,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -168,6 +192,7 @@ impl AuthInfo {
|
||||
Some(Auth::Password(pw)) => config.password(pw),
|
||||
None => &mut config,
|
||||
};
|
||||
config.channel_binding(self.channel_binding);
|
||||
for (k, v) in self.server_params.iter() {
|
||||
config.set_param(k, v);
|
||||
}
|
||||
@@ -206,6 +231,56 @@ impl AuthInfo {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn authenticate(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
compute: &mut ComputeConnection,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<PostgresSettings, PostgresError> {
|
||||
// client config with stubbed connect info.
|
||||
// TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
|
||||
// utilising pqproto.rs.
|
||||
let mut tmp_config = postgres_client::Config::new(String::new(), 0);
|
||||
// We have already established SSL if necessary.
|
||||
tmp_config.ssl_mode(SslMode::Disable);
|
||||
let tmp_config = self.enrich(tmp_config);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let connection = tmp_config
|
||||
.tls_and_authenticate(&mut compute.stream, NoTls)
|
||||
.await?;
|
||||
drop(pause);
|
||||
|
||||
let RawConnection {
|
||||
stream: _,
|
||||
parameters,
|
||||
delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connection;
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
|
||||
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
|
||||
// Yet another reason to rework the connection establishing code.
|
||||
let cancel_closure = CancelClosure::new(
|
||||
compute.socket_addr,
|
||||
RawCancelToken {
|
||||
ssl_mode: compute.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
compute.hostname.to_string(),
|
||||
user_info,
|
||||
);
|
||||
|
||||
Ok(PostgresSettings {
|
||||
params: parameters,
|
||||
cancel_closure,
|
||||
delayed_notice,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectInfo {
|
||||
@@ -265,53 +340,45 @@ impl ConnectInfo {
|
||||
}
|
||||
}
|
||||
|
||||
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
|
||||
|
||||
pub(crate) struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
// TODO(conrad): we don't need to parse these.
|
||||
// These are just immediately forwarded back to the client.
|
||||
// We could instead stream them out instead of reading them into memory.
|
||||
pub struct PostgresSettings {
|
||||
/// PostgreSQL connection parameters.
|
||||
pub(crate) params: std::collections::HashMap<String, String>,
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
pub(crate) cancel_closure: CancelClosure,
|
||||
/// Labels for proxy's metrics.
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub cancel_closure: CancelClosure,
|
||||
/// Notices received from compute after authenticating
|
||||
pub(crate) delayed_notice: Vec<NoticeResponseBody>,
|
||||
pub delayed_notice: Vec<NoticeResponseBody>,
|
||||
}
|
||||
|
||||
_guage: NumDbConnectionsGuard<'static>,
|
||||
pub struct ComputeConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
pub hostname: Host,
|
||||
pub ssl_mode: SslMode,
|
||||
pub socket_addr: SocketAddr,
|
||||
pub guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl ConnectInfo {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub(crate) async fn connect(
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
aux: MetricsAuxInfo,
|
||||
auth: &AuthInfo,
|
||||
aux: &MetricsAuxInfo,
|
||||
config: &ComputeConfig,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
|
||||
// we setup SSL early in `ConnectInfo::connect_raw`.
|
||||
tmp_config.ssl_mode(SslMode::Disable);
|
||||
|
||||
) -> Result<ComputeConnection, ConnectionError> {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream) = self.connect_raw(config).await?;
|
||||
let connection = tmp_config.connect_raw(stream, NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
let RawConnection {
|
||||
stream,
|
||||
parameters,
|
||||
delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connection;
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
|
||||
let MaybeTlsStream::Raw(stream) = stream.into_inner();
|
||||
|
||||
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
|
||||
info!(
|
||||
@@ -323,27 +390,13 @@ impl ConnectInfo {
|
||||
ctx.get_testodrome_id().unwrap_or_default(),
|
||||
);
|
||||
|
||||
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
|
||||
// Yet another reason to rework the connection establishing code.
|
||||
let cancel_closure = CancelClosure::new(
|
||||
socket_addr,
|
||||
CancelToken {
|
||||
socket_config: None,
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
self.host.to_string(),
|
||||
user_info,
|
||||
);
|
||||
|
||||
let connection = PostgresConnection {
|
||||
let connection = ComputeConnection {
|
||||
stream,
|
||||
params: parameters,
|
||||
delayed_notice,
|
||||
cancel_closure,
|
||||
aux,
|
||||
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
socket_addr,
|
||||
hostname: self.host.clone(),
|
||||
ssl_mode: self.ssl_mode,
|
||||
aux: aux.clone(),
|
||||
guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
};
|
||||
|
||||
Ok(connection)
|
||||
|
||||
@@ -120,7 +120,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(
|
||||
@@ -218,11 +218,9 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
};
|
||||
auth_info.set_startup_params(¶ms, true);
|
||||
|
||||
let node = connect_to_compute(
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info,
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&node_info,
|
||||
@@ -232,22 +230,40 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
let pg_settings = auth_info
|
||||
.authenticate(ctx, &mut node, user_info)
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
prepare_client_connection(&pg_settings, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
session
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
cancel,
|
||||
&pg_settings.cancel_closure,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node.stream,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id: None,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -76,13 +76,9 @@ impl NodeInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
auth: &compute::AuthInfo,
|
||||
config: &ComputeConfig,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.conn_info
|
||||
.connect(ctx, self.aux.clone(), auth, config, user_info)
|
||||
.await
|
||||
) -> Result<compute::ComputeConnection, compute::ConnectionError> {
|
||||
self.conn_info.connect(ctx, &self.aux, config).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -61,6 +61,10 @@
|
||||
clippy::too_many_lines,
|
||||
clippy::unused_self
|
||||
)]
|
||||
#![allow(
|
||||
clippy::unsafe_derive_deserialize,
|
||||
reason = "false positive: https://github.com/rust-lang/rust-clippy/issues/15120"
|
||||
)]
|
||||
#![cfg_attr(
|
||||
any(test, feature = "testing"),
|
||||
allow(
|
||||
@@ -75,6 +79,7 @@
|
||||
pub mod binary;
|
||||
|
||||
mod auth;
|
||||
mod batch;
|
||||
mod cache;
|
||||
mod cancellation;
|
||||
mod compute;
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use futures::FutureExt;
|
||||
use std::convert::Infallible;
|
||||
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::compute::MaybeRustlsStream;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::metrics::{
|
||||
Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
|
||||
NumDbConnectionsGuard,
|
||||
};
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
@@ -64,40 +66,20 @@ pub(crate) async fn proxy_pass(
|
||||
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) compute: MaybeRustlsStream,
|
||||
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
pub(crate) cancel: cancellation::Session,
|
||||
|
||||
pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender<Infallible>,
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
pub(crate) _db_conn: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(
|
||||
self.client,
|
||||
self.compute.stream,
|
||||
self.aux,
|
||||
self.private_link_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.boxed()
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
|
||||
|
||||
res
|
||||
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,7 @@ use async_trait::async_trait;
|
||||
use tokio::time;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
|
||||
use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
@@ -50,15 +49,13 @@ pub(crate) trait ConnectMechanism {
|
||||
}
|
||||
|
||||
pub(crate) struct TcpMechanism {
|
||||
pub(crate) auth: AuthInfo,
|
||||
/// connect_to_compute concurrency lock
|
||||
pub(crate) locks: &'static ApiLocks<Host>,
|
||||
pub(crate) user_info: ComputeUserInfo,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TcpMechanism {
|
||||
type Connection = PostgresConnection;
|
||||
type Connection = ComputeConnection;
|
||||
type ConnectError = compute::ConnectionError;
|
||||
type Error = compute::ConnectionError;
|
||||
|
||||
@@ -71,13 +68,9 @@ impl ConnectMechanism for TcpMechanism {
|
||||
ctx: &RequestContext,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
) -> Result<ComputeConnection, Self::Error> {
|
||||
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
|
||||
permit.release_result(
|
||||
node_info
|
||||
.connect(ctx, &self.auth, config, self.user_info.clone())
|
||||
.await,
|
||||
)
|
||||
permit.release_result(node_info.connect(ctx, config).await)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
@@ -357,28 +357,43 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let res = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: creds.info.clone(),
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&auth::Backend::ControlPlane(cplane, creds.info),
|
||||
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
|
||||
let node = match res {
|
||||
let mut node = match res {
|
||||
Ok(node) => node,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
let pg_settings = auth_info.authenticate(ctx, &mut node, creds.info).await;
|
||||
let pg_settings = match pg_settings {
|
||||
Ok(pg_settings) => pg_settings,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
prepare_client_connection(&pg_settings, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
session
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
cancel,
|
||||
&pg_settings.cancel_closure,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
@@ -387,31 +402,34 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node.stream,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
pub(crate) fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
settings: &compute::PostgresSettings,
|
||||
cancel_key_data: CancelKeyData,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) {
|
||||
// Forward all deferred notices to the client.
|
||||
for notice in &node.delayed_notice {
|
||||
for notice in &settings.delayed_notice {
|
||||
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
|
||||
buf.extend_from_slice(notice.as_bytes());
|
||||
});
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
for (name, value) in &node.params {
|
||||
for (name, value) in &settings.params {
|
||||
stream.write_message(BeMessage::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
|
||||
@@ -99,7 +99,6 @@ impl ShouldRetryWakeCompute for postgres_client::Error {
|
||||
impl CouldRetry for compute::ConnectionError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.could_retry(),
|
||||
compute::ConnectionError::TlsError(err) => err.could_retry(),
|
||||
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
|
||||
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
|
||||
@@ -109,7 +108,6 @@ impl CouldRetry for compute::ConnectionError {
|
||||
impl ShouldRetryWakeCompute for compute::ConnectionError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.should_retry_wake_compute(),
|
||||
// the cache entry was not checked for validity
|
||||
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
|
||||
_ => true,
|
||||
|
||||
@@ -169,7 +169,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.tls_and_authenticate(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
@@ -252,7 +252,7 @@ async fn connect_failure(
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.tls_and_authenticate(server, client_config.make_tls_connect()?)
|
||||
.await
|
||||
.err()
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
@@ -199,7 +199,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Disable)
|
||||
.connect_raw(server, NoTls)
|
||||
.tls_and_authenticate(server, NoTls)
|
||||
.await
|
||||
.err() // -> Option<E>
|
||||
.context("client shouldn't be able to connect")?;
|
||||
@@ -228,7 +228,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.tls_and_authenticate(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
@@ -245,7 +245,7 @@ async fn handshake_raw() -> anyhow::Result<()> {
|
||||
.dbname("earth")
|
||||
.set_param("options", "project=generic-project-name")
|
||||
.ssl_mode(SslMode::Prefer)
|
||||
.connect_raw(server, NoTls)
|
||||
.tls_and_authenticate(server, NoTls)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
@@ -293,7 +293,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.tls_and_authenticate(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
@@ -317,7 +317,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.tls_and_authenticate(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
@@ -344,7 +344,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
.dbname("db")
|
||||
.password(&password) // no password will match the mocked secret
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.tls_and_authenticate(server, client_config.make_tls_connect()?)
|
||||
.await
|
||||
.err() // -> Option<E>
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use std::io::ErrorKind;
|
||||
|
||||
use anyhow::Ok;
|
||||
|
||||
use crate::pqproto::{CancelKeyData, id_to_cancel_key};
|
||||
use crate::pqproto::CancelKeyData;
|
||||
|
||||
pub mod keyspace {
|
||||
pub const CANCEL_PREFIX: &str = "cancel";
|
||||
@@ -23,40 +19,12 @@ impl KeyPrefix {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
KeyPrefix::Cancel(_) => keyspace::CANCEL_PREFIX,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn parse_redis_key(key: &str) -> anyhow::Result<KeyPrefix> {
|
||||
let (prefix, key_str) = key.split_once(':').ok_or_else(|| {
|
||||
anyhow::anyhow!(std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"missing prefix"
|
||||
))
|
||||
})?;
|
||||
|
||||
match prefix {
|
||||
keyspace::CANCEL_PREFIX => {
|
||||
let id = u64::from_str_radix(key_str, 16)?;
|
||||
|
||||
Ok(KeyPrefix::Cancel(id_to_cancel_key(id)))
|
||||
}
|
||||
_ => Err(anyhow::anyhow!(std::io::Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"unknown prefix"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::pqproto::id_to_cancel_key;
|
||||
|
||||
#[test]
|
||||
fn test_build_redis_key() {
|
||||
@@ -65,16 +33,4 @@ mod tests {
|
||||
let redis_key = cancel_key.build_redis_key();
|
||||
assert_eq!(redis_key, "cancel:30390000d431");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_redis_key() {
|
||||
let redis_key = "cancel:30390000d431";
|
||||
let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key");
|
||||
|
||||
let ref_key = id_to_cancel_key(12345 << 32 | 54321);
|
||||
|
||||
assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str());
|
||||
let KeyPrefix::Cancel(cancel_key) = key;
|
||||
assert_eq!(ref_key, cancel_key);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::FutureExt;
|
||||
use redis::aio::ConnectionLike;
|
||||
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
|
||||
|
||||
@@ -35,14 +38,11 @@ impl RedisKVClient {
|
||||
}
|
||||
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.connect().await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to connect to redis: {e}");
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
self.client
|
||||
.connect()
|
||||
.boxed()
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
|
||||
}
|
||||
|
||||
pub(crate) async fn query<T: FromRedisValue>(
|
||||
@@ -54,15 +54,25 @@ impl RedisKVClient {
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match q.query(&mut self.client).await {
|
||||
let e = match q.query(&mut self.client).await {
|
||||
Ok(t) => return Ok(t),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to run query: {e}");
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
tracing::error!("failed to run query: {e}");
|
||||
match e.retry_method() {
|
||||
redis::RetryMethod::Reconnect => {
|
||||
tracing::info!("Redis client is disconnected. Reconnecting...");
|
||||
self.try_connect().await?;
|
||||
}
|
||||
redis::RetryMethod::RetryImmediately => {}
|
||||
redis::RetryMethod::WaitAndRetry => {
|
||||
// somewhat arbitrary.
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
_ => Err(e)?,
|
||||
}
|
||||
|
||||
tracing::info!("Redis client is disconnected. Reconnecting...");
|
||||
self.try_connect().await?;
|
||||
Ok(q.query(&mut self.client).await?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,7 +167,7 @@ pub(crate) async fn serve_websocket(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
|
||||
@@ -399,7 +399,7 @@ async fn collect_metrics_iteration(
|
||||
|
||||
fn create_remote_path_prefix(now: DateTime<Utc>) -> String {
|
||||
format!(
|
||||
"year={year:04}/month={month:02}/day={day:02}/{hour:02}:{minute:02}:{second:02}Z",
|
||||
"year={year:04}/month={month:02}/day={day:02}/hour={hour:02}/{hour:02}:{minute:02}:{second:02}Z",
|
||||
year = now.year(),
|
||||
month = now.month(),
|
||||
day = now.day(),
|
||||
@@ -461,7 +461,7 @@ async fn upload_backup_events(
|
||||
real_now.second().into(),
|
||||
real_now.nanosecond(),
|
||||
));
|
||||
let path = format!("{path_prefix}_{id}.json.gz");
|
||||
let path = format!("{path_prefix}_{id}.ndjson.gz");
|
||||
let remote_path = match RemotePath::from_string(&path) {
|
||||
Ok(remote_path) => remote_path,
|
||||
Err(e) => {
|
||||
@@ -471,9 +471,12 @@ async fn upload_backup_events(
|
||||
|
||||
// TODO: This is async compression from Vec to Vec. Rewrite as byte stream.
|
||||
// Use sync compression in blocking threadpool.
|
||||
let data = serde_json::to_vec(chunk).context("serialize metrics")?;
|
||||
let mut encoder = GzipEncoder::new(Vec::new());
|
||||
encoder.write_all(&data).await.context("compress metrics")?;
|
||||
for event in chunk.events.iter() {
|
||||
let data = serde_json::to_vec(event).context("serialize metrics")?;
|
||||
encoder.write_all(&data).await.context("compress metrics")?;
|
||||
encoder.write_all(b"\n").await.context("compress metrics")?;
|
||||
}
|
||||
encoder.shutdown().await.context("compress metrics")?;
|
||||
let compressed_data: Bytes = encoder.get_ref().clone().into();
|
||||
backoff::retry(
|
||||
@@ -499,7 +502,7 @@ async fn upload_backup_events(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use std::io::BufReader;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use anyhow::Error;
|
||||
@@ -673,11 +676,22 @@ mod tests {
|
||||
{
|
||||
let path = local_fs_path.join(&path_prefix).to_string();
|
||||
if entry.path().to_str().unwrap().starts_with(&path) {
|
||||
let chunk = serde_json::from_reader(flate2::bufread::GzDecoder::new(
|
||||
BufReader::new(fs::File::open(entry.into_path()).unwrap()),
|
||||
))
|
||||
.unwrap();
|
||||
stored_chunks.push(chunk);
|
||||
let file = fs::File::open(entry.into_path()).unwrap();
|
||||
let decoder = flate2::bufread::GzDecoder::new(BufReader::new(file));
|
||||
let reader = BufReader::new(decoder);
|
||||
|
||||
let mut events: Vec<Event<Extra, String>> = Vec::new();
|
||||
for line in reader.lines() {
|
||||
let line = line.unwrap();
|
||||
let event: Event<Extra, String> = serde_json::from_str(&line).unwrap();
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
let report = Report {
|
||||
events: Cow::Owned(events),
|
||||
};
|
||||
|
||||
stored_chunks.push(report);
|
||||
}
|
||||
}
|
||||
storage_test_dir.close().ok();
|
||||
|
||||
Reference in New Issue
Block a user