mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 05:22:56 +00:00
proxy: hyper1 for only proxy (#7073)
## Problem
hyper1 offers control over the HTTP connection that hyper0_14 does not.
We're blocked on switching all services to hyper1 because of how we use
tonic, but no reason we can't switch proxy over.
## Summary of changes
1. hyper0.14 -> hyper1
1. self managed server
2. Remove the `WithConnectionGuard` wrapper from `protocol2`
2. Remove TLS listener as it's no longer necessary
3. include first session ID in connection startup logs
This commit is contained in:
214
Cargo.lock
generated
214
Cargo.lock
generated
@@ -270,6 +270,12 @@ dependencies = [
|
||||
"critical-section",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-take"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8ab6b55fe97976e46f91ddbed8d147d966475dc29b2032757ba47e02376fbc3"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
@@ -298,7 +304,7 @@ dependencies = [
|
||||
"fastrand 2.0.0",
|
||||
"hex",
|
||||
"http 0.2.9",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"ring 0.17.6",
|
||||
"time",
|
||||
"tokio",
|
||||
@@ -335,7 +341,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"fastrand 2.0.0",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"tracing",
|
||||
@@ -386,7 +392,7 @@ dependencies = [
|
||||
"aws-types",
|
||||
"bytes",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"regex-lite",
|
||||
@@ -514,7 +520,7 @@ dependencies = [
|
||||
"crc32fast",
|
||||
"hex",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"md-5",
|
||||
"pin-project-lite",
|
||||
"sha1",
|
||||
@@ -546,7 +552,7 @@ dependencies = [
|
||||
"bytes-utils",
|
||||
"futures-core",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
@@ -585,10 +591,10 @@ dependencies = [
|
||||
"aws-smithy-types",
|
||||
"bytes",
|
||||
"fastrand 2.0.0",
|
||||
"h2",
|
||||
"h2 0.3.26",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"hyper-rustls",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
@@ -626,7 +632,7 @@ dependencies = [
|
||||
"bytes-utils",
|
||||
"futures-core",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"itoa",
|
||||
"num-integer",
|
||||
"pin-project-lite",
|
||||
@@ -675,8 +681,8 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
@@ -691,7 +697,7 @@ dependencies = [
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tokio-tungstenite 0.20.0",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -707,7 +713,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"mime",
|
||||
"rustversion",
|
||||
"tower-layer",
|
||||
@@ -1196,7 +1202,7 @@ dependencies = [
|
||||
"compute_api",
|
||||
"flate2",
|
||||
"futures",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"nix 0.27.1",
|
||||
"notify",
|
||||
"num_cpus",
|
||||
@@ -1313,7 +1319,7 @@ dependencies = [
|
||||
"git-version",
|
||||
"hex",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"nix 0.27.1",
|
||||
"once_cell",
|
||||
"pageserver_api",
|
||||
@@ -2199,6 +2205,25 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "816ec7294445779408f36fe57bc5b7fc1cf59664059096c65f905c1c61f58069"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
"indexmap 2.0.1",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "1.8.2"
|
||||
@@ -2370,6 +2395,29 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http 1.1.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body-util"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
"http-body 1.0.0",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-types"
|
||||
version = "2.12.0"
|
||||
@@ -2428,9 +2476,9 @@ dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"h2 0.3.26",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
@@ -2442,6 +2490,26 @@ dependencies = [
|
||||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"h2 0.4.4",
|
||||
"http 1.1.0",
|
||||
"http-body 1.0.0",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"smallvec",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-rustls"
|
||||
version = "0.24.0"
|
||||
@@ -2449,7 +2517,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7"
|
||||
dependencies = [
|
||||
"http 0.2.9",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"log",
|
||||
"rustls 0.21.9",
|
||||
"rustls-native-certs 0.6.2",
|
||||
@@ -2463,7 +2531,7 @@ version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-io-timeout",
|
||||
@@ -2476,7 +2544,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"native-tls",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
@@ -2484,15 +2552,33 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tungstenite"
|
||||
version = "0.11.1"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
|
||||
checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"http-body-util",
|
||||
"hyper 1.2.0",
|
||||
"hyper-util",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tungstenite",
|
||||
"tokio-tungstenite 0.21.0",
|
||||
"tungstenite 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
"http-body 1.0.0",
|
||||
"hyper 1.2.0",
|
||||
"pin-project-lite",
|
||||
"socket2 0.5.5",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3523,7 +3609,7 @@ dependencies = [
|
||||
"hex-literal",
|
||||
"humantime",
|
||||
"humantime-serde",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"itertools",
|
||||
"leaky-bucket",
|
||||
"md5",
|
||||
@@ -4202,6 +4288,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"async-compression",
|
||||
"async-trait",
|
||||
"atomic-take",
|
||||
"aws-config",
|
||||
"aws-sdk-iam",
|
||||
"aws-sigv4",
|
||||
@@ -4225,9 +4312,12 @@ dependencies = [
|
||||
"hmac",
|
||||
"hostname",
|
||||
"http 1.1.0",
|
||||
"http-body-util",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"hyper 1.2.0",
|
||||
"hyper-tungstenite",
|
||||
"hyper-util",
|
||||
"ipnet",
|
||||
"itertools",
|
||||
"lasso",
|
||||
@@ -4560,7 +4650,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"http-types",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"itertools",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
@@ -4590,10 +4680,10 @@ dependencies = [
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"h2 0.3.26",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"hyper-rustls",
|
||||
"hyper-tls",
|
||||
"ipnet",
|
||||
@@ -4651,7 +4741,7 @@ dependencies = [
|
||||
"futures",
|
||||
"getrandom 0.2.11",
|
||||
"http 0.2.9",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"parking_lot 0.11.2",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
@@ -4738,7 +4828,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945"
|
||||
dependencies = [
|
||||
"http 0.2.9",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"lazy_static",
|
||||
"percent-encoding",
|
||||
"regex",
|
||||
@@ -5043,7 +5133,7 @@ dependencies = [
|
||||
"git-version",
|
||||
"hex",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
@@ -5528,9 +5618,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.11.0"
|
||||
version = "1.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
|
||||
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
@@ -5622,7 +5712,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"git-version",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
@@ -5653,7 +5743,7 @@ dependencies = [
|
||||
"git-version",
|
||||
"hex",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"itertools",
|
||||
"lasso",
|
||||
"measured",
|
||||
@@ -5682,7 +5772,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"comfy-table",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"pageserver_api",
|
||||
"pageserver_client",
|
||||
"reqwest",
|
||||
@@ -6165,7 +6255,19 @@ dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
"tungstenite 0.20.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6232,10 +6334,10 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"h2 0.3.26",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
@@ -6421,7 +6523,7 @@ dependencies = [
|
||||
name = "tracing-utils"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
"opentelemetry-semantic-conventions",
|
||||
@@ -6458,6 +6560,25 @@ dependencies = [
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.1.0",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.8.5",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "twox-hash"
|
||||
version = "1.6.3"
|
||||
@@ -6623,7 +6744,7 @@ dependencies = [
|
||||
"hex",
|
||||
"hex-literal",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"jsonwebtoken",
|
||||
"leaky-bucket",
|
||||
"metrics",
|
||||
@@ -7214,7 +7335,7 @@ dependencies = [
|
||||
"hashbrown 0.14.0",
|
||||
"hex",
|
||||
"hmac",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"indexmap 1.9.3",
|
||||
"itertools",
|
||||
"libc",
|
||||
@@ -7252,7 +7373,6 @@ dependencies = [
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tungstenite",
|
||||
"url",
|
||||
"uuid",
|
||||
"zeroize",
|
||||
|
||||
@@ -44,6 +44,7 @@ license = "Apache-2.0"
|
||||
anyhow = { version = "1.0", features = ["backtrace"] }
|
||||
arc-swap = "1.6"
|
||||
async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] }
|
||||
atomic-take = "1.1.0"
|
||||
azure_core = "0.18"
|
||||
azure_identity = "0.18"
|
||||
azure_storage = "0.18"
|
||||
@@ -97,7 +98,7 @@ http-types = { version = "2", default-features = false }
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
hyper-tungstenite = "0.11"
|
||||
hyper-tungstenite = "0.13.0"
|
||||
inotify = "0.10.2"
|
||||
ipnet = "2.9.0"
|
||||
itertools = "0.10"
|
||||
|
||||
@@ -12,6 +12,7 @@ testing = []
|
||||
anyhow.workspace = true
|
||||
async-compression.workspace = true
|
||||
async-trait.workspace = true
|
||||
atomic-take.workspace = true
|
||||
aws-config.workspace = true
|
||||
aws-sdk-iam.workspace = true
|
||||
aws-sigv4.workspace = true
|
||||
@@ -36,6 +37,9 @@ http.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
|
||||
hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] }
|
||||
http-body-util = { version = "0.1" }
|
||||
ipnet.workspace = true
|
||||
itertools.workspace = true
|
||||
lasso = { workspace = true, features = ["multi-threaded"] }
|
||||
|
||||
@@ -5,19 +5,13 @@ use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
pin::{pin, Pin},
|
||||
sync::Mutex,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use bytes::{Buf, BytesMut};
|
||||
use hyper::server::accept::Accept;
|
||||
use hyper::server::conn::{AddrIncoming, AddrStream};
|
||||
use metrics::IntCounterPairGuard;
|
||||
use hyper::server::conn::AddrIncoming;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
|
||||
pub struct ProxyProtocolAccept {
|
||||
pub incoming: AddrIncoming,
|
||||
@@ -331,103 +325,6 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Accept for ProxyProtocolAccept {
|
||||
type Conn = WithConnectionGuard<WithClientIp<AddrStream>>;
|
||||
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
let span = tracing::info_span!("http_conn", ?conn_id);
|
||||
{
|
||||
let _enter = span.enter();
|
||||
tracing::info!("accepted new TCP connection");
|
||||
}
|
||||
|
||||
let Some(conn) = conn else {
|
||||
return Poll::Ready(None);
|
||||
};
|
||||
|
||||
Poll::Ready(Some(Ok(WithConnectionGuard {
|
||||
inner: WithClientIp::new(conn),
|
||||
connection_id: Uuid::new_v4(),
|
||||
gauge: Mutex::new(Some(
|
||||
NUM_CLIENT_CONNECTION_GAUGE
|
||||
.with_label_values(&[self.protocol])
|
||||
.guard(),
|
||||
)),
|
||||
span,
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct WithConnectionGuard<T> {
|
||||
#[pin]
|
||||
pub inner: T,
|
||||
pub connection_id: Uuid,
|
||||
pub gauge: Mutex<Option<IntCounterPairGuard>>,
|
||||
pub span: tracing::Span,
|
||||
}
|
||||
|
||||
impl<S> PinnedDrop for WithConnectionGuard<S> {
|
||||
fn drop(this: Pin<&mut Self>) {
|
||||
let _enter = this.span.enter();
|
||||
tracing::info!("HTTP connection closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite> AsyncWrite for WithConnectionGuard<T> {
|
||||
#[inline]
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> AsyncRead for WithConnectionGuard<T> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::pin;
|
||||
|
||||
@@ -4,42 +4,48 @@
|
||||
|
||||
mod backend;
|
||||
mod conn_pool;
|
||||
mod http_util;
|
||||
mod json;
|
||||
mod sql_over_http;
|
||||
pub mod tls_listener;
|
||||
mod websocket;
|
||||
|
||||
use atomic_take::AtomicTake;
|
||||
use bytes::Bytes;
|
||||
pub use conn_pool::GlobalConnPoolOptions;
|
||||
|
||||
use anyhow::bail;
|
||||
use hyper::StatusCode;
|
||||
use metrics::IntCounterPairGuard;
|
||||
use anyhow::Context;
|
||||
use futures::future::{select, Either};
|
||||
use futures::TryFutureExt;
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Incoming;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tracing::instrument::Instrumented;
|
||||
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
|
||||
use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES};
|
||||
use crate::protocol2::WithClientIp;
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use hyper::{
|
||||
server::conn::{AddrIncoming, AddrStream},
|
||||
Body, Method, Request, Response,
|
||||
};
|
||||
use crate::serverless::http_util::{api_error_into_response, json_response};
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use tls_listener::TlsListener;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::{CancellationToken, DropGuard};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
pub const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
|
||||
@@ -91,161 +97,174 @@ pub async fn task_main(
|
||||
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
|
||||
|
||||
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
let _ = addr_incoming.set_nodelay(true);
|
||||
let addr_incoming = ProxyProtocolAccept {
|
||||
incoming: addr_incoming,
|
||||
protocol: "http",
|
||||
};
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
ws_connections.close(); // allows `ws_connections.wait to complete`
|
||||
let server = Builder::new(hyper_util::rt::TokioExecutor::new());
|
||||
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming, config.handshake_timeout);
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
tracing::error!("could not set nodelay: {e}");
|
||||
continue;
|
||||
}
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<
|
||||
WithConnectionGuard<WithClientIp<AddrStream>>,
|
||||
>| {
|
||||
let (conn, _) = stream.get_ref();
|
||||
connections.spawn(
|
||||
connection_handler(
|
||||
config,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellation_token.clone(),
|
||||
server.clone(),
|
||||
tls_acceptor.clone(),
|
||||
conn,
|
||||
peer_addr,
|
||||
)
|
||||
.instrument(http_conn_span),
|
||||
);
|
||||
}
|
||||
|
||||
// this is jank. should dissapear with hyper 1.0 migration.
|
||||
let gauge = conn
|
||||
.gauge
|
||||
.lock()
|
||||
.expect("lock should not be poisoned")
|
||||
.take()
|
||||
.expect("gauge should be set on connection start");
|
||||
|
||||
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let span = conn.span.clone();
|
||||
let client_addr = conn.inner.client_addr();
|
||||
let remote_addr = conn.inner.inner.remote_addr();
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
None if config.require_client_ip => bail!("missing required client ip"),
|
||||
None => remote_addr,
|
||||
};
|
||||
Ok(MetricService::new(
|
||||
hyper::service::service_fn(move |req: Request<Body>| {
|
||||
let backend = backend.clone();
|
||||
let ws_connections2 = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let http_cancellation_token = http_cancellation_token.child_token();
|
||||
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
// Cancel the current inflight HTTP request if the requets stream is closed.
|
||||
// This is slightly different to `_cancel_connection` in that
|
||||
// h2 can cancel individual requests with a `RST_STREAM`.
|
||||
let _cancel_session = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let res = request_handler(
|
||||
req,
|
||||
config,
|
||||
backend,
|
||||
ws_connections2,
|
||||
cancellation_handler,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
http_cancellation_token,
|
||||
)
|
||||
.await
|
||||
.map_or_else(|e| e.into_response(), |r| r);
|
||||
|
||||
_cancel_session.disarm();
|
||||
|
||||
res
|
||||
}
|
||||
.in_current_span(),
|
||||
)
|
||||
}),
|
||||
gauge,
|
||||
cancel_connection,
|
||||
span,
|
||||
))
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
hyper::Server::builder(tls_listener)
|
||||
.serve(make_svc)
|
||||
.with_graceful_shutdown(cancellation_token.cancelled())
|
||||
.await?;
|
||||
|
||||
// await websocket connections
|
||||
ws_connections.wait().await;
|
||||
connections.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct MetricService<S> {
|
||||
inner: S,
|
||||
_gauge: IntCounterPairGuard,
|
||||
_cancel: DropGuard,
|
||||
span: tracing::Span,
|
||||
}
|
||||
/// Handles the TCP lifecycle.
|
||||
///
|
||||
/// 1. Parses PROXY protocol V2
|
||||
/// 2. Handles TLS handshake
|
||||
/// 3. Handles HTTP connection
|
||||
/// 1. With graceful shutdowns
|
||||
/// 2. With graceful request cancellation with connection failure
|
||||
/// 3. With websocket upgrade support.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn connection_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
server: Builder<TokioExecutor>,
|
||||
tls_acceptor: TlsAcceptor,
|
||||
conn: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) {
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
impl<S> MetricService<S> {
|
||||
fn new(
|
||||
inner: S,
|
||||
_gauge: IntCounterPairGuard,
|
||||
_cancel: DropGuard,
|
||||
span: tracing::Span,
|
||||
) -> MetricService<S> {
|
||||
MetricService {
|
||||
inner,
|
||||
_gauge,
|
||||
_cancel,
|
||||
span,
|
||||
let _gauge = NUM_CLIENT_CONNECTION_GAUGE
|
||||
.with_label_values(&["http"])
|
||||
.guard();
|
||||
|
||||
// handle PROXY protocol
|
||||
let mut conn = WithClientIp::new(conn);
|
||||
let peer = match conn.wait_for_addr().await {
|
||||
Ok(peer) => peer,
|
||||
Err(e) => {
|
||||
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
|
||||
where
|
||||
S: hyper::service::Service<Request<ReqBody>>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = Instrumented<S::Future>;
|
||||
let peer_addr = peer.unwrap_or(peer_addr).ip();
|
||||
info!(?session_id, %peer_addr, "accepted new TCP connection");
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
// try upgrade to TLS, but with a timeout.
|
||||
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
|
||||
Ok(Ok(conn)) => {
|
||||
info!(?session_id, %peer_addr, "accepted new TLS connection");
|
||||
conn
|
||||
}
|
||||
// The handshake failed
|
||||
Ok(Err(e)) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
return;
|
||||
}
|
||||
// The handshake timed out
|
||||
Err(e) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
self.span
|
||||
.in_scope(|| self.inner.call(req))
|
||||
.instrument(self.span.clone())
|
||||
let session_id = AtomicTake::new(session_id);
|
||||
|
||||
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let conn = server.serve_connection_with_upgrades(
|
||||
hyper_util::rt::TokioIo::new(conn),
|
||||
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
|
||||
// First HTTP request shares the same session ID
|
||||
let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
|
||||
|
||||
// Cancel the current inflight HTTP request if the requets stream is closed.
|
||||
// This is slightly different to `_cancel_connection` in that
|
||||
// h2 can cancel individual requests with a `RST_STREAM`.
|
||||
let http_request_token = http_cancellation_token.child_token();
|
||||
let cancel_request = http_request_token.clone().drop_guard();
|
||||
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
let handler = connections.spawn(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
peer_addr,
|
||||
endpoint_rate_limiter.clone(),
|
||||
http_request_token,
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
);
|
||||
|
||||
async move {
|
||||
let res = handler.await;
|
||||
cancel_request.disarm();
|
||||
res
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// On cancellation, trigger the HTTP connection handler to shut down.
|
||||
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
|
||||
Either::Left((_cancelled, mut conn)) => {
|
||||
conn.as_mut().graceful_shutdown();
|
||||
conn.await
|
||||
}
|
||||
Either::Right((res, _)) => res,
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
|
||||
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: Request<Body>,
|
||||
mut request: hyper1::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
@@ -282,14 +301,14 @@ async fn request_handler(
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
let span = ctx.span.clone();
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
|
||||
.instrument(span)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
@@ -299,7 +318,7 @@ async fn request_handler(
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Body::empty())
|
||||
.body(Full::new(Bytes::new()))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
|
||||
92
proxy/src/serverless/http_util.rs
Normal file
92
proxy/src/serverless/http_util.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
//! Things stolen from `libs/utils/src/http` to add hyper 1.0 compatibility
|
||||
//! Will merge back in at some point in the future.
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
use anyhow::Context;
|
||||
use http::{Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
|
||||
use serde::Serialize;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
/// Like [`ApiError::into_response`]
|
||||
pub fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
match this {
|
||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||
StatusCode::BAD_REQUEST,
|
||||
),
|
||||
ApiError::Forbidden(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
|
||||
}
|
||||
ApiError::Unauthorized(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
ApiError::NotFound(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
|
||||
}
|
||||
ApiError::Conflict(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
|
||||
}
|
||||
ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
|
||||
this.to_string(),
|
||||
StatusCode::PRECONDITION_FAILED,
|
||||
),
|
||||
ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
|
||||
"Shutting down".to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::REQUEST_TIMEOUT,
|
||||
),
|
||||
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::error::HttpErrorBody`]
|
||||
#[derive(Serialize)]
|
||||
struct HttpErrorBody {
|
||||
pub msg: String,
|
||||
}
|
||||
|
||||
impl HttpErrorBody {
|
||||
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
|
||||
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
HttpErrorBody { msg }.to_response(status)
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
|
||||
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::json::json_response`]
|
||||
pub fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
.body(Full::new(Bytes::from(json)))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
@@ -1,18 +1,22 @@
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::future::select;
|
||||
use futures::future::try_join;
|
||||
use futures::future::Either;
|
||||
use futures::StreamExt;
|
||||
use futures::TryFutureExt;
|
||||
use hyper::body::HttpBody;
|
||||
use hyper::header;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Body;
|
||||
use hyper1::body::Incoming;
|
||||
use hyper1::header;
|
||||
use hyper1::http::HeaderName;
|
||||
use hyper1::http::HeaderValue;
|
||||
use hyper1::Response;
|
||||
use hyper1::StatusCode;
|
||||
use hyper1::{HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use tokio::time;
|
||||
@@ -29,7 +33,6 @@ use tracing::error;
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
use utils::http::json::json_response;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
@@ -52,6 +55,7 @@ use crate::RoleName;
|
||||
use super::backend::PoolingBackend;
|
||||
use super::conn_pool::Client;
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::http_util::json_response;
|
||||
use super::json::json_to_pg_text;
|
||||
use super::json::pg_text_row_to_json;
|
||||
use super::json::JsonConversionError;
|
||||
@@ -218,10 +222,10 @@ fn get_conn_info(
|
||||
pub async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &mut ctx, request, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
@@ -332,10 +336,9 @@ pub async fn handle(
|
||||
}
|
||||
};
|
||||
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -396,7 +399,7 @@ impl UserFacingError for SqlOverHttpError {
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ReadPayloadError {
|
||||
#[error("could not read the HTTP request body: {0}")]
|
||||
Read(#[from] hyper::Error),
|
||||
Read(#[from] hyper1::Error),
|
||||
#[error("could not parse the HTTP request body: {0}")]
|
||||
Parse(#[from] serde_json::Error),
|
||||
}
|
||||
@@ -437,7 +440,7 @@ struct HttpHeaders {
|
||||
}
|
||||
|
||||
impl HttpHeaders {
|
||||
fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
|
||||
fn try_parse(headers: &hyper1::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
|
||||
// Determine the output options. Default behaviour is 'false'. Anything that is not
|
||||
// strictly 'true' assumed to be false.
|
||||
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
|
||||
@@ -488,9 +491,9 @@ async fn handle_inner(
|
||||
cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Body>, SqlOverHttpError> {
|
||||
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard();
|
||||
@@ -528,7 +531,7 @@ async fn handle_inner(
|
||||
}
|
||||
|
||||
let fetch_and_process_request = async {
|
||||
let body = hyper::body::to_bytes(request.into_body()).await?;
|
||||
let body = request.into_body().collect().await?.to_bytes();
|
||||
info!(length = body.len(), "request payload read");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
|
||||
@@ -596,7 +599,7 @@ async fn handle_inner(
|
||||
let body = serde_json::to_string(&result).expect("json serialization should not fail");
|
||||
let len = body.len();
|
||||
let response = response
|
||||
.body(Body::from(body))
|
||||
.body(Full::new(Bytes::from(body)))
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
@@ -639,6 +642,7 @@ impl QueryData {
|
||||
}
|
||||
// The query was cancelled.
|
||||
Either::Right((_cancelled, query)) => {
|
||||
tracing::info!("cancelling query");
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::error!(?err, "could not cancel query");
|
||||
}
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use hyper::server::{accept::Accept, conn::AddrStream};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
task::JoinSet,
|
||||
time::timeout,
|
||||
};
|
||||
use tokio_rustls::{server::TlsStream, TlsAcceptor};
|
||||
use tracing::{info, warn, Instrument};
|
||||
|
||||
use crate::{
|
||||
metrics::TLS_HANDSHAKE_FAILURES,
|
||||
protocol2::{WithClientIp, WithConnectionGuard},
|
||||
};
|
||||
|
||||
pin_project! {
|
||||
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
|
||||
/// encrypted using TLS.
|
||||
pub(crate) struct TlsListener<A: Accept> {
|
||||
#[pin]
|
||||
listener: A,
|
||||
tls: TlsAcceptor,
|
||||
waiting: JoinSet<Option<TlsStream<A::Conn>>>,
|
||||
timeout: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Accept> TlsListener<A> {
|
||||
/// Create a `TlsListener` with default options.
|
||||
pub(crate) fn new(tls: TlsAcceptor, listener: A, timeout: Duration) -> Self {
|
||||
TlsListener {
|
||||
listener,
|
||||
tls,
|
||||
waiting: JoinSet::new(),
|
||||
timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Accept for TlsListener<A>
|
||||
where
|
||||
A: Accept<Conn = WithConnectionGuard<WithClientIp<AddrStream>>>,
|
||||
A::Error: std::error::Error,
|
||||
A::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Conn = TlsStream<A::Conn>;
|
||||
|
||||
type Error = Infallible;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let mut this = self.project();
|
||||
|
||||
loop {
|
||||
match this.listener.as_mut().poll_accept(cx) {
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Some(Ok(mut conn))) => {
|
||||
let t = *this.timeout;
|
||||
let tls = this.tls.clone();
|
||||
let span = conn.span.clone();
|
||||
this.waiting.spawn(async move {
|
||||
let peer_addr = match conn.inner.wait_for_addr().await {
|
||||
Ok(Some(addr)) => addr,
|
||||
Err(e) => {
|
||||
tracing::error!("failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
|
||||
return None;
|
||||
}
|
||||
Ok(None) => conn.inner.inner.remote_addr()
|
||||
};
|
||||
|
||||
let accept = tls.accept(conn);
|
||||
match timeout(t, accept).await {
|
||||
Ok(Ok(conn)) => {
|
||||
info!(%peer_addr, "accepted new TLS connection");
|
||||
Some(conn)
|
||||
},
|
||||
// The handshake failed, try getting another connection from the queue
|
||||
Ok(Err(e)) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
None
|
||||
}
|
||||
// The handshake timed out, try getting another connection from the queue
|
||||
Err(_) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, "failed to accept TLS connection: timeout");
|
||||
None
|
||||
}
|
||||
}
|
||||
}.instrument(span));
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
tracing::error!("error accepting TCP connection: {e}");
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
return match this.waiting.poll_join_next(cx) {
|
||||
Poll::Ready(Some(Ok(Some(conn)))) => Poll::Ready(Some(Ok(conn))),
|
||||
// The handshake failed to complete, try getting another connection from the queue
|
||||
Poll::Ready(Some(Ok(None))) => continue,
|
||||
// The handshake panicked or was cancelled. ignore and get another connection
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
tracing::warn!("handshake aborted: {e}");
|
||||
continue;
|
||||
}
|
||||
_ => Poll::Pending,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,7 +63,7 @@ scopeguard = { version = "1" }
|
||||
serde = { version = "1", features = ["alloc", "derive"] }
|
||||
serde_json = { version = "1", features = ["raw_value"] }
|
||||
sha2 = { version = "0.10", features = ["asm"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["write"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
|
||||
subtle = { version = "2" }
|
||||
time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
|
||||
@@ -75,7 +75,6 @@ tonic = { version = "0.9", features = ["tls-roots"] }
|
||||
tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }
|
||||
tracing = { version = "0.1", features = ["log"] }
|
||||
tracing-core = { version = "0.1" }
|
||||
tungstenite = { version = "0.20" }
|
||||
url = { version = "2", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["serde", "v4", "v7"] }
|
||||
zeroize = { version = "1", features = ["derive"] }
|
||||
|
||||
Reference in New Issue
Block a user