diff --git a/Cargo.lock b/Cargo.lock index b1f53404ea..e1edd53fea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -708,7 +708,7 @@ dependencies = [ "sha1", "sync_wrapper", "tokio", - "tokio-tungstenite 0.20.0", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -979,6 +979,12 @@ version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +[[package]] +name = "bytemuck" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" + [[package]] name = "byteorder" version = "1.4.3" @@ -1598,7 +1604,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6943ae99c34386c84a470c499d3414f66502a41340aa895406e0d2e4a207b91d" dependencies = [ "cfg-if", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "lock_api", "once_cell", "parking_lot_core 0.9.8", @@ -1999,6 +2005,27 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "framed-websockets" +version = "0.1.0" +source = "git+https://github.com/neondatabase/framed-websockets#34eff3d6f8cfccbc5f35e4f65314ff7328621127" +dependencies = [ + "base64 0.21.1", + "bytemuck", + "bytes", + "futures-core", + "futures-sink", + "http-body-util", + "hyper 1.2.0", + "hyper-util", + "pin-project", + "rand 0.8.5", + "sha1", + "thiserror", + "tokio", + "tokio-util", +] + [[package]] name = "fs2" version = "0.4.3" @@ -2277,9 +2304,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash", "allocator-api2", @@ -2287,11 +2314,11 @@ dependencies = [ [[package]] name = "hashlink" -version = "0.8.4" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" dependencies = [ - "hashbrown 0.14.0", + "hashbrown 0.14.5", ] [[package]] @@ -2600,21 +2627,6 @@ dependencies = [ "tokio-native-tls", ] -[[package]] -name = "hyper-tungstenite" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad" -dependencies = [ - "http-body-util", - "hyper 1.2.0", - "hyper-util", - "pin-project-lite", - "tokio", - "tokio-tungstenite 0.21.0", - "tungstenite 0.21.0", -] - [[package]] name = "hyper-util" version = "0.1.3" @@ -2692,7 +2704,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad227c3af19d4914570ad36d30409928b75967c298feb9ea1969db3a610bb14e" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown 0.14.5", ] [[package]] @@ -2954,7 +2966,7 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" dependencies = [ - "hashbrown 0.14.0", + "hashbrown 0.14.5", ] [[package]] @@ -3007,7 +3019,7 @@ checksum = "652bc741286361c06de8cb4d89b21a6437f120c508c51713663589eeb9928ac5" dependencies = [ "bytes", "crossbeam-utils", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "itoa", "lasso", "measured-derive", @@ -3569,7 +3581,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49203cdcae0030493bad186b28da2fa25645fa276a51b6fec8010d281e02ef79" dependencies = [ "dlv-list", - "hashbrown 0.14.0", + "hashbrown 0.14.5", ] [[package]] @@ -3896,7 +3908,7 @@ dependencies = [ "ahash", "bytes", "chrono", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", "num-bigint", "paste", @@ -4380,9 +4392,10 @@ dependencies = [ "dashmap", "env_logger", "fallible-iterator", + "framed-websockets", "futures", "git-version", - "hashbrown 0.13.2", + "hashbrown 0.14.5", "hashlink", "hex", "hmac", @@ -4392,7 +4405,6 @@ dependencies = [ "humantime", "hyper 0.14.26", "hyper 1.2.0", - "hyper-tungstenite", "hyper-util", "indexmap 2.0.1", "ipnet", @@ -4437,7 +4449,6 @@ dependencies = [ "smol_str", "socket2 0.5.5", "subtle", - "sync_wrapper", "task-local-extensions", "thiserror", "tikv-jemalloc-ctl", @@ -4446,6 +4457,7 @@ dependencies = [ "tokio-postgres", "tokio-postgres-rustls", "tokio-rustls 0.25.0", + "tokio-tungstenite", "tokio-util", "tower-service", "tracing", @@ -6382,19 +6394,7 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite 0.20.1", -] - -[[package]] -name = "tokio-tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite 0.21.0", + "tungstenite", ] [[package]] @@ -6408,7 +6408,7 @@ dependencies = [ "futures-io", "futures-sink", "futures-util", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "pin-project-lite", "tokio", "tracing", @@ -6690,25 +6690,6 @@ dependencies = [ "utf-8", ] -[[package]] -name = "tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" -dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http 1.1.0", - "httparse", - "log", - "rand 0.8.5", - "sha1", - "thiserror", - "url", - "utf-8", -] - [[package]] name = "twox-hash" version = "1.6.3" @@ -7504,7 +7485,7 @@ dependencies = [ "futures-sink", "futures-util", "getrandom 0.2.11", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "hex", "hmac", "hyper 0.14.26", diff --git a/Cargo.toml b/Cargo.toml index 3ccdabee18..b59a5dcd6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,13 +81,14 @@ enum-map = "2.4.2" enumset = "1.0.12" fail = "0.5.0" fallible-iterator = "0.2" +framed-websockets = { version = "0.1.0", git = "https://github.com/neondatabase/framed-websockets" } fs2 = "0.4.3" futures = "0.3" futures-core = "0.3" futures-util = "0.3" git-version = "0.3" -hashbrown = "0.13" -hashlink = "0.8.4" +hashbrown = "0.14" +hashlink = "0.9.1" hdrhistogram = "7.5.2" hex = "0.4" hex-literal = "0.4" @@ -98,7 +99,7 @@ http-types = { version = "2", default-features = false } humantime = "2.1" humantime-serde = "1.1.1" hyper = "0.14" -hyper-tungstenite = "0.13.0" +tokio-tungstenite = "0.20.0" indexmap = "2" inotify = "0.10.2" ipnet = "2.9.0" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 3002006aed..5f9b0aa75b 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -26,6 +26,7 @@ clap.workspace = true consumption_metrics.workspace = true dashmap.workspace = true env_logger.workspace = true +framed-websockets.workspace = true futures.workspace = true git-version.workspace = true hashbrown.workspace = true @@ -35,7 +36,6 @@ hmac.workspace = true hostname.workspace = true http.workspace = true humantime.workspace = true -hyper-tungstenite.workspace = true hyper.workspace = true hyper1 = { package = "hyper", version = "1.2", features = ["server"] } hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] } @@ -76,7 +76,6 @@ smol_str.workspace = true smallvec.workspace = true socket2.workspace = true subtle.workspace = true -sync_wrapper.workspace = true task-local-extensions.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true @@ -106,6 +105,7 @@ workspace_hack.workspace = true [dev-dependencies] camino-tempfile.workspace = true fallible-iterator.workspace = true +tokio-tungstenite.workspace = true rcgen.workspace = true rstest.workspace = true tokio-postgres-rustls.workspace = true diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index f634ab4e98..24ee749e6e 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -102,7 +102,7 @@ pub async fn task_main( let connections = tokio_util::task::task_tracker::TaskTracker::new(); connections.close(); // allows `connections.wait to complete` - let server = Builder::new(hyper_util::rt::TokioExecutor::new()); + let server = Builder::new(TokioExecutor::new()); while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await { let (conn, peer_addr) = res.context("could not accept TCP stream")?; @@ -255,7 +255,6 @@ async fn connection_handler( .in_current_span() .map_ok_or_else(api_error_into_response, |r| r), ); - async move { let res = handler.await; cancel_request.disarm(); @@ -301,7 +300,7 @@ async fn request_handler( .map(|s| s.to_string()); // Check if the request is a websocket upgrade request. - if hyper_tungstenite::is_upgrade_request(&request) { + if framed_websockets::upgrade::is_upgrade_request(&request) { let ctx = RequestMonitoring::new( session_id, peer_addr, @@ -312,7 +311,7 @@ async fn request_handler( let span = ctx.span.clone(); info!(parent: &span, "performing websocket upgrade"); - let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) + let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) .map_err(|e| ApiError::BadRequest(e.into()))?; ws_connections.spawn( @@ -334,7 +333,7 @@ async fn request_handler( ); // Return the response so the spawned future can continue. - Ok(response) + Ok(response.map(|_: http_body_util::Empty| Full::new(Bytes::new()))) } else if request.uri().path() == "/sql" && *request.method() == Method::POST { let ctx = RequestMonitoring::new( session_id, diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 649bec2c7c..61d6d60dbe 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -7,10 +7,11 @@ use crate::{ proxy::{handle_client, ClientMode}, rate_limiter::EndpointRateLimiter, }; -use bytes::{Buf, Bytes}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use framed_websockets::{Frame, OpCode, WebSocketServer}; use futures::{Sink, Stream}; -use hyper::upgrade::Upgraded; -use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; +use hyper1::upgrade::OnUpgrade; +use hyper_util::rt::TokioIo; use pin_project_lite::pin_project; use std::{ @@ -21,25 +22,23 @@ use std::{ use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use tracing::warn; -// TODO: use `std::sync::Exclusive` once it's stabilized. -// Tracking issue: https://github.com/rust-lang/rust/issues/98407. -use sync_wrapper::SyncWrapper; - pin_project! { /// This is a wrapper around a [`WebSocketStream`] that /// implements [`AsyncRead`] and [`AsyncWrite`]. - pub struct WebSocketRw { + pub struct WebSocketRw { #[pin] - stream: SyncWrapper>, - bytes: Bytes, + stream: WebSocketServer, + recv: Bytes, + send: BytesMut, } } impl WebSocketRw { - pub fn new(stream: WebSocketStream) -> Self { + pub fn new(stream: WebSocketServer) -> Self { Self { - stream: stream.into(), - bytes: Bytes::new(), + stream, + recv: Bytes::new(), + send: BytesMut::new(), } } } @@ -50,22 +49,24 @@ impl AsyncWrite for WebSocketRw { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let mut stream = self.project().stream.get_pin_mut(); + let this = self.project(); + let mut stream = this.stream; + this.send.put(buf); ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?; - match stream.as_mut().start_send(Message::Binary(buf.into())) { + match stream.as_mut().start_send(Frame::binary(this.send.split())) { Ok(()) => Poll::Ready(Ok(buf.len())), Err(e) => Poll::Ready(Err(io_error(e))), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let stream = self.project().stream.get_pin_mut(); + let stream = self.project().stream; stream.poll_flush(cx).map_err(io_error) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let stream = self.project().stream.get_pin_mut(); + let stream = self.project().stream; stream.poll_close(cx).map_err(io_error) } } @@ -76,13 +77,10 @@ impl AsyncRead for WebSocketRw { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - if buf.remaining() > 0 { - let bytes = ready!(self.as_mut().poll_fill_buf(cx))?; - let len = std::cmp::min(bytes.len(), buf.remaining()); - buf.put_slice(&bytes[..len]); - self.consume(len); - } - + let bytes = ready!(self.as_mut().poll_fill_buf(cx))?; + let len = std::cmp::min(bytes.len(), buf.remaining()); + buf.put_slice(&bytes[..len]); + self.consume(len); Poll::Ready(Ok(())) } } @@ -94,31 +92,27 @@ impl AsyncBufRead for WebSocketRw { let mut this = self.project(); loop { - if !this.bytes.chunk().is_empty() { - let chunk = (*this.bytes).chunk(); + if !this.recv.chunk().is_empty() { + let chunk = (*this.recv).chunk(); return Poll::Ready(Ok(chunk)); } - let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx)); + let res = ready!(this.stream.as_mut().poll_next(cx)); match res.transpose().map_err(io_error)? { - Some(message) => match message { - Message::Ping(_) => {} - Message::Pong(_) => {} - Message::Text(text) => { + Some(message) => match message.opcode { + OpCode::Ping => {} + OpCode::Pong => {} + OpCode::Text => { // We expect to see only binary messages. let error = "unexpected text message in the websocket"; - warn!(length = text.len(), error); + warn!(length = message.payload.len(), error); return Poll::Ready(Err(io_error(error))); } - Message::Frame(_) => { - // This case is impossible according to Frame's doc. - panic!("unexpected raw frame in the websocket"); + OpCode::Binary | OpCode::Continuation => { + debug_assert!(this.recv.is_empty()); + *this.recv = message.payload.freeze(); } - Message::Binary(chunk) => { - assert!(this.bytes.is_empty()); - *this.bytes = Bytes::from(chunk); - } - Message::Close(_) => return EOF, + OpCode::Close => return EOF, }, None => return EOF, } @@ -126,19 +120,21 @@ impl AsyncBufRead for WebSocketRw { } fn consume(self: Pin<&mut Self>, amount: usize) { - self.project().bytes.advance(amount); + self.project().recv.advance(amount); } } pub async fn serve_websocket( config: &'static ProxyConfig, mut ctx: RequestMonitoring, - websocket: HyperWebsocket, + websocket: OnUpgrade, cancellation_handler: Arc, endpoint_rate_limiter: Arc, hostname: Option, ) -> anyhow::Result<()> { let websocket = websocket.await?; + let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket)); + let conn_gauge = Metrics::get() .proxy .client_connections @@ -177,15 +173,16 @@ pub async fn serve_websocket( mod tests { use std::pin::pin; + use framed_websockets::WebSocketServer; use futures::{SinkExt, StreamExt}; - use hyper_tungstenite::{ - tungstenite::{protocol::Role, Message}, - WebSocketStream, - }; use tokio::{ io::{duplex, AsyncReadExt, AsyncWriteExt}, task::JoinSet, }; + use tokio_tungstenite::{ + tungstenite::{protocol::Role, Message}, + WebSocketStream, + }; use super::WebSocketRw; @@ -210,9 +207,7 @@ mod tests { }); js.spawn(async move { - let mut rw = pin!(WebSocketRw::new( - WebSocketStream::from_raw_socket(stream2, Role::Server, None).await - )); + let mut rw = pin!(WebSocketRw::new(WebSocketServer::after_handshake(stream2))); let mut buf = vec![0; 1024]; let n = rw.read(&mut buf).await.unwrap(); diff --git a/test_runner/regress/test_proxy_websockets.py b/test_runner/regress/test_proxy_websockets.py index 6d1cb9765a..6211446a40 100644 --- a/test_runner/regress/test_proxy_websockets.py +++ b/test_runner/regress/test_proxy_websockets.py @@ -135,7 +135,14 @@ async def test_websockets_pipelined(static_proxy: NeonProxy): query_message = "SELECT 1".encode("utf-8") + b"\0" length2 = (4 + len(query_message)).to_bytes(4, byteorder="big") await websocket.send( - [length0, startup_message, b"p", length1, auth_message, b"Q", length2, query_message] + length0 + + startup_message + + b"p" + + length1 + + auth_message + + b"Q" + + length2 + + query_message ) startup_response = await websocket.recv()