use crate::{ cancellation::CancelMap, config::ProxyConfig, context::RequestMonitoring, error::io_error, proxy::{handle_client, ClientMode}, rate_limiter::EndpointRateLimiter, }; use bytes::{Buf, Bytes}; use futures::{Sink, Stream}; use hyper::{ext::Protocol, upgrade::Upgraded, Body, Method, Request, Response}; use pin_project_lite::pin_project; use tokio_tungstenite::WebSocketStream; use tungstenite::{ error::{Error as WSError, ProtocolError}, handshake::derive_accept_key, protocol::{Role, WebSocketConfig}, Message, }; use std::{ pin::Pin, sync::Arc, task::{ready, Context, Poll}, }; use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use tracing::warn; // TODO: use `std::sync::Exclusive` once it's stabilized. // Tracking issue: https://github.com/rust-lang/rust/issues/98407. use sync_wrapper::SyncWrapper; pin_project! { /// This is a wrapper around a [`WebSocketStream`] that /// implements [`AsyncRead`] and [`AsyncWrite`]. pub struct WebSocketRw { #[pin] stream: SyncWrapper>, bytes: Bytes, } } impl WebSocketRw { pub fn new(stream: WebSocketStream) -> Self { Self { stream: stream.into(), bytes: Bytes::new(), } } } impl AsyncWrite for WebSocketRw { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let mut stream = self.project().stream.get_pin_mut(); ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?; match stream.as_mut().start_send(Message::Binary(buf.into())) { Ok(()) => Poll::Ready(Ok(buf.len())), Err(e) => Poll::Ready(Err(io_error(e))), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let stream = self.project().stream.get_pin_mut(); stream.poll_flush(cx).map_err(io_error) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let stream = self.project().stream.get_pin_mut(); stream.poll_close(cx).map_err(io_error) } } impl AsyncRead for WebSocketRw { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { if buf.remaining() > 0 { let bytes = ready!(self.as_mut().poll_fill_buf(cx))?; let len = std::cmp::min(bytes.len(), buf.remaining()); buf.put_slice(&bytes[..len]); self.consume(len); } Poll::Ready(Ok(())) } } impl AsyncBufRead for WebSocketRw { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Please refer to poll_fill_buf's documentation. const EOF: Poll> = Poll::Ready(Ok(&[])); let mut this = self.project(); loop { if !this.bytes.chunk().is_empty() { let chunk = (*this.bytes).chunk(); return Poll::Ready(Ok(chunk)); } let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx)); match res.transpose().map_err(io_error)? { Some(message) => match message { Message::Ping(_) => {} Message::Pong(_) => {} Message::Text(text) => { // We expect to see only binary messages. let error = "unexpected text message in the websocket"; warn!(length = text.len(), error); return Poll::Ready(Err(io_error(error))); } Message::Frame(_) => { // This case is impossible according to Frame's doc. panic!("unexpected raw frame in the websocket"); } Message::Binary(chunk) => { assert!(this.bytes.is_empty()); *this.bytes = Bytes::from(chunk); } Message::Close(_) => return EOF, }, None => return EOF, } } } fn consume(self: Pin<&mut Self>, amount: usize) { self.project().bytes.advance(amount); } } pub async fn serve_websocket( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, websocket: HyperWebsocket, cancel_map: &CancelMap, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { let websocket = websocket.await?; handle_client( config, ctx, cancel_map, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, endpoint_rate_limiter, ) .await?; Ok(()) } /// Try to upgrade a received `hyper::Request` to a websocket connection. /// /// The function returns a HTTP response and a future that resolves to the websocket stream. /// The response body *MUST* be sent to the client before the future can be resolved. /// /// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers. /// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers. /// You can inspect the headers manually before calling this function, /// and modify the response headers appropriately. /// /// This function also does not look at the `Connection` or `Upgrade` headers. /// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`]. /// Alternatively you can inspect the `Connection` and `Upgrade` headers manually. /// pub fn upgrade( mut request: impl std::borrow::BorrowMut>, config: Option, ) -> Result<(Response, HyperWebsocket), ProtocolError> { let request = request.borrow_mut(); let key = request .headers() .get("Sec-WebSocket-Key") .ok_or(ProtocolError::MissingSecWebSocketKey)?; if request .headers() .get("Sec-WebSocket-Version") .map(|v| v.as_bytes()) != Some(b"13") { return Err(ProtocolError::MissingSecWebSocketVersionHeader); } let response = Response::builder() .status(hyper::StatusCode::SWITCHING_PROTOCOLS) .header(hyper::header::CONNECTION, "upgrade") .header(hyper::header::UPGRADE, "websocket") .header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes())) .body(Body::from("switching to websocket protocol")) .expect("bug: failed to build response"); let stream = HyperWebsocket { inner: hyper::upgrade::on(request), config, }; Ok((response, stream)) } /// Check if a request is a websocket upgrade request. /// /// If the `Upgrade` header lists multiple protocols, /// this function returns true if of them are `"websocket"`, /// If the server supports multiple upgrade protocols, /// it would be more appropriate to try each listed protocol in order. pub fn is_upgrade_request(request: &hyper::Request) -> bool { header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade") && header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket") } /// Check if there is a header of the given name containing the wanted value. fn header_contains_value( headers: &hyper::HeaderMap, header: impl hyper::header::AsHeaderName, value: impl AsRef<[u8]>, ) -> bool { let value = value.as_ref(); for header in headers.get_all(header) { if header .as_bytes() .split(|&c| c == b',') .any(|x| trim(x).eq_ignore_ascii_case(value)) { return true; } } false } fn trim(data: &[u8]) -> &[u8] { trim_end(trim_start(data)) } fn trim_start(data: &[u8]) -> &[u8] { if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) { &data[start..] } else { b"" } } fn trim_end(data: &[u8]) -> &[u8] { if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) { &data[..last + 1] } else { b"" } } /// Try to upgrade a received `hyper::Request` to a websocket connection. /// /// The function returns a HTTP response and a future that resolves to the websocket stream. /// The response body *MUST* be sent to the client before the future can be resolved. /// /// This functions checks `Sec-WebSocket-Version` header. /// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers. /// You can inspect the headers manually before calling this function, /// and modify the response headers appropriately. /// /// This function also does not look at the `Connection` or `Upgrade` headers. /// To check if a request is a websocket connect request, you can use [`is_connect_request`]. /// Alternatively you can inspect the `Connection` and `Upgrade` headers manually. /// pub fn connect( mut request: impl std::borrow::BorrowMut>, config: Option, ) -> Result<(Response, HyperWebsocket), ProtocolError> { let request = request.borrow_mut(); if request .headers() .get("Sec-WebSocket-Version") .map(|v| v.as_bytes()) != Some(b"13") { return Err(ProtocolError::MissingSecWebSocketVersionHeader); } let response = Response::builder() .status(hyper::StatusCode::OK) .body(Body::from("switching to websocket protocol")) .expect("bug: failed to build response"); let stream = HyperWebsocket { inner: hyper::upgrade::on(request), config, }; Ok((response, stream)) } /// Check if a request is a websocket connect request. pub fn is_connect_request(request: &hyper::Request) -> bool { request.method() == Method::CONNECT && request .extensions() .get::() .is_some_and(|protocol| protocol.as_str() == "websocket") } pin_project_lite::pin_project! { /// A future that resolves to a websocket stream when the associated connection completes. #[derive(Debug)] pub struct HyperWebsocket { #[pin] inner: hyper::upgrade::OnUpgrade, config: Option } } impl std::future::Future for HyperWebsocket { type Output = Result, WSError>; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { let this = self.project(); let upgraded = match this.inner.poll(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(x) => x, }; let upgraded = upgraded.map_err(|_| WSError::Protocol(ProtocolError::HandshakeIncomplete))?; let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None); tokio::pin!(stream); // The future returned by `from_raw_socket` is always ready. // Not sure why it is a future in the first place. match stream.as_mut().poll(cx) { Poll::Pending => unreachable!("from_raw_socket should always be created ready"), Poll::Ready(x) => Poll::Ready(Ok(x)), } } } #[cfg(test)] mod tests { use std::pin::pin; use futures::{SinkExt, StreamExt}; use tokio::{ io::{duplex, AsyncReadExt, AsyncWriteExt}, task::JoinSet, }; use tokio_tungstenite::WebSocketStream; use tungstenite::{protocol::Role, Message}; use super::WebSocketRw; #[tokio::test] async fn websocket_stream_wrapper_happy_path() { let (stream1, stream2) = duplex(1024); let mut js = JoinSet::new(); js.spawn(async move { let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await; client .send(Message::Binary(b"hello world".to_vec())) .await .unwrap(); let message = client.next().await.unwrap().unwrap(); assert_eq!(message, Message::Binary(b"websockets are cool".to_vec())); client.close(None).await.unwrap(); }); js.spawn(async move { let mut rw = pin!(WebSocketRw::new( WebSocketStream::from_raw_socket(stream2, Role::Server, None).await )); let mut buf = vec![0; 1024]; let n = rw.read(&mut buf).await.unwrap(); assert_eq!(&buf[..n], b"hello world"); rw.write_all(b"websockets are cool").await.unwrap(); rw.flush().await.unwrap(); let n = rw.read_to_end(&mut buf).await.unwrap(); assert_eq!(n, 0); }); js.join_next().await.unwrap().unwrap(); js.join_next().await.unwrap().unwrap(); } }