diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index f8aceb5263..f5aed010ef 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -15,6 +15,7 @@ use tokio::sync::mpsc; use crate::cancel_token::RawCancelToken; use crate::codec::{BackendMessages, FrontendMessage, RecordNotices}; use crate::config::{Host, SslMode}; +use crate::connection::gc_bytesmut; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; use crate::types::{Oid, Type}; @@ -95,20 +96,13 @@ impl InnerClient { Ok(PartialQuery(Some(self))) } - // pub fn send_with_sync(&mut self, f: F) -> Result<&mut Responses, Error> - // where - // F: FnOnce(&mut BytesMut) -> Result<(), Error>, - // { - // self.start()?.send_with_sync(f) - // } - pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> { self.responses.waiting += 1; self.buffer.clear(); // simple queries do not need sync. frontend::query(query, &mut self.buffer).map_err(Error::encode)?; - let buf = self.buffer.split().freeze(); + let buf = self.buffer.split(); self.send_message(FrontendMessage::Raw(buf)) } @@ -125,7 +119,7 @@ impl Drop for PartialQuery<'_> { if let Some(client) = self.0.take() { client.buffer.clear(); frontend::sync(&mut client.buffer); - let buf = client.buffer.split().freeze(); + let buf = client.buffer.split(); let _ = client.send_message(FrontendMessage::Raw(buf)); } } @@ -141,7 +135,7 @@ impl<'a> PartialQuery<'a> { client.buffer.clear(); f(&mut client.buffer)?; frontend::flush(&mut client.buffer); - let buf = client.buffer.split().freeze(); + let buf = client.buffer.split(); client.send_message(FrontendMessage::Raw(buf)) } @@ -154,7 +148,7 @@ impl<'a> PartialQuery<'a> { client.buffer.clear(); f(&mut client.buffer)?; frontend::sync(&mut client.buffer); - let buf = client.buffer.split().freeze(); + let buf = client.buffer.split(); let _ = client.send_message(FrontendMessage::Raw(buf)); Ok(&mut self.0.take().unwrap().responses) @@ -317,6 +311,9 @@ impl Client { DISCARD SEQUENCES;", )?; + // Clean up memory usage. + gc_bytesmut(&mut self.inner_mut().buffer); + Ok(()) } diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs index 813faa0e35..71fe062fca 100644 --- a/libs/proxy/tokio-postgres2/src/codec.rs +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -1,13 +1,13 @@ use std::io; -use bytes::{Bytes, BytesMut}; +use bytes::BytesMut; use fallible_iterator::FallibleIterator; use postgres_protocol2::message::backend; use tokio::sync::mpsc::UnboundedSender; use tokio_util::codec::{Decoder, Encoder}; pub enum FrontendMessage { - Raw(Bytes), + Raw(BytesMut), RecordNotices(RecordNotices), } @@ -17,7 +17,10 @@ pub struct RecordNotices { } pub enum BackendMessage { - Normal { messages: BackendMessages }, + Normal { + messages: BackendMessages, + ready: bool, + }, Async(backend::Message), } @@ -40,11 +43,18 @@ impl FallibleIterator for BackendMessages { pub struct PostgresCodec; -impl Encoder for PostgresCodec { +impl Encoder for PostgresCodec { type Error = io::Error; - fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> { - dst.extend_from_slice(&item); + fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> io::Result<()> { + // When it comes to request/response workflows, we usually flush the entire write + // buffer in order to wait for the response before we send a new request. + // Therefore we can avoid the copy and just replace the buffer. + if dst.is_empty() { + *dst = item; + } else { + dst.extend_from_slice(&item); + } Ok(()) } } @@ -56,6 +66,7 @@ impl Decoder for PostgresCodec { fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { let mut idx = 0; + let mut ready = false; while let Some(header) = backend::Header::parse(&src[idx..])? { let len = header.len() as usize + 1; if src[idx..].len() < len { @@ -79,6 +90,7 @@ impl Decoder for PostgresCodec { idx += len; if header.tag() == backend::READY_FOR_QUERY_TAG { + ready = true; break; } } @@ -88,6 +100,7 @@ impl Decoder for PostgresCodec { } else { Ok(Some(BackendMessage::Normal { messages: BackendMessages(src.split_to(idx)), + ready, })) } } diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index c619f92d13..3579dd94a2 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -250,19 +250,20 @@ impl Config { { let stream = connect_tls(stream, self.ssl_mode, tls).await?; let mut stream = StartupStream::new(stream); - connect_raw::startup(&mut stream, self).await?; connect_raw::authenticate(&mut stream, self).await?; Ok(stream) } - pub async fn authenticate(&self, stream: &mut StartupStream) -> Result<(), Error> + pub fn authenticate( + &self, + stream: &mut StartupStream, + ) -> impl Future> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { - connect_raw::startup(stream, self).await?; - connect_raw::authenticate(stream, self).await + connect_raw::authenticate(stream, self) } } diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index bc35cef339..17237eeef5 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -2,51 +2,28 @@ use std::io; use std::pin::Pin; use std::task::{Context, Poll, ready}; -use bytes::{Bytes, BytesMut}; +use bytes::BytesMut; use fallible_iterator::FallibleIterator; -use futures_util::{Sink, SinkExt, Stream, TryStreamExt}; +use futures_util::{SinkExt, Stream, TryStreamExt}; use postgres_protocol2::authentication::sasl; use postgres_protocol2::authentication::sasl::ScramSha256; use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message}; use postgres_protocol2::message::frontend; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_util::codec::{Framed, FramedParts, FramedWrite}; +use tokio_util::codec::{Framed, FramedParts}; use crate::Error; use crate::codec::PostgresCodec; use crate::config::{self, AuthKeys, Config}; +use crate::connection::{GC_THRESHOLD, INITIAL_CAPACITY}; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::TlsStream; pub struct StartupStream { - inner: FramedWrite, PostgresCodec>, + inner: Framed, PostgresCodec>, read_buf: BytesMut, } -impl Sink for StartupStream -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - type Error = io::Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_ready(cx) - } - - fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> { - Pin::new(&mut self.inner).start_send(item) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_close(cx) - } -} - impl Stream for StartupStream where S: AsyncRead + AsyncWrite + Unpin, @@ -55,6 +32,8 @@ where type Item = io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // We don't use `self.inner.poll_next()` as that might over-read into the read buffer. + // read 1 byte tag, 4 bytes length. let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?); @@ -121,36 +100,28 @@ where } pub fn into_framed(mut self) -> Framed, PostgresCodec> { - let write_buf = std::mem::take(self.inner.write_buffer_mut()); - let io = self.inner.into_inner(); - let mut parts = FramedParts::new(io, PostgresCodec); - parts.read_buf = self.read_buf; - parts.write_buf = write_buf; - Framed::from_parts(parts) + *self.inner.read_buffer_mut() = self.read_buf; + self.inner } pub fn new(io: MaybeTlsStream) -> Self { + let mut parts = FramedParts::new(io, PostgresCodec); + parts.write_buf = BytesMut::with_capacity(INITIAL_CAPACITY); + + let mut inner = Framed::from_parts(parts); + + // This is the default already, but nice to be explicit. + // We divide by two because writes will overshoot the boundary. + // We don't want constant overshoots to cause us to constantly re-shrink the buffer. + inner.set_backpressure_boundary(GC_THRESHOLD / 2); + Self { - inner: FramedWrite::new(io, PostgresCodec), - read_buf: BytesMut::new(), + inner, + read_buf: BytesMut::with_capacity(INITIAL_CAPACITY), } } } -pub(crate) async fn startup( - stream: &mut StartupStream, - config: &Config, -) -> Result<(), Error> -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - let mut buf = BytesMut::new(); - frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?; - - stream.send(buf.freeze()).await.map_err(Error::io) -} - pub(crate) async fn authenticate( stream: &mut StartupStream, config: &Config, @@ -159,6 +130,10 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { + frontend::startup_message(&config.server_params, stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; + + stream.inner.flush().await.map_err(Error::io)?; match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => { can_skip_channel_binding(config)?; @@ -172,7 +147,8 @@ where .as_ref() .ok_or_else(|| Error::config("password missing".into()))?; - authenticate_password(stream, pass).await?; + frontend::password_message(pass, stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; } Some(Message::AuthenticationSasl(body)) => { authenticate_sasl(stream, body, config).await?; @@ -191,6 +167,7 @@ where None => return Err(Error::closed()), } + stream.inner.flush().await.map_err(Error::io)?; match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => Ok(()), Some(Message::ErrorResponse(body)) => Err(Error::db(body)), @@ -208,20 +185,6 @@ fn can_skip_channel_binding(config: &Config) -> Result<(), Error> { } } -async fn authenticate_password( - stream: &mut StartupStream, - password: &[u8], -) -> Result<(), Error> -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - let mut buf = BytesMut::new(); - frontend::password_message(password, &mut buf).map_err(Error::encode)?; - - stream.send(buf.freeze()).await.map_err(Error::io) -} - async fn authenticate_sasl( stream: &mut StartupStream, body: AuthenticationSaslBody, @@ -276,10 +239,10 @@ where return Err(Error::config("password or auth keys missing".into())); }; - let mut buf = BytesMut::new(); - frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; - stream.send(buf.freeze()).await.map_err(Error::io)?; + frontend::sasl_initial_response(mechanism, scram.message(), stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; + stream.inner.flush().await.map_err(Error::io)?; let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslContinue(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), @@ -292,10 +255,10 @@ where .await .map_err(|e| Error::authentication(e.into()))?; - let mut buf = BytesMut::new(); - frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?; - stream.send(buf.freeze()).await.map_err(Error::io)?; + frontend::sasl_response(scram.message(), stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; + stream.inner.flush().await.map_err(Error::io)?; let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslFinal(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs index c43a22ffe7..bee4b3372d 100644 --- a/libs/proxy/tokio-postgres2/src/connection.rs +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -44,6 +44,27 @@ pub struct Connection { state: State, } +pub const INITIAL_CAPACITY: usize = 2 * 1024; +pub const GC_THRESHOLD: usize = 16 * 1024; + +/// Gargabe collect the [`BytesMut`] if it has too much spare capacity. +pub fn gc_bytesmut(buf: &mut BytesMut) { + // We use a different mode to shrink the buf when above the threshold. + // When above the threshold, we only re-allocate when the buf has 2x spare capacity. + let reclaim = GC_THRESHOLD.checked_sub(buf.len()).unwrap_or(buf.len()); + + // `try_reclaim` tries to get the capacity from any shared `BytesMut`s, + // before then comparing the length against the capacity. + if buf.try_reclaim(reclaim) { + let capacity = usize::max(buf.len(), INITIAL_CAPACITY); + + // Allocate a new `BytesMut` so that we deallocate the old version. + let mut new = BytesMut::with_capacity(capacity); + new.extend_from_slice(buf); + *buf = new; + } +} + pub enum Never {} impl Connection @@ -86,7 +107,14 @@ where continue; } BackendMessage::Async(_) => continue, - BackendMessage::Normal { messages } => messages, + BackendMessage::Normal { messages, ready } => { + // if we read a ReadyForQuery from postgres, let's try GC the read buffer. + if ready { + gc_bytesmut(self.stream.read_buffer_mut()); + } + + messages + } } } }; @@ -177,12 +205,7 @@ where // Send a terminate message to postgres Poll::Ready(None) => { trace!("poll_write: at eof, terminating"); - let mut request = BytesMut::new(); - frontend::terminate(&mut request); - - Pin::new(&mut self.stream) - .start_send(request.freeze()) - .map_err(Error::io)?; + frontend::terminate(self.stream.write_buffer_mut()); trace!("poll_write: sent eof, closing"); trace!("poll_write: done"); @@ -205,6 +228,10 @@ where { Poll::Ready(()) => { trace!("poll_flush: flushed"); + + // GC the write buffer if we managed to flush + gc_bytesmut(self.stream.write_buffer_mut()); + Poll::Ready(Ok(())) } Poll::Pending => {