completely rewrite pq_proto (#12085)

libs/pqproto is designed for safekeeper/pageserver with maximum
throughput.

proxy only needs it for handshakes/authentication where throughput is
not a concern but memory efficiency is. For this reason, we switch to
using read_exact and only allocating as much memory as we need to.

All reads return a `&'a [u8]` instead of a `Bytes` because accidental
sharing of bytes can cause fragmentation. Returning the reference
enforces all callers only hold onto the bytes they absolutely need. For
example, before this change, `pqproto` was allocating 8KiB for the
initial read `BytesMut`, and proxy was holding the `Bytes` in the
`StartupMessageParams` for the entire connection through to passthrough.
This commit is contained in:
Conrad Ludgate
2025-06-01 19:41:45 +01:00
committed by GitHub
parent f05df409bd
commit 87179e26b3
29 changed files with 1122 additions and 600 deletions

View File

@@ -17,35 +17,27 @@ pub(super) async fn authenticate(
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
debug!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, ctx);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
async {
flow.begin(scram).await.map_err(|error| {
warn!(?error, "error sending scram acknowledgement");
error
})?.authenticate().await.map_err(|error| {
let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async {
AuthFlow::new(client, scram)
.authenticate()
.await
.inspect_err(|error| {
warn!(?error, "error processing scram messages");
error
})
}
)
})
.await
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::AuthError::user_timeout(e)
})??;
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
.map_err(auth::AuthError::user_timeout)??;
let client_key = match auth_outcome {
sasl::Outcome::Success(key) => key,

View File

@@ -2,7 +2,6 @@ use std::fmt;
use async_trait::async_trait;
use postgres_client::config::SslMode;
use pq_proto::BeMessage as Be;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};
@@ -16,6 +15,7 @@ use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::stream::PqStream;
@@ -154,11 +154,13 @@ async fn authenticate(
// Give user a URL to spawn a new database.
info!(parent: &span, "sending the auth URL to the user");
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&Be::CLIENT_ENCODING)?
.write_message(&Be::NoticeResponse(&greeting))
.await?;
client.write_message(BeMessage::AuthenticationOk);
client.write_message(BeMessage::ParameterStatus {
name: b"client_encoding",
value: b"UTF8",
});
client.write_message(BeMessage::NoticeResponse(&greeting));
client.flush().await?;
// Wait for console response via control plane (see `mgmt`).
info!(parent: &span, "waiting for console's reply...");
@@ -188,7 +190,7 @@ async fn authenticate(
}
}
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.

View File

@@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext(
debug!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let ep = EndpointIdInt::from(&info.endpoint);
let auth_flow = AuthFlow::new(client)
.begin(auth::CleartextPassword {
let auth_flow = AuthFlow::new(
client,
auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
})
.await?;
drop(paused);
// cleartext auth is only allowed to the ws/http protocol.
// If we're here, we already received the password in the first message.
// Scram protocol will be executed on the proxy side.
let auth_outcome = auth_flow.authenticate().await?;
},
);
let auth_outcome = {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// cleartext auth is only allowed to the ws/http protocol.
// If we're here, we already received the password in the first message.
// Scram protocol will be executed on the proxy side.
auth_flow.authenticate().await?
};
let keys = match auth_outcome {
sasl::Outcome::Success(key) => key,
@@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication(
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
.await?
let payload = AuthFlow::new(client, auth::PasswordHack)
.get_password()
.await?;

View File

@@ -31,6 +31,7 @@ use crate::control_plane::{
};
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::pqproto::BeMessage;
use crate::protocol2::ConnectionInfoExtra;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::ComputeConnectBackend;
@@ -402,7 +403,7 @@ async fn authenticate_with_secret(
};
// we have authenticated the password
client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
client.write_message(BeMessage::AuthenticationOk);
return Ok(ComputeCredentials { info, keys });
}
@@ -702,7 +703,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -784,7 +785,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -838,7 +839,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {

View File

@@ -5,7 +5,6 @@ use std::net::IpAddr;
use std::str::FromStr;
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use thiserror::Error;
use tracing::{debug, warn};
@@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param;
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI};
use crate::types::{EndpointId, RoleName};

View File

@@ -1,10 +1,8 @@
//! Main authentication flow.
use std::io;
use std::sync::Arc;
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload};
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
pub(crate) struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub(crate) struct Scram<'a>(
pub(crate) &'a scram::ServerSecret,
pub(crate) &'a RequestContext,
);
impl AuthMethod for Scram<'_> {
impl Scram<'_> {
#[inline(always)]
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
if channel_binding {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
} else {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
scram::METHODS_WITHOUT_PLUS,
))
}
@@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> {
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
pub(crate) struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub(crate) struct CleartextPassword {
@@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword {
pub(crate) secret: AuthSecret,
}
impl AuthMethod for CleartextPassword {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub(crate) struct AuthFlow<'a, S, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream<S>>,
/// State might contain ancillary data (see [`Self::begin`]).
/// State might contain ancillary data.
state: State,
tls_server_end_point: TlsServerEndPoint,
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
/// Create a new wrapper for client authentication.
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
Self {
stream,
state: Begin,
state: method,
tls_server_end_point,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
tls_server_end_point: self.tls_server_end_point,
})
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
self.stream
.write_message(BeMessage::AuthenticationCleartextPassword);
self.stream.flush().await?;
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
@@ -133,6 +99,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
self.stream
.write_message(BeMessage::AuthenticationCleartextPassword);
self.stream.flush().await?;
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
@@ -147,7 +117,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
self.stream.write_message(BeMessage::AuthenticationOk);
}
Ok(outcome)
@@ -159,42 +129,36 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret, ctx) = self.state;
let channel_binding = self.tls_server_end_point;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// send sasl message.
{
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthError::MalformedPassword("bad sasl message"))?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {
return Err(super::AuthError::bad_auth_method(sasl.method));
let sasl = self.state.first_message(channel_binding.supported());
self.stream.write_message(sasl);
self.stream.flush().await?;
}
match sasl.method {
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
_ => {}
}
// complete sasl handshake.
sasl::authenticate(ctx, self.stream, |method| {
// Currently, the only supported SASL method is SCRAM.
match method {
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => {
ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus);
}
method => return Err(sasl::Error::BadAuthMethod(method.into())),
}
// TODO: make this a metric instead
info!("client chooses {}", sasl.method);
// TODO: make this a metric instead
info!("client chooses {}", method);
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,
rand::random,
self.tls_server_end_point,
))
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
}
Ok(outcome)
Ok(scram::Exchange::new(secret, rand::random, channel_binding))
})
.await
.map_err(AuthError::Sasl)
}
}

View File

@@ -4,8 +4,9 @@
//! This allows connecting to pods/services running in the same Kubernetes cluster from
//! the outside. Similar to an ingress controller for HTTPS.
use std::net::SocketAddr;
use std::path::Path;
use std::{net::SocketAddr, sync::Arc};
use std::sync::Arc;
use anyhow::{Context, anyhow, bail, ensure};
use clap::Arg;
@@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::TlsConnector;
use tokio_rustls::server::TlsStream;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info};
use utils::project_git_version;
@@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled};
use crate::proxy::{
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
project_git_version!(GIT_VERSION);
@@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> {
.parse()?;
// Configure TLS
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
let tls_config = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> {
dest.clone(),
tls_config.clone(),
None,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
))
@@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> {
dest,
tls_config,
Some(compute_tls_config),
tls_server_end_point,
proxy_listener_compute_tls,
cancellation_token.clone(),
))
@@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> {
pub(super) fn parse_tls(
key_path: &Path,
cert_path: &Path,
) -> anyhow::Result<(Arc<rustls::ServerConfig>, TlsServerEndPoint)> {
) -> anyhow::Result<Arc<rustls::ServerConfig>> {
let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
@@ -187,10 +189,6 @@ pub(super) fn parse_tls(
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
@@ -199,14 +197,13 @@ pub(super) fn parse_tls(
.with_single_cert(cert_chain, key)?
.into();
Ok((tls_config, tls_server_end_point))
Ok(tls_config)
}
pub(super) async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -242,15 +239,7 @@ pub(super) async fn task_main(
crate::metrics::Protocol::SniRouter,
"sni",
);
handle_client(
ctx,
dest_suffix,
tls_config,
compute_tls_config,
tls_server_end_point,
socket,
)
.await
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -269,55 +258,26 @@ pub(super) async fn task_main(
Ok(())
}
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
let msg = stream.read_startup_packet().await?;
use pq_proto::FeStartupPacket::SslRequest;
) -> anyhow::Result<TlsStream<S>> {
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
match msg {
SslRequest { direct: false } => {
stream
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
.await?;
FeStartupPacket::SslRequest { direct: None } => {
let raw = stream.accept_tls().await?;
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empty.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
Ok(raw
.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?)
}
unexpected => {
info!(
?unexpected,
"unexpected startup packet, rejecting connection"
);
stream
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None)
.await?
Err(stream.throw_error(TlsRequired, None).await)?
}
}
}
@@ -327,15 +287,18 @@ async fn handle_client(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
let sni = tls_stream
.get_ref()
.1
.server_name()
.ok_or(anyhow!("SNI missing"))?;
let dest: Vec<&str> = sni
.split_once('.')
.context("invalid SNI")?

View File

@@ -476,8 +476,7 @@ pub async fn run() -> anyhow::Result<()> {
let key_path = args.tls_key.expect("already asserted it is set");
let cert_path = args.tls_cert.expect("already asserted it is set");
let (tls_config, tls_server_end_point) =
super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let dest = Arc::new(dest);
@@ -485,7 +484,6 @@ pub async fn run() -> anyhow::Result<()> {
dest.clone(),
tls_config.clone(),
None,
tls_server_end_point,
listen,
cancellation_token.clone(),
));
@@ -494,7 +492,6 @@ pub async fn run() -> anyhow::Result<()> {
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
tls_server_end_point,
listen_tls,
cancellation_token.clone(),
));

View File

@@ -5,7 +5,6 @@ use anyhow::{Context, anyhow};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::CancelToken;
use postgres_client::tls::MakeTlsConnect;
use pq_proto::CancelKeyData;
use redis::{Cmd, FromRedisValue, Value};
use serde::{Deserialize, Serialize};
use thiserror::Error;
@@ -21,6 +20,7 @@ use crate::control_plane::ControlPlaneApi;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
use crate::pqproto::CancelKeyData;
use crate::protocol2::ConnectionInfoExtra;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;

View File

@@ -8,7 +8,6 @@ use itertools::Itertools;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use pq_proto::StartupMessageParams;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
@@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use futures::{FutureExt, TryFutureExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info};
@@ -221,12 +221,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e, Some(ctx)).await?;
}
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
let mut node = connect_to_compute(
let node = connect_to_compute(
ctx,
&TcpMechanism {
user_info,
@@ -238,7 +236,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
@@ -246,14 +244,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
Ok(Some(ProxyPassthrough {
client: stream,

View File

@@ -4,7 +4,6 @@ use std::net::IpAddr;
use chrono::Utc;
use once_cell::sync::OnceCell;
use pq_proto::StartupMessageParams;
use smol_str::SmolStr;
use tokio::sync::mpsc;
use tracing::field::display;
@@ -20,6 +19,7 @@ use crate::metrics::{
ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol,
Waiting,
};
use crate::pqproto::StartupMessageParams;
use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
use crate::types::{DbName, EndpointId, RoleName};

View File

@@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr;
use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr};
use parquet::file::writer::SerializedFileWriter;
use parquet::record::RecordWriter;
use pq_proto::StartupMessageParams;
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel};
use serde::ser::SerializeMap;
use tokio::sync::mpsc;
@@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner};
use crate::config::remote_storage_from_toml;
use crate::context::LOG_CHAN_DISCONNECT;
use crate::ext::TaskExt;
use crate::pqproto::StartupMessageParams;
#[derive(clap::Args, Clone, Debug)]
pub struct ParquetUploadArgs {

View File

@@ -92,6 +92,7 @@ mod logging;
mod metrics;
mod parse;
mod pglb;
mod pqproto;
mod protocol2;
mod proxy;
mod rate_limiter;

693
proxy/src/pqproto.rs Normal file
View File

@@ -0,0 +1,693 @@
//! Postgres protocol codec
//!
//! <https://www.postgresql.org/docs/current/protocol-message-formats.html>
use std::fmt;
use std::io::{self, Cursor};
use bytes::{Buf, BufMut};
use itertools::Itertools;
use rand::distributions::{Distribution, Standard};
use tokio::io::{AsyncRead, AsyncReadExt};
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
pub type ErrorCode = [u8; 5];
pub const FE_PASSWORD_MESSAGE: u8 = b'p';
pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
/// The protocol version number.
///
/// The most significant 16 bits are the major version number (3 for the protocol described here).
/// The least significant 16 bits are the minor version number (0 for the protocol described here).
/// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE>
#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
pub struct ProtocolVersion {
major: big_endian::U16,
minor: big_endian::U16,
}
impl ProtocolVersion {
pub const fn new(major: u16, minor: u16) -> Self {
Self {
major: big_endian::U16::new(major),
minor: big_endian::U16::new(minor),
}
}
pub const fn minor(self) -> u16 {
self.minor.get()
}
pub const fn major(self) -> u16 {
self.major.get()
}
}
impl fmt::Debug for ProtocolVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entry(&self.major())
.entry(&self.minor())
.finish()
}
}
/// read the type from the stream using zerocopy.
///
/// not cancel safe.
macro_rules! read {
($s:expr => $t:ty) => {{
// cannot be implemented as a function due to lack of const-generic-expr
let mut buf = [0; size_of::<$t>()];
$s.read_exact(&mut buf).await?;
let res: $t = zerocopy::transmute!(buf);
res
}};
}
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
where
S: AsyncRead + Unpin,
{
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
let header = read!(stream => StartupHeader);
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
// First byte indicates standard SSL handshake message
// (It can't be a Postgres startup length because in network byte order
// that would be a startup packet hundreds of megabytes long)
if header.as_bytes()[0] == 0x16 {
return Ok(FeStartupPacket::SslRequest {
// The bytes we read for the header are actually part of a TLS ClientHello.
// In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here.
// In practice though, I see no world where a ClientHello is less than 8 bytes
// since it includes ephemeral keys etc.
direct: Some(zerocopy::transmute!(header)),
});
}
let Some(len) = (header.len.get() as usize).checked_sub(8) else {
return Err(io::Error::other(format!(
"invalid startup message length {}, must be at least 8.",
header.len,
)));
};
// TODO: add a histogram for startup packet lengths
if len > MAX_STARTUP_PACKET_LENGTH {
tracing::warn!("large startup message detected: {len} bytes");
return Err(io::Error::other(format!(
"invalid startup message length {len}"
)));
}
match header.version {
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-CANCELREQUEST>
CANCEL_REQUEST_CODE => {
if len != 8 {
return Err(io::Error::other(
"CancelRequest message is malformed, backend PID / secret key missing",
));
}
Ok(FeStartupPacket::CancelRequest(
read!(stream => CancelKeyData),
))
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST>
NEGOTIATE_SSL_CODE => {
// Requested upgrade to SSL (aka TLS)
Ok(FeStartupPacket::SslRequest { direct: None })
}
NEGOTIATE_GSS_CODE => {
// Requested upgrade to GSSAPI
Ok(FeStartupPacket::GssEncRequest)
}
version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other(
format!("Unrecognized request code {version:?}"),
)),
// StartupMessage
version => {
// The protocol version number is followed by one or more pairs of parameter name and value strings.
// A zero byte is required as a terminator after the last name/value pair.
// Parameters can appear in any order. user is required, others are optional.
let mut buf = vec![0; len];
stream.read_exact(&mut buf).await?;
if buf.pop() != Some(b'\0') {
return Err(io::Error::other(
"StartupMessage params: missing null terminator",
));
}
// TODO: Don't do this.
// There's no guarantee that these messages are utf8,
// but they usually happen to be simple ascii.
let params = String::from_utf8(buf)
.map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?;
Ok(FeStartupPacket::StartupMessage {
version,
params: StartupMessageParams { params },
})
}
}
}
/// Read a raw postgres packet, which will respect the max length requested.
///
/// This returns the message tag, as well as the message body. The message
/// body is written into `buf`, and it is otherwise completely overwritten.
///
/// This is not cancel safe.
pub async fn read_message<'a, S>(
stream: &mut S,
buf: &'a mut Vec<u8>,
max: usize,
) -> io::Result<(u8, &'a mut [u8])>
where
S: AsyncRead + Unpin,
{
/// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes.
/// The first byte is a message tag, and the next 4 bytes is a big-endian length.
///
/// Awkwardly, the length value is inclusive of itself, but not of the tag. For example,
/// an empty message will always have length 4.
#[derive(Clone, Copy, FromBytes)]
#[repr(C)]
struct Header {
tag: u8,
len: big_endian::U32,
}
let header = read!(stream => Header);
// as described above, the length must be at least 4.
let Some(len) = (header.len.get() as usize).checked_sub(4) else {
return Err(io::Error::other(format!(
"invalid startup message length {}, must be at least 4.",
header.len,
)));
};
// TODO: add a histogram for message lengths
// check if the message exceeds our desired max.
if len > max {
tracing::warn!("large postgres message detected: {len} bytes");
return Err(io::Error::other(format!("invalid message length {len}")));
}
// read in our entire message.
buf.resize(len, 0);
stream.read_exact(buf).await?;
Ok((header.tag, buf))
}
pub struct WriteBuf(Cursor<Vec<u8>>);
impl Buf for WriteBuf {
#[inline]
fn remaining(&self) -> usize {
self.0.remaining()
}
#[inline]
fn chunk(&self) -> &[u8] {
self.0.chunk()
}
#[inline]
fn advance(&mut self, cnt: usize) {
self.0.advance(cnt);
}
}
impl WriteBuf {
pub const fn new() -> Self {
Self(Cursor::new(Vec::new()))
}
/// Use a heuristic to determine if we should shrink the write buffer.
#[inline]
fn should_shrink(&self) -> bool {
let n = self.0.position() as usize;
let len = self.0.get_ref().len();
// the unused space at the front of our buffer is 2x the size of our filled portion.
n + n > len
}
/// Shrink the write buffer so that subsequent writes have more spare capacity.
#[cold]
fn shrink(&mut self) {
let n = self.0.position() as usize;
let buf = self.0.get_mut();
// buf repr:
// [----unused------|-----filled-----|-----uninit-----]
// ^ n ^ buf.len() ^ buf.capacity()
let filled = n..buf.len();
let filled_len = filled.len();
buf.copy_within(filled, 0);
buf.truncate(filled_len);
self.0.set_position(0);
}
/// clear the write buffer.
pub fn reset(&mut self) {
let buf = self.0.get_mut();
buf.clear();
self.0.set_position(0);
}
/// Write a raw message to the internal buffer.
///
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
/// we calculate the length after the fact.
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
if self.should_shrink() {
self.shrink();
}
let buf = self.0.get_mut();
buf.reserve(5 + size_hint);
buf.push(tag);
let start = buf.len();
buf.extend_from_slice(&[0, 0, 0, 0]);
f(buf);
let end = buf.len();
let len = (end - start) as u32;
buf[start..start + 4].copy_from_slice(&len.to_be_bytes());
}
/// Write an encryption response message.
pub fn encryption(&mut self, m: u8) {
self.0.get_mut().push(m);
}
pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) {
self.shrink();
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ERRORRESPONSE>
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
// "SERROR\0CXXXXX\0M\0\0".len() == 17
self.write_raw(17 + msg.len(), b'E', |buf| {
// Severity: ERROR
buf.put_slice(b"SERROR\0");
// Code: error_code
buf.put_u8(b'C');
buf.put_slice(&error_code);
buf.put_u8(0);
// Message: msg
buf.put_u8(b'M');
buf.put_slice(msg.as_bytes());
buf.put_u8(0);
// End.
buf.put_u8(0);
});
}
}
#[derive(Debug)]
pub enum FeStartupPacket {
CancelRequest(CancelKeyData),
SslRequest {
direct: Option<[u8; 8]>,
},
GssEncRequest,
StartupMessage {
version: ProtocolVersion,
params: StartupMessageParams,
},
}
#[derive(Debug, Clone, Default)]
pub struct StartupMessageParams {
pub params: String,
}
impl StartupMessageParams {
/// Get parameter's value by its name.
pub fn get(&self, name: &str) -> Option<&str> {
self.iter().find_map(|(k, v)| (k == name).then_some(v))
}
/// Split command-line options according to PostgreSQL's logic,
/// taking into account all escape sequences but leaving them as-is.
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
self.get("options").map(Self::parse_options_raw)
}
/// Split command-line options according to PostgreSQL's logic,
/// taking into account all escape sequences but leaving them as-is.
pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
// See `postgres: pg_split_opts`.
let mut last_was_escape = false;
input
.split(move |c: char| {
// We split by non-escaped whitespace symbols.
let should_split = c.is_ascii_whitespace() && !last_was_escape;
last_was_escape = c == '\\' && !last_was_escape;
should_split
})
.filter(|s| !s.is_empty())
}
/// Iterate through key-value pairs in an arbitrary order.
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.params.split_terminator('\0').tuples()
}
// This function is mostly useful in tests.
#[cfg(test)]
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
let mut b = Self {
params: String::new(),
};
for (k, v) in pairs {
b.insert(k, v);
}
b
}
/// Set parameter's value by its name.
/// name and value must not contain a \0 byte
pub fn insert(&mut self, name: &str, value: &str) {
self.params.reserve(name.len() + value.len() + 2);
self.params.push_str(name);
self.params.push('\0');
self.params.push_str(value);
self.params.push('\0');
}
}
/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just
/// opaque bytes.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)]
pub struct CancelKeyData(pub big_endian::U64);
pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
CancelKeyData(big_endian::U64::new(id))
}
impl fmt::Display for CancelKeyData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let id = self.0;
f.debug_tuple("CancelKeyData")
.field(&format_args!("{id:x}"))
.finish()
}
}
impl Distribution<CancelKeyData> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
id_to_cancel_key(rng.r#gen())
}
}
pub enum BeMessage<'a> {
AuthenticationOk,
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
AuthenticationCleartextPassword,
BackendKeyData(CancelKeyData),
ParameterStatus {
name: &'a [u8],
value: &'a [u8],
},
ReadyForQuery,
NoticeResponse(&'a str),
NegotiateProtocolVersion {
version: ProtocolVersion,
options: &'a [&'a str],
},
}
#[derive(Debug)]
pub enum BeAuthenticationSaslMessage<'a> {
Methods(&'a [&'a str]),
Continue(&'a [u8]),
Final(&'a [u8]),
}
impl BeMessage<'_> {
/// Write the message into an internal buffer
pub fn write_message(self, buf: &mut WriteBuf) {
match self {
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
BeMessage::AuthenticationOk => {
buf.write_raw(1, b'R', |buf| buf.put_i32(0));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
BeMessage::AuthenticationCleartextPassword => {
buf.write_raw(1, b'R', |buf| buf.put_i32(3));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
let len: usize = methods.iter().map(|m| m.len() + 1).sum();
buf.write_raw(len + 2, b'R', |buf| {
buf.put_i32(10); // Specifies that SASL auth method is used.
for method in methods {
buf.put_slice(method.as_bytes());
buf.put_u8(0);
}
buf.put_u8(0); // zero terminator for the list
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
buf.write_raw(extra.len() + 1, b'R', |buf| {
buf.put_i32(11); // Continue SASL auth.
buf.put_slice(extra);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
buf.write_raw(extra.len() + 1, b'R', |buf| {
buf.put_i32(12); // Send final SASL message.
buf.put_slice(extra);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
BeMessage::BackendKeyData(key_data) => {
buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
BeMessage::NoticeResponse(msg) => {
// 'N' signalizes NoticeResponse messages
buf.write_raw(18 + msg.len(), b'N', |buf| {
// Severity: NOTICE
buf.put_slice(b"SNOTICE\0");
// Code: XX000 (ignored for notice, but still required)
buf.put_slice(b"CXX000\0");
// Message: msg
buf.put_u8(b'M');
buf.put_slice(msg.as_bytes());
buf.put_u8(0);
// End notice.
buf.put_u8(0);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-PARAMETERSTATUS>
BeMessage::ParameterStatus { name, value } => {
buf.write_raw(name.len() + value.len() + 2, b'S', |buf| {
buf.put_slice(name.as_bytes());
buf.put_u8(0);
buf.put_slice(value.as_bytes());
buf.put_u8(0);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
BeMessage::ReadyForQuery => {
buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
BeMessage::NegotiateProtocolVersion { version, options } => {
let len: usize = options.iter().map(|o| o.len() + 1).sum();
buf.write_raw(8 + len, b'v', |buf| {
buf.put_slice(version.as_bytes());
buf.put_u32(options.len() as u32);
for option in options {
buf.put_slice(option.as_bytes());
buf.put_u8(0);
}
});
}
}
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use tokio::io::{AsyncWriteExt, duplex};
use zerocopy::IntoBytes;
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
use super::ProtocolVersion;
#[tokio::test]
async fn reject_large_startup() {
// we're going to define a v3.0 startup message with far too many parameters.
let mut payload = vec![];
// 10001 + 8 bytes.
payload.extend_from_slice(&10009_u32.to_be_bytes());
payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
payload.resize(10009, b'a');
let (mut server, mut client) = duplex(128);
#[rustfmt::skip]
let (server, client) = tokio::join!(
async move { read_startup(&mut server).await.unwrap_err() },
async move { client.write_all(&payload).await.unwrap_err() },
);
assert_eq!(server.to_string(), "invalid startup message length 10001");
assert_eq!(client.to_string(), "broken pipe");
}
#[tokio::test]
async fn reject_large_password() {
// we're going to define a password message that is far too long.
let mut payload = vec![];
payload.push(b'p');
payload.extend_from_slice(&517_u32.to_be_bytes());
payload.resize(518, b'a');
let (mut server, mut client) = duplex(128);
#[rustfmt::skip]
let (server, client) = tokio::join!(
async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() },
async move { client.write_all(&payload).await.unwrap_err() },
);
assert_eq!(server.to_string(), "invalid message length 513");
assert_eq!(client.to_string(), "broken pipe");
}
#[tokio::test]
async fn read_startup_message() {
let mut payload = vec![];
payload.extend_from_slice(&17_u32.to_be_bytes());
payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
payload.extend_from_slice(b"abc\0def\0\0");
let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
let FeStartupPacket::StartupMessage { version, params } = startup else {
panic!("unexpected startup message: {startup:?}");
};
assert_eq!(version.major(), 3);
assert_eq!(version.minor(), 0);
assert_eq!(params.params, "abc\0def\0");
}
#[tokio::test]
async fn read_ssl_message() {
let mut payload = vec![];
payload.extend_from_slice(&8_u32.to_be_bytes());
payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes());
let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
let FeStartupPacket::SslRequest { direct: None } = startup else {
panic!("unexpected startup message: {startup:?}");
};
}
#[tokio::test]
async fn read_tls_message() {
// sample client hello taken from <https://tls13.xargs.org/#client-hello>
let client_hello = [
0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02,
0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e,
0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb,
0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9,
0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01,
0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00,
0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65,
0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02,
0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19,
0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23,
0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e,
0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09,
0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01,
0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01,
0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72,
0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38,
0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62,
0x54,
];
let mut cursor = Cursor::new(&client_hello);
let startup = read_startup(&mut cursor).await.unwrap();
let FeStartupPacket::SslRequest {
direct: Some(prefix),
} = startup
else {
panic!("unexpected startup message: {startup:?}");
};
// check that no data is lost.
assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]);
assert_eq!(cursor.position(), 8);
}
#[tokio::test]
async fn read_message_success() {
let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2";
let mut cursor = Cursor::new(&query);
let mut buf = vec![];
let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
assert_eq!(tag, b'Q');
assert_eq!(message, b"SELECT 1");
let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
assert_eq!(tag, b'Q');
assert_eq!(message, b"SELECT 2");
}
}

View File

@@ -1,5 +1,4 @@
use async_trait::async_trait;
use pq_proto::StartupMessageParams;
use tokio::time;
use tracing::{debug, info, warn};
@@ -15,6 +14,7 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pqproto::StartupMessageParams;
use crate::proxy::retry::{CouldRetry, retry_after, should_retry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;

View File

@@ -1,8 +1,3 @@
use bytes::Buf;
use pq_proto::framed::Framed;
use pq_proto::{
BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
@@ -12,7 +7,10 @@ use crate::config::TlsConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::ERR_INSECURE_CONNECTION;
use crate::pqproto::{
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
};
use crate::proxy::TlsRequired;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;
@@ -71,33 +69,25 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream));
let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?;
loop {
let msg = stream.read_startup_packet().await?;
match msg {
FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let have_tls = tls.is_some();
if !direct {
stream
.write_message(&Be::EncryptionResponse(have_tls))
.await?;
} else if !have_tls {
return Err(HandshakeError::ProtocolViolation);
}
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let Framed {
stream: raw,
read_buf,
write_buf,
} = stream.framed;
let mut read_buf;
let raw = if let Some(direct) = &direct {
read_buf = &direct[..];
stream.accept_direct_tls()
} else {
read_buf = &[];
stream.accept_tls().await?
};
let Stream::Raw { raw } = raw else {
return Err(HandshakeError::StreamUpgradeError(
@@ -105,12 +95,11 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
));
};
let mut read_buf = read_buf.reader();
let mut res = Ok(());
let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone())
.accept_with(raw, |session| {
// push the early data to the tls session
while !read_buf.get_ref().is_empty() {
while !read_buf.is_empty() {
match session.read_tls(&mut read_buf) {
Ok(_) => {}
Err(e) => {
@@ -123,7 +112,6 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
res?;
let read_buf = read_buf.into_inner();
if !read_buf.is_empty() {
return Err(HandshakeError::EarlyData);
}
@@ -157,16 +145,17 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let (_, tls_server_end_point) =
tls.cert_resolver.resolve(conn_info.server_name());
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
},
read_buf,
write_buf,
},
let tls = Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
};
(stream, msg) = PqStream::parse_startup(tls).await?;
} else {
if direct.is_some() {
// client sent us a ClientHello already, we can't do anything with it.
return Err(HandshakeError::ProtocolViolation);
}
msg = stream.reject_encryption().await?;
}
}
_ => return Err(HandshakeError::ProtocolViolation),
@@ -176,7 +165,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
msg = stream.reject_encryption().await?;
}
_ => return Err(HandshakeError::ProtocolViolation),
},
@@ -186,13 +175,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
return stream
.throw_error_str(
ERR_INSECURE_CONNECTION,
crate::error::ErrorKind::User,
None,
)
.await?;
Err(stream.throw_error(TlsRequired, None).await)?;
}
// This log highlights the start of the connection.
@@ -214,20 +197,21 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// no protocol extensions are supported.
// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
let mut unsupported = vec![];
for (k, _) in params.iter() {
let mut supported = StartupMessageParams::default();
for (k, v) in params.iter() {
if k.starts_with("_pq_.") {
unsupported.push(k);
} else {
supported.insert(k, v);
}
}
// TODO: remove unsupported options so we don't send them to compute.
stream
.write_message(&Be::NegotiateProtocolVersion {
version: PG_PROTOCOL_LATEST,
options: &unsupported,
})
.await?;
stream.write_message(BeMessage::NegotiateProtocolVersion {
version: PG_PROTOCOL_LATEST,
options: &unsupported,
});
stream.flush().await?;
info!(
?version,
@@ -235,7 +219,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake; unsupported minor version requested"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, supported));
}
FeStartupPacket::StartupMessage { version, params } => {
warn!(

View File

@@ -10,15 +10,14 @@ pub(crate) mod wake_compute;
use std::sync::Arc;
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use futures::{FutureExt, TryFutureExt};
use futures::FutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
@@ -27,8 +26,9 @@ use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::rate_limiter::EndpointRateLimiter;
@@ -38,6 +38,18 @@ use crate::{auth, compute};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
cancellation_token: &CancellationToken,
@@ -329,7 +341,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
let user = user_info.get_user().to_owned();
@@ -349,10 +361,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return stream
return Err(stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?;
.await)?;
}
};
@@ -365,7 +377,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
let mut node = connect_to_compute(
let res = connect_to_compute(
ctx,
&TcpMechanism {
user_info: compute_user_info.clone(),
@@ -377,22 +389,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
.await;
let node = match res {
Ok(node) => node,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
@@ -413,31 +422,28 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
}
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
pub(crate) async fn prepare_client_connection(
pub(crate) fn prepare_client_connection(
node: &compute::PostgresConnection,
cancel_key_data: CancelKeyData,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<(), std::io::Error> {
) {
// Forward all deferred notices to the client.
for notice in &node.delayed_notice {
stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?;
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes());
});
}
// Forward all postgres connection params to the client.
for (name, value) in &node.params {
stream.write_message_noflush(&Be::ParameterStatus {
stream.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
})?;
});
}
stream
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&Be::ReadyForQuery)
.await?;
Ok(())
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
stream.write_message(BeMessage::ReadyForQuery);
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]

View File

@@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati
#[cfg(test)]
mod tests {
use super::ShouldRetryWakeCompute;
use postgres_client::error::{DbError, SqlState};
use super::ShouldRetryWakeCompute;
#[test]
fn should_retry_wake_compute_for_db_error() {
// These SQLStates should NOT trigger a wake_compute retry.

View File

@@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_client::tls::TlsConnect;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio_util::codec::{Decoder, Encoder};
use super::*;
@@ -49,15 +49,14 @@ async fn proxy_mitm(
};
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
let (end_client, buf) = end_client.framed.into_inner();
assert!(buf.is_empty());
let end_client = end_client.flush_and_into_inner().await.unwrap();
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
// give the end_server the startup parameters
let mut buf = BytesMut::new();
frontend::startup_message(
&postgres_protocol::message::frontend::StartupMessageParams {
params: startup.params.into(),
params: startup.params.as_bytes().into(),
},
&mut buf,
)

View File

@@ -128,7 +128,7 @@ trait TestAuth: Sized {
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
stream.write_message_noflush(&Be::AuthenticationOk)?;
stream.write_message(BeMessage::AuthenticationOk);
Ok(())
}
}
@@ -157,9 +157,7 @@ impl TestAuth for Scram {
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
let outcome = auth::AuthFlow::new(stream)
.begin(auth::Scram(&self.0, &RequestContext::test()))
.await?
let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test()))
.authenticate()
.await?;
@@ -185,10 +183,12 @@ async fn dummy_proxy(
auth.authenticate(&mut stream).await?;
stream
.write_message_noflush(&Be::CLIENT_ENCODING)?
.write_message(&Be::ReadyForQuery)
.await?;
stream.write_message(BeMessage::ParameterStatus {
name: b"client_encoding",
value: b"UTF8",
});
stream.write_message(BeMessage::ReadyForQuery);
stream.flush().await?;
Ok(())
}

View File

@@ -1,10 +1,11 @@
use core::net::IpAddr;
use std::sync::Arc;
use pq_proto::CancelKeyData;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::pqproto::CancelKeyData;
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(

View File

@@ -1,16 +1,15 @@
use std::io::ErrorKind;
use anyhow::Ok;
use pq_proto::{CancelKeyData, id_to_cancel_key};
use serde::{Deserialize, Serialize};
use crate::pqproto::{CancelKeyData, id_to_cancel_key};
pub mod keyspace {
pub const CANCEL_PREFIX: &str = "cancel";
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum KeyPrefix {
#[serde(untagged)]
Cancel(CancelKeyData),
}
@@ -18,9 +17,7 @@ impl KeyPrefix {
pub(crate) fn build_redis_key(&self) -> String {
match self {
KeyPrefix::Cancel(key) => {
let hi = (key.backend_pid as u64) << 32;
let lo = (key.cancel_key as u64) & 0xffff_ffff;
let id = hi | lo;
let id = key.0.get();
let keyspace = keyspace::CANCEL_PREFIX;
format!("{keyspace}:{id:x}")
}
@@ -63,10 +60,7 @@ mod tests {
#[test]
fn test_build_redis_key() {
let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData {
backend_pid: 12345,
cancel_key: 54321,
});
let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321));
let redis_key = cancel_key.build_redis_key();
assert_eq!(redis_key, "cancel:30390000d431");
@@ -77,10 +71,7 @@ mod tests {
let redis_key = "cancel:30390000d431";
let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key");
let ref_key = CancelKeyData {
backend_pid: 12345,
cancel_key: 54321,
};
let ref_key = id_to_cancel_key(12345 << 32 | 54321);
assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str());
let KeyPrefix::Cancel(cancel_key) = key;

View File

@@ -2,11 +2,9 @@ use std::convert::Infallible;
use std::sync::Arc;
use futures::StreamExt;
use pq_proto::CancelKeyData;
use redis::aio::PubSub;
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
@@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate {
role_name: RoleNameInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct CancelSession {
pub(crate) region_id: Option<String>,
pub(crate) cancel_key_data: CancelKeyData,
pub(crate) session_id: Uuid,
pub(crate) peer_addr: Option<std::net::IpAddr>,
}
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
T: for<'de2> serde::Deserialize<'de2>,

View File

@@ -1,7 +1,5 @@
//! Definitions for SASL messages.
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
use crate::parse::split_cstr;
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
@@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> {
}
}
/// A single SASL message.
/// This struct is deliberately decoupled from lower-level
/// [`BeAuthenticationSaslMessage`].
#[derive(Debug)]
pub(super) enum ServerMessage<T> {
/// We expect to see more steps.
Continue(T),
/// This is the final step.
Final(T),
}
impl<'a> ServerMessage<&'a str> {
pub(super) fn to_reply(&self) -> BeMessage<'a> {
BeMessage::AuthenticationSasl(match self {
ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()),
ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -14,7 +14,7 @@ use std::io;
pub(crate) use channel_binding::ChannelBinding;
pub(crate) use messages::FirstMessage;
pub(crate) use stream::{Outcome, SaslStream};
pub(crate) use stream::{Outcome, authenticate};
use thiserror::Error;
use crate::error::{ReportableError, UserFacingError};
@@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError};
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]
pub(crate) enum Error {
#[error("Unsupported authentication method: {0}")]
BadAuthMethod(Box<str>),
#[error("Channel binding failed: {0}")]
ChannelBindingFailed(&'static str),
@@ -54,6 +57,7 @@ impl UserFacingError for Error {
impl ReportableError for Error {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
Error::BadAuthMethod(_) => crate::error::ErrorKind::User,
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
Error::BadClientMessage(_) => crate::error::ErrorKind::User,

View File

@@ -3,61 +3,12 @@
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use super::Mechanism;
use super::messages::ServerMessage;
use super::{Mechanism, Step};
use crate::context::RequestContext;
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::stream::PqStream;
/// Abstracts away all peculiarities of the libpq's protocol.
pub(crate) struct SaslStream<'a, S> {
/// The underlying stream.
stream: &'a mut PqStream<S>,
/// Current password message we received from client.
current: bytes::Bytes,
/// First SASL message produced by client.
first: Option<&'a str>,
}
impl<'a, S> SaslStream<'a, S> {
pub(crate) fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
Self {
stream,
current: bytes::Bytes::new(),
first: Some(first),
}
}
}
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
// Receive a new SASL message from the client.
async fn recv(&mut self) -> io::Result<&str> {
if let Some(first) = self.first.take() {
return Ok(first);
}
self.current = self.stream.read_password_message().await?;
let s = std::str::from_utf8(&self.current)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
Ok(s)
}
}
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
// Send a SASL message to the client.
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message(&msg.to_reply()).await?;
Ok(())
}
// Queue a SASL message for the client.
fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message_noflush(&msg.to_reply())?;
Ok(())
}
}
/// SASL authentication outcome.
/// It's much easier to match on those two variants
/// than to peek into a noisy protocol error type.
@@ -69,33 +20,62 @@ pub(crate) enum Outcome<R> {
Failure(&'static str),
}
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
/// Perform SASL message exchange according to the underlying algorithm
/// until user is either authenticated or denied access.
pub(crate) async fn authenticate<M: Mechanism>(
mut self,
mut mechanism: M,
) -> super::Result<Outcome<M::Output>> {
loop {
let input = self.recv().await?;
let step = mechanism.exchange(input).map_err(|error| {
info!(?error, "error during SASL exchange");
error
})?;
pub async fn authenticate<S, F, M>(
ctx: &RequestContext,
stream: &mut PqStream<S>,
mechanism: F,
) -> super::Result<Outcome<M::Output>>
where
S: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&str) -> super::Result<M>,
M: Mechanism,
{
let sasl = {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
use super::Step;
return Ok(match step {
Step::Continue(moved_mechanism, reply) => {
self.send(&ServerMessage::Continue(&reply)).await?;
mechanism = moved_mechanism;
continue;
}
Step::Success(result, reply) => {
self.send_noflush(&ServerMessage::Final(&reply))?;
Outcome::Success(result)
}
Step::Failure(reason) => Outcome::Failure(reason),
});
// Initial client message contains the chosen auth method's name.
let msg = stream.read_password_message().await?;
super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))?
};
let mut mechanism = mechanism(sasl.method)?;
let mut input = sasl.message;
loop {
let step = mechanism
.exchange(input)
.inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?;
match step {
Step::Continue(moved_mechanism, reply) => {
mechanism = moved_mechanism;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
// get next input
stream.flush().await?;
let msg = stream.read_password_message().await?;
input = std::str::from_utf8(msg)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
}
Step::Success(result, reply) => {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
stream.write_message(BeMessage::AuthenticationOk);
// exit with success
break Ok(Outcome::Success(result));
}
// exit with failure
Step::Failure(reason) => break Ok(Outcome::Failure(reason)),
}
}
}

View File

@@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
use serde_json::value::RawValue;
@@ -41,6 +40,7 @@ use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::{NeonOptions, run_until_cancelled};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
@@ -219,7 +219,7 @@ fn get_conn_info(
let mut options = Option::None;
let mut params = StartupMessageParamsBuilder::default();
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {

View File

@@ -2,19 +2,17 @@ use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_rustls::server::TlsStream;
use tracing::debug;
use crate::control_plane::messages::ColdStartInfo;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::pqproto::{
BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
read_message, read_startup,
};
use crate::tls::TlsServerEndPoint;
/// Stream wrapper which implements libpq's protocol.
@@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint;
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
pub(crate) framed: Framed<S>,
stream: S,
read: Vec<u8>,
write: WriteBuf,
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S) -> Self {
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
#[cfg(test)]
pub fn new_skip_handshake(stream: S) -> Self {
Self {
framed: Framed::new(stream),
stream,
read: Vec::new(),
write: WriteBuf::new(),
}
}
/// Extract the underlying stream and read buffer.
pub fn into_inner(self) -> (S, BytesMut) {
self.framed.into_inner()
}
/// Get a shared reference to the underlying stream.
pub(crate) fn get_ref(&self) -> &S {
self.framed.get_ref()
}
}
fn err_connection() -> io::Error {
io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
/// Construct a new libpq protocol wrapper and read the first startup message.
///
/// This is not cancel safe.
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
let startup = read_startup(&mut stream).await?;
Ok((
Self {
stream,
read: Vec::new(),
write: WriteBuf::new(),
},
startup,
))
}
/// Tell the client that encryption is not supported.
///
/// This is not cancel safe
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
// N for No.
self.write.encryption(b'N');
self.flush().await?;
read_startup(&mut self.stream).await
}
}
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
self.framed
.read_startup_message()
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)
}
async fn read_message(&mut self) -> io::Result<FeMessage> {
self.framed
.read_message()
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)
}
pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
match self.read_message().await? {
FeMessage::PasswordMessage(msg) => Ok(msg),
bad => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected message type: {bad:?}"),
)),
/// Read a raw postgres packet, which will respect the max length requested.
/// This is not cancel safe.
async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> {
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
if actual_tag != tag {
return Err(io::Error::other(format!(
"incorrect message tag, expected {:?}, got {:?}",
tag as char, actual_tag as char,
)));
}
Ok(msg)
}
/// Read a postgres password message, which will respect the max length requested.
/// This is not cancel safe.
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
// passwords are usually pretty short
// and SASL SCRAM messages are no longer than 256 bytes in my testing
// (a few hashes and random bytes, encoded into base64).
const MAX_PASSWORD_LENGTH: usize = 512;
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
.await
}
}
@@ -84,6 +101,16 @@ pub struct ReportedError {
error_kind: ErrorKind,
}
impl ReportedError {
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
let error_kind = e.get_error_kind();
Self {
source: e.into(),
error_kind,
}
}
}
impl std::fmt::Display for ReportedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.source.fmt(f)
@@ -102,109 +129,65 @@ impl ReportableError for ReportedError {
}
}
#[derive(Serialize, Deserialize, Debug)]
enum ErrorTag {
#[serde(rename = "proxy")]
Proxy,
#[serde(rename = "compute")]
Compute,
#[serde(rename = "client")]
Client,
#[serde(rename = "controlplane")]
ControlPlane,
#[serde(rename = "other")]
Other,
}
impl From<ErrorKind> for ErrorTag {
fn from(error_kind: ErrorKind) -> Self {
match error_kind {
ErrorKind::User => Self::Client,
ErrorKind::ClientDisconnect => Self::Client,
ErrorKind::RateLimit => Self::Proxy,
ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI
ErrorKind::Quota => Self::Proxy,
ErrorKind::Service => Self::Proxy,
ErrorKind::ControlPlane => Self::ControlPlane,
ErrorKind::Postgres => Self::Other,
ErrorKind::Compute => Self::Compute,
}
}
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "snake_case")]
struct ProbeErrorData {
tag: ErrorTag,
msg: String,
cold_start_info: Option<ColdStartInfo>,
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the message into an internal buffer, but don't flush the underlying stream.
pub(crate) fn write_message_noflush(
&mut self,
message: &BeMessage<'_>,
) -> io::Result<&mut Self> {
self.framed
.write_message(message)
.map_err(ProtocolError::into_io_error)?;
Ok(self)
/// Tell the client that we are willing to accept SSL.
/// This is not cancel safe
pub async fn accept_tls(mut self) -> io::Result<S> {
// S for SSL.
self.write.encryption(b'S');
self.flush().await?;
Ok(self.stream)
}
/// Write the message into an internal buffer and flush it.
pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
/// Assert that we are using direct TLS.
pub fn accept_direct_tls(self) -> S {
self.stream
}
/// Write a raw message to the internal buffer.
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
self.write.write_raw(size_hint, tag, f);
}
/// Write the message into an internal buffer
pub fn write_message(&mut self, message: BeMessage<'_>) {
message.write_message(&mut self.write);
}
/// Flush the output buffer into the underlying stream.
pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
self.framed.flush().await?;
Ok(self)
///
/// This is cancel safe.
pub async fn flush(&mut self) -> io::Result<()> {
self.stream.write_all_buf(&mut self.write).await?;
self.write.reset();
self.stream.flush().await?;
Ok(())
}
/// Writes message with the given error kind to the stream.
/// Used only for probe queries
async fn write_format_message(
&mut self,
msg: &str,
error_kind: ErrorKind,
ctx: Option<&crate::context::RequestContext>,
) -> String {
let formatted_msg = match ctx {
Some(ctx) if ctx.get_testodrome_id().is_some() => {
serde_json::to_string(&ProbeErrorData {
tag: ErrorTag::from(error_kind),
msg: msg.to_string(),
cold_start_info: Some(ctx.cold_start_info()),
})
.unwrap_or_default()
}
_ => msg.to_string(),
};
// already error case, ignore client IO error
self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None))
.await
.inspect_err(|e| debug!("write_message failed: {e}"))
.ok();
formatted_msg
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
self.flush().await?;
Ok(self.stream)
}
/// Write the error message using [`Self::write_format_message`], then re-throw it.
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
/// Write the error message to the client, then re-throw it.
///
/// Trait [`UserFacingError`] acts as an allowlist for error types.
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
pub async fn throw_error_str<T>(
pub(crate) async fn throw_error<E>(
&mut self,
msg: &'static str,
error_kind: ErrorKind,
error: E,
ctx: Option<&crate::context::RequestContext>,
) -> Result<T, ReportedError> {
self.write_format_message(msg, error_kind, ctx).await;
) -> ReportedError
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
tracing::info!(
@@ -214,39 +197,39 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
);
}
Err(ReportedError {
source: anyhow::anyhow!(msg),
error_kind,
})
}
/// Write the error message using [`Self::write_format_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
pub(crate) async fn throw_error<T, E>(
&mut self,
error: E,
ctx: Option<&crate::context::RequestContext>,
) -> Result<T, ReportedError>
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
self.write_format_message(&msg, error_kind, ctx).await;
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
tracing::info!(
kind=error_kind.to_metric_label(),
error=%error,
msg,
"forwarding error to user",
);
let probe_msg;
let mut msg = &*msg;
if let Some(ctx) = ctx {
if ctx.get_testodrome_id().is_some() {
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
}
Err(ReportedError {
source: anyhow::anyhow!(error),
error_kind,
})
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
self.flush()
.await
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
ReportedError::new(error)
}
}