proxy: small changes (#8752)

## Problem

#8736 is getting too big. splitting off some simple changes here

## Summary of changes

Local proxy wont always be using tls, so make it optional. Local proxy
wont be using ws for now, so make it optional. Remove a dead config var.
This commit is contained in:
Conrad Ludgate
2024-08-20 14:16:27 +01:00
committed by GitHub
parent 1c96957e85
commit 0170611a97
5 changed files with 65 additions and 26 deletions

View File

@@ -10,6 +10,7 @@ mod json;
mod sql_over_http;
mod websocket;
use async_trait::async_trait;
use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
@@ -26,8 +27,9 @@ use rand::rngs::StdRng;
use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::timeout;
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tokio_rustls::TlsAcceptor;
use tokio_util::task::TaskTracker;
use crate::cancellation::CancellationHandlerMain;
@@ -41,7 +43,7 @@ use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use std::net::{IpAddr, SocketAddr};
use std::pin::pin;
use std::pin::{pin, Pin};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
@@ -86,18 +88,18 @@ pub async fn task_main(
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_config = match config.tls_config.as_ref() {
Some(config) => config,
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
Some(config) => {
let mut tls_server_config = rustls::ServerConfig::clone(&config.to_server_config());
// prefer http2, but support http/1.1
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(tls_server_config) as Arc<_>
}
None => {
warn!("TLS config is missing, WebSocket Secure server will not be started");
return Ok(());
warn!("TLS config is missing");
Arc::new(NoTls) as Arc<_>
}
};
let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
// prefer http2, but support http/1.1
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
@@ -176,16 +178,41 @@ pub async fn task_main(
Ok(())
}
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
impl<T: AsyncRead + AsyncWrite + Send + 'static> AsyncReadWrite for T {}
pub type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
#[async_trait]
trait MaybeTlsAcceptor: Send + Sync + 'static {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>;
}
#[async_trait]
impl MaybeTlsAcceptor for rustls::ServerConfig {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?))
}
}
struct NoTls;
#[async_trait]
impl MaybeTlsAcceptor for NoTls {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(conn))
}
}
/// Handles the TCP startup lifecycle.
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
async fn connection_startup(
config: &ProxyConfig,
tls_acceptor: TlsAcceptor,
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
session_id: uuid::Uuid,
conn: TcpStream,
peer_addr: SocketAddr,
) -> Option<(TlsStream<ChainRW<TcpStream>>, IpAddr)> {
) -> Option<(AsyncRW, IpAddr)> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
@@ -241,7 +268,7 @@ async fn connection_handler(
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conn: TlsStream<ChainRW<TcpStream>>,
conn: AsyncRW,
peer_addr: IpAddr,
session_id: uuid::Uuid,
) {
@@ -326,7 +353,9 @@ async fn request_handler(
.map(|s| s.to_string());
// Check if the request is a websocket upgrade request.
if framed_websockets::upgrade::is_upgrade_request(&request) {
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestMonitoring::new(
session_id,
peer_addr,