remove hyper-tungstenite

This commit is contained in:
Conrad Ludgate
2024-01-11 18:01:11 +00:00
parent 07ccaa7575
commit 081e794878
5 changed files with 210 additions and 120 deletions

16
Cargo.lock generated
View File

@@ -2389,19 +2389,6 @@ dependencies = [
"tokio-native-tls",
]
[[package]]
name = "hyper-tungstenite"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
dependencies = [
"hyper",
"pin-project-lite",
"tokio",
"tokio-tungstenite",
"tungstenite",
]
[[package]]
name = "iana-time-zone"
version = "0.1.56"
@@ -3893,7 +3880,6 @@ dependencies = [
"hostname",
"humantime",
"hyper",
"hyper-tungstenite",
"ipnet",
"itertools",
"md5",
@@ -3939,11 +3925,13 @@ dependencies = [
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls",
"tokio-tungstenite",
"tokio-util",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"tracing-utils",
"tungstenite",
"url",
"utils",
"uuid",

View File

@@ -89,7 +89,6 @@ http-types = { version = "2", default-features = false }
humantime = "2.1"
humantime-serde = "1.1.1"
hyper = "0.14"
hyper-tungstenite = "0.11"
inotify = "0.10.2"
ipnet = "2.9.0"
itertools = "0.10"
@@ -156,6 +155,7 @@ tokio-rustls = "0.24"
tokio-stream = "0.1"
tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
tokio-tungstenite = "0.20"
toml = "0.7"
toml_edit = "0.19"
tonic = {version = "0.9", features = ["tls", "tls-roots"]}
@@ -163,6 +163,7 @@ tracing = "0.1"
tracing-error = "0.2.0"
tracing-opentelemetry = "0.19.0"
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
tungstenite = "0.20"
url = "2.2"
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"

View File

@@ -27,7 +27,6 @@ hex.workspace = true
hmac.workspace = true
hostname.workspace = true
humantime.workspace = true
hyper-tungstenite.workspace = true
hyper.workspace = true
ipnet.workspace = true
itertools.workspace = true
@@ -66,11 +65,13 @@ tls-listener.workspace = true
tokio-postgres.workspace = true
tokio-rustls.workspace = true
tokio-util.workspace = true
tokio-tungstenite.workspace = true
tokio = { workspace = true, features = ["signal"] }
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true
tracing.workspace = true
tungstenite.workspace = true
url.workspace = true
utils.workspace = true
uuid.workspace = true

View File

@@ -9,11 +9,7 @@ mod websocket;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::bail;
use hyper::ext::Protocol;
use hyper::StatusCode;
use hyper_tungstenite::tungstenite::error::{Error as WSError, ProtocolError};
use hyper_tungstenite::tungstenite::protocol::Role;
use hyper_tungstenite::WebSocketStream;
use metrics::IntCounterPairGuard;
use rand::rngs::StdRng;
use rand::SeedableRng;
@@ -37,7 +33,6 @@ use hyper::{
};
use std::net::IpAddr;
use std::pin::Pin;
use std::task::Poll;
use std::{future::ready, sync::Arc};
use tls_listener::TlsListener;
@@ -223,11 +218,13 @@ async fn request_handler(
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
let ws_config = None;
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&request) {
if websocket::is_upgrade_request(&request) {
info!(session_id = ?session_id, "performing websocket upgrade");
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
let (response, websocket) = websocket::upgrade(&mut request, ws_config)
.map_err(|e| ApiError::BadRequest(e.into()))?;
ws_connections.spawn(
@@ -252,11 +249,11 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if is_upgrade2_request(&request) {
} else if websocket::is_connect_request(&request) {
info!(session_id = ?session_id, "performing http2 websocket upgrade");
let (response, websocket) =
upgrade_http2(&mut request).map_err(|e| ApiError::BadRequest(e.into()))?;
let (response, websocket) = websocket::connect(&mut request, ws_config)
.map_err(|e| ApiError::BadRequest(e.into()))?;
ws_connections.spawn(
async move {
@@ -298,7 +295,7 @@ async fn request_handler(
.header("Access-Control-Allow-Origin", "*")
.header(
"Access-Control-Allow-Headers",
"Neon-Endpoint, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
@@ -308,86 +305,3 @@ async fn request_handler(
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
pin_project_lite::pin_project! {
/// A future that resolves to a websocket stream when the associated HTTP2 connect completes.
#[derive(Debug)]
pub struct HyperWebsocket2 {
#[pin]
inner: hyper::upgrade::OnUpgrade,
}
}
/// Try to upgrade a received `hyper::Request` to a websocket connection.
///
/// The function returns a HTTP response and a future that resolves to the websocket stream.
/// The response body *MUST* be sent to the client before the future can be resolved.
///
/// This functions checks `Sec-WebSocket-Version` header.
/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
/// You can inspect the headers manually before calling this function,
/// and modify the response headers appropriately.
///
/// This function also does not look at the `Connection` or `Upgrade` headers.
/// To check if a request is a websocket upgrade request, you can use [`is_upgrade2_request`].
/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
///
fn upgrade_http2<B>(
mut request: impl std::borrow::BorrowMut<Request<B>>,
) -> Result<(Response<Body>, HyperWebsocket2), ProtocolError> {
let request = request.borrow_mut();
if request
.headers()
.get("Sec-WebSocket-Version")
.map(|v| v.as_bytes())
!= Some(b"13")
{
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
}
let response = Response::builder()
.status(hyper::StatusCode::OK)
.body(Body::from("switching to websocket protocol"))
.expect("bug: failed to build response");
let stream = HyperWebsocket2 {
inner: hyper::upgrade::on(request),
};
Ok((response, stream))
}
/// Check if a request is a websocket connect request.
pub fn is_upgrade2_request<B>(request: &hyper::Request<B>) -> bool {
request.method() == Method::CONNECT
&& request
.extensions()
.get::<Protocol>()
.is_some_and(|protocol| protocol.as_str() == "websocket")
}
impl std::future::Future for HyperWebsocket2 {
type Output = Result<WebSocketStream<hyper::upgrade::Upgraded>, WSError>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
let this = self.project();
let upgraded = match this.inner.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(x) => x,
};
let upgraded =
upgraded.map_err(|_| WSError::Protocol(ProtocolError::HandshakeIncomplete))?;
let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None);
tokio::pin!(stream);
// The future returned by `from_raw_socket` is always ready.
// Not sure why it is a future in the first place.
match stream.as_mut().poll(cx) {
Poll::Pending => unreachable!("from_raw_socket should always be created ready"),
Poll::Ready(x) => Poll::Ready(Ok(x)),
}
}
}

View File

@@ -8,12 +8,17 @@ use crate::{
};
use bytes::{Buf, Bytes};
use futures::{Sink, Stream};
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{tungstenite::Message, WebSocketStream};
use hyper::{ext::Protocol, upgrade::Upgraded, Body, Method, Request, Response};
use pin_project_lite::pin_project;
use tokio_tungstenite::WebSocketStream;
use tungstenite::{
error::{Error as WSError, ProtocolError},
handshake::derive_accept_key,
protocol::{Role, WebSocketConfig},
Message,
};
use std::{
future::IntoFuture,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
@@ -133,9 +138,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub async fn serve_websocket(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
websocket: impl IntoFuture<
Output = Result<WebSocketStream<Upgraded>, hyper_tungstenite::tungstenite::Error>,
>,
websocket: HyperWebsocket,
cancel_map: &CancelMap,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -153,19 +156,202 @@ pub async fn serve_websocket(
Ok(())
}
/// Try to upgrade a received `hyper::Request` to a websocket connection.
///
/// The function returns a HTTP response and a future that resolves to the websocket stream.
/// The response body *MUST* be sent to the client before the future can be resolved.
///
/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers.
/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
/// You can inspect the headers manually before calling this function,
/// and modify the response headers appropriately.
///
/// This function also does not look at the `Connection` or `Upgrade` headers.
/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`].
/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
///
pub fn upgrade<B>(
mut request: impl std::borrow::BorrowMut<Request<B>>,
config: Option<WebSocketConfig>,
) -> Result<(Response<Body>, HyperWebsocket), ProtocolError> {
let request = request.borrow_mut();
let key = request
.headers()
.get("Sec-WebSocket-Key")
.ok_or(ProtocolError::MissingSecWebSocketKey)?;
if request
.headers()
.get("Sec-WebSocket-Version")
.map(|v| v.as_bytes())
!= Some(b"13")
{
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
}
let response = Response::builder()
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
.header(hyper::header::CONNECTION, "upgrade")
.header(hyper::header::UPGRADE, "websocket")
.header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes()))
.body(Body::from("switching to websocket protocol"))
.expect("bug: failed to build response");
let stream = HyperWebsocket {
inner: hyper::upgrade::on(request),
config,
};
Ok((response, stream))
}
/// Check if a request is a websocket upgrade request.
///
/// If the `Upgrade` header lists multiple protocols,
/// this function returns true if of them are `"websocket"`,
/// If the server supports multiple upgrade protocols,
/// it would be more appropriate to try each listed protocol in order.
pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade")
&& header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket")
}
/// Check if there is a header of the given name containing the wanted value.
fn header_contains_value(
headers: &hyper::HeaderMap,
header: impl hyper::header::AsHeaderName,
value: impl AsRef<[u8]>,
) -> bool {
let value = value.as_ref();
for header in headers.get_all(header) {
if header
.as_bytes()
.split(|&c| c == b',')
.any(|x| trim(x).eq_ignore_ascii_case(value))
{
return true;
}
}
false
}
fn trim(data: &[u8]) -> &[u8] {
trim_end(trim_start(data))
}
fn trim_start(data: &[u8]) -> &[u8] {
if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) {
&data[start..]
} else {
b""
}
}
fn trim_end(data: &[u8]) -> &[u8] {
if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
&data[..last + 1]
} else {
b""
}
}
/// Try to upgrade a received `hyper::Request` to a websocket connection.
///
/// The function returns a HTTP response and a future that resolves to the websocket stream.
/// The response body *MUST* be sent to the client before the future can be resolved.
///
/// This functions checks `Sec-WebSocket-Version` header.
/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
/// You can inspect the headers manually before calling this function,
/// and modify the response headers appropriately.
///
/// This function also does not look at the `Connection` or `Upgrade` headers.
/// To check if a request is a websocket upgrade request, you can use [`is_upgrade2_request`].
/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
///
pub fn connect<B>(
mut request: impl std::borrow::BorrowMut<Request<B>>,
config: Option<WebSocketConfig>,
) -> Result<(Response<Body>, HyperWebsocket), ProtocolError> {
let request = request.borrow_mut();
if request
.headers()
.get("Sec-WebSocket-Version")
.map(|v| v.as_bytes())
!= Some(b"13")
{
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
}
let response = Response::builder()
.status(hyper::StatusCode::OK)
.body(Body::from("switching to websocket protocol"))
.expect("bug: failed to build response");
let stream = HyperWebsocket {
inner: hyper::upgrade::on(request),
config,
};
Ok((response, stream))
}
/// Check if a request is a websocket connect request.
pub fn is_connect_request<B>(request: &hyper::Request<B>) -> bool {
request.method() == Method::CONNECT
&& request
.extensions()
.get::<Protocol>()
.is_some_and(|protocol| protocol.as_str() == "websocket")
}
pin_project_lite::pin_project! {
/// A future that resolves to a websocket stream when the associated connection completes.
#[derive(Debug)]
pub struct HyperWebsocket {
#[pin]
inner: hyper::upgrade::OnUpgrade,
config: Option<WebSocketConfig>
}
}
impl std::future::Future for HyperWebsocket {
type Output = Result<WebSocketStream<hyper::upgrade::Upgraded>, WSError>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
let this = self.project();
let upgraded = match this.inner.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(x) => x,
};
let upgraded =
upgraded.map_err(|_| WSError::Protocol(ProtocolError::HandshakeIncomplete))?;
let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None);
tokio::pin!(stream);
// The future returned by `from_raw_socket` is always ready.
// Not sure why it is a future in the first place.
match stream.as_mut().poll(cx) {
Poll::Pending => unreachable!("from_raw_socket should always be created ready"),
Poll::Ready(x) => Poll::Ready(Ok(x)),
}
}
}
#[cfg(test)]
mod tests {
use std::pin::pin;
use futures::{SinkExt, StreamExt};
use hyper_tungstenite::{
tungstenite::{protocol::Role, Message},
WebSocketStream,
};
use tokio::{
io::{duplex, AsyncReadExt, AsyncWriteExt},
task::JoinSet,
};
use tokio_tungstenite::WebSocketStream;
use tungstenite::{protocol::Role, Message};
use super::WebSocketRw;