mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 00:42:54 +00:00
move PqStream to pq_frontend.rs and rename to PgFeStream
This commit is contained in:
@@ -7,13 +7,13 @@ use crate::auth::{self, AuthFlow};
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
use crate::{compute, sasl};
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &RequestContext,
|
||||
creds: ComputeUserInfo,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::wake_compute::WakeComputeBackend;
|
||||
use crate::stream::PqStream;
|
||||
use crate::stream::PqFeStream;
|
||||
use crate::types::RoleName;
|
||||
use crate::{auth, compute, waiters};
|
||||
|
||||
@@ -96,7 +96,7 @@ impl ConsoleRedirectBackend {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> {
|
||||
authenticate(ctx, auth_config, &self.console_uri, client)
|
||||
.await
|
||||
@@ -122,7 +122,7 @@ async fn authenticate(
|
||||
ctx: &RequestContext,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::stream::{self, Stream};
|
||||
pub(crate) async fn authenticate_cleartext(
|
||||
ctx: &RequestContext,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
secret: AuthSecret,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
@@ -61,7 +61,7 @@ pub(crate) async fn authenticate_cleartext(
|
||||
pub(crate) async fn password_hack_no_authentication(
|
||||
ctx: &RequestContext,
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
) -> auth::Result<(ComputeUserInfo, Vec<u8>)> {
|
||||
debug!("project not specified, resorting to the password hack auth flow");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
@@ -201,7 +201,7 @@ async fn auth_quirks(
|
||||
ctx: &RequestContext,
|
||||
api: &impl control_plane::ControlPlaneApi,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -267,7 +267,7 @@ async fn authenticate_with_secret(
|
||||
ctx: &RequestContext,
|
||||
secret: AuthSecret,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
unauthenticated_password: Option<Vec<u8>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
@@ -318,7 +318,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
pub(crate) async fn authenticate(
|
||||
self,
|
||||
ctx: &RequestContext,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
client: &mut stream::PqFeStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -446,7 +446,7 @@ mod tests {
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::scram::ServerSecret;
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
|
||||
struct Auth {
|
||||
ips: Vec<IpPattern>,
|
||||
@@ -522,7 +522,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
let mut stream = PqFeStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -604,7 +604,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
let mut stream = PqFeStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -658,7 +658,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
let mut stream = PqFeStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::sasl;
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::scram::{self};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
@@ -53,7 +53,7 @@ pub(crate) struct CleartextPassword {
|
||||
#[must_use]
|
||||
pub(crate) struct AuthFlow<'a, S, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
stream: &'a mut PqFeStream<Stream<S>>,
|
||||
/// State might contain ancillary data.
|
||||
state: State,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
@@ -62,7 +62,7 @@ pub(crate) struct AuthFlow<'a, S, State> {
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
|
||||
pub(crate) fn new(stream: &'a mut PqFeStream<Stream<S>>, method: M) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
|
||||
Self {
|
||||
|
||||
@@ -29,7 +29,7 @@ use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
@@ -262,7 +262,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<TlsStream<S>> {
|
||||
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
|
||||
let (mut stream, msg) = PqFeStream::parse_startup(Stream::from_raw(raw_stream)).await?;
|
||||
match msg {
|
||||
FeStartupPacket::SslRequest { direct: None } => {
|
||||
let raw = stream.accept_tls().await?;
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::pqproto::{
|
||||
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
|
||||
};
|
||||
use crate::proxy::TlsRequired;
|
||||
use crate::stream::{PqStream, Stream, StreamUpgradeError};
|
||||
use crate::stream::{PqFeStream, Stream, StreamUpgradeError};
|
||||
use crate::tls::PG_ALPN_PROTOCOL;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
@@ -49,7 +49,7 @@ impl ReportableError for HandshakeError {
|
||||
}
|
||||
|
||||
pub(crate) enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Startup(PqFeStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?;
|
||||
let (mut stream, mut msg) = PqFeStream::parse_startup(Stream::from_raw(stream)).await?;
|
||||
loop {
|
||||
match msg {
|
||||
FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
|
||||
@@ -152,7 +152,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
tls: tls_stream,
|
||||
tls_server_end_point,
|
||||
};
|
||||
(stream, msg) = PqStream::parse_startup(tls).await?;
|
||||
(stream, msg) = PqFeStream::parse_startup(tls).await?;
|
||||
} else {
|
||||
if direct.is_some() {
|
||||
// client sent us a ClientHello already, we can't do anything with it.
|
||||
|
||||
@@ -30,7 +30,7 @@ use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::util::run_until_cancelled;
|
||||
use crate::{auth, compute};
|
||||
@@ -415,7 +415,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
pub(crate) fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
cancel_key_data: CancelKeyData,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
stream: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) {
|
||||
// Forward all deferred notices to the client.
|
||||
for notice in &node.delayed_notice {
|
||||
|
||||
@@ -122,7 +122,7 @@ fn generate_tls_config<'a>(
|
||||
trait TestAuth: Sized {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
stream: &mut PqFeStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
stream.write_message(BeMessage::AuthenticationOk);
|
||||
Ok(())
|
||||
@@ -151,7 +151,7 @@ impl Scram {
|
||||
impl TestAuth for Scram {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
stream: &mut PqFeStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test()))
|
||||
.authenticate()
|
||||
|
||||
@@ -7,7 +7,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use super::{Mechanism, Step};
|
||||
use crate::context::RequestContext;
|
||||
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::stream::PqStream;
|
||||
use crate::stream::PqFeStream;
|
||||
|
||||
/// SASL authentication outcome.
|
||||
/// It's much easier to match on those two variants
|
||||
@@ -22,7 +22,7 @@ pub(crate) enum Outcome<R> {
|
||||
|
||||
pub async fn authenticate<S, F, M>(
|
||||
ctx: &RequestContext,
|
||||
stream: &mut PqStream<S>,
|
||||
stream: &mut PqFeStream<S>,
|
||||
mechanism: F,
|
||||
) -> super::Result<Outcome<M::Output>>
|
||||
where
|
||||
|
||||
@@ -1,100 +1,19 @@
|
||||
mod pq_frontend;
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
|
||||
pub use pq_frontend::PqFeStream;
|
||||
use rustls::ServerConfig;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pqproto::{
|
||||
BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
|
||||
read_message, read_startup,
|
||||
};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
///
|
||||
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
stream: S,
|
||||
read: Vec<u8>,
|
||||
write: WriteBuf,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
|
||||
#[cfg(test)]
|
||||
pub fn new_skip_handshake(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper and read the first startup message.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
|
||||
let startup = read_startup(&mut stream).await?;
|
||||
Ok((
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
},
|
||||
startup,
|
||||
))
|
||||
}
|
||||
|
||||
/// Tell the client that encryption is not supported.
|
||||
///
|
||||
/// This is not cancel safe
|
||||
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
|
||||
// N for No.
|
||||
self.write.encryption(b'N');
|
||||
self.flush().await?;
|
||||
read_startup(&mut self.stream).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
|
||||
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
if actual_tag != tag {
|
||||
return Err(io::Error::other(format!(
|
||||
"incorrect message tag, expected {:?}, got {:?}",
|
||||
tag as char, actual_tag as char,
|
||||
)));
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
/// Read a postgres password message, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
|
||||
// passwords are usually pretty short
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: u32 = 512;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE.0, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReportedError {
|
||||
source: anyhow::Error,
|
||||
@@ -129,110 +48,6 @@ impl ReportableError for ReportedError {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Tell the client that we are willing to accept SSL.
|
||||
/// This is not cancel safe
|
||||
pub async fn accept_tls(mut self) -> io::Result<S> {
|
||||
// S for SSL.
|
||||
self.write.encryption(b'S');
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Assert that we are using direct TLS.
|
||||
pub fn accept_direct_tls(self) -> S {
|
||||
self.stream
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag, f);
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer
|
||||
pub fn write_message(&mut self, message: BeMessage<'_>) {
|
||||
message.write_message(&mut self.write);
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.write_all_buf(&mut self.write).await?;
|
||||
self.write.reset();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Write the error message to the client, then re-throw it.
|
||||
///
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
|
||||
pub(crate) async fn throw_error<E>(
|
||||
&mut self,
|
||||
error: E,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> ReportedError
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
|
||||
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
|
||||
tracing::info!(
|
||||
kind = error_kind.to_metric_label(),
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
}
|
||||
|
||||
let probe_msg;
|
||||
let mut msg = &*msg;
|
||||
if let Some(ctx) = ctx {
|
||||
if ctx.get_testodrome_id().is_some() {
|
||||
let tag = match error_kind {
|
||||
ErrorKind::User => "client",
|
||||
ErrorKind::ClientDisconnect => "client",
|
||||
ErrorKind::RateLimit => "proxy",
|
||||
ErrorKind::ServiceRateLimit => "proxy",
|
||||
ErrorKind::Quota => "proxy",
|
||||
ErrorKind::Service => "proxy",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "other",
|
||||
ErrorKind::Compute => "compute",
|
||||
};
|
||||
probe_msg = typed_json::json!({
|
||||
"tag": tag,
|
||||
"msg": msg,
|
||||
"cold_start_info": ctx.cold_start_info(),
|
||||
})
|
||||
.to_string();
|
||||
msg = &probe_msg;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
|
||||
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
|
||||
|
||||
self.flush()
|
||||
.await
|
||||
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
|
||||
|
||||
ReportedError::new(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
|
||||
192
proxy/src/stream/pq_frontend.rs
Normal file
192
proxy/src/stream/pq_frontend.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
//! Postgres connection from frontend, proxy is the backend.
|
||||
|
||||
use std::io;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::error::{ErrorKind, UserFacingError};
|
||||
use crate::pqproto::{
|
||||
BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, FeTag, SQLSTATE_INTERNAL_ERROR, WriteBuf,
|
||||
read_message, read_startup,
|
||||
};
|
||||
use crate::stream::ReportedError;
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
pub struct PqFeStream<S> {
|
||||
stream: S,
|
||||
read: Vec<u8>,
|
||||
write: WriteBuf,
|
||||
}
|
||||
|
||||
impl<S> PqFeStream<S> {
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
|
||||
#[cfg(test)]
|
||||
pub fn new_skip_handshake(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> PqFeStream<S> {
|
||||
/// Construct a new libpq protocol wrapper and read the first startup message.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
|
||||
let startup = read_startup(&mut stream).await?;
|
||||
Ok((
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
},
|
||||
startup,
|
||||
))
|
||||
}
|
||||
|
||||
/// Tell the client that encryption is not supported.
|
||||
///
|
||||
/// This is not cancel safe
|
||||
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
|
||||
// N for No.
|
||||
self.write.encryption(b'N');
|
||||
self.flush().await?;
|
||||
read_startup(&mut self.stream).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqFeStream<S> {
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_expect(&mut self, tag: FeTag, max: u32) -> io::Result<&mut [u8]> {
|
||||
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
let actual_tag = FeTag(actual_tag);
|
||||
if actual_tag != tag {
|
||||
return Err(io::Error::other(format!(
|
||||
"incorrect message tag, expected {tag}, got {actual_tag}",
|
||||
)));
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
/// Read a postgres password message, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
|
||||
// passwords are usually pretty short
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: u32 = 512;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqFeStream<S> {
|
||||
/// Tell the client that we are willing to accept SSL.
|
||||
/// This is not cancel safe
|
||||
pub async fn accept_tls(mut self) -> io::Result<S> {
|
||||
// S for SSL.
|
||||
self.write.encryption(b'S');
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Assert that we are using direct TLS.
|
||||
pub fn accept_direct_tls(self) -> S {
|
||||
self.stream
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag, f);
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer
|
||||
pub fn write_message(&mut self, message: BeMessage<'_>) {
|
||||
message.write_message(&mut self.write);
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.write_all_buf(&mut self.write).await?;
|
||||
self.write.reset();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Write the error message to the client, then re-throw it.
|
||||
///
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
|
||||
pub(crate) async fn throw_error<E>(
|
||||
&mut self,
|
||||
error: E,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> ReportedError
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
|
||||
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
|
||||
tracing::info!(
|
||||
kind = error_kind.to_metric_label(),
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
}
|
||||
|
||||
let probe_msg;
|
||||
let mut msg = &*msg;
|
||||
if let Some(ctx) = ctx {
|
||||
if ctx.get_testodrome_id().is_some() {
|
||||
let tag = match error_kind {
|
||||
ErrorKind::User => "client",
|
||||
ErrorKind::ClientDisconnect => "client",
|
||||
ErrorKind::RateLimit => "proxy",
|
||||
ErrorKind::ServiceRateLimit => "proxy",
|
||||
ErrorKind::Quota => "proxy",
|
||||
ErrorKind::Service => "proxy",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "other",
|
||||
ErrorKind::Compute => "compute",
|
||||
};
|
||||
probe_msg = typed_json::json!({
|
||||
"tag": tag,
|
||||
"msg": msg,
|
||||
"cold_start_info": ctx.cold_start_info(),
|
||||
})
|
||||
.to_string();
|
||||
msg = &probe_msg;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
|
||||
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
|
||||
|
||||
self.flush()
|
||||
.await
|
||||
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
|
||||
|
||||
ReportedError::new(error)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user