Compare commits

...

1 Commits

Author SHA1 Message Date
Conrad Ludgate
d7e6a319bb add handshake timeouts 2023-12-15 16:06:32 +00:00
4 changed files with 34 additions and 18 deletions

View File

@@ -33,6 +33,8 @@ pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>, pub config: Arc<rustls::ServerConfig>,
pub common_names: Option<HashSet<String>>, pub common_names: Option<HashSet<String>>,
pub cert_resolver: Arc<CertResolver>, pub cert_resolver: Arc<CertResolver>,
pub handshake_timeout: Duration,
pub max_handshaking: usize,
} }
pub struct HttpConfig { pub struct HttpConfig {
@@ -98,6 +100,8 @@ pub fn configure_tls(
config, config,
common_names: Some(common_names), common_names: Some(common_names),
cert_resolver, cert_resolver,
handshake_timeout: tls_listener::DEFAULT_HANDSHAKE_TIMEOUT,
max_handshaking: tls_listener::DEFAULT_MAX_HANDSHAKES,
}) })
} }

View File

@@ -28,7 +28,7 @@ use prometheus::{
IntGaugeVec, IntGaugeVec,
}; };
use regex::Regex; use regex::Regex;
use std::{error::Error, io, net::IpAddr, ops::ControlFlow, sync::Arc, time::Instant}; use std::{error::Error, io, net::IpAddr, ops::ControlFlow, sync::Arc};
use tokio::{ use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt}, io::{AsyncRead, AsyncWrite, AsyncWriteExt},
time, time,
@@ -154,7 +154,7 @@ pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
pub struct LatencyTimer { pub struct LatencyTimer {
// time since the stopwatch was started // time since the stopwatch was started
start: Option<Instant>, start: Option<time::Instant>,
// accumulated time on the stopwatch // accumulated time on the stopwatch
accumulated: std::time::Duration, accumulated: std::time::Duration,
// label data // label data
@@ -171,7 +171,7 @@ pub struct LatencyTimerPause<'a> {
impl LatencyTimer { impl LatencyTimer {
pub fn new(protocol: &'static str) -> Self { pub fn new(protocol: &'static str) -> Self {
Self { Self {
start: Some(Instant::now()), start: Some(time::Instant::now()),
accumulated: std::time::Duration::ZERO, accumulated: std::time::Duration::ZERO,
protocol, protocol,
cache_miss: false, cache_miss: false,
@@ -205,7 +205,7 @@ impl LatencyTimer {
impl Drop for LatencyTimerPause<'_> { impl Drop for LatencyTimerPause<'_> {
fn drop(&mut self) { fn drop(&mut self) {
// start the stopwatch again // start the stopwatch again
self.timer.start = Some(Instant::now()); self.timer.start = Some(time::Instant::now());
} }
} }
@@ -467,9 +467,14 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Client may try upgrading to each protocol only once // Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false); let (mut tried_ssl, mut tried_gss) = (false, false);
let handshake_timeout = tls
.map(|tls| tls.handshake_timeout)
.unwrap_or(tls_listener::DEFAULT_HANDSHAKE_TIMEOUT);
let deadline = time::Instant::now() + handshake_timeout;
let mut stream = PqStream::new(Stream::from_raw(stream)); let mut stream = PqStream::new(Stream::from_raw(stream));
loop { loop {
let msg = stream.read_startup_packet().await?; let msg = tokio::time::timeout_at(deadline, stream.read_startup_packet()).await??;
info!("received {msg:?}"); info!("received {msg:?}");
use FeStartupPacket::*; use FeStartupPacket::*;
@@ -495,7 +500,9 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
if !read_buf.is_empty() { if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse"); bail!("data is sent before server replied with EncryptionResponse");
} }
let tls_stream = raw.upgrade(tls.to_server_config()).await?; let tls_stream =
tokio::time::timeout_at(deadline, raw.upgrade(tls.to_server_config()))
.await??;
let (_, tls_server_end_point) = tls let (_, tls_server_end_point) = tls
.cert_resolver .cert_resolver

View File

@@ -85,6 +85,8 @@ fn generate_tls_config<'a>(
config, config,
common_names, common_names,
cert_resolver: Arc::new(cert_resolver), cert_resolver: Arc::new(cert_resolver),
handshake_timeout: tls_listener::DEFAULT_HANDSHAKE_TIMEOUT,
max_handshaking: tls_listener::DEFAULT_MAX_HANDSHAKES,
} }
}; };

View File

@@ -29,7 +29,6 @@ use hyper::{
use std::net::IpAddr; use std::net::IpAddr;
use std::task::Poll; use std::task::Poll;
use std::{future::ready, sync::Arc}; use std::{future::ready, sync::Arc};
use tls_listener::TlsListener;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument}; use tracing::{error, info, info_span, warn, Instrument};
@@ -59,14 +58,15 @@ pub async fn task_main(
} }
}); });
let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config()); // let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config { let tls_config = match config.tls_config.as_ref() {
Some(config) => config.into(), Some(config) => config,
None => { None => {
warn!("TLS config is missing, WebSocket Secure server will not be started"); warn!("TLS config is missing, WebSocket Secure server will not be started");
return Ok(()); return Ok(());
} }
}; };
let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?; let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
let _ = addr_incoming.set_nodelay(true); let _ = addr_incoming.set_nodelay(true);
@@ -77,14 +77,17 @@ pub async fn task_main(
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete` ws_connections.close(); // allows `ws_connections.wait to complete`
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { let tls_listener = tls_listener::builder(tls_acceptor)
if let Err(err) = conn { .handshake_timeout(tls_config.handshake_timeout)
error!("failed to accept TLS connection for websockets: {err:?}"); .listen(addr_incoming)
ready(false) .filter(|conn| {
} else { if let Err(err) = conn {
ready(true) error!("failed to accept TLS connection for websockets: {err:?}");
} ready(false)
}); } else {
ready(true)
}
});
let make_svc = hyper::service::make_service_fn( let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| { |stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {