diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index aef5c9383e..de74f606b5 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -297,9 +297,11 @@ async fn handle_client( // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await { - Ok(_) => Ok(()), - Err(ErrorSource::Client(err)) => Err(err).context("client"), - Err(ErrorSource::Compute(err)) => Err(err).context("compute"), - } + // match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await { + // Ok(_) => Ok(()), + // Err(ErrorSource::Client(err)) => Err(err).context("client"), + // Err(ErrorSource::Compute(err)) => Err(err).context("compute"), + // } + + Ok(()) } diff --git a/proxy/src/proxy/conntrack.rs b/proxy/src/proxy/conntrack.rs index e31d6287cb..12357b3b1f 100644 --- a/proxy/src/proxy/conntrack.rs +++ b/proxy/src/proxy/conntrack.rs @@ -1,12 +1,7 @@ -use std::pin::Pin; +use std::fmt; use std::sync::Arc; -use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering}; -use std::task::{Context, Poll}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::SystemTime; -use std::{fmt, io}; - -use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ConnId(usize); @@ -87,24 +82,6 @@ pub enum ConnectionState { Unknown = 5, } -impl ConnectionState { - const fn into_repr(self) -> u8 { - self as u8 - } - - const fn from_repr(value: u8) -> Option { - Some(match value { - 0 => Self::Init, - 1 => Self::Idle, - 2 => Self::Transaction, - 3 => Self::Busy, - 4 => Self::Closed, - 5 => Self::Unknown, - _ => return None, - }) - } -} - impl fmt::Display for ConnectionState { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { @@ -118,26 +95,11 @@ impl fmt::Display for ConnectionState { } } -/// Stores the `ConnectionState`. Used by ConnectionTracker to avoid needing -/// mutable references. -#[derive(Debug, Default)] -struct AtomicConnectionState(AtomicU8); - -impl AtomicConnectionState { - fn set(&self, state: ConnectionState) { - self.0.store(state.into_repr(), Ordering::Relaxed); - } - - fn get(&self) -> ConnectionState { - ConnectionState::from_repr(self.0.load(Ordering::Relaxed)).expect("only valid variants") - } -} - /// Tracks the `ConnectionState` of a connection by inspecting the frontend and /// backend stream and reacting to specific messages. Used in combination with /// two `TrackedStream`s. pub struct ConnectionTracker { - state: AtomicConnectionState, + state: ConnectionState, observer: SCO, conn_id: SCO::ConnId, } @@ -145,7 +107,7 @@ pub struct ConnectionTracker { impl Drop for ConnectionTracker { fn drop(&mut self) { self.observer - .change(self.conn_id, self.state.get(), ConnectionState::Closed); + .change(self.conn_id, self.state, ConnectionState::Closed); } } @@ -153,25 +115,25 @@ impl ConnectionTracker { pub fn new(conn_id: SCO::ConnId, observer: SCO) -> Self { ConnectionTracker { conn_id, - state: AtomicConnectionState::default(), + state: ConnectionState::default(), observer, } } - pub fn frontend_message_tag(&self, tag: Tag) { + pub fn frontend_message_tag(&mut self, tag: Tag) { self.update_state(|old_state| Self::state_from_frontend_tag(old_state, tag)); } - pub fn backend_message_tag(&self, tag: Tag) { + pub fn backend_message_tag(&mut self, tag: Tag) { self.update_state(|old_state| Self::state_from_backend_tag(old_state, tag)); } - fn update_state(&self, new_state_fn: impl FnOnce(ConnectionState) -> ConnectionState) { - let old_state = self.state.get(); + fn update_state(&mut self, new_state_fn: impl FnOnce(ConnectionState) -> ConnectionState) { + let old_state = self.state; let new_state = new_state_fn(old_state); if old_state != new_state { self.observer.change(self.conn_id, old_state, new_state); - self.state.set(new_state); + self.state = new_state; } } @@ -242,73 +204,9 @@ impl TagObserver for F { } } -pin_project! { - pub struct TrackedStream { - #[pin] - stream: S, - scanner: StreamScanner, - } -} - -impl TrackedStream { - pub const fn new(stream: S, midstream: bool, observer: TO) -> Self { - TrackedStream { - stream, - scanner: StreamScanner::new(midstream, observer), - } - } -} - -impl AsyncRead for TrackedStream { - #[inline] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let this = self.project(); - let old_len = buf.filled().len(); - match this.stream.poll_read(cx, buf) { - Poll::Ready(Ok(())) => { - let new_len = buf.filled().len(); - this.scanner.scan_bytes(&buf.filled()[old_len..new_len]); - Poll::Ready(Ok(())) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - } - } -} - -impl AsyncWrite for TrackedStream { - #[inline(always)] - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().stream.poll_write(cx, buf) - } - - #[inline(always)] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().stream.poll_flush(cx) - } - - #[inline(always)] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().stream.poll_shutdown(cx) - } -} - -#[derive(Debug)] -struct StreamScanner { - observer: TO, - state: StreamScannerState, -} - #[derive(Copy, Clone, Debug, PartialEq, Eq)] -enum StreamScannerState { +pub(super) enum StreamScannerState { + #[allow(dead_code)] /// Initial state when no message has been read and we are looling for a /// message without a tag. Start, @@ -339,36 +237,23 @@ enum StreamScannerState { Lost, } -impl StreamScanner { - const fn new(midstream: bool, observer: TO) -> Self { - StreamScanner { - observer, - state: if midstream { - StreamScannerState::Tag - } else { - StreamScannerState::Start - }, - } - } -} - -impl StreamScanner { - fn scan_bytes(&mut self, mut buf: &[u8]) { +impl StreamScannerState { + pub(super) fn scan_bytes(&mut self, mut buf: &[u8], observer: &mut TO) { use StreamScannerState as S; - if matches!(self.state, S::End | S::Lost) { + if matches!(*self, S::End | S::Lost) { return; } if buf.is_empty() { - match self.state { + match *self { S::Start | S::Tag => { - self.observer.observe(Tag::End); - self.state = S::End; + observer.observe(Tag::End); + *self = S::End; return; } S::Length { .. } | S::Payload { .. } => { - self.observer.observe(Tag::Lost); - self.state = S::Lost; + observer.observe(Tag::Lost); + *self = S::Lost; return; } S::End | S::Lost => unreachable!(), @@ -376,9 +261,9 @@ impl StreamScanner { } while !buf.is_empty() { - match self.state { + match *self { S::Start => { - self.state = S::Length { + *self = S::Length { tag: Tag::Start, length_bytes_missing: 4, calculated_length: 0, @@ -389,7 +274,7 @@ impl StreamScanner { let tag = buf.first().copied().expect("buf not empty"); buf = &buf[1..]; - self.state = S::Length { + *self = S::Length { tag: Tag::Message(tag), length_bytes_missing: 4, calculated_length: 0, @@ -413,23 +298,23 @@ impl StreamScanner { length_bytes_missing -= consume; if length_bytes_missing == 0 { let Some(bytes_to_skip) = calculated_length.checked_sub(4) else { - self.observer.observe(Tag::Lost); - self.state = S::Lost; + observer.observe(Tag::Lost); + *self = S::Lost; return; }; if bytes_to_skip == 0 { - self.observer.observe(tag); - self.state = S::Tag; + observer.observe(tag); + *self = S::Tag; } else { - self.state = S::Payload { + *self = S::Payload { tag, first: true, bytes_to_skip, }; } } else { - self.state = S::Length { + *self = S::Length { tag, length_bytes_missing, calculated_length, @@ -447,13 +332,13 @@ impl StreamScanner { if bytes_to_skip == 0 { if tag == Tag::READY_FOR_QUERY && first && consume == 1 { let status = buf.first().copied().expect("buf not empty"); - self.observer.observe(Tag::ReadyForQuery(status)); + observer.observe(Tag::ReadyForQuery(status)); } else { - self.observer.observe(tag); + observer.observe(tag); } - self.state = S::Tag; + *self = S::Tag; } else { - self.state = S::Payload { + *self = S::Payload { tag, first: false, bytes_to_skip, @@ -471,14 +356,103 @@ impl StreamScanner { #[cfg(test)] mod tests { use std::cell::RefCell; + use std::io; + use std::pin::Pin; use std::pin::pin; use std::rc::Rc; use std::sync::{Arc, Mutex}; + use std::task::{Context, Poll}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncReadExt, BufReader}; use super::*; + pub struct TrackedStream { + stream: S, + scanner: StreamScanner, + } + + impl Unpin for TrackedStream {} + + impl TrackedStream { + pub const fn new(stream: S, midstream: bool, observer: TO) -> Self { + TrackedStream { + stream, + scanner: StreamScanner::new(midstream, observer), + } + } + } + + impl AsyncRead for TrackedStream { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let Self { stream, scanner } = Pin::into_inner(self); + let StreamScanner { observer, state } = scanner; + + let old_len = buf.filled().len(); + match Pin::new(stream).poll_read(cx, buf) { + Poll::Ready(Ok(())) => { + let new_len = buf.filled().len(); + state.scan_bytes(&buf.filled()[old_len..new_len], observer); + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } + } + + impl AsyncWrite for TrackedStream { + #[inline(always)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + #[inline(always)] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + #[inline(always)] + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } + } + + #[derive(Debug)] + struct StreamScanner { + observer: TO, + state: StreamScannerState, + } + + impl StreamScanner { + const fn new(midstream: bool, observer: TO) -> Self { + StreamScanner { + observer, + state: if midstream { + StreamScannerState::Tag + } else { + StreamScannerState::Start + }, + } + } + } + + impl StreamScanner { + fn scan_bytes(&mut self, buf: &[u8]) { + self.state.scan_bytes(buf, &mut self.observer); + } + } + #[test] fn test_stream_scanner() { let tags = Rc::new(RefCell::new(Vec::new())); @@ -572,7 +546,7 @@ mod tests { self.0.lock().unwrap().push((old_state, new_state)); } } - let tracker = ConnectionTracker::new(42, Observer(transitions.clone())); + let mut tracker = ConnectionTracker::new(42, Observer(transitions.clone())); let stream = TestStream::new( &[ diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 97f8d7c6af..b85d1629eb 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -6,9 +6,11 @@ use std::task::{Context, Poll, ready}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::info; +use super::conntrack::{ConnectionTracker, StateChangeObserver, StreamScannerState, TagObserver}; + #[derive(Debug)] enum TransferState { - Running(CopyBuffer), + Running(CopyBuffer, StreamScannerState), ShuttingDown(u64), Done(u64), } @@ -43,6 +45,7 @@ pub enum ErrorSource { fn transfer_one_direction( cx: &mut Context<'_>, state: &mut TransferState, + mut observer: impl TagObserver, r: &mut A, w: &mut B, ) -> Poll> @@ -54,8 +57,9 @@ where let mut w = Pin::new(w); loop { match state { - TransferState::Running(buf) => { - let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + TransferState::Running(buf, stream_state) => { + let count = + ready!(buf.poll_copy(cx, stream_state, &mut observer, r.as_mut(), w.as_mut()))?; *state = TransferState::ShuttingDown(count); } TransferState::ShuttingDown(count) => { @@ -71,45 +75,66 @@ where pub async fn copy_bidirectional_client_compute( client: &mut Client, compute: &mut Compute, + conn_tracker: &mut ConnectionTracker, ) -> Result<(u64, u64), ErrorSource> where Client: AsyncRead + AsyncWrite + Unpin + ?Sized, Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, { - let mut client_to_compute = TransferState::Running(CopyBuffer::new()); - let mut compute_to_client = TransferState::Running(CopyBuffer::new()); + let mut client_to_compute = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag); + let mut compute_to_client = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag); poll_fn(|cx| { - let mut client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, client, compute) - .map_err(ErrorSource::from_client)?; - let mut compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, compute, client) - .map_err(ErrorSource::from_compute)?; + let mut client_to_compute_result = transfer_one_direction( + cx, + &mut client_to_compute, + |tag| conn_tracker.frontend_message_tag(tag), + client, + compute, + ) + .map_err(ErrorSource::from_client)?; + let mut compute_to_client_result = transfer_one_direction( + cx, + &mut compute_to_client, + |tag| conn_tracker.backend_message_tag(tag), + compute, + client, + ) + .map_err(ErrorSource::from_compute)?; // TODO: 1 info log, with a enum label for close direction. // Early termination checks from compute to client. if let TransferState::Done(_) = compute_to_client { - if let TransferState::Running(buf) = &client_to_compute { + if let TransferState::Running(buf, _) = &client_to_compute { info!("Compute is done, terminate client"); // Initiate shutdown client_to_compute = TransferState::ShuttingDown(buf.amt); - client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, client, compute) - .map_err(ErrorSource::from_client)?; + client_to_compute_result = transfer_one_direction( + cx, + &mut client_to_compute, + |tag| conn_tracker.frontend_message_tag(tag), + client, + compute, + ) + .map_err(ErrorSource::from_client)?; } } // Early termination checks from client to compute. if let TransferState::Done(_) = client_to_compute { - if let TransferState::Running(buf) = &compute_to_client { + if let TransferState::Running(buf, _) = &compute_to_client { info!("Client is done, terminate compute"); // Initiate shutdown compute_to_client = TransferState::ShuttingDown(buf.amt); - compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, compute, client) - .map_err(ErrorSource::from_compute)?; + compute_to_client_result = transfer_one_direction( + cx, + &mut compute_to_client, + |tag| conn_tracker.backend_message_tag(tag), + compute, + client, + ) + .map_err(ErrorSource::from_compute)?; } } @@ -148,6 +173,8 @@ impl CopyBuffer { fn poll_fill_buf( &mut self, cx: &mut Context<'_>, + state: &mut StreamScannerState, + observer: &mut impl TagObserver, reader: Pin<&mut R>, ) -> Poll> where @@ -158,6 +185,8 @@ impl CopyBuffer { buf.set_filled(me.cap); let res = reader.poll_read(cx, &mut buf); + state.scan_bytes(&buf.filled()[me.cap..], observer); + if let Poll::Ready(Ok(())) = res { let filled_len = buf.filled().len(); me.read_done = me.cap == filled_len; @@ -169,6 +198,8 @@ impl CopyBuffer { fn poll_write_buf( &mut self, cx: &mut Context<'_>, + state: &mut StreamScannerState, + observer: &mut impl TagObserver, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> @@ -182,7 +213,8 @@ impl CopyBuffer { // Top up the buffer towards full if we can read a bit more // data - this should improve the chances of a large write if !me.read_done && me.cap < me.buf.len() { - ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?; + ready!(me.poll_fill_buf(cx, state, observer, reader.as_mut())) + .map_err(ErrorDirection::Read)?; } Poll::Pending } @@ -193,6 +225,8 @@ impl CopyBuffer { pub(super) fn poll_copy( &mut self, cx: &mut Context<'_>, + state: &mut StreamScannerState, + observer: &mut impl TagObserver, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> @@ -204,7 +238,7 @@ impl CopyBuffer { // If there is some space left in our buffer, then we try to read some // data to continue, thus maximizing the chances of a large write. if self.cap < self.buf.len() && !self.read_done { - match self.poll_fill_buf(cx, reader.as_mut()) { + match self.poll_fill_buf(cx, state, observer, reader.as_mut()) { Poll::Ready(Ok(())) => (), Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))), Poll::Pending => { @@ -227,7 +261,13 @@ impl CopyBuffer { // If our buffer has some data, let's write it out! while self.pos < self.cap { - let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; + let i = ready!(self.poll_write_buf( + cx, + state, + observer, + reader.as_mut(), + writer.as_mut() + ))?; if i == 0 { return Poll::Ready(Err(ErrorDirection::Write(io::Error::new( io::ErrorKind::WriteZero, @@ -263,10 +303,24 @@ impl CopyBuffer { #[cfg(test)] mod tests { + use std::sync::Mutex; + use tokio::io::AsyncWriteExt; + use crate::proxy::conntrack::ConnectionState; + use super::*; + #[derive(Default)] + struct Observer(Mutex>); + + impl StateChangeObserver for Observer { + type ConnId = (); + fn change(&self, (): Self::ConnId, old_state: ConnectionState, new_state: ConnectionState) { + self.0.lock().unwrap().push((old_state, new_state)); + } + } + #[tokio::test] async fn test_client_to_compute() { let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream @@ -278,9 +332,15 @@ mod tests { compute_client.write_all(b"Neon").await.unwrap(); compute_client.shutdown().await.unwrap(); - let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy) - .await - .unwrap(); + let mut conn_tracker = ConnectionTracker::new((), Observer::default()); + + let result = copy_bidirectional_client_compute( + &mut client_proxy, + &mut compute_proxy, + &mut conn_tracker, + ) + .await + .unwrap(); // Assert correct transferred amounts let (client_to_compute_count, compute_to_client_count) = result; @@ -301,9 +361,15 @@ mod tests { .await .unwrap(); - let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy) - .await - .unwrap(); + let mut conn_tracker = ConnectionTracker::new((), Observer::default()); + + let result = copy_bidirectional_client_compute( + &mut client_proxy, + &mut compute_proxy, + &mut conn_tracker, + ) + .await + .unwrap(); // Assert correct transferred amounts let (client_to_compute_count, compute_to_client_count) = result; diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 73b7343831..b0e75c79bc 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -11,7 +11,7 @@ use crate::compute::PostgresConnection; use crate::config::ComputeConfig; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; -use crate::proxy::conntrack::{ConnectionTracking, TrackedStream}; +use crate::proxy::conntrack::ConnectionTracking; use crate::stream::Stream; use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}; @@ -31,11 +31,11 @@ pub(crate) async fn proxy_pass( private_link_id, }); - let conn_tracker = conntracking.new_tracker(); + let mut conn_tracker = conntracking.new_tracker(); let metrics = &Metrics::get().proxy.io_bytes; let m_sent = metrics.with_labels(Direction::Tx); - let client = MeasuredStream::new( + let mut client = MeasuredStream::new( client, |_| {}, |cnt| { @@ -44,10 +44,9 @@ pub(crate) async fn proxy_pass( usage_tx.record_egress(cnt as u64); }, ); - let mut client = TrackedStream::new(client, true, |tag| conn_tracker.frontend_message_tag(tag)); let m_recv = metrics.with_labels(Direction::Rx); - let compute = MeasuredStream::new( + let mut compute = MeasuredStream::new( compute, |_| {}, |cnt| { @@ -56,14 +55,13 @@ pub(crate) async fn proxy_pass( usage_tx.record_ingress(cnt as u64); }, ); - let mut compute = - TrackedStream::new(compute, true, |tag| conn_tracker.backend_message_tag(tag)); // Starting from here we only proxy the client's traffic. debug!("performing the proxy pass..."); let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute( &mut client, &mut compute, + &mut conn_tracker, ) .await?;