From d1bd8d377c93975e8ea2ff94882a5a33fe20f47b Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 10 Mar 2024 08:36:29 +0000 Subject: [PATCH] remove readversion state --- proxy/Cargo.toml | 4 +- proxy/src/serverless.rs | 64 ++++---- proxy/src/serverless/http_auto.rs | 234 ++++-------------------------- 3 files changed, 72 insertions(+), 230 deletions(-) diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index d6c356bf26..b001c2947f 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -30,8 +30,8 @@ hostname.workspace = true humantime.workspace = true hyper-tungstenite.workspace = true hyper.workspace = true -hyper1 = { package = "hyper", version = "1.2", features = ["server"] } -hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] } +hyper1 = { package = "hyper", version = "1.2", features = ["server", "http1", "http2"] } +hyper-util = { version = "0.1", features = ["tokio"] } http1 = { package = "http", version = "1" } http-body-util = { version = "0.1" } ipnet.workspace = true diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 71386e7233..fb70f800bf 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -31,6 +31,7 @@ use crate::protocol2::WithClientIp; use crate::proxy::run_until_cancelled; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; +use crate::serverless::http_auto::Rewind; use crate::{cancellation::CancellationHandler, config::ProxyConfig}; use std::convert::Infallible; @@ -163,35 +164,46 @@ pub async fn task_main( } }; - let service = hyper1::service::service_fn(move |req: hyper1::Request| { - let backend = backend.clone(); - let ws_connections = ws_connections.clone(); - let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - let cancellation_handler = cancellation_handler.clone(); - - async move { - Ok::<_, Infallible>( - request_handler( - req, - config, - backend, - ws_connections, - cancellation_handler, - peer_addr, - endpoint_rate_limiter, - ) - .await - .map_or_else(api_error_into_response, |r| r), - ) + let (version, conn) =match conn.get_ref().1.alpn_protocol() { + Some(b"http/1.1") => (http_auto::Version::H1, Rewind::new(hyper_util::rt::TokioIo::new(conn))), + Some(b"h2") => (http_auto::Version::H2, Rewind::new(hyper_util::rt::TokioIo::new(conn))), + _ => match http_auto::read_version(hyper_util::rt::TokioIo::new(conn)).await { + Ok(v) => v, + Err(e) => { + tracing::warn!("HTTP connection error {e}"); + return; + }, } - }); - - let conn = match conn.get_ref().1.alpn_protocol() { - Some(b"http/1.1") => server.serve_http1_connection_with_upgrades(hyper_util::rt::TokioIo::new(conn), service), - Some(b"h2") => server.serve_http2_connection(hyper_util::rt::TokioIo::new(conn), service), - _ => server.serve_connection_with_upgrades(hyper_util::rt::TokioIo::new(conn), service) }; + let conn = server.serve_connection_with_upgrades( + conn, + version, + hyper1::service::service_fn(move |req: hyper1::Request| { + let backend = backend.clone(); + let ws_connections = ws_connections.clone(); + let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellation_handler = cancellation_handler.clone(); + + async move { + Ok::<_, Infallible>( + request_handler( + req, + config, + backend, + ws_connections, + cancellation_handler, + peer_addr, + endpoint_rate_limiter, + ) + .await + .map_or_else(api_error_into_response, |r| r), + ) + } + }) + ); + + let cancel = pin!(cancellation_token.cancelled()); let conn = pin!(conn); let res = match select(cancel, conn).await { diff --git a/proxy/src/serverless/http_auto.rs b/proxy/src/serverless/http_auto.rs index 0a5a5b8547..caf942f888 100644 --- a/proxy/src/serverless/http_auto.rs +++ b/proxy/src/serverless/http_auto.rs @@ -32,11 +32,6 @@ type Result = std::result::Result; const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; -/// Exactly equivalent to [`Http2ServerConnExec`]. -pub trait HttpServerConnExec: Http2ServerConnExec {} - -impl> HttpServerConnExec for T {} - /// Http1 or Http2 connection builder. #[derive(Clone, Debug)] pub struct Builder { @@ -46,20 +41,6 @@ pub struct Builder { impl Builder { /// Create a new auto connection builder. - /// - /// `executor` parameter should be a type that implements - /// [`Executor`](hyper::rt::Executor) trait. - /// - /// # Example - /// - /// ``` - /// use hyper_util::{ - /// rt::TokioExecutor, - /// server::conn::auto, - /// }; - /// - /// auto::Builder::new(TokioExecutor::new()); - /// ``` pub fn new() -> Self { let mut builder = Self { http1: http1::Builder::new(), @@ -77,9 +58,10 @@ impl Builder { /// `Send`. pub fn serve_connection_with_upgrades( &self, - io: I, + io: Rewind, + version: Version, service: S, - ) -> UpgradeableConnection<'_, I, S> + ) -> UpgradeableConnection where S: Service, Response = Response>, S::Future: 'static, @@ -87,70 +69,32 @@ impl Builder { B: Body + 'static, B::Error: Into>, I: Read + Write + Unpin + Send + 'static, - TokioExecutor: HttpServerConnExec, + TokioExecutor: Http2ServerConnExec, { - UpgradeableConnection { - state: UpgradeableConnState::ReadVersion { - read_version: read_version(io), - builder: self, - service: Some(service), - }, - } - } - - /// Bind a HTTP2 connection together with a [`Service`]. This requires that the IO object implements `Send`. - pub fn serve_http2_connection( - &self, - io: I, - service: S, - ) -> UpgradeableConnection<'_, I, S> - where - S: Service, Response = Response>, - S::Future: 'static, - S::Error: Into>, - B: Body + 'static, - B::Error: Into>, - I: Read + Write + Unpin + Send + 'static, - TokioExecutor: HttpServerConnExec, - { - let conn = self.http2.serve_connection(Rewind::new(io), service); - UpgradeableConnection { - state: UpgradeableConnState::H2 { conn }, - } - } - - /// Bind a HTTP2 connection together with a [`Service`]. This requires that the IO object implements `Send`. - pub fn serve_http1_connection_with_upgrades( - &self, - io: I, - service: S, - ) -> UpgradeableConnection<'_, I, S> - where - S: Service, Response = Response>, - S::Future: 'static, - S::Error: Into>, - B: Body + 'static, - B::Error: Into>, - I: Read + Write + Unpin + Send + 'static, - TokioExecutor: HttpServerConnExec, - { - let conn = self - .http1 - .serve_connection(Rewind::new(io), service) - .with_upgrades(); - UpgradeableConnection { - state: UpgradeableConnState::H1 { conn }, + match version { + Version::H1 => { + let conn = self.http1.serve_connection(io, service).with_upgrades(); + UpgradeableConnection { + state: UpgradeableConnState::H1 { conn }, + } + } + Version::H2 => { + let conn = self.http2.serve_connection(io, service); + UpgradeableConnection { + state: UpgradeableConnState::H2 { conn }, + } + } } } } #[derive(Copy, Clone)] -enum Version { +pub(crate) enum Version { H1, H2, } -fn read_version(io: I) -> ReadVersion +pub(crate) fn read_version(io: I) -> ReadVersion where I: Read + Unpin, { @@ -164,7 +108,7 @@ where } pin_project! { - struct ReadVersion { + pub(crate) struct ReadVersion { io: Option, buf: [MaybeUninit; 24], // the amount of `buf` thats been filled @@ -218,112 +162,24 @@ where pin_project! { /// Connection future. - pub struct Connection<'a, I, S> + pub struct UpgradeableConnection where S: HttpService, { #[pin] - state: ConnState<'a, I, S>, - } -} - -type Http1Connection = hyper1::server::conn::http1::Connection, S>; -type Http2Connection = hyper1::server::conn::http2::Connection, S, TokioExecutor>; - -pin_project! { - #[project = ConnStateProj] - enum ConnState<'a, I, S> - where - S: HttpService, - { - ReadVersion { - #[pin] - read_version: ReadVersion, - builder: &'a Builder, - service: Option, - }, - H1 { - #[pin] - conn: Http1Connection, - }, - H2 { - #[pin] - conn: Http2Connection, - }, - } -} - -impl Future for Connection<'_, I, S> -where - S: Service, Response = Response>, - S::Future: 'static, - S::Error: Into>, - B: Body + 'static, - B::Error: Into>, - I: Read + Write + Unpin + 'static, - TokioExecutor: HttpServerConnExec, -{ - type Output = Result<()>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let mut this = self.as_mut().project(); - - match this.state.as_mut().project() { - ConnStateProj::ReadVersion { - read_version, - builder, - service, - } => { - let (version, io) = ready!(read_version.poll(cx))?; - let service = service.take().unwrap(); - match version { - Version::H1 => { - let conn = builder.http1.serve_connection(io, service); - this.state.set(ConnState::H1 { conn }); - } - Version::H2 => { - let conn = builder.http2.serve_connection(io, service); - this.state.set(ConnState::H2 { conn }); - } - } - } - ConnStateProj::H1 { conn } => { - return conn.poll(cx).map_err(Into::into); - } - ConnStateProj::H2 { conn } => { - return conn.poll(cx).map_err(Into::into); - } - } - } - } -} - -pin_project! { - /// Connection future. - pub struct UpgradeableConnection<'a, I, S> - where - S: HttpService, - { - #[pin] - state: UpgradeableConnState<'a, I, S>, + state: UpgradeableConnState, } } type Http1UpgradeableConnection = hyper1::server::conn::http1::UpgradeableConnection; +type Http2Connection = hyper1::server::conn::http2::Connection, S, TokioExecutor>; pin_project! { #[project = UpgradeableConnStateProj] - enum UpgradeableConnState<'a, I, S> + enum UpgradeableConnState where S: HttpService, { - ReadVersion { - #[pin] - read_version: ReadVersion, - builder: &'a Builder, - service: Option, - }, H1 { #[pin] conn: Http1UpgradeableConnection, S>, @@ -335,14 +191,14 @@ pin_project! { } } -impl UpgradeableConnection<'_, I, S> +impl UpgradeableConnection where S: HttpService, S::Error: Into>, I: Read + Write + Unpin, B: Body + 'static, B::Error: Into>, - TokioExecutor: HttpServerConnExec, + TokioExecutor: Http2ServerConnExec, { /// Start a graceful shutdown process for this connection. /// @@ -354,14 +210,13 @@ where /// called after `UpgradeableConnection::poll` has resolved, this does nothing. pub fn graceful_shutdown(self: Pin<&mut Self>) { match self.project().state.project() { - UpgradeableConnStateProj::ReadVersion { .. } => {} UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(), UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(), } } } -impl Future for UpgradeableConnection<'_, I, S> +impl Future for UpgradeableConnection where S: Service, Response = Response>, S::Future: 'static, @@ -369,40 +224,15 @@ where B: Body + 'static, B::Error: Into>, I: Read + Write + Unpin + Send + 'static, - TokioExecutor: HttpServerConnExec, + TokioExecutor: Http2ServerConnExec, { type Output = Result<()>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let mut this = self.as_mut().project(); - - match this.state.as_mut().project() { - UpgradeableConnStateProj::ReadVersion { - read_version, - builder, - service, - } => { - let (version, io) = ready!(read_version.poll(cx))?; - let service = service.take().unwrap(); - match version { - Version::H1 => { - let conn = builder.http1.serve_connection(io, service).with_upgrades(); - this.state.set(UpgradeableConnState::H1 { conn }); - } - Version::H2 => { - let conn = builder.http2.serve_connection(io, service); - this.state.set(UpgradeableConnState::H2 { conn }); - } - } - } - UpgradeableConnStateProj::H1 { conn } => { - return conn.poll(cx).map_err(Into::into); - } - UpgradeableConnStateProj::H2 { conn } => { - return conn.poll(cx).map_err(Into::into); - } - } + let mut this = self.as_mut().project(); + match this.state.as_mut().project() { + UpgradeableConnStateProj::H1 { conn } => conn.poll(cx).map_err(Into::into), + UpgradeableConnStateProj::H2 { conn } => conn.poll(cx).map_err(Into::into), } } }