proxy: pg17 fixes (#8321)

## Problem

#7809 - we do not support sslnegotiation=direct
#7810 - we do not support negotiating down the protocol extensions.

## Summary of changes

1. Same as postgres, check the first startup packet byte for tls header
`0x16`, and check the ALPN.
2. Tell clients using protocol >3.0 to downgrade
This commit is contained in:
Conrad Ludgate
2024-07-10 09:10:29 +01:00
committed by GitHub
parent 1a49f1c15c
commit fe13fccdc2
6 changed files with 222 additions and 54 deletions

View File

@@ -216,10 +216,11 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
use pq_proto::FeStartupPacket::*;
match msg {
SslRequest => {
SslRequest { direct: false } => {
stream
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
.await?;
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.

View File

@@ -75,6 +75,9 @@ impl TlsConfig {
}
}
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L159>
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &str,
@@ -111,16 +114,17 @@ pub fn configure_tls(
let cert_resolver = Arc::new(cert_resolver);
// allow TLS 1.2 to be compatible with older client libraries
let config = rustls::ServerConfig::builder_with_protocol_versions(&[
let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone())
.into();
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
Ok(TlsConfig {
config,
config: Arc::new(config),
common_names,
cert_resolver,
})

View File

@@ -1,11 +1,17 @@
use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
use bytes::Buf;
use pq_proto::{
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
StartupMessageParams,
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use tracing::{info, warn};
use crate::{
config::TlsConfig,
auth::endpoint_sni,
config::{TlsConfig, PG_ALPN_PROTOCOL},
error::ReportableError,
metrics::Metrics,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
};
@@ -68,6 +74,9 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
@@ -75,40 +84,96 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
SslRequest { direct } => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
let have_tls = tls.is_some();
if !direct {
stream
.write_message(&Be::EncryptionResponse(have_tls))
.await?;
} else if !have_tls {
return Err(HandshakeError::ProtocolViolation);
}
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empy.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
let Framed {
stream: raw,
read_buf,
write_buf,
} = stream.framed;
let Stream::Raw { raw } = raw else {
return Err(HandshakeError::StreamUpgradeError(
StreamUpgradeError::AlreadyTls,
));
};
let mut read_buf = read_buf.reader();
let mut res = Ok(());
let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config())
.accept_with(raw, |session| {
// push the early data to the tls session
while !read_buf.get_ref().is_empty() {
match session.read_tls(&mut read_buf) {
Ok(_) => {}
Err(e) => {
res = Err(e);
break;
}
}
}
});
res?;
let read_buf = read_buf.into_inner();
if !read_buf.is_empty() {
return Err(HandshakeError::EarlyData);
}
let tls_stream = raw
.upgrade(tls.to_server_config(), record_handshake_error)
.await?;
let tls_stream = accept.await.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc()
}
})?;
let conn_info = tls_stream.get_ref().1;
// check the ALPN, if exists, as required.
match conn_info.alpn_protocol() {
None | Some(PG_ALPN_PROTOCOL) => {}
Some(other) => {
// try parse ep for better error
let ep = conn_info.server_name().and_then(|sni| {
endpoint_sni(sni, &tls.common_names).ok().flatten()
});
let alpn = String::from_utf8_lossy(other);
warn!(?ep, %alpn, "unexpected ALPN");
return Err(HandshakeError::ProtocolViolation);
}
}
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.resolve(conn_info.server_name())
.ok_or(HandshakeError::MissingCertificate)?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
},
read_buf,
write_buf,
},
};
}
}
_ => return Err(HandshakeError::ProtocolViolation),
@@ -122,7 +187,9 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
_ => return Err(HandshakeError::ProtocolViolation),
},
StartupMessage { params, .. } => {
StartupMessage { params, version }
if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST =>
{
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
@@ -131,9 +198,48 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
.await?;
}
info!(session_type = "normal", "successful handshake");
info!(?version, session_type = "normal", "successful handshake");
break Ok(HandshakeData::Startup(stream, params));
}
// downgrade protocol version
StartupMessage { params, version }
if version.major() == 3 && version > PG_PROTOCOL_LATEST =>
{
warn!(?version, "unsupported minor version");
// no protocol extensions are supported.
// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
let mut unsupported = vec![];
for (k, _) in params.iter() {
if k.starts_with("_pq_.") {
unsupported.push(k);
}
}
// TODO: remove unsupported options so we don't send them to compute.
stream
.write_message(&Be::NegotiateProtocolVersion {
version: PG_PROTOCOL_LATEST,
options: &unsupported,
})
.await?;
info!(
?version,
session_type = "normal",
"successful handshake; unsupported minor version requested"
);
break Ok(HandshakeData::Startup(stream, params));
}
StartupMessage { version, .. } => {
warn!(
?version,
session_type = "normal",
"unsuccessful handshake; unsupported version"
);
return Err(HandshakeError::ProtocolViolation);
}
CancelRequest(cancel_key_data) => {
info!(session_type = "cancellation", "successful handshake");
break Ok(HandshakeData::Cancel(cancel_key_data));