From 1718c0b59befddb84ebb9565d1ce7cc7cede804a Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Wed, 21 Feb 2024 23:43:55 +0100 Subject: [PATCH] Proxy: cancel query on connection drop (#6832) ## Problem https://github.com/neondatabase/cloud/issues/10259 ## Summary of changes Make sure that the request is dropped once the connection was dropped. --- proxy/src/cancellation.rs | 5 +- proxy/src/proxy/copy_bidirectional.rs | 100 +++++++++++++++----------- proxy/src/proxy/passthrough.rs | 10 ++- 3 files changed, 69 insertions(+), 46 deletions(-) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 93a77bc4ae..c9607909b3 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -168,12 +168,11 @@ impl CancelClosure { cancel_token, } } - /// Cancels the query running on user's compute node. - async fn try_cancel_query(self) -> Result<(), CancelError> { + pub async fn try_cancel_query(self) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; self.cancel_token.cancel_query_raw(socket, NoTls).await?; - + info!("query was cancelled"); Ok(()) } } diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 2ecc1151da..684be74f9a 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -1,4 +1,5 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::info; use std::future::poll_fn; use std::io; @@ -39,42 +40,51 @@ where } } -pub(super) async fn copy_bidirectional( - a: &mut A, - b: &mut B, +#[tracing::instrument(skip_all)] +pub(super) async fn copy_bidirectional_client_compute( + client: &mut Client, + compute: &mut Compute, ) -> Result<(u64, u64), std::io::Error> where - A: AsyncRead + AsyncWrite + Unpin + ?Sized, - B: AsyncRead + AsyncWrite + Unpin + ?Sized, + Client: AsyncRead + AsyncWrite + Unpin + ?Sized, + Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, { - let mut a_to_b = TransferState::Running(CopyBuffer::new()); - let mut b_to_a = TransferState::Running(CopyBuffer::new()); + let mut client_to_compute = TransferState::Running(CopyBuffer::new()); + let mut compute_to_client = TransferState::Running(CopyBuffer::new()); poll_fn(|cx| { - let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?; - let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?; + let mut client_to_compute_result = + transfer_one_direction(cx, &mut client_to_compute, client, compute)?; + let mut compute_to_client_result = + transfer_one_direction(cx, &mut compute_to_client, compute, client)?; - // Early termination checks - if let TransferState::Done(_) = a_to_b { - if let TransferState::Running(buf) = &b_to_a { + // Early termination checks from compute to client. + if let TransferState::Done(_) = compute_to_client { + if let TransferState::Running(buf) = &client_to_compute { + info!("Compute is done, terminate client"); // Initiate shutdown - b_to_a = TransferState::ShuttingDown(buf.amt); - b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?; + client_to_compute = TransferState::ShuttingDown(buf.amt); + client_to_compute_result = + transfer_one_direction(cx, &mut client_to_compute, client, compute)?; } } - if let TransferState::Done(_) = b_to_a { - if let TransferState::Running(buf) = &a_to_b { + + // Early termination checks from compute to client. + if let TransferState::Done(_) = client_to_compute { + if let TransferState::Running(buf) = &compute_to_client { + info!("Client is done, terminate compute"); // Initiate shutdown - a_to_b = TransferState::ShuttingDown(buf.amt); - a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?; + compute_to_client = TransferState::ShuttingDown(buf.amt); + compute_to_client_result = + transfer_one_direction(cx, &mut compute_to_client, client, compute)?; } } // It is not a problem if ready! returns early ... (comment remains the same) - let a_to_b = ready!(a_to_b_result); - let b_to_a = ready!(b_to_a_result); + let client_to_compute = ready!(client_to_compute_result); + let compute_to_client = ready!(compute_to_client_result); - Poll::Ready(Ok((a_to_b, b_to_a))) + Poll::Ready(Ok((client_to_compute, compute_to_client))) }) .await } @@ -219,38 +229,46 @@ mod tests { use tokio::io::AsyncWriteExt; #[tokio::test] - async fn test_early_termination_a_to_d() { - let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream - let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream + async fn test_client_to_compute() { + let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream + let (mut compute_proxy, mut compute_client) = tokio::io::duplex(32); // Create a mock duplex stream // Simulate 'a' finishing while there's still data for 'b' - a_mock.write_all(b"hello").await.unwrap(); - a_mock.shutdown().await.unwrap(); - d_mock.write_all(b"Neon Serverless Postgres").await.unwrap(); + client_client.write_all(b"hello").await.unwrap(); + client_client.shutdown().await.unwrap(); + compute_client.write_all(b"Neon").await.unwrap(); + compute_client.shutdown().await.unwrap(); - let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap(); + let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy) + .await + .unwrap(); // Assert correct transferred amounts - let (a_to_d_count, d_to_a_count) = result; - assert_eq!(a_to_d_count, 5); // 'hello' was transferred - assert!(d_to_a_count <= 8); // response only partially transferred or not at all + let (client_to_compute_count, compute_to_client_count) = result; + assert_eq!(client_to_compute_count, 5); // 'hello' was transferred + assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all } #[tokio::test] - async fn test_early_termination_d_to_a() { - let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream - let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream + async fn test_compute_to_client() { + let (mut client_client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream + let (mut compute_proxy, mut compute_client) = tokio::io::duplex(8); // Create a mock duplex stream // Simulate 'a' finishing while there's still data for 'b' - d_mock.write_all(b"hello").await.unwrap(); - d_mock.shutdown().await.unwrap(); - a_mock.write_all(b"Neon Serverless Postgres").await.unwrap(); + compute_client.write_all(b"hello").await.unwrap(); + compute_client.shutdown().await.unwrap(); + client_client + .write_all(b"Neon Serverless Postgres") + .await + .unwrap(); - let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap(); + let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy) + .await + .unwrap(); // Assert correct transferred amounts - let (a_to_d_count, d_to_a_count) = result; - assert_eq!(d_to_a_count, 5); // 'hello' was transferred - assert!(a_to_d_count <= 8); // response only partially transferred or not at all + let (client_to_compute_count, compute_to_client_count) = result; + assert_eq!(compute_to_client_count, 5); // 'hello' was transferred + assert!(client_to_compute_count <= 8); // response only partially transferred or not at all } } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 73c170fc0b..b2f682fd2f 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -46,7 +46,11 @@ pub async fn proxy_pass( // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?; + let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute( + &mut client, + &mut compute, + ) + .await?; Ok(()) } @@ -63,6 +67,8 @@ pub struct ProxyPassthrough { impl ProxyPassthrough { pub async fn proxy_pass(self) -> anyhow::Result<()> { - proxy_pass(self.client, self.compute.stream, self.aux).await + let res = proxy_pass(self.client, self.compute.stream, self.aux).await; + self.compute.cancel_closure.try_cancel_query().await?; + res } }