diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 34512e9f5b..8e569e4370 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -151,21 +151,34 @@ impl CancellationHandler>>> { #[derive(Clone)] pub struct CancelClosure { socket_addr: SocketAddr, - cancel_token: CancelToken, + cancel_token: Option, } impl CancelClosure { pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self { Self { socket_addr, - cancel_token, + cancel_token: Some(cancel_token), } } + + #[cfg(test)] + pub fn test() -> Self { + use std::net::{Ipv4Addr, SocketAddrV4}; + + Self { + socket_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from_bits(0), 0)), + cancel_token: None, + } + } + /// Cancels the query running on user's compute node. pub async fn try_cancel_query(self) -> Result<(), CancelError> { - let socket = TcpStream::connect(self.socket_addr).await?; - self.cancel_token.cancel_query_raw(socket, NoTls).await?; - info!("query was cancelled"); + if let Some(cancel_token) = self.cancel_token { + let socket = TcpStream::connect(self.socket_addr).await?; + cancel_token.cancel_query_raw(socket, NoTls).await?; + info!("query was cancelled"); + } Ok(()) } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 977d7bda82..e8d1cc3b3a 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -5,6 +5,7 @@ use tracing::{field::display, info}; use crate::{ auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, + cancellation::CancelClosure, compute::{self, ConnectionError}, config::{AuthenticationConfig, ProxyConfig}, console::{ @@ -229,13 +230,16 @@ impl ConnectMechanism for TokioMechanism { let host = node_info.config.get_host()?; let permit = self.locks.get_permit(&host).await?; - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = node_info - .config - .connect2(ctx, timeout, &mut tokio_postgres::NoTls) - .await; - drop(pause); - let (_, client, connection) = permit.release_result(res)?; + let (socket_addr, client, connection) = permit.release_result( + node_info + .config + .connect2(ctx, timeout, &mut tokio_postgres::NoTls) + .await, + )?; + + // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. + // Yet another reason to rework the connection establishing code. + let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token()); tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); Ok(poll_client( @@ -246,6 +250,7 @@ impl ConnectMechanism for TokioMechanism { connection, self.conn_id, node_info.aux.clone(), + cancel_closure, )) } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 5bb136072d..ce0d416509 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -18,6 +18,7 @@ use tokio_postgres::tls::NoTlsStream; use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use tokio_util::sync::CancellationToken; +use crate::cancellation::CancelClosure; use crate::console::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; @@ -464,6 +465,7 @@ impl GlobalConnPool { } } +#[allow(clippy::too_many_arguments)] pub fn poll_client( global_pool: Arc>, ctx: &RequestMonitoring, @@ -472,6 +474,7 @@ pub fn poll_client( mut connection: tokio_postgres::Connection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, + cancel_closure: CancelClosure, ) -> Client { let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); let mut session_id = ctx.session_id(); @@ -573,6 +576,7 @@ pub fn poll_client( cancel, aux, conn_id, + cancel_closure, }; Client::new(inner, conn_info, pool_clone) } @@ -583,6 +587,7 @@ struct ClientInner { cancel: CancellationToken, aux: MetricsAuxInfo, conn_id: uuid::Uuid, + cancel_closure: CancelClosure, } impl Drop for ClientInner { @@ -647,7 +652,7 @@ impl Client { pool, } } - pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) { + pub fn inner(&mut self) -> (&mut C, &CancelClosure, Discard<'_, C>) { let Self { inner, pool, @@ -655,7 +660,11 @@ impl Client { span: _, } = self; let inner = inner.as_mut().expect("client inner should not be removed"); - (&mut inner.inner, Discard { pool, conn_info }) + ( + &mut inner.inner, + &inner.cancel_closure, + Discard { pool, conn_info }, + ) } } @@ -752,6 +761,7 @@ mod tests { cold_start_info: crate::console::messages::ColdStartInfo::Warm, }, conn_id: uuid::Uuid::new_v4(), + cancel_closure: CancelClosure::test(), } } @@ -786,7 +796,7 @@ mod tests { { let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); assert_eq!(0, pool.get_global_connections_count()); - client.inner().1.discard(); + client.inner().2.discard(); // Discard should not add the connection from the pool. assert_eq!(0, pool.get_global_connections_count()); } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 5d8f0bd6c4..abbb67460d 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -26,7 +26,6 @@ use tokio_postgres::error::ErrorPosition; use tokio_postgres::error::SqlState; use tokio_postgres::GenericClient; use tokio_postgres::IsolationLevel; -use tokio_postgres::NoTls; use tokio_postgres::ReadyForQueryStatus; use tokio_postgres::Transaction; use tokio_util::sync::CancellationToken; @@ -40,7 +39,6 @@ use utils::http::error::ApiError; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; use crate::auth::ComputeUserInfoParseError; -use crate::compute::ConnectionError; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; @@ -263,7 +261,7 @@ pub async fn handle( let mut message = e.to_string_client(); let db_error = match &e { SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError( - ConnectionError::Postgres(e), + crate::compute::ConnectionError::Postgres(e), )) | SqlOverHttpError::Postgres(e) => e.as_db_error(), _ => None, @@ -625,8 +623,7 @@ impl QueryData { client: &mut Client, parsed_headers: HttpHeaders, ) -> Result { - let (inner, mut discard) = client.inner(); - let cancel_token = inner.cancel_token(); + let (inner, cancel_token, mut discard) = client.inner(); let res = match select( pin!(query_to_json(&*inner, self, &mut 0, parsed_headers)), @@ -650,7 +647,7 @@ impl QueryData { // The query was cancelled. Either::Right((_cancelled, query)) => { tracing::info!("cancelling query"); - if let Err(err) = cancel_token.cancel_query(NoTls).await { + if let Err(err) = cancel_token.clone().try_cancel_query().await { tracing::error!(?err, "could not cancel query"); } // wait for the query cancellation @@ -667,7 +664,7 @@ impl QueryData { Ok(Err(error)) => { let db_error = match &error { SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError( - ConnectionError::Postgres(e), + crate::compute::ConnectionError::Postgres(e), )) | SqlOverHttpError::Postgres(e) => e.as_db_error(), _ => None, @@ -699,8 +696,7 @@ impl BatchQueryData { parsed_headers: HttpHeaders, ) -> Result { info!("starting transaction"); - let (inner, mut discard) = client.inner(); - let cancel_token = inner.cancel_token(); + let (inner, cancel_token, mut discard) = client.inner(); let mut builder = inner.build_transaction(); if let Some(isolation_level) = parsed_headers.txn_isolation_level { builder = builder.isolation_level(isolation_level); @@ -733,7 +729,7 @@ impl BatchQueryData { json_output } Err(SqlOverHttpError::Cancelled(_)) => { - if let Err(err) = cancel_token.cancel_query(NoTls).await { + if let Err(err) = cancel_token.clone().try_cancel_query().await { tracing::error!(?err, "could not cancel query"); } // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.