From 32126d705b940cee5393eaf81142d3eb2e401a47 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 25 Oct 2023 15:43:03 +0100 Subject: [PATCH] proxy refactor serverless (#4685) ## Problem Our serverless backend was a bit jumbled. As a comment indicated, we were handling SQL-over-HTTP in our `websocket.rs` file. I've extracted out the `sql_over_http` and `websocket` files from the `http` module and put them into a new module called `serverless`. ## Summary of changes ```sh mkdir proxy/src/serverless mv proxy/src/http/{conn_pool,sql_over_http,websocket}.rs proxy/src/serverless/ mv proxy/src/http/server.rs proxy/src/http/health_server.rs mv proxy/src/metrics proxy/src/usage_metrics.rs ``` I have also extracted the hyper server and handler from websocket.rs into `serverless.rs` --- proxy/src/bin/proxy.rs | 21 +- proxy/src/http.rs | 5 +- .../src/http/{server.rs => health_server.rs} | 0 proxy/src/lib.rs | 3 +- proxy/src/proxy.rs | 2 +- .../src/{http/websocket.rs => serverless.rs} | 319 +++++------------- proxy/src/{http => serverless}/conn_pool.rs | 2 +- .../src/{http => serverless}/sql_over_http.rs | 0 proxy/src/serverless/websocket.rs | 146 ++++++++ proxy/src/{metrics.rs => usage_metrics.rs} | 0 10 files changed, 255 insertions(+), 243 deletions(-) rename proxy/src/http/{server.rs => health_server.rs} (100%) rename proxy/src/{http/websocket.rs => serverless.rs} (51%) rename proxy/src/{http => serverless}/conn_pool.rs (99%) rename proxy/src/{http => serverless}/sql_over_http.rs (100%) create mode 100644 proxy/src/serverless/websocket.rs rename proxy/src/{metrics.rs => usage_metrics.rs} (100%) diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 9b54f98402..a9ca308797 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -4,10 +4,11 @@ use proxy::config::AuthenticationConfig; use proxy::config::HttpConfig; use proxy::console; use proxy::http; -use proxy::metrics; +use proxy::usage_metrics; use anyhow::bail; use proxy::config::{self, ProxyConfig}; +use proxy::serverless; use std::pin::pin; use std::{borrow::Cow, net::SocketAddr}; use tokio::net::TcpListener; @@ -129,14 +130,16 @@ async fn main() -> anyhow::Result<()> { cancellation_token.clone(), )); - if let Some(wss_address) = args.wss { - let wss_address: SocketAddr = wss_address.parse()?; - info!("Starting wss on {wss_address}"); - let wss_listener = TcpListener::bind(wss_address).await?; + // TODO: rename the argument to something like serverless. + // It now covers more than just websockets, it also covers SQL over HTTP. + if let Some(serverless_address) = args.wss { + let serverless_address: SocketAddr = serverless_address.parse()?; + info!("Starting wss on {serverless_address}"); + let serverless_listener = TcpListener::bind(serverless_address).await?; - client_tasks.spawn(http::websocket::task_main( + client_tasks.spawn(serverless::task_main( config, - wss_listener, + serverless_listener, cancellation_token.clone(), )); } @@ -144,11 +147,11 @@ async fn main() -> anyhow::Result<()> { // maintenance tasks. these never return unless there's an error let mut maintenance_tasks = JoinSet::new(); maintenance_tasks.spawn(proxy::handle_signals(cancellation_token)); - maintenance_tasks.spawn(http::server::task_main(http_listener)); + maintenance_tasks.spawn(http::health_server::task_main(http_listener)); maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener)); if let Some(metrics_config) = &config.metric_collection { - maintenance_tasks.spawn(metrics::task_main(metrics_config)); + maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); } let maintenance = loop { diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 755137807b..14a9072a45 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -2,10 +2,7 @@ //! Other modules should use stuff from this module instead of //! directly relying on deps like `reqwest` (think loose coupling). -pub mod conn_pool; -pub mod server; -pub mod sql_over_http; -pub mod websocket; +pub mod health_server; use std::{sync::Arc, time::Duration}; diff --git a/proxy/src/http/server.rs b/proxy/src/http/health_server.rs similarity index 100% rename from proxy/src/http/server.rs rename to proxy/src/http/health_server.rs diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index a3d1cdd3c8..803b278482 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -14,14 +14,15 @@ pub mod console; pub mod error; pub mod http; pub mod logging; -pub mod metrics; pub mod parse; pub mod protocol2; pub mod proxy; pub mod sasl; pub mod scram; +pub mod serverless; pub mod stream; pub mod url; +pub mod usage_metrics; pub mod waiters; /// Handle unix signals appropriately. diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index ec878b696e..165794414f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -8,9 +8,9 @@ use crate::{ config::{AuthenticationConfig, ProxyConfig, TlsConfig}, console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api}, http::StatusCode, - metrics::{Ids, USAGE_METRICS}, protocol2::WithClientIp, stream::{PqStream, Stream}, + usage_metrics::{Ids, USAGE_METRICS}, }; use anyhow::{bail, Context}; use async_trait::async_trait; diff --git a/proxy/src/http/websocket.rs b/proxy/src/serverless.rs similarity index 51% rename from proxy/src/http/websocket.rs rename to proxy/src/serverless.rs index 689a84969c..23deda3ae6 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/serverless.rs @@ -1,235 +1,36 @@ -use crate::{ - cancellation::CancelMap, - config::ProxyConfig, - error::io_error, - protocol2::{ProxyProtocolAccept, WithClientIp}, - proxy::{ - handle_client, ClientMode, NUM_CLIENT_CONNECTION_CLOSED_COUNTER, - NUM_CLIENT_CONNECTION_OPENED_COUNTER, - }, -}; +//! Routers for our serverless APIs +//! +//! Handles both SQL over HTTP and SQL over Websockets. + +mod conn_pool; +mod sql_over_http; +mod websocket; + use anyhow::bail; -use bytes::{Buf, Bytes}; -use futures::{Sink, Stream, StreamExt}; +use hyper::StatusCode; +pub use reqwest_middleware::{ClientWithMiddleware, Error}; +pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; + +use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; +use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER}; +use crate::{cancellation::CancelMap, config::ProxyConfig}; +use futures::StreamExt; use hyper::{ server::{ accept, conn::{AddrIncoming, AddrStream}, }, - upgrade::Upgraded, - Body, Method, Request, Response, StatusCode, + Body, Method, Request, Response, }; -use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; -use pin_project_lite::pin_project; -use std::{ - future::ready, - pin::Pin, - sync::Arc, - task::{ready, Context, Poll}, -}; +use std::task::Poll; +use std::{future::ready, sync::Arc}; use tls_listener::TlsListener; -use tokio::{ - io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}, - net::TcpListener, -}; +use tokio::net::TcpListener; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; use utils::http::{error::ApiError, json::json_response}; -// TODO: use `std::sync::Exclusive` once it's stabilized. -// Tracking issue: https://github.com/rust-lang/rust/issues/98407. -use sync_wrapper::SyncWrapper; - -use super::{conn_pool::GlobalConnPool, sql_over_http}; - -pin_project! { - /// This is a wrapper around a [`WebSocketStream`] that - /// implements [`AsyncRead`] and [`AsyncWrite`]. - pub struct WebSocketRw { - #[pin] - stream: SyncWrapper>, - bytes: Bytes, - } -} - -impl WebSocketRw { - pub fn new(stream: WebSocketStream) -> Self { - Self { - stream: stream.into(), - bytes: Bytes::new(), - } - } -} - -impl AsyncWrite for WebSocketRw { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let mut stream = self.project().stream.get_pin_mut(); - - ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?; - match stream.as_mut().start_send(Message::Binary(buf.into())) { - Ok(()) => Poll::Ready(Ok(buf.len())), - Err(e) => Poll::Ready(Err(io_error(e))), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let stream = self.project().stream.get_pin_mut(); - stream.poll_flush(cx).map_err(io_error) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let stream = self.project().stream.get_pin_mut(); - stream.poll_close(cx).map_err(io_error) - } -} - -impl AsyncRead for WebSocketRw { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if buf.remaining() > 0 { - let bytes = ready!(self.as_mut().poll_fill_buf(cx))?; - let len = std::cmp::min(bytes.len(), buf.remaining()); - buf.put_slice(&bytes[..len]); - self.consume(len); - } - - Poll::Ready(Ok(())) - } -} - -impl AsyncBufRead for WebSocketRw { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Please refer to poll_fill_buf's documentation. - const EOF: Poll> = Poll::Ready(Ok(&[])); - - let mut this = self.project(); - loop { - if !this.bytes.chunk().is_empty() { - let chunk = (*this.bytes).chunk(); - return Poll::Ready(Ok(chunk)); - } - - let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx)); - match res.transpose().map_err(io_error)? { - Some(message) => match message { - Message::Ping(_) => {} - Message::Pong(_) => {} - Message::Text(text) => { - // We expect to see only binary messages. - let error = "unexpected text message in the websocket"; - warn!(length = text.len(), error); - return Poll::Ready(Err(io_error(error))); - } - Message::Frame(_) => { - // This case is impossible according to Frame's doc. - panic!("unexpected raw frame in the websocket"); - } - Message::Binary(chunk) => { - assert!(this.bytes.is_empty()); - *this.bytes = Bytes::from(chunk); - } - Message::Close(_) => return EOF, - }, - None => return EOF, - } - } - } - - fn consume(self: Pin<&mut Self>, amount: usize) { - self.project().bytes.advance(amount); - } -} - -async fn serve_websocket( - websocket: HyperWebsocket, - config: &'static ProxyConfig, - cancel_map: &CancelMap, - session_id: uuid::Uuid, - hostname: Option, -) -> anyhow::Result<()> { - let websocket = websocket.await?; - handle_client( - config, - cancel_map, - session_id, - WebSocketRw::new(websocket), - ClientMode::Websockets { hostname }, - ) - .await?; - Ok(()) -} - -async fn ws_handler( - mut request: Request, - config: &'static ProxyConfig, - conn_pool: Arc, - cancel_map: Arc, - session_id: uuid::Uuid, - sni_hostname: Option, -) -> Result, ApiError> { - let host = request - .headers() - .get("host") - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.split(':').next()) - .map(|s| s.to_string()); - - // Check if the request is a websocket upgrade request. - if hyper_tungstenite::is_upgrade_request(&request) { - info!(session_id = ?session_id, "performing websocket upgrade"); - - let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) - .map_err(|e| ApiError::BadRequest(e.into()))?; - - tokio::spawn( - async move { - if let Err(e) = - serve_websocket(websocket, config, &cancel_map, session_id, host).await - { - error!(session_id = ?session_id, "error in websocket connection: {e:#}"); - } - } - .in_current_span(), - ); - - // Return the response so the spawned future can continue. - Ok(response) - // TODO: that deserves a refactor as now this function also handles http json client besides websockets. - // Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead. - } else if request.uri().path() == "/sql" && request.method() == Method::POST { - sql_over_http::handle( - request, - sni_hostname, - conn_pool, - session_id, - &config.http_config, - ) - .await - } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { - Response::builder() - .header("Allow", "OPTIONS, POST") - .header("Access-Control-Allow-Origin", "*") - .header( - "Access-Control-Allow-Headers", - "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In", - ) - .header("Access-Control-Max-Age", "86400" /* 24 hours */) - .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code - .body(Body::empty()) - .map_err(|e| ApiError::InternalServerError(e.into())) - } else { - json_response(StatusCode::BAD_REQUEST, "query is not supported") - } -} - pub async fn task_main( config: &'static ProxyConfig, ws_listener: TcpListener, @@ -239,7 +40,7 @@ pub async fn task_main( info!("websocket server has shut down"); } - let conn_pool: Arc = GlobalConnPool::new(config); + let conn_pool = conn_pool::GlobalConnPool::new(config); // shutdown the connection pool tokio::spawn({ @@ -300,13 +101,15 @@ pub async fn task_main( let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); - ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name) - .instrument(info_span!( - "ws-client", - session = %session_id, - %peer_addr, - )) - .await + request_handler( + req, config, conn_pool, cancel_map, session_id, sni_name, + ) + .instrument(info_span!( + "serverless", + session = %session_id, + %peer_addr, + )) + .await } }, ))) @@ -359,3 +162,65 @@ where self.inner.call(req) } } + +async fn request_handler( + mut request: Request, + config: &'static ProxyConfig, + conn_pool: Arc, + cancel_map: Arc, + session_id: uuid::Uuid, + sni_hostname: Option, +) -> Result, ApiError> { + let host = request + .headers() + .get("host") + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.split(':').next()) + .map(|s| s.to_string()); + + // Check if the request is a websocket upgrade request. + if hyper_tungstenite::is_upgrade_request(&request) { + info!(session_id = ?session_id, "performing websocket upgrade"); + + let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) + .map_err(|e| ApiError::BadRequest(e.into()))?; + + tokio::spawn( + async move { + if let Err(e) = + websocket::serve_websocket(websocket, config, &cancel_map, session_id, host) + .await + { + error!(session_id = ?session_id, "error in websocket connection: {e:#}"); + } + } + .in_current_span(), + ); + + // Return the response so the spawned future can continue. + Ok(response) + } else if request.uri().path() == "/sql" && request.method() == Method::POST { + sql_over_http::handle( + request, + sni_hostname, + conn_pool, + session_id, + &config.http_config, + ) + .await + } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { + Response::builder() + .header("Allow", "OPTIONS, POST") + .header("Access-Control-Allow-Origin", "*") + .header( + "Access-Control-Allow-Headers", + "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In", + ) + .header("Access-Control-Max-Age", "86400" /* 24 hours */) + .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code + .body(Body::empty()) + .map_err(|e| ApiError::InternalServerError(e.into())) + } else { + json_response(StatusCode::BAD_REQUEST, "query is not supported") + } +} diff --git a/proxy/src/http/conn_pool.rs b/proxy/src/serverless/conn_pool.rs similarity index 99% rename from proxy/src/http/conn_pool.rs rename to proxy/src/serverless/conn_pool.rs index 5218a44479..c5bfc32568 100644 --- a/proxy/src/http/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -22,8 +22,8 @@ use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use crate::{ auth, console, - metrics::{Ids, MetricCounter, USAGE_METRICS}, proxy::{LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER}, + usage_metrics::{Ids, MetricCounter, USAGE_METRICS}, }; use crate::{compute, config}; diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs similarity index 100% rename from proxy/src/http/sql_over_http.rs rename to proxy/src/serverless/sql_over_http.rs diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs new file mode 100644 index 0000000000..86141ab64f --- /dev/null +++ b/proxy/src/serverless/websocket.rs @@ -0,0 +1,146 @@ +use crate::{ + cancellation::CancelMap, + config::ProxyConfig, + error::io_error, + proxy::{handle_client, ClientMode}, +}; +use bytes::{Buf, Bytes}; +use futures::{Sink, Stream}; +use hyper::upgrade::Upgraded; +use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; +use pin_project_lite::pin_project; + +use std::{ + pin::Pin, + task::{ready, Context, Poll}, +}; +use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; +use tracing::warn; + +// TODO: use `std::sync::Exclusive` once it's stabilized. +// Tracking issue: https://github.com/rust-lang/rust/issues/98407. +use sync_wrapper::SyncWrapper; + +pin_project! { + /// This is a wrapper around a [`WebSocketStream`] that + /// implements [`AsyncRead`] and [`AsyncWrite`]. + pub struct WebSocketRw { + #[pin] + stream: SyncWrapper>, + bytes: Bytes, + } +} + +impl WebSocketRw { + pub fn new(stream: WebSocketStream) -> Self { + Self { + stream: stream.into(), + bytes: Bytes::new(), + } + } +} + +impl AsyncWrite for WebSocketRw { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut stream = self.project().stream.get_pin_mut(); + + ready!(stream.as_mut().poll_ready(cx).map_err(io_error))?; + match stream.as_mut().start_send(Message::Binary(buf.into())) { + Ok(()) => Poll::Ready(Ok(buf.len())), + Err(e) => Poll::Ready(Err(io_error(e))), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let stream = self.project().stream.get_pin_mut(); + stream.poll_flush(cx).map_err(io_error) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let stream = self.project().stream.get_pin_mut(); + stream.poll_close(cx).map_err(io_error) + } +} + +impl AsyncRead for WebSocketRw { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if buf.remaining() > 0 { + let bytes = ready!(self.as_mut().poll_fill_buf(cx))?; + let len = std::cmp::min(bytes.len(), buf.remaining()); + buf.put_slice(&bytes[..len]); + self.consume(len); + } + + Poll::Ready(Ok(())) + } +} + +impl AsyncBufRead for WebSocketRw { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Please refer to poll_fill_buf's documentation. + const EOF: Poll> = Poll::Ready(Ok(&[])); + + let mut this = self.project(); + loop { + if !this.bytes.chunk().is_empty() { + let chunk = (*this.bytes).chunk(); + return Poll::Ready(Ok(chunk)); + } + + let res = ready!(this.stream.as_mut().get_pin_mut().poll_next(cx)); + match res.transpose().map_err(io_error)? { + Some(message) => match message { + Message::Ping(_) => {} + Message::Pong(_) => {} + Message::Text(text) => { + // We expect to see only binary messages. + let error = "unexpected text message in the websocket"; + warn!(length = text.len(), error); + return Poll::Ready(Err(io_error(error))); + } + Message::Frame(_) => { + // This case is impossible according to Frame's doc. + panic!("unexpected raw frame in the websocket"); + } + Message::Binary(chunk) => { + assert!(this.bytes.is_empty()); + *this.bytes = Bytes::from(chunk); + } + Message::Close(_) => return EOF, + }, + None => return EOF, + } + } + } + + fn consume(self: Pin<&mut Self>, amount: usize) { + self.project().bytes.advance(amount); + } +} + +pub async fn serve_websocket( + websocket: HyperWebsocket, + config: &'static ProxyConfig, + cancel_map: &CancelMap, + session_id: uuid::Uuid, + hostname: Option, +) -> anyhow::Result<()> { + let websocket = websocket.await?; + handle_client( + config, + cancel_map, + session_id, + WebSocketRw::new(websocket), + ClientMode::Websockets { hostname }, + ) + .await?; + Ok(()) +} diff --git a/proxy/src/metrics.rs b/proxy/src/usage_metrics.rs similarity index 100% rename from proxy/src/metrics.rs rename to proxy/src/usage_metrics.rs