From ec7c8783642a7e87ade87f8b68a8b8b84ca58f51 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 10 Mar 2024 09:03:00 +0000 Subject: [PATCH] remove unsafe --- proxy/src/serverless/http_auto.rs | 105 ++++++++++-------------------- 1 file changed, 36 insertions(+), 69 deletions(-) diff --git a/proxy/src/serverless/http_auto.rs b/proxy/src/serverless/http_auto.rs index d0c4d14d34..3b547e9a9d 100644 --- a/proxy/src/serverless/http_auto.rs +++ b/proxy/src/serverless/http_auto.rs @@ -3,23 +3,18 @@ use futures::ready; use hyper1::body::Body; -use hyper1::rt::ReadBufCursor; use hyper1::service::HttpService; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use std::future::Future; use std::marker::PhantomPinned; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; use std::{error::Error as StdError, io, marker::Unpin}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use ::http1::{Request, Response}; -use bytes::{Buf, Bytes}; -use hyper1::{ - body::Incoming, - rt::{Read, ReadBuf, Write}, - service::Service, -}; +use bytes::Bytes; +use hyper1::{body::Incoming, service::Service}; use hyper1::server::conn::http1; use hyper1::{rt::bounds::Http2ServerConnExec, server::conn::http2}; @@ -68,18 +63,21 @@ impl Builder { S::Error: Into>, B: Body + 'static, B::Error: Into>, - I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, TokioExecutor: Http2ServerConnExec, { match version { Version::H1 => { - let conn = self.http1.serve_connection(io, service).with_upgrades(); + let conn = self + .http1 + .serve_connection(TokioIo::new(io), service) + .with_upgrades(); UpgradeableConnection { state: UpgradeableConnState::H1 { conn }, } } Version::H2 => { - let conn = self.http2.serve_connection(io, service); + let conn = self.http2.serve_connection(TokioIo::new(io), service); UpgradeableConnection { state: UpgradeableConnState::H2 { conn }, } @@ -96,11 +94,11 @@ pub(crate) enum Version { pub(crate) fn read_version(io: I) -> ReadVersion where - I: tokio::io::AsyncRead + Unpin, + I: AsyncRead + Unpin, { ReadVersion { - io: Some(TokioIo::new(io)), - buf: [MaybeUninit::uninit(); 24], + io: Some(io), + buf: [0; 24], filled: 0, version: Version::H2, _pin: PhantomPinned, @@ -109,8 +107,8 @@ where pin_project! { pub(crate) struct ReadVersion { - io: Option>, - buf: [MaybeUninit; 24], + io: Option, + buf: [u8; 24], // the amount of `buf` thats been filled filled: usize, version: Version, @@ -122,24 +120,20 @@ pin_project! { impl Future for ReadVersion where - I: tokio::io::AsyncRead + Unpin, + I: AsyncRead + Unpin, { type Output = io::Result<(Version, Rewind)>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - let mut buf = ReadBuf::uninit(&mut *this.buf); - // SAFETY: `this.filled` tracks how many bytes have been read (and thus initialized) and - // we're only advancing by that many. - unsafe { - buf.unfilled().advance(*this.filled); - }; + let mut buf = ReadBuf::new(&mut *this.buf); + buf.set_filled(*this.filled); // We start as H2 and switch to H1 as soon as we don't have the preface. while buf.filled().len() < H2_PREFACE.len() { let len = buf.filled().len(); - ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, buf.unfilled()))?; + ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, &mut buf))?; *this.filled = buf.filled().len(); // We starts as H2 and switch to H1 when we don't get the preface. @@ -171,8 +165,10 @@ pin_project! { } } -type Http1UpgradeableConnection = hyper1::server::conn::http1::UpgradeableConnection; -type Http2Connection = hyper1::server::conn::http2::Connection, S, TokioExecutor>; +type Http1UpgradeableConnection = + hyper1::server::conn::http1::UpgradeableConnection>, S>; +type Http2Connection = + hyper1::server::conn::http2::Connection>, S, TokioExecutor>; pin_project! { #[project = UpgradeableConnStateProj] @@ -182,7 +178,7 @@ pin_project! { { H1 { #[pin] - conn: Http1UpgradeableConnection, S>, + conn: Http1UpgradeableConnection, }, H2 { #[pin] @@ -195,7 +191,7 @@ impl UpgradeableConnection where S: HttpService, S::Error: Into>, - I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + I: AsyncRead + AsyncWrite + Unpin, B: Body + 'static, B::Error: Into>, TokioExecutor: Http2ServerConnExec, @@ -223,7 +219,7 @@ where S::Error: Into>, B: Body + 'static, B::Error: Into>, - I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, TokioExecutor: Http2ServerConnExec, { type Output = Result<()>; @@ -241,18 +237,18 @@ where #[derive(Debug)] pub(crate) struct Rewind { pre: Option, - inner: TokioIo, + inner: T, } impl Rewind { pub(crate) fn new(io: T) -> Self { Rewind { pre: None, - inner: TokioIo::new(io), + inner: io, } } - pub(crate) fn new_buffered(io: TokioIo, buf: Bytes) -> Self { + pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { Rewind { pre: Some(buf), inner: io, @@ -260,22 +256,20 @@ impl Rewind { } } -impl Read for Rewind +impl AsyncRead for Rewind where - T: tokio::io::AsyncRead + Unpin, + T: AsyncRead + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - mut buf: ReadBufCursor<'_>, + buf: &mut ReadBuf<'_>, ) -> Poll> { - if let Some(mut prefix) = self.pre.take() { + if let Some(prefix) = self.pre.take() { // If there are no remaining bytes, let the bytes get dropped. if !prefix.is_empty() { - let copy_len = std::cmp::min(prefix.len(), remaining(&mut buf)); - // TODO: There should be a way to do following two lines cleaner... - put_slice(&mut buf, &prefix[..copy_len]); - prefix.advance(copy_len); + let copy_len = std::cmp::min(prefix.len(), buf.remaining()); + buf.put_slice(&prefix[..copy_len]); // Put back what's left if !prefix.is_empty() { self.pre = Some(prefix); @@ -288,36 +282,9 @@ where } } -fn remaining(cursor: &mut ReadBufCursor<'_>) -> usize { - // SAFETY: - // We do not uninitialize any set bytes. - unsafe { cursor.as_mut().len() } -} - -// Copied from `ReadBufCursor::put_slice`. -// If that becomes public, we could ditch this. -fn put_slice(cursor: &mut ReadBufCursor<'_>, slice: &[u8]) { - assert!( - remaining(cursor) >= slice.len(), - "buf.len() must fit in remaining()" - ); - - let amt = slice.len(); - - // SAFETY: - // the length is asserted above - unsafe { - cursor.as_mut()[..amt] - .as_mut_ptr() - .cast::() - .copy_from_nonoverlapping(slice.as_ptr(), amt); - cursor.advance(amt); - } -} - -impl Write for Rewind +impl AsyncWrite for Rewind where - T: tokio::io::AsyncWrite + Unpin, + T: AsyncWrite + Unpin, { fn poll_write( mut self: Pin<&mut Self>,