mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 17:02:56 +00:00
cleanup
This commit is contained in:
@@ -12,9 +12,12 @@ pub struct ConnectionTracking {
|
||||
}
|
||||
|
||||
impl ConnectionTracking {
|
||||
pub fn new_tracker(self: &Arc<Self>) -> ConnectionTracker<Arc<Self>> {
|
||||
pub fn new_tracker(self: &Arc<Self>) -> ConnectionTracker<Conn> {
|
||||
let conn_id = self.new_conn_id();
|
||||
ConnectionTracker::new(conn_id, Arc::clone(self))
|
||||
ConnectionTracker::new(Conn {
|
||||
conn_id,
|
||||
tracking: Arc::clone(self),
|
||||
})
|
||||
}
|
||||
|
||||
fn new_conn_id(&self) -> ConnId {
|
||||
@@ -43,31 +46,28 @@ impl ConnectionTracking {
|
||||
}
|
||||
}
|
||||
|
||||
impl StateChangeObserver for Arc<ConnectionTracking> {
|
||||
type ConnId = ConnId;
|
||||
fn change(
|
||||
&self,
|
||||
conn_id: Self::ConnId,
|
||||
_old_state: ConnectionState,
|
||||
new_state: ConnectionState,
|
||||
) {
|
||||
pub struct Conn {
|
||||
conn_id: ConnId,
|
||||
tracking: Arc<ConnectionTracking>,
|
||||
}
|
||||
|
||||
impl StateChangeObserver for Conn {
|
||||
fn change(&mut self, _old_state: ConnectionState, new_state: ConnectionState) {
|
||||
match new_state {
|
||||
ConnectionState::Init
|
||||
| ConnectionState::Idle
|
||||
| ConnectionState::Transaction
|
||||
| ConnectionState::Busy
|
||||
| ConnectionState::Unknown => self.update(conn_id, new_state),
|
||||
ConnectionState::Closed => self.remove(conn_id),
|
||||
| ConnectionState::Unknown => self.tracking.update(self.conn_id, new_state),
|
||||
ConnectionState::Closed => self.tracking.remove(self.conn_id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Called by `ConnectionTracker` whenever the `ConnectionState` changed.
|
||||
pub trait StateChangeObserver {
|
||||
/// Identifier of the connection passed back on state change.
|
||||
type ConnId: Copy;
|
||||
/// Called iff the connection's state changed.
|
||||
fn change(&self, conn_id: Self::ConnId, old_state: ConnectionState, new_state: ConnectionState);
|
||||
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
|
||||
@@ -101,20 +101,17 @@ impl fmt::Display for ConnectionState {
|
||||
pub struct ConnectionTracker<SCO: StateChangeObserver> {
|
||||
state: ConnectionState,
|
||||
observer: SCO,
|
||||
conn_id: SCO::ConnId,
|
||||
}
|
||||
|
||||
impl<SCO: StateChangeObserver> Drop for ConnectionTracker<SCO> {
|
||||
fn drop(&mut self) {
|
||||
self.observer
|
||||
.change(self.conn_id, self.state, ConnectionState::Closed);
|
||||
self.observer.change(self.state, ConnectionState::Closed);
|
||||
}
|
||||
}
|
||||
|
||||
impl<SCO: StateChangeObserver> ConnectionTracker<SCO> {
|
||||
pub fn new(conn_id: SCO::ConnId, observer: SCO) -> Self {
|
||||
pub fn new(observer: SCO) -> Self {
|
||||
ConnectionTracker {
|
||||
conn_id,
|
||||
state: ConnectionState::default(),
|
||||
observer,
|
||||
}
|
||||
@@ -132,7 +129,7 @@ impl<SCO: StateChangeObserver> ConnectionTracker<SCO> {
|
||||
let old_state = self.state;
|
||||
let new_state = new_state_fn(old_state);
|
||||
if old_state != new_state {
|
||||
self.observer.change(self.conn_id, old_state, new_state);
|
||||
self.observer.change(old_state, new_state);
|
||||
self.state = new_state;
|
||||
}
|
||||
}
|
||||
@@ -358,114 +355,28 @@ mod tests {
|
||||
use std::cell::RefCell;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::pin::pin;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub struct TrackedStream<S, TO> {
|
||||
stream: S,
|
||||
scanner: StreamScanner<TO>,
|
||||
}
|
||||
|
||||
impl<S: Unpin, TO> Unpin for TrackedStream<S, TO> {}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, TO: TagObserver> TrackedStream<S, TO> {
|
||||
pub const fn new(stream: S, midstream: bool, observer: TO) -> Self {
|
||||
TrackedStream {
|
||||
stream,
|
||||
scanner: StreamScanner::new(midstream, observer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, TO: TagObserver> AsyncRead for TrackedStream<S, TO> {
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let Self { stream, scanner } = Pin::into_inner(self);
|
||||
let StreamScanner { observer, state } = scanner;
|
||||
|
||||
let old_len = buf.filled().len();
|
||||
match Pin::new(stream).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let new_len = buf.filled().len();
|
||||
state.scan_bytes(&buf.filled()[old_len..new_len], observer);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin, TO> AsyncWrite for TrackedStream<S, TO> {
|
||||
#[inline(always)]
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.stream).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.stream).poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.stream).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StreamScanner<TO> {
|
||||
observer: TO,
|
||||
state: StreamScannerState,
|
||||
}
|
||||
|
||||
impl<TO: TagObserver> StreamScanner<TO> {
|
||||
const fn new(midstream: bool, observer: TO) -> Self {
|
||||
StreamScanner {
|
||||
observer,
|
||||
state: if midstream {
|
||||
StreamScannerState::Tag
|
||||
} else {
|
||||
StreamScannerState::Start
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TO: TagObserver> StreamScanner<TO> {
|
||||
fn scan_bytes(&mut self, buf: &[u8]) {
|
||||
self.state.scan_bytes(buf, &mut self.observer);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_scanner() {
|
||||
let tags = Rc::new(RefCell::new(Vec::new()));
|
||||
let observer_tags = tags.clone();
|
||||
let observer = move |tag| {
|
||||
let mut observer = move |tag| {
|
||||
observer_tags.borrow_mut().push(tag);
|
||||
};
|
||||
let mut scanner = StreamScanner::new(false, observer);
|
||||
let mut state = StreamScannerState::Start;
|
||||
|
||||
scanner.scan_bytes(&[0, 0]);
|
||||
state.scan_bytes(&[0, 0], &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[]);
|
||||
assert_eq!(
|
||||
scanner.state,
|
||||
state,
|
||||
StreamScannerState::Length {
|
||||
tag: Tag::Start,
|
||||
length_bytes_missing: 2,
|
||||
@@ -473,10 +384,10 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
scanner.scan_bytes(&[0x01, 0x01, 0x00]);
|
||||
state.scan_bytes(&[0x01, 0x01, 0x00], &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[]);
|
||||
assert_eq!(
|
||||
scanner.state,
|
||||
state,
|
||||
StreamScannerState::Payload {
|
||||
tag: Tag::Start,
|
||||
first: false,
|
||||
@@ -484,10 +395,10 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
scanner.scan_bytes(vec![0; 0x00000101 - 4 - 1 - 1].as_slice());
|
||||
state.scan_bytes(vec![0; 0x00000101 - 4 - 1 - 1].as_slice(), &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[]);
|
||||
assert_eq!(
|
||||
scanner.state,
|
||||
state,
|
||||
StreamScannerState::Payload {
|
||||
tag: Tag::Start,
|
||||
first: false,
|
||||
@@ -495,10 +406,10 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
scanner.scan_bytes(&[0x00, b'A', 0x00, 0x00, 0x00, 0x08]);
|
||||
state.scan_bytes(&[0x00, b'A', 0x00, 0x00, 0x00, 0x08], &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[Tag::Start]);
|
||||
assert_eq!(
|
||||
scanner.state,
|
||||
state,
|
||||
StreamScannerState::Payload {
|
||||
tag: Tag::Message(b'A'),
|
||||
first: true,
|
||||
@@ -506,18 +417,18 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
scanner.scan_bytes(&[0, 0, 0, 0]);
|
||||
state.scan_bytes(&[0, 0, 0, 0], &mut observer);
|
||||
assert_eq!(tags.borrow().as_slice(), &[Tag::Start, Tag::Message(b'A')]);
|
||||
assert_eq!(scanner.state, StreamScannerState::Tag);
|
||||
assert_eq!(state, StreamScannerState::Tag);
|
||||
|
||||
scanner.scan_bytes(&[b'Z', 0x00, 0x00, 0x00, 0x05, b'T']);
|
||||
state.scan_bytes(&[b'Z', 0x00, 0x00, 0x00, 0x05, b'T'], &mut observer);
|
||||
assert_eq!(
|
||||
tags.borrow().as_slice(),
|
||||
&[Tag::Start, Tag::Message(b'A'), Tag::ReadyForQuery(b'T')]
|
||||
);
|
||||
assert_eq!(scanner.state, StreamScannerState::Tag);
|
||||
assert_eq!(state, StreamScannerState::Tag);
|
||||
|
||||
scanner.scan_bytes(&[]);
|
||||
state.scan_bytes(&[], &mut observer);
|
||||
assert_eq!(
|
||||
tags.borrow().as_slice(),
|
||||
&[
|
||||
@@ -527,7 +438,7 @@ mod tests {
|
||||
Tag::End
|
||||
]
|
||||
);
|
||||
assert_eq!(scanner.state, StreamScannerState::End);
|
||||
assert_eq!(state, StreamScannerState::End);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -535,20 +446,13 @@ mod tests {
|
||||
let transitions: Arc<Mutex<Vec<(ConnectionState, ConnectionState)>>> = Arc::default();
|
||||
struct Observer(Arc<Mutex<Vec<(ConnectionState, ConnectionState)>>>);
|
||||
impl StateChangeObserver for Observer {
|
||||
type ConnId = usize;
|
||||
fn change(
|
||||
&self,
|
||||
conn_id: Self::ConnId,
|
||||
old_state: ConnectionState,
|
||||
new_state: ConnectionState,
|
||||
) {
|
||||
assert_eq!(conn_id, 42);
|
||||
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) {
|
||||
self.0.lock().unwrap().push((old_state, new_state));
|
||||
}
|
||||
}
|
||||
let mut tracker = ConnectionTracker::new(42, Observer(transitions.clone()));
|
||||
let mut tracker = ConnectionTracker::new(Observer(transitions.clone()));
|
||||
|
||||
let stream = TestStream::new(
|
||||
let stream = BufReader::new(
|
||||
&[
|
||||
0, 0, 0, 4, // Init
|
||||
b'Z', 0, 0, 0, 5, b'I', // Init -> Idle
|
||||
@@ -557,7 +461,7 @@ mod tests {
|
||||
][..],
|
||||
);
|
||||
// AsyncRead
|
||||
let mut stream = TrackedStream::new(stream, false, |tag| tracker.backend_message_tag(tag));
|
||||
let mut stream = TrackedStream::new(stream, |tag| tracker.backend_message_tag(tag));
|
||||
|
||||
let mut readbuf = [0; 2];
|
||||
let n = stream.read_exact(&mut readbuf).await.unwrap();
|
||||
@@ -614,41 +518,47 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
struct TestStream {
|
||||
stream: BufReader<&'static [u8]>,
|
||||
pub struct TrackedStream<S, TO> {
|
||||
stream: S,
|
||||
observer: TO,
|
||||
state: StreamScannerState,
|
||||
}
|
||||
impl TestStream {
|
||||
fn new(data: &'static [u8]) -> Self {
|
||||
TestStream {
|
||||
stream: BufReader::new(data),
|
||||
|
||||
impl<S: Unpin, TO> Unpin for TrackedStream<S, TO> {}
|
||||
|
||||
impl<S: AsyncRead + Unpin, TO: TagObserver> TrackedStream<S, TO> {
|
||||
pub const fn new(stream: S, observer: TO) -> Self {
|
||||
TrackedStream {
|
||||
stream,
|
||||
observer,
|
||||
state: StreamScannerState::Start,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AsyncRead for TestStream {
|
||||
|
||||
impl<S: AsyncRead + Unpin, TO: TagObserver> AsyncRead for TrackedStream<S, TO> {
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
pin!(&mut self.stream).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
impl AsyncWrite for TestStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
let Self {
|
||||
stream,
|
||||
observer,
|
||||
state,
|
||||
} = Pin::into_inner(self);
|
||||
|
||||
let old_len = buf.filled().len();
|
||||
match Pin::new(stream).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let new_len = buf.filled().len();
|
||||
state.scan_bytes(&buf.filled()[old_len..new_len], observer);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,8 +303,6 @@ impl CopyBuffer {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Mutex;
|
||||
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use crate::proxy::conntrack::ConnectionState;
|
||||
@@ -312,12 +310,11 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
#[derive(Default)]
|
||||
struct Observer(Mutex<Vec<(ConnectionState, ConnectionState)>>);
|
||||
struct Observer(Vec<(ConnectionState, ConnectionState)>);
|
||||
|
||||
impl StateChangeObserver for Observer {
|
||||
type ConnId = ();
|
||||
fn change(&self, (): Self::ConnId, old_state: ConnectionState, new_state: ConnectionState) {
|
||||
self.0.lock().unwrap().push((old_state, new_state));
|
||||
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) {
|
||||
self.0.push((old_state, new_state));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -332,7 +329,7 @@ mod tests {
|
||||
compute_client.write_all(b"Neon").await.unwrap();
|
||||
compute_client.shutdown().await.unwrap();
|
||||
|
||||
let mut conn_tracker = ConnectionTracker::new((), Observer::default());
|
||||
let mut conn_tracker = ConnectionTracker::new(Observer::default());
|
||||
|
||||
let result = copy_bidirectional_client_compute(
|
||||
&mut client_proxy,
|
||||
@@ -361,7 +358,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut conn_tracker = ConnectionTracker::new((), Observer::default());
|
||||
let mut conn_tracker = ConnectionTracker::new(Observer::default());
|
||||
|
||||
let result = copy_bidirectional_client_compute(
|
||||
&mut client_proxy,
|
||||
|
||||
Reference in New Issue
Block a user