From a318213e722772f43e08cff1ce4cc8e703bccb39 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 11 Jan 2024 15:27:38 +0000 Subject: [PATCH] maybe works with ws over http2? --- proxy/src/auth/credentials.rs | 2 +- proxy/src/serverless.rs | 151 +++++++++++++++++++++++++----- proxy/src/serverless/websocket.rs | 7 +- 3 files changed, 136 insertions(+), 24 deletions(-) diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 546eb516d4..7ec003aad1 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -58,7 +58,7 @@ pub fn endpoint_sni<'a>( common_names: &HashSet, ) -> Result<(&'a str, &'a str), ComputeUserInfoParseError> { let Some((subdomain, common_name)) = sni.split_once('.') else { - return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() }); + return Ok((sni, "")); }; if !common_names.contains(common_name) { return Err(ComputeUserInfoParseError::UnknownCommonName { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 2f632ec8da..59281bbbb5 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -8,10 +8,12 @@ mod websocket; pub use conn_pool::GlobalConnPoolOptions; -use anyhow::{bail, Context}; +use anyhow::bail; use hyper::ext::Protocol; -use hyper::upgrade::OnUpgrade; use hyper::StatusCode; +use hyper_tungstenite::tungstenite::error::{Error as WSError, ProtocolError}; +use hyper_tungstenite::tungstenite::protocol::Role; +use hyper_tungstenite::WebSocketStream; use metrics::IntCounterPairGuard; use rand::rngs::StdRng; use rand::SeedableRng; @@ -35,6 +37,7 @@ use hyper::{ }; use std::net::IpAddr; +use std::pin::Pin; use std::task::Poll; use std::{future::ready, sync::Arc}; use tls_listener::TlsListener; @@ -249,25 +252,43 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response) - } else if request.method() == Method::CONNECT { - // request. - dbg!(request.headers()); - let _upgrade = request - .extensions_mut() - .remove::() - .context("missing upgrade") - .map_err(ApiError::InternalServerError)?; - let protocol = request - .extensions_mut() - .remove::() - .context("missing protocol") - .map_err(ApiError::InternalServerError)?; + } else if is_upgrade2_request(&request) { + info!(session_id = ?session_id, "performing http2 websocket upgrade"); - tracing::info!(protocol = protocol.as_str(), "http2 connect???"); + let (response, websocket) = + upgrade_http2(&mut request).map_err(|e| ApiError::BadRequest(e.into()))?; - Err(ApiError::InternalServerError(anyhow::anyhow!( - "not yet supported" - ))) + let host = request + .headers() + .get("neon-endpoint") + .map(|t| t.to_str()) + .transpose() + .map_err(|e| ApiError::BadRequest(e.into()))? + .map(|s| s.to_owned()) + .or(host); + + ws_connections.spawn( + async move { + let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region); + + if let Err(e) = websocket::serve_websocket( + config, + &mut ctx, + websocket, + &cancel_map, + host, + endpoint_rate_limiter, + ) + .await + { + error!(session_id = ?session_id, "error in http2 websocket connection: {e:#}"); + } + } + .in_current_span(), + ); + + // Return the response so the spawned future can continue. + Ok(response) } else if request.uri().path() == "/sql" && request.method() == Method::POST { let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); @@ -282,11 +303,11 @@ async fn request_handler( .await } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { Response::builder() - .header("Allow", "OPTIONS, POST") + .header("Allow", "OPTIONS, POST, CONNECT") .header("Access-Control-Allow-Origin", "*") .header( "Access-Control-Allow-Headers", - "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level", + "Neon-Endpoint, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level", ) .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code @@ -296,3 +317,91 @@ async fn request_handler( json_response(StatusCode::BAD_REQUEST, "query is not supported") } } + +pin_project_lite::pin_project! { + /// A future that resolves to a websocket stream when the associated HTTP2 connect completes. + #[derive(Debug)] + pub struct HyperWebsocket2 { + #[pin] + inner: hyper::upgrade::OnUpgrade, + } +} + +/// Try to upgrade a received `hyper::Request` to a websocket connection. +/// +/// The function returns a HTTP response and a future that resolves to the websocket stream. +/// The response body *MUST* be sent to the client before the future can be resolved. +/// +/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers. +/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers. +/// You can inspect the headers manually before calling this function, +/// and modify the response headers appropriately. +/// +/// This function also does not look at the `Connection` or `Upgrade` headers. +/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`]. +/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually. +/// +fn upgrade_http2( + mut request: impl std::borrow::BorrowMut>, +) -> Result<(Response, HyperWebsocket2), ProtocolError> { + let request = request.borrow_mut(); + + if request + .headers() + .get("Sec-WebSocket-Version") + .map(|v| v.as_bytes()) + != Some(b"13") + { + return Err(ProtocolError::MissingSecWebSocketVersionHeader); + } + + let response = Response::builder() + .status(hyper::StatusCode::OK) + .body(Body::from("switching to websocket protocol")) + .expect("bug: failed to build response"); + + let stream = HyperWebsocket2 { + inner: hyper::upgrade::on(request), + }; + + Ok((response, stream)) +} + +/// Check if a request is a websocket upgrade request. +/// +/// If the `Upgrade` header lists multiple protocols, +/// this function returns true if of them are `"websocket"`, +/// If the server supports multiple upgrade protocols, +/// it would be more appropriate to try each listed protocol in order. +pub fn is_upgrade2_request(request: &hyper::Request) -> bool { + request.method() == Method::CONNECT + && request + .extensions() + .get::() + .is_some_and(|protocol| protocol.as_str() == "websocket") +} + +impl std::future::Future for HyperWebsocket2 { + type Output = Result, WSError>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { + let this = self.project(); + let upgraded = match this.inner.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(x) => x, + }; + + let upgraded = + upgraded.map_err(|_| WSError::Protocol(ProtocolError::HandshakeIncomplete))?; + + let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None); + tokio::pin!(stream); + + // The future returned by `from_raw_socket` is always ready. + // Not sure why it is a future in the first place. + match stream.as_mut().poll(cx) { + Poll::Pending => unreachable!("from_raw_socket should always be created ready"), + Poll::Ready(x) => Poll::Ready(Ok(x)), + } + } +} diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index a6529c920a..bb8690bf8a 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -9,10 +9,11 @@ use crate::{ use bytes::{Buf, Bytes}; use futures::{Sink, Stream}; use hyper::upgrade::Upgraded; -use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; +use hyper_tungstenite::{tungstenite::Message, WebSocketStream}; use pin_project_lite::pin_project; use std::{ + future::IntoFuture, pin::Pin, sync::Arc, task::{ready, Context, Poll}, @@ -132,7 +133,9 @@ impl AsyncBufRead for WebSocketRw { pub async fn serve_websocket( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - websocket: HyperWebsocket, + websocket: impl IntoFuture< + Output = Result, hyper_tungstenite::tungstenite::Error>, + >, cancel_map: &CancelMap, hostname: Option, endpoint_rate_limiter: Arc,