maybe works with ws over http2?

This commit is contained in:
Conrad Ludgate
2024-01-11 15:27:38 +00:00
parent 520171f17a
commit a318213e72
3 changed files with 136 additions and 24 deletions

View File

@@ -58,7 +58,7 @@ pub fn endpoint_sni<'a>(
common_names: &HashSet<String>,
) -> Result<(&'a str, &'a str), ComputeUserInfoParseError> {
let Some((subdomain, common_name)) = sni.split_once('.') else {
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
return Ok((sni, ""));
};
if !common_names.contains(common_name) {
return Err(ComputeUserInfoParseError::UnknownCommonName {

View File

@@ -8,10 +8,12 @@ mod websocket;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::{bail, Context};
use anyhow::bail;
use hyper::ext::Protocol;
use hyper::upgrade::OnUpgrade;
use hyper::StatusCode;
use hyper_tungstenite::tungstenite::error::{Error as WSError, ProtocolError};
use hyper_tungstenite::tungstenite::protocol::Role;
use hyper_tungstenite::WebSocketStream;
use metrics::IntCounterPairGuard;
use rand::rngs::StdRng;
use rand::SeedableRng;
@@ -35,6 +37,7 @@ use hyper::{
};
use std::net::IpAddr;
use std::pin::Pin;
use std::task::Poll;
use std::{future::ready, sync::Arc};
use tls_listener::TlsListener;
@@ -249,25 +252,43 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.method() == Method::CONNECT {
// request.
dbg!(request.headers());
let _upgrade = request
.extensions_mut()
.remove::<OnUpgrade>()
.context("missing upgrade")
.map_err(ApiError::InternalServerError)?;
let protocol = request
.extensions_mut()
.remove::<Protocol>()
.context("missing protocol")
.map_err(ApiError::InternalServerError)?;
} else if is_upgrade2_request(&request) {
info!(session_id = ?session_id, "performing http2 websocket upgrade");
tracing::info!(protocol = protocol.as_str(), "http2 connect???");
let (response, websocket) =
upgrade_http2(&mut request).map_err(|e| ApiError::BadRequest(e.into()))?;
Err(ApiError::InternalServerError(anyhow::anyhow!(
"not yet supported"
)))
let host = request
.headers()
.get("neon-endpoint")
.map(|t| t.to_str())
.transpose()
.map_err(|e| ApiError::BadRequest(e.into()))?
.map(|s| s.to_owned())
.or(host);
ws_connections.spawn(
async move {
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
if let Err(e) = websocket::serve_websocket(
config,
&mut ctx,
websocket,
&cancel_map,
host,
endpoint_rate_limiter,
)
.await
{
error!(session_id = ?session_id, "error in http2 websocket connection: {e:#}");
}
}
.in_current_span(),
);
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
@@ -282,11 +303,11 @@ async fn request_handler(
.await
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")
.header("Allow", "OPTIONS, POST, CONNECT")
.header("Access-Control-Allow-Origin", "*")
.header(
"Access-Control-Allow-Headers",
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
"Neon-Endpoint, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
@@ -296,3 +317,91 @@ async fn request_handler(
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
pin_project_lite::pin_project! {
/// A future that resolves to a websocket stream when the associated HTTP2 connect completes.
#[derive(Debug)]
pub struct HyperWebsocket2 {
#[pin]
inner: hyper::upgrade::OnUpgrade,
}
}
/// Try to upgrade a received `hyper::Request` to a websocket connection.
///
/// The function returns a HTTP response and a future that resolves to the websocket stream.
/// The response body *MUST* be sent to the client before the future can be resolved.
///
/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers.
/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
/// You can inspect the headers manually before calling this function,
/// and modify the response headers appropriately.
///
/// This function also does not look at the `Connection` or `Upgrade` headers.
/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`].
/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
///
fn upgrade_http2<B>(
mut request: impl std::borrow::BorrowMut<Request<B>>,
) -> Result<(Response<Body>, HyperWebsocket2), ProtocolError> {
let request = request.borrow_mut();
if request
.headers()
.get("Sec-WebSocket-Version")
.map(|v| v.as_bytes())
!= Some(b"13")
{
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
}
let response = Response::builder()
.status(hyper::StatusCode::OK)
.body(Body::from("switching to websocket protocol"))
.expect("bug: failed to build response");
let stream = HyperWebsocket2 {
inner: hyper::upgrade::on(request),
};
Ok((response, stream))
}
/// Check if a request is a websocket upgrade request.
///
/// If the `Upgrade` header lists multiple protocols,
/// this function returns true if of them are `"websocket"`,
/// If the server supports multiple upgrade protocols,
/// it would be more appropriate to try each listed protocol in order.
pub fn is_upgrade2_request<B>(request: &hyper::Request<B>) -> bool {
request.method() == Method::CONNECT
&& request
.extensions()
.get::<Protocol>()
.is_some_and(|protocol| protocol.as_str() == "websocket")
}
impl std::future::Future for HyperWebsocket2 {
type Output = Result<WebSocketStream<hyper::upgrade::Upgraded>, WSError>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
let this = self.project();
let upgraded = match this.inner.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(x) => x,
};
let upgraded =
upgraded.map_err(|_| WSError::Protocol(ProtocolError::HandshakeIncomplete))?;
let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None);
tokio::pin!(stream);
// The future returned by `from_raw_socket` is always ready.
// Not sure why it is a future in the first place.
match stream.as_mut().poll(cx) {
Poll::Pending => unreachable!("from_raw_socket should always be created ready"),
Poll::Ready(x) => Poll::Ready(Ok(x)),
}
}
}

View File

@@ -9,10 +9,11 @@ use crate::{
use bytes::{Buf, Bytes};
use futures::{Sink, Stream};
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
use hyper_tungstenite::{tungstenite::Message, WebSocketStream};
use pin_project_lite::pin_project;
use std::{
future::IntoFuture,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
@@ -132,7 +133,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub async fn serve_websocket(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
websocket: HyperWebsocket,
websocket: impl IntoFuture<
Output = Result<WebSocketStream<Upgraded>, hyper_tungstenite::tungstenite::Error>,
>,
cancel_map: &CancelMap,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,