mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
proxy: add flag to reject requests without proxy protocol client ip (#5417)
## Problem We need a flag to require proxy protocol (prerequisite for #5416) ## Summary of changes Add a cli flag to require client IP addresses. Error if IP address is missing when the flag is active.
This commit is contained in:
@@ -83,6 +83,10 @@ struct ProxyCliArgs {
|
||||
/// timeout for http connections
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
sql_over_http_timeout: tokio::time::Duration,
|
||||
|
||||
/// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
require_client_ip: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@@ -233,6 +237,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
metric_collection,
|
||||
allow_self_signed_compute: args.allow_self_signed_compute,
|
||||
http_config,
|
||||
require_client_ip: args.require_client_ip,
|
||||
}));
|
||||
|
||||
Ok(config)
|
||||
|
||||
@@ -14,6 +14,7 @@ pub struct ProxyConfig {
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub http_config: HttpConfig,
|
||||
pub require_client_ip: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::{
|
||||
NUM_CLIENT_CONNECTION_OPENED_COUNTER,
|
||||
},
|
||||
};
|
||||
use anyhow::bail;
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream, StreamExt};
|
||||
use hyper::{
|
||||
@@ -22,7 +23,6 @@ use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::ready,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
@@ -280,12 +280,18 @@ pub async fn task_main(
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
|
||||
let (io, tls) = stream.get_ref();
|
||||
let peer_addr = io.client_addr().unwrap_or(io.inner.remote_addr());
|
||||
let client_addr = io.client_addr();
|
||||
let remote_addr = io.inner.remote_addr();
|
||||
let sni_name = tls.server_name().map(|s| s.to_string());
|
||||
let conn_pool = conn_pool.clone();
|
||||
|
||||
async move {
|
||||
Ok::<_, Infallible>(MetricService::new(hyper::service::service_fn(
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
None if config.require_client_ip => bail!("missing required client ip"),
|
||||
None => remote_addr,
|
||||
};
|
||||
Ok(MetricService::new(hyper::service::service_fn(
|
||||
move |req: Request<Body>| {
|
||||
let sni_name = sni_name.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
|
||||
@@ -200,6 +200,8 @@ pub async fn task_main(
|
||||
let mut socket = WithClientIp::new(socket);
|
||||
if let Some(ip) = socket.wait_for_addr().await? {
|
||||
tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
|
||||
} else if config.require_client_ip {
|
||||
bail!("missing required client IP");
|
||||
}
|
||||
|
||||
socket
|
||||
|
||||
Reference in New Issue
Block a user