diff --git a/Cargo.lock b/Cargo.lock index 512145f6c8..7af5518fc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2389,19 +2389,6 @@ dependencies = [ "tokio-native-tls", ] -[[package]] -name = "hyper-tungstenite" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9" -dependencies = [ - "hyper", - "pin-project-lite", - "tokio", - "tokio-tungstenite", - "tungstenite", -] - [[package]] name = "iana-time-zone" version = "0.1.56" @@ -3893,7 +3880,6 @@ dependencies = [ "hostname", "humantime", "hyper", - "hyper-tungstenite", "ipnet", "itertools", "md5", @@ -3939,11 +3925,13 @@ dependencies = [ "tokio-postgres", "tokio-postgres-rustls", "tokio-rustls", + "tokio-tungstenite", "tokio-util", "tracing", "tracing-opentelemetry", "tracing-subscriber", "tracing-utils", + "tungstenite", "url", "utils", "uuid", diff --git a/Cargo.toml b/Cargo.toml index 2d8fbaffa8..a910e8364a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,7 +89,6 @@ http-types = { version = "2", default-features = false } humantime = "2.1" humantime-serde = "1.1.1" hyper = "0.14" -hyper-tungstenite = "0.11" inotify = "0.10.2" ipnet = "2.9.0" itertools = "0.10" @@ -156,6 +155,7 @@ tokio-rustls = "0.24" tokio-stream = "0.1" tokio-tar = "0.3" tokio-util = { version = "0.7.10", features = ["io", "rt"] } +tokio-tungstenite = "0.20" toml = "0.7" toml_edit = "0.19" tonic = {version = "0.9", features = ["tls", "tls-roots"]} @@ -163,6 +163,7 @@ tracing = "0.1" tracing-error = "0.2.0" tracing-opentelemetry = "0.19.0" tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] } +tungstenite = "0.20" url = "2.2" uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] } walkdir = "2.3.2" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 23a9bb178d..a4557b2bd4 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -27,7 +27,6 @@ hex.workspace = true hmac.workspace = true hostname.workspace = true humantime.workspace = true -hyper-tungstenite.workspace = true hyper.workspace = true ipnet.workspace = true itertools.workspace = true @@ -66,11 +65,13 @@ tls-listener.workspace = true tokio-postgres.workspace = true tokio-rustls.workspace = true tokio-util.workspace = true +tokio-tungstenite.workspace = true tokio = { workspace = true, features = ["signal"] } tracing-opentelemetry.workspace = true tracing-subscriber.workspace = true tracing-utils.workspace = true tracing.workspace = true +tungstenite.workspace = true url.workspace = true utils.workspace = true uuid.workspace = true diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 03337cd867..78e93590af 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -9,11 +9,7 @@ mod websocket; pub use conn_pool::GlobalConnPoolOptions; use anyhow::bail; -use hyper::ext::Protocol; use hyper::StatusCode; -use hyper_tungstenite::tungstenite::error::{Error as WSError, ProtocolError}; -use hyper_tungstenite::tungstenite::protocol::Role; -use hyper_tungstenite::WebSocketStream; use metrics::IntCounterPairGuard; use rand::rngs::StdRng; use rand::SeedableRng; @@ -37,7 +33,6 @@ use hyper::{ }; use std::net::IpAddr; -use std::pin::Pin; use std::task::Poll; use std::{future::ready, sync::Arc}; use tls_listener::TlsListener; @@ -223,11 +218,13 @@ async fn request_handler( .and_then(|h| h.split(':').next()) .map(|s| s.to_string()); + let ws_config = None; + // Check if the request is a websocket upgrade request. - if hyper_tungstenite::is_upgrade_request(&request) { + if websocket::is_upgrade_request(&request) { info!(session_id = ?session_id, "performing websocket upgrade"); - let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) + let (response, websocket) = websocket::upgrade(&mut request, ws_config) .map_err(|e| ApiError::BadRequest(e.into()))?; ws_connections.spawn( @@ -252,11 +249,11 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response) - } else if is_upgrade2_request(&request) { + } else if websocket::is_connect_request(&request) { info!(session_id = ?session_id, "performing http2 websocket upgrade"); - let (response, websocket) = - upgrade_http2(&mut request).map_err(|e| ApiError::BadRequest(e.into()))?; + let (response, websocket) = websocket::connect(&mut request, ws_config) + .map_err(|e| ApiError::BadRequest(e.into()))?; ws_connections.spawn( async move { @@ -298,7 +295,7 @@ async fn request_handler( .header("Access-Control-Allow-Origin", "*") .header( "Access-Control-Allow-Headers", - "Neon-Endpoint, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level", + "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level", ) .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code @@ -308,86 +305,3 @@ async fn request_handler( json_response(StatusCode::BAD_REQUEST, "query is not supported") } } - -pin_project_lite::pin_project! { - /// A future that resolves to a websocket stream when the associated HTTP2 connect completes. - #[derive(Debug)] - pub struct HyperWebsocket2 { - #[pin] - inner: hyper::upgrade::OnUpgrade, - } -} - -/// Try to upgrade a received `hyper::Request` to a websocket connection. -/// -/// The function returns a HTTP response and a future that resolves to the websocket stream. -/// The response body *MUST* be sent to the client before the future can be resolved. -/// -/// This functions checks `Sec-WebSocket-Version` header. -/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers. -/// You can inspect the headers manually before calling this function, -/// and modify the response headers appropriately. -/// -/// This function also does not look at the `Connection` or `Upgrade` headers. -/// To check if a request is a websocket upgrade request, you can use [`is_upgrade2_request`]. -/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually. -/// -fn upgrade_http2( - mut request: impl std::borrow::BorrowMut>, -) -> Result<(Response, HyperWebsocket2), ProtocolError> { - let request = request.borrow_mut(); - - if request - .headers() - .get("Sec-WebSocket-Version") - .map(|v| v.as_bytes()) - != Some(b"13") - { - return Err(ProtocolError::MissingSecWebSocketVersionHeader); - } - - let response = Response::builder() - .status(hyper::StatusCode::OK) - .body(Body::from("switching to websocket protocol")) - .expect("bug: failed to build response"); - - let stream = HyperWebsocket2 { - inner: hyper::upgrade::on(request), - }; - - Ok((response, stream)) -} - -/// Check if a request is a websocket connect request. -pub fn is_upgrade2_request(request: &hyper::Request) -> bool { - request.method() == Method::CONNECT - && request - .extensions() - .get::() - .is_some_and(|protocol| protocol.as_str() == "websocket") -} - -impl std::future::Future for HyperWebsocket2 { - type Output = Result, WSError>; - - fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { - let this = self.project(); - let upgraded = match this.inner.poll(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(x) => x, - }; - - let upgraded = - upgraded.map_err(|_| WSError::Protocol(ProtocolError::HandshakeIncomplete))?; - - let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None); - tokio::pin!(stream); - - // The future returned by `from_raw_socket` is always ready. - // Not sure why it is a future in the first place. - match stream.as_mut().poll(cx) { - Poll::Pending => unreachable!("from_raw_socket should always be created ready"), - Poll::Ready(x) => Poll::Ready(Ok(x)), - } - } -} diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index bb8690bf8a..cfdaf93cd4 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -8,12 +8,17 @@ use crate::{ }; use bytes::{Buf, Bytes}; use futures::{Sink, Stream}; -use hyper::upgrade::Upgraded; -use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; +use hyper::{ext::Protocol, upgrade::Upgraded, Body, Method, Request, Response}; use pin_project_lite::pin_project; +use tokio_tungstenite::WebSocketStream; +use tungstenite::{ + error::{Error as WSError, ProtocolError}, + handshake::derive_accept_key, + protocol::{Role, WebSocketConfig}, + Message, +}; use std::{ - future::IntoFuture, pin::Pin, sync::Arc, task::{ready, Context, Poll}, @@ -133,9 +138,7 @@ impl AsyncBufRead for WebSocketRw { pub async fn serve_websocket( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - websocket: impl IntoFuture< - Output = Result, hyper_tungstenite::tungstenite::Error>, - >, + websocket: HyperWebsocket, cancel_map: &CancelMap, hostname: Option, endpoint_rate_limiter: Arc, @@ -153,19 +156,202 @@ pub async fn serve_websocket( Ok(()) } +/// Try to upgrade a received `hyper::Request` to a websocket connection. +/// +/// The function returns a HTTP response and a future that resolves to the websocket stream. +/// The response body *MUST* be sent to the client before the future can be resolved. +/// +/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers. +/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers. +/// You can inspect the headers manually before calling this function, +/// and modify the response headers appropriately. +/// +/// This function also does not look at the `Connection` or `Upgrade` headers. +/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`]. +/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually. +/// +pub fn upgrade( + mut request: impl std::borrow::BorrowMut>, + config: Option, +) -> Result<(Response, HyperWebsocket), ProtocolError> { + let request = request.borrow_mut(); + + let key = request + .headers() + .get("Sec-WebSocket-Key") + .ok_or(ProtocolError::MissingSecWebSocketKey)?; + if request + .headers() + .get("Sec-WebSocket-Version") + .map(|v| v.as_bytes()) + != Some(b"13") + { + return Err(ProtocolError::MissingSecWebSocketVersionHeader); + } + + let response = Response::builder() + .status(hyper::StatusCode::SWITCHING_PROTOCOLS) + .header(hyper::header::CONNECTION, "upgrade") + .header(hyper::header::UPGRADE, "websocket") + .header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes())) + .body(Body::from("switching to websocket protocol")) + .expect("bug: failed to build response"); + + let stream = HyperWebsocket { + inner: hyper::upgrade::on(request), + config, + }; + + Ok((response, stream)) +} + +/// Check if a request is a websocket upgrade request. +/// +/// If the `Upgrade` header lists multiple protocols, +/// this function returns true if of them are `"websocket"`, +/// If the server supports multiple upgrade protocols, +/// it would be more appropriate to try each listed protocol in order. +pub fn is_upgrade_request(request: &hyper::Request) -> bool { + header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade") + && header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket") +} + +/// Check if there is a header of the given name containing the wanted value. +fn header_contains_value( + headers: &hyper::HeaderMap, + header: impl hyper::header::AsHeaderName, + value: impl AsRef<[u8]>, +) -> bool { + let value = value.as_ref(); + for header in headers.get_all(header) { + if header + .as_bytes() + .split(|&c| c == b',') + .any(|x| trim(x).eq_ignore_ascii_case(value)) + { + return true; + } + } + false +} + +fn trim(data: &[u8]) -> &[u8] { + trim_end(trim_start(data)) +} + +fn trim_start(data: &[u8]) -> &[u8] { + if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) { + &data[start..] + } else { + b"" + } +} + +fn trim_end(data: &[u8]) -> &[u8] { + if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) { + &data[..last + 1] + } else { + b"" + } +} + +/// Try to upgrade a received `hyper::Request` to a websocket connection. +/// +/// The function returns a HTTP response and a future that resolves to the websocket stream. +/// The response body *MUST* be sent to the client before the future can be resolved. +/// +/// This functions checks `Sec-WebSocket-Version` header. +/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers. +/// You can inspect the headers manually before calling this function, +/// and modify the response headers appropriately. +/// +/// This function also does not look at the `Connection` or `Upgrade` headers. +/// To check if a request is a websocket upgrade request, you can use [`is_upgrade2_request`]. +/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually. +/// +pub fn connect( + mut request: impl std::borrow::BorrowMut>, + config: Option, +) -> Result<(Response, HyperWebsocket), ProtocolError> { + let request = request.borrow_mut(); + + if request + .headers() + .get("Sec-WebSocket-Version") + .map(|v| v.as_bytes()) + != Some(b"13") + { + return Err(ProtocolError::MissingSecWebSocketVersionHeader); + } + + let response = Response::builder() + .status(hyper::StatusCode::OK) + .body(Body::from("switching to websocket protocol")) + .expect("bug: failed to build response"); + + let stream = HyperWebsocket { + inner: hyper::upgrade::on(request), + config, + }; + + Ok((response, stream)) +} + +/// Check if a request is a websocket connect request. +pub fn is_connect_request(request: &hyper::Request) -> bool { + request.method() == Method::CONNECT + && request + .extensions() + .get::() + .is_some_and(|protocol| protocol.as_str() == "websocket") +} + +pin_project_lite::pin_project! { + /// A future that resolves to a websocket stream when the associated connection completes. + #[derive(Debug)] + pub struct HyperWebsocket { + #[pin] + inner: hyper::upgrade::OnUpgrade, + config: Option + } +} + +impl std::future::Future for HyperWebsocket { + type Output = Result, WSError>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { + let this = self.project(); + let upgraded = match this.inner.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(x) => x, + }; + + let upgraded = + upgraded.map_err(|_| WSError::Protocol(ProtocolError::HandshakeIncomplete))?; + + let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None); + tokio::pin!(stream); + + // The future returned by `from_raw_socket` is always ready. + // Not sure why it is a future in the first place. + match stream.as_mut().poll(cx) { + Poll::Pending => unreachable!("from_raw_socket should always be created ready"), + Poll::Ready(x) => Poll::Ready(Ok(x)), + } + } +} + #[cfg(test)] mod tests { use std::pin::pin; use futures::{SinkExt, StreamExt}; - use hyper_tungstenite::{ - tungstenite::{protocol::Role, Message}, - WebSocketStream, - }; use tokio::{ io::{duplex, AsyncReadExt, AsyncWriteExt}, task::JoinSet, }; + use tokio_tungstenite::WebSocketStream; + use tungstenite::{protocol::Role, Message}; use super::WebSocketRw;