diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 0efead8c2e..f6fcfe395e 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -34,7 +34,7 @@ pub enum HandshakeError { #[error("{0}")] StreamUpgradeError(#[from] StreamUpgradeError), - #[cfg(target_os = "linux")] + #[cfg(all(target_os = "linux", not(test)))] #[error("{0}")] KtlsUpgradeError(#[from] ktls::Error), @@ -199,9 +199,9 @@ where framed: Framed { stream: Stream::Tls { #[cfg(any(not(target_os = "linux"), test))] - tls: Box::new(tls_stream), + tls: Box::pin(tls_stream), #[cfg(all(target_os = "linux", not(test)))] - tls: ktls::config_ktls_server(tls_stream)?, + tls: Box::pin(ktls::config_ktls_server(tls_stream)?), tls_server_end_point, }, read_buf, diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index cabbfecfab..d9bf3b86cb 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -179,10 +179,10 @@ pub enum Stream { Tls { /// We box [`TlsStream`] since it can be quite large. #[cfg(any(not(target_os = "linux"), test))] - tls: Box>, + tls: Pin>>, #[cfg(all(target_os = "linux", not(test)))] - tls: ktls::KtlsStream, + tls: Pin>>, /// Channel binding parameter tls_server_end_point: TlsServerEndPoint, @@ -226,18 +226,14 @@ impl Stream { record_handshake_error: bool, ) -> Result, StreamUpgradeError> { match self { - Stream::Raw { raw } => { - let stream = tokio_rustls::TlsAcceptor::from(cfg) - .accept(raw) - .await - .inspect_err(|_| { - if record_handshake_error { - Metrics::get().proxy.tls_handshake_failures.inc(); - } - })?; - - Ok(stream) - } + Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg) + .accept(raw) + .await + .inspect_err(|_| { + if record_handshake_error { + Metrics::get().proxy.tls_handshake_failures.inc(); + } + })?), Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls), } } @@ -251,7 +247,7 @@ impl AsyncRead for Stream { ) -> task::Poll> { match &mut *self { Self::Raw { raw } => Pin::new(raw).poll_read(context, buf), - Self::Tls { tls, .. } => Pin::new(tls).poll_read(context, buf), + Self::Tls { tls, .. } => tls.as_mut().poll_read(context, buf), } } } @@ -264,7 +260,7 @@ impl AsyncWrite for Stream { ) -> task::Poll> { match &mut *self { Self::Raw { raw } => Pin::new(raw).poll_write(context, buf), - Self::Tls { tls, .. } => Pin::new(tls).poll_write(context, buf), + Self::Tls { tls, .. } => tls.as_mut().poll_write(context, buf), } } @@ -274,7 +270,7 @@ impl AsyncWrite for Stream { ) -> task::Poll> { match &mut *self { Self::Raw { raw } => Pin::new(raw).poll_flush(context), - Self::Tls { tls, .. } => Pin::new(tls).poll_flush(context), + Self::Tls { tls, .. } => tls.as_mut().poll_flush(context), } } @@ -284,7 +280,7 @@ impl AsyncWrite for Stream { ) -> task::Poll> { match &mut *self { Self::Raw { raw } => Pin::new(raw).poll_shutdown(context), - Self::Tls { tls, .. } => Pin::new(tls).poll_shutdown(context), + Self::Tls { tls, .. } => tls.as_mut().poll_shutdown(context), } } }