From 7afa5b3f35e0b0f923d3d25d25ee8786bac7a406 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 10 Mar 2024 07:50:57 +0000 Subject: [PATCH] vendor hyper_util::server::conn::auto --- proxy/src/serverless.rs | 3 +- proxy/src/serverless/http_auto.rs | 819 ++++++++++++++++++++++++++++++ 2 files changed, 821 insertions(+), 1 deletion(-) create mode 100644 proxy/src/serverless/http_auto.rs diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 4dbfc4b6a9..046f81d1e4 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -4,6 +4,7 @@ mod backend; mod conn_pool; +mod http_auto; mod json; mod sql_over_http; mod websocket; @@ -98,7 +99,7 @@ pub async fn task_main( let http_connections = tokio_util::task::task_tracker::TaskTracker::new(); http_connections.close(); - let server = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + let server = http_auto::Builder::new(); loop { let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await else { diff --git a/proxy/src/serverless/http_auto.rs b/proxy/src/serverless/http_auto.rs new file mode 100644 index 0000000000..831effe4ac --- /dev/null +++ b/proxy/src/serverless/http_auto.rs @@ -0,0 +1,819 @@ +//! [`hyper-util`] offers an 'auto' connection to detect whether the connection should be HTTP1 or HTTP2. +//! There's a bug in this implementation where graceful shutdowns are not properly respected. + +use futures::ready; +use hyper1::body::Body; +use hyper1::rt::ReadBufCursor; +use hyper1::service::HttpService; +use hyper_util::rt::TokioExecutor; +use std::future::Future; +use std::marker::PhantomPinned; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{error::Error as StdError, io, marker::Unpin, time::Duration}; + +use ::http1::{Request, Response}; +use bytes::{Buf, Bytes}; +use hyper1::{ + body::Incoming, + rt::{Read, ReadBuf, Timer, Write}, + service::Service, +}; + +use hyper1::server::conn::http1; +use hyper1::{rt::bounds::Http2ServerConnExec, server::conn::http2}; + +use pin_project_lite::pin_project; + +type Error = Box; + +type Result = std::result::Result; + +const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +/// Exactly equivalent to [`Http2ServerConnExec`]. +pub trait HttpServerConnExec: Http2ServerConnExec {} + +impl> HttpServerConnExec for T {} + +/// Http1 or Http2 connection builder. +#[derive(Clone, Debug)] +pub struct Builder { + http1: http1::Builder, + http2: http2::Builder, +} + +impl Builder { + /// Create a new auto connection builder. + /// + /// `executor` parameter should be a type that implements + /// [`Executor`](hyper::rt::Executor) trait. + /// + /// # Example + /// + /// ``` + /// use hyper_util::{ + /// rt::TokioExecutor, + /// server::conn::auto, + /// }; + /// + /// auto::Builder::new(TokioExecutor::new()); + /// ``` + pub fn new() -> Self { + Self { + http1: http1::Builder::new(), + http2: http2::Builder::new(TokioExecutor::new()), + } + } + + /// Http1 configuration. + pub fn http1(&mut self) -> Http1Builder<'_> { + Http1Builder { inner: self } + } + + /// Http2 configuration. + pub fn http2(&mut self) -> Http2Builder<'_> { + Http2Builder { inner: self } + } + + /// Bind a connection together with a [`Service`]. + pub fn serve_connection(&self, io: I, service: S) -> Connection<'_, I, S> + where + S: Service, Response = Response>, + S::Future: 'static, + S::Error: Into>, + B: Body + 'static, + B::Error: Into>, + I: Read + Write + Unpin + 'static, + TokioExecutor: HttpServerConnExec, + { + Connection { + state: ConnState::ReadVersion { + read_version: read_version(io), + builder: self, + service: Some(service), + }, + } + } + + /// Bind a connection together with a [`Service`], with the ability to + /// handle HTTP upgrades. This requires that the IO object implements + /// `Send`. + pub fn serve_connection_with_upgrades( + &self, + io: I, + service: S, + ) -> UpgradeableConnection<'_, I, S> + where + S: Service, Response = Response>, + S::Future: 'static, + S::Error: Into>, + B: Body + 'static, + B::Error: Into>, + I: Read + Write + Unpin + Send + 'static, + TokioExecutor: HttpServerConnExec, + { + UpgradeableConnection { + state: UpgradeableConnState::ReadVersion { + read_version: read_version(io), + builder: self, + service: Some(service), + }, + } + } +} + +#[derive(Copy, Clone)] +enum Version { + H1, + H2, +} + +fn read_version(io: I) -> ReadVersion +where + I: Read + Unpin, +{ + ReadVersion { + io: Some(io), + buf: [MaybeUninit::uninit(); 24], + filled: 0, + version: Version::H2, + _pin: PhantomPinned, + } +} + +pin_project! { + struct ReadVersion { + io: Option, + buf: [MaybeUninit; 24], + // the amount of `buf` thats been filled + filled: usize, + version: Version, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} + +impl Future for ReadVersion +where + I: Read + Unpin, +{ + type Output = io::Result<(Version, Rewind)>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + let mut buf = ReadBuf::uninit(&mut *this.buf); + // SAFETY: `this.filled` tracks how many bytes have been read (and thus initialized) and + // we're only advancing by that many. + unsafe { + buf.unfilled().advance(*this.filled); + }; + + // We start as H2 and switch to H1 as soon as we don't have the preface. + while buf.filled().len() < H2_PREFACE.len() { + let len = buf.filled().len(); + ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, buf.unfilled()))?; + *this.filled = buf.filled().len(); + + // We starts as H2 and switch to H1 when we don't get the preface. + if buf.filled().len() == len + || &buf.filled()[len..] != &H2_PREFACE[len..buf.filled().len()] + { + *this.version = Version::H1; + break; + } + } + + let io = this.io.take().unwrap(); + let buf = buf.filled().to_vec(); + Poll::Ready(Ok(( + *this.version, + Rewind::new_buffered(io, Bytes::from(buf)), + ))) + } +} + +pin_project! { + /// Connection future. + pub struct Connection<'a, I, S> + where + S: HttpService, + { + #[pin] + state: ConnState<'a, I, S>, + } +} + +type Http1Connection = hyper1::server::conn::http1::Connection, S>; +type Http2Connection = hyper1::server::conn::http2::Connection, S, TokioExecutor>; + +pin_project! { + #[project = ConnStateProj] + enum ConnState<'a, I, S> + where + S: HttpService, + { + ReadVersion { + #[pin] + read_version: ReadVersion, + builder: &'a Builder, + service: Option, + }, + H1 { + #[pin] + conn: Http1Connection, + }, + H2 { + #[pin] + conn: Http2Connection, + }, + } +} + +impl Connection<'_, I, S> +where + S: HttpService, + S::Error: Into>, + I: Read + Write + Unpin, + B: Body + 'static, + B::Error: Into>, + TokioExecutor: HttpServerConnExec, +{ + /// Start a graceful shutdown process for this connection. + /// + /// This `Connection` should continue to be polled until shutdown can finish. + /// + /// # Note + /// + /// This should only be called while the `Connection` future is still pending. If called after + /// `Connection::poll` has resolved, this does nothing. + pub fn graceful_shutdown(self: Pin<&mut Self>) { + match self.project().state.project() { + ConnStateProj::ReadVersion { .. } => {} + ConnStateProj::H1 { conn } => conn.graceful_shutdown(), + ConnStateProj::H2 { conn } => conn.graceful_shutdown(), + } + } +} + +impl Future for Connection<'_, I, S> +where + S: Service, Response = Response>, + S::Future: 'static, + S::Error: Into>, + B: Body + 'static, + B::Error: Into>, + I: Read + Write + Unpin + 'static, + TokioExecutor: HttpServerConnExec, +{ + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let mut this = self.as_mut().project(); + + match this.state.as_mut().project() { + ConnStateProj::ReadVersion { + read_version, + builder, + service, + } => { + let (version, io) = ready!(read_version.poll(cx))?; + let service = service.take().unwrap(); + match version { + Version::H1 => { + let conn = builder.http1.serve_connection(io, service); + this.state.set(ConnState::H1 { conn }); + } + Version::H2 => { + let conn = builder.http2.serve_connection(io, service); + this.state.set(ConnState::H2 { conn }); + } + } + } + ConnStateProj::H1 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + ConnStateProj::H2 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + } + } + } +} + +pin_project! { + /// Connection future. + pub struct UpgradeableConnection<'a, I, S> + where + S: HttpService, + { + #[pin] + state: UpgradeableConnState<'a, I, S>, + } +} + +type Http1UpgradeableConnection = hyper1::server::conn::http1::UpgradeableConnection; + +pin_project! { + #[project = UpgradeableConnStateProj] + enum UpgradeableConnState<'a, I, S> + where + S: HttpService, + { + ReadVersion { + #[pin] + read_version: ReadVersion, + builder: &'a Builder, + service: Option, + }, + H1 { + #[pin] + conn: Http1UpgradeableConnection, S>, + }, + H2 { + #[pin] + conn: Http2Connection, + }, + } +} + +impl UpgradeableConnection<'_, I, S> +where + S: HttpService, + S::Error: Into>, + I: Read + Write + Unpin, + B: Body + 'static, + B::Error: Into>, + TokioExecutor: HttpServerConnExec, +{ + /// Start a graceful shutdown process for this connection. + /// + /// This `UpgradeableConnection` should continue to be polled until shutdown can finish. + /// + /// # Note + /// + /// This should only be called while the `Connection` future is still nothing. pending. If + /// called after `UpgradeableConnection::poll` has resolved, this does nothing. + pub fn graceful_shutdown(self: Pin<&mut Self>) { + match self.project().state.project() { + UpgradeableConnStateProj::ReadVersion { .. } => {} + UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(), + UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(), + } + } +} + +impl Future for UpgradeableConnection<'_, I, S> +where + S: Service, Response = Response>, + S::Future: 'static, + S::Error: Into>, + B: Body + 'static, + B::Error: Into>, + I: Read + Write + Unpin + Send + 'static, + TokioExecutor: HttpServerConnExec, +{ + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let mut this = self.as_mut().project(); + + match this.state.as_mut().project() { + UpgradeableConnStateProj::ReadVersion { + read_version, + builder, + service, + } => { + let (version, io) = ready!(read_version.poll(cx))?; + let service = service.take().unwrap(); + match version { + Version::H1 => { + let conn = builder.http1.serve_connection(io, service).with_upgrades(); + this.state.set(UpgradeableConnState::H1 { conn }); + } + Version::H2 => { + let conn = builder.http2.serve_connection(io, service); + this.state.set(UpgradeableConnState::H2 { conn }); + } + } + } + UpgradeableConnStateProj::H1 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + UpgradeableConnStateProj::H2 { conn } => { + return conn.poll(cx).map_err(Into::into); + } + } + } + } +} + +/// Http1 part of builder. +pub struct Http1Builder<'a> { + inner: &'a mut Builder, +} + +impl Http1Builder<'_> { + /// Http2 configuration. + pub fn http2(&mut self) -> Http2Builder<'_> { + Http2Builder { inner: self.inner } + } + + /// Set whether HTTP/1 connections should support half-closures. + /// + /// Clients can chose to shutdown their write-side while waiting + /// for the server to respond. Setting this to `true` will + /// prevent closing the connection immediately if `read` + /// detects an EOF in the middle of a request. + /// + /// Default is `false`. + pub fn half_close(&mut self, val: bool) -> &mut Self { + self.inner.http1.half_close(val); + self + } + + /// Enables or disables HTTP/1 keep-alive. + /// + /// Default is true. + pub fn keep_alive(&mut self, val: bool) -> &mut Self { + self.inner.http1.keep_alive(val); + self + } + + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn title_case_headers(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.title_case_headers(enabled); + self + } + + /// Set whether to support preserving original header cases. + /// + /// Currently, this will record the original cases received, and store them + /// in a private extension on the `Request`. It will also look for and use + /// such an extension in any provided `Response`. + /// + /// Since the relevant extension is still private, there is no way to + /// interact with the original cases. The only effect this can have now is + /// to forward the cases in a proxy-like fashion. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn preserve_header_case(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.preserve_header_case(enabled); + self + } + + /// Set a timeout for reading client request headers. If a client does not + /// transmit the entire header within this time, the connection is closed. + /// + /// Default is None. + pub fn header_read_timeout(&mut self, read_timeout: Duration) -> &mut Self { + self.inner.http1.header_read_timeout(read_timeout); + self + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Setting this to true will force hyper to use queued strategy + /// which may eliminate unnecessary cloning on some TLS backends + /// + /// Default is `auto`. In this mode hyper will try to guess which + /// mode to use + pub fn writev(&mut self, val: bool) -> &mut Self { + self.inner.http1.writev(val); + self + } + + /// Set the maximum buffer size for the connection. + /// + /// Default is ~400kb. + /// + /// # Panics + /// + /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. + pub fn max_buf_size(&mut self, max: usize) -> &mut Self { + self.inner.http1.max_buf_size(max); + self + } + + /// Aggregates flushes to better support pipelined responses. + /// + /// Experimental, may have bugs. + /// + /// Default is false. + pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.pipeline_flush(enabled); + self + } + + /// Set the timer used in background tasks. + pub fn timer(&mut self, timer: M) -> &mut Self + where + M: Timer + Send + Sync + 'static, + { + self.inner.http1.timer(timer); + self + } + + /// Bind a connection together with a [`Service`]. + pub async fn serve_connection(&self, io: I, service: S) -> Result<()> + where + S: Service, Response = Response>, + S::Future: 'static, + S::Error: Into>, + B: Body + 'static, + B::Error: Into>, + I: Read + Write + Unpin + 'static, + TokioExecutor: HttpServerConnExec, + { + self.inner.serve_connection(io, service).await + } +} + +/// Http2 part of builder. +pub struct Http2Builder<'a> { + inner: &'a mut Builder, +} + +impl Http2Builder<'_> { + /// Http1 configuration. + pub fn http1(&mut self) -> Http1Builder<'_> { + Http1Builder { inner: self.inner } + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + pub fn initial_stream_window_size(&mut self, sz: impl Into>) -> &mut Self { + self.inner.http2.initial_stream_window_size(sz); + self + } + + /// Sets the max connection-level flow control for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn initial_connection_window_size(&mut self, sz: impl Into>) -> &mut Self { + self.inner.http2.initial_connection_window_size(sz); + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + pub fn adaptive_window(&mut self, enabled: bool) -> &mut Self { + self.inner.http2.adaptive_window(enabled); + self + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn max_frame_size(&mut self, sz: impl Into>) -> &mut Self { + self.inner.http2.max_frame_size(sz); + self + } + + /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 + /// connections. + /// + /// Default is 200. Passing `None` will remove any limit. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS + pub fn max_concurrent_streams(&mut self, max: impl Into>) -> &mut Self { + self.inner.http2.max_concurrent_streams(max); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + pub fn keep_alive_interval(&mut self, interval: impl Into>) -> &mut Self { + self.inner.http2.keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + pub fn keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.inner.http2.keep_alive_timeout(timeout); + self + } + + /// Set the maximum write buffer size for each HTTP/2 stream. + /// + /// Default is currently ~400KB, but may change. + /// + /// # Panics + /// + /// The value must be no larger than `u32::MAX`. + pub fn max_send_buf_size(&mut self, max: usize) -> &mut Self { + self.inner.http2.max_send_buf_size(max); + self + } + + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + pub fn enable_connect_protocol(&mut self) -> &mut Self { + self.inner.http2.enable_connect_protocol(); + self + } + + /// Sets the max size of received header frames. + /// + /// Default is currently ~16MB, but may change. + pub fn max_header_list_size(&mut self, max: u32) -> &mut Self { + self.inner.http2.max_header_list_size(max); + self + } + + /// Set the timer used in background tasks. + pub fn timer(&mut self, timer: M) -> &mut Self + where + M: Timer + Send + Sync + 'static, + { + self.inner.http2.timer(timer); + self + } + + /// Bind a connection together with a [`Service`]. + pub async fn serve_connection(&self, io: I, service: S) -> Result<()> + where + S: Service, Response = Response>, + S::Future: 'static, + S::Error: Into>, + B: Body + 'static, + B::Error: Into>, + I: Read + Write + Unpin + 'static, + TokioExecutor: HttpServerConnExec, + { + self.inner.serve_connection(io, service).await + } +} + +/// Combine a buffer with an IO, rewinding reads to use the buffer. +#[derive(Debug)] +pub(crate) struct Rewind { + pre: Option, + inner: T, +} + +impl Rewind { + #[cfg(test)] + pub(crate) fn new(io: T) -> Self { + Rewind { + pre: None, + inner: io, + } + } + + #[allow(dead_code)] + pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { + Rewind { + pre: Some(buf), + inner: io, + } + } + + #[cfg(test)] + pub(crate) fn rewind(&mut self, bs: Bytes) { + debug_assert!(self.pre.is_none()); + self.pre = Some(bs); + } + + // pub(crate) fn into_inner(self) -> (T, Bytes) { + // (self.inner, self.pre.unwrap_or_else(Bytes::new)) + // } + + // pub(crate) fn get_mut(&mut self) -> &mut T { + // &mut self.inner + // } +} + +impl Read for Rewind +where + T: Read + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: ReadBufCursor<'_>, + ) -> Poll> { + if let Some(mut prefix) = self.pre.take() { + // If there are no remaining bytes, let the bytes get dropped. + if !prefix.is_empty() { + let copy_len = std::cmp::min(prefix.len(), remaining(&mut buf)); + // TODO: There should be a way to do following two lines cleaner... + put_slice(&mut buf, &prefix[..copy_len]); + prefix.advance(copy_len); + // Put back what's left + if !prefix.is_empty() { + self.pre = Some(prefix); + } + + return Poll::Ready(Ok(())); + } + } + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +fn remaining(cursor: &mut ReadBufCursor<'_>) -> usize { + // SAFETY: + // We do not uninitialize any set bytes. + unsafe { cursor.as_mut().len() } +} + +// Copied from `ReadBufCursor::put_slice`. +// If that becomes public, we could ditch this. +fn put_slice(cursor: &mut ReadBufCursor<'_>, slice: &[u8]) { + assert!( + remaining(cursor) >= slice.len(), + "buf.len() must fit in remaining()" + ); + + let amt = slice.len(); + + // SAFETY: + // the length is asserted above + unsafe { + cursor.as_mut()[..amt] + .as_mut_ptr() + .cast::() + .copy_from_nonoverlapping(slice.as_ptr(), amt); + cursor.advance(amt); + } +} + +impl Write for Rewind +where + T: Write + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +}