asyncreadready

This commit is contained in:
Conrad Ludgate
2024-08-21 16:16:49 +01:00
parent 8cc45ad9bd
commit fbd4b91169
4 changed files with 34 additions and 3 deletions

View File

@@ -27,6 +27,17 @@ impl<S: AsRawFd> AsRawFd for ChainRW<S> {
}
}
#[cfg(all(target_os = "linux", not(test)))]
impl<S: ktls::AsyncReadReady> AsRawFd for ChainRW<S> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if self.buf.is_empty() {
self.inner.poll_read_ready(cx)
} else {
Poll::Ready(Ok(()))
}
}
}
impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
#[inline]
fn poll_write(

View File

@@ -9,6 +9,7 @@ pub mod retry;
pub mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use handshake::KtlsAsyncReadReady;
use crate::{
auth,
@@ -232,7 +233,7 @@ impl ReportableError for ClientRequestError {
}
}
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,

View File

@@ -73,17 +73,30 @@ pub enum HandshakeData<S> {
Cancel(CancelKeyData),
}
#[cfg(any(not(target_os = "linux"), test))]
pub trait KtlsAsyncReadReady {}
#[cfg(all(target_os = "linux", not(test)))]
pub trait KtlsAsyncReadReady: ktls::AsyncReadReady {}
#[cfg(any(not(target_os = "linux"), test))]
impl<K: AsyncRead> KtlsAsyncReadReady for K {}
#[cfg(all(target_os = "linux", not(test)))]
impl<K: ktls::AsyncReadReady> KtlsAsyncReadReady for K {}
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
pub async fn handshake<S>(
ctx: &RequestMonitoring,
stream: S,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<S>, HandshakeError> {
) -> Result<HandshakeData<S>, HandshakeError>
where
S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady,
{
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);

View File

@@ -51,6 +51,12 @@ impl<S> AsRawFd for WebSocketRw<S> {
unreachable!("ktls should not need to be used for websocket rw")
}
}
#[cfg(all(target_os = "linux", not(test)))]
impl<S: ktls::AsyncReadReady> AsRawFd for ChainRW<S> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!("ktls should not need to be used for websocket rw")
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
fn poll_write(