mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-25 17:10:38 +00:00
re-impl auth startup to compute
This commit is contained in:
146
proxy/src/compute/authenticate.rs
Normal file
146
proxy/src/compute/authenticate.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use bytes::BufMut;
|
||||
use postgres_client::tls::{ChannelBinding, TlsStream};
|
||||
use postgres_protocol::authentication::sasl;
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use super::{Auth, MaybeRustlsStream};
|
||||
use crate::compute::RustlsStream;
|
||||
use crate::pqproto::{
|
||||
AUTH_OK, AUTH_SASL, AUTH_SASL_CONT, AUTH_SASL_FINAL, FE_PASSWORD_MESSAGE, StartupMessageParams,
|
||||
};
|
||||
use crate::stream::{PostgresError, PqBeStream};
|
||||
|
||||
pub async fn authenticate<S>(
|
||||
stream: MaybeRustlsStream<S>,
|
||||
auth: Option<&Auth>,
|
||||
params: &StartupMessageParams,
|
||||
) -> Result<PqBeStream<MaybeRustlsStream<S>>, PostgresError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
RustlsStream<S>: TlsStream + Unpin,
|
||||
{
|
||||
let mut stream = PqBeStream::new(stream, params);
|
||||
stream.flush().await?;
|
||||
|
||||
let channel_binding = stream.get_ref().channel_binding();
|
||||
|
||||
// TODO: rather than checking for SASL, maybe we can just assume it.
|
||||
// With SCRAM_SHA_256 if we're not using TLS,
|
||||
// and SCRAM_SHA_256_PLUS if we are using TLS.
|
||||
|
||||
let (channel_binding, mechanism) = match stream.read_auth_message().await? {
|
||||
(AUTH_OK, _) => return Ok(stream),
|
||||
(AUTH_SASL, mechanisms) => {
|
||||
let mut has_scram = false;
|
||||
let mut has_scram_plus = false;
|
||||
for mechanism in mechanisms.split(|&b| b == b'\0') {
|
||||
match mechanism {
|
||||
b"SCRAM-SHA-256" => has_scram = true,
|
||||
b"SCRAM-SHA-256-PLUS" => has_scram_plus = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
match (channel_binding, has_scram, has_scram_plus) {
|
||||
(cb, true, false) => {
|
||||
if cb.tls_server_end_point.is_some() {
|
||||
// I don't think this can happen in our setup, but I would like to monitor it.
|
||||
tracing::warn!(
|
||||
"TLS is enabled, but compute doesn't support SCRAM-SHA-256-PLUS."
|
||||
);
|
||||
}
|
||||
(sasl::ChannelBinding::unrequested(), SCRAM_SHA_256)
|
||||
}
|
||||
(
|
||||
ChannelBinding {
|
||||
tls_server_end_point: None,
|
||||
},
|
||||
true,
|
||||
_,
|
||||
) => (sasl::ChannelBinding::unsupported(), SCRAM_SHA_256),
|
||||
(
|
||||
ChannelBinding {
|
||||
tls_server_end_point: Some(h),
|
||||
},
|
||||
_,
|
||||
true,
|
||||
) => (
|
||||
sasl::ChannelBinding::tls_server_end_point(h),
|
||||
SCRAM_SHA_256_PLUS,
|
||||
),
|
||||
(_, false, _) => {
|
||||
tracing::error!(
|
||||
"compute responded with unsupported auth mechanisms: {}",
|
||||
String::from_utf8_lossy(mechanisms)
|
||||
);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
}
|
||||
}
|
||||
(tag, msg) => {
|
||||
tracing::error!(
|
||||
"compute responded with unexpected auth message with tag[{tag}]: {}",
|
||||
String::from_utf8_lossy(msg)
|
||||
);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
};
|
||||
|
||||
let mut scram = match auth {
|
||||
// We only touch passwords when it comes to console-redirect.
|
||||
Some(Auth::Password(pw)) => sasl::ScramSha256::new(pw, channel_binding),
|
||||
Some(Auth::Scram(keys)) => sasl::ScramSha256::new_with_keys(**keys, channel_binding),
|
||||
None => {
|
||||
// local_proxy does not set credentials, since it relies on trust and expects an OK message above
|
||||
tracing::error!("compute requested SASL auth, but there are no credentials available",);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
};
|
||||
|
||||
stream.write_raw(0, FE_PASSWORD_MESSAGE.0, |buf| {
|
||||
buf.put_slice(mechanism.as_bytes());
|
||||
buf.put_u8(b'\0');
|
||||
|
||||
let data = scram.message();
|
||||
buf.put_u32(data.len() as u32);
|
||||
buf.put_slice(data);
|
||||
});
|
||||
stream.flush().await?;
|
||||
|
||||
loop {
|
||||
// wait for SASLContinue or SASLFinal.
|
||||
match stream.read_auth_message().await? {
|
||||
(AUTH_SASL_CONT, data) => scram.update(data).await?,
|
||||
(AUTH_SASL_FINAL, data) => {
|
||||
scram.finish(data)?;
|
||||
break;
|
||||
}
|
||||
(tag, msg) => {
|
||||
tracing::error!(
|
||||
"compute responded with unexpected auth message with tag[{tag}]: {}",
|
||||
String::from_utf8_lossy(msg)
|
||||
);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
}
|
||||
|
||||
stream.write_raw(0, FE_PASSWORD_MESSAGE.0, |buf| {
|
||||
buf.put_slice(scram.message());
|
||||
});
|
||||
stream.flush().await?;
|
||||
}
|
||||
|
||||
match stream.read_auth_message().await? {
|
||||
(AUTH_OK, _) => {}
|
||||
(tag, msg) => {
|
||||
tracing::error!(
|
||||
"compute responded with unexpected auth message with tag[{tag}]: {}",
|
||||
String::from_utf8_lossy(msg)
|
||||
);
|
||||
return Err(PostgresError::InvalidAuthMessage);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
mod authenticate;
|
||||
mod tls;
|
||||
|
||||
use std::fmt::Debug;
|
||||
@@ -9,8 +10,6 @@ use itertools::Itertools;
|
||||
use postgres_client::config::{AuthKeys, SslMode};
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{NoTls, RawCancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -27,6 +26,7 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::neon_option;
|
||||
use crate::stream::{PostgresError, PqBeStream};
|
||||
use crate::types::Host;
|
||||
|
||||
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
@@ -36,7 +36,7 @@ pub(crate) enum ConnectionError {
|
||||
/// This error doesn't seem to reveal any secrets; for instance,
|
||||
/// `postgres_client::error::Kind` doesn't contain ip addresses and such.
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
Postgres(#[from] PostgresError),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] TlsError),
|
||||
@@ -53,20 +53,21 @@ impl UserFacingError for ConnectionError {
|
||||
match self {
|
||||
// This helps us drop irrelevant library-specific prefixes.
|
||||
// TODO: propagate severity level and other parameters.
|
||||
ConnectionError::Postgres(err) => match err.as_db_error() {
|
||||
Some(err) => {
|
||||
let msg = err.message();
|
||||
ConnectionError::Postgres(PostgresError::Error(err)) => {
|
||||
let (_code, msg) = err.parse();
|
||||
let msg = String::from_utf8_lossy(msg);
|
||||
|
||||
if msg.starts_with("unsupported startup parameter: ")
|
||||
|| msg.starts_with("unsupported startup parameter in options: ")
|
||||
{
|
||||
format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
|
||||
} else {
|
||||
msg.to_owned()
|
||||
}
|
||||
if msg.starts_with("unsupported startup parameter: ")
|
||||
|| msg.starts_with("unsupported startup parameter in options: ")
|
||||
{
|
||||
format!(
|
||||
"{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter"
|
||||
)
|
||||
} else {
|
||||
msg.into_owned()
|
||||
}
|
||||
None => err.to_string(),
|
||||
},
|
||||
}
|
||||
ConnectionError::Postgres(err) => err.to_string(),
|
||||
ConnectionError::WakeComputeError(err) => err.to_string_client(),
|
||||
ConnectionError::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
@@ -79,10 +80,12 @@ impl UserFacingError for ConnectionError {
|
||||
impl ReportableError for ConnectionError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
|
||||
crate::error::ErrorKind::Postgres
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::Postgres(PostgresError::Io(_)) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::Postgres(
|
||||
PostgresError::Error(_)
|
||||
| PostgresError::InvalidAuthMessage
|
||||
| PostgresError::Unexpected(_),
|
||||
) => crate::error::ErrorKind::Postgres,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
@@ -161,18 +164,6 @@ impl ConnectInfo {
|
||||
}
|
||||
|
||||
impl AuthInfo {
|
||||
fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config {
|
||||
match &self.auth {
|
||||
Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)),
|
||||
Some(Auth::Password(pw)) => config.password(pw),
|
||||
None => &mut config,
|
||||
};
|
||||
for (k, v) in self.server_params.iter() {
|
||||
config.set_param(k, v);
|
||||
}
|
||||
config
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub(crate) fn set_startup_params(
|
||||
&mut self,
|
||||
@@ -212,7 +203,7 @@ impl ConnectInfo {
|
||||
async fn connect_raw(
|
||||
&self,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
|
||||
) -> Result<(SocketAddr, MaybeRustlsStream<TcpStream>), TlsError> {
|
||||
let timeout = config.timeout;
|
||||
|
||||
// wrap TcpStream::connect with timeout
|
||||
@@ -264,25 +255,19 @@ impl ConnectInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
|
||||
pub type RustlsStream<S> = <ComputeConfig as MakeTlsConnect<S>>::Stream;
|
||||
pub type MaybeRustlsStream<S> = MaybeTlsStream<S, RustlsStream<S>>;
|
||||
|
||||
pub(crate) struct PostgresConnection {
|
||||
pub struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub(crate) stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub(crate) params: std::collections::HashMap<String, String>,
|
||||
pub stream: PqBeStream<MaybeRustlsStream<TcpStream>>,
|
||||
|
||||
pub socket_addr: SocketAddr,
|
||||
pub cancel_token: RawCancelToken,
|
||||
pub hostname: String,
|
||||
pub ssl_mode: SslMode,
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
/// Notices received from compute after authenticating
|
||||
pub(crate) delayed_notice: Vec<NoticeResponseBody>,
|
||||
|
||||
pub(crate) guage: NumDbConnectionsGuard<'static>,
|
||||
pub guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl ConnectInfo {
|
||||
@@ -290,30 +275,18 @@ impl ConnectInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
aux: MetricsAuxInfo,
|
||||
aux: &MetricsAuxInfo,
|
||||
auth: &AuthInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let mut tmp_config = auth.enrich(self.to_postgres_client_config());
|
||||
// we setup SSL early in `ConnectInfo::connect_raw`.
|
||||
tmp_config.ssl_mode(SslMode::Disable);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream) = self.connect_raw(config).await?;
|
||||
let connection = tmp_config.connect_raw(stream, NoTls).await?;
|
||||
let stream =
|
||||
authenticate::authenticate(stream, auth.auth.as_ref(), &auth.server_params).await?;
|
||||
drop(pause);
|
||||
|
||||
let RawConnection {
|
||||
stream,
|
||||
parameters,
|
||||
delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connection;
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
// tracing::Span::current().record("pid", tracing::field::display(process_id));
|
||||
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
|
||||
let MaybeTlsStream::Raw(stream) = stream.into_inner();
|
||||
|
||||
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
|
||||
info!(
|
||||
@@ -327,18 +300,10 @@ impl ConnectInfo {
|
||||
|
||||
let connection = PostgresConnection {
|
||||
stream,
|
||||
params: parameters,
|
||||
delayed_notice,
|
||||
|
||||
socket_addr,
|
||||
cancel_token: RawCancelToken {
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
hostname: self.host.to_string(),
|
||||
|
||||
aux,
|
||||
ssl_mode: self.ssl_mode,
|
||||
aux: aux.clone(),
|
||||
guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
};
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info};
|
||||
|
||||
use crate::auth::backend::ConsoleRedirectBackend;
|
||||
use crate::cancellation::{CancelClosure, CancellationHandler};
|
||||
use crate::cancellation::CancellationHandler;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
@@ -177,7 +177,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
@@ -210,15 +210,15 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, mut auth_info, user_info) = match backend
|
||||
.authenticate(ctx, &config.authentication_config, &mut stream)
|
||||
.authenticate(ctx, &config.authentication_config, &mut client)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
auth_info.set_startup_params(¶ms, true);
|
||||
|
||||
let node = connect_to_compute(
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
auth: auth_info,
|
||||
@@ -228,24 +228,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.or_else(|e| async { Err(client.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
let cancel_closure =
|
||||
prepare_client_connection(&mut node, session.key(), &mut client, user_info).await?;
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
let cancel_closure = CancelClosure::new(
|
||||
node.socket_addr,
|
||||
node.cancel_token,
|
||||
node.hostname,
|
||||
user_info,
|
||||
);
|
||||
|
||||
session
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
@@ -256,9 +249,12 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.await;
|
||||
});
|
||||
|
||||
let client = client.flush_and_into_inner().await?;
|
||||
let compute = node.stream.flush_and_into_inner().await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
compute: node.stream,
|
||||
client,
|
||||
compute,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id: None,
|
||||
|
||||
@@ -79,9 +79,7 @@ impl NodeInfo {
|
||||
auth: &compute::AuthInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.conn_info
|
||||
.connect(ctx, self.aux.clone(), auth, config)
|
||||
.await
|
||||
self.conn_info.connect(ctx, &self.aux, auth, config).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::convert::Infallible;
|
||||
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
@@ -66,8 +67,7 @@ pub(crate) async fn proxy_pass(
|
||||
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: MaybeRustlsStream,
|
||||
|
||||
pub(crate) compute: MaybeRustlsStream<TcpStream>,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
|
||||
|
||||
@@ -11,16 +11,16 @@ use rand::distributions::{Distribution, Standard};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
|
||||
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct ErrorCode(pub [u8; 5]);
|
||||
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct FeTag(pub u8);
|
||||
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct BeTag(pub u8);
|
||||
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct AuthTag(pub i32);
|
||||
|
||||
pub const FE_PASSWORD_MESSAGE: FeTag = FeTag(b'p');
|
||||
@@ -32,7 +32,6 @@ pub const BE_READY_MESSAGE: BeTag = BeTag(b'Z');
|
||||
pub const BE_NEGOTIATE_MESSAGE: BeTag = BeTag(b'v');
|
||||
|
||||
pub const AUTH_OK: AuthTag = AuthTag(0);
|
||||
pub const AUTH_CLEAR: AuthTag = AuthTag(3);
|
||||
pub const AUTH_SASL: AuthTag = AuthTag(10);
|
||||
pub const AUTH_SASL_CONT: AuthTag = AuthTag(11);
|
||||
pub const AUTH_SASL_FINAL: AuthTag = AuthTag(12);
|
||||
@@ -356,6 +355,10 @@ impl WriteBuf {
|
||||
Self(Cursor::new(Vec::new()))
|
||||
}
|
||||
|
||||
pub const fn len(&self) -> usize {
|
||||
self.0.get_ref().len()
|
||||
}
|
||||
|
||||
/// Use a heuristic to determine if we should shrink the write buffer.
|
||||
#[inline]
|
||||
fn should_shrink(&self) -> bool {
|
||||
@@ -557,11 +560,11 @@ pub enum BeMessage<'a> {
|
||||
AuthenticationOk,
|
||||
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
|
||||
AuthenticationCleartextPassword,
|
||||
BackendKeyData(CancelKeyData),
|
||||
ParameterStatus {
|
||||
name: &'a [u8],
|
||||
value: &'a [u8],
|
||||
},
|
||||
#[cfg(test)]
|
||||
ReadyForQuery,
|
||||
NoticeResponse(&'a str),
|
||||
NegotiateProtocolVersion {
|
||||
@@ -617,13 +620,6 @@ impl BeMessage<'_> {
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
|
||||
BeMessage::BackendKeyData(key_data) => {
|
||||
buf.write_raw(8, BE_KEY_MESSAGE.0, |buf| {
|
||||
buf.put_slice(key_data.as_bytes())
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
|
||||
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
|
||||
BeMessage::NoticeResponse(msg) => {
|
||||
@@ -655,6 +651,7 @@ impl BeMessage<'_> {
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
|
||||
BeMessage::ReadyForQuery => {
|
||||
buf.write_raw(1, BE_READY_MESSAGE.0, |buf| buf.put_u8(b'I'));
|
||||
|
||||
@@ -10,6 +10,7 @@ use std::sync::Arc;
|
||||
use futures::FutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use postgres_client::RawCancelToken;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
@@ -18,6 +19,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::cancellation::{self, CancelClosure, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
@@ -26,11 +28,11 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
use crate::pqproto::{CancelKeyData, StartupMessageParams};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqFeStream, Stream};
|
||||
use crate::stream::{PostgresError, PqFeStream, Stream};
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::util::run_until_cancelled;
|
||||
use crate::{auth, compute};
|
||||
@@ -253,7 +255,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
client: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
@@ -273,9 +275,9 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
@@ -307,7 +309,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
let hostname = mode.hostname(client.get_ref());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
@@ -319,14 +321,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let user_info = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
&mut client,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
@@ -339,7 +341,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return Err(stream
|
||||
return Err(client
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await)?;
|
||||
@@ -366,26 +368,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
)
|
||||
.await;
|
||||
|
||||
let node = match res {
|
||||
let mut node = match res {
|
||||
Ok(node) => node,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
let cancel_closure =
|
||||
prepare_client_connection(&mut node, session.key(), &mut client, creds.info).await?;
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
let cancel_closure = CancelClosure::new(
|
||||
node.socket_addr,
|
||||
node.cancel_token,
|
||||
node.hostname,
|
||||
creds.info,
|
||||
);
|
||||
|
||||
session
|
||||
.maintain_cancel_key(
|
||||
session_id,
|
||||
@@ -396,6 +391,9 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.await;
|
||||
});
|
||||
|
||||
let client = client.flush_and_into_inner().await?;
|
||||
let compute = node.stream.flush_and_into_inner().await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
@@ -403,8 +401,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
compute: node.stream,
|
||||
client,
|
||||
compute,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id,
|
||||
@@ -418,28 +416,60 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
pub(crate) fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
cancel_key_data: CancelKeyData,
|
||||
pub(crate) async fn prepare_client_connection(
|
||||
node: &mut compute::PostgresConnection,
|
||||
key_data: &CancelKeyData,
|
||||
stream: &mut PqFeStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) {
|
||||
// Forward all deferred notices to the client.
|
||||
for notice in &node.delayed_notice {
|
||||
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
|
||||
buf.extend_from_slice(notice.as_bytes());
|
||||
});
|
||||
user_info: ComputeUserInfo,
|
||||
) -> Result<CancelClosure, std::io::Error> {
|
||||
use zerocopy::{FromBytes, IntoBytes};
|
||||
|
||||
use crate::pqproto::{BE_KEY_MESSAGE, BE_READY_MESSAGE};
|
||||
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
|
||||
loop {
|
||||
match node.stream.read_raw_be(1024).await {
|
||||
// parse backend keys, and substitute our own.
|
||||
Ok((tag @ BE_KEY_MESSAGE, msg)) => {
|
||||
stream.write_raw(8, tag, |b| b.extend_from_slice(key_data.as_bytes()));
|
||||
|
||||
let key_data = CancelKeyData::read_from_bytes(msg)
|
||||
.map_err(|_| std::io::Error::other("invalid msg len"))?;
|
||||
|
||||
process_id = (key_data.0.get() >> 32) as i32;
|
||||
secret_key = (key_data.0.get() & 0xffff_ffff) as i32;
|
||||
}
|
||||
// ready for query, we're done :)
|
||||
Ok((tag @ BE_READY_MESSAGE, msg)) => {
|
||||
stream.write_raw(msg.len(), tag, |b| b.extend_from_slice(msg.as_bytes()));
|
||||
break;
|
||||
}
|
||||
// either a notice or a parameter status.
|
||||
Ok((tag, msg)) => {
|
||||
stream.write_raw(msg.len(), tag, |b| b.extend_from_slice(msg.as_bytes()));
|
||||
}
|
||||
Err(PostgresError::Io(io)) => return Err(io),
|
||||
Err(PostgresError::Error(e)) => return Err(std::io::Error::other(e)),
|
||||
Err(_) => unreachable!("read_raw_be only returns IO or BackendError types"),
|
||||
}
|
||||
|
||||
if stream.write_buf_len() > 512 {
|
||||
stream.flush().await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
for (name, value) in &node.params {
|
||||
stream.write_message(BeMessage::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
});
|
||||
}
|
||||
|
||||
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
|
||||
stream.write_message(BeMessage::ReadyForQuery);
|
||||
Ok(CancelClosure::new(
|
||||
node.socket_addr,
|
||||
RawCancelToken {
|
||||
ssl_mode: node.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
},
|
||||
node.hostname.clone(),
|
||||
user_info,
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use std::error::Error;
|
||||
use std::io;
|
||||
|
||||
use bstr::ByteSlice;
|
||||
use tokio::time;
|
||||
|
||||
use crate::compute;
|
||||
use crate::config::RetryConfig;
|
||||
use crate::stream::{BackendError, PostgresError};
|
||||
|
||||
pub(crate) trait CouldRetry {
|
||||
/// Returns true if the error could be retried
|
||||
@@ -96,10 +98,55 @@ impl ShouldRetryWakeCompute for postgres_client::Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for BackendError {
|
||||
fn could_retry(&self) -> bool {
|
||||
let (code, _message) = self.parse();
|
||||
matches!(
|
||||
code,
|
||||
crate::pqproto::CONNECTION_FAILURE
|
||||
| crate::pqproto::CONNECTION_EXCEPTION
|
||||
| crate::pqproto::CONNECTION_DOES_NOT_EXIST
|
||||
| crate::pqproto::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ShouldRetryWakeCompute for BackendError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
let (code, message) = self.parse();
|
||||
|
||||
// Here are errors that happens after the user successfully authenticated to the database.
|
||||
let non_retriable_pg_errors = matches!(
|
||||
code,
|
||||
crate::pqproto::TOO_MANY_CONNECTIONS
|
||||
| crate::pqproto::OUT_OF_MEMORY
|
||||
| crate::pqproto::SYNTAX_ERROR
|
||||
| crate::pqproto::T_R_SERIALIZATION_FAILURE
|
||||
| crate::pqproto::INVALID_CATALOG_NAME
|
||||
| crate::pqproto::INVALID_SCHEMA_NAME
|
||||
| crate::pqproto::INVALID_PARAMETER_VALUE,
|
||||
);
|
||||
if non_retriable_pg_errors {
|
||||
return false;
|
||||
}
|
||||
|
||||
// PGBouncer errors that should not trigger a wake_compute retry.
|
||||
if code == crate::pqproto::PROTOCOL_VIOLATION {
|
||||
// Source for the error message:
|
||||
// https://github.com/pgbouncer/pgbouncer/blob/f15997fe3effe3a94ba8bcc1ea562e6117d1a131/src/client.c#L1070
|
||||
return message.contains_str("no more connections allowed (max_client_conn)");
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for compute::ConnectionError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.could_retry(),
|
||||
compute::ConnectionError::Postgres(PostgresError::Error(err)) => err.could_retry(),
|
||||
compute::ConnectionError::Postgres(PostgresError::Io(err)) => err.could_retry(),
|
||||
compute::ConnectionError::Postgres(PostgresError::Unexpected(_)) => false,
|
||||
compute::ConnectionError::Postgres(PostgresError::InvalidAuthMessage) => false,
|
||||
compute::ConnectionError::TlsError(err) => err.could_retry(),
|
||||
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
|
||||
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
|
||||
@@ -109,7 +156,12 @@ impl CouldRetry for compute::ConnectionError {
|
||||
impl ShouldRetryWakeCompute for compute::ConnectionError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
compute::ConnectionError::Postgres(err) => err.should_retry_wake_compute(),
|
||||
compute::ConnectionError::Postgres(PostgresError::Error(err)) => {
|
||||
err.should_retry_wake_compute()
|
||||
}
|
||||
compute::ConnectionError::Postgres(PostgresError::Io(_)) => true,
|
||||
compute::ConnectionError::Postgres(PostgresError::Unexpected(_)) => false,
|
||||
compute::ConnectionError::Postgres(PostgresError::InvalidAuthMessage) => false,
|
||||
// the cache entry was not checked for validity
|
||||
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
|
||||
_ => true,
|
||||
|
||||
@@ -25,6 +25,7 @@ use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
use crate::tls::server_config::CertResolver;
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
mod pq_backend;
|
||||
mod pq_frontend;
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
|
||||
pub use pq_backend::{BackendError, PostgresError, PqBeStream};
|
||||
pub use pq_frontend::PqFeStream;
|
||||
use rustls::ServerConfig;
|
||||
use thiserror::Error;
|
||||
|
||||
165
proxy/src/stream/pq_backend.rs
Normal file
165
proxy/src/stream/pq_backend.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
//! Postgres connection from backend, proxy is the frontend.
|
||||
|
||||
use std::io;
|
||||
|
||||
use bytes::Bytes;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::pqproto::{
|
||||
AuthTag, BE_AUTH_MESSAGE, BE_ERR_MESSAGE, BeTag, ErrorCode, SQLSTATE_INTERNAL_ERROR,
|
||||
StartupMessageParams, WriteBuf, read_message,
|
||||
};
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
pub struct PqBeStream<S> {
|
||||
stream: S,
|
||||
read: Vec<u8>,
|
||||
write: WriteBuf,
|
||||
}
|
||||
|
||||
impl<S> PqBeStream<S> {
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Construct a new libpq protocol wrapper and write the first startup message.
|
||||
pub fn new(stream: S, params: &StartupMessageParams) -> Self {
|
||||
let mut write = WriteBuf::new();
|
||||
write.startup(params);
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqBeStream<S> {
|
||||
/// Read a raw postgres packet from the backend, which will respect the max length requested,
|
||||
/// as well as handling postgres error messages.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_raw_be(&mut self, max: u32) -> Result<(BeTag, &mut [u8]), PostgresError> {
|
||||
let (tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
match BeTag(tag) {
|
||||
BE_ERR_MESSAGE => Err(PostgresError::Error(BackendError {
|
||||
data: msg.to_vec().into(),
|
||||
})),
|
||||
tag => Ok((tag, msg)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_be_expect(
|
||||
&mut self,
|
||||
tag: BeTag,
|
||||
max: u32,
|
||||
) -> Result<&mut [u8], PostgresError> {
|
||||
let (actual_tag, msg) = self.read_raw_be(max).await?;
|
||||
if actual_tag != tag {
|
||||
return Err(PostgresError::Unexpected(UnexpectedMessage {
|
||||
expected: tag,
|
||||
tag: actual_tag,
|
||||
data: msg.to_vec().into(),
|
||||
}));
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
/// Read a postgres backend auth message.
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_auth_message(&mut self) -> Result<(AuthTag, &mut [u8]), PostgresError> {
|
||||
const MAX_AUTH_LENGTH: u32 = 512;
|
||||
|
||||
self.read_raw_be_expect(BE_AUTH_MESSAGE, MAX_AUTH_LENGTH)
|
||||
.await?
|
||||
.split_first_chunk_mut()
|
||||
.map(|(tag, msg)| (AuthTag(i32::from_be_bytes(*tag)), msg))
|
||||
.ok_or(PostgresError::InvalidAuthMessage)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqBeStream<S> {
|
||||
/// Write a raw message to the internal buffer.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag, f);
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.write_all_buf(&mut self.write).await?;
|
||||
self.write.reset();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum PostgresError {
|
||||
#[error("postgres responded with error {0}")]
|
||||
Error(#[from] BackendError),
|
||||
#[error("postgres responded with an unexpected message: {0}")]
|
||||
Unexpected(#[from] UnexpectedMessage),
|
||||
#[error("postgres responded with an invalid authentication message")]
|
||||
InvalidAuthMessage,
|
||||
#[error("IO error from compute: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("expected {expected}, got {tag} with data {data:?}")]
|
||||
pub struct UnexpectedMessage {
|
||||
expected: BeTag,
|
||||
tag: BeTag,
|
||||
data: Bytes,
|
||||
}
|
||||
|
||||
pub struct BackendError {
|
||||
data: Bytes,
|
||||
}
|
||||
|
||||
impl BackendError {
|
||||
pub fn parse(&self) -> (ErrorCode, &[u8]) {
|
||||
let mut code = &[] as &[u8];
|
||||
let mut message = &[] as &[u8];
|
||||
|
||||
for param in self.data.split(|b| *b == 0) {
|
||||
match param {
|
||||
[b'M', rest @ ..] => message = rest,
|
||||
[b'C', rest @ ..] => code = rest,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let code = code.try_into().map_or(SQLSTATE_INTERNAL_ERROR, ErrorCode);
|
||||
|
||||
(code, message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BackendError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self}")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", &self.data)
|
||||
}
|
||||
}
|
||||
impl std::error::Error for BackendError {}
|
||||
@@ -6,8 +6,8 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::error::{ErrorKind, UserFacingError};
|
||||
use crate::pqproto::{
|
||||
BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, FeTag, SQLSTATE_INTERNAL_ERROR, WriteBuf,
|
||||
read_message, read_startup,
|
||||
BeMessage, BeTag, FE_PASSWORD_MESSAGE, FeStartupPacket, FeTag, SQLSTATE_INTERNAL_ERROR,
|
||||
WriteBuf, read_message, read_startup,
|
||||
};
|
||||
use crate::stream::ReportedError;
|
||||
|
||||
@@ -32,6 +32,10 @@ impl<S> PqFeStream<S> {
|
||||
write: WriteBuf::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_buf_len(&self) -> usize {
|
||||
self.write.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> PqFeStream<S> {
|
||||
@@ -103,8 +107,8 @@ impl<S: AsyncWrite + Unpin> PqFeStream<S> {
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag, f);
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: BeTag, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag.0, f);
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer
|
||||
@@ -150,6 +154,7 @@ impl<S: AsyncWrite + Unpin> PqFeStream<S> {
|
||||
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
|
||||
tracing::info!(
|
||||
kind = error_kind.to_metric_label(),
|
||||
%error,
|
||||
msg,
|
||||
"forwarding error to user"
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user