diff --git a/proxy/src/config.rs b/proxy/src/config.rs index f932df4058..9d2d478965 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -33,6 +33,8 @@ pub struct TlsConfig { pub config: Arc, pub common_names: Option>, pub cert_resolver: Arc, + pub handshake_timeout: Duration, + pub max_handshaking: usize, } pub struct HttpConfig { @@ -98,6 +100,8 @@ pub fn configure_tls( config, common_names: Some(common_names), cert_resolver, + handshake_timeout: tls_listener::DEFAULT_HANDSHAKE_TIMEOUT, + max_handshaking: tls_listener::DEFAULT_MAX_HANDSHAKES, }) } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index da65065179..34ae32d176 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -28,7 +28,7 @@ use prometheus::{ IntGaugeVec, }; use regex::Regex; -use std::{error::Error, io, net::IpAddr, ops::ControlFlow, sync::Arc, time::Instant}; +use std::{error::Error, io, net::IpAddr, ops::ControlFlow, sync::Arc}; use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt}, time, @@ -154,7 +154,7 @@ pub static ALLOWED_IPS_NUMBER: Lazy = Lazy::new(|| { pub struct LatencyTimer { // time since the stopwatch was started - start: Option, + start: Option, // accumulated time on the stopwatch accumulated: std::time::Duration, // label data @@ -171,7 +171,7 @@ pub struct LatencyTimerPause<'a> { impl LatencyTimer { pub fn new(protocol: &'static str) -> Self { Self { - start: Some(Instant::now()), + start: Some(time::Instant::now()), accumulated: std::time::Duration::ZERO, protocol, cache_miss: false, @@ -205,7 +205,7 @@ impl LatencyTimer { impl Drop for LatencyTimerPause<'_> { fn drop(&mut self) { // start the stopwatch again - self.timer.start = Some(Instant::now()); + self.timer.start = Some(time::Instant::now()); } } @@ -467,9 +467,14 @@ async fn handshake( // Client may try upgrading to each protocol only once let (mut tried_ssl, mut tried_gss) = (false, false); + let handshake_timeout = tls + .map(|tls| tls.handshake_timeout) + .unwrap_or(tls_listener::DEFAULT_HANDSHAKE_TIMEOUT); + let deadline = time::Instant::now() + handshake_timeout; + let mut stream = PqStream::new(Stream::from_raw(stream)); loop { - let msg = stream.read_startup_packet().await?; + let msg = tokio::time::timeout_at(deadline, stream.read_startup_packet()).await??; info!("received {msg:?}"); use FeStartupPacket::*; @@ -495,7 +500,9 @@ async fn handshake( if !read_buf.is_empty() { bail!("data is sent before server replied with EncryptionResponse"); } - let tls_stream = raw.upgrade(tls.to_server_config()).await?; + let tls_stream = + tokio::time::timeout_at(deadline, raw.upgrade(tls.to_server_config())) + .await??; let (_, tls_server_end_point) = tls .cert_resolver diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 4691abbfb9..47418b8e94 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -85,6 +85,8 @@ fn generate_tls_config<'a>( config, common_names, cert_resolver: Arc::new(cert_resolver), + handshake_timeout: tls_listener::DEFAULT_HANDSHAKE_TIMEOUT, + max_handshaking: tls_listener::DEFAULT_MAX_HANDSHAKES, } }; diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 870e9c1103..72835b6d48 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -29,7 +29,6 @@ use hyper::{ use std::net::IpAddr; use std::task::Poll; use std::{future::ready, sync::Arc}; -use tls_listener::TlsListener; use tokio::net::TcpListener; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; @@ -59,14 +58,15 @@ pub async fn task_main( } }); - let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config()); - let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config { - Some(config) => config.into(), + // let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config()); + let tls_config = match config.tls_config.as_ref() { + Some(config) => config, None => { warn!("TLS config is missing, WebSocket Secure server will not be started"); return Ok(()); } }; + let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into(); let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?; let _ = addr_incoming.set_nodelay(true); @@ -77,14 +77,17 @@ pub async fn task_main( let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); ws_connections.close(); // allows `ws_connections.wait to complete` - let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { - if let Err(err) = conn { - error!("failed to accept TLS connection for websockets: {err:?}"); - ready(false) - } else { - ready(true) - } - }); + let tls_listener = tls_listener::builder(tls_acceptor) + .handshake_timeout(tls_config.handshake_timeout) + .listen(addr_incoming) + .filter(|conn| { + if let Err(err) = conn { + error!("failed to accept TLS connection for websockets: {err:?}"); + ready(false) + } else { + ready(true) + } + }); let make_svc = hyper::service::make_service_fn( |stream: &tokio_rustls::server::TlsStream>| {