This commit is contained in:
Conrad Ludgate
2025-05-12 18:16:15 +01:00
parent c55742b437
commit b2e0ab5dc6
2 changed files with 77 additions and 170 deletions

View File

@@ -12,9 +12,12 @@ pub struct ConnectionTracking {
}
impl ConnectionTracking {
pub fn new_tracker(self: &Arc<Self>) -> ConnectionTracker<Arc<Self>> {
pub fn new_tracker(self: &Arc<Self>) -> ConnectionTracker<Conn> {
let conn_id = self.new_conn_id();
ConnectionTracker::new(conn_id, Arc::clone(self))
ConnectionTracker::new(Conn {
conn_id,
tracking: Arc::clone(self),
})
}
fn new_conn_id(&self) -> ConnId {
@@ -43,31 +46,28 @@ impl ConnectionTracking {
}
}
impl StateChangeObserver for Arc<ConnectionTracking> {
type ConnId = ConnId;
fn change(
&self,
conn_id: Self::ConnId,
_old_state: ConnectionState,
new_state: ConnectionState,
) {
pub struct Conn {
conn_id: ConnId,
tracking: Arc<ConnectionTracking>,
}
impl StateChangeObserver for Conn {
fn change(&mut self, _old_state: ConnectionState, new_state: ConnectionState) {
match new_state {
ConnectionState::Init
| ConnectionState::Idle
| ConnectionState::Transaction
| ConnectionState::Busy
| ConnectionState::Unknown => self.update(conn_id, new_state),
ConnectionState::Closed => self.remove(conn_id),
| ConnectionState::Unknown => self.tracking.update(self.conn_id, new_state),
ConnectionState::Closed => self.tracking.remove(self.conn_id),
}
}
}
/// Called by `ConnectionTracker` whenever the `ConnectionState` changed.
pub trait StateChangeObserver {
/// Identifier of the connection passed back on state change.
type ConnId: Copy;
/// Called iff the connection's state changed.
fn change(&self, conn_id: Self::ConnId, old_state: ConnectionState, new_state: ConnectionState);
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState);
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
@@ -101,20 +101,17 @@ impl fmt::Display for ConnectionState {
pub struct ConnectionTracker<SCO: StateChangeObserver> {
state: ConnectionState,
observer: SCO,
conn_id: SCO::ConnId,
}
impl<SCO: StateChangeObserver> Drop for ConnectionTracker<SCO> {
fn drop(&mut self) {
self.observer
.change(self.conn_id, self.state, ConnectionState::Closed);
self.observer.change(self.state, ConnectionState::Closed);
}
}
impl<SCO: StateChangeObserver> ConnectionTracker<SCO> {
pub fn new(conn_id: SCO::ConnId, observer: SCO) -> Self {
pub fn new(observer: SCO) -> Self {
ConnectionTracker {
conn_id,
state: ConnectionState::default(),
observer,
}
@@ -132,7 +129,7 @@ impl<SCO: StateChangeObserver> ConnectionTracker<SCO> {
let old_state = self.state;
let new_state = new_state_fn(old_state);
if old_state != new_state {
self.observer.change(self.conn_id, old_state, new_state);
self.observer.change(old_state, new_state);
self.state = new_state;
}
}
@@ -358,114 +355,28 @@ mod tests {
use std::cell::RefCell;
use std::io;
use std::pin::Pin;
use std::pin::pin;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, ReadBuf};
use tokio::io::{AsyncReadExt, BufReader};
use super::*;
pub struct TrackedStream<S, TO> {
stream: S,
scanner: StreamScanner<TO>,
}
impl<S: Unpin, TO> Unpin for TrackedStream<S, TO> {}
impl<S: AsyncRead + AsyncWrite + Unpin, TO: TagObserver> TrackedStream<S, TO> {
pub const fn new(stream: S, midstream: bool, observer: TO) -> Self {
TrackedStream {
stream,
scanner: StreamScanner::new(midstream, observer),
}
}
}
impl<S: AsyncRead + Unpin, TO: TagObserver> AsyncRead for TrackedStream<S, TO> {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let Self { stream, scanner } = Pin::into_inner(self);
let StreamScanner { observer, state } = scanner;
let old_len = buf.filled().len();
match Pin::new(stream).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let new_len = buf.filled().len();
state.scan_bytes(&buf.filled()[old_len..new_len], observer);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
impl<S: AsyncWrite + Unpin, TO> AsyncWrite for TrackedStream<S, TO> {
#[inline(always)]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
#[inline(always)]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
#[inline(always)]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
}
#[derive(Debug)]
struct StreamScanner<TO> {
observer: TO,
state: StreamScannerState,
}
impl<TO: TagObserver> StreamScanner<TO> {
const fn new(midstream: bool, observer: TO) -> Self {
StreamScanner {
observer,
state: if midstream {
StreamScannerState::Tag
} else {
StreamScannerState::Start
},
}
}
}
impl<TO: TagObserver> StreamScanner<TO> {
fn scan_bytes(&mut self, buf: &[u8]) {
self.state.scan_bytes(buf, &mut self.observer);
}
}
#[test]
fn test_stream_scanner() {
let tags = Rc::new(RefCell::new(Vec::new()));
let observer_tags = tags.clone();
let observer = move |tag| {
let mut observer = move |tag| {
observer_tags.borrow_mut().push(tag);
};
let mut scanner = StreamScanner::new(false, observer);
let mut state = StreamScannerState::Start;
scanner.scan_bytes(&[0, 0]);
state.scan_bytes(&[0, 0], &mut observer);
assert_eq!(tags.borrow().as_slice(), &[]);
assert_eq!(
scanner.state,
state,
StreamScannerState::Length {
tag: Tag::Start,
length_bytes_missing: 2,
@@ -473,10 +384,10 @@ mod tests {
}
);
scanner.scan_bytes(&[0x01, 0x01, 0x00]);
state.scan_bytes(&[0x01, 0x01, 0x00], &mut observer);
assert_eq!(tags.borrow().as_slice(), &[]);
assert_eq!(
scanner.state,
state,
StreamScannerState::Payload {
tag: Tag::Start,
first: false,
@@ -484,10 +395,10 @@ mod tests {
}
);
scanner.scan_bytes(vec![0; 0x00000101 - 4 - 1 - 1].as_slice());
state.scan_bytes(vec![0; 0x00000101 - 4 - 1 - 1].as_slice(), &mut observer);
assert_eq!(tags.borrow().as_slice(), &[]);
assert_eq!(
scanner.state,
state,
StreamScannerState::Payload {
tag: Tag::Start,
first: false,
@@ -495,10 +406,10 @@ mod tests {
}
);
scanner.scan_bytes(&[0x00, b'A', 0x00, 0x00, 0x00, 0x08]);
state.scan_bytes(&[0x00, b'A', 0x00, 0x00, 0x00, 0x08], &mut observer);
assert_eq!(tags.borrow().as_slice(), &[Tag::Start]);
assert_eq!(
scanner.state,
state,
StreamScannerState::Payload {
tag: Tag::Message(b'A'),
first: true,
@@ -506,18 +417,18 @@ mod tests {
}
);
scanner.scan_bytes(&[0, 0, 0, 0]);
state.scan_bytes(&[0, 0, 0, 0], &mut observer);
assert_eq!(tags.borrow().as_slice(), &[Tag::Start, Tag::Message(b'A')]);
assert_eq!(scanner.state, StreamScannerState::Tag);
assert_eq!(state, StreamScannerState::Tag);
scanner.scan_bytes(&[b'Z', 0x00, 0x00, 0x00, 0x05, b'T']);
state.scan_bytes(&[b'Z', 0x00, 0x00, 0x00, 0x05, b'T'], &mut observer);
assert_eq!(
tags.borrow().as_slice(),
&[Tag::Start, Tag::Message(b'A'), Tag::ReadyForQuery(b'T')]
);
assert_eq!(scanner.state, StreamScannerState::Tag);
assert_eq!(state, StreamScannerState::Tag);
scanner.scan_bytes(&[]);
state.scan_bytes(&[], &mut observer);
assert_eq!(
tags.borrow().as_slice(),
&[
@@ -527,7 +438,7 @@ mod tests {
Tag::End
]
);
assert_eq!(scanner.state, StreamScannerState::End);
assert_eq!(state, StreamScannerState::End);
}
#[tokio::test]
@@ -535,20 +446,13 @@ mod tests {
let transitions: Arc<Mutex<Vec<(ConnectionState, ConnectionState)>>> = Arc::default();
struct Observer(Arc<Mutex<Vec<(ConnectionState, ConnectionState)>>>);
impl StateChangeObserver for Observer {
type ConnId = usize;
fn change(
&self,
conn_id: Self::ConnId,
old_state: ConnectionState,
new_state: ConnectionState,
) {
assert_eq!(conn_id, 42);
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) {
self.0.lock().unwrap().push((old_state, new_state));
}
}
let mut tracker = ConnectionTracker::new(42, Observer(transitions.clone()));
let mut tracker = ConnectionTracker::new(Observer(transitions.clone()));
let stream = TestStream::new(
let stream = BufReader::new(
&[
0, 0, 0, 4, // Init
b'Z', 0, 0, 0, 5, b'I', // Init -> Idle
@@ -557,7 +461,7 @@ mod tests {
][..],
);
// AsyncRead
let mut stream = TrackedStream::new(stream, false, |tag| tracker.backend_message_tag(tag));
let mut stream = TrackedStream::new(stream, |tag| tracker.backend_message_tag(tag));
let mut readbuf = [0; 2];
let n = stream.read_exact(&mut readbuf).await.unwrap();
@@ -614,41 +518,47 @@ mod tests {
);
}
struct TestStream {
stream: BufReader<&'static [u8]>,
pub struct TrackedStream<S, TO> {
stream: S,
observer: TO,
state: StreamScannerState,
}
impl TestStream {
fn new(data: &'static [u8]) -> Self {
TestStream {
stream: BufReader::new(data),
impl<S: Unpin, TO> Unpin for TrackedStream<S, TO> {}
impl<S: AsyncRead + Unpin, TO: TagObserver> TrackedStream<S, TO> {
pub const fn new(stream: S, observer: TO) -> Self {
TrackedStream {
stream,
observer,
state: StreamScannerState::Start,
}
}
}
impl AsyncRead for TestStream {
impl<S: AsyncRead + Unpin, TO: TagObserver> AsyncRead for TrackedStream<S, TO> {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
pin!(&mut self.stream).poll_read(cx, buf)
}
}
impl AsyncWrite for TestStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
let Self {
stream,
observer,
state,
} = Pin::into_inner(self);
let old_len = buf.filled().len();
match Pin::new(stream).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let new_len = buf.filled().len();
state.scan_bytes(&buf.filled()[old_len..new_len], observer);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
}

View File

@@ -303,8 +303,6 @@ impl CopyBuffer {
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use tokio::io::AsyncWriteExt;
use crate::proxy::conntrack::ConnectionState;
@@ -312,12 +310,11 @@ mod tests {
use super::*;
#[derive(Default)]
struct Observer(Mutex<Vec<(ConnectionState, ConnectionState)>>);
struct Observer(Vec<(ConnectionState, ConnectionState)>);
impl StateChangeObserver for Observer {
type ConnId = ();
fn change(&self, (): Self::ConnId, old_state: ConnectionState, new_state: ConnectionState) {
self.0.lock().unwrap().push((old_state, new_state));
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) {
self.0.push((old_state, new_state));
}
}
@@ -332,7 +329,7 @@ mod tests {
compute_client.write_all(b"Neon").await.unwrap();
compute_client.shutdown().await.unwrap();
let mut conn_tracker = ConnectionTracker::new((), Observer::default());
let mut conn_tracker = ConnectionTracker::new(Observer::default());
let result = copy_bidirectional_client_compute(
&mut client_proxy,
@@ -361,7 +358,7 @@ mod tests {
.await
.unwrap();
let mut conn_tracker = ConnectionTracker::new((), Observer::default());
let mut conn_tracker = ConnectionTracker::new(Observer::default());
let result = copy_bidirectional_client_compute(
&mut client_proxy,