mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-03 19:42:55 +00:00
[proxy] move read_info from the compute connection to be as late as possible (#12660)
Second attempt at #12130, now with a smaller diff. This allows us to skip allocating for things like parameter status and notices that we will either just forward untouched, or discard. LKB-2494
This commit is contained in:
@@ -11,9 +11,8 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::connect::connect;
|
||||
use crate::connect_raw::{RawConnection, connect_raw};
|
||||
use crate::connect_raw::{self, StartupStream};
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect, TlsStream};
|
||||
use crate::{Client, Connection, Error};
|
||||
|
||||
@@ -244,24 +243,26 @@ impl Config {
|
||||
&self,
|
||||
stream: S,
|
||||
tls: T,
|
||||
) -> Result<RawConnection<S, T::Stream>, Error>
|
||||
) -> Result<StartupStream<S, T::Stream>, Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
let stream = connect_tls(stream, self.ssl_mode, tls).await?;
|
||||
connect_raw(stream, self).await
|
||||
let mut stream = StartupStream::new(stream);
|
||||
connect_raw::startup(&mut stream, self).await?;
|
||||
connect_raw::authenticate(&mut stream, self).await?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
pub async fn authenticate<S, T>(
|
||||
&self,
|
||||
stream: MaybeTlsStream<S, T>,
|
||||
) -> Result<RawConnection<S, T>, Error>
|
||||
pub async fn authenticate<S, T>(&self, stream: &mut StartupStream<S, T>) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
connect_raw(stream, self).await
|
||||
connect_raw::startup(stream, self).await?;
|
||||
connect_raw::authenticate(stream, self).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use std::net::IpAddr;
|
||||
|
||||
use futures_util::TryStreamExt;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::client::SocketConfig;
|
||||
use crate::config::Host;
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::connect_raw::StartupStream;
|
||||
use crate::connect_socket::connect_socket;
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, RawConnection};
|
||||
use crate::{Client, Config, Connection, Error};
|
||||
|
||||
pub async fn connect<T>(
|
||||
tls: &T,
|
||||
@@ -43,14 +45,8 @@ where
|
||||
T: TlsConnect<TcpStream>,
|
||||
{
|
||||
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
|
||||
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
|
||||
let RawConnection {
|
||||
stream,
|
||||
parameters: _,
|
||||
delayed_notice: _,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connect_raw(stream, config).await?;
|
||||
let mut stream = config.tls_and_authenticate(socket, tls).await?;
|
||||
let (process_id, secret_key) = wait_until_ready(&mut stream).await?;
|
||||
|
||||
let socket_config = SocketConfig {
|
||||
host_addr,
|
||||
@@ -70,7 +66,32 @@ where
|
||||
secret_key,
|
||||
);
|
||||
|
||||
let stream = stream.into_framed();
|
||||
let connection = Connection::new(stream, conn_tx, conn_rx);
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
async fn wait_until_ready<S, T>(stream: &mut StartupStream<S, T>) -> Result<(i32, i32), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
|
||||
loop {
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
process_id = body.process_id();
|
||||
secret_key = body.secret_key();
|
||||
}
|
||||
// These values are currently not used by `Client`/`Connection`. Ignore them.
|
||||
Some(Message::ParameterStatus(_)) | Some(Message::NoticeResponse(_)) => {}
|
||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key)),
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{Context, Poll, ready};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
|
||||
use futures_util::{Sink, SinkExt, Stream, TryStreamExt};
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::codec::Framed;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_util::codec::{Framed, FramedParts, FramedWrite};
|
||||
|
||||
use crate::Error;
|
||||
use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
|
||||
use crate::codec::PostgresCodec;
|
||||
use crate::config::{self, AuthKeys, Config};
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::TlsStream;
|
||||
|
||||
pub struct StartupStream<S, T> {
|
||||
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
buf: BackendMessages,
|
||||
delayed_notice: Vec<NoticeResponseBody>,
|
||||
inner: FramedWrite<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
read_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S, T> Sink<Bytes> for StartupStream<S, T>
|
||||
@@ -56,63 +54,93 @@ where
|
||||
{
|
||||
type Item = io::Result<Message>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<io::Result<Message>>> {
|
||||
loop {
|
||||
match self.buf.next() {
|
||||
Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
|
||||
Ok(None) => {}
|
||||
Err(e) => return Poll::Ready(Some(Err(e))),
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
// read 1 byte tag, 4 bytes length.
|
||||
let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?);
|
||||
|
||||
let len = u32::from_be_bytes(header[1..5].try_into().unwrap());
|
||||
if len < 4 {
|
||||
return Poll::Ready(Some(Err(std::io::Error::other(
|
||||
"postgres message too small",
|
||||
))));
|
||||
}
|
||||
if len >= 65536 {
|
||||
return Poll::Ready(Some(Err(std::io::Error::other(
|
||||
"postgres message too large",
|
||||
))));
|
||||
}
|
||||
|
||||
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
|
||||
Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
|
||||
Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
|
||||
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
// the tag is an additional byte.
|
||||
let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?);
|
||||
|
||||
// Message::parse will remove the all the bytes from the buffer.
|
||||
Poll::Ready(Message::parse(&mut self.read_buf).transpose())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RawConnection<S, T> {
|
||||
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub delayed_notice: Vec<NoticeResponseBody>,
|
||||
pub process_id: i32,
|
||||
pub secret_key: i32,
|
||||
}
|
||||
|
||||
pub async fn connect_raw<S, T>(
|
||||
stream: MaybeTlsStream<S, T>,
|
||||
config: &Config,
|
||||
) -> Result<RawConnection<S, T>, Error>
|
||||
impl<S, T> StartupStream<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut stream = StartupStream {
|
||||
inner: Framed::new(stream, PostgresCodec),
|
||||
buf: BackendMessages::empty(),
|
||||
delayed_notice: Vec::new(),
|
||||
};
|
||||
/// Fill the buffer until it's the exact length provided. No additional data will be read from the socket.
|
||||
///
|
||||
/// If the current buffer length is greater, nothing happens.
|
||||
fn poll_fill_buf_exact(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
len: usize,
|
||||
) -> Poll<Result<&[u8], std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
let mut stream = Pin::new(this.inner.get_mut());
|
||||
|
||||
startup(&mut stream, config).await?;
|
||||
authenticate(&mut stream, config).await?;
|
||||
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
|
||||
let mut n = this.read_buf.len();
|
||||
while n < len {
|
||||
this.read_buf.resize(len, 0);
|
||||
|
||||
Ok(RawConnection {
|
||||
stream: stream.inner,
|
||||
parameters,
|
||||
delayed_notice: stream.delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
})
|
||||
let mut buf = ReadBuf::new(&mut this.read_buf[..]);
|
||||
buf.set_filled(n);
|
||||
|
||||
if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() {
|
||||
this.read_buf.truncate(n);
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
if buf.filled().len() == n {
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"early eof",
|
||||
)));
|
||||
}
|
||||
n = buf.filled().len();
|
||||
|
||||
this.read_buf.truncate(n);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(&this.read_buf[..len]))
|
||||
}
|
||||
|
||||
pub fn into_framed(mut self) -> Framed<MaybeTlsStream<S, T>, PostgresCodec> {
|
||||
let write_buf = std::mem::take(self.inner.write_buffer_mut());
|
||||
let io = self.inner.into_inner();
|
||||
let mut parts = FramedParts::new(io, PostgresCodec);
|
||||
parts.read_buf = self.read_buf;
|
||||
parts.write_buf = write_buf;
|
||||
Framed::from_parts(parts)
|
||||
}
|
||||
|
||||
pub fn new(io: MaybeTlsStream<S, T>) -> Self {
|
||||
Self {
|
||||
inner: FramedWrite::new(io, PostgresCodec),
|
||||
read_buf: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||
pub(crate) async fn startup<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
config: &Config,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
@@ -123,7 +151,10 @@ where
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)
|
||||
}
|
||||
|
||||
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||
pub(crate) async fn authenticate<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
config: &Config,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
@@ -278,35 +309,3 @@ where
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_info<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
) -> Result<(i32, i32, HashMap<String, String>), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
let mut parameters = HashMap::new();
|
||||
|
||||
loop {
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
process_id = body.process_id();
|
||||
secret_key = body.secret_key();
|
||||
}
|
||||
Some(Message::ParameterStatus(body)) => {
|
||||
parameters.insert(
|
||||
body.name().map_err(Error::parse)?.to_string(),
|
||||
body.value().map_err(Error::parse)?.to_string(),
|
||||
);
|
||||
}
|
||||
Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
|
||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,16 +452,16 @@ impl Error {
|
||||
Error(Box::new(ErrorInner { kind, cause }))
|
||||
}
|
||||
|
||||
pub(crate) fn closed() -> Error {
|
||||
pub fn closed() -> Error {
|
||||
Error::new(Kind::Closed, None)
|
||||
}
|
||||
|
||||
pub(crate) fn unexpected_message() -> Error {
|
||||
pub fn unexpected_message() -> Error {
|
||||
Error::new(Kind::UnexpectedMessage, None)
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
pub(crate) fn db(error: ErrorResponseBody) -> Error {
|
||||
pub fn db(error: ErrorResponseBody) -> Error {
|
||||
match DbError::parse(&mut error.fields()) {
|
||||
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
|
||||
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
|
||||
@@ -493,7 +493,7 @@ impl Error {
|
||||
Error::new(Kind::Tls, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn io(e: io::Error) -> Error {
|
||||
pub fn io(e: io::Error) -> Error {
|
||||
Error::new(Kind::Io, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ use postgres_protocol2::message::backend::ReadyForQueryBody;
|
||||
pub use crate::cancel_token::{CancelToken, RawCancelToken};
|
||||
pub use crate::client::{Client, SocketConfig};
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connect_raw::RawConnection;
|
||||
pub use crate::connection::Connection;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::generic_client::GenericClient;
|
||||
@@ -50,7 +49,7 @@ mod client;
|
||||
mod codec;
|
||||
pub mod config;
|
||||
mod connect;
|
||||
mod connect_raw;
|
||||
pub mod connect_raw;
|
||||
mod connect_socket;
|
||||
mod connect_tls;
|
||||
mod connection;
|
||||
|
||||
@@ -429,26 +429,13 @@ impl CancellationHandler {
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: RawCancelToken,
|
||||
hostname: String, // for pg_sni router
|
||||
user_info: ComputeUserInfo,
|
||||
pub socket_addr: SocketAddr,
|
||||
pub cancel_token: RawCancelToken,
|
||||
pub hostname: String, // for pg_sni router
|
||||
pub user_info: ComputeUserInfo,
|
||||
}
|
||||
|
||||
impl CancelClosure {
|
||||
pub(crate) fn new(
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: RawCancelToken,
|
||||
hostname: String,
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Self {
|
||||
Self {
|
||||
socket_addr,
|
||||
cancel_token,
|
||||
hostname,
|
||||
user_info,
|
||||
}
|
||||
}
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub(crate) async fn try_cancel_query(
|
||||
&self,
|
||||
|
||||
@@ -7,17 +7,15 @@ use std::net::{IpAddr, SocketAddr};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
|
||||
use postgres_client::connect_raw::StartupStream;
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{NoTls, RawCancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
use crate::auth::parse_endpoint_param;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::compute::tls::TlsError;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
@@ -236,8 +234,7 @@ impl AuthInfo {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
compute: &mut ComputeConnection,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<PostgresSettings, PostgresError> {
|
||||
) -> Result<(), PostgresError> {
|
||||
// client config with stubbed connect info.
|
||||
// TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
|
||||
// utilising pqproto.rs.
|
||||
@@ -247,39 +244,10 @@ impl AuthInfo {
|
||||
let tmp_config = self.enrich(tmp_config);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let connection = tmp_config
|
||||
.tls_and_authenticate(&mut compute.stream, NoTls)
|
||||
.await?;
|
||||
tmp_config.authenticate(&mut compute.stream).await?;
|
||||
drop(pause);
|
||||
|
||||
let RawConnection {
|
||||
stream: _,
|
||||
parameters,
|
||||
delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connection;
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
|
||||
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
|
||||
// Yet another reason to rework the connection establishing code.
|
||||
let cancel_closure = CancelClosure::new(
|
||||
compute.socket_addr,
|
||||
RawCancelToken {
|
||||
ssl_mode: compute.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
compute.hostname.to_string(),
|
||||
user_info.clone(),
|
||||
);
|
||||
|
||||
Ok(PostgresSettings {
|
||||
params: parameters,
|
||||
cancel_closure,
|
||||
delayed_notice,
|
||||
})
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,21 +311,9 @@ impl ConnectInfo {
|
||||
pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
|
||||
|
||||
// TODO(conrad): we don't need to parse these.
|
||||
// These are just immediately forwarded back to the client.
|
||||
// We could instead stream them out instead of reading them into memory.
|
||||
pub struct PostgresSettings {
|
||||
/// PostgreSQL connection parameters.
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
pub cancel_closure: CancelClosure,
|
||||
/// Notices received from compute after authenticating
|
||||
pub delayed_notice: Vec<NoticeResponseBody>,
|
||||
}
|
||||
|
||||
pub struct ComputeConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
pub stream: StartupStream<tokio::net::TcpStream, RustlsStream>,
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
pub hostname: Host,
|
||||
@@ -390,6 +346,7 @@ impl ConnectInfo {
|
||||
ctx.get_testodrome_id().unwrap_or_default(),
|
||||
);
|
||||
|
||||
let stream = StartupStream::new(stream);
|
||||
let connection = ComputeConnection {
|
||||
stream,
|
||||
socket_addr,
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use postgres_client::RawCancelToken;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info};
|
||||
|
||||
use crate::auth::backend::ConsoleRedirectBackend;
|
||||
use crate::cancellation::CancellationHandler;
|
||||
use crate::cancellation::{CancelClosure, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
@@ -16,7 +17,7 @@ use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::{ErrorSource, finish_client_init};
|
||||
use crate::proxy::{ErrorSource, forward_compute_params_to_client, send_client_greeting};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
pub async fn task_main(
|
||||
@@ -226,21 +227,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
|
||||
let pg_settings = auth_info
|
||||
.authenticate(ctx, &mut node, &user_info)
|
||||
auth_info
|
||||
.authenticate(ctx, &mut node)
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
send_client_greeting(ctx, &config.greetings, &mut stream);
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
finish_client_init(
|
||||
ctx,
|
||||
&pg_settings,
|
||||
*session.key(),
|
||||
&mut stream,
|
||||
&config.greetings,
|
||||
);
|
||||
let (process_id, secret_key) =
|
||||
forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream)
|
||||
.await?;
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
let hostname = node.hostname.to_string();
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
@@ -249,7 +248,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
cancel,
|
||||
&pg_settings.cancel_closure,
|
||||
&CancelClosure {
|
||||
socket_addr: node.socket_addr,
|
||||
cancel_token: RawCancelToken {
|
||||
ssl_mode: node.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
hostname,
|
||||
user_info,
|
||||
},
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
@@ -257,7 +265,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
compute: node.stream,
|
||||
compute: node.stream.into_framed().into_inner(),
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id: None,
|
||||
|
||||
@@ -319,7 +319,7 @@ pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client,
|
||||
compute: node.stream,
|
||||
compute: node.stream.into_framed().into_inner(),
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id,
|
||||
|
||||
@@ -313,6 +313,14 @@ impl WriteBuf {
|
||||
self.0.set_position(0);
|
||||
}
|
||||
|
||||
/// Shrinks the buffer if efficient to do so, and returns the remaining size.
|
||||
pub fn occupied_len(&mut self) -> usize {
|
||||
if self.should_shrink() {
|
||||
self.shrink();
|
||||
}
|
||||
self.0.get_mut().len()
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
///
|
||||
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
|
||||
|
||||
@@ -9,18 +9,23 @@ use std::collections::HashSet;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::TryStreamExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use postgres_client::RawCancelToken;
|
||||
use postgres_client::connect_raw::StartupStream;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, format_smolstr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::cache::Cache;
|
||||
use crate::cancellation::CancellationHandler;
|
||||
use crate::compute::ComputeConnection;
|
||||
use crate::cancellation::{CancelClosure, CancellationHandler};
|
||||
use crate::compute::{ComputeConnection, PostgresError, RustlsStream};
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ControlPlaneClient;
|
||||
@@ -105,7 +110,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
// the compute was cached, and we connected, but the compute cache was actually stale
|
||||
// and is associated with the wrong endpoint. We detect this when the **authentication** fails.
|
||||
// As such, we retry once here if the `authenticate` function fails and the error is valid to retry.
|
||||
let pg_settings = loop {
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
// TODO: callback to pglb
|
||||
@@ -127,9 +132,12 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
unreachable!("ensured above");
|
||||
};
|
||||
|
||||
let res = auth_info.authenticate(ctx, &mut node, user_info).await;
|
||||
let res = auth_info.authenticate(ctx, &mut node).await;
|
||||
match res {
|
||||
Ok(pg_settings) => break pg_settings,
|
||||
Ok(()) => {
|
||||
send_client_greeting(ctx, &config.greetings, client);
|
||||
break;
|
||||
}
|
||||
Err(e) if attempt < 2 && e.should_retry_wake_compute() => {
|
||||
tracing::warn!(error = ?e, "retrying wake compute");
|
||||
|
||||
@@ -141,11 +149,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
}
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
}
|
||||
}
|
||||
|
||||
let auth::Backend::ControlPlane(_, user_info) = backend else {
|
||||
unreachable!("ensured above");
|
||||
};
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
finish_client_init(ctx, &pg_settings, *session.key(), client, &config.greetings);
|
||||
let (process_id, secret_key) =
|
||||
forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?;
|
||||
let hostname = node.hostname.to_string();
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = oneshot::channel();
|
||||
@@ -154,7 +168,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
cancel,
|
||||
&pg_settings.cancel_closure,
|
||||
&CancelClosure {
|
||||
socket_addr: node.socket_addr,
|
||||
cancel_token: RawCancelToken {
|
||||
ssl_mode: node.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
hostname,
|
||||
user_info,
|
||||
},
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
@@ -163,35 +186,18 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
Ok((node, cancel_on_shutdown))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
pub(crate) fn finish_client_init(
|
||||
/// Greet the client with any useful information.
|
||||
pub(crate) fn send_client_greeting(
|
||||
ctx: &RequestContext,
|
||||
settings: &compute::PostgresSettings,
|
||||
cancel_key_data: CancelKeyData,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
greetings: &String,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) {
|
||||
// Forward all deferred notices to the client.
|
||||
for notice in &settings.delayed_notice {
|
||||
client.write_raw(notice.as_bytes().len(), b'N', |buf| {
|
||||
buf.extend_from_slice(notice.as_bytes());
|
||||
});
|
||||
}
|
||||
|
||||
// Expose session_id to clients if we have a greeting message.
|
||||
if !greetings.is_empty() {
|
||||
let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id());
|
||||
client.write_message(BeMessage::NoticeResponse(session_msg.as_str()));
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
for (name, value) in &settings.params {
|
||||
client.write_message(BeMessage::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
});
|
||||
}
|
||||
|
||||
// Forward recorded latencies for probing requests
|
||||
if let Some(testodrome_id) = ctx.get_testodrome_id() {
|
||||
client.write_message(BeMessage::ParameterStatus {
|
||||
@@ -221,9 +227,63 @@ pub(crate) fn finish_client_init(
|
||||
value: latency_measured.retry.as_micros().to_string().as_bytes(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn forward_compute_params_to_client(
|
||||
ctx: &RequestContext,
|
||||
cancel_key_data: CancelKeyData,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
compute: &mut StartupStream<TcpStream, RustlsStream>,
|
||||
) -> Result<(i32, i32), ClientRequestError> {
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
|
||||
let err = loop {
|
||||
// if the client buffer is too large, let's write out some bytes now to save some space
|
||||
client.write_if_full().await?;
|
||||
|
||||
let msg = match compute.try_next().await {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => break postgres_client::Error::io(e),
|
||||
};
|
||||
|
||||
match msg {
|
||||
// Send our cancellation key data instead.
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
|
||||
process_id = body.process_id();
|
||||
secret_key = body.secret_key();
|
||||
}
|
||||
// Forward all postgres connection params to the client.
|
||||
Some(Message::ParameterStatus(body)) => {
|
||||
if let Ok(name) = body.name()
|
||||
&& let Ok(value) = body.value()
|
||||
{
|
||||
client.write_message(BeMessage::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
});
|
||||
}
|
||||
}
|
||||
// Forward all notices to the client.
|
||||
Some(Message::NoticeResponse(notice)) => {
|
||||
client.write_raw(notice.as_bytes().len(), b'N', |buf| {
|
||||
buf.extend_from_slice(notice.as_bytes());
|
||||
});
|
||||
}
|
||||
Some(Message::ReadyForQuery(_)) => {
|
||||
client.write_message(BeMessage::ReadyForQuery);
|
||||
return Ok((process_id, secret_key));
|
||||
}
|
||||
Some(Message::ErrorResponse(body)) => break postgres_client::Error::db(body),
|
||||
Some(_) => break postgres_client::Error::unexpected_message(),
|
||||
None => break postgres_client::Error::closed(),
|
||||
}
|
||||
};
|
||||
|
||||
Err(client
|
||||
.throw_error(PostgresError::Postgres(err), Some(ctx))
|
||||
.await)?
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
|
||||
@@ -154,6 +154,15 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
message.write_message(&mut self.write);
|
||||
}
|
||||
|
||||
/// Write the buffer to the socket until we have some more space again.
|
||||
pub async fn write_if_full(&mut self) -> io::Result<()> {
|
||||
while self.write.occupied_len() > 2048 {
|
||||
self.stream.write_buf(&mut self.write).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
|
||||
Reference in New Issue
Block a user