mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-22 21:59:59 +00:00
[proxy/tokio-postgres] garbage collection for codec buffers (#12701)
## Problem A large insert or a large row will cause the codec to allocate a large buffer. The codec never shrinks the buffer however. LKB-2496 ## Summary of changes 1. Introduce a naive GC system for codec buffers 2. Try and reduce copies as much as possible
This commit is contained in:
@@ -15,6 +15,7 @@ use tokio::sync::mpsc;
|
||||
use crate::cancel_token::RawCancelToken;
|
||||
use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
|
||||
use crate::config::{Host, SslMode};
|
||||
use crate::connection::gc_bytesmut;
|
||||
use crate::query::RowStream;
|
||||
use crate::simple_query::SimpleQueryStream;
|
||||
use crate::types::{Oid, Type};
|
||||
@@ -95,20 +96,13 @@ impl InnerClient {
|
||||
Ok(PartialQuery(Some(self)))
|
||||
}
|
||||
|
||||
// pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
|
||||
// where
|
||||
// F: FnOnce(&mut BytesMut) -> Result<(), Error>,
|
||||
// {
|
||||
// self.start()?.send_with_sync(f)
|
||||
// }
|
||||
|
||||
pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
|
||||
self.responses.waiting += 1;
|
||||
|
||||
self.buffer.clear();
|
||||
// simple queries do not need sync.
|
||||
frontend::query(query, &mut self.buffer).map_err(Error::encode)?;
|
||||
let buf = self.buffer.split().freeze();
|
||||
let buf = self.buffer.split();
|
||||
self.send_message(FrontendMessage::Raw(buf))
|
||||
}
|
||||
|
||||
@@ -125,7 +119,7 @@ impl Drop for PartialQuery<'_> {
|
||||
if let Some(client) = self.0.take() {
|
||||
client.buffer.clear();
|
||||
frontend::sync(&mut client.buffer);
|
||||
let buf = client.buffer.split().freeze();
|
||||
let buf = client.buffer.split();
|
||||
let _ = client.send_message(FrontendMessage::Raw(buf));
|
||||
}
|
||||
}
|
||||
@@ -141,7 +135,7 @@ impl<'a> PartialQuery<'a> {
|
||||
client.buffer.clear();
|
||||
f(&mut client.buffer)?;
|
||||
frontend::flush(&mut client.buffer);
|
||||
let buf = client.buffer.split().freeze();
|
||||
let buf = client.buffer.split();
|
||||
client.send_message(FrontendMessage::Raw(buf))
|
||||
}
|
||||
|
||||
@@ -154,7 +148,7 @@ impl<'a> PartialQuery<'a> {
|
||||
client.buffer.clear();
|
||||
f(&mut client.buffer)?;
|
||||
frontend::sync(&mut client.buffer);
|
||||
let buf = client.buffer.split().freeze();
|
||||
let buf = client.buffer.split();
|
||||
let _ = client.send_message(FrontendMessage::Raw(buf));
|
||||
|
||||
Ok(&mut self.0.take().unwrap().responses)
|
||||
@@ -317,6 +311,9 @@ impl Client {
|
||||
DISCARD SEQUENCES;",
|
||||
)?;
|
||||
|
||||
// Clean up memory usage.
|
||||
gc_bytesmut(&mut self.inner_mut().buffer);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use std::io;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol2::message::backend;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
pub enum FrontendMessage {
|
||||
Raw(Bytes),
|
||||
Raw(BytesMut),
|
||||
RecordNotices(RecordNotices),
|
||||
}
|
||||
|
||||
@@ -17,7 +17,10 @@ pub struct RecordNotices {
|
||||
}
|
||||
|
||||
pub enum BackendMessage {
|
||||
Normal { messages: BackendMessages },
|
||||
Normal {
|
||||
messages: BackendMessages,
|
||||
ready: bool,
|
||||
},
|
||||
Async(backend::Message),
|
||||
}
|
||||
|
||||
@@ -40,11 +43,18 @@ impl FallibleIterator for BackendMessages {
|
||||
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Encoder<Bytes> for PostgresCodec {
|
||||
impl Encoder<BytesMut> for PostgresCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> {
|
||||
dst.extend_from_slice(&item);
|
||||
fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> io::Result<()> {
|
||||
// When it comes to request/response workflows, we usually flush the entire write
|
||||
// buffer in order to wait for the response before we send a new request.
|
||||
// Therefore we can avoid the copy and just replace the buffer.
|
||||
if dst.is_empty() {
|
||||
*dst = item;
|
||||
} else {
|
||||
dst.extend_from_slice(&item);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -56,6 +66,7 @@ impl Decoder for PostgresCodec {
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
|
||||
let mut idx = 0;
|
||||
|
||||
let mut ready = false;
|
||||
while let Some(header) = backend::Header::parse(&src[idx..])? {
|
||||
let len = header.len() as usize + 1;
|
||||
if src[idx..].len() < len {
|
||||
@@ -79,6 +90,7 @@ impl Decoder for PostgresCodec {
|
||||
idx += len;
|
||||
|
||||
if header.tag() == backend::READY_FOR_QUERY_TAG {
|
||||
ready = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -88,6 +100,7 @@ impl Decoder for PostgresCodec {
|
||||
} else {
|
||||
Ok(Some(BackendMessage::Normal {
|
||||
messages: BackendMessages(src.split_to(idx)),
|
||||
ready,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,19 +250,20 @@ impl Config {
|
||||
{
|
||||
let stream = connect_tls(stream, self.ssl_mode, tls).await?;
|
||||
let mut stream = StartupStream::new(stream);
|
||||
connect_raw::startup(&mut stream, self).await?;
|
||||
connect_raw::authenticate(&mut stream, self).await?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
pub async fn authenticate<S, T>(&self, stream: &mut StartupStream<S, T>) -> Result<(), Error>
|
||||
pub fn authenticate<S, T>(
|
||||
&self,
|
||||
stream: &mut StartupStream<S, T>,
|
||||
) -> impl Future<Output = Result<(), Error>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
connect_raw::startup(stream, self).await?;
|
||||
connect_raw::authenticate(stream, self).await
|
||||
connect_raw::authenticate(stream, self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,51 +2,28 @@ use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll, ready};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Sink, SinkExt, Stream, TryStreamExt};
|
||||
use futures_util::{SinkExt, Stream, TryStreamExt};
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_util::codec::{Framed, FramedParts, FramedWrite};
|
||||
use tokio_util::codec::{Framed, FramedParts};
|
||||
|
||||
use crate::Error;
|
||||
use crate::codec::PostgresCodec;
|
||||
use crate::config::{self, AuthKeys, Config};
|
||||
use crate::connection::{GC_THRESHOLD, INITIAL_CAPACITY};
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::TlsStream;
|
||||
|
||||
pub struct StartupStream<S, T> {
|
||||
inner: FramedWrite<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
read_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S, T> Sink<Bytes> for StartupStream<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
|
||||
Pin::new(&mut self.inner).start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Stream for StartupStream<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
@@ -55,6 +32,8 @@ where
|
||||
type Item = io::Result<Message>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
// We don't use `self.inner.poll_next()` as that might over-read into the read buffer.
|
||||
|
||||
// read 1 byte tag, 4 bytes length.
|
||||
let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?);
|
||||
|
||||
@@ -121,36 +100,28 @@ where
|
||||
}
|
||||
|
||||
pub fn into_framed(mut self) -> Framed<MaybeTlsStream<S, T>, PostgresCodec> {
|
||||
let write_buf = std::mem::take(self.inner.write_buffer_mut());
|
||||
let io = self.inner.into_inner();
|
||||
let mut parts = FramedParts::new(io, PostgresCodec);
|
||||
parts.read_buf = self.read_buf;
|
||||
parts.write_buf = write_buf;
|
||||
Framed::from_parts(parts)
|
||||
*self.inner.read_buffer_mut() = self.read_buf;
|
||||
self.inner
|
||||
}
|
||||
|
||||
pub fn new(io: MaybeTlsStream<S, T>) -> Self {
|
||||
let mut parts = FramedParts::new(io, PostgresCodec);
|
||||
parts.write_buf = BytesMut::with_capacity(INITIAL_CAPACITY);
|
||||
|
||||
let mut inner = Framed::from_parts(parts);
|
||||
|
||||
// This is the default already, but nice to be explicit.
|
||||
// We divide by two because writes will overshoot the boundary.
|
||||
// We don't want constant overshoots to cause us to constantly re-shrink the buffer.
|
||||
inner.set_backpressure_boundary(GC_THRESHOLD / 2);
|
||||
|
||||
Self {
|
||||
inner: FramedWrite::new(io, PostgresCodec),
|
||||
read_buf: BytesMut::new(),
|
||||
inner,
|
||||
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn startup<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
config: &Config,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
|
||||
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)
|
||||
}
|
||||
|
||||
pub(crate) async fn authenticate<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
config: &Config,
|
||||
@@ -159,6 +130,10 @@ where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
frontend::startup_message(&config.server_params, stream.inner.write_buffer_mut())
|
||||
.map_err(Error::encode)?;
|
||||
|
||||
stream.inner.flush().await.map_err(Error::io)?;
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationOk) => {
|
||||
can_skip_channel_binding(config)?;
|
||||
@@ -172,7 +147,8 @@ where
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("password missing".into()))?;
|
||||
|
||||
authenticate_password(stream, pass).await?;
|
||||
frontend::password_message(pass, stream.inner.write_buffer_mut())
|
||||
.map_err(Error::encode)?;
|
||||
}
|
||||
Some(Message::AuthenticationSasl(body)) => {
|
||||
authenticate_sasl(stream, body, config).await?;
|
||||
@@ -191,6 +167,7 @@ where
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
|
||||
stream.inner.flush().await.map_err(Error::io)?;
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationOk) => Ok(()),
|
||||
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
|
||||
@@ -208,20 +185,6 @@ fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate_password<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
password: &[u8],
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
|
||||
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)
|
||||
}
|
||||
|
||||
async fn authenticate_sasl<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
body: AuthenticationSaslBody,
|
||||
@@ -276,10 +239,10 @@ where
|
||||
return Err(Error::config("password or auth keys missing".into()));
|
||||
};
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)?;
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), stream.inner.write_buffer_mut())
|
||||
.map_err(Error::encode)?;
|
||||
|
||||
stream.inner.flush().await.map_err(Error::io)?;
|
||||
let body = match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationSaslContinue(body)) => body,
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
@@ -292,10 +255,10 @@ where
|
||||
.await
|
||||
.map_err(|e| Error::authentication(e.into()))?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)?;
|
||||
frontend::sasl_response(scram.message(), stream.inner.write_buffer_mut())
|
||||
.map_err(Error::encode)?;
|
||||
|
||||
stream.inner.flush().await.map_err(Error::io)?;
|
||||
let body = match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationSaslFinal(body)) => body,
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
|
||||
@@ -44,6 +44,27 @@ pub struct Connection<S, T> {
|
||||
state: State,
|
||||
}
|
||||
|
||||
pub const INITIAL_CAPACITY: usize = 2 * 1024;
|
||||
pub const GC_THRESHOLD: usize = 16 * 1024;
|
||||
|
||||
/// Gargabe collect the [`BytesMut`] if it has too much spare capacity.
|
||||
pub fn gc_bytesmut(buf: &mut BytesMut) {
|
||||
// We use a different mode to shrink the buf when above the threshold.
|
||||
// When above the threshold, we only re-allocate when the buf has 2x spare capacity.
|
||||
let reclaim = GC_THRESHOLD.checked_sub(buf.len()).unwrap_or(buf.len());
|
||||
|
||||
// `try_reclaim` tries to get the capacity from any shared `BytesMut`s,
|
||||
// before then comparing the length against the capacity.
|
||||
if buf.try_reclaim(reclaim) {
|
||||
let capacity = usize::max(buf.len(), INITIAL_CAPACITY);
|
||||
|
||||
// Allocate a new `BytesMut` so that we deallocate the old version.
|
||||
let mut new = BytesMut::with_capacity(capacity);
|
||||
new.extend_from_slice(buf);
|
||||
*buf = new;
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Never {}
|
||||
|
||||
impl<S, T> Connection<S, T>
|
||||
@@ -86,7 +107,14 @@ where
|
||||
continue;
|
||||
}
|
||||
BackendMessage::Async(_) => continue,
|
||||
BackendMessage::Normal { messages } => messages,
|
||||
BackendMessage::Normal { messages, ready } => {
|
||||
// if we read a ReadyForQuery from postgres, let's try GC the read buffer.
|
||||
if ready {
|
||||
gc_bytesmut(self.stream.read_buffer_mut());
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -177,12 +205,7 @@ where
|
||||
// Send a terminate message to postgres
|
||||
Poll::Ready(None) => {
|
||||
trace!("poll_write: at eof, terminating");
|
||||
let mut request = BytesMut::new();
|
||||
frontend::terminate(&mut request);
|
||||
|
||||
Pin::new(&mut self.stream)
|
||||
.start_send(request.freeze())
|
||||
.map_err(Error::io)?;
|
||||
frontend::terminate(self.stream.write_buffer_mut());
|
||||
|
||||
trace!("poll_write: sent eof, closing");
|
||||
trace!("poll_write: done");
|
||||
@@ -205,6 +228,10 @@ where
|
||||
{
|
||||
Poll::Ready(()) => {
|
||||
trace!("poll_flush: flushed");
|
||||
|
||||
// GC the write buffer if we managed to flush
|
||||
gc_bytesmut(self.stream.write_buffer_mut());
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Pending => {
|
||||
|
||||
Reference in New Issue
Block a user