Proxy: cancel query on connection drop (#6832)

## Problem

https://github.com/neondatabase/cloud/issues/10259

## Summary of changes

Make sure that the request is dropped once the connection was dropped.
This commit is contained in:
Anna Khanova
2024-02-21 23:43:55 +01:00
committed by GitHub
parent 8107ae8377
commit 1718c0b59b
3 changed files with 69 additions and 46 deletions

View File

@@ -168,12 +168,11 @@ impl CancelClosure {
cancel_token,
}
}
/// Cancels the query running on user's compute node.
async fn try_cancel_query(self) -> Result<(), CancelError> {
pub async fn try_cancel_query(self) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
info!("query was cancelled");
Ok(())
}
}

View File

@@ -1,4 +1,5 @@
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::info;
use std::future::poll_fn;
use std::io;
@@ -39,42 +40,51 @@ where
}
}
pub(super) async fn copy_bidirectional<A, B>(
a: &mut A,
b: &mut B,
#[tracing::instrument(skip_all)]
pub(super) async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,
) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut a_to_b = TransferState::Running(CopyBuffer::new());
let mut b_to_a = TransferState::Running(CopyBuffer::new());
let mut client_to_compute = TransferState::Running(CopyBuffer::new());
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
poll_fn(|cx| {
let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
let mut client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)?;
let mut compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)?;
// Early termination checks
if let TransferState::Done(_) = a_to_b {
if let TransferState::Running(buf) = &b_to_a {
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
b_to_a = TransferState::ShuttingDown(buf.amt);
b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)?;
}
}
if let TransferState::Done(_) = b_to_a {
if let TransferState::Running(buf) = &a_to_b {
// Early termination checks from compute to client.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
a_to_b = TransferState::ShuttingDown(buf.amt);
a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, client, compute)?;
}
}
// It is not a problem if ready! returns early ... (comment remains the same)
let a_to_b = ready!(a_to_b_result);
let b_to_a = ready!(b_to_a_result);
let client_to_compute = ready!(client_to_compute_result);
let compute_to_client = ready!(compute_to_client_result);
Poll::Ready(Ok((a_to_b, b_to_a)))
Poll::Ready(Ok((client_to_compute, compute_to_client)))
})
.await
}
@@ -219,38 +229,46 @@ mod tests {
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn test_early_termination_a_to_d() {
let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream
let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream
async fn test_client_to_compute() {
let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
let (mut compute_proxy, mut compute_client) = tokio::io::duplex(32); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
a_mock.write_all(b"hello").await.unwrap();
a_mock.shutdown().await.unwrap();
d_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
client_client.write_all(b"hello").await.unwrap();
client_client.shutdown().await.unwrap();
compute_client.write_all(b"Neon").await.unwrap();
compute_client.shutdown().await.unwrap();
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
.await
.unwrap();
// Assert correct transferred amounts
let (a_to_d_count, d_to_a_count) = result;
assert_eq!(a_to_d_count, 5); // 'hello' was transferred
assert!(d_to_a_count <= 8); // response only partially transferred or not at all
let (client_to_compute_count, compute_to_client_count) = result;
assert_eq!(client_to_compute_count, 5); // 'hello' was transferred
assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all
}
#[tokio::test]
async fn test_early_termination_d_to_a() {
let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream
async fn test_compute_to_client() {
let (mut client_client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut compute_proxy, mut compute_client) = tokio::io::duplex(8); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
d_mock.write_all(b"hello").await.unwrap();
d_mock.shutdown().await.unwrap();
a_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
compute_client.write_all(b"hello").await.unwrap();
compute_client.shutdown().await.unwrap();
client_client
.write_all(b"Neon Serverless Postgres")
.await
.unwrap();
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
.await
.unwrap();
// Assert correct transferred amounts
let (a_to_d_count, d_to_a_count) = result;
assert_eq!(d_to_a_count, 5); // 'hello' was transferred
assert!(a_to_d_count <= 8); // response only partially transferred or not at all
let (client_to_compute_count, compute_to_client_count) = result;
assert_eq!(compute_to_client_count, 5); // 'hello' was transferred
assert!(client_to_compute_count <= 8); // response only partially transferred or not at all
}
}

View File

@@ -46,7 +46,11 @@ pub async fn proxy_pass(
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?;
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
&mut client,
&mut compute,
)
.await?;
Ok(())
}
@@ -63,6 +67,8 @@ pub struct ProxyPassthrough<S> {
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub async fn proxy_pass(self) -> anyhow::Result<()> {
proxy_pass(self.client, self.compute.stream, self.aux).await
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
self.compute.cancel_closure.try_cancel_query().await?;
res
}
}