mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-20 14:40:37 +00:00
proxy
This commit is contained in:
142
Cargo.lock
generated
142
Cargo.lock
generated
@@ -612,7 +612,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
@@ -712,7 +712,7 @@ dependencies = [
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tokio-tungstenite 0.20.0",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -2470,10 +2470,10 @@ dependencies = [
|
||||
"http 0.2.9",
|
||||
"hyper 0.14.26",
|
||||
"log",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"rustls-native-certs",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.24.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2503,15 +2503,17 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tungstenite"
|
||||
version = "0.11.1"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
|
||||
checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad"
|
||||
dependencies = [
|
||||
"hyper 0.14.26",
|
||||
"http-body-util",
|
||||
"hyper 1.1.0",
|
||||
"hyper-util",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tungstenite",
|
||||
"tokio-tungstenite 0.21.0",
|
||||
"tungstenite 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3824,14 +3826,14 @@ dependencies = [
|
||||
"futures",
|
||||
"once_cell",
|
||||
"pq_proto",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tracing",
|
||||
"workspace_hack",
|
||||
]
|
||||
@@ -4042,9 +4044,13 @@ dependencies = [
|
||||
"hex",
|
||||
"hmac",
|
||||
"hostname",
|
||||
"http 1.0.0",
|
||||
"http-body 1.0.0",
|
||||
"http-body-util",
|
||||
"humantime",
|
||||
"hyper 1.1.0",
|
||||
"hyper-tungstenite",
|
||||
"hyper-util",
|
||||
"ipnet",
|
||||
"itertools",
|
||||
"md5",
|
||||
@@ -4074,7 +4080,7 @@ dependencies = [
|
||||
"routerify",
|
||||
"rstest",
|
||||
"rustc-hash",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"rustls-pemfile",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
@@ -4089,7 +4095,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
@@ -4240,7 +4246,7 @@ dependencies = [
|
||||
"itoa",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"rustls-native-certs",
|
||||
"rustls-pemfile",
|
||||
"rustls-webpki 0.101.7",
|
||||
@@ -4248,7 +4254,7 @@ dependencies = [
|
||||
"sha1_smol",
|
||||
"socket2 0.4.9",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-util",
|
||||
"url",
|
||||
]
|
||||
@@ -4390,14 +4396,14 @@ dependencies = [
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-util",
|
||||
"tower-service",
|
||||
"url",
|
||||
@@ -4652,6 +4658,20 @@ dependencies = [
|
||||
"sct",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring 0.17.6",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.102.1",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.6.3"
|
||||
@@ -4673,6 +4693,12 @@ dependencies = [
|
||||
"base64 0.21.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e9d979b3ce68192e42760c7810125eb6cf2ea10efae545a156063e61f314e2a"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.100.2"
|
||||
@@ -4693,6 +4719,17 @@ dependencies = [
|
||||
"untrusted 0.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.102.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef4ca26037c909dedb327b48c3327d0ba91d3dd3c4e05dad328f210ffb68e95b"
|
||||
dependencies = [
|
||||
"ring 0.17.6",
|
||||
"rustls-pki-types",
|
||||
"untrusted 0.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.12"
|
||||
@@ -4735,7 +4772,7 @@ dependencies = [
|
||||
"serde_with",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
@@ -4901,7 +4938,7 @@ checksum = "2e95efd0cefa32028cdb9766c96de71d96671072f9fb494dc9fb84c0ef93e52b"
|
||||
dependencies = [
|
||||
"httpdate",
|
||||
"reqwest",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"sentry-backtrace",
|
||||
"sentry-contexts",
|
||||
"sentry-core",
|
||||
@@ -5600,16 +5637,15 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tls-listener"
|
||||
version = "0.7.0"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81294c017957a1a69794f506723519255879e15a870507faf45dfed288b763dd"
|
||||
checksum = "e516b94ec383622ae4da09b2a9bab471c4c880453c93fd0dd4dff1f6f2ff0849"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"hyper 0.14.26",
|
||||
"pin-project-lite",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.25.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5692,10 +5728,10 @@ checksum = "dd5831152cb0d3f79ef5523b357319ba154795d64c7078b2daa95a803b54057f"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"ring 0.16.20",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.24.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5704,7 +5740,18 @@ version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.25.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f"
|
||||
dependencies = [
|
||||
"rustls 0.22.2",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
@@ -5743,7 +5790,19 @@ dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
"tungstenite 0.20.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5851,7 +5910,7 @@ dependencies = [
|
||||
"rustls-native-certs",
|
||||
"rustls-pemfile",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-stream",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
@@ -6076,6 +6135,25 @@ dependencies = [
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.0.0",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.8.5",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "twox-hash"
|
||||
version = "1.6.3"
|
||||
@@ -6173,7 +6251,7 @@ dependencies = [
|
||||
"base64 0.21.1",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"rustls-webpki 0.100.2",
|
||||
"url",
|
||||
"webpki-roots 0.23.1",
|
||||
@@ -6774,7 +6852,7 @@ dependencies = [
|
||||
"regex-syntax 0.8.2",
|
||||
"reqwest",
|
||||
"ring 0.16.20",
|
||||
"rustls",
|
||||
"rustls 0.21.9",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -6785,14 +6863,14 @@ dependencies = [
|
||||
"time",
|
||||
"time-macros",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-util",
|
||||
"toml_datetime",
|
||||
"toml_edit",
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tungstenite",
|
||||
"tungstenite 0.20.1",
|
||||
"url",
|
||||
"uuid",
|
||||
"zstd",
|
||||
|
||||
@@ -94,7 +94,7 @@ hyper-util = "0.1.0"
|
||||
http = "1"
|
||||
http-body = "1"
|
||||
http-body-util = "0.1"
|
||||
hyper-tungstenite = "0.11"
|
||||
hyper-tungstenite = "0.13.0"
|
||||
inotify = "0.10.2"
|
||||
ipnet = "2.9.0"
|
||||
itertools = "0.10"
|
||||
@@ -153,7 +153,7 @@ tar = "0.4"
|
||||
task-local-extensions = "0.1.4"
|
||||
test-context = "0.1"
|
||||
thiserror = "1.0"
|
||||
tls-listener = { version = "0.7", features = ["rustls", "hyper-h1"] }
|
||||
tls-listener = { version = "0.9", features = ["rustls", "tokio-net"] }
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-io-timeout = "1.2.0"
|
||||
tokio-postgres-rustls = "0.10.0"
|
||||
|
||||
@@ -4,6 +4,8 @@ use crate::http::{
|
||||
error::ApiError,
|
||||
json::{json_request, json_response},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use routerify::Body;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -152,7 +154,7 @@ pub struct FailpointConfig {
|
||||
pub async fn failpoints_handler(
|
||||
mut request: Request<Body>,
|
||||
_cancel: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
if !fail::has_failpoints() {
|
||||
return Err(ApiError::BadRequest(anyhow::anyhow!(
|
||||
"Cannot manage failpoints because storage was compiled without failpoints support"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::Buf;
|
||||
use http_body_util::BodyExt;
|
||||
use bytes::{Buf, Bytes};
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::{header, Request, Response, StatusCode};
|
||||
use routerify::Body;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -44,14 +44,14 @@ pub async fn json_request_or_empty_body<T: for<'de> Deserialize<'de>>(
|
||||
pub fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(json))
|
||||
.body(Full::from(json))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -5,4 +5,4 @@ pub mod request;
|
||||
|
||||
/// Current fast way to apply simple http routing in various Neon binaries.
|
||||
/// Re-exported for sake of uniform approach, that could be later replaced with better alternatives, if needed.
|
||||
pub use routerify::{ext::RequestExt, RouterBuilder};
|
||||
pub use routerify::{ext::RequestExt, Body, RequestServiceBuilder, RouterBuilder};
|
||||
|
||||
@@ -28,7 +28,11 @@ hmac.workspace = true
|
||||
hostname.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper = { workspace = true, features = ["server"] }
|
||||
hyper-util = { workspace = true, features = ["tokio", "server", "server-auto"] }
|
||||
http = { workspace = true, features = [] }
|
||||
http-body = { workspace = true, features = [] }
|
||||
http-body-util = { workspace = true, features = [] }
|
||||
ipnet.workspace = true
|
||||
itertools.workspace = true
|
||||
md5.workspace = true
|
||||
|
||||
@@ -4,14 +4,12 @@
|
||||
|
||||
pub mod health_server;
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::FutureExt;
|
||||
pub use reqwest::{Request, Response, StatusCode};
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::time::Instant;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::{metrics::CONSOLE_REQUEST_LATENCY, rate_limiter, url::ApiUrl};
|
||||
use reqwest_middleware::RequestBuilder;
|
||||
@@ -21,8 +19,6 @@ use reqwest_middleware::RequestBuilder;
|
||||
/// We deliberately don't want to replace this with a public static.
|
||||
pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> ClientWithMiddleware {
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.dns_resolver(Arc::new(GaiResolver::default()))
|
||||
.connection_verbose(true)
|
||||
.build()
|
||||
.expect("Failed to create http client");
|
||||
|
||||
@@ -34,8 +30,6 @@ pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> Clien
|
||||
|
||||
pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
|
||||
let timeout_client = reqwest::ClientBuilder::new()
|
||||
.dns_resolver(Arc::new(GaiResolver::default()))
|
||||
.connection_verbose(true)
|
||||
.timeout(default_timout)
|
||||
.build()
|
||||
.expect("Failed to create http client with timeout");
|
||||
@@ -100,37 +94,6 @@ impl Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
/// https://docs.rs/reqwest/0.11.18/src/reqwest/dns/gai.rs.html
|
||||
use hyper::{
|
||||
client::connect::dns::{GaiResolver as HyperGaiResolver, Name},
|
||||
service::Service,
|
||||
};
|
||||
use reqwest::dns::{Addrs, Resolve, Resolving};
|
||||
#[derive(Debug)]
|
||||
pub struct GaiResolver(HyperGaiResolver);
|
||||
|
||||
impl Default for GaiResolver {
|
||||
fn default() -> Self {
|
||||
Self(HyperGaiResolver::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl Resolve for GaiResolver {
|
||||
fn resolve(&self, name: Name) -> Resolving {
|
||||
let this = &mut self.0.clone();
|
||||
let start = Instant::now();
|
||||
Box::pin(
|
||||
Service::<Name>::call(this, name.clone()).map(move |result| {
|
||||
let resolve_duration = start.elapsed();
|
||||
trace!(duration = ?resolve_duration, addr = %name, "resolve host complete");
|
||||
result
|
||||
.map(|addrs| -> Addrs { Box::new(addrs) })
|
||||
.map_err(|err| -> Box<dyn std::error::Error + Send + Sync> { Box::new(err) })
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
use anyhow::{anyhow, bail};
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use anyhow::anyhow;
|
||||
use http::{Request, Response};
|
||||
use hyper::StatusCode;
|
||||
use hyper_util::{
|
||||
rt::{TokioExecutor, TokioIo},
|
||||
server::conn,
|
||||
};
|
||||
use std::{convert::Infallible, net::TcpListener};
|
||||
use tracing::info;
|
||||
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
|
||||
use utils::http::{
|
||||
endpoint, error::ApiError, json::json_response, Body, RequestServiceBuilder, RouterBuilder,
|
||||
};
|
||||
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
json_response(StatusCode::OK, "")
|
||||
json_response(StatusCode::OK, "").map(|req| req.map(Body::new))
|
||||
}
|
||||
|
||||
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
fn make_router() -> RouterBuilder<Body, ApiError> {
|
||||
endpoint::make_router().get("/v1/status", status_handler)
|
||||
}
|
||||
|
||||
@@ -17,11 +24,20 @@ pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<Infallible>
|
||||
info!("http has shut down");
|
||||
}
|
||||
|
||||
let service = || RouterService::new(make_router().build()?);
|
||||
let router = make_router().build().map_err(|e| anyhow!(e))?;
|
||||
let builder = RequestServiceBuilder::new(router).map_err(|e| anyhow!(e))?;
|
||||
let listener = tokio::net::TcpListener::from_std(http_listener)?;
|
||||
|
||||
hyper::Server::from_tcp(http_listener)?
|
||||
.serve(service().map_err(|e| anyhow!(e))?)
|
||||
.await?;
|
||||
|
||||
bail!("hyper server without shutdown handling cannot shutdown successfully");
|
||||
loop {
|
||||
let (stream, remote_addr) = listener.accept().await.unwrap();
|
||||
let io = TokioIo::new(stream);
|
||||
let service = builder.build(remote_addr);
|
||||
tokio::task::spawn(async move {
|
||||
let builder = conn::auto::Builder::new(TokioExecutor::new());
|
||||
let res = builder.serve_connection(io, service).await;
|
||||
if let Err(err) = res {
|
||||
println!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,13 +10,15 @@ use std::{
|
||||
};
|
||||
|
||||
use bytes::{Buf, BytesMut};
|
||||
use hyper::server::conn::{AddrIncoming, AddrStream};
|
||||
use pin_project_lite::pin_project;
|
||||
use tls_listener::AsyncAccept;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf},
|
||||
net::{TcpListener, TcpStream},
|
||||
};
|
||||
|
||||
pub struct ProxyProtocolAccept {
|
||||
pub incoming: AddrIncoming,
|
||||
pub incoming: TcpListener,
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
@@ -327,20 +329,18 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
|
||||
}
|
||||
|
||||
impl AsyncAccept for ProxyProtocolAccept {
|
||||
type Connection = WithClientIp<AddrStream>;
|
||||
type Connection = WithClientIp<TcpStream>;
|
||||
|
||||
type Address = SocketAddr;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
|
||||
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
|
||||
let Some(conn) = conn else {
|
||||
return Poll::Ready(None);
|
||||
};
|
||||
|
||||
Poll::Ready(Some(Ok(WithClientIp::new(conn))))
|
||||
) -> Poll<Result<(Self::Connection, Self::Address), Self::Error>> {
|
||||
Pin::new(&mut self.incoming)
|
||||
.poll_accept(cx)
|
||||
.map_ok(|(c, a)| (WithClientIp::new(c), a))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ use crate::{
|
||||
proxy::retry::{retry_after, ShouldRetry},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use hyper::StatusCode;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use reqwest::StatusCode;
|
||||
use std::ops::ControlFlow;
|
||||
use tokio::time;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
@@ -6,41 +6,54 @@ mod conn_pool;
|
||||
mod sql_over_http;
|
||||
mod websocket;
|
||||
|
||||
use bytes::Bytes;
|
||||
pub use conn_pool::GlobalConnPoolOptions;
|
||||
|
||||
use anyhow::bail;
|
||||
use http_body_util::Full;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::StatusCode;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use hyper_util::server::conn;
|
||||
use metrics::IntCounterPairGuard;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::select;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
use crate::protocol2::ProxyProtocolAccept;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::{cancellation::CancelMap, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
server::{
|
||||
accept,
|
||||
conn::{AddrIncoming, AddrStream},
|
||||
},
|
||||
Body, Method, Request, Response,
|
||||
};
|
||||
use hyper::{Method, Request, Response};
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::task::Poll;
|
||||
use std::{future::ready, sync::Arc};
|
||||
use tls_listener::TlsListener;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tls_listener::{AsyncTls, TlsListener};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Tls(TlsAcceptor);
|
||||
|
||||
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncTls<C> for Tls {
|
||||
type Stream = tokio_rustls::server::TlsStream<C>;
|
||||
type Error = std::io::Error;
|
||||
type AcceptFuture = tokio_rustls::Accept<C>;
|
||||
|
||||
fn accept(&self, conn: C) -> Self::AcceptFuture {
|
||||
tokio_rustls::TlsAcceptor::accept(&self.0, conn)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
ws_listener: TcpListener,
|
||||
@@ -79,42 +92,52 @@ pub async fn task_main(
|
||||
};
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
|
||||
|
||||
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
let _ = addr_incoming.set_nodelay(true);
|
||||
// let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
// let _ = addr_incoming.set_nodelay(true);
|
||||
let addr_incoming = ProxyProtocolAccept {
|
||||
incoming: addr_incoming,
|
||||
incoming: ws_listener,
|
||||
};
|
||||
|
||||
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let ws_connections2 = ws_connections.clone();
|
||||
ws_connections.close(); // allows `ws_connections.wait to complete`
|
||||
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
error!("failed to accept TLS connection for websockets: {err:?}");
|
||||
ready(false)
|
||||
} else {
|
||||
ready(true)
|
||||
}
|
||||
});
|
||||
let mut tls_listener = TlsListener::new(Tls(tls_acceptor), addr_incoming);
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (stream, remote_addr) = select! {
|
||||
res = tls_listener.accept() => {
|
||||
match res {
|
||||
Err(err) =>
|
||||
{error!("failed to accept TLS connection for websockets: {err:?}"); continue},
|
||||
Ok(s) => s,
|
||||
}
|
||||
}
|
||||
_ = cancellation_token.cancelled() => break,
|
||||
};
|
||||
let (io, tls) = stream.get_ref();
|
||||
let client_addr = io.client_addr();
|
||||
let remote_addr = io.inner.remote_addr();
|
||||
let sni_name = tls.server_name().map(|s| s.to_string());
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let ws_connections = ws_connections2.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
None if config.require_client_ip => bail!("missing required client ip"),
|
||||
None => remote_addr,
|
||||
};
|
||||
Ok(MetricService::new(hyper::service::service_fn(
|
||||
move |req: Request<Body>| {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
None if config.require_client_ip => {
|
||||
tracing::error!("Error serving connection: missing required client ip");
|
||||
continue;
|
||||
}
|
||||
None => remote_addr,
|
||||
};
|
||||
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let cancellation_token = cancellation_token.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let service = MetricService::new(hyper::service::service_fn(
|
||||
move |req: Request<Incoming>| {
|
||||
let sni_name = sni_name.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
@@ -144,15 +167,22 @@ pub async fn task_main(
|
||||
.await
|
||||
}
|
||||
},
|
||||
)))
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
hyper::Server::builder(accept::from_stream(tls_listener))
|
||||
.serve(make_svc)
|
||||
.with_graceful_shutdown(cancellation_token.cancelled())
|
||||
.await?;
|
||||
));
|
||||
let builder = conn::auto::Builder::new(TokioExecutor::new());
|
||||
let mut conn = pin!(builder.serve_connection(io, service));
|
||||
let res = select! {
|
||||
_ = cancellation_token.cancelled() => {
|
||||
conn.as_mut().graceful_shutdown();
|
||||
conn.await
|
||||
}
|
||||
res = conn.as_mut() => res,
|
||||
};
|
||||
if let Err(err) = res {
|
||||
tracing::error!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// await websocket connections
|
||||
ws_connections.wait().await;
|
||||
@@ -184,18 +214,14 @@ where
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
fn call(&self, req: Request<ReqBody>) -> Self::Future {
|
||||
self.inner.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: Request<Body>,
|
||||
mut request: Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
tls: &'static TlsConfig,
|
||||
conn_pool: Arc<conn_pool::GlobalConnPool>,
|
||||
@@ -205,7 +231,7 @@ async fn request_handler(
|
||||
sni_hostname: Option<String>,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
@@ -264,7 +290,7 @@ async fn request_handler(
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Body::empty())
|
||||
.body(Full::new(Bytes::new()))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use bytes::Buf;
|
||||
use bytes::Bytes;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
use http_body::Body;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::header;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use hyper::{HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
@@ -235,10 +240,10 @@ pub async fn handle(
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
request: Request<Incoming>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.request_timeout,
|
||||
handle_inner(tls, config, ctx, request, sni_hostname, conn_pool),
|
||||
@@ -347,10 +352,10 @@ async fn handle_inner(
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
request: Request<Incoming>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
) -> anyhow::Result<Response<Full<Bytes>>> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&["http"])
|
||||
.guard();
|
||||
@@ -406,8 +411,8 @@ async fn handle_inner(
|
||||
//
|
||||
// Read the query and query params from the request body
|
||||
//
|
||||
let body = hyper::body::to_bytes(request.into_body()).await?;
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
let body = request.into_body().collect().await?.aggregate().reader();
|
||||
let payload: Payload = serde_json::from_reader(body)?;
|
||||
|
||||
let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
|
||||
|
||||
@@ -504,7 +509,7 @@ async fn handle_inner(
|
||||
let body = serde_json::to_string(&result).expect("json serialization should not fail");
|
||||
let len = body.len();
|
||||
let response = response
|
||||
.body(Body::from(body))
|
||||
.body(Full::from(body))
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
|
||||
@@ -235,18 +235,19 @@ async fn collect_metrics_iteration(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
net::TcpListener,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use anyhow::Error;
|
||||
use bytes::{Buf, Bytes};
|
||||
use chrono::Utc;
|
||||
use consumption_metrics::{Event, EventChunk};
|
||||
use hyper::{
|
||||
service::{make_service_fn, service_fn},
|
||||
Body, Response,
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use hyper::{body::Incoming, service::service_fn, Response};
|
||||
use hyper_util::{
|
||||
rt::{TokioExecutor, TokioIo},
|
||||
server::conn,
|
||||
};
|
||||
use tokio::net::TcpListener;
|
||||
use url::Url;
|
||||
|
||||
use super::{collect_metrics_iteration, Ids, Metrics};
|
||||
@@ -254,30 +255,43 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn metrics() {
|
||||
let listener = TcpListener::bind("0.0.0.0:0").unwrap();
|
||||
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let reports = Arc::new(Mutex::new(vec![]));
|
||||
let reports2 = reports.clone();
|
||||
|
||||
let server = hyper::server::Server::from_tcp(listener)
|
||||
.unwrap()
|
||||
.serve(make_service_fn(move |_| {
|
||||
let reports = reports.clone();
|
||||
async move {
|
||||
Ok::<_, Error>(service_fn(move |req| {
|
||||
let reports = reports.clone();
|
||||
async move {
|
||||
let bytes = hyper::body::to_bytes(req.into_body()).await?;
|
||||
let events: EventChunk<'static, Event<Ids, String>> =
|
||||
serde_json::from_slice(&bytes)?;
|
||||
reports.lock().unwrap().push(events);
|
||||
Ok::<_, Error>(Response::new(Body::from(vec![])))
|
||||
}
|
||||
}))
|
||||
}
|
||||
}));
|
||||
let addr = server.local_addr();
|
||||
tokio::spawn(server);
|
||||
let service = service_fn(move |req: hyper::Request<Incoming>| {
|
||||
let reports = reports.clone();
|
||||
async move {
|
||||
let bytes = req
|
||||
.into_body()
|
||||
.collect()
|
||||
.await
|
||||
.unwrap()
|
||||
.aggregate()
|
||||
.reader();
|
||||
let events: EventChunk<'static, Event<Ids, String>> =
|
||||
serde_json::from_reader(bytes)?;
|
||||
reports.lock().unwrap().push(events);
|
||||
Ok::<_, Error>(Response::new(Empty::<Bytes>::new()))
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await.unwrap();
|
||||
let io = TokioIo::new(stream);
|
||||
let service = service.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let builder = conn::auto::Builder::new(TokioExecutor::new());
|
||||
let res = builder.serve_connection(io, service).await;
|
||||
if let Err(err) = res {
|
||||
println!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let metrics = Metrics::default();
|
||||
let client = http::new_client(RateLimiterConfig::default());
|
||||
|
||||
Reference in New Issue
Block a user