proxy: small changes (#8752)

## Problem

#8736 is getting too big. splitting off some simple changes here

## Summary of changes

Local proxy wont always be using tls, so make it optional. Local proxy
wont be using ws for now, so make it optional. Remove a dead config var.
This commit is contained in:
Conrad Ludgate
2024-08-20 14:16:27 +01:00
committed by GitHub
parent 1c96957e85
commit 0170611a97
5 changed files with 65 additions and 26 deletions

View File

@@ -173,9 +173,6 @@ struct ProxyCliArgs {
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// disable ip check for http requests. If it is too time consuming, it could be turned off.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_ip_check_for_http: bool,
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
#[clap(long)]
redis_notifications: Option<String>,
@@ -661,6 +658,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
)?;
let http_config = HttpConfig {
accept_websockets: true,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,

View File

@@ -52,6 +52,7 @@ pub struct TlsConfig {
}
pub struct HttpConfig {
pub accept_websockets: bool,
pub pool_options: GlobalConnPoolOptions,
pub cancel_set: CancelSet,
pub client_conn_threshold: u64,

View File

@@ -10,6 +10,7 @@ mod json;
mod sql_over_http;
mod websocket;
use async_trait::async_trait;
use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
@@ -26,8 +27,9 @@ use rand::rngs::StdRng;
use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::timeout;
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tokio_rustls::TlsAcceptor;
use tokio_util::task::TaskTracker;
use crate::cancellation::CancellationHandlerMain;
@@ -41,7 +43,7 @@ use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use std::net::{IpAddr, SocketAddr};
use std::pin::pin;
use std::pin::{pin, Pin};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
@@ -86,18 +88,18 @@ pub async fn task_main(
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_config = match config.tls_config.as_ref() {
Some(config) => config,
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
Some(config) => {
let mut tls_server_config = rustls::ServerConfig::clone(&config.to_server_config());
// prefer http2, but support http/1.1
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(tls_server_config) as Arc<_>
}
None => {
warn!("TLS config is missing, WebSocket Secure server will not be started");
return Ok(());
warn!("TLS config is missing");
Arc::new(NoTls) as Arc<_>
}
};
let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
// prefer http2, but support http/1.1
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
@@ -176,16 +178,41 @@ pub async fn task_main(
Ok(())
}
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
impl<T: AsyncRead + AsyncWrite + Send + 'static> AsyncReadWrite for T {}
pub type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
#[async_trait]
trait MaybeTlsAcceptor: Send + Sync + 'static {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>;
}
#[async_trait]
impl MaybeTlsAcceptor for rustls::ServerConfig {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?))
}
}
struct NoTls;
#[async_trait]
impl MaybeTlsAcceptor for NoTls {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(conn))
}
}
/// Handles the TCP startup lifecycle.
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
async fn connection_startup(
config: &ProxyConfig,
tls_acceptor: TlsAcceptor,
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
session_id: uuid::Uuid,
conn: TcpStream,
peer_addr: SocketAddr,
) -> Option<(TlsStream<ChainRW<TcpStream>>, IpAddr)> {
) -> Option<(AsyncRW, IpAddr)> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
@@ -241,7 +268,7 @@ async fn connection_handler(
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conn: TlsStream<ChainRW<TcpStream>>,
conn: AsyncRW,
peer_addr: IpAddr,
session_id: uuid::Uuid,
) {
@@ -326,7 +353,9 @@ async fn request_handler(
.map(|s| s.to_string());
// Check if the request is a websocket upgrade request.
if framed_websockets::upgrade::is_upgrade_request(&request) {
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestMonitoring::new(
session_id,
peer_addr,

View File

@@ -758,6 +758,7 @@ mod tests {
async fn test_pool() {
let _ = env_logger::try_init();
let config = Box::leak(Box::new(crate::config::HttpConfig {
accept_websockets: false,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: 2,
gc_epoch: Duration::from_secs(1),

View File

@@ -147,7 +147,7 @@ impl UserFacingError for ConnInfoError {
fn get_conn_info(
ctx: &RequestMonitoring,
headers: &HeaderMap,
tls: &TlsConfig,
tls: Option<&TlsConfig>,
) -> Result<ConnInfo, ConnInfoError> {
// HTTP only uses cleartext (for now and likely always)
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
@@ -184,12 +184,22 @@ fn get_conn_info(
.ok_or(ConnInfoError::MissingPassword)?;
let password = urlencoding::decode_binary(password.as_bytes());
let hostname = connection_url
.host_str()
.ok_or(ConnInfoError::MissingHostname)?;
let endpoint =
endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
if let Some(tls) = tls {
endpoint_sni(hostname, &tls.common_names)?
.ok_or(ConnInfoError::MalformedEndpoint)?
} else {
hostname
.split_once(".")
.map_or(hostname, |(prefix, _)| prefix)
.into()
}
}
Some(url::Host::Ipv4(_)) | Some(url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname)
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
@@ -502,7 +512,7 @@ async fn handle_inner(
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
info!(user = conn_info.user_info.user.as_str(), "credentials");
// Allow connection pooling only if explicitly requested