do not replace cancelkeydata

This commit is contained in:
Conrad Ludgate
2025-06-27 16:50:42 +01:00
committed by Conrad Ludgate
parent d0e579c026
commit 725aed694b
2 changed files with 63 additions and 56 deletions

View File

@@ -15,6 +15,7 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::ClientRequestError; use crate::pglb::ClientRequestError;
use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough; use crate::pglb::passthrough::ProxyPassthrough;
use crate::pqproto::CancelKeyData;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::{ use crate::proxy::{
ErrorSource, connect_compute, forward_compute_params_to_client, send_client_greeting, ErrorSource, connect_compute, forward_compute_params_to_client, send_client_greeting,
@@ -207,7 +208,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx.set_db_options(params.clone()); ctx.set_db_options(params.clone());
let (node_info, mut auth_info, user_info) = match backend let (node_info, mut auth_info, _user_info) = match backend
.authenticate(ctx, &config.authentication_config, &mut stream) .authenticate(ctx, &config.authentication_config, &mut stream)
.await .await
{ {
@@ -231,35 +232,34 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.await?; .await?;
send_client_greeting(ctx, &config.greetings, &mut stream); send_client_greeting(ctx, &config.greetings, &mut stream);
let session = cancellation_handler.get_key(); // let session = cancellation_handler.get_key();
let (process_id, secret_key) = let (_process_id, _secret_key) =
forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream) forward_compute_params_to_client(ctx, None, &mut stream, &mut node.stream).await?;
.await?;
let stream = stream.flush_and_into_inner().await?; let stream = stream.flush_and_into_inner().await?;
let hostname = node.hostname.to_string(); // let hostname = node.hostname.to_string();
let session_id = ctx.session_id(); // let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); let (cancel_on_shutdown, _cancel) = tokio::sync::oneshot::channel();
tokio::spawn(async move { // tokio::spawn(async move {
session // session
.maintain_cancel_key( // .maintain_cancel_key(
session_id, // session_id,
cancel, // cancel,
&CancelClosure { // &CancelClosure {
socket_addr: node.socket_addr, // socket_addr: node.socket_addr,
cancel_token: RawCancelToken { // cancel_token: RawCancelToken {
ssl_mode: node.ssl_mode, // ssl_mode: node.ssl_mode,
process_id, // process_id,
secret_key, // secret_key,
}, // },
hostname, // hostname,
user_info, // user_info,
}, // },
&config.connect_to_compute, // &config.connect_to_compute,
) // )
.await; // .await;
}); // });
Ok(Some(ProxyPassthrough { Ok(Some(ProxyPassthrough {
client: stream, client: stream,

View File

@@ -42,7 +42,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig, config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>, auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext, ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>, _cancellation_handler: Arc<CancellationHandler>,
client: &mut PqStream<Stream<S>>, client: &mut PqStream<Stream<S>>,
mode: &ClientMode, mode: &ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>, endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -114,7 +114,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
send_client_greeting(ctx, &config.greetings, client); send_client_greeting(ctx, &config.greetings, client);
let auth::Backend::ControlPlane(_, user_info) = backend else { let auth::Backend::ControlPlane(_, _user_info) = backend else {
unreachable!("ensured above"); unreachable!("ensured above");
}; };
@@ -124,33 +124,33 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
client.write_message(BeMessage::AuthenticationOk); client.write_message(BeMessage::AuthenticationOk);
} }
let session = cancellation_handler.get_key(); // let session = cancellation_handler.get_key();
let (process_id, secret_key) = let (_process_id, _secret_key) =
forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?; forward_compute_params_to_client(ctx, None, client, &mut node.stream).await?;
let hostname = node.hostname.to_string(); // let hostname = node.hostname.to_string();
let session_id = ctx.session_id(); // let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = oneshot::channel(); let (cancel_on_shutdown, _cancel) = oneshot::channel();
tokio::spawn(async move { // tokio::spawn(async move {
session // session
.maintain_cancel_key( // .maintain_cancel_key(
session_id, // session_id,
cancel, // cancel,
&CancelClosure { // &CancelClosure {
socket_addr: node.socket_addr, // socket_addr: node.socket_addr,
cancel_token: RawCancelToken { // cancel_token: RawCancelToken {
ssl_mode: node.ssl_mode, // ssl_mode: node.ssl_mode,
process_id, // process_id,
secret_key, // secret_key,
}, // },
hostname, // hostname,
user_info, // user_info,
}, // },
&config.connect_to_compute, // &config.connect_to_compute,
) // )
.await; // .await;
}); // });
Ok((node, cancel_on_shutdown)) Ok((node, cancel_on_shutdown))
} }
@@ -200,7 +200,7 @@ pub(crate) fn send_client_greeting(
pub(crate) async fn forward_compute_params_to_client( pub(crate) async fn forward_compute_params_to_client(
ctx: &RequestContext, ctx: &RequestContext,
cancel_key_data: CancelKeyData, cancel_key_data: Option<CancelKeyData>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>, client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
compute: &mut StartupStream<TcpStream, RustlsStream>, compute: &mut StartupStream<TcpStream, RustlsStream>,
) -> Result<(i32, i32), ClientRequestError> { ) -> Result<(i32, i32), ClientRequestError> {
@@ -219,9 +219,16 @@ pub(crate) async fn forward_compute_params_to_client(
match msg { match msg {
// Send our cancellation key data instead. // Send our cancellation key data instead.
Some(Message::BackendKeyData(body)) => { Some(Message::BackendKeyData(body)) => {
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
process_id = body.process_id(); process_id = body.process_id();
secret_key = body.secret_key(); secret_key = body.secret_key();
let cancel_key_data = cancel_key_data.unwrap_or_else(|| {
let pid = process_id as u32;
let key = secret_key as u32;
CancelKeyData(((pid as u64) << 32 | (key as u64)).into())
});
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
} }
// Forward all postgres connection params to the client. // Forward all postgres connection params to the client.
Some(Message::ParameterStatus(body)) => { Some(Message::ParameterStatus(body)) => {