Files
neon/proxy/src/serverless/websocket.rs

233 lines
7.3 KiB
Rust

use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tracing::warn;
use crate::cancellation::CancellationHandler;
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
/// This is a wrapper around a [`WebSocketStream`] that
/// implements [`AsyncRead`] and [`AsyncWrite`].
pub(crate) struct WebSocketRw<S> {
#[pin]
stream: WebSocketServer<S>,
recv: Bytes,
send: BytesMut,
}
}
impl<S> WebSocketRw<S> {
pub(crate) fn new(stream: WebSocketServer<S>) -> Self {
Self {
stream,
recv: Bytes::new(),
send: BytesMut::new(),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
let mut stream = this.stream;
ready!(stream.as_mut().poll_ready(cx).map_err(io::Error::other))?;
this.send.put(buf);
match stream.as_mut().start_send(Frame::binary(this.send.split())) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(io::Error::other(e))),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream = self.project().stream;
stream.poll_flush(cx).map_err(io::Error::other)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream = self.project().stream;
stream.poll_close(cx).map_err(io::Error::other)
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
self.consume(len);
Poll::Ready(Ok(()))
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
// Please refer to poll_fill_buf's documentation.
const EOF: Poll<io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
let mut this = self.project();
loop {
if !this.recv.chunk().is_empty() {
let chunk = (*this.recv).chunk();
return Poll::Ready(Ok(chunk));
}
let res = ready!(this.stream.as_mut().poll_next(cx));
match res.transpose().map_err(io::Error::other)? {
Some(message) => match message.opcode {
OpCode::Ping => {}
OpCode::Pong => {}
OpCode::Text => {
// We expect to see only binary messages.
let error = "unexpected text message in the websocket";
warn!(length = message.payload.len(), error);
return Poll::Ready(Err(io::Error::other(error)));
}
OpCode::Binary | OpCode::Continuation => {
debug_assert!(this.recv.is_empty());
*this.recv = message.payload.freeze();
}
OpCode::Close => return EOF,
},
None => return EOF,
}
}
}
fn consume(self: Pin<&mut Self>, amount: usize) {
self.project().recv.advance(amount);
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
ctx: RequestContext,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Ws);
let res = Box::pin(handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
cancellations,
))
.await;
match res {
Err(e) => {
// todo: log and push to ctx the error kind
ctx.set_error_kind(e.get_error_kind());
Err(e.into())
}
Ok(None) => {
ctx.set_success();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}
}
}
}
#[cfg(test)]
#[expect(clippy::unwrap_used)]
mod tests {
use std::pin::pin;
use framed_websockets::WebSocketServer;
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::task::JoinSet;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::protocol::Role;
use super::WebSocketRw;
#[tokio::test]
async fn websocket_stream_wrapper_happy_path() {
let (stream1, stream2) = duplex(1024);
let mut js = JoinSet::new();
js.spawn(async move {
let mut client = WebSocketStream::from_raw_socket(stream1, Role::Client, None).await;
client
.send(Message::Binary(b"hello world".to_vec()))
.await
.unwrap();
let message = client.next().await.unwrap().unwrap();
assert_eq!(message, Message::Binary(b"websockets are cool".to_vec()));
client.close(None).await.unwrap();
});
js.spawn(async move {
let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
let mut buf = vec![0; 1024];
let n = rw.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello world");
rw.write_all(b"websockets are cool").await.unwrap();
rw.flush().await.unwrap();
let n = rw.read_to_end(&mut buf).await.unwrap();
assert_eq!(n, 0);
});
js.join_next().await.unwrap().unwrap();
js.join_next().await.unwrap().unwrap();
}
}