mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-13 19:20:36 +00:00
Compare commits
13 Commits
conrad/jso
...
conrad/rem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19cbffc796 | ||
|
|
c05f105035 | ||
|
|
a17a882895 | ||
|
|
49b6ee6c57 | ||
|
|
69d21d85bc | ||
|
|
78f37e6f11 | ||
|
|
e5fd708495 | ||
|
|
f1e3f2259e | ||
|
|
29c4db1658 | ||
|
|
727c333831 | ||
|
|
14312f1a9a | ||
|
|
008cd84e7b | ||
|
|
68d561664b |
@@ -383,12 +383,19 @@ async fn handle_client(
|
||||
info!("performing the proxy pass...");
|
||||
|
||||
let res = match client {
|
||||
Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
||||
Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
||||
Connection::Raw(mut c) => {
|
||||
copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await
|
||||
}
|
||||
Connection::Tls(mut c) => {
|
||||
copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await
|
||||
}
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(_) => Ok(()),
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Timeout(_)) => Err(anyhow!(
|
||||
"timed out while gracefully shutting down the connection"
|
||||
)),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
|
||||
@@ -129,6 +129,12 @@ pub async fn task_main(
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Timeout(_)) => {
|
||||
info!(
|
||||
?session_id,
|
||||
"per-client task timed out while gracefully shutting down the connection"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
|
||||
@@ -200,8 +200,10 @@ pub enum HttpDirection {
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "direction")]
|
||||
pub enum Direction {
|
||||
Tx,
|
||||
Rx,
|
||||
#[label(rename = "tx")]
|
||||
ComputeToClient,
|
||||
#[label(rename = "rx")]
|
||||
ClientToCompute,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
|
||||
@@ -1,313 +1,394 @@
|
||||
use std::future::poll_fn;
|
||||
use std::io;
|
||||
use std::ops::Range;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll, ready};
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TransferState {
|
||||
Running(CopyBuffer),
|
||||
ShuttingDown(u64),
|
||||
Done(u64),
|
||||
}
|
||||
use crate::metrics::Direction;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ErrorDirection {
|
||||
Read(io::Error),
|
||||
Write(io::Error),
|
||||
}
|
||||
const DISCONNECT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
impl ErrorSource {
|
||||
fn from_client(err: ErrorDirection) -> ErrorSource {
|
||||
match err {
|
||||
ErrorDirection::Read(client) => Self::Client(client),
|
||||
ErrorDirection::Write(compute) => Self::Compute(compute),
|
||||
}
|
||||
}
|
||||
fn from_compute(err: ErrorDirection) -> ErrorSource {
|
||||
match err {
|
||||
ErrorDirection::Write(client) => Self::Client(client),
|
||||
ErrorDirection::Read(compute) => Self::Compute(compute),
|
||||
}
|
||||
}
|
||||
/// Mark a value as being unlikely.
|
||||
#[cold]
|
||||
#[inline(always)]
|
||||
fn cold<I>(i: I) -> I {
|
||||
i
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ErrorSource {
|
||||
Client(io::Error),
|
||||
Compute(io::Error),
|
||||
Timeout(tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
fn transfer_one_direction<A, B>(
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut TransferState,
|
||||
r: &mut A,
|
||||
w: &mut B,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
where
|
||||
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut r = Pin::new(r);
|
||||
let mut w = Pin::new(w);
|
||||
loop {
|
||||
match state {
|
||||
TransferState::Running(buf) => {
|
||||
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
|
||||
*state = TransferState::ShuttingDown(count);
|
||||
}
|
||||
TransferState::ShuttingDown(count) => {
|
||||
ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?;
|
||||
*state = TransferState::Done(*count);
|
||||
}
|
||||
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
|
||||
impl ErrorSource {
|
||||
fn read(dir: Direction, err: io::Error) -> Self {
|
||||
match dir {
|
||||
Direction::ComputeToClient => ErrorSource::Compute(err),
|
||||
Direction::ClientToCompute => ErrorSource::Client(err),
|
||||
}
|
||||
}
|
||||
|
||||
fn write(dir: Direction, err: io::Error) -> Self {
|
||||
match dir {
|
||||
Direction::ComputeToClient => ErrorSource::Client(err),
|
||||
Direction::ClientToCompute => ErrorSource::Compute(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn copy_bidirectional_client_compute<Client, Compute>(
|
||||
client: &mut Client,
|
||||
compute: &mut Compute,
|
||||
) -> Result<(u64, u64), ErrorSource>
|
||||
mut f: impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
) -> Result<(), ErrorSource>
|
||||
where
|
||||
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut client_to_compute = TransferState::Running(CopyBuffer::new());
|
||||
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
|
||||
let f = &mut f;
|
||||
let mut client_to_compute = CopyBuffer::new(Direction::ClientToCompute);
|
||||
let mut compute_to_client = CopyBuffer::new(Direction::ComputeToClient);
|
||||
|
||||
poll_fn(|cx| {
|
||||
let mut client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
let mut compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
let mut client = Pin::new(client);
|
||||
let mut compute = Pin::new(compute);
|
||||
|
||||
// TODO: 1 info log, with a enum label for close direction.
|
||||
// Initial copy hot path
|
||||
let close_dir = poll_fn(|cx| -> Poll<Result<_, ErrorSource>> {
|
||||
let copy1 = client_to_compute.poll_copy(cx, f, client.as_mut(), compute.as_mut())?;
|
||||
let copy2 = compute_to_client.poll_copy(cx, f, compute.as_mut(), client.as_mut())?;
|
||||
|
||||
// Early termination checks from compute to client.
|
||||
if let TransferState::Done(_) = compute_to_client {
|
||||
if let TransferState::Running(buf) = &client_to_compute {
|
||||
info!("Compute is done, terminate client");
|
||||
// Initiate shutdown
|
||||
client_to_compute = TransferState::ShuttingDown(buf.amt);
|
||||
client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
}
|
||||
match (copy1, copy2) {
|
||||
(Poll::Pending, Poll::Pending) => Poll::Pending,
|
||||
(Poll::Ready(()), _) => Poll::Ready(Ok(client_to_compute.dir)),
|
||||
(_, Poll::Ready(())) => Poll::Ready(Ok(compute_to_client.dir)),
|
||||
}
|
||||
|
||||
// Early termination checks from client to compute.
|
||||
if let TransferState::Done(_) = client_to_compute {
|
||||
if let TransferState::Running(buf) = &compute_to_client {
|
||||
info!("Client is done, terminate compute");
|
||||
// Initiate shutdown
|
||||
compute_to_client = TransferState::ShuttingDown(buf.amt);
|
||||
compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
}
|
||||
}
|
||||
|
||||
// It is not a problem if ready! returns early ... (comment remains the same)
|
||||
let client_to_compute = ready!(client_to_compute_result);
|
||||
let compute_to_client = ready!(compute_to_client_result);
|
||||
|
||||
Poll::Ready(Ok((client_to_compute, compute_to_client)))
|
||||
})
|
||||
.await
|
||||
.await?;
|
||||
|
||||
// initiate shutdown.
|
||||
match close_dir {
|
||||
Direction::ClientToCompute => {
|
||||
info!("Client is done, terminate compute");
|
||||
|
||||
// we will never write anymore data to the client.
|
||||
compute_to_client.filled = 0..0;
|
||||
|
||||
// make sure to shutdown the client conn.
|
||||
compute_to_client.need_flush = true;
|
||||
}
|
||||
Direction::ComputeToClient => {
|
||||
info!("Compute is done, terminate client");
|
||||
|
||||
// we will never write anymore data to the compute.
|
||||
client_to_compute.filled = 0..0;
|
||||
|
||||
// make sure to shutdown the compute conn.
|
||||
client_to_compute.need_flush = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Finish sending the rest of the data to client/compute before shutting it down.
|
||||
//
|
||||
// Edge case:
|
||||
// * peer has filled the TCP buffers and is blocking on a `write()`,
|
||||
// * proxy has filled the TCP buffers and is waiting on a `write()`.
|
||||
// Since no side is reading from the buffers, no progress will be made.
|
||||
let shutdown = poll_fn(|cx| {
|
||||
let res1 = client_to_compute.poll_empty(cx, compute.as_mut())?;
|
||||
let res2 = compute_to_client.poll_empty(cx, client.as_mut())?;
|
||||
|
||||
if res1.is_ready() && res2.is_ready() {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
});
|
||||
|
||||
// We assume most peers will have enough buffer space so this issue doesn't arise, but we apply
|
||||
// a timeout just in case.
|
||||
//
|
||||
// We could also update `poll_empty` to try and read the data, but I think this is not an edge case
|
||||
// worth overcomplicating.
|
||||
let res = tokio::time::timeout(DISCONNECT_TIMEOUT, shutdown).await;
|
||||
|
||||
match res {
|
||||
Ok(res) => res,
|
||||
Err(timeout) => Err(ErrorSource::Timeout(timeout)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct CopyBuffer {
|
||||
read_done: bool,
|
||||
need_flush: bool,
|
||||
pos: usize,
|
||||
cap: usize,
|
||||
amt: u64,
|
||||
buf: Box<[u8]>,
|
||||
}
|
||||
const DEFAULT_BUF_SIZE: usize = 1024;
|
||||
pub(super) struct CopyBuffer {
|
||||
dir: Direction,
|
||||
need_flush: bool,
|
||||
filled: Range<usize>,
|
||||
buf: [u8; DEFAULT_BUF_SIZE],
|
||||
}
|
||||
|
||||
impl CopyBuffer {
|
||||
pub(super) fn new() -> Self {
|
||||
pub(super) const fn new(dir: Direction) -> Self {
|
||||
Self {
|
||||
read_done: false,
|
||||
dir,
|
||||
need_flush: false,
|
||||
pos: 0,
|
||||
cap: 0,
|
||||
amt: 0,
|
||||
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
|
||||
filled: 0..0,
|
||||
buf: [0; DEFAULT_BUF_SIZE],
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_fill_buf<R>(
|
||||
/// Returns Ready(Ok(())) when no more writes could progress, and the buffer has space to read.
|
||||
#[inline(always)]
|
||||
fn poll_write_loop<W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
reader: Pin<&mut R>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
{
|
||||
let me = &mut *self;
|
||||
let mut buf = ReadBuf::new(&mut me.buf);
|
||||
buf.set_filled(me.cap);
|
||||
|
||||
let res = reader.poll_read(cx, &mut buf);
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
let filled_len = buf.filled().len();
|
||||
me.read_done = me.cap == filled_len;
|
||||
me.cap = filled_len;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn poll_write_buf<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<usize, ErrorDirection>>
|
||||
) -> Poll<Result<(), ErrorSource>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
let me = &mut *self;
|
||||
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
|
||||
Poll::Pending => {
|
||||
// Top up the buffer towards full if we can read a bit more
|
||||
// data - this should improve the chances of a large write
|
||||
if !me.read_done && me.cap < me.buf.len() {
|
||||
ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
res @ Poll::Ready(_) => res.map_err(ErrorDirection::Write),
|
||||
}
|
||||
}
|
||||
debug_assert!(!self.filled.is_empty());
|
||||
|
||||
pub(super) fn poll_copy<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
loop {
|
||||
// If there is some space left in our buffer, then we try to read some
|
||||
// data to continue, thus maximizing the chances of a large write.
|
||||
if self.cap < self.buf.len() && !self.read_done {
|
||||
match self.poll_fill_buf(cx, reader.as_mut()) {
|
||||
Poll::Ready(Ok(())) => (),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
|
||||
Poll::Pending => {
|
||||
// Ignore pending reads when our buffer is not empty, because
|
||||
// we can try to write data immediately.
|
||||
if self.pos == self.cap {
|
||||
// Try flushing when the reader has no progress to avoid deadlock
|
||||
// when the reader depends on buffered writer.
|
||||
if self.need_flush {
|
||||
ready!(writer.as_mut().poll_flush(cx))
|
||||
.map_err(ErrorDirection::Write)?;
|
||||
self.need_flush = false;
|
||||
}
|
||||
let filled_buf = &self.buf[self.filled.clone()];
|
||||
match writer.as_mut().poll_write(cx, filled_buf) {
|
||||
Poll::Ready(Err(err)) => {
|
||||
return Poll::Ready(Err(ErrorSource::write(self.dir, cold(err))));
|
||||
}
|
||||
Poll::Ready(Ok(0)) => {
|
||||
let err =
|
||||
io::Error::new(io::ErrorKind::WriteZero, "write zero byte into writer");
|
||||
return Poll::Ready(Err(ErrorSource::write(self.dir, cold(err))));
|
||||
}
|
||||
Poll::Ready(Ok(i)) => {
|
||||
// update the write head.
|
||||
self.filled.start += i;
|
||||
self.need_flush = true;
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
// we wrote some data, but the filled buffer might not be fully empty yet.
|
||||
if !self.filled.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// the buffer is definitely empty. reset positions.
|
||||
self.filled = 0..0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
|
||||
if i == 0 {
|
||||
return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"write zero byte into writer",
|
||||
))));
|
||||
}
|
||||
self.pos += i;
|
||||
self.amt += i as u64;
|
||||
self.need_flush = true;
|
||||
}
|
||||
|
||||
// If pos larger than cap, this loop will never stop.
|
||||
// In particular, user's wrong poll_write implementation returning
|
||||
// incorrect written length may lead to thread blocking.
|
||||
debug_assert!(
|
||||
self.pos <= self.cap,
|
||||
"writer returned length larger than input slice"
|
||||
);
|
||||
|
||||
// All data has been written, the buffer can be considered empty again
|
||||
self.pos = 0;
|
||||
self.cap = 0;
|
||||
|
||||
// If we've written all the data and we've seen EOF, flush out the
|
||||
// data and finish the transfer.
|
||||
if self.read_done {
|
||||
ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?;
|
||||
return Poll::Ready(Ok(self.amt));
|
||||
// While we couldn't write, we might be able to read.
|
||||
Poll::Pending if self.filled.end < self.buf.len() => break,
|
||||
// We couldn't write, and have no space to read. Just exit.
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
/// Returns Ready(Ok((true))) when read returns EOF.
|
||||
/// Returns Ready(Ok((false))) when read returns data.
|
||||
#[inline(always)]
|
||||
fn poll_read_once<R>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
reader: Pin<&mut R>,
|
||||
) -> Poll<Result<bool, ErrorSource>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
{
|
||||
debug_assert!(self.filled.end < self.buf.len());
|
||||
let mut buf = ReadBuf::new(&mut self.buf[self.filled.end..]);
|
||||
|
||||
match reader.poll_read(cx, &mut buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let filled = buf.filled();
|
||||
// no more data to read, switch to shutdown mode.
|
||||
if filled.is_empty() {
|
||||
self.need_flush = true;
|
||||
return Poll::Ready(Ok(true));
|
||||
}
|
||||
|
||||
// run our inspection callback.
|
||||
f(self.dir, filled);
|
||||
|
||||
// update the read head.
|
||||
self.filled.end += filled.len();
|
||||
|
||||
// read more data
|
||||
Poll::Ready(Ok(false))
|
||||
}
|
||||
// cannot continue on error.
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(ErrorSource::read(self.dir, cold(e)))),
|
||||
// No more data to read, and no more data to write.
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns Ready(Ok(())) when read returns EOF.
|
||||
fn poll_copy<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<(), ErrorSource>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
// this register eliminates a branch in the hot loop.
|
||||
let mut empty = self.filled.is_empty();
|
||||
|
||||
// write then read hot loop
|
||||
loop {
|
||||
if !empty {
|
||||
ready!(self.poll_write_loop(cx, writer.as_mut())?);
|
||||
}
|
||||
|
||||
// If empty is true, there is guaranteed space to read.
|
||||
// If empty is false, and the write loop returned ready, then we know there's space for more reads.
|
||||
match self.poll_read_once(cx, f, reader.as_mut())? {
|
||||
// EOF
|
||||
Poll::Ready(true) => return Poll::Ready(Ok(())),
|
||||
// Needs write.
|
||||
Poll::Ready(false) => empty = false,
|
||||
// Cannot read. The peer might not send us anything until
|
||||
// they receive data from us, so let's switch to flushing.
|
||||
Poll::Pending => break,
|
||||
}
|
||||
}
|
||||
|
||||
if self.need_flush {
|
||||
let flush = writer.as_mut().poll_flush(cx);
|
||||
ready!(flush.map_err(|e| ErrorSource::write(self.dir, e))?);
|
||||
self.need_flush = false;
|
||||
}
|
||||
|
||||
// there might be more data still to read.
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
/// Returns Ready(Ok(())) when the conn is fully shutdown.
|
||||
pub(super) fn poll_empty<W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<(), ErrorSource>>
|
||||
where
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
if !self.filled.is_empty() {
|
||||
ready!(self.poll_write_loop(cx, writer.as_mut())?);
|
||||
|
||||
if !self.filled.is_empty() {
|
||||
// still some data to write
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
if self.need_flush {
|
||||
let res = writer.poll_shutdown(cx);
|
||||
ready!(res.map_err(|e| ErrorSource::write(self.dir, e))?);
|
||||
self.need_flush = false;
|
||||
}
|
||||
|
||||
// no data to read, no data to write.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_to_compute() {
|
||||
let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
let (mut compute_proxy, mut compute_client) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
let (mut client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
let (mut proxy_compute, mut compute) = tokio::io::duplex(16); // Create a mock duplex stream
|
||||
|
||||
// Simulate 'a' finishing while there's still data for 'b'
|
||||
client_client.write_all(b"hello").await.unwrap();
|
||||
client_client.shutdown().await.unwrap();
|
||||
compute_client.write_all(b"Neon").await.unwrap();
|
||||
compute_client.shutdown().await.unwrap();
|
||||
client.write_all(b"Neon Serverless Postgres").await.unwrap();
|
||||
compute.write_all(b"is amazing").await.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
client.shutdown().await.unwrap();
|
||||
|
||||
let copy = tokio::spawn(async move {
|
||||
copy_bidirectional_client_compute(&mut client_proxy, &mut proxy_compute, |_, _| {})
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
assert_eq!(client_to_compute_count, 5); // 'hello' was transferred
|
||||
assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all
|
||||
let mut client_recv = String::new();
|
||||
let mut compute_recv = String::new();
|
||||
|
||||
client.read_to_string(&mut client_recv).await.unwrap();
|
||||
compute.read_to_string(&mut compute_recv).await.unwrap();
|
||||
|
||||
assert_eq!(compute_recv, "Neon Serverless Postgres");
|
||||
assert_eq!(client_recv, "is amazing");
|
||||
|
||||
copy.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compute_to_client() {
|
||||
let (mut client_client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
let (mut compute_proxy, mut compute_client) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
let (mut client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
let (mut proxy_compute, mut compute) = tokio::io::duplex(16); // Create a mock duplex stream
|
||||
|
||||
// Simulate 'a' finishing while there's still data for 'b'
|
||||
compute_client.write_all(b"hello").await.unwrap();
|
||||
compute_client.shutdown().await.unwrap();
|
||||
client_client
|
||||
.write_all(b"Neon Serverless Postgres")
|
||||
.await
|
||||
.unwrap();
|
||||
client.write_all(b"Neon Serverless Postgres").await.unwrap();
|
||||
compute.write_all(b"is amazing").await.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
compute.shutdown().await.unwrap();
|
||||
|
||||
let copy = tokio::spawn(async move {
|
||||
copy_bidirectional_client_compute(&mut client_proxy, &mut proxy_compute, |_, _| {})
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
assert_eq!(compute_to_client_count, 5); // 'hello' was transferred
|
||||
assert!(client_to_compute_count <= 8); // response only partially transferred or not at all
|
||||
let mut client_recv = String::new();
|
||||
let mut compute_recv = String::new();
|
||||
|
||||
client.read_to_string(&mut client_recv).await.unwrap();
|
||||
compute.read_to_string(&mut compute_recv).await.unwrap();
|
||||
|
||||
assert_eq!(compute_recv, "Neon Serverless ");
|
||||
assert_eq!(client_recv, "is amazing");
|
||||
|
||||
copy.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn test_timeout() {
|
||||
let (mut client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
let (mut proxy_compute, mut compute) = tokio::io::duplex(16); // Create a mock duplex stream
|
||||
|
||||
// Try to send 24 bytes to compute, but compute only has space for 16 bytes.
|
||||
// Writes will not succeed.
|
||||
client.write_all(b"Neon Serverless Postgres").await.unwrap();
|
||||
client.shutdown().await.unwrap();
|
||||
|
||||
let copy = tokio::spawn(async move {
|
||||
copy_bidirectional_client_compute(&mut client_proxy, &mut proxy_compute, |_, _| {})
|
||||
.await
|
||||
.unwrap_err()
|
||||
});
|
||||
|
||||
tokio::time::advance(DISCONNECT_TIMEOUT).await;
|
||||
|
||||
let res = copy.await.unwrap();
|
||||
assert!(matches!(res, ErrorSource::Timeout(_)));
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let mut compute_recv = String::new();
|
||||
compute.read_to_string(&mut compute_recv).await.unwrap();
|
||||
assert_eq!(compute_recv, "Neon Serverless ");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,6 +167,12 @@ pub async fn task_main(
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Timeout(_)) => {
|
||||
info!(
|
||||
?session_id,
|
||||
"per-client task timed out while gracefully shutting down the connection"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
@@ -9,14 +8,15 @@ use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::proxy::copy_bidirectional_client_compute;
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
pub(crate) async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
mut client: Stream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
mut compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
private_link_id: Option<SmolStr>,
|
||||
) -> Result<(), ErrorSource> {
|
||||
@@ -28,37 +28,30 @@ pub(crate) async fn proxy_pass(
|
||||
});
|
||||
|
||||
let metrics = &Metrics::get().proxy.io_bytes;
|
||||
let m_sent = metrics.with_labels(Direction::Tx);
|
||||
let mut client = MeasuredStream::new(
|
||||
client,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
metrics.get_metric(m_sent).inc_by(cnt as u64);
|
||||
usage_tx.record_egress(cnt as u64);
|
||||
},
|
||||
);
|
||||
let m_sent = metrics.with_labels(Direction::ComputeToClient);
|
||||
let m_recv = metrics.with_labels(Direction::ClientToCompute);
|
||||
|
||||
let m_recv = metrics.with_labels(Direction::Rx);
|
||||
let mut compute = MeasuredStream::new(
|
||||
compute,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
metrics.get_metric(m_recv).inc_by(cnt as u64);
|
||||
usage_tx.record_ingress(cnt as u64);
|
||||
},
|
||||
);
|
||||
let inspect = |direction, bytes: &[u8]| match direction {
|
||||
Direction::ComputeToClient => {
|
||||
metrics.get_metric(m_sent).inc_by(bytes.len() as u64);
|
||||
usage_tx.record_egress(bytes.len() as u64);
|
||||
}
|
||||
Direction::ClientToCompute => {
|
||||
metrics.get_metric(m_recv).inc_by(bytes.len() as u64);
|
||||
usage_tx.record_ingress(bytes.len() as u64);
|
||||
}
|
||||
};
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
debug!("performing the proxy pass...");
|
||||
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
|
||||
&mut client,
|
||||
&mut compute,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
// reduce branching internal to the hot path.
|
||||
match &mut client {
|
||||
Stream::Raw { raw } => copy_bidirectional_client_compute(raw, &mut compute, inspect).await,
|
||||
Stream::Tls { tls, .. } => {
|
||||
copy_bidirectional_client_compute(&mut *tls, &mut compute, inspect).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll, ready};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use framed_websockets::{Frame, OpCode, WebSocketServer};
|
||||
use futures::{Sink, Stream};
|
||||
@@ -169,6 +169,9 @@ pub(crate) async fn serve_websocket(
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Timeout(_)) => Err(anyhow!(
|
||||
"timed out while gracefully shutting down the connection"
|
||||
)),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user