diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 849e7d65e8..8bc29642a8 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -86,8 +86,7 @@ impl ComputeUserInfoMaybeEndpoint { pub fn parse( ctx: &RequestMonitoring, params: &StartupMessageParams, - sni: Option<&str>, - common_names: Option<&HashSet>, + endpoint_from_domain: Option, ) -> Result { // Some parameters are stored in the startup message. let get_param = |key| { @@ -111,16 +110,7 @@ impl ComputeUserInfoMaybeEndpoint { }) .map(|name| name.into()); - let endpoint_from_domain = if let Some(sni_str) = sni { - if let Some(cn) = common_names { - endpoint_sni(sni_str, cn)? - } else { - None - } - } else { - None - }; - + let is_sni = endpoint_from_domain.is_some(); let endpoint = match (endpoint_option, endpoint_from_domain) { // Invariant: if we have both project name variants, they should match. (Some(option), Some(domain)) if option != domain => { @@ -143,7 +133,7 @@ impl ComputeUserInfoMaybeEndpoint { let metrics = Metrics::get(); info!(%user, "credentials"); - if sni.is_some() { + if is_sni { info!("Connection with sni"); metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni); } else if endpoint.is_some() { @@ -255,7 +245,7 @@ mod tests { // According to postgresql, only `user` should be required. let options = StartupMessageParams::new([("user", "john_doe")]); let ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id, None); @@ -270,7 +260,7 @@ mod tests { ("foo", "bar"), // should be ignored ]); let ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id, None); @@ -281,12 +271,8 @@ mod tests { fn parse_project_from_sni() -> anyhow::Result<()> { let options = StartupMessageParams::new([("user", "john_doe")]); - let sni = Some("foo.localhost"); - let common_names = Some(["localhost".into()].into()); - let ctx = RequestMonitoring::test(); - let user_info = - ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("foo".into()))?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("foo")); assert_eq!(user_info.options.get_cache_key("foo"), "foo"); @@ -302,7 +288,7 @@ mod tests { ]); let ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("bar")); @@ -317,7 +303,7 @@ mod tests { ]); let ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("bar")); @@ -335,7 +321,7 @@ mod tests { ]); let ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?; assert_eq!(user_info.user, "john_doe"); assert!(user_info.endpoint_id.is_none()); @@ -350,7 +336,7 @@ mod tests { ]); let ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?; assert_eq!(user_info.user, "john_doe"); assert!(user_info.endpoint_id.is_none()); @@ -361,49 +347,21 @@ mod tests { fn parse_projects_identical() -> anyhow::Result<()> { let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]); - let sni = Some("baz.localhost"); - let common_names = Some(["localhost".into()].into()); - let ctx = RequestMonitoring::test(); - let user_info = - ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("baz".into()))?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("baz")); Ok(()) } - #[test] - fn parse_multi_common_names() -> anyhow::Result<()> { - let options = StartupMessageParams::new([("user", "john_doe")]); - - let common_names = Some(["a.com".into(), "b.com".into()].into()); - let sni = Some("p1.a.com"); - let ctx = RequestMonitoring::test(); - let user_info = - ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; - assert_eq!(user_info.endpoint_id.as_deref(), Some("p1")); - - let common_names = Some(["a.com".into(), "b.com".into()].into()); - let sni = Some("p1.b.com"); - let ctx = RequestMonitoring::test(); - let user_info = - ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; - assert_eq!(user_info.endpoint_id.as_deref(), Some("p1")); - - Ok(()) - } - #[test] fn parse_projects_different() { let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]); - let sni = Some("second.localhost"); - let common_names = Some(["localhost".into()].into()); - let ctx = RequestMonitoring::test(); - let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref()) + let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("second".into())) .expect_err("should fail"); match err { InconsistentProjectNames { domain, option } => { @@ -414,24 +372,6 @@ mod tests { } } - #[test] - fn parse_inconsistent_sni() { - let options = StartupMessageParams::new([("user", "john_doe")]); - - let sni = Some("project.localhost"); - let common_names = Some(["example.com".into()].into()); - - let ctx = RequestMonitoring::test(); - let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref()) - .expect_err("should fail"); - match err { - UnknownCommonName { cn } => { - assert_eq!(cn, "localhost"); - } - _ => panic!("bad error: {err:?}"), - } - } - #[test] fn parse_neon_options() -> anyhow::Result<()> { let options = StartupMessageParams::new([ @@ -439,11 +379,9 @@ mod tests { ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"), ]); - let sni = Some("project.localhost"); - let common_names = Some(["localhost".into()].into()); let ctx = RequestMonitoring::test(); let user_info = - ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; + ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("project".into()))?; assert_eq!(user_info.endpoint_id.as_deref(), Some("project")); assert_eq!( user_info.options.get_cache_key("project"), diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 1038fa5116..4c67b206b1 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -7,7 +7,6 @@ use std::{net::SocketAddr, sync::Arc}; use futures::future::Either; use itertools::Itertools; -use proxy::config::TlsServerEndPoint; use proxy::context::RequestMonitoring; use proxy::metrics::{Metrics, ThreadPoolMetrics}; use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource}; @@ -20,6 +19,7 @@ use futures::TryFutureExt; use proxy::stream::{PqStream, Stream}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; use utils::{project_git_version, sentry_init::init_sentry}; @@ -72,7 +72,7 @@ async fn main() -> anyhow::Result<()> { let destination: String = args.get_one::("dest").unwrap().parse()?; // Configure TLS - let (tls_config, tls_server_end_point): (Arc, TlsServerEndPoint) = match ( + let tls_config = match ( args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { @@ -102,19 +102,14 @@ async fn main() -> anyhow::Result<()> { })? }; - // needed for channel bindings - let first_cert = cert_chain.first().context("missing certificate")?; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - - let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[ - &rustls::version::TLS13, - &rustls::version::TLS12, - ]) - .with_no_client_auth() - .with_single_cert(cert_chain, key)? - .into(); - - (tls_config, tls_server_end_point) + Arc::new( + rustls::ServerConfig::builder_with_protocol_versions(&[ + &rustls::version::TLS13, + &rustls::version::TLS12, + ]) + .with_no_client_auth() + .with_single_cert(cert_chain, key)?, + ) } _ => bail!("tls-key and tls-cert must be specified"), }; @@ -129,7 +124,6 @@ async fn main() -> anyhow::Result<()> { let main = tokio::spawn(task_main( Arc::new(destination), tls_config, - tls_server_end_point, proxy_listener, cancellation_token.clone(), )); @@ -151,7 +145,6 @@ async fn main() -> anyhow::Result<()> { async fn task_main( dest_suffix: Arc, tls_config: Arc, - tls_server_end_point: TlsServerEndPoint, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -183,7 +176,7 @@ async fn task_main( proxy::metrics::Protocol::SniRouter, "sni", ); - handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await + handle_client(ctx, dest_suffix, tls_config, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -208,8 +201,7 @@ async fn ssl_handshake( ctx: &RequestMonitoring, raw_stream: S, tls_config: Arc, - tls_server_end_point: TlsServerEndPoint, -) -> anyhow::Result> { +) -> anyhow::Result>> { let mut stream = PqStream::new(Stream::from_raw(raw_stream)); let msg = stream.read_startup_packet().await?; @@ -235,13 +227,10 @@ async fn ssl_handshake( bail!("data is sent before server replied with EncryptionResponse"); } - Ok(Stream::Tls { - tls: Box::new( - raw.upgrade(tls_config, !ctx.has_private_peer_addr()) - .await?, - ), - tls_server_end_point, - }) + Ok(Box::new( + raw.upgrade(tls_config, !ctx.has_private_peer_addr()) + .await?, + )) } unexpected => { info!( @@ -259,15 +248,18 @@ async fn handle_client( ctx: RequestMonitoring, dest_suffix: Arc, tls_config: Arc, - tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?; + let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain` - let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?; + let sni = tls_stream + .get_ref() + .1 + .server_name() + .ok_or(anyhow!("SNI missing"))?; let dest: Vec<&str> = sni .split_once('.') .context("invalid SNI")? diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 2182f38fe7..5d51f14d9a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -21,7 +21,7 @@ use crate::{ protocol2::read_proxy_protocol, proxy::handshake::{handshake, HandshakeData}, rate_limiter::EndpointRateLimiter, - stream::{PqStream, Stream}, + stream::PqStream, EndpointCacheKey, }; use futures::TryFutureExt; @@ -191,13 +191,6 @@ impl ClientMode { } } - fn hostname<'a, S>(&'a self, s: &'a Stream) -> Option<&'a str> { - match self { - ClientMode::Tcp => s.sni_hostname(), - ClientMode::Websockets { hostname } => hostname.as_deref(), - } - } - fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> { match self { ClientMode::Tcp => tls, @@ -261,9 +254,9 @@ pub async fn handle_client( let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error); - let (mut stream, params) = + let (mut stream, ep, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { - HandshakeData::Startup(stream, params) => (stream, params), + HandshakeData::Startup(stream, ep, params) => (stream, ep, params), HandshakeData::Cancel(cancel_key_data) => { return Ok(cancellation_handler .cancel_session(cancel_key_data, ctx.session_id()) @@ -275,15 +268,11 @@ pub async fn handle_client( ctx.set_db_options(params.clone()); - let hostname = mode.hostname(stream.get_ref()); - - let common_names = tls.map(|tls| &tls.common_names); - // Extract credentials which we're going to use for auth. let result = config .auth_backend .as_ref() - .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) + .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, ep)) .transpose(); let user_info = match result { diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 27a72f8072..d6b6464fb9 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -15,6 +15,7 @@ use crate::{ metrics::Metrics, proxy::ERR_INSECURE_CONNECTION, stream::{PqStream, Stream, StreamUpgradeError}, + EndpointId, }; #[derive(Error, Debug)] @@ -58,7 +59,11 @@ impl ReportableError for HandshakeError { } pub enum HandshakeData { - Startup(PqStream>, StartupMessageParams), + Startup( + PqStream>, + Option, + StartupMessageParams, + ), Cancel(CancelKeyData), } @@ -80,6 +85,7 @@ pub async fn handshake( const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0); let mut stream = PqStream::new(Stream::from_raw(stream)); + let mut ep = None; loop { let msg = stream.read_startup_packet().await?; match msg { @@ -145,11 +151,11 @@ pub async fn handshake( let conn_info = tls_stream.get_ref().1; // try parse endpoint - let ep = conn_info + ep = conn_info .server_name() .and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten()); - if let Some(ep) = ep { - ctx.set_endpoint_id(ep); + if let Some(ep) = &ep { + ctx.set_endpoint_id(ep.clone()); } // check the ALPN, if exists, as required. @@ -170,7 +176,10 @@ pub async fn handshake( stream = PqStream { framed: Framed { stream: Stream::Tls { + #[cfg(not(target_os = "linux"))] tls: Box::new(tls_stream), + #[cfg(target_os = "linux")] + tls: {}, tls_server_end_point, }, read_buf, @@ -207,7 +216,7 @@ pub async fn handshake( session_type = "normal", "successful handshake" ); - break Ok(HandshakeData::Startup(stream, params)); + break Ok(HandshakeData::Startup(stream, ep, params)); } // downgrade protocol version FeStartupPacket::StartupMessage { params, version } @@ -238,7 +247,7 @@ pub async fn handshake( session_type = "normal", "successful handshake; unsupported minor version requested" ); - break Ok(HandshakeData::Startup(stream, params)); + break Ok(HandshakeData::Startup(stream, ep, params)); } FeStartupPacket::StartupMessage { version, .. } => { warn!( diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index b52e21f5c7..98e2e11096 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -16,6 +16,7 @@ use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status}; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::error::ErrorKind; +use crate::stream::Stream; use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId}; use anyhow::{bail, Context}; use async_trait::async_trait; @@ -180,7 +181,7 @@ async fn dummy_proxy( let (client, _) = read_proxy_protocol(client).await?; let mut stream = match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? { - HandshakeData::Startup(stream, _) => stream, + HandshakeData::Startup(stream, ..) => stream, HandshakeData::Cancel(_) => bail!("cancellation not supported"), }; diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 2d752b9183..16867d7473 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -43,7 +43,7 @@ async fn proxy_mitm( .await .unwrap() { - HandshakeData::Startup(stream, params) => (stream, params), + HandshakeData::Startup(stream, _ep, params) => (stream, params), HandshakeData::Cancel(_) => panic!("cancellation not supported"), }; diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 7809d2e574..e8c68da3f8 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -178,7 +178,12 @@ pub enum Stream { Raw { raw: S }, Tls { /// We box [`TlsStream`] since it can be quite large. + #[cfg(not(target_os = "linux"))] tls: Box>, + + #[cfg(target_os = "linux")] + tls: ktls::KtlsStream, + /// Channel binding parameter tls_server_end_point: TlsServerEndPoint, }, @@ -192,14 +197,6 @@ impl Stream { Self::Raw { raw } } - /// Return SNI hostname when it's available. - pub fn sni_hostname(&self) -> Option<&str> { - match self { - Stream::Raw { .. } => None, - Stream::Tls { tls, .. } => tls.get_ref().1.server_name(), - } - } - pub fn tls_server_end_point(&self) -> TlsServerEndPoint { match self { Stream::Raw { .. } => TlsServerEndPoint::Undefined, @@ -229,14 +226,18 @@ impl Stream { record_handshake_error: bool, ) -> Result, StreamUpgradeError> { match self { - Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg) - .accept(raw) - .await - .inspect_err(|_| { - if record_handshake_error { - Metrics::get().proxy.tls_handshake_failures.inc(); - } - })?), + Stream::Raw { raw } => { + let stream = tokio_rustls::TlsAcceptor::from(cfg) + .accept(raw) + .await + .inspect_err(|_| { + if record_handshake_error { + Metrics::get().proxy.tls_handshake_failures.inc(); + } + })?; + + Ok(stream) + } Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls), } }