start integrating ktls

This commit is contained in:
Conrad Ludgate
2024-08-21 14:11:58 +01:00
parent e171fd805b
commit 987a859352
7 changed files with 74 additions and 144 deletions

View File

@@ -86,8 +86,7 @@ impl ComputeUserInfoMaybeEndpoint {
pub fn parse(
ctx: &RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,
common_names: Option<&HashSet<String>>,
endpoint_from_domain: Option<EndpointId>,
) -> Result<Self, ComputeUserInfoParseError> {
// Some parameters are stored in the startup message.
let get_param = |key| {
@@ -111,16 +110,7 @@ impl ComputeUserInfoMaybeEndpoint {
})
.map(|name| name.into());
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
endpoint_sni(sni_str, cn)?
} else {
None
}
} else {
None
};
let is_sni = endpoint_from_domain.is_some();
let endpoint = match (endpoint_option, endpoint_from_domain) {
// Invariant: if we have both project name variants, they should match.
(Some(option), Some(domain)) if option != domain => {
@@ -143,7 +133,7 @@ impl ComputeUserInfoMaybeEndpoint {
let metrics = Metrics::get();
info!(%user, "credentials");
if sni.is_some() {
if is_sni {
info!("Connection with sni");
metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
} else if endpoint.is_some() {
@@ -255,7 +245,7 @@ mod tests {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -270,7 +260,7 @@ mod tests {
("foo", "bar"), // should be ignored
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -281,12 +271,8 @@ mod tests {
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("foo.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("foo".into()))?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
@@ -302,7 +288,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -317,7 +303,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -335,7 +321,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -350,7 +336,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -361,49 +347,21 @@ mod tests {
fn parse_projects_identical() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
let sni = Some("baz.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("baz".into()))?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
Ok(())
}
#[test]
fn parse_multi_common_names() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.a.com");
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.b.com");
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
Ok(())
}
#[test]
fn parse_projects_different() {
let options =
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
let sni = Some("second.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("second".into()))
.expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
@@ -414,24 +372,6 @@ mod tests {
}
}
#[test]
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
}
_ => panic!("bad error: {err:?}"),
}
}
#[test]
fn parse_neon_options() -> anyhow::Result<()> {
let options = StartupMessageParams::new([
@@ -439,11 +379,9 @@ mod tests {
("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
]);
let sni = Some("project.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("project".into()))?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
assert_eq!(
user_info.options.get_cache_key("project"),

View File

@@ -7,7 +7,6 @@ use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
@@ -20,6 +19,7 @@ use futures::TryFutureExt;
use proxy::stream::{PqStream, Stream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::server::TlsStream;
use tokio_util::sync::CancellationToken;
use utils::{project_git_version, sentry_init::init_sentry};
@@ -72,7 +72,7 @@ async fn main() -> anyhow::Result<()> {
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
// Configure TLS
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
let tls_config = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -102,19 +102,14 @@ async fn main() -> anyhow::Result<()> {
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
(tls_config, tls_server_end_point)
Arc::new(
rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?,
)
}
_ => bail!("tls-key and tls-cert must be specified"),
};
@@ -129,7 +124,6 @@ async fn main() -> anyhow::Result<()> {
let main = tokio::spawn(task_main(
Arc::new(destination),
tls_config,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
));
@@ -151,7 +145,6 @@ async fn main() -> anyhow::Result<()> {
async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -183,7 +176,7 @@ async fn task_main(
proxy::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
handle_client(ctx, dest_suffix, tls_config, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -208,8 +201,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestMonitoring,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
) -> anyhow::Result<Box<TlsStream<S>>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
let msg = stream.read_startup_packet().await?;
@@ -235,13 +227,10 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
Ok(Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
))
}
unexpected => {
info!(
@@ -259,15 +248,18 @@ async fn handle_client(
ctx: RequestMonitoring,
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
let sni = tls_stream
.get_ref()
.1
.server_name()
.ok_or(anyhow!("SNI missing"))?;
let dest: Vec<&str> = sni
.split_once('.')
.context("invalid SNI")?

View File

@@ -21,7 +21,7 @@ use crate::{
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
stream::PqStream,
EndpointCacheKey,
};
use futures::TryFutureExt;
@@ -191,13 +191,6 @@ impl ClientMode {
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
@@ -261,9 +254,9 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) =
let (mut stream, ep, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Startup(stream, ep, params) => (stream, ep, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id())
@@ -275,15 +268,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, ep))
.transpose();
let user_info = match result {

View File

@@ -15,6 +15,7 @@ use crate::{
metrics::Metrics,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
EndpointId,
};
#[derive(Error, Debug)]
@@ -58,7 +59,11 @@ impl ReportableError for HandshakeError {
}
pub enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Startup(
PqStream<Stream<S>>,
Option<EndpointId>,
StartupMessageParams,
),
Cancel(CancelKeyData),
}
@@ -80,6 +85,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream));
let mut ep = None;
loop {
let msg = stream.read_startup_packet().await?;
match msg {
@@ -145,11 +151,11 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let conn_info = tls_stream.get_ref().1;
// try parse endpoint
let ep = conn_info
ep = conn_info
.server_name()
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
if let Some(ep) = ep {
ctx.set_endpoint_id(ep);
if let Some(ep) = &ep {
ctx.set_endpoint_id(ep.clone());
}
// check the ALPN, if exists, as required.
@@ -170,7 +176,10 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
#[cfg(not(target_os = "linux"))]
tls: Box::new(tls_stream),
#[cfg(target_os = "linux")]
tls: {},
tls_server_end_point,
},
read_buf,
@@ -207,7 +216,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, ep, params));
}
// downgrade protocol version
FeStartupPacket::StartupMessage { params, version }
@@ -238,7 +247,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake; unsupported minor version requested"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, ep, params));
}
FeStartupPacket::StartupMessage { version, .. } => {
warn!(

View File

@@ -16,6 +16,7 @@ use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::stream::Stream;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
@@ -180,7 +181,7 @@ async fn dummy_proxy(
let (client, _) = read_proxy_protocol(client).await?;
let mut stream =
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Startup(stream, ..) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};

View File

@@ -43,7 +43,7 @@ async fn proxy_mitm(
.await
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Startup(stream, _ep, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};

View File

@@ -178,7 +178,12 @@ pub enum Stream<S> {
Raw { raw: S },
Tls {
/// We box [`TlsStream`] since it can be quite large.
#[cfg(not(target_os = "linux"))]
tls: Box<TlsStream<S>>,
#[cfg(target_os = "linux")]
tls: ktls::KtlsStream<S>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
},
@@ -192,14 +197,6 @@ impl<S> Stream<S> {
Self::Raw { raw }
}
/// Return SNI hostname when it's available.
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
}
}
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
match self {
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
@@ -229,14 +226,18 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
record_handshake_error: bool,
) -> Result<TlsStream<S>, StreamUpgradeError> {
match self {
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
.accept(raw)
.await
.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc();
}
})?),
Stream::Raw { raw } => {
let stream = tokio_rustls::TlsAcceptor::from(cfg)
.accept(raw)
.await
.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc();
}
})?;
Ok(stream)
}
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
}
}