diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index a9e5fbc85b..d1f8430b8a 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -91,6 +91,7 @@ mod jemalloc; mod logging; mod metrics; mod parse; +mod pglb; mod protocol2; mod proxy; mod rate_limiter; diff --git a/proxy/src/pglb/inprocess.rs b/proxy/src/pglb/inprocess.rs new file mode 100644 index 0000000000..905f82f909 --- /dev/null +++ b/proxy/src/pglb/inprocess.rs @@ -0,0 +1,193 @@ +#![allow(dead_code, reason = "TODO: work in progress")] + +use std::pin::{Pin, pin}; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use std::{fmt, io}; + +use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf}; +use tokio::sync::mpsc; + +const STREAM_CHANNEL_SIZE: usize = 16; +const MAX_STREAM_BUFFER_SIZE: usize = 4096; + +#[derive(Debug)] +pub struct Connection { + stream_sender: mpsc::Sender, + stream_receiver: mpsc::Receiver, + stream_id_counter: Arc, +} + +impl Connection { + pub fn new() -> (Connection, Connection) { + let (sender_a, receiver_a) = mpsc::channel(STREAM_CHANNEL_SIZE); + let (sender_b, receiver_b) = mpsc::channel(STREAM_CHANNEL_SIZE); + + let stream_id_counter = Arc::new(AtomicUsize::new(1)); + + let conn_a = Connection { + stream_sender: sender_a, + stream_receiver: receiver_b, + stream_id_counter: Arc::clone(&stream_id_counter), + }; + let conn_b = Connection { + stream_sender: sender_b, + stream_receiver: receiver_a, + stream_id_counter, + }; + + (conn_a, conn_b) + } + + #[inline] + fn next_stream_id(&self) -> StreamId { + StreamId(self.stream_id_counter.fetch_add(1, Ordering::Relaxed)) + } + + #[tracing::instrument(skip_all, fields(stream_id = tracing::field::Empty, err))] + pub async fn open_stream(&self) -> io::Result { + let (local, remote) = tokio::io::duplex(MAX_STREAM_BUFFER_SIZE); + let stream_id = self.next_stream_id(); + tracing::Span::current().record("stream_id", stream_id.0); + + let local = Stream { + inner: local, + id: stream_id, + }; + let remote = Stream { + inner: remote, + id: stream_id, + }; + + self.stream_sender + .send(remote) + .await + .map_err(io::Error::other)?; + + Ok(local) + } + + #[tracing::instrument(skip_all, fields(stream_id = tracing::field::Empty, err))] + pub async fn accept_stream(&mut self) -> io::Result> { + Ok(self.stream_receiver.recv().await.inspect(|stream| { + tracing::Span::current().record("stream_id", stream.id.0); + })) + } +} + +#[derive(Copy, Clone, Debug)] +pub struct StreamId(usize); + +impl fmt::Display for StreamId { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +// TODO: Proper closing. Currently Streams can outlive their Connections. +// Carry WeakSender and check strong_count? +#[derive(Debug)] +pub struct Stream { + inner: DuplexStream, + id: StreamId, +} + +impl Stream { + #[inline] + pub fn id(&self) -> StreamId { + self.id + } +} + +impl AsyncRead for Stream { + #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))] + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + pin!(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for Stream { + #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))] + #[inline] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + pin!(&mut self.inner).poll_write(cx, buf) + } + + #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))] + #[inline] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(&mut self.inner).poll_flush(cx) + } + + #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))] + #[inline] + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + pin!(&mut self.inner).poll_shutdown(cx) + } + + #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))] + #[inline] + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + pin!(&mut self.inner).poll_write_vectored(cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +#[cfg(test)] +mod tests { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + use super::*; + + #[tokio::test] + async fn test_simple_roundtrip() { + let (client, mut server) = Connection::new(); + + let server_task = tokio::spawn(async move { + while let Some(mut stream) = server.accept_stream().await.unwrap() { + tokio::spawn(async move { + let mut buf = [0; 64]; + loop { + match stream.read(&mut buf).await.unwrap() { + 0 => break, + n => stream.write(&buf[..n]).await.unwrap(), + }; + } + }); + } + }); + + let mut stream = client.open_stream().await.unwrap(); + stream.write_all(b"hello!").await.unwrap(); + let mut buf = [0; 64]; + let n = stream.read(&mut buf).await.unwrap(); + assert_eq!(n, 6); + assert_eq!(&buf[..n], b"hello!"); + + drop(stream); + drop(client); + server_task.await.unwrap(); + } +} diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs new file mode 100644 index 0000000000..1088859fb9 --- /dev/null +++ b/proxy/src/pglb/mod.rs @@ -0,0 +1 @@ +pub mod inprocess;