This commit is contained in:
Conrad Ludgate
2024-08-21 16:29:52 +01:00
parent fbd4b91169
commit 471b3b300d
2 changed files with 17 additions and 21 deletions

View File

@@ -34,7 +34,7 @@ pub enum HandshakeError {
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
#[cfg(target_os = "linux")]
#[cfg(all(target_os = "linux", not(test)))]
#[error("{0}")]
KtlsUpgradeError(#[from] ktls::Error),
@@ -199,9 +199,9 @@ where
framed: Framed {
stream: Stream::Tls {
#[cfg(any(not(target_os = "linux"), test))]
tls: Box::new(tls_stream),
tls: Box::pin(tls_stream),
#[cfg(all(target_os = "linux", not(test)))]
tls: ktls::config_ktls_server(tls_stream)?,
tls: Box::pin(ktls::config_ktls_server(tls_stream)?),
tls_server_end_point,
},
read_buf,

View File

@@ -179,10 +179,10 @@ pub enum Stream<S> {
Tls {
/// We box [`TlsStream`] since it can be quite large.
#[cfg(any(not(target_os = "linux"), test))]
tls: Box<TlsStream<S>>,
tls: Pin<Box<TlsStream<S>>>,
#[cfg(all(target_os = "linux", not(test)))]
tls: ktls::KtlsStream<S>,
tls: Pin<Box<ktls::KtlsStream<S>>>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
@@ -226,18 +226,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
record_handshake_error: bool,
) -> Result<TlsStream<S>, StreamUpgradeError> {
match self {
Stream::Raw { raw } => {
let stream = tokio_rustls::TlsAcceptor::from(cfg)
.accept(raw)
.await
.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc();
}
})?;
Ok(stream)
}
Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg)
.accept(raw)
.await
.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc();
}
})?),
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
}
}
@@ -251,7 +247,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_read(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf),
Self::Tls { tls, .. } => tls.as_mut().poll_read(context, buf),
}
}
}
@@ -264,7 +260,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
) -> task::Poll<io::Result<usize>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_write(context, buf),
Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf),
Self::Tls { tls, .. } => tls.as_mut().poll_write(context, buf),
}
}
@@ -274,7 +270,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_flush(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context),
Self::Tls { tls, .. } => tls.as_mut().poll_flush(context),
}
}
@@ -284,7 +280,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
) -> task::Poll<io::Result<()>> {
match &mut *self {
Self::Raw { raw } => Pin::new(raw).poll_shutdown(context),
Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context),
Self::Tls { tls, .. } => tls.as_mut().poll_shutdown(context),
}
}
}