proxy: swap tungstenite for a simpler impl (#7353)

## Problem

I wanted to do a deep dive of the tungstenite codebase.
tokio-tungstenite is incredibly convoluted... In my searching I found
[fastwebsockets by deno](https://github.com/denoland/fastwebsockets),
but it wasn't quite sufficient.

This also removes the default 16MB/64MB frame/message size limitation.
framed-websockets solves this by inserting continuation frames for
partially received messages, so the whole message does not need to be
entirely read into memory.

## Summary of changes

I took the fastwebsockets code as a starting off point and rewrote it to
be simpler, server-only, and be poll-based to support our Read/Write
wrappers.

I have replaced our tungstenite code with my framed-websockets fork.

<https://github.com/neondatabase/framed-websockets>
This commit is contained in:
Conrad Ludgate
2024-05-16 12:05:50 +01:00
committed by GitHub
parent 923cf91aa4
commit 790c05d675
6 changed files with 107 additions and 124 deletions

View File

@@ -102,7 +102,7 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let server = Builder::new(hyper_util::rt::TokioExecutor::new());
let server = Builder::new(TokioExecutor::new());
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
@@ -255,7 +255,6 @@ async fn connection_handler(
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
);
async move {
let res = handler.await;
cancel_request.disarm();
@@ -301,7 +300,7 @@ async fn request_handler(
.map(|s| s.to_string());
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&request) {
if framed_websockets::upgrade::is_upgrade_request(&request) {
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
@@ -312,7 +311,7 @@ async fn request_handler(
let span = ctx.span.clone();
info!(parent: &span, "performing websocket upgrade");
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
ws_connections.spawn(
@@ -334,7 +333,7 @@ async fn request_handler(
);
// Return the response so the spawned future can continue.
Ok(response)
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,

View File

@@ -7,10 +7,11 @@ use crate::{
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
use bytes::{Buf, Bytes};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
use hyper1::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use std::{
@@ -21,25 +22,23 @@ use std::{
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tracing::warn;
// TODO: use `std::sync::Exclusive` once it's stabilized.
// Tracking issue: https://github.com/rust-lang/rust/issues/98407.
use sync_wrapper::SyncWrapper;
pin_project! {
/// This is a wrapper around a [`WebSocketStream`] that
/// implements [`AsyncRead`] and [`AsyncWrite`].
pub struct WebSocketRw<S = Upgraded> {
pub struct WebSocketRw<S> {
#[pin]
stream: SyncWrapper<WebSocketStream<S>>,
bytes: Bytes,
stream: WebSocketServer<S>,
recv: Bytes,
send: BytesMut,
}
}
impl<S> WebSocketRw<S> {
pub fn new(stream: WebSocketStream<S>) -> Self {
pub fn new(stream: WebSocketServer<S>) -> Self {
Self {
stream: stream.into(),
bytes: Bytes::new(),
stream,
recv: Bytes::new(),
send: BytesMut::new(),
}
}
}
@@ -50,22 +49,24 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut stream = self.project().stream.get_pin_mut();
let this = self.project();
let mut stream = this.stream;
this.send.put(buf);
ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
match stream.as_mut().start_send(Message::Binary(buf.into())) {
match stream.as_mut().start_send(Frame::binary(this.send.split())) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(io_error(e))),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream = self.project().stream.get_pin_mut();
let stream = self.project().stream;
stream.poll_flush(cx).map_err(io_error)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream = self.project().stream.get_pin_mut();
let stream = self.project().stream;
stream.poll_close(cx).map_err(io_error)
}
}
@@ -76,13 +77,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() > 0 {
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
self.consume(len);
}
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
self.consume(len);
Poll::Ready(Ok(()))
}
}
@@ -94,31 +92,27 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
let mut this = self.project();
loop {
if !this.bytes.chunk().is_empty() {
let chunk = (*this.bytes).chunk();
if !this.recv.chunk().is_empty() {
let chunk = (*this.recv).chunk();
return Poll::Ready(Ok(chunk));
}
let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
let res = ready!(this.stream.as_mut().poll_next(cx));
match res.transpose().map_err(io_error)? {
Some(message) => match message {
Message::Ping(_) => {}
Message::Pong(_) => {}
Message::Text(text) => {
Some(message) => match message.opcode {
OpCode::Ping => {}
OpCode::Pong => {}
OpCode::Text => {
// We expect to see only binary messages.
let error = "unexpected text message in the websocket";
warn!(length = text.len(), error);
warn!(length = message.payload.len(), error);
return Poll::Ready(Err(io_error(error)));
}
Message::Frame(_) => {
// This case is impossible according to Frame's doc.
panic!("unexpected raw frame in the websocket");
OpCode::Binary | OpCode::Continuation => {
debug_assert!(this.recv.is_empty());
*this.recv = message.payload.freeze();
}
Message::Binary(chunk) => {
assert!(this.bytes.is_empty());
*this.bytes = Bytes::from(chunk);
}
Message::Close(_) => return EOF,
OpCode::Close => return EOF,
},
None => return EOF,
}
@@ -126,19 +120,21 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
}
fn consume(self: Pin<&mut Self>, amount: usize) {
self.project().bytes.advance(amount);
self.project().recv.advance(amount);
}
}
pub async fn serve_websocket(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
websocket: HyperWebsocket,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
let conn_gauge = Metrics::get()
.proxy
.client_connections
@@ -177,15 +173,16 @@ pub async fn serve_websocket(
mod tests {
use std::pin::pin;
use framed_websockets::WebSocketServer;
use futures::{SinkExt, StreamExt};
use hyper_tungstenite::{
tungstenite::{protocol::Role, Message},
WebSocketStream,
};
use tokio::{
io::{duplex, AsyncReadExt, AsyncWriteExt},
task::JoinSet,
};
use tokio_tungstenite::{
tungstenite::{protocol::Role, Message},
WebSocketStream,
};
use super::WebSocketRw;
@@ -210,9 +207,7 @@ mod tests {
});
js.spawn(async move {
let mut rw = pin!(WebSocketRw::new(
WebSocketStream::from_raw_socket(stream2, Role::Server, None).await
));
let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
let mut buf = vec![0; 1024];
let n = rw.read(&mut buf).await.unwrap();