diff --git a/libs/proxy/tokio-postgres2/src/tls.rs b/libs/proxy/tokio-postgres2/src/tls.rs index f9cbcf4991..0613dfa9d0 100644 --- a/libs/proxy/tokio-postgres2/src/tls.rs +++ b/libs/proxy/tokio-postgres2/src/tls.rs @@ -14,7 +14,7 @@ pub(crate) mod private { /// Channel binding information returned from a TLS handshake. pub struct ChannelBinding { - pub(crate) tls_server_end_point: Option>, + pub tls_server_end_point: Option>, } impl ChannelBinding { diff --git a/proxy/src/compute/authenticate.rs b/proxy/src/compute/authenticate.rs new file mode 100644 index 0000000000..a6689fa1e5 --- /dev/null +++ b/proxy/src/compute/authenticate.rs @@ -0,0 +1,146 @@ +use bytes::BufMut; +use postgres_client::tls::{ChannelBinding, TlsStream}; +use postgres_protocol::authentication::sasl; +use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::{Auth, MaybeRustlsStream}; +use crate::compute::RustlsStream; +use crate::pqproto::{ + AUTH_OK, AUTH_SASL, AUTH_SASL_CONT, AUTH_SASL_FINAL, FE_PASSWORD_MESSAGE, StartupMessageParams, +}; +use crate::stream::{PostgresError, PqBeStream}; + +pub async fn authenticate( + stream: MaybeRustlsStream, + auth: Option<&Auth>, + params: &StartupMessageParams, +) -> Result>, PostgresError> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + RustlsStream: TlsStream + Unpin, +{ + let mut stream = PqBeStream::new(stream, params); + stream.flush().await?; + + let channel_binding = stream.get_ref().channel_binding(); + + // TODO: rather than checking for SASL, maybe we can just assume it. + // With SCRAM_SHA_256 if we're not using TLS, + // and SCRAM_SHA_256_PLUS if we are using TLS. + + let (channel_binding, mechanism) = match stream.read_auth_message().await? { + (AUTH_OK, _) => return Ok(stream), + (AUTH_SASL, mechanisms) => { + let mut has_scram = false; + let mut has_scram_plus = false; + for mechanism in mechanisms.split(|&b| b == b'\0') { + match mechanism { + b"SCRAM-SHA-256" => has_scram = true, + b"SCRAM-SHA-256-PLUS" => has_scram_plus = true, + _ => {} + } + } + + match (channel_binding, has_scram, has_scram_plus) { + (cb, true, false) => { + if cb.tls_server_end_point.is_some() { + // I don't think this can happen in our setup, but I would like to monitor it. + tracing::warn!( + "TLS is enabled, but compute doesn't support SCRAM-SHA-256-PLUS." + ); + } + (sasl::ChannelBinding::unrequested(), SCRAM_SHA_256) + } + ( + ChannelBinding { + tls_server_end_point: None, + }, + true, + _, + ) => (sasl::ChannelBinding::unsupported(), SCRAM_SHA_256), + ( + ChannelBinding { + tls_server_end_point: Some(h), + }, + _, + true, + ) => ( + sasl::ChannelBinding::tls_server_end_point(h), + SCRAM_SHA_256_PLUS, + ), + (_, false, _) => { + tracing::error!( + "compute responded with unsupported auth mechanisms: {}", + String::from_utf8_lossy(mechanisms) + ); + return Err(PostgresError::InvalidAuthMessage); + } + } + } + (tag, msg) => { + tracing::error!( + "compute responded with unexpected auth message with tag[{tag}]: {}", + String::from_utf8_lossy(msg) + ); + return Err(PostgresError::InvalidAuthMessage); + } + }; + + let mut scram = match auth { + // We only touch passwords when it comes to console-redirect. + Some(Auth::Password(pw)) => sasl::ScramSha256::new(pw, channel_binding), + Some(Auth::Scram(keys)) => sasl::ScramSha256::new_with_keys(**keys, channel_binding), + None => { + // local_proxy does not set credentials, since it relies on trust and expects an OK message above + tracing::error!("compute requested SASL auth, but there are no credentials available",); + return Err(PostgresError::InvalidAuthMessage); + } + }; + + stream.write_raw(0, FE_PASSWORD_MESSAGE.0, |buf| { + buf.put_slice(mechanism.as_bytes()); + buf.put_u8(b'\0'); + + let data = scram.message(); + buf.put_u32(data.len() as u32); + buf.put_slice(data); + }); + stream.flush().await?; + + loop { + // wait for SASLContinue or SASLFinal. + match stream.read_auth_message().await? { + (AUTH_SASL_CONT, data) => scram.update(data).await?, + (AUTH_SASL_FINAL, data) => { + scram.finish(data)?; + break; + } + (tag, msg) => { + tracing::error!( + "compute responded with unexpected auth message with tag[{tag}]: {}", + String::from_utf8_lossy(msg) + ); + return Err(PostgresError::InvalidAuthMessage); + } + } + + stream.write_raw(0, FE_PASSWORD_MESSAGE.0, |buf| { + buf.put_slice(scram.message()); + }); + stream.flush().await?; + } + + match stream.read_auth_message().await? { + (AUTH_OK, _) => {} + (tag, msg) => { + tracing::error!( + "compute responded with unexpected auth message with tag[{tag}]: {}", + String::from_utf8_lossy(msg) + ); + return Err(PostgresError::InvalidAuthMessage); + } + } + + Ok(stream) +} diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 24d294a762..8e20b491ea 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -1,3 +1,4 @@ +mod authenticate; mod tls; use std::fmt::Debug; @@ -9,8 +10,6 @@ use itertools::Itertools; use postgres_client::config::{AuthKeys, SslMode}; use postgres_client::maybe_tls_stream::MaybeTlsStream; use postgres_client::tls::MakeTlsConnect; -use postgres_client::{NoTls, RawCancelToken, RawConnection}; -use postgres_protocol::message::backend::NoticeResponseBody; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; use tracing::{debug, error, info, warn}; @@ -27,6 +26,7 @@ use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; +use crate::stream::{PostgresError, PqBeStream}; use crate::types::Host; pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; @@ -36,7 +36,7 @@ pub(crate) enum ConnectionError { /// This error doesn't seem to reveal any secrets; for instance, /// `postgres_client::error::Kind` doesn't contain ip addresses and such. #[error("{COULD_NOT_CONNECT}: {0}")] - Postgres(#[from] postgres_client::Error), + Postgres(#[from] PostgresError), #[error("{COULD_NOT_CONNECT}: {0}")] TlsError(#[from] TlsError), @@ -53,20 +53,21 @@ impl UserFacingError for ConnectionError { match self { // This helps us drop irrelevant library-specific prefixes. // TODO: propagate severity level and other parameters. - ConnectionError::Postgres(err) => match err.as_db_error() { - Some(err) => { - let msg = err.message(); + ConnectionError::Postgres(PostgresError::Error(err)) => { + let (_code, msg) = err.parse(); + let msg = String::from_utf8_lossy(msg); - if msg.starts_with("unsupported startup parameter: ") - || msg.starts_with("unsupported startup parameter in options: ") - { - format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter") - } else { - msg.to_owned() - } + if msg.starts_with("unsupported startup parameter: ") + || msg.starts_with("unsupported startup parameter in options: ") + { + format!( + "{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter" + ) + } else { + msg.into_owned() } - None => err.to_string(), - }, + } + ConnectionError::Postgres(err) => err.to_string(), ConnectionError::WakeComputeError(err) => err.to_string_client(), ConnectionError::TooManyConnectionAttempts(_) => { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() @@ -79,10 +80,12 @@ impl UserFacingError for ConnectionError { impl ReportableError for ConnectionError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - ConnectionError::Postgres(e) if e.as_db_error().is_some() => { - crate::error::ErrorKind::Postgres - } - ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, + ConnectionError::Postgres(PostgresError::Io(_)) => crate::error::ErrorKind::Compute, + ConnectionError::Postgres( + PostgresError::Error(_) + | PostgresError::InvalidAuthMessage + | PostgresError::Unexpected(_), + ) => crate::error::ErrorKind::Postgres, ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), @@ -161,18 +164,6 @@ impl ConnectInfo { } impl AuthInfo { - fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config { - match &self.auth { - Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)), - Some(Auth::Password(pw)) => config.password(pw), - None => &mut config, - }; - for (k, v) in self.server_params.iter() { - config.set_param(k, v); - } - config - } - /// Apply startup message params to the connection config. pub(crate) fn set_startup_params( &mut self, @@ -212,7 +203,7 @@ impl ConnectInfo { async fn connect_raw( &self, config: &ComputeConfig, - ) -> Result<(SocketAddr, MaybeTlsStream), TlsError> { + ) -> Result<(SocketAddr, MaybeRustlsStream), TlsError> { let timeout = config.timeout; // wrap TcpStream::connect with timeout @@ -264,25 +255,19 @@ impl ConnectInfo { } } -pub type RustlsStream = >::Stream; -pub type MaybeRustlsStream = MaybeTlsStream; +pub type RustlsStream = >::Stream; +pub type MaybeRustlsStream = MaybeTlsStream>; -pub(crate) struct PostgresConnection { +pub struct PostgresConnection { /// Socket connected to a compute node. - pub(crate) stream: MaybeTlsStream, - /// PostgreSQL connection parameters. - pub(crate) params: std::collections::HashMap, + pub stream: PqBeStream>, pub socket_addr: SocketAddr, - pub cancel_token: RawCancelToken, pub hostname: String, + pub ssl_mode: SslMode, + pub aux: MetricsAuxInfo, - /// Labels for proxy's metrics. - pub(crate) aux: MetricsAuxInfo, - /// Notices received from compute after authenticating - pub(crate) delayed_notice: Vec, - - pub(crate) guage: NumDbConnectionsGuard<'static>, + pub guage: NumDbConnectionsGuard<'static>, } impl ConnectInfo { @@ -290,30 +275,18 @@ impl ConnectInfo { pub(crate) async fn connect( &self, ctx: &RequestContext, - aux: MetricsAuxInfo, + aux: &MetricsAuxInfo, auth: &AuthInfo, config: &ComputeConfig, ) -> Result { - let mut tmp_config = auth.enrich(self.to_postgres_client_config()); - // we setup SSL early in `ConnectInfo::connect_raw`. - tmp_config.ssl_mode(SslMode::Disable); - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let (socket_addr, stream) = self.connect_raw(config).await?; - let connection = tmp_config.connect_raw(stream, NoTls).await?; + let stream = + authenticate::authenticate(stream, auth.auth.as_ref(), &auth.server_params).await?; drop(pause); - let RawConnection { - stream, - parameters, - delayed_notice, - process_id, - secret_key, - } = connection; - - tracing::Span::current().record("pid", tracing::field::display(process_id)); + // tracing::Span::current().record("pid", tracing::field::display(process_id)); tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); - let MaybeTlsStream::Raw(stream) = stream.into_inner(); // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?) info!( @@ -327,18 +300,10 @@ impl ConnectInfo { let connection = PostgresConnection { stream, - params: parameters, - delayed_notice, - socket_addr, - cancel_token: RawCancelToken { - ssl_mode: self.ssl_mode, - process_id, - secret_key, - }, hostname: self.host.to_string(), - - aux, + ssl_mode: self.ssl_mode, + aux: aux.clone(), guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()), }; diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 863156296d..a8d4e206a2 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -6,7 +6,7 @@ use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info}; use crate::auth::backend::ConsoleRedirectBackend; -use crate::cancellation::{CancelClosure, CancellationHandler}; +use crate::cancellation::CancellationHandler; use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::error::ReportableError; @@ -177,7 +177,7 @@ pub(crate) async fn handle_client( let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); let do_handshake = handshake(ctx, stream, tls, record_handshake_error); - let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) + let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) .await?? { HandshakeData::Startup(stream, params) => (stream, params), @@ -210,15 +210,15 @@ pub(crate) async fn handle_client( ctx.set_db_options(params.clone()); let (node_info, mut auth_info, user_info) = match backend - .authenticate(ctx, &config.authentication_config, &mut stream) + .authenticate(ctx, &config.authentication_config, &mut client) .await { Ok(auth_result) => auth_result, - Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, }; auth_info.set_startup_params(¶ms, true); - let node = connect_to_compute( + let mut node = connect_to_compute( ctx, &TcpMechanism { auth: auth_info, @@ -228,24 +228,17 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) + .or_else(|e| async { Err(client.throw_error(e, Some(ctx)).await) }) .await?; let session = cancellation_handler.get_key(); - prepare_client_connection(&node, *session.key(), &mut stream); - let stream = stream.flush_and_into_inner().await?; + let cancel_closure = + prepare_client_connection(&mut node, session.key(), &mut client, user_info).await?; let session_id = ctx.session_id(); let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); tokio::spawn(async move { - let cancel_closure = CancelClosure::new( - node.socket_addr, - node.cancel_token, - node.hostname, - user_info, - ); - session .maintain_cancel_key( session_id, @@ -256,9 +249,12 @@ pub(crate) async fn handle_client( .await; }); + let client = client.flush_and_into_inner().await?; + let compute = node.stream.flush_and_into_inner().await?; + Ok(Some(ProxyPassthrough { - client: stream, - compute: node.stream, + client, + compute, aux: node.aux, private_link_id: None, diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index d80caee9f3..12fb7ace99 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -79,9 +79,7 @@ impl NodeInfo { auth: &compute::AuthInfo, config: &ComputeConfig, ) -> Result { - self.conn_info - .connect(ctx, self.aux.clone(), auth, config) - .await + self.conn_info.connect(ctx, &self.aux, auth, config).await } } diff --git a/proxy/src/pglb/passthrough.rs b/proxy/src/pglb/passthrough.rs index d4c029f6d9..1ef41f7182 100644 --- a/proxy/src/pglb/passthrough.rs +++ b/proxy/src/pglb/passthrough.rs @@ -2,6 +2,7 @@ use std::convert::Infallible; use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; use tracing::debug; use utils::measured_stream::MeasuredStream; @@ -66,8 +67,7 @@ pub(crate) async fn proxy_pass( pub(crate) struct ProxyPassthrough { pub(crate) client: Stream, - pub(crate) compute: MaybeRustlsStream, - + pub(crate) compute: MaybeRustlsStream, pub(crate) aux: MetricsAuxInfo, pub(crate) private_link_id: Option, diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index 3af7560b7c..8e3041eefa 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -11,16 +11,16 @@ use rand::distributions::{Distribution, Standard}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; -#[derive(Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct ErrorCode(pub [u8; 5]); -#[derive(Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct FeTag(pub u8); -#[derive(Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct BeTag(pub u8); -#[derive(Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct AuthTag(pub i32); pub const FE_PASSWORD_MESSAGE: FeTag = FeTag(b'p'); @@ -32,7 +32,6 @@ pub const BE_READY_MESSAGE: BeTag = BeTag(b'Z'); pub const BE_NEGOTIATE_MESSAGE: BeTag = BeTag(b'v'); pub const AUTH_OK: AuthTag = AuthTag(0); -pub const AUTH_CLEAR: AuthTag = AuthTag(3); pub const AUTH_SASL: AuthTag = AuthTag(10); pub const AUTH_SASL_CONT: AuthTag = AuthTag(11); pub const AUTH_SASL_FINAL: AuthTag = AuthTag(12); @@ -356,6 +355,10 @@ impl WriteBuf { Self(Cursor::new(Vec::new())) } + pub const fn len(&self) -> usize { + self.0.get_ref().len() + } + /// Use a heuristic to determine if we should shrink the write buffer. #[inline] fn should_shrink(&self) -> bool { @@ -557,11 +560,11 @@ pub enum BeMessage<'a> { AuthenticationOk, AuthenticationSasl(BeAuthenticationSaslMessage<'a>), AuthenticationCleartextPassword, - BackendKeyData(CancelKeyData), ParameterStatus { name: &'a [u8], value: &'a [u8], }, + #[cfg(test)] ReadyForQuery, NoticeResponse(&'a str), NegotiateProtocolVersion { @@ -617,13 +620,6 @@ impl BeMessage<'_> { }); } - // - BeMessage::BackendKeyData(key_data) => { - buf.write_raw(8, BE_KEY_MESSAGE.0, |buf| { - buf.put_slice(key_data.as_bytes()) - }); - } - // // BeMessage::NoticeResponse(msg) => { @@ -655,6 +651,7 @@ impl BeMessage<'_> { }); } + #[cfg(test)] // BeMessage::ReadyForQuery => { buf.write_raw(1, BE_READY_MESSAGE.0, |buf| buf.put_u8(b'I')); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 8aa398b20f..48c0d9b98d 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -10,6 +10,7 @@ use std::sync::Arc; use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; +use postgres_client::RawCancelToken; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, ToSmolStr, format_smolstr}; @@ -18,6 +19,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; +use crate::auth::backend::ComputeUserInfo; use crate::cancellation::{self, CancelClosure, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; @@ -26,11 +28,11 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard}; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; use crate::pglb::passthrough::ProxyPassthrough; -use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; +use crate::pqproto::{CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; use crate::rate_limiter::EndpointRateLimiter; -use crate::stream::{PqFeStream, Stream}; +use crate::stream::{PostgresError, PqFeStream, Stream}; use crate::types::EndpointCacheKey; use crate::util::run_until_cancelled; use crate::{auth, compute}; @@ -253,7 +255,7 @@ pub(crate) async fn handle_client( auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, cancellation_handler: Arc, - stream: S, + client: S, mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: NumClientConnectionsGuard<'static>, @@ -273,9 +275,9 @@ pub(crate) async fn handle_client( let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error); + let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error); - let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) + let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) .await?? { HandshakeData::Startup(stream, params) => (stream, params), @@ -307,7 +309,7 @@ pub(crate) async fn handle_client( ctx.set_db_options(params.clone()); - let hostname = mode.hostname(stream.get_ref()); + let hostname = mode.hostname(client.get_ref()); let common_names = tls.map(|tls| &tls.common_names); @@ -319,14 +321,14 @@ pub(crate) async fn handle_client( let user_info = match result { Ok(user_info) => user_info, - Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, }; let user = user_info.get_user().to_owned(); let user_info = match user_info .authenticate( ctx, - &mut stream, + &mut client, mode.allow_cleartext(), &config.authentication_config, endpoint_rate_limiter, @@ -339,7 +341,7 @@ pub(crate) async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return Err(stream + return Err(client .throw_error(e, Some(ctx)) .instrument(params_span) .await)?; @@ -366,26 +368,19 @@ pub(crate) async fn handle_client( ) .await; - let node = match res { + let mut node = match res { Ok(node) => node, - Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, }; let session = cancellation_handler.get_key(); - prepare_client_connection(&node, *session.key(), &mut stream); - let stream = stream.flush_and_into_inner().await?; + let cancel_closure = + prepare_client_connection(&mut node, session.key(), &mut client, creds.info).await?; let session_id = ctx.session_id(); let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); tokio::spawn(async move { - let cancel_closure = CancelClosure::new( - node.socket_addr, - node.cancel_token, - node.hostname, - creds.info, - ); - session .maintain_cancel_key( session_id, @@ -396,6 +391,9 @@ pub(crate) async fn handle_client( .await; }); + let client = client.flush_and_into_inner().await?; + let compute = node.stream.flush_and_into_inner().await?; + let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), @@ -403,8 +401,8 @@ pub(crate) async fn handle_client( }; Ok(Some(ProxyPassthrough { - client: stream, - compute: node.stream, + client, + compute, aux: node.aux, private_link_id, @@ -418,28 +416,60 @@ pub(crate) async fn handle_client( } /// Finish client connection initialization: confirm auth success, send params, etc. -pub(crate) fn prepare_client_connection( - node: &compute::PostgresConnection, - cancel_key_data: CancelKeyData, +pub(crate) async fn prepare_client_connection( + node: &mut compute::PostgresConnection, + key_data: &CancelKeyData, stream: &mut PqFeStream, -) { - // Forward all deferred notices to the client. - for notice in &node.delayed_notice { - stream.write_raw(notice.as_bytes().len(), b'N', |buf| { - buf.extend_from_slice(notice.as_bytes()); - }); + user_info: ComputeUserInfo, +) -> Result { + use zerocopy::{FromBytes, IntoBytes}; + + use crate::pqproto::{BE_KEY_MESSAGE, BE_READY_MESSAGE}; + + let mut process_id = 0; + let mut secret_key = 0; + + loop { + match node.stream.read_raw_be(1024).await { + // parse backend keys, and substitute our own. + Ok((tag @ BE_KEY_MESSAGE, msg)) => { + stream.write_raw(8, tag, |b| b.extend_from_slice(key_data.as_bytes())); + + let key_data = CancelKeyData::read_from_bytes(msg) + .map_err(|_| std::io::Error::other("invalid msg len"))?; + + process_id = (key_data.0.get() >> 32) as i32; + secret_key = (key_data.0.get() & 0xffff_ffff) as i32; + } + // ready for query, we're done :) + Ok((tag @ BE_READY_MESSAGE, msg)) => { + stream.write_raw(msg.len(), tag, |b| b.extend_from_slice(msg.as_bytes())); + break; + } + // either a notice or a parameter status. + Ok((tag, msg)) => { + stream.write_raw(msg.len(), tag, |b| b.extend_from_slice(msg.as_bytes())); + } + Err(PostgresError::Io(io)) => return Err(io), + Err(PostgresError::Error(e)) => return Err(std::io::Error::other(e)), + Err(_) => unreachable!("read_raw_be only returns IO or BackendError types"), + } + + if stream.write_buf_len() > 512 { + stream.flush().await?; + } } - // Forward all postgres connection params to the client. - for (name, value) in &node.params { - stream.write_message(BeMessage::ParameterStatus { - name: name.as_bytes(), - value: value.as_bytes(), - }); - } - - stream.write_message(BeMessage::BackendKeyData(cancel_key_data)); - stream.write_message(BeMessage::ReadyForQuery); + Ok(CancelClosure::new( + node.socket_addr, + RawCancelToken { + ssl_mode: node.ssl_mode, + process_id, + secret_key, + }, + node.hostname.clone(), + user_info, + )) } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 0f19944afa..00726fb7a1 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -1,10 +1,12 @@ use std::error::Error; use std::io; +use bstr::ByteSlice; use tokio::time; use crate::compute; use crate::config::RetryConfig; +use crate::stream::{BackendError, PostgresError}; pub(crate) trait CouldRetry { /// Returns true if the error could be retried @@ -96,10 +98,55 @@ impl ShouldRetryWakeCompute for postgres_client::Error { } } +impl CouldRetry for BackendError { + fn could_retry(&self) -> bool { + let (code, _message) = self.parse(); + matches!( + code, + crate::pqproto::CONNECTION_FAILURE + | crate::pqproto::CONNECTION_EXCEPTION + | crate::pqproto::CONNECTION_DOES_NOT_EXIST + | crate::pqproto::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION, + ) + } +} + +impl ShouldRetryWakeCompute for BackendError { + fn should_retry_wake_compute(&self) -> bool { + let (code, message) = self.parse(); + + // Here are errors that happens after the user successfully authenticated to the database. + let non_retriable_pg_errors = matches!( + code, + crate::pqproto::TOO_MANY_CONNECTIONS + | crate::pqproto::OUT_OF_MEMORY + | crate::pqproto::SYNTAX_ERROR + | crate::pqproto::T_R_SERIALIZATION_FAILURE + | crate::pqproto::INVALID_CATALOG_NAME + | crate::pqproto::INVALID_SCHEMA_NAME + | crate::pqproto::INVALID_PARAMETER_VALUE, + ); + if non_retriable_pg_errors { + return false; + } + + // PGBouncer errors that should not trigger a wake_compute retry. + if code == crate::pqproto::PROTOCOL_VIOLATION { + // Source for the error message: + // https://github.com/pgbouncer/pgbouncer/blob/f15997fe3effe3a94ba8bcc1ea562e6117d1a131/src/client.c#L1070 + return message.contains_str("no more connections allowed (max_client_conn)"); + } + true + } +} + impl CouldRetry for compute::ConnectionError { fn could_retry(&self) -> bool { match self { - compute::ConnectionError::Postgres(err) => err.could_retry(), + compute::ConnectionError::Postgres(PostgresError::Error(err)) => err.could_retry(), + compute::ConnectionError::Postgres(PostgresError::Io(err)) => err.could_retry(), + compute::ConnectionError::Postgres(PostgresError::Unexpected(_)) => false, + compute::ConnectionError::Postgres(PostgresError::InvalidAuthMessage) => false, compute::ConnectionError::TlsError(err) => err.could_retry(), compute::ConnectionError::WakeComputeError(err) => err.could_retry(), compute::ConnectionError::TooManyConnectionAttempts(_) => false, @@ -109,7 +156,12 @@ impl CouldRetry for compute::ConnectionError { impl ShouldRetryWakeCompute for compute::ConnectionError { fn should_retry_wake_compute(&self) -> bool { match self { - compute::ConnectionError::Postgres(err) => err.should_retry_wake_compute(), + compute::ConnectionError::Postgres(PostgresError::Error(err)) => { + err.should_retry_wake_compute() + } + compute::ConnectionError::Postgres(PostgresError::Io(_)) => true, + compute::ConnectionError::Postgres(PostgresError::Unexpected(_)) => false, + compute::ConnectionError::Postgres(PostgresError::InvalidAuthMessage) => false, // the cache entry was not checked for validity compute::ConnectionError::TooManyConnectionAttempts(_) => false, _ => true, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index daf2394a77..692b3b7946 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -25,6 +25,7 @@ use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; +use crate::pqproto::BeMessage; use crate::proxy::connect_compute::ConnectMechanism; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::server_config::CertResolver; diff --git a/proxy/src/stream/mod.rs b/proxy/src/stream/mod.rs index 4a0f13be8c..7baabe7091 100644 --- a/proxy/src/stream/mod.rs +++ b/proxy/src/stream/mod.rs @@ -1,9 +1,11 @@ +mod pq_backend; mod pq_frontend; use std::pin::Pin; use std::sync::Arc; use std::{io, task}; +pub use pq_backend::{BackendError, PostgresError, PqBeStream}; pub use pq_frontend::PqFeStream; use rustls::ServerConfig; use thiserror::Error; diff --git a/proxy/src/stream/pq_backend.rs b/proxy/src/stream/pq_backend.rs new file mode 100644 index 0000000000..00b182bf9a --- /dev/null +++ b/proxy/src/stream/pq_backend.rs @@ -0,0 +1,165 @@ +//! Postgres connection from backend, proxy is the frontend. + +use std::io; + +use bytes::Bytes; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; + +use crate::pqproto::{ + AuthTag, BE_AUTH_MESSAGE, BE_ERR_MESSAGE, BeTag, ErrorCode, SQLSTATE_INTERNAL_ERROR, + StartupMessageParams, WriteBuf, read_message, +}; + +/// Stream wrapper which implements libpq's protocol. +pub struct PqBeStream { + stream: S, + read: Vec, + write: WriteBuf, +} + +impl PqBeStream { + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Construct a new libpq protocol wrapper and write the first startup message. + pub fn new(stream: S, params: &StartupMessageParams) -> Self { + let mut write = WriteBuf::new(); + write.startup(params); + Self { + stream, + read: Vec::new(), + write, + } + } +} + +impl PqBeStream { + /// Read a raw postgres packet from the backend, which will respect the max length requested, + /// as well as handling postgres error messages. + /// + /// This is not cancel safe. + pub async fn read_raw_be(&mut self, max: u32) -> Result<(BeTag, &mut [u8]), PostgresError> { + let (tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; + match BeTag(tag) { + BE_ERR_MESSAGE => Err(PostgresError::Error(BackendError { + data: msg.to_vec().into(), + })), + tag => Ok((tag, msg)), + } + } + + /// Read a raw postgres packet, which will respect the max length requested. + /// This is not cancel safe. + async fn read_raw_be_expect( + &mut self, + tag: BeTag, + max: u32, + ) -> Result<&mut [u8], PostgresError> { + let (actual_tag, msg) = self.read_raw_be(max).await?; + if actual_tag != tag { + return Err(PostgresError::Unexpected(UnexpectedMessage { + expected: tag, + tag: actual_tag, + data: msg.to_vec().into(), + })); + } + Ok(msg) + } + + /// Read a postgres backend auth message. + /// This is not cancel safe. + pub async fn read_auth_message(&mut self) -> Result<(AuthTag, &mut [u8]), PostgresError> { + const MAX_AUTH_LENGTH: u32 = 512; + + self.read_raw_be_expect(BE_AUTH_MESSAGE, MAX_AUTH_LENGTH) + .await? + .split_first_chunk_mut() + .map(|(tag, msg)| (AuthTag(i32::from_be_bytes(*tag)), msg)) + .ok_or(PostgresError::InvalidAuthMessage) + } +} + +impl PqBeStream { + /// Write a raw message to the internal buffer. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + self.write.write_raw(size_hint, tag, f); + } + + /// Flush the output buffer into the underlying stream. + /// + /// This is cancel safe. + pub async fn flush(&mut self) -> io::Result<()> { + self.stream.write_all_buf(&mut self.write).await?; + self.write.reset(); + + self.stream.flush().await?; + + Ok(()) + } + + /// Flush the output buffer into the underlying stream. + /// + /// This is cancel safe. + pub async fn flush_and_into_inner(mut self) -> io::Result { + self.flush().await?; + Ok(self.stream) + } +} + +#[derive(Debug, Error)] +pub enum PostgresError { + #[error("postgres responded with error {0}")] + Error(#[from] BackendError), + #[error("postgres responded with an unexpected message: {0}")] + Unexpected(#[from] UnexpectedMessage), + #[error("postgres responded with an invalid authentication message")] + InvalidAuthMessage, + #[error("IO error from compute: {0}")] + Io(#[from] io::Error), +} + +#[derive(Debug, Error)] +#[error("expected {expected}, got {tag} with data {data:?}")] +pub struct UnexpectedMessage { + expected: BeTag, + tag: BeTag, + data: Bytes, +} + +pub struct BackendError { + data: Bytes, +} + +impl BackendError { + pub fn parse(&self) -> (ErrorCode, &[u8]) { + let mut code = &[] as &[u8]; + let mut message = &[] as &[u8]; + + for param in self.data.split(|b| *b == 0) { + match param { + [b'M', rest @ ..] => message = rest, + [b'C', rest @ ..] => code = rest, + _ => {} + } + } + + let code = code.try_into().map_or(SQLSTATE_INTERNAL_ERROR, ErrorCode); + + (code, message) + } +} + +impl std::fmt::Debug for BackendError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + +impl std::fmt::Display for BackendError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", &self.data) + } +} +impl std::error::Error for BackendError {} diff --git a/proxy/src/stream/pq_frontend.rs b/proxy/src/stream/pq_frontend.rs index 0dc5e05688..ef2619041b 100644 --- a/proxy/src/stream/pq_frontend.rs +++ b/proxy/src/stream/pq_frontend.rs @@ -6,8 +6,8 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use crate::error::{ErrorKind, UserFacingError}; use crate::pqproto::{ - BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, FeTag, SQLSTATE_INTERNAL_ERROR, WriteBuf, - read_message, read_startup, + BeMessage, BeTag, FE_PASSWORD_MESSAGE, FeStartupPacket, FeTag, SQLSTATE_INTERNAL_ERROR, + WriteBuf, read_message, read_startup, }; use crate::stream::ReportedError; @@ -32,6 +32,10 @@ impl PqFeStream { write: WriteBuf::new(), } } + + pub fn write_buf_len(&self) -> usize { + self.write.len() + } } impl PqFeStream { @@ -103,8 +107,8 @@ impl PqFeStream { } /// Write a raw message to the internal buffer. - pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { - self.write.write_raw(size_hint, tag, f); + pub fn write_raw(&mut self, size_hint: usize, tag: BeTag, f: impl FnOnce(&mut Vec)) { + self.write.write_raw(size_hint, tag.0, f); } /// Write the message into an internal buffer @@ -150,6 +154,7 @@ impl PqFeStream { if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { tracing::info!( kind = error_kind.to_metric_label(), + %error, msg, "forwarding error to user" );