mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-18 02:42:56 +00:00
Compare commits
3 Commits
split-prox
...
hyper-1.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a65a5c372b | ||
|
|
2cf85471f5 | ||
|
|
665f4ff4b5 |
498
Cargo.lock
generated
498
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
20
Cargo.toml
20
Cargo.toml
@@ -51,7 +51,7 @@ async-trait = "0.1"
|
||||
aws-config = { version = "1.0", default-features = false, features=["rustls"] }
|
||||
aws-sdk-s3 = "1.0"
|
||||
aws-smithy-async = { version = "1.0", default-features = false, features=["rt-tokio"] }
|
||||
aws-smithy-types = "1.0"
|
||||
aws-smithy-types = { version = "1.1.2", features = ["http-body-1-x"] }
|
||||
aws-credential-types = "1.0"
|
||||
axum = { version = "0.6.20", features = ["ws"] }
|
||||
base64 = "0.13.0"
|
||||
@@ -89,8 +89,12 @@ hostname = "0.3.1"
|
||||
http-types = { version = "2", default-features = false }
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
hyper-tungstenite = "0.11"
|
||||
hyper = "1.0.0"
|
||||
hyper-util = "0.1.0"
|
||||
http = "1"
|
||||
http-body = "1"
|
||||
http-body-util = "0.1"
|
||||
hyper-tungstenite = "0.13.0"
|
||||
inotify = "0.10.2"
|
||||
ipnet = "2.9.0"
|
||||
itertools = "0.10"
|
||||
@@ -113,7 +117,7 @@ parquet_derive = "49.0.0"
|
||||
pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
|
||||
pin-project-lite = "0.2"
|
||||
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
||||
prost = "0.11"
|
||||
prost = "0.12"
|
||||
rand = "0.8"
|
||||
redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] }
|
||||
regex = "1.10.2"
|
||||
@@ -121,7 +125,7 @@ reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"
|
||||
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_19"] }
|
||||
reqwest-middleware = "0.2.0"
|
||||
reqwest-retry = "0.2.2"
|
||||
routerify = "3"
|
||||
routerify = { git = "https://github.com/conradludgate/routerify", branch = "hyper1" }
|
||||
rpds = "0.13"
|
||||
rustc-hash = "1.1.0"
|
||||
rustls = "0.21"
|
||||
@@ -149,7 +153,7 @@ tar = "0.4"
|
||||
task-local-extensions = "0.1.4"
|
||||
test-context = "0.1"
|
||||
thiserror = "1.0"
|
||||
tls-listener = { version = "0.7", features = ["rustls", "hyper-h1"] }
|
||||
tls-listener = { version = "0.9", features = ["rustls", "tokio-net"] }
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-io-timeout = "1.2.0"
|
||||
tokio-postgres-rustls = "0.10.0"
|
||||
@@ -159,7 +163,7 @@ tokio-tar = "0.3"
|
||||
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
|
||||
toml = "0.7"
|
||||
toml_edit = "0.19"
|
||||
tonic = {version = "0.9", features = ["tls", "tls-roots"]}
|
||||
tonic = {version = "0.10", features = ["tls", "tls-roots"]}
|
||||
tracing = "0.1"
|
||||
tracing-error = "0.2.0"
|
||||
tracing-opentelemetry = "0.19.0"
|
||||
@@ -211,7 +215,7 @@ criterion = "0.5.1"
|
||||
rcgen = "0.11"
|
||||
rstest = "0.18"
|
||||
camino-tempfile = "1.0.2"
|
||||
tonic-build = "0.9"
|
||||
tonic-build = "0.10.2"
|
||||
|
||||
[patch.crates-io]
|
||||
|
||||
|
||||
@@ -12,7 +12,11 @@ cfg-if.workspace = true
|
||||
clap.workspace = true
|
||||
flate2.workspace = true
|
||||
futures.workspace = true
|
||||
hyper = { workspace = true, features = ["full"] }
|
||||
hyper = { workspace = true, features = ["server"] }
|
||||
hyper-util = { workspace = true, features = ["tokio", "server", "server-auto"] }
|
||||
http = { workspace = true, features = [] }
|
||||
http-body = { workspace = true, features = [] }
|
||||
http-body-util = { workspace = true, features = [] }
|
||||
nix.workspace = true
|
||||
notify.workspace = true
|
||||
num_cpus.workspace = true
|
||||
|
||||
@@ -6,14 +6,22 @@ use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
use crate::compute::{ComputeNode, ComputeState, ParsedSpec};
|
||||
use bytes::Bytes;
|
||||
use compute_api::requests::ConfigurationRequest;
|
||||
use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError};
|
||||
|
||||
use anyhow::Result;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Method, Request, Response, Server, StatusCode};
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Method, Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use hyper_util::server::conn;
|
||||
use num_cpus;
|
||||
use serde_json;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task;
|
||||
use tracing::{error, info, warn};
|
||||
use tracing_utils::http::OtelName;
|
||||
@@ -36,7 +44,7 @@ fn status_response_from_state(state: &ComputeState) -> ComputeStatusResponse {
|
||||
}
|
||||
|
||||
// Service function to handle all available routes.
|
||||
async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body> {
|
||||
async fn routes(req: Request<Incoming>, compute: &Arc<ComputeNode>) -> Response<Full<Bytes>> {
|
||||
//
|
||||
// NOTE: The URI path is currently included in traces. That's OK because
|
||||
// it doesn't contain any variable parts or sensitive information. But
|
||||
@@ -48,7 +56,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
info!("serving /status GET request");
|
||||
let state = compute.state.lock().unwrap();
|
||||
let status_response = status_response_from_state(&state);
|
||||
Response::new(Body::from(serde_json::to_string(&status_response).unwrap()))
|
||||
Response::new(Full::from(serde_json::to_string(&status_response).unwrap()))
|
||||
}
|
||||
|
||||
// Startup metrics in JSON format. Keep /metrics reserved for a possible
|
||||
@@ -56,7 +64,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
(&Method::GET, "/metrics.json") => {
|
||||
info!("serving /metrics.json GET request");
|
||||
let metrics = compute.state.lock().unwrap().metrics.clone();
|
||||
Response::new(Body::from(serde_json::to_string(&metrics).unwrap()))
|
||||
Response::new(Full::from(serde_json::to_string(&metrics).unwrap()))
|
||||
}
|
||||
|
||||
// Collect Postgres current usage insights
|
||||
@@ -66,11 +74,11 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
if status != ComputeStatus::Running {
|
||||
let msg = format!("compute is not running, current status: {:?}", status);
|
||||
error!(msg);
|
||||
return Response::new(Body::from(msg));
|
||||
return Response::new(Full::from(msg));
|
||||
}
|
||||
|
||||
let insights = compute.collect_insights().await;
|
||||
Response::new(Body::from(insights))
|
||||
Response::new(Full::from(insights))
|
||||
}
|
||||
|
||||
(&Method::POST, "/check_writability") => {
|
||||
@@ -82,15 +90,15 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
status
|
||||
);
|
||||
error!(msg);
|
||||
return Response::new(Body::from(msg));
|
||||
return Response::new(Full::from(msg));
|
||||
}
|
||||
|
||||
let res = crate::checker::check_writability(compute).await;
|
||||
match res {
|
||||
Ok(_) => Response::new(Body::from("true")),
|
||||
Ok(_) => Response::new(Full::from("true")),
|
||||
Err(e) => {
|
||||
error!("check_writability failed: {}", e);
|
||||
Response::new(Body::from(e.to_string()))
|
||||
Response::new(Full::from(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -98,7 +106,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
(&Method::GET, "/info") => {
|
||||
let num_cpus = num_cpus::get_physical();
|
||||
info!("serving /info GET request. num_cpus: {}", num_cpus);
|
||||
Response::new(Body::from(
|
||||
Response::new(Full::from(
|
||||
serde_json::json!({
|
||||
"num_cpus": num_cpus,
|
||||
})
|
||||
@@ -115,7 +123,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
(&Method::POST, "/configure") => {
|
||||
info!("serving /configure POST request");
|
||||
match handle_configure_request(req, compute).await {
|
||||
Ok(msg) => Response::new(Body::from(msg)),
|
||||
Ok(msg) => Response::new(Full::from(msg)),
|
||||
Err((msg, code)) => {
|
||||
error!("error handling /configure request: {msg}");
|
||||
render_json_error(&msg, code)
|
||||
@@ -132,7 +140,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
// if no remote storage is configured
|
||||
if compute.ext_remote_storage.is_none() {
|
||||
info!("no extensions remote storage configured");
|
||||
let mut resp = Response::new(Body::from("no remote storage configured"));
|
||||
let mut resp = Response::new(Full::from("no remote storage configured"));
|
||||
*resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return resp;
|
||||
}
|
||||
@@ -143,7 +151,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
if params == "is_library=true" {
|
||||
is_library = true;
|
||||
} else {
|
||||
let mut resp = Response::new(Body::from("Wrong request parameters"));
|
||||
let mut resp = Response::new(Full::from("Wrong request parameters"));
|
||||
*resp.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return resp;
|
||||
}
|
||||
@@ -165,7 +173,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
Some(r) => r,
|
||||
None => {
|
||||
info!("no remote extensions spec was provided");
|
||||
let mut resp = Response::new(Body::from("no remote storage configured"));
|
||||
let mut resp = Response::new(Full::from("no remote storage configured"));
|
||||
*resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return resp;
|
||||
}
|
||||
@@ -182,10 +190,10 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
match ext {
|
||||
Ok((ext_name, ext_path)) => {
|
||||
match compute.download_extension(ext_name, ext_path).await {
|
||||
Ok(_) => Response::new(Body::from("OK")),
|
||||
Ok(_) => Response::new(Full::from("OK")),
|
||||
Err(e) => {
|
||||
error!("extension download failed: {}", e);
|
||||
let mut resp = Response::new(Body::from(e.to_string()));
|
||||
let mut resp = Response::new(Full::from(e.to_string()));
|
||||
*resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
resp
|
||||
}
|
||||
@@ -193,7 +201,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("extension download failed to find extension: {}", e);
|
||||
let mut resp = Response::new(Body::from("failed to find file"));
|
||||
let mut resp = Response::new(Full::from("failed to find file"));
|
||||
*resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
resp
|
||||
}
|
||||
@@ -202,7 +210,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
|
||||
// Return the `404 Not Found` for any other routes.
|
||||
_ => {
|
||||
let mut not_found = Response::new(Body::from("404 Not Found"));
|
||||
let mut not_found = Response::new(Full::from("404 Not Found"));
|
||||
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
||||
not_found
|
||||
}
|
||||
@@ -210,7 +218,7 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
}
|
||||
|
||||
async fn handle_configure_request(
|
||||
req: Request<Body>,
|
||||
req: Request<Incoming>,
|
||||
compute: &Arc<ComputeNode>,
|
||||
) -> Result<String, (String, StatusCode)> {
|
||||
if !compute.live_config_allowed {
|
||||
@@ -220,7 +228,7 @@ async fn handle_configure_request(
|
||||
));
|
||||
}
|
||||
|
||||
let body_bytes = hyper::body::to_bytes(req.into_body()).await.unwrap();
|
||||
let body_bytes = req.into_body().collect().await.unwrap().to_bytes();
|
||||
let spec_raw = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||
if let Ok(request) = serde_json::from_str::<ConfigurationRequest>(&spec_raw) {
|
||||
let spec = request.spec;
|
||||
@@ -287,13 +295,13 @@ async fn handle_configure_request(
|
||||
}
|
||||
}
|
||||
|
||||
fn render_json_error(e: &str, status: StatusCode) -> Response<Body> {
|
||||
fn render_json_error(e: &str, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
let error = GenericAPIError {
|
||||
error: e.to_string(),
|
||||
};
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.body(Body::from(serde_json::to_string(&error).unwrap()))
|
||||
.body(Full::from(serde_json::to_string(&error).unwrap()))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -304,35 +312,43 @@ async fn serve(port: u16, state: Arc<ComputeNode>) {
|
||||
// see e.g. https://github.com/rust-lang/rust/pull/34440
|
||||
let addr = SocketAddr::new(IpAddr::from(Ipv6Addr::UNSPECIFIED), port);
|
||||
|
||||
let make_service = make_service_fn(move |_conn| {
|
||||
let service = service_fn(move |req: Request<Incoming>| {
|
||||
let state = state.clone();
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
|
||||
let state = state.clone();
|
||||
async move {
|
||||
Ok::<_, Infallible>(
|
||||
// NOTE: We include the URI path in the string. It
|
||||
// doesn't contain any variable parts or sensitive
|
||||
// information in this API.
|
||||
tracing_utils::http::tracing_handler(
|
||||
req,
|
||||
|req| routes(req, &state),
|
||||
OtelName::UriPath,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
}
|
||||
}))
|
||||
Ok::<_, Infallible>(
|
||||
// NOTE: We include the URI path in the string. It
|
||||
// doesn't contain any variable parts or sensitive
|
||||
// information in this API.
|
||||
tracing_utils::http::tracing_handler(
|
||||
req,
|
||||
|req| routes(req, &state),
|
||||
OtelName::UriPath,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
info!("starting HTTP server on {}", addr);
|
||||
|
||||
let server = Server::bind(&addr).serve(make_service);
|
||||
|
||||
// Run this server forever
|
||||
if let Err(e) = server.await {
|
||||
error!("server error: {}", e);
|
||||
let listener = TcpListener::bind(addr).await.unwrap();
|
||||
loop {
|
||||
let (stream, _) = match listener.accept().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("server error: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let io = TokioIo::new(stream);
|
||||
let service = service.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let builder = conn::auto::Builder::new(TokioExecutor::new());
|
||||
let res = builder.serve_connection(io, service).await;
|
||||
if let Err(err) = res {
|
||||
println!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::{background_process, local_env::LocalEnv};
|
||||
use camino::Utf8PathBuf;
|
||||
use hyper::Method;
|
||||
use pageserver_api::{
|
||||
models::{ShardParameters, TenantCreateRequest, TimelineCreateRequest, TimelineInfo},
|
||||
shard::TenantShardId,
|
||||
@@ -8,6 +7,7 @@ use pageserver_api::{
|
||||
use pageserver_client::mgmt_api::ResponseErrorMessageExt;
|
||||
use postgres_backend::AuthType;
|
||||
use postgres_connection::parse_host_port;
|
||||
use reqwest::Method;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::{path::PathBuf, process::Child, str::FromStr};
|
||||
use tracing::instrument;
|
||||
@@ -278,7 +278,7 @@ impl AttachmentService {
|
||||
/// Simple HTTP request wrapper for calling into attachment service
|
||||
async fn dispatch<RQ, RS>(
|
||||
&self,
|
||||
method: hyper::Method,
|
||||
method: reqwest::Method,
|
||||
path: String,
|
||||
body: Option<RQ>,
|
||||
) -> anyhow::Result<RS>
|
||||
|
||||
@@ -15,7 +15,11 @@ aws-sdk-s3.workspace = true
|
||||
aws-credential-types.workspace = true
|
||||
bytes.workspace = true
|
||||
camino.workspace = true
|
||||
hyper = { workspace = true, features = ["stream"] }
|
||||
hyper = { workspace = true, features = [] }
|
||||
hyper-util = { workspace = true, features = [] }
|
||||
http = { workspace = true, features = [] }
|
||||
http-body = { workspace = true, features = [] }
|
||||
http-body-util = { workspace = true, features = [] }
|
||||
futures.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -36,7 +36,9 @@ use aws_smithy_types::body::SdkBody;
|
||||
use aws_smithy_types::byte_stream::ByteStream;
|
||||
use bytes::Bytes;
|
||||
use futures::stream::Stream;
|
||||
use hyper::Body;
|
||||
use futures_util::TryStreamExt;
|
||||
use http_body::Frame;
|
||||
use http_body_util::StreamBody;
|
||||
use scopeguard::ScopeGuard;
|
||||
|
||||
use super::StorageMetadata;
|
||||
@@ -469,8 +471,8 @@ impl RemoteStorage for S3Bucket {
|
||||
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let body = Body::wrap_stream(from);
|
||||
let bytes_stream = ByteStream::new(SdkBody::from_body_0_4(body));
|
||||
let body = StreamBody::new(from.map_ok(Frame::data));
|
||||
let bytes_stream = ByteStream::new(SdkBody::from_body_1_x(body));
|
||||
|
||||
let res = self
|
||||
.client
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
//! Tracing wrapper for Hyper HTTP server
|
||||
|
||||
use hyper::body::Body;
|
||||
use hyper::HeaderMap;
|
||||
use hyper::{Body, Request, Response};
|
||||
use hyper::{Request, Response};
|
||||
use std::future::Future;
|
||||
use tracing::Instrument;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
@@ -35,14 +36,14 @@ pub enum OtelName<'a> {
|
||||
/// instrumentation libraries at:
|
||||
/// <https://opentelemetry.io/registry/?language=rust&component=instrumentation>
|
||||
/// If a Hyper crate appears, consider switching to that.
|
||||
pub async fn tracing_handler<F, R>(
|
||||
req: Request<Body>,
|
||||
pub async fn tracing_handler<B1: Body, B2: Body, F, R>(
|
||||
req: Request<B1>,
|
||||
handler: F,
|
||||
otel_name: OtelName<'_>,
|
||||
) -> Response<Body>
|
||||
) -> Response<B2>
|
||||
where
|
||||
F: Fn(Request<Body>) -> R,
|
||||
R: Future<Output = Response<Body>>,
|
||||
F: Fn(Request<B1>) -> R,
|
||||
R: Future<Output = Response<B2>>,
|
||||
{
|
||||
// Create a tracing span, with context propagated from the incoming
|
||||
// request if any.
|
||||
|
||||
@@ -22,6 +22,7 @@ chrono.workspace = true
|
||||
heapless.workspace = true
|
||||
hex = { workspace = true, features = ["serde"] }
|
||||
hyper = { workspace = true, features = ["full"] }
|
||||
http-body-util = { workspace = true, features = [] }
|
||||
fail.workspace = true
|
||||
futures = { workspace = true}
|
||||
jsonwebtoken.workspace = true
|
||||
|
||||
@@ -4,7 +4,10 @@ use crate::http::{
|
||||
error::ApiError,
|
||||
json::{json_request, json_response},
|
||||
};
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use routerify::Body;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::*;
|
||||
@@ -151,7 +154,7 @@ pub struct FailpointConfig {
|
||||
pub async fn failpoints_handler(
|
||||
mut request: Request<Body>,
|
||||
_cancel: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
if !fail::has_failpoints() {
|
||||
return Err(ApiError::BadRequest(anyhow::anyhow!(
|
||||
"Cannot manage failpoints because storage was compiled without failpoints support"
|
||||
|
||||
@@ -4,11 +4,11 @@ use anyhow::Context;
|
||||
use hyper::header::{HeaderName, AUTHORIZATION};
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::Method;
|
||||
use hyper::{header::CONTENT_TYPE, Body, Request, Response};
|
||||
use hyper::{header::CONTENT_TYPE, Request, Response};
|
||||
use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
|
||||
use once_cell::sync::Lazy;
|
||||
use routerify::ext::RequestExt;
|
||||
use routerify::{Middleware, RequestInfo, Router, RouterBuilder};
|
||||
use routerify::{Body, Middleware, RequestInfo, Router, RouterBuilder};
|
||||
use tracing::{self, debug, info, info_span, warn, Instrument};
|
||||
|
||||
use std::future::Future;
|
||||
@@ -238,7 +238,7 @@ async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body
|
||||
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
|
||||
let body = Body::wrap_stream(ReceiverStream::new(rx));
|
||||
let body = Body::from_stream(ReceiverStream::new(rx));
|
||||
|
||||
let mut writer = ChannelWriter::new(128 * 1024, tx);
|
||||
|
||||
@@ -284,7 +284,7 @@ async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
||||
pub fn add_request_id_middleware<B: hyper::body::Body + Send + Sync + 'static>(
|
||||
) -> Middleware<B, ApiError> {
|
||||
Middleware::pre(move |req| async move {
|
||||
let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
|
||||
@@ -317,7 +317,7 @@ async fn add_request_id_header_to_response(
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
pub fn make_router() -> RouterBuilder<routerify::Body, ApiError> {
|
||||
Router::builder()
|
||||
.middleware(add_request_id_middleware())
|
||||
.middleware(Middleware::post_with_info(
|
||||
@@ -328,11 +328,11 @@ pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
}
|
||||
|
||||
pub fn attach_openapi_ui(
|
||||
router_builder: RouterBuilder<hyper::Body, ApiError>,
|
||||
router_builder: RouterBuilder<routerify::Body, ApiError>,
|
||||
spec: &'static [u8],
|
||||
spec_mount_path: &'static str,
|
||||
ui_mount_path: &'static str,
|
||||
) -> RouterBuilder<hyper::Body, ApiError> {
|
||||
) -> RouterBuilder<routerify::Body, ApiError> {
|
||||
router_builder
|
||||
.get(spec_mount_path,
|
||||
move |r| request_span(r, move |_| async move {
|
||||
@@ -388,7 +388,7 @@ fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
||||
pub fn auth_middleware<B: hyper::body::Body + Send + Sync + 'static>(
|
||||
provide_auth: fn(&Request<Body>) -> Option<&SwappableJwtAuth>,
|
||||
) -> Middleware<B, ApiError> {
|
||||
Middleware::pre(move |req| async move {
|
||||
@@ -423,7 +423,7 @@ pub fn add_response_header_middleware<B>(
|
||||
value: &str,
|
||||
) -> anyhow::Result<Middleware<B, ApiError>>
|
||||
where
|
||||
B: hyper::body::HttpBody + Send + Sync + 'static,
|
||||
B: hyper::body::Body + Send + Sync + 'static,
|
||||
{
|
||||
let name =
|
||||
HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?;
|
||||
@@ -464,7 +464,6 @@ pub fn check_permission_with(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::future::poll_fn;
|
||||
use hyper::service::Service;
|
||||
use routerify::RequestServiceBuilder;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
@@ -473,16 +472,13 @@ mod tests {
|
||||
async fn test_request_id_returned() {
|
||||
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
|
||||
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
|
||||
let mut service = builder.build(remote_addr);
|
||||
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
|
||||
panic!("request service is not ready: {:?}", e);
|
||||
}
|
||||
let service = builder.build(remote_addr);
|
||||
|
||||
let mut req: Request<Body> = Request::default();
|
||||
req.headers_mut()
|
||||
.append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
|
||||
|
||||
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
|
||||
let resp: Response<Body> = service.call(req).await.unwrap();
|
||||
|
||||
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
|
||||
|
||||
@@ -493,13 +489,10 @@ mod tests {
|
||||
async fn test_request_id_empty() {
|
||||
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
|
||||
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
|
||||
let mut service = builder.build(remote_addr);
|
||||
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
|
||||
panic!("request service is not ready: {:?}", e);
|
||||
}
|
||||
let service = builder.build(remote_addr);
|
||||
|
||||
let req: Request<Body> = Request::default();
|
||||
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
|
||||
let resp: Response<Body> = service.call(req).await.unwrap();
|
||||
|
||||
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use hyper::{header, Body, Response, StatusCode};
|
||||
use hyper::{header, Response, StatusCode};
|
||||
use routerify::Body;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Cow;
|
||||
use std::error::Error as StdError;
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use anyhow::Context;
|
||||
use bytes::Buf;
|
||||
use hyper::{header, Body, Request, Response, StatusCode};
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::{Buf, Bytes};
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::{header, Request, Response, StatusCode};
|
||||
use routerify::Body;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::error::ApiError;
|
||||
@@ -18,10 +20,14 @@ pub async fn json_request<T: for<'de> Deserialize<'de>>(
|
||||
pub async fn json_request_or_empty_body<T: for<'de> Deserialize<'de>>(
|
||||
request: &mut Request<Body>,
|
||||
) -> Result<Option<T>, ApiError> {
|
||||
let body = hyper::body::aggregate(request.body_mut())
|
||||
let body = request
|
||||
.body_mut()
|
||||
.collect()
|
||||
.await
|
||||
.map_err(|e| anyhow!(e))
|
||||
.context("Failed to read request body")
|
||||
.map_err(ApiError::BadRequest)?;
|
||||
.map_err(ApiError::BadRequest)?
|
||||
.aggregate();
|
||||
if body.remaining() == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
@@ -35,17 +41,24 @@ pub async fn json_request_or_empty_body<T: for<'de> Deserialize<'de>>(
|
||||
.map_err(ApiError::BadRequest)
|
||||
}
|
||||
|
||||
pub fn json_response<T: Serialize>(
|
||||
pub fn json_response_body<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
json_response(status, data).map(|r| r.map(Body::new))
|
||||
}
|
||||
|
||||
pub fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(json))
|
||||
.body(Full::from(json))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -5,4 +5,4 @@ pub mod request;
|
||||
|
||||
/// Current fast way to apply simple http routing in various Neon binaries.
|
||||
/// Re-exported for sake of uniform approach, that could be later replaced with better alternatives, if needed.
|
||||
pub use routerify::{ext::RequestExt, RouterBuilder, RouterService};
|
||||
pub use routerify::{ext::RequestExt, Body, RequestServiceBuilder, RouterBuilder};
|
||||
|
||||
@@ -3,8 +3,9 @@ use std::{borrow::Cow, str::FromStr};
|
||||
|
||||
use super::error::ApiError;
|
||||
use anyhow::anyhow;
|
||||
use hyper::{body::HttpBody, Body, Request};
|
||||
use routerify::ext::RequestExt;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::Request;
|
||||
use routerify::{ext::RequestExt, Body};
|
||||
|
||||
pub fn get_request_param<'a>(
|
||||
request: &'a Request<Body>,
|
||||
@@ -75,7 +76,7 @@ pub fn parse_query_param<E: fmt::Display, T: FromStr<Err = E>>(
|
||||
}
|
||||
|
||||
pub async fn ensure_no_body(request: &mut Request<Body>) -> Result<(), ApiError> {
|
||||
match request.body_mut().data().await {
|
||||
match request.body_mut().frame().await {
|
||||
Some(_) => Err(ApiError::BadRequest(anyhow!("Unexpected request body"))),
|
||||
None => Ok(()),
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use futures::TryFutureExt;
|
||||
use humantime::format_rfc3339;
|
||||
use hyper::header;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, Request, Response, Uri};
|
||||
use hyper::{Request, Response, Uri};
|
||||
use metrics::launch_timestamp::LaunchTimestamp;
|
||||
use pageserver_api::models::LocationConfigListResponse;
|
||||
use pageserver_api::models::ShardParameters;
|
||||
@@ -32,6 +32,7 @@ use utils::failpoint_support::failpoints_handler;
|
||||
use utils::http::endpoint::request_span;
|
||||
use utils::http::json::json_request_or_empty_body;
|
||||
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
|
||||
use utils::http::Body;
|
||||
|
||||
use crate::context::{DownloadBehavior, RequestContext};
|
||||
use crate::deletion_queue::DeletionQueueClient;
|
||||
@@ -64,7 +65,7 @@ use utils::{
|
||||
http::{
|
||||
endpoint::{self, attach_openapi_ui, auth_middleware, check_permission_with},
|
||||
error::{ApiError, HttpErrorBody},
|
||||
json::{json_request, json_response},
|
||||
json::{json_request, json_response_body as json_response},
|
||||
request::parse_request_param,
|
||||
RequestExt, RouterBuilder,
|
||||
},
|
||||
@@ -1571,7 +1572,7 @@ async fn getpage_at_lsn_handler(
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.body(hyper::Body::from(page))
|
||||
.body(Body::from(page))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
@@ -1868,7 +1869,7 @@ pub fn make_router(
|
||||
state: Arc<State>,
|
||||
launch_ts: &'static LaunchTimestamp,
|
||||
auth: Option<Arc<SwappableJwtAuth>>,
|
||||
) -> anyhow::Result<RouterBuilder<hyper::Body, ApiError>> {
|
||||
) -> anyhow::Result<RouterBuilder<Body, ApiError>> {
|
||||
let spec = include_bytes!("openapi_spec.yml");
|
||||
let mut router = attach_openapi_ui(endpoint::make_router(), spec, "/swagger.yml", "/v1/doc");
|
||||
if auth.is_some() {
|
||||
|
||||
@@ -28,7 +28,11 @@ hmac.workspace = true
|
||||
hostname.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper = { workspace = true, features = ["server"] }
|
||||
hyper-util = { workspace = true, features = ["tokio", "server", "server-auto"] }
|
||||
http = { workspace = true, features = [] }
|
||||
http-body = { workspace = true, features = [] }
|
||||
http-body-util = { workspace = true, features = [] }
|
||||
ipnet.workspace = true
|
||||
itertools.workspace = true
|
||||
md5.workspace = true
|
||||
|
||||
@@ -4,14 +4,12 @@
|
||||
|
||||
pub mod health_server;
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::FutureExt;
|
||||
pub use reqwest::{Request, Response, StatusCode};
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::time::Instant;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::{metrics::CONSOLE_REQUEST_LATENCY, rate_limiter, url::ApiUrl};
|
||||
use reqwest_middleware::RequestBuilder;
|
||||
@@ -21,8 +19,6 @@ use reqwest_middleware::RequestBuilder;
|
||||
/// We deliberately don't want to replace this with a public static.
|
||||
pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> ClientWithMiddleware {
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.dns_resolver(Arc::new(GaiResolver::default()))
|
||||
.connection_verbose(true)
|
||||
.build()
|
||||
.expect("Failed to create http client");
|
||||
|
||||
@@ -34,8 +30,6 @@ pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> Clien
|
||||
|
||||
pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
|
||||
let timeout_client = reqwest::ClientBuilder::new()
|
||||
.dns_resolver(Arc::new(GaiResolver::default()))
|
||||
.connection_verbose(true)
|
||||
.timeout(default_timout)
|
||||
.build()
|
||||
.expect("Failed to create http client with timeout");
|
||||
@@ -100,37 +94,6 @@ impl Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
/// https://docs.rs/reqwest/0.11.18/src/reqwest/dns/gai.rs.html
|
||||
use hyper::{
|
||||
client::connect::dns::{GaiResolver as HyperGaiResolver, Name},
|
||||
service::Service,
|
||||
};
|
||||
use reqwest::dns::{Addrs, Resolve, Resolving};
|
||||
#[derive(Debug)]
|
||||
pub struct GaiResolver(HyperGaiResolver);
|
||||
|
||||
impl Default for GaiResolver {
|
||||
fn default() -> Self {
|
||||
Self(HyperGaiResolver::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl Resolve for GaiResolver {
|
||||
fn resolve(&self, name: Name) -> Resolving {
|
||||
let this = &mut self.0.clone();
|
||||
let start = Instant::now();
|
||||
Box::pin(
|
||||
Service::<Name>::call(this, name.clone()).map(move |result| {
|
||||
let resolve_duration = start.elapsed();
|
||||
trace!(duration = ?resolve_duration, addr = %name, "resolve host complete");
|
||||
result
|
||||
.map(|addrs| -> Addrs { Box::new(addrs) })
|
||||
.map_err(|err| -> Box<dyn std::error::Error + Send + Sync> { Box::new(err) })
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
use anyhow::{anyhow, bail};
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use anyhow::anyhow;
|
||||
use http::{Request, Response};
|
||||
use hyper::StatusCode;
|
||||
use hyper_util::{
|
||||
rt::{TokioExecutor, TokioIo},
|
||||
server::conn,
|
||||
};
|
||||
use std::{convert::Infallible, net::TcpListener};
|
||||
use tracing::info;
|
||||
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
|
||||
use utils::http::{
|
||||
endpoint, error::ApiError, json::json_response, Body, RequestServiceBuilder, RouterBuilder,
|
||||
};
|
||||
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
json_response(StatusCode::OK, "")
|
||||
json_response(StatusCode::OK, "").map(|req| req.map(Body::new))
|
||||
}
|
||||
|
||||
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
fn make_router() -> RouterBuilder<Body, ApiError> {
|
||||
endpoint::make_router().get("/v1/status", status_handler)
|
||||
}
|
||||
|
||||
@@ -17,11 +24,20 @@ pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<Infallible>
|
||||
info!("http has shut down");
|
||||
}
|
||||
|
||||
let service = || RouterService::new(make_router().build()?);
|
||||
let router = make_router().build().map_err(|e| anyhow!(e))?;
|
||||
let builder = RequestServiceBuilder::new(router).map_err(|e| anyhow!(e))?;
|
||||
let listener = tokio::net::TcpListener::from_std(http_listener)?;
|
||||
|
||||
hyper::Server::from_tcp(http_listener)?
|
||||
.serve(service().map_err(|e| anyhow!(e))?)
|
||||
.await?;
|
||||
|
||||
bail!("hyper server without shutdown handling cannot shutdown successfully");
|
||||
loop {
|
||||
let (stream, remote_addr) = listener.accept().await.unwrap();
|
||||
let io = TokioIo::new(stream);
|
||||
let service = builder.build(remote_addr);
|
||||
tokio::task::spawn(async move {
|
||||
let builder = conn::auto::Builder::new(TokioExecutor::new());
|
||||
let res = builder.serve_connection(io, service).await;
|
||||
if let Err(err) = res {
|
||||
println!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,13 +10,15 @@ use std::{
|
||||
};
|
||||
|
||||
use bytes::{Buf, BytesMut};
|
||||
use hyper::server::conn::{AddrIncoming, AddrStream};
|
||||
use pin_project_lite::pin_project;
|
||||
use tls_listener::AsyncAccept;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf},
|
||||
net::{TcpListener, TcpStream},
|
||||
};
|
||||
|
||||
pub struct ProxyProtocolAccept {
|
||||
pub incoming: AddrIncoming,
|
||||
pub incoming: TcpListener,
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
@@ -327,20 +329,18 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
|
||||
}
|
||||
|
||||
impl AsyncAccept for ProxyProtocolAccept {
|
||||
type Connection = WithClientIp<AddrStream>;
|
||||
type Connection = WithClientIp<TcpStream>;
|
||||
|
||||
type Address = SocketAddr;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
|
||||
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
|
||||
let Some(conn) = conn else {
|
||||
return Poll::Ready(None);
|
||||
};
|
||||
|
||||
Poll::Ready(Some(Ok(WithClientIp::new(conn))))
|
||||
) -> Poll<Result<(Self::Connection, Self::Address), Self::Error>> {
|
||||
Pin::new(&mut self.incoming)
|
||||
.poll_accept(cx)
|
||||
.map_ok(|(c, a)| (WithClientIp::new(c), a))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ use crate::{
|
||||
proxy::retry::{retry_after, ShouldRetry},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use hyper::StatusCode;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use reqwest::StatusCode;
|
||||
use std::ops::ControlFlow;
|
||||
use tokio::time;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
@@ -6,41 +6,54 @@ mod conn_pool;
|
||||
mod sql_over_http;
|
||||
mod websocket;
|
||||
|
||||
use bytes::Bytes;
|
||||
pub use conn_pool::GlobalConnPoolOptions;
|
||||
|
||||
use anyhow::bail;
|
||||
use http_body_util::Full;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::StatusCode;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use hyper_util::server::conn;
|
||||
use metrics::IntCounterPairGuard;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::select;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
use crate::protocol2::ProxyProtocolAccept;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::{cancellation::CancelMap, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
server::{
|
||||
accept,
|
||||
conn::{AddrIncoming, AddrStream},
|
||||
},
|
||||
Body, Method, Request, Response,
|
||||
};
|
||||
use hyper::{Method, Request, Response};
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::task::Poll;
|
||||
use std::{future::ready, sync::Arc};
|
||||
use tls_listener::TlsListener;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tls_listener::{AsyncTls, TlsListener};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Tls(TlsAcceptor);
|
||||
|
||||
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncTls<C> for Tls {
|
||||
type Stream = tokio_rustls::server::TlsStream<C>;
|
||||
type Error = std::io::Error;
|
||||
type AcceptFuture = tokio_rustls::Accept<C>;
|
||||
|
||||
fn accept(&self, conn: C) -> Self::AcceptFuture {
|
||||
tokio_rustls::TlsAcceptor::accept(&self.0, conn)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
ws_listener: TcpListener,
|
||||
@@ -79,42 +92,52 @@ pub async fn task_main(
|
||||
};
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
|
||||
|
||||
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
let _ = addr_incoming.set_nodelay(true);
|
||||
// let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
// let _ = addr_incoming.set_nodelay(true);
|
||||
let addr_incoming = ProxyProtocolAccept {
|
||||
incoming: addr_incoming,
|
||||
incoming: ws_listener,
|
||||
};
|
||||
|
||||
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let ws_connections2 = ws_connections.clone();
|
||||
ws_connections.close(); // allows `ws_connections.wait to complete`
|
||||
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
error!("failed to accept TLS connection for websockets: {err:?}");
|
||||
ready(false)
|
||||
} else {
|
||||
ready(true)
|
||||
}
|
||||
});
|
||||
let mut tls_listener = TlsListener::new(Tls(tls_acceptor), addr_incoming);
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (stream, remote_addr) = select! {
|
||||
res = tls_listener.accept() => {
|
||||
match res {
|
||||
Err(err) =>
|
||||
{error!("failed to accept TLS connection for websockets: {err:?}"); continue},
|
||||
Ok(s) => s,
|
||||
}
|
||||
}
|
||||
_ = cancellation_token.cancelled() => break,
|
||||
};
|
||||
let (io, tls) = stream.get_ref();
|
||||
let client_addr = io.client_addr();
|
||||
let remote_addr = io.inner.remote_addr();
|
||||
let sni_name = tls.server_name().map(|s| s.to_string());
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let ws_connections = ws_connections2.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
None if config.require_client_ip => bail!("missing required client ip"),
|
||||
None => remote_addr,
|
||||
};
|
||||
Ok(MetricService::new(hyper::service::service_fn(
|
||||
move |req: Request<Body>| {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
None if config.require_client_ip => {
|
||||
tracing::error!("Error serving connection: missing required client ip");
|
||||
continue;
|
||||
}
|
||||
None => remote_addr,
|
||||
};
|
||||
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let cancellation_token = cancellation_token.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let service = MetricService::new(hyper::service::service_fn(
|
||||
move |req: Request<Incoming>| {
|
||||
let sni_name = sni_name.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
@@ -144,15 +167,22 @@ pub async fn task_main(
|
||||
.await
|
||||
}
|
||||
},
|
||||
)))
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
hyper::Server::builder(accept::from_stream(tls_listener))
|
||||
.serve(make_svc)
|
||||
.with_graceful_shutdown(cancellation_token.cancelled())
|
||||
.await?;
|
||||
));
|
||||
let builder = conn::auto::Builder::new(TokioExecutor::new());
|
||||
let mut conn = pin!(builder.serve_connection(io, service));
|
||||
let res = select! {
|
||||
_ = cancellation_token.cancelled() => {
|
||||
conn.as_mut().graceful_shutdown();
|
||||
conn.await
|
||||
}
|
||||
res = conn.as_mut() => res,
|
||||
};
|
||||
if let Err(err) = res {
|
||||
tracing::error!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// await websocket connections
|
||||
ws_connections.wait().await;
|
||||
@@ -184,18 +214,14 @@ where
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
fn call(&self, req: Request<ReqBody>) -> Self::Future {
|
||||
self.inner.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: Request<Body>,
|
||||
mut request: Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
tls: &'static TlsConfig,
|
||||
conn_pool: Arc<conn_pool::GlobalConnPool>,
|
||||
@@ -205,7 +231,7 @@ async fn request_handler(
|
||||
sni_hostname: Option<String>,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
@@ -264,7 +290,7 @@ async fn request_handler(
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Body::empty())
|
||||
.body(Full::new(Bytes::new()))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use bytes::Buf;
|
||||
use bytes::Bytes;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
use http_body::Body;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::header;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use hyper::{HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
@@ -235,10 +240,10 @@ pub async fn handle(
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
request: Request<Incoming>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.request_timeout,
|
||||
handle_inner(tls, config, ctx, request, sni_hostname, conn_pool),
|
||||
@@ -347,10 +352,10 @@ async fn handle_inner(
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
request: Request<Incoming>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
) -> anyhow::Result<Response<Full<Bytes>>> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&["http"])
|
||||
.guard();
|
||||
@@ -406,8 +411,8 @@ async fn handle_inner(
|
||||
//
|
||||
// Read the query and query params from the request body
|
||||
//
|
||||
let body = hyper::body::to_bytes(request.into_body()).await?;
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
let body = request.into_body().collect().await?.aggregate().reader();
|
||||
let payload: Payload = serde_json::from_reader(body)?;
|
||||
|
||||
let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
|
||||
|
||||
@@ -504,7 +509,7 @@ async fn handle_inner(
|
||||
let body = serde_json::to_string(&result).expect("json serialization should not fail");
|
||||
let len = body.len();
|
||||
let response = response
|
||||
.body(Body::from(body))
|
||||
.body(Full::from(body))
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
|
||||
@@ -235,18 +235,19 @@ async fn collect_metrics_iteration(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
net::TcpListener,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use anyhow::Error;
|
||||
use bytes::{Buf, Bytes};
|
||||
use chrono::Utc;
|
||||
use consumption_metrics::{Event, EventChunk};
|
||||
use hyper::{
|
||||
service::{make_service_fn, service_fn},
|
||||
Body, Response,
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use hyper::{body::Incoming, service::service_fn, Response};
|
||||
use hyper_util::{
|
||||
rt::{TokioExecutor, TokioIo},
|
||||
server::conn,
|
||||
};
|
||||
use tokio::net::TcpListener;
|
||||
use url::Url;
|
||||
|
||||
use super::{collect_metrics_iteration, Ids, Metrics};
|
||||
@@ -254,30 +255,43 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn metrics() {
|
||||
let listener = TcpListener::bind("0.0.0.0:0").unwrap();
|
||||
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let reports = Arc::new(Mutex::new(vec![]));
|
||||
let reports2 = reports.clone();
|
||||
|
||||
let server = hyper::server::Server::from_tcp(listener)
|
||||
.unwrap()
|
||||
.serve(make_service_fn(move |_| {
|
||||
let reports = reports.clone();
|
||||
async move {
|
||||
Ok::<_, Error>(service_fn(move |req| {
|
||||
let reports = reports.clone();
|
||||
async move {
|
||||
let bytes = hyper::body::to_bytes(req.into_body()).await?;
|
||||
let events: EventChunk<'static, Event<Ids, String>> =
|
||||
serde_json::from_slice(&bytes)?;
|
||||
reports.lock().unwrap().push(events);
|
||||
Ok::<_, Error>(Response::new(Body::from(vec![])))
|
||||
}
|
||||
}))
|
||||
}
|
||||
}));
|
||||
let addr = server.local_addr();
|
||||
tokio::spawn(server);
|
||||
let service = service_fn(move |req: hyper::Request<Incoming>| {
|
||||
let reports = reports.clone();
|
||||
async move {
|
||||
let bytes = req
|
||||
.into_body()
|
||||
.collect()
|
||||
.await
|
||||
.unwrap()
|
||||
.aggregate()
|
||||
.reader();
|
||||
let events: EventChunk<'static, Event<Ids, String>> =
|
||||
serde_json::from_reader(bytes)?;
|
||||
reports.lock().unwrap().push(events);
|
||||
Ok::<_, Error>(Response::new(Empty::<Bytes>::new()))
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await.unwrap();
|
||||
let io = TokioIo::new(stream);
|
||||
let service = service.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let builder = conn::auto::Builder::new(TokioExecutor::new());
|
||||
let res = builder.serve_connection(io, service).await;
|
||||
if let Err(err) = res {
|
||||
println!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let metrics = Metrics::default();
|
||||
let client = http::new_client(RateLimiterConfig::default());
|
||||
|
||||
@@ -18,7 +18,11 @@ futures-core.workspace = true
|
||||
futures-util.workspace = true
|
||||
git-version.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper = { workspace = true, features = ["full"] }
|
||||
hyper = { workspace = true, features = ["server"] }
|
||||
hyper-util = { workspace = true, features = ["tokio", "server", "server-auto"] }
|
||||
http = { workspace = true, features = [] }
|
||||
http-body = { workspace = true, features = [] }
|
||||
http-body-util = { workspace = true, features = [] }
|
||||
once_cell.workspace = true
|
||||
parking_lot.workspace = true
|
||||
prost.workspace = true
|
||||
@@ -29,6 +33,9 @@ tracing.workspace = true
|
||||
metrics.workspace = true
|
||||
utils.workspace = true
|
||||
|
||||
# needed for tonic
|
||||
http0_2 = { package = "http", version = "0.2" }
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
|
||||
@@ -13,10 +13,14 @@
|
||||
use clap::{command, Parser};
|
||||
use futures_core::Stream;
|
||||
use futures_util::StreamExt;
|
||||
use http::Request;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::header::CONTENT_TYPE;
|
||||
use hyper::server::conn::AddrStream;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Method, StatusCode};
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use hyper_util::server::conn;
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
@@ -24,6 +28,7 @@ use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::broadcast::error::RecvError;
|
||||
use tokio::time;
|
||||
@@ -596,9 +601,7 @@ impl BrokerService for Broker {
|
||||
}
|
||||
|
||||
// We serve only metrics and healthcheck through http1.
|
||||
async fn http1_handler(
|
||||
req: hyper::Request<hyper::body::Body>,
|
||||
) -> Result<hyper::Response<Body>, Infallible> {
|
||||
async fn http1_handler(req: hyper::Request<Body>) -> Result<hyper::Response<Body>, Infallible> {
|
||||
let resp = match (req.method(), req.uri().path()) {
|
||||
(&Method::GET, "/metrics") => {
|
||||
let mut buffer = vec![];
|
||||
@@ -662,16 +665,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let storage_broker_server = BrokerServiceServer::new(storage_broker_impl);
|
||||
|
||||
info!("listening on {}", &args.listen_addr);
|
||||
let listener = TcpListener::bind(args.listen_addr).await?;
|
||||
|
||||
// grpc is served along with http1 for metrics on a single port, hence we
|
||||
// don't use tonic's Server.
|
||||
hyper::Server::bind(&args.listen_addr)
|
||||
.http2_keep_alive_interval(Some(args.http2_keepalive_interval))
|
||||
.serve(make_service_fn(move |conn: &AddrStream| {
|
||||
loop {
|
||||
let (stream, remote_addr) = listener.accept().await?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let storage_broker_server_cloned = storage_broker_server.clone();
|
||||
let connect_info = conn.connect_info();
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |mut req| {
|
||||
let service = async move {
|
||||
Ok::<_, Infallible>(service_fn(move |mut req: Request<Incoming>| {
|
||||
// That's what tonic's MakeSvc.call does to pass conninfo to
|
||||
// the request handler (and where its request.remote_addr()
|
||||
// expects it to find).
|
||||
@@ -690,6 +696,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
if req.headers().get("content-type").map(|x| x.as_bytes())
|
||||
== Some(b"application/grpc")
|
||||
{
|
||||
// TODO: this doesn't work :(
|
||||
let (parts, body) = req.into_parts();
|
||||
let req = http0_2::Request::from_parts(parts, body);
|
||||
let res_resp = storage_broker_server_svc.call(req).await;
|
||||
// Grpc and http1 handlers have slightly different
|
||||
// Response types: it is UnsyncBoxBody for the
|
||||
@@ -703,10 +712,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
}
|
||||
}))
|
||||
};
|
||||
|
||||
let builder = conn::auto::Builder::new(TokioExecutor::new())
|
||||
.http2()
|
||||
.keep_alive_interval(Some(args.http2_keepalive_interval));
|
||||
|
||||
if let Err(err) = builder.serve_connection(io, service).await {
|
||||
tracing::error!("Error serving connection: {:?}", err);
|
||||
}
|
||||
}))
|
||||
.await?;
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use hyper::body::HttpBody;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tonic::codegen::StdError;
|
||||
pub use tonic::transport::Uri;
|
||||
use tonic::transport::{ClientTlsConfig, Endpoint};
|
||||
use tonic::{transport::Channel, Status};
|
||||
use utils::id::{TenantId, TenantTimelineId, TimelineId};
|
||||
@@ -27,8 +25,6 @@ pub use tonic::Code;
|
||||
pub use tonic::Request;
|
||||
pub use tonic::Streaming;
|
||||
|
||||
pub use hyper::Uri;
|
||||
|
||||
pub const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:50051";
|
||||
pub const DEFAULT_ENDPOINT: &str = const_format::formatcp!("http://{DEFAULT_LISTEN_ADDR}");
|
||||
|
||||
@@ -99,50 +95,7 @@ pub fn parse_proto_ttid(proto_ttid: &ProtoTenantTimelineId) -> Result<TenantTime
|
||||
// well.
|
||||
type AnyError = Box<dyn std::error::Error + Send + Sync + 'static>;
|
||||
|
||||
// Provides impl HttpBody for two different types implementing it. Inspired by
|
||||
// https://github.com/hyperium/tonic/blob/master/examples/src/hyper_warp/server.rs
|
||||
pub enum EitherBody<A, B> {
|
||||
Left(A),
|
||||
Right(B),
|
||||
}
|
||||
|
||||
impl<A, B> HttpBody for EitherBody<A, B>
|
||||
where
|
||||
A: HttpBody + Send + Unpin,
|
||||
B: HttpBody<Data = A::Data> + Send + Unpin,
|
||||
A::Error: Into<AnyError>,
|
||||
B::Error: Into<AnyError>,
|
||||
{
|
||||
type Data = A::Data;
|
||||
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
|
||||
|
||||
fn is_end_stream(&self) -> bool {
|
||||
match self {
|
||||
EitherBody::Left(b) => b.is_end_stream(),
|
||||
EitherBody::Right(b) => b.is_end_stream(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_data(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
|
||||
match self.get_mut() {
|
||||
EitherBody::Left(b) => Pin::new(b).poll_data(cx).map(map_option_err),
|
||||
EitherBody::Right(b) => Pin::new(b).poll_data(cx).map(map_option_err),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_trailers(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
|
||||
match self.get_mut() {
|
||||
EitherBody::Left(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
|
||||
EitherBody::Right(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub type EitherBody<L, R> = http_body_util::Either<L, R>;
|
||||
|
||||
fn map_option_err<T, U: Into<AnyError>>(err: Option<Result<T, U>>) -> Option<Result<T, AnyError>> {
|
||||
err.map(|e| e.map_err(Into::into))
|
||||
|
||||
Reference in New Issue
Block a user