diff --git a/proxy/src/proxy/conntrack.rs b/proxy/src/proxy/conntrack.rs index 12357b3b1f..7cf41559dc 100644 --- a/proxy/src/proxy/conntrack.rs +++ b/proxy/src/proxy/conntrack.rs @@ -12,9 +12,12 @@ pub struct ConnectionTracking { } impl ConnectionTracking { - pub fn new_tracker(self: &Arc) -> ConnectionTracker> { + pub fn new_tracker(self: &Arc) -> ConnectionTracker { let conn_id = self.new_conn_id(); - ConnectionTracker::new(conn_id, Arc::clone(self)) + ConnectionTracker::new(Conn { + conn_id, + tracking: Arc::clone(self), + }) } fn new_conn_id(&self) -> ConnId { @@ -43,31 +46,28 @@ impl ConnectionTracking { } } -impl StateChangeObserver for Arc { - type ConnId = ConnId; - fn change( - &self, - conn_id: Self::ConnId, - _old_state: ConnectionState, - new_state: ConnectionState, - ) { +pub struct Conn { + conn_id: ConnId, + tracking: Arc, +} + +impl StateChangeObserver for Conn { + fn change(&mut self, _old_state: ConnectionState, new_state: ConnectionState) { match new_state { ConnectionState::Init | ConnectionState::Idle | ConnectionState::Transaction | ConnectionState::Busy - | ConnectionState::Unknown => self.update(conn_id, new_state), - ConnectionState::Closed => self.remove(conn_id), + | ConnectionState::Unknown => self.tracking.update(self.conn_id, new_state), + ConnectionState::Closed => self.tracking.remove(self.conn_id), } } } /// Called by `ConnectionTracker` whenever the `ConnectionState` changed. pub trait StateChangeObserver { - /// Identifier of the connection passed back on state change. - type ConnId: Copy; /// Called iff the connection's state changed. - fn change(&self, conn_id: Self::ConnId, old_state: ConnectionState, new_state: ConnectionState); + fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState); } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] @@ -101,20 +101,17 @@ impl fmt::Display for ConnectionState { pub struct ConnectionTracker { state: ConnectionState, observer: SCO, - conn_id: SCO::ConnId, } impl Drop for ConnectionTracker { fn drop(&mut self) { - self.observer - .change(self.conn_id, self.state, ConnectionState::Closed); + self.observer.change(self.state, ConnectionState::Closed); } } impl ConnectionTracker { - pub fn new(conn_id: SCO::ConnId, observer: SCO) -> Self { + pub fn new(observer: SCO) -> Self { ConnectionTracker { - conn_id, state: ConnectionState::default(), observer, } @@ -132,7 +129,7 @@ impl ConnectionTracker { let old_state = self.state; let new_state = new_state_fn(old_state); if old_state != new_state { - self.observer.change(self.conn_id, old_state, new_state); + self.observer.change(old_state, new_state); self.state = new_state; } } @@ -358,114 +355,28 @@ mod tests { use std::cell::RefCell; use std::io; use std::pin::Pin; - use std::pin::pin; use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio::io::{AsyncRead, ReadBuf}; use tokio::io::{AsyncReadExt, BufReader}; use super::*; - pub struct TrackedStream { - stream: S, - scanner: StreamScanner, - } - - impl Unpin for TrackedStream {} - - impl TrackedStream { - pub const fn new(stream: S, midstream: bool, observer: TO) -> Self { - TrackedStream { - stream, - scanner: StreamScanner::new(midstream, observer), - } - } - } - - impl AsyncRead for TrackedStream { - #[inline] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let Self { stream, scanner } = Pin::into_inner(self); - let StreamScanner { observer, state } = scanner; - - let old_len = buf.filled().len(); - match Pin::new(stream).poll_read(cx, buf) { - Poll::Ready(Ok(())) => { - let new_len = buf.filled().len(); - state.scan_bytes(&buf.filled()[old_len..new_len], observer); - Poll::Ready(Ok(())) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - } - } - } - - impl AsyncWrite for TrackedStream { - #[inline(always)] - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.stream).poll_write(cx, buf) - } - - #[inline(always)] - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_flush(cx) - } - - #[inline(always)] - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_shutdown(cx) - } - } - - #[derive(Debug)] - struct StreamScanner { - observer: TO, - state: StreamScannerState, - } - - impl StreamScanner { - const fn new(midstream: bool, observer: TO) -> Self { - StreamScanner { - observer, - state: if midstream { - StreamScannerState::Tag - } else { - StreamScannerState::Start - }, - } - } - } - - impl StreamScanner { - fn scan_bytes(&mut self, buf: &[u8]) { - self.state.scan_bytes(buf, &mut self.observer); - } - } - #[test] fn test_stream_scanner() { let tags = Rc::new(RefCell::new(Vec::new())); let observer_tags = tags.clone(); - let observer = move |tag| { + let mut observer = move |tag| { observer_tags.borrow_mut().push(tag); }; - let mut scanner = StreamScanner::new(false, observer); + let mut state = StreamScannerState::Start; - scanner.scan_bytes(&[0, 0]); + state.scan_bytes(&[0, 0], &mut observer); assert_eq!(tags.borrow().as_slice(), &[]); assert_eq!( - scanner.state, + state, StreamScannerState::Length { tag: Tag::Start, length_bytes_missing: 2, @@ -473,10 +384,10 @@ mod tests { } ); - scanner.scan_bytes(&[0x01, 0x01, 0x00]); + state.scan_bytes(&[0x01, 0x01, 0x00], &mut observer); assert_eq!(tags.borrow().as_slice(), &[]); assert_eq!( - scanner.state, + state, StreamScannerState::Payload { tag: Tag::Start, first: false, @@ -484,10 +395,10 @@ mod tests { } ); - scanner.scan_bytes(vec![0; 0x00000101 - 4 - 1 - 1].as_slice()); + state.scan_bytes(vec![0; 0x00000101 - 4 - 1 - 1].as_slice(), &mut observer); assert_eq!(tags.borrow().as_slice(), &[]); assert_eq!( - scanner.state, + state, StreamScannerState::Payload { tag: Tag::Start, first: false, @@ -495,10 +406,10 @@ mod tests { } ); - scanner.scan_bytes(&[0x00, b'A', 0x00, 0x00, 0x00, 0x08]); + state.scan_bytes(&[0x00, b'A', 0x00, 0x00, 0x00, 0x08], &mut observer); assert_eq!(tags.borrow().as_slice(), &[Tag::Start]); assert_eq!( - scanner.state, + state, StreamScannerState::Payload { tag: Tag::Message(b'A'), first: true, @@ -506,18 +417,18 @@ mod tests { } ); - scanner.scan_bytes(&[0, 0, 0, 0]); + state.scan_bytes(&[0, 0, 0, 0], &mut observer); assert_eq!(tags.borrow().as_slice(), &[Tag::Start, Tag::Message(b'A')]); - assert_eq!(scanner.state, StreamScannerState::Tag); + assert_eq!(state, StreamScannerState::Tag); - scanner.scan_bytes(&[b'Z', 0x00, 0x00, 0x00, 0x05, b'T']); + state.scan_bytes(&[b'Z', 0x00, 0x00, 0x00, 0x05, b'T'], &mut observer); assert_eq!( tags.borrow().as_slice(), &[Tag::Start, Tag::Message(b'A'), Tag::ReadyForQuery(b'T')] ); - assert_eq!(scanner.state, StreamScannerState::Tag); + assert_eq!(state, StreamScannerState::Tag); - scanner.scan_bytes(&[]); + state.scan_bytes(&[], &mut observer); assert_eq!( tags.borrow().as_slice(), &[ @@ -527,7 +438,7 @@ mod tests { Tag::End ] ); - assert_eq!(scanner.state, StreamScannerState::End); + assert_eq!(state, StreamScannerState::End); } #[tokio::test] @@ -535,20 +446,13 @@ mod tests { let transitions: Arc>> = Arc::default(); struct Observer(Arc>>); impl StateChangeObserver for Observer { - type ConnId = usize; - fn change( - &self, - conn_id: Self::ConnId, - old_state: ConnectionState, - new_state: ConnectionState, - ) { - assert_eq!(conn_id, 42); + fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) { self.0.lock().unwrap().push((old_state, new_state)); } } - let mut tracker = ConnectionTracker::new(42, Observer(transitions.clone())); + let mut tracker = ConnectionTracker::new(Observer(transitions.clone())); - let stream = TestStream::new( + let stream = BufReader::new( &[ 0, 0, 0, 4, // Init b'Z', 0, 0, 0, 5, b'I', // Init -> Idle @@ -557,7 +461,7 @@ mod tests { ][..], ); // AsyncRead - let mut stream = TrackedStream::new(stream, false, |tag| tracker.backend_message_tag(tag)); + let mut stream = TrackedStream::new(stream, |tag| tracker.backend_message_tag(tag)); let mut readbuf = [0; 2]; let n = stream.read_exact(&mut readbuf).await.unwrap(); @@ -614,41 +518,47 @@ mod tests { ); } - struct TestStream { - stream: BufReader<&'static [u8]>, + pub struct TrackedStream { + stream: S, + observer: TO, + state: StreamScannerState, } - impl TestStream { - fn new(data: &'static [u8]) -> Self { - TestStream { - stream: BufReader::new(data), + + impl Unpin for TrackedStream {} + + impl TrackedStream { + pub const fn new(stream: S, observer: TO) -> Self { + TrackedStream { + stream, + observer, + state: StreamScannerState::Start, } } } - impl AsyncRead for TestStream { + + impl AsyncRead for TrackedStream { + #[inline] fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - pin!(&mut self.stream).poll_read(cx, buf) - } - } - impl AsyncWrite for TestStream { - fn poll_write( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Poll::Ready(Ok(buf.len())) - } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) + let Self { + stream, + observer, + state, + } = Pin::into_inner(self); + + let old_len = buf.filled().len(); + match Pin::new(stream).poll_read(cx, buf) { + Poll::Ready(Ok(())) => { + let new_len = buf.filled().len(); + state.scan_bytes(&buf.filled()[old_len..new_len], observer); + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } } } } diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index b85d1629eb..8a0c53ef54 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -303,8 +303,6 @@ impl CopyBuffer { #[cfg(test)] mod tests { - use std::sync::Mutex; - use tokio::io::AsyncWriteExt; use crate::proxy::conntrack::ConnectionState; @@ -312,12 +310,11 @@ mod tests { use super::*; #[derive(Default)] - struct Observer(Mutex>); + struct Observer(Vec<(ConnectionState, ConnectionState)>); impl StateChangeObserver for Observer { - type ConnId = (); - fn change(&self, (): Self::ConnId, old_state: ConnectionState, new_state: ConnectionState) { - self.0.lock().unwrap().push((old_state, new_state)); + fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) { + self.0.push((old_state, new_state)); } } @@ -332,7 +329,7 @@ mod tests { compute_client.write_all(b"Neon").await.unwrap(); compute_client.shutdown().await.unwrap(); - let mut conn_tracker = ConnectionTracker::new((), Observer::default()); + let mut conn_tracker = ConnectionTracker::new(Observer::default()); let result = copy_bidirectional_client_compute( &mut client_proxy, @@ -361,7 +358,7 @@ mod tests { .await .unwrap(); - let mut conn_tracker = ConnectionTracker::new((), Observer::default()); + let mut conn_tracker = ConnectionTracker::new(Observer::default()); let result = copy_bidirectional_client_compute( &mut client_proxy,