mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-02 13:00:37 +00:00
embed into conn tracking into copy_bidirectional
This commit is contained in:
@@ -297,9 +297,11 @@ async fn handle_client(
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
|
||||
match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
// match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
|
||||
// Ok(_) => Ok(()),
|
||||
// Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
// Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
// }
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
use std::pin::Pin;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::SystemTime;
|
||||
use std::{fmt, io};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct ConnId(usize);
|
||||
@@ -87,24 +82,6 @@ pub enum ConnectionState {
|
||||
Unknown = 5,
|
||||
}
|
||||
|
||||
impl ConnectionState {
|
||||
const fn into_repr(self) -> u8 {
|
||||
self as u8
|
||||
}
|
||||
|
||||
const fn from_repr(value: u8) -> Option<Self> {
|
||||
Some(match value {
|
||||
0 => Self::Init,
|
||||
1 => Self::Idle,
|
||||
2 => Self::Transaction,
|
||||
3 => Self::Busy,
|
||||
4 => Self::Closed,
|
||||
5 => Self::Unknown,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnectionState {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
@@ -118,26 +95,11 @@ impl fmt::Display for ConnectionState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores the `ConnectionState`. Used by ConnectionTracker to avoid needing
|
||||
/// mutable references.
|
||||
#[derive(Debug, Default)]
|
||||
struct AtomicConnectionState(AtomicU8);
|
||||
|
||||
impl AtomicConnectionState {
|
||||
fn set(&self, state: ConnectionState) {
|
||||
self.0.store(state.into_repr(), Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn get(&self) -> ConnectionState {
|
||||
ConnectionState::from_repr(self.0.load(Ordering::Relaxed)).expect("only valid variants")
|
||||
}
|
||||
}
|
||||
|
||||
/// Tracks the `ConnectionState` of a connection by inspecting the frontend and
|
||||
/// backend stream and reacting to specific messages. Used in combination with
|
||||
/// two `TrackedStream`s.
|
||||
pub struct ConnectionTracker<SCO: StateChangeObserver> {
|
||||
state: AtomicConnectionState,
|
||||
state: ConnectionState,
|
||||
observer: SCO,
|
||||
conn_id: SCO::ConnId,
|
||||
}
|
||||
@@ -145,7 +107,7 @@ pub struct ConnectionTracker<SCO: StateChangeObserver> {
|
||||
impl<SCO: StateChangeObserver> Drop for ConnectionTracker<SCO> {
|
||||
fn drop(&mut self) {
|
||||
self.observer
|
||||
.change(self.conn_id, self.state.get(), ConnectionState::Closed);
|
||||
.change(self.conn_id, self.state, ConnectionState::Closed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,25 +115,25 @@ impl<SCO: StateChangeObserver> ConnectionTracker<SCO> {
|
||||
pub fn new(conn_id: SCO::ConnId, observer: SCO) -> Self {
|
||||
ConnectionTracker {
|
||||
conn_id,
|
||||
state: AtomicConnectionState::default(),
|
||||
state: ConnectionState::default(),
|
||||
observer,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn frontend_message_tag(&self, tag: Tag) {
|
||||
pub fn frontend_message_tag(&mut self, tag: Tag) {
|
||||
self.update_state(|old_state| Self::state_from_frontend_tag(old_state, tag));
|
||||
}
|
||||
|
||||
pub fn backend_message_tag(&self, tag: Tag) {
|
||||
pub fn backend_message_tag(&mut self, tag: Tag) {
|
||||
self.update_state(|old_state| Self::state_from_backend_tag(old_state, tag));
|
||||
}
|
||||
|
||||
fn update_state(&self, new_state_fn: impl FnOnce(ConnectionState) -> ConnectionState) {
|
||||
let old_state = self.state.get();
|
||||
fn update_state(&mut self, new_state_fn: impl FnOnce(ConnectionState) -> ConnectionState) {
|
||||
let old_state = self.state;
|
||||
let new_state = new_state_fn(old_state);
|
||||
if old_state != new_state {
|
||||
self.observer.change(self.conn_id, old_state, new_state);
|
||||
self.state.set(new_state);
|
||||
self.state = new_state;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,73 +204,9 @@ impl<F: FnMut(Tag)> TagObserver for F {
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct TrackedStream<S, TO> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
scanner: StreamScanner<TO>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, TO: TagObserver> TrackedStream<S, TO> {
|
||||
pub const fn new(stream: S, midstream: bool, observer: TO) -> Self {
|
||||
TrackedStream {
|
||||
stream,
|
||||
scanner: StreamScanner::new(midstream, observer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, TO: TagObserver> AsyncRead for TrackedStream<S, TO> {
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
let old_len = buf.filled().len();
|
||||
match this.stream.poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let new_len = buf.filled().len();
|
||||
this.scanner.scan_bytes(&buf.filled()[old_len..new_len]);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin, TO> AsyncWrite for TrackedStream<S, TO> {
|
||||
#[inline(always)]
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.project().stream.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().stream.poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().stream.poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StreamScanner<TO> {
|
||||
observer: TO,
|
||||
state: StreamScannerState,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
enum StreamScannerState {
|
||||
pub(super) enum StreamScannerState {
|
||||
#[allow(dead_code)]
|
||||
/// Initial state when no message has been read and we are looling for a
|
||||
/// message without a tag.
|
||||
Start,
|
||||
@@ -339,36 +237,23 @@ enum StreamScannerState {
|
||||
Lost,
|
||||
}
|
||||
|
||||
impl<TO: TagObserver> StreamScanner<TO> {
|
||||
const fn new(midstream: bool, observer: TO) -> Self {
|
||||
StreamScanner {
|
||||
observer,
|
||||
state: if midstream {
|
||||
StreamScannerState::Tag
|
||||
} else {
|
||||
StreamScannerState::Start
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TO: TagObserver> StreamScanner<TO> {
|
||||
fn scan_bytes(&mut self, mut buf: &[u8]) {
|
||||
impl StreamScannerState {
|
||||
pub(super) fn scan_bytes<TO: TagObserver>(&mut self, mut buf: &[u8], observer: &mut TO) {
|
||||
use StreamScannerState as S;
|
||||
|
||||
if matches!(self.state, S::End | S::Lost) {
|
||||
if matches!(*self, S::End | S::Lost) {
|
||||
return;
|
||||
}
|
||||
if buf.is_empty() {
|
||||
match self.state {
|
||||
match *self {
|
||||
S::Start | S::Tag => {
|
||||
self.observer.observe(Tag::End);
|
||||
self.state = S::End;
|
||||
observer.observe(Tag::End);
|
||||
*self = S::End;
|
||||
return;
|
||||
}
|
||||
S::Length { .. } | S::Payload { .. } => {
|
||||
self.observer.observe(Tag::Lost);
|
||||
self.state = S::Lost;
|
||||
observer.observe(Tag::Lost);
|
||||
*self = S::Lost;
|
||||
return;
|
||||
}
|
||||
S::End | S::Lost => unreachable!(),
|
||||
@@ -376,9 +261,9 @@ impl<TO: TagObserver> StreamScanner<TO> {
|
||||
}
|
||||
|
||||
while !buf.is_empty() {
|
||||
match self.state {
|
||||
match *self {
|
||||
S::Start => {
|
||||
self.state = S::Length {
|
||||
*self = S::Length {
|
||||
tag: Tag::Start,
|
||||
length_bytes_missing: 4,
|
||||
calculated_length: 0,
|
||||
@@ -389,7 +274,7 @@ impl<TO: TagObserver> StreamScanner<TO> {
|
||||
let tag = buf.first().copied().expect("buf not empty");
|
||||
buf = &buf[1..];
|
||||
|
||||
self.state = S::Length {
|
||||
*self = S::Length {
|
||||
tag: Tag::Message(tag),
|
||||
length_bytes_missing: 4,
|
||||
calculated_length: 0,
|
||||
@@ -413,23 +298,23 @@ impl<TO: TagObserver> StreamScanner<TO> {
|
||||
length_bytes_missing -= consume;
|
||||
if length_bytes_missing == 0 {
|
||||
let Some(bytes_to_skip) = calculated_length.checked_sub(4) else {
|
||||
self.observer.observe(Tag::Lost);
|
||||
self.state = S::Lost;
|
||||
observer.observe(Tag::Lost);
|
||||
*self = S::Lost;
|
||||
return;
|
||||
};
|
||||
|
||||
if bytes_to_skip == 0 {
|
||||
self.observer.observe(tag);
|
||||
self.state = S::Tag;
|
||||
observer.observe(tag);
|
||||
*self = S::Tag;
|
||||
} else {
|
||||
self.state = S::Payload {
|
||||
*self = S::Payload {
|
||||
tag,
|
||||
first: true,
|
||||
bytes_to_skip,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
self.state = S::Length {
|
||||
*self = S::Length {
|
||||
tag,
|
||||
length_bytes_missing,
|
||||
calculated_length,
|
||||
@@ -447,13 +332,13 @@ impl<TO: TagObserver> StreamScanner<TO> {
|
||||
if bytes_to_skip == 0 {
|
||||
if tag == Tag::READY_FOR_QUERY && first && consume == 1 {
|
||||
let status = buf.first().copied().expect("buf not empty");
|
||||
self.observer.observe(Tag::ReadyForQuery(status));
|
||||
observer.observe(Tag::ReadyForQuery(status));
|
||||
} else {
|
||||
self.observer.observe(tag);
|
||||
observer.observe(tag);
|
||||
}
|
||||
self.state = S::Tag;
|
||||
*self = S::Tag;
|
||||
} else {
|
||||
self.state = S::Payload {
|
||||
*self = S::Payload {
|
||||
tag,
|
||||
first: false,
|
||||
bytes_to_skip,
|
||||
@@ -471,14 +356,103 @@ impl<TO: TagObserver> StreamScanner<TO> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::cell::RefCell;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::pin::pin;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub struct TrackedStream<S, TO> {
|
||||
stream: S,
|
||||
scanner: StreamScanner<TO>,
|
||||
}
|
||||
|
||||
impl<S: Unpin, TO> Unpin for TrackedStream<S, TO> {}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, TO: TagObserver> TrackedStream<S, TO> {
|
||||
pub const fn new(stream: S, midstream: bool, observer: TO) -> Self {
|
||||
TrackedStream {
|
||||
stream,
|
||||
scanner: StreamScanner::new(midstream, observer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, TO: TagObserver> AsyncRead for TrackedStream<S, TO> {
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let Self { stream, scanner } = Pin::into_inner(self);
|
||||
let StreamScanner { observer, state } = scanner;
|
||||
|
||||
let old_len = buf.filled().len();
|
||||
match Pin::new(stream).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let new_len = buf.filled().len();
|
||||
state.scan_bytes(&buf.filled()[old_len..new_len], observer);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin, TO> AsyncWrite for TrackedStream<S, TO> {
|
||||
#[inline(always)]
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.stream).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.stream).poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.stream).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StreamScanner<TO> {
|
||||
observer: TO,
|
||||
state: StreamScannerState,
|
||||
}
|
||||
|
||||
impl<TO: TagObserver> StreamScanner<TO> {
|
||||
const fn new(midstream: bool, observer: TO) -> Self {
|
||||
StreamScanner {
|
||||
observer,
|
||||
state: if midstream {
|
||||
StreamScannerState::Tag
|
||||
} else {
|
||||
StreamScannerState::Start
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TO: TagObserver> StreamScanner<TO> {
|
||||
fn scan_bytes(&mut self, buf: &[u8]) {
|
||||
self.state.scan_bytes(buf, &mut self.observer);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_scanner() {
|
||||
let tags = Rc::new(RefCell::new(Vec::new()));
|
||||
@@ -572,7 +546,7 @@ mod tests {
|
||||
self.0.lock().unwrap().push((old_state, new_state));
|
||||
}
|
||||
}
|
||||
let tracker = ConnectionTracker::new(42, Observer(transitions.clone()));
|
||||
let mut tracker = ConnectionTracker::new(42, Observer(transitions.clone()));
|
||||
|
||||
let stream = TestStream::new(
|
||||
&[
|
||||
|
||||
@@ -6,9 +6,11 @@ use std::task::{Context, Poll, ready};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::info;
|
||||
|
||||
use super::conntrack::{ConnectionTracker, StateChangeObserver, StreamScannerState, TagObserver};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TransferState {
|
||||
Running(CopyBuffer),
|
||||
Running(CopyBuffer, StreamScannerState),
|
||||
ShuttingDown(u64),
|
||||
Done(u64),
|
||||
}
|
||||
@@ -43,6 +45,7 @@ pub enum ErrorSource {
|
||||
fn transfer_one_direction<A, B>(
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut TransferState,
|
||||
mut observer: impl TagObserver,
|
||||
r: &mut A,
|
||||
w: &mut B,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
@@ -54,8 +57,9 @@ where
|
||||
let mut w = Pin::new(w);
|
||||
loop {
|
||||
match state {
|
||||
TransferState::Running(buf) => {
|
||||
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
|
||||
TransferState::Running(buf, stream_state) => {
|
||||
let count =
|
||||
ready!(buf.poll_copy(cx, stream_state, &mut observer, r.as_mut(), w.as_mut()))?;
|
||||
*state = TransferState::ShuttingDown(count);
|
||||
}
|
||||
TransferState::ShuttingDown(count) => {
|
||||
@@ -71,45 +75,66 @@ where
|
||||
pub async fn copy_bidirectional_client_compute<Client, Compute>(
|
||||
client: &mut Client,
|
||||
compute: &mut Compute,
|
||||
conn_tracker: &mut ConnectionTracker<impl StateChangeObserver>,
|
||||
) -> Result<(u64, u64), ErrorSource>
|
||||
where
|
||||
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut client_to_compute = TransferState::Running(CopyBuffer::new());
|
||||
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
|
||||
let mut client_to_compute = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag);
|
||||
let mut compute_to_client = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag);
|
||||
|
||||
poll_fn(|cx| {
|
||||
let mut client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
let mut compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
let mut client_to_compute_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut client_to_compute,
|
||||
|tag| conn_tracker.frontend_message_tag(tag),
|
||||
client,
|
||||
compute,
|
||||
)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
let mut compute_to_client_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut compute_to_client,
|
||||
|tag| conn_tracker.backend_message_tag(tag),
|
||||
compute,
|
||||
client,
|
||||
)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
|
||||
// TODO: 1 info log, with a enum label for close direction.
|
||||
|
||||
// Early termination checks from compute to client.
|
||||
if let TransferState::Done(_) = compute_to_client {
|
||||
if let TransferState::Running(buf) = &client_to_compute {
|
||||
if let TransferState::Running(buf, _) = &client_to_compute {
|
||||
info!("Compute is done, terminate client");
|
||||
// Initiate shutdown
|
||||
client_to_compute = TransferState::ShuttingDown(buf.amt);
|
||||
client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
client_to_compute_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut client_to_compute,
|
||||
|tag| conn_tracker.frontend_message_tag(tag),
|
||||
client,
|
||||
compute,
|
||||
)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Early termination checks from client to compute.
|
||||
if let TransferState::Done(_) = client_to_compute {
|
||||
if let TransferState::Running(buf) = &compute_to_client {
|
||||
if let TransferState::Running(buf, _) = &compute_to_client {
|
||||
info!("Client is done, terminate compute");
|
||||
// Initiate shutdown
|
||||
compute_to_client = TransferState::ShuttingDown(buf.amt);
|
||||
compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
compute_to_client_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut compute_to_client,
|
||||
|tag| conn_tracker.backend_message_tag(tag),
|
||||
compute,
|
||||
client,
|
||||
)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,6 +173,8 @@ impl CopyBuffer {
|
||||
fn poll_fill_buf<R>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut StreamScannerState,
|
||||
observer: &mut impl TagObserver,
|
||||
reader: Pin<&mut R>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
@@ -158,6 +185,8 @@ impl CopyBuffer {
|
||||
buf.set_filled(me.cap);
|
||||
|
||||
let res = reader.poll_read(cx, &mut buf);
|
||||
state.scan_bytes(&buf.filled()[me.cap..], observer);
|
||||
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
let filled_len = buf.filled().len();
|
||||
me.read_done = me.cap == filled_len;
|
||||
@@ -169,6 +198,8 @@ impl CopyBuffer {
|
||||
fn poll_write_buf<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut StreamScannerState,
|
||||
observer: &mut impl TagObserver,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<usize, ErrorDirection>>
|
||||
@@ -182,7 +213,8 @@ impl CopyBuffer {
|
||||
// Top up the buffer towards full if we can read a bit more
|
||||
// data - this should improve the chances of a large write
|
||||
if !me.read_done && me.cap < me.buf.len() {
|
||||
ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
|
||||
ready!(me.poll_fill_buf(cx, state, observer, reader.as_mut()))
|
||||
.map_err(ErrorDirection::Read)?;
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
@@ -193,6 +225,8 @@ impl CopyBuffer {
|
||||
pub(super) fn poll_copy<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut StreamScannerState,
|
||||
observer: &mut impl TagObserver,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
@@ -204,7 +238,7 @@ impl CopyBuffer {
|
||||
// If there is some space left in our buffer, then we try to read some
|
||||
// data to continue, thus maximizing the chances of a large write.
|
||||
if self.cap < self.buf.len() && !self.read_done {
|
||||
match self.poll_fill_buf(cx, reader.as_mut()) {
|
||||
match self.poll_fill_buf(cx, state, observer, reader.as_mut()) {
|
||||
Poll::Ready(Ok(())) => (),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
|
||||
Poll::Pending => {
|
||||
@@ -227,7 +261,13 @@ impl CopyBuffer {
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
|
||||
let i = ready!(self.poll_write_buf(
|
||||
cx,
|
||||
state,
|
||||
observer,
|
||||
reader.as_mut(),
|
||||
writer.as_mut()
|
||||
))?;
|
||||
if i == 0 {
|
||||
return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
@@ -263,10 +303,24 @@ impl CopyBuffer {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Mutex;
|
||||
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use crate::proxy::conntrack::ConnectionState;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Default)]
|
||||
struct Observer(Mutex<Vec<(ConnectionState, ConnectionState)>>);
|
||||
|
||||
impl StateChangeObserver for Observer {
|
||||
type ConnId = ();
|
||||
fn change(&self, (): Self::ConnId, old_state: ConnectionState, new_state: ConnectionState) {
|
||||
self.0.lock().unwrap().push((old_state, new_state));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_to_compute() {
|
||||
let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
@@ -278,9 +332,15 @@ mod tests {
|
||||
compute_client.write_all(b"Neon").await.unwrap();
|
||||
compute_client.shutdown().await.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut conn_tracker = ConnectionTracker::new((), Observer::default());
|
||||
|
||||
let result = copy_bidirectional_client_compute(
|
||||
&mut client_proxy,
|
||||
&mut compute_proxy,
|
||||
&mut conn_tracker,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
@@ -301,9 +361,15 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut conn_tracker = ConnectionTracker::new((), Observer::default());
|
||||
|
||||
let result = copy_bidirectional_client_compute(
|
||||
&mut client_proxy,
|
||||
&mut compute_proxy,
|
||||
&mut conn_tracker,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::proxy::conntrack::{ConnectionTracking, TrackedStream};
|
||||
use crate::proxy::conntrack::ConnectionTracking;
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
@@ -31,11 +31,11 @@ pub(crate) async fn proxy_pass(
|
||||
private_link_id,
|
||||
});
|
||||
|
||||
let conn_tracker = conntracking.new_tracker();
|
||||
let mut conn_tracker = conntracking.new_tracker();
|
||||
|
||||
let metrics = &Metrics::get().proxy.io_bytes;
|
||||
let m_sent = metrics.with_labels(Direction::Tx);
|
||||
let client = MeasuredStream::new(
|
||||
let mut client = MeasuredStream::new(
|
||||
client,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
@@ -44,10 +44,9 @@ pub(crate) async fn proxy_pass(
|
||||
usage_tx.record_egress(cnt as u64);
|
||||
},
|
||||
);
|
||||
let mut client = TrackedStream::new(client, true, |tag| conn_tracker.frontend_message_tag(tag));
|
||||
|
||||
let m_recv = metrics.with_labels(Direction::Rx);
|
||||
let compute = MeasuredStream::new(
|
||||
let mut compute = MeasuredStream::new(
|
||||
compute,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
@@ -56,14 +55,13 @@ pub(crate) async fn proxy_pass(
|
||||
usage_tx.record_ingress(cnt as u64);
|
||||
},
|
||||
);
|
||||
let mut compute =
|
||||
TrackedStream::new(compute, true, |tag| conn_tracker.backend_message_tag(tag));
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
debug!("performing the proxy pass...");
|
||||
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
|
||||
&mut client,
|
||||
&mut compute,
|
||||
&mut conn_tracker,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user