From 87179e26b3c18d9cd09b9eecbfed1db742b391ab Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 1 Jun 2025 19:41:45 +0100 Subject: [PATCH] completely rewrite pq_proto (#12085) libs/pqproto is designed for safekeeper/pageserver with maximum throughput. proxy only needs it for handshakes/authentication where throughput is not a concern but memory efficiency is. For this reason, we switch to using read_exact and only allocating as much memory as we need to. All reads return a `&'a [u8]` instead of a `Bytes` because accidental sharing of bytes can cause fragmentation. Returning the reference enforces all callers only hold onto the bytes they absolutely need. For example, before this change, `pqproto` was allocating 8KiB for the initial read `BytesMut`, and proxy was holding the `Bytes` in the `StartupMessageParams` for the entire connection through to passthrough. --- proxy/src/auth/backend/classic.rs | 26 +- proxy/src/auth/backend/console_redirect.rs | 16 +- proxy/src/auth/backend/hacks.rs | 30 +- proxy/src/auth/backend/mod.rs | 9 +- proxy/src/auth/credentials.rs | 2 +- proxy/src/auth/flow.rs | 118 ++-- proxy/src/binary/pg_sni_router.rs | 87 +-- proxy/src/binary/proxy.rs | 5 +- proxy/src/cancellation.rs | 2 +- proxy/src/compute.rs | 2 +- proxy/src/console_redirect_proxy.rs | 20 +- proxy/src/context/mod.rs | 2 +- proxy/src/context/parquet.rs | 2 +- proxy/src/lib.rs | 1 + proxy/src/pqproto.rs | 693 +++++++++++++++++++++ proxy/src/proxy/connect_compute.rs | 2 +- proxy/src/proxy/handshake.rs | 90 ++- proxy/src/proxy/mod.rs | 68 +- proxy/src/proxy/retry.rs | 3 +- proxy/src/proxy/tests/mitm.rs | 7 +- proxy/src/proxy/tests/mod.rs | 16 +- proxy/src/redis/cancellation_publisher.rs | 3 +- proxy/src/redis/keys.rs | 21 +- proxy/src/redis/notifications.rs | 10 - proxy/src/sasl/messages.rs | 22 - proxy/src/sasl/mod.rs | 6 +- proxy/src/sasl/stream.rs | 136 ++-- proxy/src/serverless/sql_over_http.rs | 4 +- proxy/src/stream.rs | 319 +++++----- 29 files changed, 1122 insertions(+), 600 deletions(-) create mode 100644 proxy/src/pqproto.rs diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 5e494dfdd6..dcc500f2c8 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -17,35 +17,27 @@ pub(super) async fn authenticate( config: &'static AuthenticationConfig, secret: AuthSecret, ) -> auth::Result { - let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { debug!("auth endpoint chooses MD5"); - return Err(auth::AuthError::bad_auth_method("MD5")); + return Err(auth::AuthError::MalformedPassword("MD5 not supported")); } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); let scram = auth::Scram(&secret, ctx); - let auth_outcome = tokio::time::timeout( - config.scram_protocol_timeout, - async { - - flow.begin(scram).await.map_err(|error| { - warn!(?error, "error sending scram acknowledgement"); - error - })?.authenticate().await.map_err(|error| { + let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async { + AuthFlow::new(client, scram) + .authenticate() + .await + .inspect_err(|error| { warn!(?error, "error processing scram messages"); - error }) - } - ) + }) .await - .map_err(|e| { - warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()); - auth::AuthError::user_timeout(e) - })??; + .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) + .map_err(auth::AuthError::user_timeout)??; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index dd48384c03..a50c30257f 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -2,7 +2,6 @@ use std::fmt; use async_trait::async_trait; use postgres_client::config::SslMode; -use pq_proto::BeMessage as Be; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; @@ -16,6 +15,7 @@ use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; +use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; use crate::stream::PqStream; @@ -154,11 +154,13 @@ async fn authenticate( // Give user a URL to spawn a new database. info!(parent: &span, "sending the auth URL to the user"); - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::NoticeResponse(&greeting)) - .await?; + client.write_message(BeMessage::AuthenticationOk); + client.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + client.write_message(BeMessage::NoticeResponse(&greeting)); + client.flush().await?; // Wait for console response via control plane (see `mgmt`). info!(parent: &span, "waiting for console's reply..."); @@ -188,7 +190,7 @@ async fn authenticate( } } - client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; + client.write_message(BeMessage::NoticeResponse("Connecting to database.")); // This config should be self-contained, because we won't // take username or dbname from client's startup message. diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 3316543022..1e5c076fb9 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext( debug!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); - // pause the timer while we communicate with the client - let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let ep = EndpointIdInt::from(&info.endpoint); - let auth_flow = AuthFlow::new(client) - .begin(auth::CleartextPassword { + let auth_flow = AuthFlow::new( + client, + auth::CleartextPassword { secret, endpoint: ep, pool: config.thread_pool.clone(), - }) - .await?; - drop(paused); - // cleartext auth is only allowed to the ws/http protocol. - // If we're here, we already received the password in the first message. - // Scram protocol will be executed on the proxy side. - let auth_outcome = auth_flow.authenticate().await?; + }, + ); + let auth_outcome = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // cleartext auth is only allowed to the ws/http protocol. + // If we're here, we already received the password in the first message. + // Scram protocol will be executed on the proxy side. + auth_flow.authenticate().await? + }; let keys = match auth_outcome { sasl::Outcome::Success(key) => key, @@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication( // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let payload = AuthFlow::new(client) - .begin(auth::PasswordHack) - .await? + let payload = AuthFlow::new(client, auth::PasswordHack) .get_password() .await?; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 6e5c0a3954..8c892d90a0 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -31,6 +31,7 @@ use crate::control_plane::{ }; use crate::intern::EndpointIdInt; use crate::metrics::Metrics; +use crate::pqproto::BeMessage; use crate::protocol2::ConnectionInfoExtra; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; @@ -402,7 +403,7 @@ async fn authenticate_with_secret( }; // we have authenticated the password - client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; + client.write_message(BeMessage::AuthenticationOk); return Ok(ComputeCredentials { info, keys }); } @@ -702,7 +703,7 @@ mod tests { #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -784,7 +785,7 @@ mod tests { #[tokio::test] async fn auth_quirks_cleartext() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -838,7 +839,7 @@ mod tests { #[tokio::test] async fn auth_quirks_password_hack() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 526d0df7f2..b51da48862 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -5,7 +5,6 @@ use std::net::IpAddr; use std::str::FromStr; use itertools::Itertools; -use pq_proto::StartupMessageParams; use thiserror::Error; use tracing::{debug, warn}; @@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param; use crate::context::RequestContext; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::NeonOptions; use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI}; use crate::types::{EndpointId, RoleName}; diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 0992c6d875..8fbc4577e9 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,10 +1,8 @@ //! Main authentication flow. -use std::io; use std::sync::Arc; use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; -use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; use crate::intern::EndpointIdInt; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; use crate::scram::{self}; use crate::stream::{PqStream, Stream}; use crate::tls::TlsServerEndPoint; -/// Every authentication selector is supposed to implement this trait. -pub(crate) trait AuthMethod { - /// Any authentication selector should provide initial backend message - /// containing auth method name and parameters, e.g. md5 salt. - fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; -} - -/// Initial state of [`AuthFlow`]. -pub(crate) struct Begin; - /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. pub(crate) struct Scram<'a>( pub(crate) &'a scram::ServerSecret, pub(crate) &'a RequestContext, ); -impl AuthMethod for Scram<'_> { +impl Scram<'_> { #[inline(always)] fn first_message(&self, channel_binding: bool) -> BeMessage<'_> { if channel_binding { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) } else { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( scram::METHODS_WITHOUT_PLUS, )) } @@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> { /// . pub(crate) struct PasswordHack; -impl AuthMethod for PasswordHack { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// Use clear-text password auth called `password` in docs /// pub(crate) struct CleartextPassword { @@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword { pub(crate) secret: AuthSecret, } -impl AuthMethod for CleartextPassword { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// This wrapper for [`PqStream`] performs client authentication. #[must_use] pub(crate) struct AuthFlow<'a, S, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream>, - /// State might contain ancillary data (see [`Self::begin`]). + /// State might contain ancillary data. state: State, tls_server_end_point: TlsServerEndPoint, } /// Initial state of the stream wrapper. -impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { +impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> { /// Create a new wrapper for client authentication. - pub(crate) fn new(stream: &'a mut PqStream>) -> Self { + pub(crate) fn new(stream: &'a mut PqStream>, method: M) -> Self { let tls_server_end_point = stream.get_ref().tls_server_end_point(); Self { stream, - state: Begin, + state: method, tls_server_end_point, } } - - /// Move to the next step by sending auth method's name & params to client. - pub(crate) async fn begin(self, method: M) -> io::Result> { - self.stream - .write_message(&method.first_message(self.tls_server_end_point.supported())) - .await?; - - Ok(AuthFlow { - stream: self.stream, - state: method, - tls_server_end_point: self.tls_server_end_point, - }) - } } impl AuthFlow<'_, S, PasswordHack> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn get_password(self) -> super::Result { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -133,6 +99,10 @@ impl AuthFlow<'_, S, PasswordHack> { impl AuthFlow<'_, S, CleartextPassword> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -147,7 +117,7 @@ impl AuthFlow<'_, S, CleartextPassword> { .await?; if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; + self.stream.write_message(BeMessage::AuthenticationOk); } Ok(outcome) @@ -159,42 +129,36 @@ impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { let Scram(secret, ctx) = self.state; + let channel_binding = self.tls_server_end_point; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + // send sasl message. + { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // Initial client message contains the chosen auth method's name. - let msg = self.stream.read_password_message().await?; - let sasl = sasl::FirstMessage::parse(&msg) - .ok_or(AuthError::MalformedPassword("bad sasl message"))?; - - // Currently, the only supported SASL method is SCRAM. - if !scram::METHODS.contains(&sasl.method) { - return Err(super::AuthError::bad_auth_method(sasl.method)); + let sasl = self.state.first_message(channel_binding.supported()); + self.stream.write_message(sasl); + self.stream.flush().await?; } - match sasl.method { - SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), - SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus), - _ => {} - } + // complete sasl handshake. + sasl::authenticate(ctx, self.stream, |method| { + // Currently, the only supported SASL method is SCRAM. + match method { + SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => { + ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus); + } + method => return Err(sasl::Error::BadAuthMethod(method.into())), + } - // TODO: make this a metric instead - info!("client chooses {}", sasl.method); + // TODO: make this a metric instead + info!("client chooses {}", method); - let outcome = sasl::SaslStream::new(self.stream, sasl.message) - .authenticate(scram::Exchange::new( - secret, - rand::random, - self.tls_server_end_point, - )) - .await?; - - if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; - } - - Ok(outcome) + Ok(scram::Exchange::new(secret, rand::random, channel_binding)) + }) + .await + .map_err(AuthError::Sasl) } } diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 3e87538ae7..a4f517fead 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -4,8 +4,9 @@ //! This allows connecting to pods/services running in the same Kubernetes cluster from //! the outside. Similar to an ingress controller for HTTPS. +use std::net::SocketAddr; use std::path::Path; -use std::{net::SocketAddr, sync::Arc}; +use std::sync::Arc; use anyhow::{Context, anyhow, bail, ensure}; use clap::Arg; @@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio_rustls::TlsConnector; +use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; use tracing::{Instrument, error, info}; use utils::project_git_version; @@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry; use crate::context::RequestContext; use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; -use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled}; +use crate::proxy::{ + ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled, +}; use crate::stream::{PqStream, Stream}; -use crate::tls::TlsServerEndPoint; project_git_version!(GIT_VERSION); @@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> { .parse()?; // Configure TLS - let (tls_config, tls_server_end_point): (Arc, TlsServerEndPoint) = match ( + let tls_config = match ( args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { @@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, proxy_listener, cancellation_token.clone(), )) @@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(compute_tls_config), - tls_server_end_point, proxy_listener_compute_tls, cancellation_token.clone(), )) @@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> { pub(super) fn parse_tls( key_path: &Path, cert_path: &Path, -) -> anyhow::Result<(Arc, TlsServerEndPoint)> { +) -> anyhow::Result> { let key = { let key_bytes = std::fs::read(key_path).context("TLS key file")?; @@ -187,10 +189,6 @@ pub(super) fn parse_tls( })? }; - // needed for channel bindings - let first_cert = cert_chain.first().context("missing certificate")?; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - let tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) @@ -199,14 +197,13 @@ pub(super) fn parse_tls( .with_single_cert(cert_chain, key)? .into(); - Ok((tls_config, tls_server_end_point)) + Ok(tls_config) } pub(super) async fn task_main( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -242,15 +239,7 @@ pub(super) async fn task_main( crate::metrics::Protocol::SniRouter, "sni", ); - handle_client( - ctx, - dest_suffix, - tls_config, - compute_tls_config, - tls_server_end_point, - socket, - ) - .await + handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -269,55 +258,26 @@ pub(super) async fn task_main( Ok(()) } -const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; - async fn ssl_handshake( ctx: &RequestContext, raw_stream: S, tls_config: Arc, - tls_server_end_point: TlsServerEndPoint, -) -> anyhow::Result> { - let mut stream = PqStream::new(Stream::from_raw(raw_stream)); - - let msg = stream.read_startup_packet().await?; - use pq_proto::FeStartupPacket::SslRequest; - +) -> anyhow::Result> { + let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?; match msg { - SslRequest { direct: false } => { - stream - .write_message(&pq_proto::BeMessage::EncryptionResponse(true)) - .await?; + FeStartupPacket::SslRequest { direct: None } => { + let raw = stream.accept_tls().await?; - // Upgrade raw stream into a secure TLS-backed stream. - // NOTE: We've consumed `tls`; this fact will be used later. - - let (raw, read_buf) = stream.into_inner(); - // TODO: Normally, client doesn't send any data before - // server says TLS handshake is ok and read_buf is empty. - // However, you could imagine pipelining of postgres - // SSLRequest + TLS ClientHello in one hunk similar to - // pipelining in our node js driver. We should probably - // support that by chaining read_buf with the stream. - if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); - } - - Ok(Stream::Tls { - tls: Box::new( - raw.upgrade(tls_config, !ctx.has_private_peer_addr()) - .await?, - ), - tls_server_end_point, - }) + Ok(raw + .upgrade(tls_config, !ctx.has_private_peer_addr()) + .await?) } unexpected => { info!( ?unexpected, "unexpected startup packet, rejecting connection" ); - stream - .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None) - .await? + Err(stream.throw_error(TlsRequired, None).await)? } } } @@ -327,15 +287,18 @@ async fn handle_client( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?; + let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain` - let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?; + let sni = tls_stream + .get_ref() + .1 + .server_name() + .ok_or(anyhow!("SNI missing"))?; let dest: Vec<&str> = sni .split_once('.') .context("invalid SNI")? diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 5f24940985..9a3903ba9a 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -476,8 +476,7 @@ pub async fn run() -> anyhow::Result<()> { let key_path = args.tls_key.expect("already asserted it is set"); let cert_path = args.tls_cert.expect("already asserted it is set"); - let (tls_config, tls_server_end_point) = - super::pg_sni_router::parse_tls(&key_path, &cert_path)?; + let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?; let dest = Arc::new(dest); @@ -485,7 +484,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, listen, cancellation_token.clone(), )); @@ -494,7 +492,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(config.connect_to_compute.tls.clone()), - tls_server_end_point, listen_tls, cancellation_token.clone(), )); diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a6e7bf85a0..0bff901376 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -5,7 +5,6 @@ use anyhow::{Context, anyhow}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::CancelToken; use postgres_client::tls::MakeTlsConnect; -use pq_proto::CancelKeyData; use redis::{Cmd, FromRedisValue, Value}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -21,6 +20,7 @@ use crate::control_plane::ControlPlaneApi; use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; +use crate::pqproto::CancelKeyData; use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 26254beecf..2899f25129 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -8,7 +8,6 @@ use itertools::Itertools; use postgres_client::tls::MakeTlsConnect; use postgres_client::{CancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; -use pq_proto::StartupMessageParams; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; @@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; +use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index e3184e20d1..9499aba61b 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use futures::{FutureExt, TryFutureExt}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info}; @@ -221,12 +221,10 @@ pub(crate) async fn handle_client( .await { Ok(auth_result) => auth_result, - Err(e) => { - return stream.throw_error(e, Some(ctx)).await?; - } + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; - let mut node = connect_to_compute( + let node = connect_to_compute( ctx, &TcpMechanism { user_info, @@ -238,7 +236,7 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) + .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; let cancellation_handler_clone = Arc::clone(&cancellation_handler); @@ -246,14 +244,8 @@ pub(crate) async fn handle_client( session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; Ok(Some(ProxyPassthrough { client: stream, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 79aaf22990..de4600951e 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -4,7 +4,6 @@ use std::net::IpAddr; use chrono::Utc; use once_cell::sync::OnceCell; -use pq_proto::StartupMessageParams; use smol_str::SmolStr; use tokio::sync::mpsc; use tracing::field::display; @@ -20,6 +19,7 @@ use crate::metrics::{ ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting, }; +use crate::pqproto::StartupMessageParams; use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra}; use crate::types::{DbName, EndpointId, RoleName}; diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index f6250bcd17..c9d3905abd 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr; use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr}; use parquet::file::writer::SerializedFileWriter; use parquet::record::RecordWriter; -use pq_proto::StartupMessageParams; use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel}; use serde::ser::SerializeMap; use tokio::sync::mpsc; @@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner}; use crate::config::remote_storage_from_toml; use crate::context::LOG_CHAN_DISCONNECT; use crate::ext::TaskExt; +use crate::pqproto::StartupMessageParams; #[derive(clap::Args, Clone, Debug)] pub struct ParquetUploadArgs { diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index d1f8430b8a..d65d056585 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -92,6 +92,7 @@ mod logging; mod metrics; mod parse; mod pglb; +mod pqproto; mod protocol2; mod proxy; mod rate_limiter; diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs new file mode 100644 index 0000000000..d68d9f9474 --- /dev/null +++ b/proxy/src/pqproto.rs @@ -0,0 +1,693 @@ +//! Postgres protocol codec +//! +//! + +use std::fmt; +use std::io::{self, Cursor}; + +use bytes::{Buf, BufMut}; +use itertools::Itertools; +use rand::distributions::{Distribution, Standard}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; + +pub type ErrorCode = [u8; 5]; + +pub const FE_PASSWORD_MESSAGE: u8 = b'p'; + +pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000"; + +/// The protocol version number. +/// +/// The most significant 16 bits are the major version number (3 for the protocol described here). +/// The least significant 16 bits are the minor version number (0 for the protocol described here). +/// +#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)] +#[repr(C)] +pub struct ProtocolVersion { + major: big_endian::U16, + minor: big_endian::U16, +} + +impl ProtocolVersion { + pub const fn new(major: u16, minor: u16) -> Self { + Self { + major: big_endian::U16::new(major), + minor: big_endian::U16::new(minor), + } + } + pub const fn minor(self) -> u16 { + self.minor.get() + } + pub const fn major(self) -> u16 { + self.major.get() + } +} + +impl fmt::Debug for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entry(&self.major()) + .entry(&self.minor()) + .finish() + } +} + +/// read the type from the stream using zerocopy. +/// +/// not cancel safe. +macro_rules! read { + ($s:expr => $t:ty) => {{ + // cannot be implemented as a function due to lack of const-generic-expr + let mut buf = [0; size_of::<$t>()]; + $s.read_exact(&mut buf).await?; + let res: $t = zerocopy::transmute!(buf); + res + }}; +} + +pub async fn read_startup(stream: &mut S) -> io::Result +where + S: AsyncRead + Unpin, +{ + /// + const MAX_STARTUP_PACKET_LENGTH: usize = 10000; + const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; + /// + const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); + /// + const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); + /// + const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + + /// This first reads the startup message header, is 8 bytes. + /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. + /// + /// The length value is inclusive of the header. For example, + /// an empty message will always have length 8. + #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] + #[repr(C)] + struct StartupHeader { + len: big_endian::U32, + version: ProtocolVersion, + } + + let header = read!(stream => StartupHeader); + + // + // First byte indicates standard SSL handshake message + // (It can't be a Postgres startup length because in network byte order + // that would be a startup packet hundreds of megabytes long) + if header.as_bytes()[0] == 0x16 { + return Ok(FeStartupPacket::SslRequest { + // The bytes we read for the header are actually part of a TLS ClientHello. + // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here. + // In practice though, I see no world where a ClientHello is less than 8 bytes + // since it includes ephemeral keys etc. + direct: Some(zerocopy::transmute!(header)), + }); + } + + let Some(len) = (header.len.get() as usize).checked_sub(8) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 8.", + header.len, + ))); + }; + + // TODO: add a histogram for startup packet lengths + if len > MAX_STARTUP_PACKET_LENGTH { + tracing::warn!("large startup message detected: {len} bytes"); + return Err(io::Error::other(format!( + "invalid startup message length {len}" + ))); + } + + match header.version { + // + CANCEL_REQUEST_CODE => { + if len != 8 { + return Err(io::Error::other( + "CancelRequest message is malformed, backend PID / secret key missing", + )); + } + + Ok(FeStartupPacket::CancelRequest( + read!(stream => CancelKeyData), + )) + } + // + NEGOTIATE_SSL_CODE => { + // Requested upgrade to SSL (aka TLS) + Ok(FeStartupPacket::SslRequest { direct: None }) + } + NEGOTIATE_GSS_CODE => { + // Requested upgrade to GSSAPI + Ok(FeStartupPacket::GssEncRequest) + } + version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other( + format!("Unrecognized request code {version:?}"), + )), + // StartupMessage + version => { + // The protocol version number is followed by one or more pairs of parameter name and value strings. + // A zero byte is required as a terminator after the last name/value pair. + // Parameters can appear in any order. user is required, others are optional. + + let mut buf = vec![0; len]; + stream.read_exact(&mut buf).await?; + + if buf.pop() != Some(b'\0') { + return Err(io::Error::other( + "StartupMessage params: missing null terminator", + )); + } + + // TODO: Don't do this. + // There's no guarantee that these messages are utf8, + // but they usually happen to be simple ascii. + let params = String::from_utf8(buf) + .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?; + + Ok(FeStartupPacket::StartupMessage { + version, + params: StartupMessageParams { params }, + }) + } + } +} + +/// Read a raw postgres packet, which will respect the max length requested. +/// +/// This returns the message tag, as well as the message body. The message +/// body is written into `buf`, and it is otherwise completely overwritten. +/// +/// This is not cancel safe. +pub async fn read_message<'a, S>( + stream: &mut S, + buf: &'a mut Vec, + max: usize, +) -> io::Result<(u8, &'a mut [u8])> +where + S: AsyncRead + Unpin, +{ + /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes. + /// The first byte is a message tag, and the next 4 bytes is a big-endian length. + /// + /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example, + /// an empty message will always have length 4. + #[derive(Clone, Copy, FromBytes)] + #[repr(C)] + struct Header { + tag: u8, + len: big_endian::U32, + } + + let header = read!(stream => Header); + + // as described above, the length must be at least 4. + let Some(len) = (header.len.get() as usize).checked_sub(4) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 4.", + header.len, + ))); + }; + + // TODO: add a histogram for message lengths + + // check if the message exceeds our desired max. + if len > max { + tracing::warn!("large postgres message detected: {len} bytes"); + return Err(io::Error::other(format!("invalid message length {len}"))); + } + + // read in our entire message. + buf.resize(len, 0); + stream.read_exact(buf).await?; + + Ok((header.tag, buf)) +} + +pub struct WriteBuf(Cursor>); + +impl Buf for WriteBuf { + #[inline] + fn remaining(&self) -> usize { + self.0.remaining() + } + + #[inline] + fn chunk(&self) -> &[u8] { + self.0.chunk() + } + + #[inline] + fn advance(&mut self, cnt: usize) { + self.0.advance(cnt); + } +} + +impl WriteBuf { + pub const fn new() -> Self { + Self(Cursor::new(Vec::new())) + } + + /// Use a heuristic to determine if we should shrink the write buffer. + #[inline] + fn should_shrink(&self) -> bool { + let n = self.0.position() as usize; + let len = self.0.get_ref().len(); + + // the unused space at the front of our buffer is 2x the size of our filled portion. + n + n > len + } + + /// Shrink the write buffer so that subsequent writes have more spare capacity. + #[cold] + fn shrink(&mut self) { + let n = self.0.position() as usize; + let buf = self.0.get_mut(); + + // buf repr: + // [----unused------|-----filled-----|-----uninit-----] + // ^ n ^ buf.len() ^ buf.capacity() + let filled = n..buf.len(); + let filled_len = filled.len(); + buf.copy_within(filled, 0); + buf.truncate(filled_len); + self.0.set_position(0); + } + + /// clear the write buffer. + pub fn reset(&mut self) { + let buf = self.0.get_mut(); + buf.clear(); + self.0.set_position(0); + } + + /// Write a raw message to the internal buffer. + /// + /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since + /// we calculate the length after the fact. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + if self.should_shrink() { + self.shrink(); + } + + let buf = self.0.get_mut(); + buf.reserve(5 + size_hint); + + buf.push(tag); + let start = buf.len(); + buf.extend_from_slice(&[0, 0, 0, 0]); + + f(buf); + + let end = buf.len(); + let len = (end - start) as u32; + buf[start..start + 4].copy_from_slice(&len.to_be_bytes()); + } + + /// Write an encryption response message. + pub fn encryption(&mut self, m: u8) { + self.0.get_mut().push(m); + } + + pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) { + self.shrink(); + + // + // + // "SERROR\0CXXXXX\0M\0\0".len() == 17 + self.write_raw(17 + msg.len(), b'E', |buf| { + // Severity: ERROR + buf.put_slice(b"SERROR\0"); + + // Code: error_code + buf.put_u8(b'C'); + buf.put_slice(&error_code); + buf.put_u8(0); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End. + buf.put_u8(0); + }); + } +} + +#[derive(Debug)] +pub enum FeStartupPacket { + CancelRequest(CancelKeyData), + SslRequest { + direct: Option<[u8; 8]>, + }, + GssEncRequest, + StartupMessage { + version: ProtocolVersion, + params: StartupMessageParams, + }, +} + +#[derive(Debug, Clone, Default)] +pub struct StartupMessageParams { + pub params: String, +} + +impl StartupMessageParams { + /// Get parameter's value by its name. + pub fn get(&self, name: &str) -> Option<&str> { + self.iter().find_map(|(k, v)| (k == name).then_some(v)) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + /// [`None`] means that there's no `options` in [`Self`]. + pub fn options_raw(&self) -> Option> { + self.get("options").map(Self::parse_options_raw) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + pub fn parse_options_raw(input: &str) -> impl Iterator { + // See `postgres: pg_split_opts`. + let mut last_was_escape = false; + input + .split(move |c: char| { + // We split by non-escaped whitespace symbols. + let should_split = c.is_ascii_whitespace() && !last_was_escape; + last_was_escape = c == '\\' && !last_was_escape; + should_split + }) + .filter(|s| !s.is_empty()) + } + + /// Iterate through key-value pairs in an arbitrary order. + pub fn iter(&self) -> impl Iterator { + self.params.split_terminator('\0').tuples() + } + + // This function is mostly useful in tests. + #[cfg(test)] + pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self { + let mut b = Self { + params: String::new(), + }; + for (k, v) in pairs { + b.insert(k, v); + } + b + } + + /// Set parameter's value by its name. + /// name and value must not contain a \0 byte + pub fn insert(&mut self, name: &str, value: &str) { + self.params.reserve(name.len() + value.len() + 2); + self.params.push_str(name); + self.params.push('\0'); + self.params.push_str(value); + self.params.push('\0'); + } +} + +/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just +/// opaque bytes. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)] +pub struct CancelKeyData(pub big_endian::U64); + +pub fn id_to_cancel_key(id: u64) -> CancelKeyData { + CancelKeyData(big_endian::U64::new(id)) +} + +impl fmt::Display for CancelKeyData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let id = self.0; + f.debug_tuple("CancelKeyData") + .field(&format_args!("{id:x}")) + .finish() + } +} +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> CancelKeyData { + id_to_cancel_key(rng.r#gen()) + } +} + +pub enum BeMessage<'a> { + AuthenticationOk, + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), + AuthenticationCleartextPassword, + BackendKeyData(CancelKeyData), + ParameterStatus { + name: &'a [u8], + value: &'a [u8], + }, + ReadyForQuery, + NoticeResponse(&'a str), + NegotiateProtocolVersion { + version: ProtocolVersion, + options: &'a [&'a str], + }, +} + +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + +impl BeMessage<'_> { + /// Write the message into an internal buffer + pub fn write_message(self, buf: &mut WriteBuf) { + match self { + // + BeMessage::AuthenticationOk => { + buf.write_raw(1, b'R', |buf| buf.put_i32(0)); + } + // + BeMessage::AuthenticationCleartextPassword => { + buf.write_raw(1, b'R', |buf| buf.put_i32(3)); + } + + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => { + let len: usize = methods.iter().map(|m| m.len() + 1).sum(); + buf.write_raw(len + 2, b'R', |buf| { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods { + buf.put_slice(method.as_bytes()); + buf.put_u8(0); + } + buf.put_u8(0); // zero terminator for the list + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + }); + } + + // + BeMessage::BackendKeyData(key_data) => { + buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes())); + } + + // + // + BeMessage::NoticeResponse(msg) => { + // 'N' signalizes NoticeResponse messages + buf.write_raw(18 + msg.len(), b'N', |buf| { + // Severity: NOTICE + buf.put_slice(b"SNOTICE\0"); + + // Code: XX000 (ignored for notice, but still required) + buf.put_slice(b"CXX000\0"); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End notice. + buf.put_u8(0); + }); + } + + // + BeMessage::ParameterStatus { name, value } => { + buf.write_raw(name.len() + value.len() + 2, b'S', |buf| { + buf.put_slice(name.as_bytes()); + buf.put_u8(0); + buf.put_slice(value.as_bytes()); + buf.put_u8(0); + }); + } + + // + BeMessage::ReadyForQuery => { + buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I')); + } + + // + BeMessage::NegotiateProtocolVersion { version, options } => { + let len: usize = options.iter().map(|o| o.len() + 1).sum(); + buf.write_raw(8 + len, b'v', |buf| { + buf.put_slice(version.as_bytes()); + buf.put_u32(options.len() as u32); + for option in options { + buf.put_slice(option.as_bytes()); + buf.put_u8(0); + } + }); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use tokio::io::{AsyncWriteExt, duplex}; + use zerocopy::IntoBytes; + + use crate::pqproto::{FeStartupPacket, read_message, read_startup}; + + use super::ProtocolVersion; + + #[tokio::test] + async fn reject_large_startup() { + // we're going to define a v3.0 startup message with far too many parameters. + let mut payload = vec![]; + // 10001 + 8 bytes. + payload.extend_from_slice(&10009_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.resize(10009, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_startup(&mut server).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid startup message length 10001"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn reject_large_password() { + // we're going to define a password message that is far too long. + let mut payload = vec![]; + payload.push(b'p'); + payload.extend_from_slice(&517_u32.to_be_bytes()); + payload.resize(518, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid message length 513"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn read_startup_message() { + let mut payload = vec![]; + payload.extend_from_slice(&17_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.extend_from_slice(b"abc\0def\0\0"); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::StartupMessage { version, params } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + + assert_eq!(version.major(), 3); + assert_eq!(version.minor(), 0); + assert_eq!(params.params, "abc\0def\0"); + } + + #[tokio::test] + async fn read_ssl_message() { + let mut payload = vec![]; + payload.extend_from_slice(&8_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes()); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::SslRequest { direct: None } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + } + + #[tokio::test] + async fn read_tls_message() { + // sample client hello taken from + let client_hello = [ + 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02, + 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, + 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, + 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, + 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, + 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, + 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, + 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, + 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19, + 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23, + 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e, + 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09, + 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, + 0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01, + 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72, + 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, + 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, + 0x54, + ]; + + let mut cursor = Cursor::new(&client_hello); + + let startup = read_startup(&mut cursor).await.unwrap(); + let FeStartupPacket::SslRequest { + direct: Some(prefix), + } = startup + else { + panic!("unexpected startup message: {startup:?}"); + }; + + // check that no data is lost. + assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]); + assert_eq!(cursor.position(), 8); + } + + #[tokio::test] + async fn read_message_success() { + let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2"; + let mut cursor = Cursor::new(&query); + + let mut buf = vec![]; + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 1"); + + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 2"); + } +} diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index e013fbbe2e..57785c9ec5 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use pq_proto::StartupMessageParams; use tokio::time; use tracing::{debug, info, warn}; @@ -15,6 +14,7 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; +use crate::pqproto::StartupMessageParams; use crate::proxy::retry::{CouldRetry, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 54c02f2c15..13ee8c7dd2 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,8 +1,3 @@ -use bytes::Buf; -use pq_proto::framed::Framed; -use pq_proto::{ - BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, -}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -12,7 +7,10 @@ use crate::config::TlsConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; -use crate::proxy::ERR_INSECURE_CONNECTION; +use crate::pqproto::{ + BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, +}; +use crate::proxy::TlsRequired; use crate::stream::{PqStream, Stream, StreamUpgradeError}; use crate::tls::PG_ALPN_PROTOCOL; @@ -71,33 +69,25 @@ pub(crate) async fn handshake( const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0); const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0); - let mut stream = PqStream::new(Stream::from_raw(stream)); + let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?; loop { - let msg = stream.read_startup_packet().await?; match msg { FeStartupPacket::SslRequest { direct } => match stream.get_ref() { Stream::Raw { .. } if !tried_ssl => { tried_ssl = true; - // We can't perform TLS handshake without a config - let have_tls = tls.is_some(); - if !direct { - stream - .write_message(&Be::EncryptionResponse(have_tls)) - .await?; - } else if !have_tls { - return Err(HandshakeError::ProtocolViolation); - } - if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. - let Framed { - stream: raw, - read_buf, - write_buf, - } = stream.framed; + let mut read_buf; + let raw = if let Some(direct) = &direct { + read_buf = &direct[..]; + stream.accept_direct_tls() + } else { + read_buf = &[]; + stream.accept_tls().await? + }; let Stream::Raw { raw } = raw else { return Err(HandshakeError::StreamUpgradeError( @@ -105,12 +95,11 @@ pub(crate) async fn handshake( )); }; - let mut read_buf = read_buf.reader(); let mut res = Ok(()); let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone()) .accept_with(raw, |session| { // push the early data to the tls session - while !read_buf.get_ref().is_empty() { + while !read_buf.is_empty() { match session.read_tls(&mut read_buf) { Ok(_) => {} Err(e) => { @@ -123,7 +112,6 @@ pub(crate) async fn handshake( res?; - let read_buf = read_buf.into_inner(); if !read_buf.is_empty() { return Err(HandshakeError::EarlyData); } @@ -157,16 +145,17 @@ pub(crate) async fn handshake( let (_, tls_server_end_point) = tls.cert_resolver.resolve(conn_info.server_name()); - stream = PqStream { - framed: Framed { - stream: Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }, - read_buf, - write_buf, - }, + let tls = Stream::Tls { + tls: Box::new(tls_stream), + tls_server_end_point, }; + (stream, msg) = PqStream::parse_startup(tls).await?; + } else { + if direct.is_some() { + // client sent us a ClientHello already, we can't do anything with it. + return Err(HandshakeError::ProtocolViolation); + } + msg = stream.reject_encryption().await?; } } _ => return Err(HandshakeError::ProtocolViolation), @@ -176,7 +165,7 @@ pub(crate) async fn handshake( tried_gss = true; // Currently, we don't support GSSAPI - stream.write_message(&Be::EncryptionResponse(false)).await?; + msg = stream.reject_encryption().await?; } _ => return Err(HandshakeError::ProtocolViolation), }, @@ -186,13 +175,7 @@ pub(crate) async fn handshake( // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { - return stream - .throw_error_str( - ERR_INSECURE_CONNECTION, - crate::error::ErrorKind::User, - None, - ) - .await?; + Err(stream.throw_error(TlsRequired, None).await)?; } // This log highlights the start of the connection. @@ -214,20 +197,21 @@ pub(crate) async fn handshake( // no protocol extensions are supported. // let mut unsupported = vec![]; - for (k, _) in params.iter() { + let mut supported = StartupMessageParams::default(); + + for (k, v) in params.iter() { if k.starts_with("_pq_.") { unsupported.push(k); + } else { + supported.insert(k, v); } } - // TODO: remove unsupported options so we don't send them to compute. - - stream - .write_message(&Be::NegotiateProtocolVersion { - version: PG_PROTOCOL_LATEST, - options: &unsupported, - }) - .await?; + stream.write_message(BeMessage::NegotiateProtocolVersion { + version: PG_PROTOCOL_LATEST, + options: &unsupported, + }); + stream.flush().await?; info!( ?version, @@ -235,7 +219,7 @@ pub(crate) async fn handshake( session_type = "normal", "successful handshake; unsupported minor version requested" ); - break Ok(HandshakeData::Startup(stream, params)); + break Ok(HandshakeData::Startup(stream, supported)); } FeStartupPacket::StartupMessage { version, params } => { warn!( diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0a86022e78..26ac6a89e7 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -10,15 +10,14 @@ pub(crate) mod wake_compute; use std::sync::Arc; pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; -use futures::{FutureExt, TryFutureExt}; +use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; -use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams}; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, ToSmolStr, format_smolstr}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; @@ -27,8 +26,9 @@ use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; -use crate::error::ReportableError; +use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; use crate::proxy::handshake::{HandshakeData, handshake}; use crate::rate_limiter::EndpointRateLimiter; @@ -38,6 +38,18 @@ use crate::{auth, compute}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; +#[derive(Error, Debug)] +#[error("{ERR_INSECURE_CONNECTION}")] +pub struct TlsRequired; + +impl ReportableError for TlsRequired { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + +impl UserFacingError for TlsRequired {} + pub async fn run_until_cancelled( f: F, cancellation_token: &CancellationToken, @@ -329,7 +341,7 @@ pub(crate) async fn handle_client( let user_info = match result { Ok(user_info) => user_info, - Err(e) => stream.throw_error(e, Some(ctx)).await?, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; let user = user_info.get_user().to_owned(); @@ -349,10 +361,10 @@ pub(crate) async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return stream + return Err(stream .throw_error(e, Some(ctx)) .instrument(params_span) - .await?; + .await)?; } }; @@ -365,7 +377,7 @@ pub(crate) async fn handle_client( .get(NeonOptions::PARAMS_COMPAT) .is_some(); - let mut node = connect_to_compute( + let res = connect_to_compute( ctx, &TcpMechanism { user_info: compute_user_info.clone(), @@ -377,22 +389,19 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) - .await?; + .await; + + let node = match res { + Ok(node) => node, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + }; let cancellation_handler_clone = Arc::clone(&cancellation_handler); let session = cancellation_handler_clone.get_key(); session.write_cancel_key(node.cancel_closure.clone())?; - - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), @@ -413,31 +422,28 @@ pub(crate) async fn handle_client( } /// Finish client connection initialization: confirm auth success, send params, etc. -#[tracing::instrument(skip_all)] -pub(crate) async fn prepare_client_connection( +pub(crate) fn prepare_client_connection( node: &compute::PostgresConnection, cancel_key_data: CancelKeyData, stream: &mut PqStream, -) -> Result<(), std::io::Error> { +) { // Forward all deferred notices to the client. for notice in &node.delayed_notice { - stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?; + stream.write_raw(notice.as_bytes().len(), b'N', |buf| { + buf.extend_from_slice(notice.as_bytes()); + }); } // Forward all postgres connection params to the client. for (name, value) in &node.params { - stream.write_message_noflush(&Be::ParameterStatus { + stream.write_message(BeMessage::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), - })?; + }); } - stream - .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&Be::ReadyForQuery) - .await?; - - Ok(()) + stream.write_message(BeMessage::BackendKeyData(cancel_key_data)); + stream.write_message(BeMessage::ReadyForQuery); } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 0879564ced..01e603ec14 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati #[cfg(test)] mod tests { - use super::ShouldRetryWakeCompute; use postgres_client::error::{DbError, SqlState}; + use super::ShouldRetryWakeCompute; + #[test] fn should_retry_wake_compute_for_db_error() { // These SQLStates should NOT trigger a wake_compute retry. diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 59c9ac27b8..c92ee49b8d 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use postgres_client::tls::TlsConnect; use postgres_protocol::message::frontend; -use tokio::io::{AsyncReadExt, DuplexStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; use tokio_util::codec::{Decoder, Encoder}; use super::*; @@ -49,15 +49,14 @@ async fn proxy_mitm( }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); - let (end_client, buf) = end_client.framed.into_inner(); - assert!(buf.is_empty()); + let end_client = end_client.flush_and_into_inner().await.unwrap(); let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame); // give the end_server the startup parameters let mut buf = BytesMut::new(); frontend::startup_message( &postgres_protocol::message::frontend::StartupMessageParams { - params: startup.params.into(), + params: startup.params.as_bytes().into(), }, &mut buf, ) diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index be6426a63c..3cc053e0ad 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -128,7 +128,7 @@ trait TestAuth: Sized { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - stream.write_message_noflush(&Be::AuthenticationOk)?; + stream.write_message(BeMessage::AuthenticationOk); Ok(()) } } @@ -157,9 +157,7 @@ impl TestAuth for Scram { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0, &RequestContext::test())) - .await? + let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test())) .authenticate() .await?; @@ -185,10 +183,12 @@ async fn dummy_proxy( auth.authenticate(&mut stream).await?; - stream - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::ReadyForQuery) - .await?; + stream.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + stream.write_message(BeMessage::ReadyForQuery); + stream.flush().await?; Ok(()) } diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 186fece4b2..6f56aeea06 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -1,10 +1,11 @@ use core::net::IpAddr; use std::sync::Arc; -use pq_proto::CancelKeyData; use tokio::sync::Mutex; use uuid::Uuid; +use crate::pqproto::CancelKeyData; + pub trait CancellationPublisherMut: Send + Sync + 'static { #[allow(async_fn_in_trait)] async fn try_publish( diff --git a/proxy/src/redis/keys.rs b/proxy/src/redis/keys.rs index 7527bca6d0..3113bad949 100644 --- a/proxy/src/redis/keys.rs +++ b/proxy/src/redis/keys.rs @@ -1,16 +1,15 @@ use std::io::ErrorKind; use anyhow::Ok; -use pq_proto::{CancelKeyData, id_to_cancel_key}; -use serde::{Deserialize, Serialize}; + +use crate::pqproto::{CancelKeyData, id_to_cancel_key}; pub mod keyspace { pub const CANCEL_PREFIX: &str = "cancel"; } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum KeyPrefix { - #[serde(untagged)] Cancel(CancelKeyData), } @@ -18,9 +17,7 @@ impl KeyPrefix { pub(crate) fn build_redis_key(&self) -> String { match self { KeyPrefix::Cancel(key) => { - let hi = (key.backend_pid as u64) << 32; - let lo = (key.cancel_key as u64) & 0xffff_ffff; - let id = hi | lo; + let id = key.0.get(); let keyspace = keyspace::CANCEL_PREFIX; format!("{keyspace}:{id:x}") } @@ -63,10 +60,7 @@ mod tests { #[test] fn test_build_redis_key() { - let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }); + let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321)); let redis_key = cancel_key.build_redis_key(); assert_eq!(redis_key, "cancel:30390000d431"); @@ -77,10 +71,7 @@ mod tests { let redis_key = "cancel:30390000d431"; let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key"); - let ref_key = CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }; + let ref_key = id_to_cancel_key(12345 << 32 | 54321); assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str()); let KeyPrefix::Cancel(cancel_key) = key; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 5f9f2509e2..769d519d94 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -2,11 +2,9 @@ use std::convert::Infallible; use std::sync::Arc; use futures::StreamExt; -use pq_proto::CancelKeyData; use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use tokio_util::sync::CancellationToken; -use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; @@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate { role_name: RoleNameInt, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct CancelSession { - pub(crate) region_id: Option, - pub(crate) cancel_key_data: CancelKeyData, - pub(crate) session_id: Uuid, - pub(crate) peer_addr: Option, -} - fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index 7f2f3a761c..8d26a3f453 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -1,7 +1,5 @@ //! Definitions for SASL messages. -use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; - use crate::parse::split_cstr; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). @@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> { } } -/// A single SASL message. -/// This struct is deliberately decoupled from lower-level -/// [`BeAuthenticationSaslMessage`]. -#[derive(Debug)] -pub(super) enum ServerMessage { - /// We expect to see more steps. - Continue(T), - /// This is the final step. - Final(T), -} - -impl<'a> ServerMessage<&'a str> { - pub(super) fn to_reply(&self) -> BeMessage<'a> { - BeMessage::AuthenticationSasl(match self { - ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()), - ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()), - }) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/sasl/mod.rs b/proxy/src/sasl/mod.rs index f0181b404f..007b62dfd2 100644 --- a/proxy/src/sasl/mod.rs +++ b/proxy/src/sasl/mod.rs @@ -14,7 +14,7 @@ use std::io; pub(crate) use channel_binding::ChannelBinding; pub(crate) use messages::FirstMessage; -pub(crate) use stream::{Outcome, SaslStream}; +pub(crate) use stream::{Outcome, authenticate}; use thiserror::Error; use crate::error::{ReportableError, UserFacingError}; @@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] pub(crate) enum Error { + #[error("Unsupported authentication method: {0}")] + BadAuthMethod(Box), + #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), @@ -54,6 +57,7 @@ impl UserFacingError for Error { impl ReportableError for Error { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { + Error::BadAuthMethod(_) => crate::error::ErrorKind::User, Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, Error::BadClientMessage(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 46e6a439e5..cb15132673 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -3,61 +3,12 @@ use std::io; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::info; -use super::Mechanism; -use super::messages::ServerMessage; +use super::{Mechanism, Step}; +use crate::context::RequestContext; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::stream::PqStream; -/// Abstracts away all peculiarities of the libpq's protocol. -pub(crate) struct SaslStream<'a, S> { - /// The underlying stream. - stream: &'a mut PqStream, - /// Current password message we received from client. - current: bytes::Bytes, - /// First SASL message produced by client. - first: Option<&'a str>, -} - -impl<'a, S> SaslStream<'a, S> { - pub(crate) fn new(stream: &'a mut PqStream, first: &'a str) -> Self { - Self { - stream, - current: bytes::Bytes::new(), - first: Some(first), - } - } -} - -impl SaslStream<'_, S> { - // Receive a new SASL message from the client. - async fn recv(&mut self) -> io::Result<&str> { - if let Some(first) = self.first.take() { - return Ok(first); - } - - self.current = self.stream.read_password_message().await?; - let s = std::str::from_utf8(&self.current) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; - - Ok(s) - } -} - -impl SaslStream<'_, S> { - // Send a SASL message to the client. - async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message(&msg.to_reply()).await?; - Ok(()) - } - - // Queue a SASL message for the client. - fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message_noflush(&msg.to_reply())?; - Ok(()) - } -} - /// SASL authentication outcome. /// It's much easier to match on those two variants /// than to peek into a noisy protocol error type. @@ -69,33 +20,62 @@ pub(crate) enum Outcome { Failure(&'static str), } -impl SaslStream<'_, S> { - /// Perform SASL message exchange according to the underlying algorithm - /// until user is either authenticated or denied access. - pub(crate) async fn authenticate( - mut self, - mut mechanism: M, - ) -> super::Result> { - loop { - let input = self.recv().await?; - let step = mechanism.exchange(input).map_err(|error| { - info!(?error, "error during SASL exchange"); - error - })?; +pub async fn authenticate( + ctx: &RequestContext, + stream: &mut PqStream, + mechanism: F, +) -> super::Result> +where + S: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&str) -> super::Result, + M: Mechanism, +{ + let sasl = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - use super::Step; - return Ok(match step { - Step::Continue(moved_mechanism, reply) => { - self.send(&ServerMessage::Continue(&reply)).await?; - mechanism = moved_mechanism; - continue; - } - Step::Success(result, reply) => { - self.send_noflush(&ServerMessage::Final(&reply))?; - Outcome::Success(result) - } - Step::Failure(reason) => Outcome::Failure(reason), - }); + // Initial client message contains the chosen auth method's name. + let msg = stream.read_password_message().await?; + super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))? + }; + + let mut mechanism = mechanism(sasl.method)?; + let mut input = sasl.message; + loop { + let step = mechanism + .exchange(input) + .inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?; + + match step { + Step::Continue(moved_mechanism, reply) => { + mechanism = moved_mechanism; + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + } + Step::Success(result, reply) => { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + stream.write_message(BeMessage::AuthenticationOk); + // exit with success + break Ok(Outcome::Success(result)); + } + // exit with failure + Step::Failure(reason) => break Ok(Outcome::Failure(reason)), } } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 1c5bb64480..eb80ac9ad0 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; -use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; use serde_json::value::RawValue; @@ -41,6 +40,7 @@ use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{ReadBodyError, read_body_with_limit}; use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::{NeonOptions, run_until_cancelled}; use crate::serverless::backend::HttpConnError; use crate::types::{DbName, RoleName}; @@ -219,7 +219,7 @@ fn get_conn_info( let mut options = Option::None; - let mut params = StartupMessageParamsBuilder::default(); + let mut params = StartupMessageParams::default(); params.insert("user", &username); params.insert("database", &dbname); for (key, value) in pairs { diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 360550b0ac..7126430a85 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -2,19 +2,17 @@ use std::pin::Pin; use std::sync::Arc; use std::{io, task}; -use bytes::BytesMut; -use pq_proto::framed::{ConnectionError, Framed}; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; -use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_rustls::server::TlsStream; -use tracing::debug; -use crate::control_plane::messages::ColdStartInfo; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; +use crate::pqproto::{ + BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf, + read_message, read_startup, +}; use crate::tls::TlsServerEndPoint; /// Stream wrapper which implements libpq's protocol. @@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint; /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying /// to pass random malformed bytes through the connection). pub struct PqStream { - pub(crate) framed: Framed, + stream: S, + read: Vec, + write: WriteBuf, } impl PqStream { - /// Construct a new libpq protocol wrapper. - pub fn new(stream: S) -> Self { + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Construct a new libpq protocol wrapper over a stream without the first startup message. + #[cfg(test)] + pub fn new_skip_handshake(stream: S) -> Self { Self { - framed: Framed::new(stream), + stream, + read: Vec::new(), + write: WriteBuf::new(), } } - - /// Extract the underlying stream and read buffer. - pub fn into_inner(self) -> (S, BytesMut) { - self.framed.into_inner() - } - - /// Get a shared reference to the underlying stream. - pub(crate) fn get_ref(&self) -> &S { - self.framed.get_ref() - } } -fn err_connection() -> io::Error { - io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost") +impl PqStream { + /// Construct a new libpq protocol wrapper and read the first startup message. + /// + /// This is not cancel safe. + pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> { + let startup = read_startup(&mut stream).await?; + Ok(( + Self { + stream, + read: Vec::new(), + write: WriteBuf::new(), + }, + startup, + )) + } + + /// Tell the client that encryption is not supported. + /// + /// This is not cancel safe + pub async fn reject_encryption(&mut self) -> io::Result { + // N for No. + self.write.encryption(b'N'); + self.flush().await?; + read_startup(&mut self.stream).await + } } impl PqStream { - /// Receive [`FeStartupPacket`], which is a first packet sent by a client. - pub async fn read_startup_packet(&mut self) -> io::Result { - self.framed - .read_startup_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - async fn read_message(&mut self) -> io::Result { - self.framed - .read_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - pub(crate) async fn read_password_message(&mut self) -> io::Result { - match self.read_message().await? { - FeMessage::PasswordMessage(msg) => Ok(msg), - bad => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("unexpected message type: {bad:?}"), - )), + /// Read a raw postgres packet, which will respect the max length requested. + /// This is not cancel safe. + async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> { + let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; + if actual_tag != tag { + return Err(io::Error::other(format!( + "incorrect message tag, expected {:?}, got {:?}", + tag as char, actual_tag as char, + ))); } + Ok(msg) + } + + /// Read a postgres password message, which will respect the max length requested. + /// This is not cancel safe. + pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> { + // passwords are usually pretty short + // and SASL SCRAM messages are no longer than 256 bytes in my testing + // (a few hashes and random bytes, encoded into base64). + const MAX_PASSWORD_LENGTH: usize = 512; + self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) + .await } } @@ -84,6 +101,16 @@ pub struct ReportedError { error_kind: ErrorKind, } +impl ReportedError { + pub fn new(e: (impl UserFacingError + Into)) -> Self { + let error_kind = e.get_error_kind(); + Self { + source: e.into(), + error_kind, + } + } +} + impl std::fmt::Display for ReportedError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.source.fmt(f) @@ -102,109 +129,65 @@ impl ReportableError for ReportedError { } } -#[derive(Serialize, Deserialize, Debug)] -enum ErrorTag { - #[serde(rename = "proxy")] - Proxy, - #[serde(rename = "compute")] - Compute, - #[serde(rename = "client")] - Client, - #[serde(rename = "controlplane")] - ControlPlane, - #[serde(rename = "other")] - Other, -} - -impl From for ErrorTag { - fn from(error_kind: ErrorKind) -> Self { - match error_kind { - ErrorKind::User => Self::Client, - ErrorKind::ClientDisconnect => Self::Client, - ErrorKind::RateLimit => Self::Proxy, - ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI - ErrorKind::Quota => Self::Proxy, - ErrorKind::Service => Self::Proxy, - ErrorKind::ControlPlane => Self::ControlPlane, - ErrorKind::Postgres => Self::Other, - ErrorKind::Compute => Self::Compute, - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "snake_case")] -struct ProbeErrorData { - tag: ErrorTag, - msg: String, - cold_start_info: Option, -} - impl PqStream { - /// Write the message into an internal buffer, but don't flush the underlying stream. - pub(crate) fn write_message_noflush( - &mut self, - message: &BeMessage<'_>, - ) -> io::Result<&mut Self> { - self.framed - .write_message(message) - .map_err(ProtocolError::into_io_error)?; - Ok(self) + /// Tell the client that we are willing to accept SSL. + /// This is not cancel safe + pub async fn accept_tls(mut self) -> io::Result { + // S for SSL. + self.write.encryption(b'S'); + self.flush().await?; + Ok(self.stream) } - /// Write the message into an internal buffer and flush it. - pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - self.write_message_noflush(message)?; - self.flush().await?; - Ok(self) + /// Assert that we are using direct TLS. + pub fn accept_direct_tls(self) -> S { + self.stream + } + + /// Write a raw message to the internal buffer. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + self.write.write_raw(size_hint, tag, f); + } + + /// Write the message into an internal buffer + pub fn write_message(&mut self, message: BeMessage<'_>) { + message.write_message(&mut self.write); } /// Flush the output buffer into the underlying stream. - pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> { - self.framed.flush().await?; - Ok(self) + /// + /// This is cancel safe. + pub async fn flush(&mut self) -> io::Result<()> { + self.stream.write_all_buf(&mut self.write).await?; + self.write.reset(); + + self.stream.flush().await?; + + Ok(()) } - /// Writes message with the given error kind to the stream. - /// Used only for probe queries - async fn write_format_message( - &mut self, - msg: &str, - error_kind: ErrorKind, - ctx: Option<&crate::context::RequestContext>, - ) -> String { - let formatted_msg = match ctx { - Some(ctx) if ctx.get_testodrome_id().is_some() => { - serde_json::to_string(&ProbeErrorData { - tag: ErrorTag::from(error_kind), - msg: msg.to_string(), - cold_start_info: Some(ctx.cold_start_info()), - }) - .unwrap_or_default() - } - _ => msg.to_string(), - }; - - // already error case, ignore client IO error - self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None)) - .await - .inspect_err(|e| debug!("write_message failed: {e}")) - .ok(); - - formatted_msg + /// Flush the output buffer into the underlying stream. + /// + /// This is cancel safe. + pub async fn flush_and_into_inner(mut self) -> io::Result { + self.flush().await?; + Ok(self.stream) } - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Allowing string literals is safe under the assumption they might not contain any runtime info. - /// This method exists due to `&str` not implementing `Into`. + /// Write the error message to the client, then re-throw it. + /// + /// Trait [`UserFacingError`] acts as an allowlist for error types. /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub async fn throw_error_str( + pub(crate) async fn throw_error( &mut self, - msg: &'static str, - error_kind: ErrorKind, + error: E, ctx: Option<&crate::context::RequestContext>, - ) -> Result { - self.write_format_message(msg, error_kind, ctx).await; + ) -> ReportedError + where + E: UserFacingError + Into, + { + let error_kind = error.get_error_kind(); + let msg = error.to_string_client(); if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { tracing::info!( @@ -214,39 +197,39 @@ impl PqStream { ); } - Err(ReportedError { - source: anyhow::anyhow!(msg), - error_kind, - }) - } - - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Trait [`UserFacingError`] acts as an allowlist for error types. - /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub(crate) async fn throw_error( - &mut self, - error: E, - ctx: Option<&crate::context::RequestContext>, - ) -> Result - where - E: UserFacingError + Into, - { - let error_kind = error.get_error_kind(); - let msg = error.to_string_client(); - self.write_format_message(&msg, error_kind, ctx).await; - if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { - tracing::info!( - kind=error_kind.to_metric_label(), - error=%error, - msg, - "forwarding error to user", - ); + let probe_msg; + let mut msg = &*msg; + if let Some(ctx) = ctx { + if ctx.get_testodrome_id().is_some() { + let tag = match error_kind { + ErrorKind::User => "client", + ErrorKind::ClientDisconnect => "client", + ErrorKind::RateLimit => "proxy", + ErrorKind::ServiceRateLimit => "proxy", + ErrorKind::Quota => "proxy", + ErrorKind::Service => "proxy", + ErrorKind::ControlPlane => "controlplane", + ErrorKind::Postgres => "other", + ErrorKind::Compute => "compute", + }; + probe_msg = typed_json::json!({ + "tag": tag, + "msg": msg, + "cold_start_info": ctx.cold_start_info(), + }) + .to_string(); + msg = &probe_msg; + } } - Err(ReportedError { - source: anyhow::anyhow!(error), - error_kind, - }) + // TODO: either preserve the error code from postgres, or assign error codes to proxy errors. + self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR); + + self.flush() + .await + .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}")); + + ReportedError::new(error) } }