add handshake timeouts

This commit is contained in:
Conrad Ludgate
2023-12-15 16:06:32 +00:00
parent 98629841e0
commit d7e6a319bb
4 changed files with 34 additions and 18 deletions

View File

@@ -33,6 +33,8 @@ pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: Option<HashSet<String>>,
pub cert_resolver: Arc<CertResolver>,
pub handshake_timeout: Duration,
pub max_handshaking: usize,
}
pub struct HttpConfig {
@@ -98,6 +100,8 @@ pub fn configure_tls(
config,
common_names: Some(common_names),
cert_resolver,
handshake_timeout: tls_listener::DEFAULT_HANDSHAKE_TIMEOUT,
max_handshaking: tls_listener::DEFAULT_MAX_HANDSHAKES,
})
}

View File

@@ -28,7 +28,7 @@ use prometheus::{
IntGaugeVec,
};
use regex::Regex;
use std::{error::Error, io, net::IpAddr, ops::ControlFlow, sync::Arc, time::Instant};
use std::{error::Error, io, net::IpAddr, ops::ControlFlow, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
time,
@@ -154,7 +154,7 @@ pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
pub struct LatencyTimer {
// time since the stopwatch was started
start: Option<Instant>,
start: Option<time::Instant>,
// accumulated time on the stopwatch
accumulated: std::time::Duration,
// label data
@@ -171,7 +171,7 @@ pub struct LatencyTimerPause<'a> {
impl LatencyTimer {
pub fn new(protocol: &'static str) -> Self {
Self {
start: Some(Instant::now()),
start: Some(time::Instant::now()),
accumulated: std::time::Duration::ZERO,
protocol,
cache_miss: false,
@@ -205,7 +205,7 @@ impl LatencyTimer {
impl Drop for LatencyTimerPause<'_> {
fn drop(&mut self) {
// start the stopwatch again
self.timer.start = Some(Instant::now());
self.timer.start = Some(time::Instant::now());
}
}
@@ -467,9 +467,14 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let handshake_timeout = tls
.map(|tls| tls.handshake_timeout)
.unwrap_or(tls_listener::DEFAULT_HANDSHAKE_TIMEOUT);
let deadline = time::Instant::now() + handshake_timeout;
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
let msg = tokio::time::timeout_at(deadline, stream.read_startup_packet()).await??;
info!("received {msg:?}");
use FeStartupPacket::*;
@@ -495,7 +500,9 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let tls_stream =
tokio::time::timeout_at(deadline, raw.upgrade(tls.to_server_config()))
.await??;
let (_, tls_server_end_point) = tls
.cert_resolver

View File

@@ -85,6 +85,8 @@ fn generate_tls_config<'a>(
config,
common_names,
cert_resolver: Arc::new(cert_resolver),
handshake_timeout: tls_listener::DEFAULT_HANDSHAKE_TIMEOUT,
max_handshaking: tls_listener::DEFAULT_MAX_HANDSHAKES,
}
};

View File

@@ -29,7 +29,6 @@ use hyper::{
use std::net::IpAddr;
use std::task::Poll;
use std::{future::ready, sync::Arc};
use tls_listener::TlsListener;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
@@ -59,14 +58,15 @@ pub async fn task_main(
}
});
let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
Some(config) => config.into(),
// let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
let tls_config = match config.tls_config.as_ref() {
Some(config) => config,
None => {
warn!("TLS config is missing, WebSocket Secure server will not be started");
return Ok(());
}
};
let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
let _ = addr_incoming.set_nodelay(true);
@@ -77,14 +77,17 @@ pub async fn task_main(
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete`
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn {
error!("failed to accept TLS connection for websockets: {err:?}");
ready(false)
} else {
ready(true)
}
});
let tls_listener = tls_listener::builder(tls_acceptor)
.handshake_timeout(tls_config.handshake_timeout)
.listen(addr_incoming)
.filter(|conn| {
if let Err(err) = conn {
error!("failed to accept TLS connection for websockets: {err:?}");
ready(false)
} else {
ready(true)
}
});
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {