mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-22 21:59:59 +00:00
[proxy] move read_info from the compute connection to be as late as possible (#12660)
Second attempt at #12130, now with a smaller diff. This allows us to skip allocating for things like parameter status and notices that we will either just forward untouched, or discard. LKB-2494
This commit is contained in:
@@ -11,9 +11,8 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::connect::connect;
|
||||
use crate::connect_raw::{RawConnection, connect_raw};
|
||||
use crate::connect_raw::{self, StartupStream};
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect, TlsStream};
|
||||
use crate::{Client, Connection, Error};
|
||||
|
||||
@@ -244,24 +243,26 @@ impl Config {
|
||||
&self,
|
||||
stream: S,
|
||||
tls: T,
|
||||
) -> Result<RawConnection<S, T::Stream>, Error>
|
||||
) -> Result<StartupStream<S, T::Stream>, Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
let stream = connect_tls(stream, self.ssl_mode, tls).await?;
|
||||
connect_raw(stream, self).await
|
||||
let mut stream = StartupStream::new(stream);
|
||||
connect_raw::startup(&mut stream, self).await?;
|
||||
connect_raw::authenticate(&mut stream, self).await?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
pub async fn authenticate<S, T>(
|
||||
&self,
|
||||
stream: MaybeTlsStream<S, T>,
|
||||
) -> Result<RawConnection<S, T>, Error>
|
||||
pub async fn authenticate<S, T>(&self, stream: &mut StartupStream<S, T>) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
connect_raw(stream, self).await
|
||||
connect_raw::startup(stream, self).await?;
|
||||
connect_raw::authenticate(stream, self).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use std::net::IpAddr;
|
||||
|
||||
use futures_util::TryStreamExt;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::client::SocketConfig;
|
||||
use crate::config::Host;
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::connect_raw::StartupStream;
|
||||
use crate::connect_socket::connect_socket;
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, RawConnection};
|
||||
use crate::{Client, Config, Connection, Error};
|
||||
|
||||
pub async fn connect<T>(
|
||||
tls: &T,
|
||||
@@ -43,14 +45,8 @@ where
|
||||
T: TlsConnect<TcpStream>,
|
||||
{
|
||||
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
|
||||
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
|
||||
let RawConnection {
|
||||
stream,
|
||||
parameters: _,
|
||||
delayed_notice: _,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connect_raw(stream, config).await?;
|
||||
let mut stream = config.tls_and_authenticate(socket, tls).await?;
|
||||
let (process_id, secret_key) = wait_until_ready(&mut stream).await?;
|
||||
|
||||
let socket_config = SocketConfig {
|
||||
host_addr,
|
||||
@@ -70,7 +66,32 @@ where
|
||||
secret_key,
|
||||
);
|
||||
|
||||
let stream = stream.into_framed();
|
||||
let connection = Connection::new(stream, conn_tx, conn_rx);
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
async fn wait_until_ready<S, T>(stream: &mut StartupStream<S, T>) -> Result<(i32, i32), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
|
||||
loop {
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
process_id = body.process_id();
|
||||
secret_key = body.secret_key();
|
||||
}
|
||||
// These values are currently not used by `Client`/`Connection`. Ignore them.
|
||||
Some(Message::ParameterStatus(_)) | Some(Message::NoticeResponse(_)) => {}
|
||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key)),
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{Context, Poll, ready};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
|
||||
use futures_util::{Sink, SinkExt, Stream, TryStreamExt};
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::codec::Framed;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_util::codec::{Framed, FramedParts, FramedWrite};
|
||||
|
||||
use crate::Error;
|
||||
use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
|
||||
use crate::codec::PostgresCodec;
|
||||
use crate::config::{self, AuthKeys, Config};
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::TlsStream;
|
||||
|
||||
pub struct StartupStream<S, T> {
|
||||
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
buf: BackendMessages,
|
||||
delayed_notice: Vec<NoticeResponseBody>,
|
||||
inner: FramedWrite<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
read_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S, T> Sink<Bytes> for StartupStream<S, T>
|
||||
@@ -56,63 +54,93 @@ where
|
||||
{
|
||||
type Item = io::Result<Message>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
// read 1 byte tag, 4 bytes length.
|
||||
let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?);
|
||||
|
||||
let len = u32::from_be_bytes(header[1..5].try_into().unwrap());
|
||||
if len < 4 {
|
||||
return Poll::Ready(Some(Err(std::io::Error::other(
|
||||
"postgres message too small",
|
||||
))));
|
||||
}
|
||||
if len >= 65536 {
|
||||
return Poll::Ready(Some(Err(std::io::Error::other(
|
||||
"postgres message too large",
|
||||
))));
|
||||
}
|
||||
|
||||
// the tag is an additional byte.
|
||||
let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?);
|
||||
|
||||
// Message::parse will remove the all the bytes from the buffer.
|
||||
Poll::Ready(Message::parse(&mut self.read_buf).transpose())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> StartupStream<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
/// Fill the buffer until it's the exact length provided. No additional data will be read from the socket.
|
||||
///
|
||||
/// If the current buffer length is greater, nothing happens.
|
||||
fn poll_fill_buf_exact(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<io::Result<Message>>> {
|
||||
loop {
|
||||
match self.buf.next() {
|
||||
Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
|
||||
Ok(None) => {}
|
||||
Err(e) => return Poll::Ready(Some(Err(e))),
|
||||
len: usize,
|
||||
) -> Poll<Result<&[u8], std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
let mut stream = Pin::new(this.inner.get_mut());
|
||||
|
||||
let mut n = this.read_buf.len();
|
||||
while n < len {
|
||||
this.read_buf.resize(len, 0);
|
||||
|
||||
let mut buf = ReadBuf::new(&mut this.read_buf[..]);
|
||||
buf.set_filled(n);
|
||||
|
||||
if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() {
|
||||
this.read_buf.truncate(n);
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
|
||||
Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
|
||||
Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
|
||||
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
|
||||
None => return Poll::Ready(None),
|
||||
if buf.filled().len() == n {
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"early eof",
|
||||
)));
|
||||
}
|
||||
n = buf.filled().len();
|
||||
|
||||
this.read_buf.truncate(n);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(&this.read_buf[..len]))
|
||||
}
|
||||
|
||||
pub fn into_framed(mut self) -> Framed<MaybeTlsStream<S, T>, PostgresCodec> {
|
||||
let write_buf = std::mem::take(self.inner.write_buffer_mut());
|
||||
let io = self.inner.into_inner();
|
||||
let mut parts = FramedParts::new(io, PostgresCodec);
|
||||
parts.read_buf = self.read_buf;
|
||||
parts.write_buf = write_buf;
|
||||
Framed::from_parts(parts)
|
||||
}
|
||||
|
||||
pub fn new(io: MaybeTlsStream<S, T>) -> Self {
|
||||
Self {
|
||||
inner: FramedWrite::new(io, PostgresCodec),
|
||||
read_buf: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RawConnection<S, T> {
|
||||
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub delayed_notice: Vec<NoticeResponseBody>,
|
||||
pub process_id: i32,
|
||||
pub secret_key: i32,
|
||||
}
|
||||
|
||||
pub async fn connect_raw<S, T>(
|
||||
stream: MaybeTlsStream<S, T>,
|
||||
pub(crate) async fn startup<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
config: &Config,
|
||||
) -> Result<RawConnection<S, T>, Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
let mut stream = StartupStream {
|
||||
inner: Framed::new(stream, PostgresCodec),
|
||||
buf: BackendMessages::empty(),
|
||||
delayed_notice: Vec::new(),
|
||||
};
|
||||
|
||||
startup(&mut stream, config).await?;
|
||||
authenticate(&mut stream, config).await?;
|
||||
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
|
||||
|
||||
Ok(RawConnection {
|
||||
stream: stream.inner,
|
||||
parameters,
|
||||
delayed_notice: stream.delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
})
|
||||
}
|
||||
|
||||
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
@@ -123,7 +151,10 @@ where
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)
|
||||
}
|
||||
|
||||
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||
pub(crate) async fn authenticate<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
config: &Config,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
@@ -278,35 +309,3 @@ where
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_info<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
) -> Result<(i32, i32, HashMap<String, String>), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
let mut parameters = HashMap::new();
|
||||
|
||||
loop {
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
process_id = body.process_id();
|
||||
secret_key = body.secret_key();
|
||||
}
|
||||
Some(Message::ParameterStatus(body)) => {
|
||||
parameters.insert(
|
||||
body.name().map_err(Error::parse)?.to_string(),
|
||||
body.value().map_err(Error::parse)?.to_string(),
|
||||
);
|
||||
}
|
||||
Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
|
||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,16 +452,16 @@ impl Error {
|
||||
Error(Box::new(ErrorInner { kind, cause }))
|
||||
}
|
||||
|
||||
pub(crate) fn closed() -> Error {
|
||||
pub fn closed() -> Error {
|
||||
Error::new(Kind::Closed, None)
|
||||
}
|
||||
|
||||
pub(crate) fn unexpected_message() -> Error {
|
||||
pub fn unexpected_message() -> Error {
|
||||
Error::new(Kind::UnexpectedMessage, None)
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
pub(crate) fn db(error: ErrorResponseBody) -> Error {
|
||||
pub fn db(error: ErrorResponseBody) -> Error {
|
||||
match DbError::parse(&mut error.fields()) {
|
||||
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
|
||||
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
|
||||
@@ -493,7 +493,7 @@ impl Error {
|
||||
Error::new(Kind::Tls, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn io(e: io::Error) -> Error {
|
||||
pub fn io(e: io::Error) -> Error {
|
||||
Error::new(Kind::Io, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ use postgres_protocol2::message::backend::ReadyForQueryBody;
|
||||
pub use crate::cancel_token::{CancelToken, RawCancelToken};
|
||||
pub use crate::client::{Client, SocketConfig};
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connect_raw::RawConnection;
|
||||
pub use crate::connection::Connection;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::generic_client::GenericClient;
|
||||
@@ -50,7 +49,7 @@ mod client;
|
||||
mod codec;
|
||||
pub mod config;
|
||||
mod connect;
|
||||
mod connect_raw;
|
||||
pub mod connect_raw;
|
||||
mod connect_socket;
|
||||
mod connect_tls;
|
||||
mod connection;
|
||||
|
||||
Reference in New Issue
Block a user