diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 3dbc1709fd..f81cdf12d3 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -180,7 +180,7 @@ pub(crate) async fn handle_client( .await?? { HandshakeData::Startup(stream, params) => (stream, params), - HandshakeData::Cancel(cancel_key_data) => { + HandshakeData::Cancel(_, cancel_key_data) => { // spawn a task to cancel the session, but don't wait for it cancellations.spawn({ let cancellation_handler_clone = Arc::clone(&cancellation_handler); diff --git a/proxy/src/pglb/handshake.rs b/proxy/src/pglb/handshake.rs index 25a2d01b4a..d06632f3db 100644 --- a/proxy/src/pglb/handshake.rs +++ b/proxy/src/pglb/handshake.rs @@ -50,7 +50,7 @@ impl ReportableError for HandshakeError { pub(crate) enum HandshakeData { Startup(PqStream>, StartupMessageParams), - Cancel(CancelKeyData), + Cancel(Option, CancelKeyData), } /// Establish a (most probably, secure) connection with the client. @@ -234,8 +234,17 @@ pub(crate) async fn handshake( return Err(HandshakeError::ProtocolViolation); } FeStartupPacket::CancelRequest(cancel_key_data) => { - info!(session_type = "cancellation", "successful handshake"); - break Ok(HandshakeData::Cancel(cancel_key_data)); + let server_name = match stream.get_ref() { + Stream::Raw { .. } => None, + Stream::Tls { tls, .. } => tls.get_ref().1.server_name().map(String::from), + }; + + info!( + session_type = "cancellation", + server_name, "successful handshake" + ); + + break Ok(HandshakeData::Cancel(server_name, cancel_key_data)); } } } diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index 999fa6eb32..f7b196b717 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -3,19 +3,23 @@ pub mod handshake; pub mod inprocess; pub mod passthrough; +use std::net::IpAddr; +use std::str::FromStr; use std::sync::Arc; use futures::FutureExt; use smol_str::ToSmolStr; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; -use crate::auth; +use crate::auth::{self, Backend}; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; +use crate::control_plane::client::ControlPlaneClient; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; pub use crate::pglb::copy_bidirectional::ErrorSource; @@ -266,7 +270,28 @@ pub(crate) async fn handle_connection( .await?? { HandshakeData::Startup(client, params) => (client, params), - HandshakeData::Cancel(cancel_key_data) => { + HandshakeData::Cancel(server_name, cancel_key_data) => { + if let Backend::ControlPlane(api, ()) = auth_backend + && let ControlPlaneClient::LakebaseV1(lakebase) = &**api + { + let pod_suffix = format!(".{}.pod.cluster.local", lakebase.namespace); + + let pod_ip = server_name + .as_deref() + .and_then(|server_name| server_name.strip_suffix(&pod_suffix)) + .and_then(|pod_ip| IpAddr::from_str(&pod_ip.replace('-', ".")).ok()); + + if let Some(pod_ip) = pod_ip { + cancellations.spawn(async move { + let stream = TcpStream::connect((pod_ip, lakebase.port)).await?; + crate::pqproto::cancel(stream, cancel_key_data).await?; + anyhow::Ok(()) + }); + } + + return Ok(None); + } + // spawn a task to cancel the session, but don't wait for it cancellations.spawn({ let cancellation_handler_clone = Arc::clone(&cancellation_handler); diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index 7a68d430db..5e985907ad 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -9,6 +9,7 @@ use bytes::{Buf, BufMut}; use itertools::Itertools; use rand::distr::{Distribution, StandardUniform}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; pub type ErrorCode = [u8; 5]; @@ -53,6 +54,18 @@ impl fmt::Debug for ProtocolVersion { } } +pub async fn cancel(mut s: TcpStream, key: CancelKeyData) -> io::Result<()> { + s.write_all( + StartupHeader { + len: 16_u32.into(), + version: CANCEL_REQUEST_CODE, + } + .as_bytes(), + ) + .await?; + s.write_all(key.as_bytes()).await +} + /// const MAX_STARTUP_PACKET_LENGTH: usize = 10000; const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index b09d8edc4c..bd3ef00e08 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -48,7 +48,7 @@ async fn proxy_mitm( .unwrap() { HandshakeData::Startup(stream, params) => (stream, params), - HandshakeData::Cancel(_) => panic!("cancellation not supported"), + HandshakeData::Cancel(_, _) => panic!("cancellation not supported"), }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index d923f4b260..46b9d1e5aa 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -181,7 +181,7 @@ async fn dummy_proxy( ) -> anyhow::Result<()> { let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? { HandshakeData::Startup(stream, _) => stream, - HandshakeData::Cancel(_) => bail!("cancellation not supported"), + HandshakeData::Cancel(_, _) => bail!("cancellation not supported"), }; auth.authenticate(&mut stream).await?;