Compare commits

...

9 Commits

Author SHA1 Message Date
Conrad Ludgate
3c0eb1bf71 add timeout for read_version 2024-03-10 09:32:26 +00:00
Conrad Ludgate
ec7c878364 remove unsafe 2024-03-10 09:03:00 +00:00
Conrad Ludgate
5d799f0a25 remove io fluff 2024-03-10 08:53:37 +00:00
Conrad Ludgate
d1bd8d377c remove readversion state 2024-03-10 08:36:29 +00:00
Conrad Ludgate
71fda96c21 remove dead code
add support for pre-determined http version
2024-03-10 08:13:33 +00:00
Conrad Ludgate
7afa5b3f35 vendor hyper_util::server::conn::auto 2024-03-10 07:50:57 +00:00
Conrad Ludgate
2fc4e3df84 update logging 2024-03-09 12:18:22 +00:00
Conrad Ludgate
d91ff747bb remove tls listener file 2024-03-09 12:17:17 +00:00
Conrad Ludgate
375dfd661c proxy: hyper1 for only proxy 2024-03-09 12:17:09 +00:00
12 changed files with 740 additions and 566 deletions

206
Cargo.lock generated
View File

@@ -285,7 +285,7 @@ dependencies = [
"futures",
"git-version",
"humantime",
"hyper",
"hyper 0.14.26",
"metrics",
"once_cell",
"pageserver_api",
@@ -331,7 +331,7 @@ dependencies = [
"fastrand 2.0.0",
"hex",
"http 0.2.9",
"hyper",
"hyper 0.14.26",
"ring 0.17.6",
"time",
"tokio",
@@ -368,7 +368,7 @@ dependencies = [
"bytes",
"fastrand 2.0.0",
"http 0.2.9",
"http-body",
"http-body 0.4.5",
"percent-encoding",
"pin-project-lite",
"tracing",
@@ -396,7 +396,7 @@ dependencies = [
"aws-types",
"bytes",
"http 0.2.9",
"http-body",
"http-body 0.4.5",
"once_cell",
"percent-encoding",
"regex-lite",
@@ -547,7 +547,7 @@ dependencies = [
"crc32fast",
"hex",
"http 0.2.9",
"http-body",
"http-body 0.4.5",
"md-5",
"pin-project-lite",
"sha1",
@@ -579,7 +579,7 @@ dependencies = [
"bytes-utils",
"futures-core",
"http 0.2.9",
"http-body",
"http-body 0.4.5",
"once_cell",
"percent-encoding",
"pin-project-lite",
@@ -618,10 +618,10 @@ dependencies = [
"aws-smithy-types",
"bytes",
"fastrand 2.0.0",
"h2",
"h2 0.3.24",
"http 0.2.9",
"http-body",
"hyper",
"http-body 0.4.5",
"hyper 0.14.26",
"hyper-rustls",
"once_cell",
"pin-project-lite",
@@ -658,7 +658,7 @@ dependencies = [
"bytes-utils",
"futures-core",
"http 0.2.9",
"http-body",
"http-body 0.4.5",
"itoa",
"num-integer",
"pin-project-lite",
@@ -707,8 +707,8 @@ dependencies = [
"bytes",
"futures-util",
"http 0.2.9",
"http-body",
"hyper",
"http-body 0.4.5",
"hyper 0.14.26",
"itoa",
"matchit",
"memchr",
@@ -723,7 +723,7 @@ dependencies = [
"sha1",
"sync_wrapper",
"tokio",
"tokio-tungstenite",
"tokio-tungstenite 0.20.0",
"tower",
"tower-layer",
"tower-service",
@@ -739,7 +739,7 @@ dependencies = [
"bytes",
"futures-util",
"http 0.2.9",
"http-body",
"http-body 0.4.5",
"mime",
"rustversion",
"tower-layer",
@@ -1228,7 +1228,7 @@ dependencies = [
"compute_api",
"flate2",
"futures",
"hyper",
"hyper 0.14.26",
"nix 0.27.1",
"notify",
"num_cpus",
@@ -1344,7 +1344,7 @@ dependencies = [
"futures",
"git-version",
"hex",
"hyper",
"hyper 0.14.26",
"nix 0.27.1",
"once_cell",
"pageserver_api",
@@ -2244,6 +2244,25 @@ dependencies = [
"tracing",
]
[[package]]
name = "h2"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http 1.0.0",
"indexmap 2.0.1",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "half"
version = "1.8.2"
@@ -2409,6 +2428,29 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
dependencies = [
"bytes",
"http 1.0.0",
]
[[package]]
name = "http-body-util"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840"
dependencies = [
"bytes",
"futures-util",
"http 1.0.0",
"http-body 1.0.0",
"pin-project-lite",
]
[[package]]
name = "http-types"
version = "2.12.0"
@@ -2467,9 +2509,9 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2",
"h2 0.3.24",
"http 0.2.9",
"http-body",
"http-body 0.4.5",
"httparse",
"httpdate",
"itoa",
@@ -2481,6 +2523,26 @@ dependencies = [
"want",
]
[[package]]
name = "hyper"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.2",
"http 1.0.0",
"http-body 1.0.0",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
"tokio",
]
[[package]]
name = "hyper-rustls"
version = "0.24.0"
@@ -2488,7 +2550,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7"
dependencies = [
"http 0.2.9",
"hyper",
"hyper 0.14.26",
"log",
"rustls 0.21.9",
"rustls-native-certs",
@@ -2502,7 +2564,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
dependencies = [
"hyper",
"hyper 0.14.26",
"pin-project-lite",
"tokio",
"tokio-io-timeout",
@@ -2515,7 +2577,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
dependencies = [
"bytes",
"hyper",
"hyper 0.14.26",
"native-tls",
"tokio",
"tokio-native-tls",
@@ -2523,15 +2585,33 @@ dependencies = [
[[package]]
name = "hyper-tungstenite"
version = "0.11.1"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad"
dependencies = [
"hyper",
"http-body-util",
"hyper 1.2.0",
"hyper-util",
"pin-project-lite",
"tokio",
"tokio-tungstenite",
"tungstenite",
"tokio-tungstenite 0.21.0",
"tungstenite 0.21.0",
]
[[package]]
name = "hyper-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
dependencies = [
"bytes",
"futures-util",
"http 1.0.0",
"http-body 1.0.0",
"hyper 1.2.0",
"pin-project-lite",
"socket2 0.5.5",
"tokio",
]
[[package]]
@@ -3508,7 +3588,7 @@ dependencies = [
"hex-literal",
"humantime",
"humantime-serde",
"hyper",
"hyper 0.14.26",
"itertools",
"leaky-bucket",
"md5",
@@ -4178,9 +4258,13 @@ dependencies = [
"hex",
"hmac",
"hostname",
"http 1.0.0",
"http-body-util",
"humantime",
"hyper",
"hyper 0.14.26",
"hyper 1.2.0",
"hyper-tungstenite",
"hyper-util",
"ipnet",
"itertools",
"lasso",
@@ -4512,7 +4596,7 @@ dependencies = [
"futures-util",
"http-types",
"humantime",
"hyper",
"hyper 0.14.26",
"itertools",
"metrics",
"once_cell",
@@ -4542,10 +4626,10 @@ dependencies = [
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"h2 0.3.24",
"http 0.2.9",
"http-body",
"hyper",
"http-body 0.4.5",
"hyper 0.14.26",
"hyper-rustls",
"hyper-tls",
"ipnet",
@@ -4603,7 +4687,7 @@ dependencies = [
"futures",
"getrandom 0.2.11",
"http 0.2.9",
"hyper",
"hyper 0.14.26",
"parking_lot 0.11.2",
"reqwest",
"reqwest-middleware",
@@ -4690,7 +4774,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945"
dependencies = [
"http 0.2.9",
"hyper",
"hyper 0.14.26",
"lazy_static",
"percent-encoding",
"regex",
@@ -4969,7 +5053,7 @@ dependencies = [
"git-version",
"hex",
"humantime",
"hyper",
"hyper 0.14.26",
"metrics",
"once_cell",
"parking_lot 0.12.1",
@@ -5444,9 +5528,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.11.0"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
[[package]]
name = "smol_str"
@@ -5538,7 +5622,7 @@ dependencies = [
"futures-util",
"git-version",
"humantime",
"hyper",
"hyper 0.14.26",
"metrics",
"once_cell",
"parking_lot 0.12.1",
@@ -6022,7 +6106,19 @@ dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
"tungstenite 0.20.1",
]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.21.0",
]
[[package]]
@@ -6089,10 +6185,10 @@ dependencies = [
"bytes",
"futures-core",
"futures-util",
"h2",
"h2 0.3.24",
"http 0.2.9",
"http-body",
"hyper",
"http-body 0.4.5",
"hyper 0.14.26",
"hyper-timeout",
"percent-encoding",
"pin-project",
@@ -6278,7 +6374,7 @@ dependencies = [
name = "tracing-utils"
version = "0.1.0"
dependencies = [
"hyper",
"hyper 0.14.26",
"opentelemetry",
"opentelemetry-otlp",
"opentelemetry-semantic-conventions",
@@ -6315,6 +6411,25 @@ dependencies = [
"utf-8",
]
[[package]]
name = "tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http 1.0.0",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]]
name = "twox-hash"
version = "1.6.3"
@@ -6478,7 +6593,7 @@ dependencies = [
"heapless",
"hex",
"hex-literal",
"hyper",
"hyper 0.14.26",
"jsonwebtoken",
"leaky-bucket",
"metrics",
@@ -7003,7 +7118,7 @@ dependencies = [
"hashbrown 0.14.0",
"hex",
"hmac",
"hyper",
"hyper 0.14.26",
"indexmap 1.9.3",
"itertools",
"libc",
@@ -7040,7 +7155,6 @@ dependencies = [
"tower",
"tracing",
"tracing-core",
"tungstenite",
"url",
"uuid",
"zeroize",

View File

@@ -92,7 +92,7 @@ http-types = { version = "2", default-features = false }
humantime = "2.1"
humantime-serde = "1.1.1"
hyper = "0.14"
hyper-tungstenite = "0.11"
hyper-tungstenite = "0.13.0"
inotify = "0.10.2"
ipnet = "2.9.0"
itertools = "0.10"

View File

@@ -30,6 +30,10 @@ hostname.workspace = true
humantime.workspace = true
hyper-tungstenite.workspace = true
hyper.workspace = true
hyper1 = { package = "hyper", version = "1.2", features = ["server", "http1", "http2"] }
hyper-util = { version = "0.1", features = ["tokio"] }
http1 = { package = "http", version = "1" }
http-body-util = { version = "0.1" }
ipnet.workspace = true
itertools.workspace = true
lasso = { workspace = true, features = ["multi-threaded"] }

View File

@@ -175,7 +175,7 @@ async fn task_main(
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let ctx = RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
let ctx = RequestMonitoring::new(session_id, peer_addr, "sni_router", "sni");
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
}
.unwrap_or_else(|e| {

View File

@@ -3,7 +3,7 @@
use chrono::Utc;
use once_cell::sync::OnceCell;
use smol_str::SmolStr;
use std::net::IpAddr;
use std::net::{IpAddr, SocketAddr};
use tokio::sync::mpsc;
use tracing::{field::display, info_span, Span};
use uuid::Uuid;
@@ -62,7 +62,7 @@ pub enum AuthMethod {
impl RequestMonitoring {
pub fn new(
session_id: Uuid,
peer_addr: IpAddr,
peer_addr: SocketAddr,
protocol: &'static str,
region: &'static str,
) -> Self {
@@ -75,7 +75,7 @@ impl RequestMonitoring {
);
Self {
peer_addr,
peer_addr: peer_addr.ip(),
session_id,
protocol,
first_packet: Utc::now(),
@@ -100,7 +100,12 @@ impl RequestMonitoring {
#[cfg(test)]
pub fn test() -> Self {
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
RequestMonitoring::new(
Uuid::now_v7(),
([127, 0, 0, 1], 5432).into(),
"test",
"test",
)
}
pub fn console_application_name(&self) -> String {

View File

@@ -5,19 +5,13 @@ use std::{
io,
net::SocketAddr,
pin::{pin, Pin},
sync::Mutex,
task::{ready, Context, Poll},
};
use bytes::{Buf, BytesMut};
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use metrics::IntCounterPairGuard;
use hyper::server::conn::AddrIncoming;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use uuid::Uuid;
use crate::{metrics::NUM_CLIENT_CONNECTION_GAUGE, serverless::tls_listener::AsyncAccept};
pub struct ProxyProtocolAccept {
pub incoming: AddrIncoming,
@@ -331,87 +325,6 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
}
}
impl AsyncAccept for ProxyProtocolAccept {
type Connection = WithConnectionGuard<WithClientIp<AddrStream>>;
type Error = io::Error;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
tracing::info!(protocol = self.protocol, "accepted new TCP connection");
let Some(conn) = conn else {
return Poll::Ready(None);
};
Poll::Ready(Some(Ok(WithConnectionGuard {
inner: WithClientIp::new(conn),
connection_id: Uuid::new_v4(),
gauge: Mutex::new(Some(
NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&[self.protocol])
.guard(),
)),
})))
}
}
pin_project! {
pub struct WithConnectionGuard<T> {
#[pin]
pub inner: T,
pub connection_id: Uuid,
pub gauge: Mutex<Option<IntCounterPairGuard>>,
}
}
impl<T: AsyncWrite> AsyncWrite for WithConnectionGuard<T> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_shutdown(cx)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
impl<T: AsyncRead> AsyncRead for WithConnectionGuard<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}
#[cfg(test)]
mod tests {
use std::pin::pin;

View File

@@ -91,9 +91,8 @@ pub async fn task_main(
connections.spawn(async move {
let mut socket = WithClientIp::new(socket);
let mut peer_addr = peer_addr.ip();
match socket.wait_for_addr().await {
Ok(Some(addr)) => peer_addr = addr.ip(),
let peer_addr = match socket.wait_for_addr().await {
Ok(Some(addr)) => addr,
Err(e) => {
error!("per-client task finished with an error: {e:#}");
return;
@@ -102,8 +101,8 @@ pub async fn task_main(
error!("missing required client IP");
return;
}
Ok(None) => {}
}
Ok(None) => peer_addr
};
match socket.inner.set_nodelay(true) {
Ok(()) => {},

View File

@@ -4,46 +4,45 @@
mod backend;
mod conn_pool;
mod http_auto;
mod json;
mod sql_over_http;
pub mod tls_listener;
mod websocket;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::bail;
use hyper::StatusCode;
use metrics::IntCounterPairGuard;
use anyhow::Context;
use futures::future::{select, Either};
use http1::{Method, Response, StatusCode};
use http_body_util::Full;
use hyper1::body::Incoming;
use rand::rngs::StdRng;
use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::Serialize;
use tokio::time::timeout;
use tokio_util::task::TaskTracker;
use crate::context::RequestMonitoring;
use crate::metrics::TLS_HANDSHAKE_FAILURES;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES};
use crate::protocol2::WithClientIp;
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_auto::Rewind;
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
server::{
accept,
conn::{AddrIncoming, AddrStream},
},
Body, Method, Request, Response,
};
use std::convert::Infallible;
use std::net::IpAddr;
use std::task::Poll;
use std::{future::ready, sync::Arc};
use tls_listener::TlsListener;
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};
use utils::http::error::ApiError;
pub const SERVERLESS_DRIVER_SNI: &str = "api";
@@ -95,134 +94,221 @@ pub async fn task_main(
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
let _ = addr_incoming.set_nodelay(true);
let addr_incoming = ProxyProtocolAccept {
incoming: addr_incoming,
protocol: "http",
};
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete`
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn {
error!(
protocol = "http",
"failed to accept TLS connection: {err:?}"
);
TLS_HANDSHAKE_FAILURES.inc();
ready(false)
} else {
info!(protocol = "http", "accepted new TLS connection");
ready(true)
let http_connections = tokio_util::task::task_tracker::TaskTracker::new();
http_connections.close();
let server = http_auto::Builder::new();
loop {
let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await else {
break;
};
let (conn, mut peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
tracing::error!("could not set nodolay: {e}");
continue;
}
});
let cancellation_token = cancellation_token.child_token();
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<
WithConnectionGuard<WithClientIp<AddrStream>>,
>| {
let (conn, _) = stream.get_ref();
let tls = tls_acceptor.clone();
// this is jank. should dissapear with hyper 1.0 migration.
let gauge = conn
.gauge
.lock()
.expect("lock should not be poisoned")
.take()
.expect("gauge should be set on connection start");
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
let server = server.clone();
let client_addr = conn.inner.client_addr();
let remote_addr = conn.inner.inner.remote_addr();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
let peer_addr = match client_addr {
Some(addr) => addr,
None if config.require_client_ip => bail!("missing required client ip"),
None => remote_addr,
};
Ok(MetricService::new(
hyper::service::service_fn(move |req: Request<Body>| {
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
http_connections.spawn(async move {
let _gauge = NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&["http"])
.guard();
async move {
Ok::<_, Infallible>(
request_handler(
req,
config,
backend,
ws_connections,
cancellation_handler,
peer_addr.ip(),
endpoint_rate_limiter,
)
.await
.map_or_else(|e| e.into_response(), |r| r),
)
}
}),
gauge,
))
// handle PROXY protocol
let mut conn = WithClientIp::new(conn);
let peer = match conn.wait_for_addr().await {
Ok(peer) => peer,
Err(e) => {
tracing::error!(
"failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"
);
return;
}
};
if let Some(peer) = peer {
peer_addr = peer;
}
},
);
info!(%peer_addr, protocol = "http", "accepted new TCP connection");
hyper::Server::builder(accept::from_stream(tls_listener))
.serve(make_svc)
.with_graceful_shutdown(cancellation_token.cancelled())
.await?;
let accept = tls.accept(conn);
let conn = match timeout(Duration::from_secs(10), accept).await {
Ok(Ok(conn)) => {
info!(%peer_addr, protocol = "http", "accepted new TLS connection");
conn
}
// The handshake failed, try getting another connection from the queue
Ok(Err(e)) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, protocol = "http", "failed to accept TLS connection: {e:?}");
return;
}
// The handshake timed out, try getting another connection from the queue
Err(_) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, protocol = "http", "failed to accept TLS connection: timeout");
return;
}
};
let (version, conn) = match conn.get_ref().1.alpn_protocol() {
Some(b"http/1.1") => (http_auto::Version::H1, Rewind::new(conn)),
Some(b"h2") => (http_auto::Version::H2, Rewind::new(conn)),
_ => {
tracing::debug!("HTTP: no ALPN negotiated");
let conn = timeout(Duration::from_secs(10), http_auto::read_version(conn)).await;
match conn {
Ok(Ok(v)) => v,
Ok(Err(e)) => {
tracing::warn!("HTTP connection error: {e}");
return;
},
Err(_) => {
tracing::warn!("HTTP connection error: timeout determining http version");
return;
}
}
}
};
let conn = server.serve_connection_with_upgrades(
conn,
version,
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
Ok::<_, Infallible>(
request_handler(
req,
config,
backend,
ws_connections,
cancellation_handler,
peer_addr,
endpoint_rate_limiter,
)
.await
.map_or_else(api_error_into_response, |r| r),
)
}
})
);
let cancel = pin!(cancellation_token.cancelled());
let conn = pin!(conn);
let res = match select(cancel, conn).await {
Either::Left((_cancelled, mut conn)) => {
conn.as_mut().graceful_shutdown();
conn.await
}
Either::Right((res, _)) => res,
};
match res {
Ok(()) => {}
Err(e) => {
tracing::warn!("HTTP connection error {e}")
}
}
});
}
// await websocket connections
http_connections.wait().await;
ws_connections.wait().await;
Ok(())
}
struct MetricService<S> {
inner: S,
_gauge: IntCounterPairGuard,
}
impl<S> MetricService<S> {
fn new(inner: S, _gauge: IntCounterPairGuard) -> MetricService<S> {
MetricService { inner, _gauge }
fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
StatusCode::BAD_REQUEST,
),
ApiError::Forbidden(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
}
ApiError::Unauthorized(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
}
ApiError::NotFound(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
}
ApiError::Conflict(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
}
ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
this.to_string(),
StatusCode::PRECONDITION_FAILED,
),
ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
"Shutting down".to_string(),
StatusCode::SERVICE_UNAVAILABLE,
),
ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::SERVICE_UNAVAILABLE,
),
ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::REQUEST_TIMEOUT,
),
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::INTERNAL_SERVER_ERROR,
),
}
}
impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
where
S: hyper::service::Service<Request<ReqBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[derive(Serialize)]
struct HttpErrorBody {
pub msg: String,
}
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
impl HttpErrorBody {
pub fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
HttpErrorBody { msg }.to_response(status)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
self.inner.call(req)
pub fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.header(http1::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
.unwrap()
}
}
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: Request<Body>,
mut request: hyper1::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
peer_addr: IpAddr,
peer_addr: SocketAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Body>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let session_id = uuid::Uuid::new_v4();
let host = request
@@ -261,14 +347,14 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
let span = ctx.span.clone();
sql_over_http::handle(config, ctx, request, backend)
.instrument(span)
.await
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")
.header("Access-Control-Allow-Origin", "*")
@@ -278,9 +364,24 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Body::empty())
.body(Full::new(Bytes::new()))
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<Full<Bytes>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;
let response = Response::builder()
.status(status)
.header(http1::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json)))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}

View File

@@ -0,0 +1,316 @@
//! [`hyper-util`] offers an 'auto' connection to detect whether the connection should be HTTP1 or HTTP2.
//! There's a bug in this implementation where graceful shutdowns are not properly respected.
use futures::ready;
use hyper1::body::Body;
use hyper1::service::HttpService;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{error::Error as StdError, io, marker::Unpin};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use ::http1::{Request, Response};
use bytes::Bytes;
use hyper1::{body::Incoming, service::Service};
use hyper1::server::conn::http1;
use hyper1::{rt::bounds::Http2ServerConnExec, server::conn::http2};
use pin_project_lite::pin_project;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Result<T> = std::result::Result<T, Error>;
const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
/// Http1 or Http2 connection builder.
#[derive(Clone, Debug)]
pub struct Builder {
http1: http1::Builder,
http2: http2::Builder<TokioExecutor>,
}
impl Builder {
/// Create a new auto connection builder.
pub fn new() -> Self {
let mut builder = Self {
http1: http1::Builder::new(),
http2: http2::Builder::new(TokioExecutor::new()),
};
builder.http1.timer(TokioTimer::new());
builder.http2.timer(TokioTimer::new());
builder
}
/// Bind a connection together with a [`Service`], with the ability to
/// handle HTTP upgrades. This requires that the IO object implements
/// `Send`.
pub fn serve_connection_with_upgrades<I, S, B>(
&self,
io: Rewind<I>,
version: Version,
service: S,
) -> UpgradeableConnection<I, S>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
TokioExecutor: Http2ServerConnExec<S::Future, B>,
{
match version {
Version::H1 => {
let conn = self
.http1
.serve_connection(TokioIo::new(io), service)
.with_upgrades();
UpgradeableConnection {
state: UpgradeableConnState::H1 { conn },
}
}
Version::H2 => {
let conn = self.http2.serve_connection(TokioIo::new(io), service);
UpgradeableConnection {
state: UpgradeableConnState::H2 { conn },
}
}
}
}
}
#[derive(Copy, Clone)]
pub(crate) enum Version {
H1,
H2,
}
pub(crate) fn read_version<I>(io: I) -> ReadVersion<I>
where
I: AsyncRead + Unpin,
{
ReadVersion {
io: Some(io),
buf: [0; 24],
filled: 0,
version: Version::H2,
_pin: PhantomPinned,
}
}
pin_project! {
pub(crate) struct ReadVersion<I> {
io: Option<I>,
buf: [u8; 24],
// the amount of `buf` thats been filled
filled: usize,
version: Version,
// Make this future `!Unpin` for compatibility with async trait methods.
#[pin]
_pin: PhantomPinned,
}
}
impl<I> Future for ReadVersion<I>
where
I: AsyncRead + Unpin,
{
type Output = io::Result<(Version, Rewind<I>)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut buf = ReadBuf::new(&mut *this.buf);
buf.set_filled(*this.filled);
// We start as H2 and switch to H1 as soon as we don't have the preface.
while buf.filled().len() < H2_PREFACE.len() {
let len = buf.filled().len();
ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, &mut buf))?;
*this.filled = buf.filled().len();
// We starts as H2 and switch to H1 when we don't get the preface.
if buf.filled().len() == len
|| buf.filled()[len..] != H2_PREFACE[len..buf.filled().len()]
{
*this.version = Version::H1;
break;
}
}
let io = this.io.take().unwrap();
let buf = buf.filled().to_vec();
Poll::Ready(Ok((
*this.version,
Rewind::new_buffered(io, Bytes::from(buf)),
)))
}
}
pin_project! {
/// Connection future.
pub struct UpgradeableConnection<I, S>
where
S: HttpService<Incoming>,
{
#[pin]
state: UpgradeableConnState<I, S>,
}
}
type Http1UpgradeableConnection<I, S> =
hyper1::server::conn::http1::UpgradeableConnection<TokioIo<Rewind<I>>, S>;
type Http2Connection<I, S> =
hyper1::server::conn::http2::Connection<TokioIo<Rewind<I>>, S, TokioExecutor>;
pin_project! {
#[project = UpgradeableConnStateProj]
enum UpgradeableConnState<I, S>
where
S: HttpService<Incoming>,
{
H1 {
#[pin]
conn: Http1UpgradeableConnection<I, S>,
},
H2 {
#[pin]
conn: Http2Connection<I, S>,
},
}
}
impl<I, S, B> UpgradeableConnection<I, S>
where
S: HttpService<Incoming, ResBody = B>,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
TokioExecutor: Http2ServerConnExec<S::Future, B>,
{
/// Start a graceful shutdown process for this connection.
///
/// This `UpgradeableConnection` should continue to be polled until shutdown can finish.
///
/// # Note
///
/// This should only be called while the `Connection` future is still nothing. pending. If
/// called after `UpgradeableConnection::poll` has resolved, this does nothing.
pub fn graceful_shutdown(self: Pin<&mut Self>) {
match self.project().state.project() {
UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(),
UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(),
}
}
}
impl<I, S, B> Future for UpgradeableConnection<I, S>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
TokioExecutor: Http2ServerConnExec<S::Future, B>,
{
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
match this.state.as_mut().project() {
UpgradeableConnStateProj::H1 { conn } => conn.poll(cx).map_err(Into::into),
UpgradeableConnStateProj::H2 { conn } => conn.poll(cx).map_err(Into::into),
}
}
}
/// Combine a buffer with an IO, rewinding reads to use the buffer.
#[derive(Debug)]
pub(crate) struct Rewind<T> {
pre: Option<Bytes>,
inner: T,
}
impl<T> Rewind<T> {
pub(crate) fn new(io: T) -> Self {
Rewind {
pre: None,
inner: io,
}
}
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
Rewind {
pre: Some(buf),
inner: io,
}
}
}
impl<T> AsyncRead for Rewind<T>
where
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = std::cmp::min(prefix.len(), buf.remaining());
buf.put_slice(&prefix[..copy_len]);
// Put back what's left
if !prefix.is_empty() {
self.pre = Some(prefix);
}
return Poll::Ready(Ok(()));
}
}
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl<T> AsyncWrite for Rewind<T>
where
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

View File

@@ -1,14 +1,19 @@
use std::sync::Arc;
use super::json_response;
use anyhow::bail;
use bytes::Bytes;
use futures::StreamExt;
use hyper::body::HttpBody;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
use hyper1::body::Incoming;
use hyper1::header;
use hyper1::http::HeaderName;
use hyper1::http::HeaderValue;
use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use serde_json::json;
use serde_json::Value;
use tokio::try_join;
@@ -22,7 +27,6 @@ use tracing::error;
use tracing::info;
use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
@@ -191,9 +195,9 @@ fn get_conn_info(
pub async fn handle(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
request: Request<Body>,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Body>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let result = tokio::time::timeout(
config.http_config.request_timeout,
handle_inner(config, &mut ctx, request, backend),
@@ -300,19 +304,18 @@ pub async fn handle(
}
};
response.headers_mut().insert(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
);
response
.headers_mut()
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
Ok(response)
}
async fn handle_inner(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> anyhow::Result<Response<Body>> {
) -> anyhow::Result<Response<Full<Bytes>>> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&[ctx.protocol])
.guard();
@@ -369,9 +372,12 @@ async fn handle_inner(
}
let fetch_and_process_request = async {
let body = hyper::body::to_bytes(request.into_body())
let body = request
.into_body()
.collect()
.await
.map_err(anyhow::Error::from)?;
.map_err(anyhow::Error::from)?
.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
@@ -490,7 +496,7 @@ async fn handle_inner(
let body = serde_json::to_string(&result).expect("json serialization should not fail");
let len = body.len();
let response = response
.body(Body::from(body))
.body(Full::new(Bytes::from(body)))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");

View File

@@ -1,283 +0,0 @@
use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use futures::{Future, Stream, StreamExt};
use pin_project_lite::pin_project;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite},
task::JoinSet,
time::timeout,
};
/// Default timeout for the TLS handshake.
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
/// Trait for TLS implementation.
///
/// Implementations are provided by the rustls and native-tls features.
pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
/// The type of the TLS stream created from the underlying stream.
type Stream: Send + 'static;
/// Error type for completing the TLS handshake
type Error: std::error::Error + Send + 'static;
/// Type of the Future for the TLS stream that is accepted.
type AcceptFuture: Future<Output = Result<Self::Stream, Self::Error>> + Send + 'static;
/// Accept a TLS connection on an underlying stream
fn accept(&self, stream: C) -> Self::AcceptFuture;
}
/// Asynchronously accept connections.
pub trait AsyncAccept {
/// The type of the connection that is accepted.
type Connection: AsyncRead + AsyncWrite;
/// The type of error that may be returned.
type Error;
/// Poll to accept the next connection.
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>>;
/// Return a new `AsyncAccept` that stops accepting connections after
/// `ender` completes.
///
/// Useful for graceful shutdown.
///
/// See [examples/echo.rs](https://github.com/tmccombs/tls-listener/blob/main/examples/echo.rs)
/// for example of how to use.
fn until<F: Future>(self, ender: F) -> Until<Self, F>
where
Self: Sized,
{
Until {
acceptor: self,
ender,
}
}
}
pin_project! {
///
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
/// encrypted using TLS.
///
/// It is similar to:
///
/// ```ignore
/// tcpListener.and_then(|s| tlsAcceptor.accept(s))
/// ```
///
/// except that it has the ability to accept multiple transport-level connections
/// simultaneously while the TLS handshake is pending for other connections.
///
/// By default, if a client fails the TLS handshake, that is treated as an error, and the
/// `TlsListener` will return an `Err`. If the `TlsListener` is passed directly to a hyper
/// [`Server`][1], then an invalid handshake can cause the server to stop accepting connections.
/// See [`http-stream.rs`][2] or [`http-low-level`][3] examples, for examples of how to avoid this.
///
/// Note that if the maximum number of pending connections is greater than 1, the resulting
/// [`T::Stream`][4] connections may come in a different order than the connections produced by the
/// underlying listener.
///
/// [1]: https://docs.rs/hyper/latest/hyper/server/struct.Server.html
/// [2]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-stream.rs
/// [3]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-low-level.rs
/// [4]: AsyncTls::Stream
///
#[allow(clippy::type_complexity)]
pub struct TlsListener<A: AsyncAccept, T: AsyncTls<A::Connection>> {
#[pin]
listener: A,
tls: T,
waiting: JoinSet<Result<Result<T::Stream, T::Error>, tokio::time::error::Elapsed>>,
timeout: Duration,
}
}
/// Builder for `TlsListener`.
#[derive(Clone)]
pub struct Builder<T> {
tls: T,
handshake_timeout: Duration,
}
/// Wraps errors from either the listener or the TLS Acceptor
#[derive(Debug, Error)]
pub enum Error<LE: std::error::Error, TE: std::error::Error> {
/// An error that arose from the listener ([AsyncAccept::Error])
#[error("{0}")]
ListenerError(#[source] LE),
/// An error that occurred during the TLS accept handshake
#[error("{0}")]
TlsAcceptError(#[source] TE),
}
impl<A: AsyncAccept, T> TlsListener<A, T>
where
T: AsyncTls<A::Connection>,
{
/// Create a `TlsListener` with default options.
pub fn new(tls: T, listener: A) -> Self {
builder(tls).listen(listener)
}
}
impl<A, T> TlsListener<A, T>
where
A: AsyncAccept,
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
/// Accept the next connection
///
/// This is essentially an alias to `self.next()` with a more domain-appropriate name.
pub async fn accept(&mut self) -> Option<<Self as Stream>::Item>
where
Self: Unpin,
{
self.next().await
}
/// Replaces the Tls Acceptor configuration, which will be used for new connections.
///
/// This can be used to change the certificate used at runtime.
pub fn replace_acceptor(&mut self, acceptor: T) {
self.tls = acceptor;
}
/// Replaces the Tls Acceptor configuration from a pinned reference to `Self`.
///
/// This is useful if your listener is `!Unpin`.
///
/// This can be used to change the certificate used at runtime.
pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) {
*self.project().tls = acceptor;
}
}
impl<A, T> Stream for TlsListener<A, T>
where
A: AsyncAccept,
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
type Item = Result<T::Stream, Error<A::Error, T::Error>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Some(Ok(conn))) => {
this.waiting
.spawn(timeout(*this.timeout, this.tls.accept(conn)));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(Error::ListenerError(e))));
}
Poll::Ready(None) => return Poll::Ready(None),
}
}
loop {
return match this.waiting.poll_join_next(cx) {
Poll::Ready(Some(Ok(Ok(conn)))) => {
Poll::Ready(Some(conn.map_err(Error::TlsAcceptError)))
}
// The handshake timed out, try getting another connection from the queue
Poll::Ready(Some(Ok(Err(_)))) => continue,
// The handshake panicked
Poll::Ready(Some(Err(e))) if e.is_panic() => {
std::panic::resume_unwind(e.into_panic())
}
// The handshake was externally aborted
Poll::Ready(Some(Err(_))) => unreachable!("handshake tasks are never aborted"),
_ => Poll::Pending,
};
}
}
}
impl<C: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncTls<C> for tokio_rustls::TlsAcceptor {
type Stream = tokio_rustls::server::TlsStream<C>;
type Error = std::io::Error;
type AcceptFuture = tokio_rustls::Accept<C>;
fn accept(&self, conn: C) -> Self::AcceptFuture {
tokio_rustls::TlsAcceptor::accept(self, conn)
}
}
impl<T> Builder<T> {
/// Set the timeout for handshakes.
///
/// If a timeout takes longer than `timeout`, then the handshake will be
/// aborted and the underlying connection will be dropped.
///
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
self.handshake_timeout = timeout;
self
}
/// Create a `TlsListener` from the builder
///
/// Actually build the `TlsListener`. The `listener` argument should be
/// an implementation of the `AsyncAccept` trait that accepts new connections
/// that the `TlsListener` will encrypt using TLS.
pub fn listen<A: AsyncAccept>(&self, listener: A) -> TlsListener<A, T>
where
T: AsyncTls<A::Connection>,
{
TlsListener {
listener,
tls: self.tls.clone(),
waiting: JoinSet::new(),
timeout: self.handshake_timeout,
}
}
}
/// Create a new Builder for a TlsListener
///
/// `server_config` will be used to configure the TLS sessions.
pub fn builder<T>(tls: T) -> Builder<T> {
Builder {
tls,
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
}
}
pin_project! {
/// See [`AsyncAccept::until`]
pub struct Until<A, E> {
#[pin]
acceptor: A,
#[pin]
ender: E,
}
}
impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
type Connection = A::Connection;
type Error = A::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
let this = self.project();
match this.ender.poll(cx) {
Poll::Pending => this.acceptor.poll_accept(cx),
Poll::Ready(_) => Poll::Ready(None),
}
}
}

View File

@@ -64,7 +64,7 @@ rustls = { version = "0.21", features = ["dangerous_configuration"] }
scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["raw_value"] }
smallvec = { version = "1", default-features = false, features = ["write"] }
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
subtle = { version = "2" }
time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
@@ -76,7 +76,6 @@ tonic = { version = "0.9", features = ["tls-roots"] }
tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }
tracing = { version = "0.1", features = ["log"] }
tracing-core = { version = "0.1" }
tungstenite = { version = "0.20" }
url = { version = "2", features = ["serde"] }
uuid = { version = "1", features = ["serde", "v4", "v7"] }
zeroize = { version = "1", features = ["derive"] }