diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index b44e0ddd2f..d83a1f3bcf 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -173,9 +173,6 @@ struct ProxyCliArgs { /// cache for `role_secret` (use `size=0` to disable) #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] role_secret_cache: String, - /// disable ip check for http requests. If it is too time consuming, it could be turned off. - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - disable_ip_check_for_http: bool, /// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections) #[clap(long)] redis_notifications: Option, @@ -661,6 +658,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { )?; let http_config = HttpConfig { + accept_websockets: true, pool_options: GlobalConnPoolOptions { max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 36d04924f2..a280aa88ce 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -52,6 +52,7 @@ pub struct TlsConfig { } pub struct HttpConfig { + pub accept_websockets: bool, pub pool_options: GlobalConnPoolOptions, pub cancel_set: CancelSet, pub client_conn_threshold: u64, diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 5416d63b5b..b2bf93dc6d 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -10,6 +10,7 @@ mod json; mod sql_over_http; mod websocket; +use async_trait::async_trait; use atomic_take::AtomicTake; use bytes::Bytes; pub use conn_pool::GlobalConnPoolOptions; @@ -26,8 +27,9 @@ use rand::rngs::StdRng; use rand::SeedableRng; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::timeout; -use tokio_rustls::{server::TlsStream, TlsAcceptor}; +use tokio_rustls::TlsAcceptor; use tokio_util::task::TaskTracker; use crate::cancellation::CancellationHandlerMain; @@ -41,7 +43,7 @@ use crate::serverless::backend::PoolingBackend; use crate::serverless::http_util::{api_error_into_response, json_response}; use std::net::{IpAddr, SocketAddr}; -use std::pin::pin; +use std::pin::{pin, Pin}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio_util::sync::CancellationToken; @@ -86,18 +88,18 @@ pub async fn task_main( config, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), }); - - let tls_config = match config.tls_config.as_ref() { - Some(config) => config, + let tls_acceptor: Arc = match config.tls_config.as_ref() { + Some(config) => { + let mut tls_server_config = rustls::ServerConfig::clone(&config.to_server_config()); + // prefer http2, but support http/1.1 + tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + Arc::new(tls_server_config) as Arc<_> + } None => { - warn!("TLS config is missing, WebSocket Secure server will not be started"); - return Ok(()); + warn!("TLS config is missing"); + Arc::new(NoTls) as Arc<_> } }; - let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config()); - // prefer http2, but support http/1.1 - tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into(); let connections = tokio_util::task::task_tracker::TaskTracker::new(); connections.close(); // allows `connections.wait to complete` @@ -176,16 +178,41 @@ pub async fn task_main( Ok(()) } +pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {} +impl AsyncReadWrite for T {} +pub type AsyncRW = Pin>; + +#[async_trait] +trait MaybeTlsAcceptor: Send + Sync + 'static { + async fn accept(self: Arc, conn: ChainRW) -> std::io::Result; +} + +#[async_trait] +impl MaybeTlsAcceptor for rustls::ServerConfig { + async fn accept(self: Arc, conn: ChainRW) -> std::io::Result { + Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?)) + } +} + +struct NoTls; + +#[async_trait] +impl MaybeTlsAcceptor for NoTls { + async fn accept(self: Arc, conn: ChainRW) -> std::io::Result { + Ok(Box::pin(conn)) + } +} + /// Handles the TCP startup lifecycle. /// 1. Parses PROXY protocol V2 /// 2. Handles TLS handshake async fn connection_startup( config: &ProxyConfig, - tls_acceptor: TlsAcceptor, + tls_acceptor: Arc, session_id: uuid::Uuid, conn: TcpStream, peer_addr: SocketAddr, -) -> Option<(TlsStream>, IpAddr)> { +) -> Option<(AsyncRW, IpAddr)> { // handle PROXY protocol let (conn, peer) = match read_proxy_protocol(conn).await { Ok(c) => c, @@ -241,7 +268,7 @@ async fn connection_handler( cancellation_handler: Arc, endpoint_rate_limiter: Arc, cancellation_token: CancellationToken, - conn: TlsStream>, + conn: AsyncRW, peer_addr: IpAddr, session_id: uuid::Uuid, ) { @@ -326,7 +353,9 @@ async fn request_handler( .map(|s| s.to_string()); // Check if the request is a websocket upgrade request. - if framed_websockets::upgrade::is_upgrade_request(&request) { + if config.http_config.accept_websockets + && framed_websockets::upgrade::is_upgrade_request(&request) + { let ctx = RequestMonitoring::new( session_id, peer_addr, diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 9ede659cc4..3478787995 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -758,6 +758,7 @@ mod tests { async fn test_pool() { let _ = env_logger::try_init(); let config = Box::leak(Box::new(crate::config::HttpConfig { + accept_websockets: false, pool_options: GlobalConnPoolOptions { max_conns_per_endpoint: 2, gc_epoch: Duration::from_secs(1), diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index c41df07a4d..bbfed90f39 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -147,7 +147,7 @@ impl UserFacingError for ConnInfoError { fn get_conn_info( ctx: &RequestMonitoring, headers: &HeaderMap, - tls: &TlsConfig, + tls: Option<&TlsConfig>, ) -> Result { // HTTP only uses cleartext (for now and likely always) ctx.set_auth_method(crate::context::AuthMethod::Cleartext); @@ -184,12 +184,22 @@ fn get_conn_info( .ok_or(ConnInfoError::MissingPassword)?; let password = urlencoding::decode_binary(password.as_bytes()); - let hostname = connection_url - .host_str() - .ok_or(ConnInfoError::MissingHostname)?; - - let endpoint = - endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?; + let endpoint = match connection_url.host() { + Some(url::Host::Domain(hostname)) => { + if let Some(tls) = tls { + endpoint_sni(hostname, &tls.common_names)? + .ok_or(ConnInfoError::MalformedEndpoint)? + } else { + hostname + .split_once(".") + .map_or(hostname, |(prefix, _)| prefix) + .into() + } + } + Some(url::Host::Ipv4(_)) | Some(url::Host::Ipv6(_)) | None => { + return Err(ConnInfoError::MissingHostname) + } + }; ctx.set_endpoint_id(endpoint.clone()); let pairs = connection_url.query_pairs(); @@ -502,7 +512,7 @@ async fn handle_inner( let headers = request.headers(); // TLS config should be there. - let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?; + let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?; info!(user = conn_info.user_info.user.as_str(), "credentials"); // Allow connection pooling only if explicitly requested