mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 05:52:55 +00:00
proxy: swap tungstenite for a simpler impl (#7353)
## Problem I wanted to do a deep dive of the tungstenite codebase. tokio-tungstenite is incredibly convoluted... In my searching I found [fastwebsockets by deno](https://github.com/denoland/fastwebsockets), but it wasn't quite sufficient. This also removes the default 16MB/64MB frame/message size limitation. framed-websockets solves this by inserting continuation frames for partially received messages, so the whole message does not need to be entirely read into memory. ## Summary of changes I took the fastwebsockets code as a starting off point and rewrote it to be simpler, server-only, and be poll-based to support our Read/Write wrappers. I have replaced our tungstenite code with my framed-websockets fork. <https://github.com/neondatabase/framed-websockets>
This commit is contained in:
@@ -102,7 +102,7 @@ pub async fn task_main(
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
let server = Builder::new(hyper_util::rt::TokioExecutor::new());
|
||||
let server = Builder::new(TokioExecutor::new());
|
||||
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
@@ -255,7 +255,6 @@ async fn connection_handler(
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
);
|
||||
|
||||
async move {
|
||||
let res = handler.await;
|
||||
cancel_request.disarm();
|
||||
@@ -301,7 +300,7 @@ async fn request_handler(
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||
if framed_websockets::upgrade::is_upgrade_request(&request) {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
@@ -312,7 +311,7 @@ async fn request_handler(
|
||||
let span = ctx.span.clone();
|
||||
info!(parent: &span, "performing websocket upgrade");
|
||||
|
||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
ws_connections.spawn(
|
||||
@@ -334,7 +333,7 @@ async fn request_handler(
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
|
||||
@@ -7,10 +7,11 @@ use crate::{
|
||||
proxy::{handle_client, ClientMode},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
};
|
||||
use bytes::{Buf, Bytes};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use framed_websockets::{Frame, OpCode, WebSocketServer};
|
||||
use futures::{Sink, Stream};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use hyper1::upgrade::OnUpgrade;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use std::{
|
||||
@@ -21,25 +22,23 @@ use std::{
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::warn;
|
||||
|
||||
// TODO: use `std::sync::Exclusive` once it's stabilized.
|
||||
// Tracking issue: https://github.com/rust-lang/rust/issues/98407.
|
||||
use sync_wrapper::SyncWrapper;
|
||||
|
||||
pin_project! {
|
||||
/// This is a wrapper around a [`WebSocketStream`] that
|
||||
/// implements [`AsyncRead`] and [`AsyncWrite`].
|
||||
pub struct WebSocketRw<S = Upgraded> {
|
||||
pub struct WebSocketRw<S> {
|
||||
#[pin]
|
||||
stream: SyncWrapper<WebSocketStream<S>>,
|
||||
bytes: Bytes,
|
||||
stream: WebSocketServer<S>,
|
||||
recv: Bytes,
|
||||
send: BytesMut,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> WebSocketRw<S> {
|
||||
pub fn new(stream: WebSocketStream<S>) -> Self {
|
||||
pub fn new(stream: WebSocketServer<S>) -> Self {
|
||||
Self {
|
||||
stream: stream.into(),
|
||||
bytes: Bytes::new(),
|
||||
stream,
|
||||
recv: Bytes::new(),
|
||||
send: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -50,22 +49,24 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let mut stream = self.project().stream.get_pin_mut();
|
||||
let this = self.project();
|
||||
let mut stream = this.stream;
|
||||
this.send.put(buf);
|
||||
|
||||
ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?;
|
||||
match stream.as_mut().start_send(Message::Binary(buf.into())) {
|
||||
match stream.as_mut().start_send(Frame::binary(this.send.split())) {
|
||||
Ok(()) => Poll::Ready(Ok(buf.len())),
|
||||
Err(e) => Poll::Ready(Err(io_error(e))),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let stream = self.project().stream.get_pin_mut();
|
||||
let stream = self.project().stream;
|
||||
stream.poll_flush(cx).map_err(io_error)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let stream = self.project().stream.get_pin_mut();
|
||||
let stream = self.project().stream;
|
||||
stream.poll_close(cx).map_err(io_error)
|
||||
}
|
||||
}
|
||||
@@ -76,13 +77,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for WebSocketRw<S> {
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if buf.remaining() > 0 {
|
||||
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
||||
let len = std::cmp::min(bytes.len(), buf.remaining());
|
||||
buf.put_slice(&bytes[..len]);
|
||||
self.consume(len);
|
||||
}
|
||||
|
||||
let bytes = ready!(self.as_mut().poll_fill_buf(cx))?;
|
||||
let len = std::cmp::min(bytes.len(), buf.remaining());
|
||||
buf.put_slice(&bytes[..len]);
|
||||
self.consume(len);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
@@ -94,31 +92,27 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
|
||||
let mut this = self.project();
|
||||
loop {
|
||||
if !this.bytes.chunk().is_empty() {
|
||||
let chunk = (*this.bytes).chunk();
|
||||
if !this.recv.chunk().is_empty() {
|
||||
let chunk = (*this.recv).chunk();
|
||||
return Poll::Ready(Ok(chunk));
|
||||
}
|
||||
|
||||
let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx));
|
||||
let res = ready!(this.stream.as_mut().poll_next(cx));
|
||||
match res.transpose().map_err(io_error)? {
|
||||
Some(message) => match message {
|
||||
Message::Ping(_) => {}
|
||||
Message::Pong(_) => {}
|
||||
Message::Text(text) => {
|
||||
Some(message) => match message.opcode {
|
||||
OpCode::Ping => {}
|
||||
OpCode::Pong => {}
|
||||
OpCode::Text => {
|
||||
// We expect to see only binary messages.
|
||||
let error = "unexpected text message in the websocket";
|
||||
warn!(length = text.len(), error);
|
||||
warn!(length = message.payload.len(), error);
|
||||
return Poll::Ready(Err(io_error(error)));
|
||||
}
|
||||
Message::Frame(_) => {
|
||||
// This case is impossible according to Frame's doc.
|
||||
panic!("unexpected raw frame in the websocket");
|
||||
OpCode::Binary | OpCode::Continuation => {
|
||||
debug_assert!(this.recv.is_empty());
|
||||
*this.recv = message.payload.freeze();
|
||||
}
|
||||
Message::Binary(chunk) => {
|
||||
assert!(this.bytes.is_empty());
|
||||
*this.bytes = Bytes::from(chunk);
|
||||
}
|
||||
Message::Close(_) => return EOF,
|
||||
OpCode::Close => return EOF,
|
||||
},
|
||||
None => return EOF,
|
||||
}
|
||||
@@ -126,19 +120,21 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
}
|
||||
|
||||
fn consume(self: Pin<&mut Self>, amount: usize) {
|
||||
self.project().bytes.advance(amount);
|
||||
self.project().recv.advance(amount);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
websocket: HyperWebsocket,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
@@ -177,15 +173,16 @@ pub async fn serve_websocket(
|
||||
mod tests {
|
||||
use std::pin::pin;
|
||||
|
||||
use framed_websockets::WebSocketServer;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use hyper_tungstenite::{
|
||||
tungstenite::{protocol::Role, Message},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tokio::{
|
||||
io::{duplex, AsyncReadExt, AsyncWriteExt},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::{protocol::Role, Message},
|
||||
WebSocketStream,
|
||||
};
|
||||
|
||||
use super::WebSocketRw;
|
||||
|
||||
@@ -210,9 +207,7 @@ mod tests {
|
||||
});
|
||||
|
||||
js.spawn(async move {
|
||||
let mut rw = pin!(WebSocketRw::new(
|
||||
WebSocketStream::from_raw_socket(stream2, Role::Server, None).await
|
||||
));
|
||||
let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2)));
|
||||
|
||||
let mut buf = vec![0; 1024];
|
||||
let n = rw.read(&mut buf).await.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user