Compare commits

...

3 Commits

Author SHA1 Message Date
Conrad Ludgate
4454662349 use hget instead of hgetall 2025-06-10 07:05:23 -07:00
Conrad Ludgate
2e6acf7bed box the connect future and respect the redis retry methods on err 2025-06-10 06:43:31 -07:00
Conrad Ludgate
6b528426e7 split CancelToken into RawCancelToken for smaller sizes and better typesafety 2025-06-10 06:29:43 -07:00
7 changed files with 65 additions and 70 deletions

View File

@@ -1,5 +1,3 @@
use std::io;
use tokio::net::TcpStream;
use crate::client::SocketConfig;
@@ -8,7 +6,7 @@ use crate::tls::MakeTlsConnect;
use crate::{Error, cancel_query_raw, connect_socket};
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
config: SocketConfig,
ssl_mode: SslMode,
tls: T,
process_id: i32,
@@ -17,16 +15,6 @@ pub(crate) async fn cancel_query<T>(
where
T: MakeTlsConnect<TcpStream>,
{
let config = match config {
Some(config) => config,
None => {
return Err(Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"unknown host",
)));
}
};
let hostname = match &config.host {
Host::Tcp(host) => &**host,
};

View File

@@ -9,9 +9,16 @@ use crate::{Error, cancel_query, cancel_query_raw};
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone)]
pub struct CancelToken {
pub socket_config: Option<SocketConfig>,
pub socket_config: SocketConfig,
pub raw: RawCancelToken,
}
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone, Serialize, Deserialize)]
pub struct RawCancelToken {
pub ssl_mode: SslMode,
pub process_id: i32,
pub secret_key: i32,
@@ -36,14 +43,16 @@ impl CancelToken {
{
cancel_query::cancel_query(
self.socket_config.clone(),
self.ssl_mode,
self.raw.ssl_mode,
tls,
self.process_id,
self.secret_key,
self.raw.process_id,
self.raw.secret_key,
)
.await
}
}
impl RawCancelToken {
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>

View File

@@ -12,6 +12,7 @@ use postgres_protocol2::message::frontend;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::cancel_token::RawCancelToken;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::{Host, SslMode};
use crate::query::RowStream;
@@ -331,10 +332,12 @@ impl Client {
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
socket_config: self.socket_config.clone(),
raw: RawCancelToken {
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
},
}
}

View File

@@ -3,7 +3,7 @@
use postgres_protocol2::message::backend::ReadyForQueryBody;
pub use crate::cancel_token::CancelToken;
pub use crate::cancel_token::{CancelToken, RawCancelToken};
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use anyhow::{Context, anyhow};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::CancelToken;
use postgres_client::RawCancelToken;
use postgres_client::tls::MakeTlsConnect;
use redis::{Cmd, FromRedisValue, Value};
use serde::{Deserialize, Serialize};
@@ -33,7 +33,6 @@ const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
pub enum CancelKeyOp {
StoreCancelKey {
key: String,
field: String,
value: String,
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
_guard: CancelChannelSizeGuard<'static>,
@@ -41,7 +40,7 @@ pub enum CancelKeyOp {
},
GetCancelData {
key: String,
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
resp_tx: oneshot::Sender<anyhow::Result<String>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
@@ -120,7 +119,6 @@ impl CancelKeyOp {
match self {
CancelKeyOp::StoreCancelKey {
key,
field,
value,
resp_tx,
_guard,
@@ -128,7 +126,7 @@ impl CancelKeyOp {
} => {
let reply =
resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
pipe.add_command(Cmd::hset(&key, field, value), reply);
pipe.add_command(Cmd::hset(&key, "data", value), reply);
pipe.add_command_no_reply(Cmd::expire(key, expire));
}
CancelKeyOp::GetCancelData {
@@ -137,7 +135,7 @@ impl CancelKeyOp {
_guard,
} => {
let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
pipe.add_command_with_reply(Cmd::hgetall(key), reply);
pipe.add_command_with_reply(Cmd::hget(key, "data"), reply);
}
CancelKeyOp::RemoveCancelKey {
key,
@@ -160,7 +158,7 @@ pub enum CancelReplyOp {
_guard: CancelChannelSizeGuard<'static>,
},
GetCancelData {
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
resp_tx: oneshot::Sender<anyhow::Result<String>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
@@ -347,7 +345,7 @@ impl CancellationHandler {
_guard: Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HGetAll),
.guard(RedisMsgKind::HGet),
};
let Some(tx) = &self.tx else {
@@ -366,32 +364,21 @@ impl CancellationHandler {
CancelError::InternalError
})?;
let cancel_state_str: Option<String> = match result {
Ok(mut state) => {
if state.len() == 1 {
Some(state.remove(0).1)
} else {
tracing::warn!("unexpected number of entries in cancel state: {state:?}");
return Err(CancelError::InternalError);
}
}
let cancel_state_str: String = match result {
Ok(s) => s,
Err(e) => {
tracing::warn!("failed to receive cancel state from redis: {e}");
return Err(CancelError::InternalError);
}
};
let cancel_state: Option<CancelClosure> = match cancel_state_str {
Some(state) => {
let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
tracing::warn!("failed to deserialize cancel state: {e}");
CancelError::InternalError
})?;
Some(cancel_closure)
}
None => None,
};
Ok(cancel_state)
let cancel_closure: CancelClosure =
serde_json::from_str(&cancel_state_str).map_err(|e| {
tracing::warn!("failed to deserialize cancel state: {e}");
CancelError::InternalError
})?;
Ok(Some(cancel_closure))
}
/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
@@ -470,7 +457,7 @@ impl CancellationHandler {
#[derive(Clone, Serialize, Deserialize)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
cancel_token: RawCancelToken,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
}
@@ -478,7 +465,7 @@ pub struct CancelClosure {
impl CancelClosure {
pub(crate) fn new(
socket_addr: SocketAddr,
cancel_token: CancelToken,
cancel_token: RawCancelToken,
hostname: String,
user_info: ComputeUserInfo,
) -> Self {
@@ -538,7 +525,6 @@ impl Session {
let op = CancelKeyOp::StoreCancelKey {
key: self.redis_key.clone(),
field: "data".to_string(),
value: closure_json,
resp_tx: None,
_guard: Metrics::get()

View File

@@ -9,7 +9,7 @@ use itertools::Itertools;
use postgres_client::config::{AuthKeys, SslMode};
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, NoTls, RawConnection};
use postgres_client::{NoTls, RawCancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
@@ -327,8 +327,7 @@ impl ConnectInfo {
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
socket_addr,
CancelToken {
socket_config: None,
RawCancelToken {
ssl_mode: self.ssl_mode,
process_id,
secret_key,

View File

@@ -1,3 +1,6 @@
use std::time::Duration;
use futures::FutureExt;
use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
@@ -35,14 +38,11 @@ impl RedisKVClient {
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
match self.client.connect().await {
Ok(()) => {}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e);
}
}
Ok(())
self.client
.connect()
.boxed()
.await
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
}
pub(crate) async fn query<T: FromRedisValue>(
@@ -54,15 +54,25 @@ impl RedisKVClient {
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match q.query(&mut self.client).await {
let e = match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => {
tracing::error!("failed to run query: {e}");
Err(e) => e,
};
tracing::error!("failed to run query: {e}");
match e.retry_method() {
redis::RetryMethod::Reconnect => {
tracing::info!("Redis client is disconnected. Reconnecting...");
self.try_connect().await?;
}
redis::RetryMethod::RetryImmediately => {}
redis::RetryMethod::WaitAndRetry => {
// somewhat arbitrary.
tokio::time::sleep(Duration::from_millis(100)).await;
}
_ => Err(e)?,
}
tracing::info!("Redis client is disconnected. Reconnecting...");
self.try_connect().await?;
Ok(q.query(&mut self.client).await?)
}
}