From 8e7d2aab76590ff82f56d485df9eaf528a15c1a0 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 18 Sep 2024 14:28:59 +0100 Subject: [PATCH] put it all together --- proxy/src/auth/backend.rs | 1 + proxy/src/bin/local_proxy.rs | 1 + proxy/src/bin/proxy.rs | 1 + proxy/src/config.rs | 1 + proxy/src/serverless.rs | 9 +- proxy/src/serverless/backend.rs | 13 +-- proxy/src/serverless/http_conn_pool.rs | 48 +++++----- proxy/src/serverless/http_util.rs | 21 +++-- proxy/src/serverless/sql_over_http.rs | 117 +++++++++++++++++-------- 9 files changed, 128 insertions(+), 84 deletions(-) diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index afa0af336b..8e09c002df 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -620,6 +620,7 @@ mod tests { rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, + is_auth_broker: false, }); async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage { diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 06505d29d6..44d264caaf 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -277,6 +277,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig rate_limiter: BucketRateLimiter::new(vec![]), rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, + is_auth_broker: false, }, require_client_ip: false, handshake_timeout: Duration::from_secs(10), diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 3cca754290..748266edc9 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -697,6 +697,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, ip_allowlist_check_enabled: !args.is_private_access_proxy, + is_auth_broker: true, }; let config = Box::leak(Box::new(ProxyConfig { diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 3b044e2f5d..207e4fdc18 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -71,6 +71,7 @@ pub struct AuthenticationConfig { pub rate_limit_ip_subnet: u8, pub ip_allowlist_check_enabled: bool, pub jwks_cache: JwkCache, + pub is_auth_broker: bool, } impl TlsConfig { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 3278d9a658..a7e3fa709b 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -20,7 +20,8 @@ use anyhow::Context; use futures::future::{select, Either}; use futures::TryFutureExt; use http::{Method, Response, StatusCode}; -use http_body_util::Full; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Empty}; use hyper1::body::Incoming; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder; @@ -364,7 +365,7 @@ async fn request_handler( // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, -) -> Result>, ApiError> { +) -> Result>, ApiError> { let host = request .headers() .get("host") @@ -408,7 +409,7 @@ async fn request_handler( ); // Return the response so the spawned future can continue. - Ok(response.map(|_: http_body_util::Empty| Full::new(Bytes::new()))) + Ok(response.map(|b| b.map_err(|x| match x {}).boxed())) } else if request.uri().path() == "/sql" && *request.method() == Method::POST { let ctx = RequestMonitoring::new( session_id, @@ -431,7 +432,7 @@ async fn request_handler( ) .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code - .body(Full::new(Bytes::new())) + .body(Empty::new().map_err(|x| match x {}).boxed()) .map_err(|e| ApiError::InternalServerError(e.into())) } else { json_response(StatusCode::BAD_REQUEST, "query is not supported") diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 3931dbb797..9c80f452cc 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,9 +1,6 @@ use std::{io, sync::Arc, time::Duration}; use async_trait::async_trait; -use bytes::Bytes; -use http_body_util::Full; -use hyper1::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use tokio::net::{lookup_host, TcpStream}; use tracing::{field::display, info}; @@ -396,7 +393,7 @@ impl ConnectMechanism for HyperMechanism { Ok(poll_http2_client( self.pool.clone(), ctx, - self.conn_info.clone(), + &self.conn_info, client, connection, self.conn_id, @@ -411,13 +408,7 @@ async fn connect_http2( host: &str, port: u16, timeout: Duration, -) -> Result< - ( - http2::SendRequest>, - http2::Connection, Full, TokioExecutor>, - ), - HttpConnError, -> { +) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), HttpConnError> { let mut addrs = lookup_host((host, port)).await?; let mut last_err = None; diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 1c92f86dc9..7fc4cce91a 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -1,11 +1,10 @@ -use bytes::Bytes; use dashmap::DashMap; -use http_body_util::Full; use hyper1::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use parking_lot::RwLock; use rand::Rng; use std::collections::VecDeque; +use std::ops::DerefMut; use std::{ ops::Deref, sync::atomic::{self, AtomicUsize}, @@ -18,14 +17,18 @@ use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{context::RequestMonitoring, EndpointCacheKey}; -use tracing::{debug, error, Span}; +use tracing::{debug, error}; use tracing::{info, info_span, Instrument}; use super::conn_pool::ConnInfo; +pub(crate) type Send = http2::SendRequest; +pub(crate) type Connect = + http2::Connection, hyper1::body::Incoming, TokioExecutor>; + #[derive(Clone)] struct ConnPoolEntry { - conn: http2::SendRequest>, + conn: Send, conn_id: uuid::Uuid, aux: MetricsAuxInfo, } @@ -215,7 +218,7 @@ impl GlobalConnPool { ); ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); ctx.success(); - Some(Client::new(client.conn, conn_info.clone(), client.aux)) + Some(Client::new(client.conn, client.aux)) } fn get_or_create_endpoint_pool( @@ -263,9 +266,9 @@ impl GlobalConnPool { pub(crate) fn poll_http2_client( global_pool: Arc, ctx: &RequestMonitoring, - conn_info: ConnInfo, - client: http2::SendRequest>, - connection: http2::Connection, Full, TokioExecutor>, + conn_info: &ConnInfo, + client: Send, + connection: Connect, conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> Client { @@ -314,7 +317,7 @@ pub(crate) fn poll_http2_client( .instrument(span), ); - Client::new(client, conn_info, aux) + Client::new(client, aux) } impl Client { @@ -327,32 +330,23 @@ impl Client { } pub(crate) struct Client { - span: Span, - inner: http2::SendRequest>, + inner: Send, aux: MetricsAuxInfo, - conn_info: ConnInfo, } impl Client { - pub(self) fn new( - inner: http2::SendRequest>, - conn_info: ConnInfo, - aux: MetricsAuxInfo, - ) -> Self { - Self { - inner, - span: Span::current(), - conn_info, - aux, - } - } - pub(crate) fn inner(&mut self) -> &mut http2::SendRequest> { - &mut self.inner + pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self { + Self { inner, aux } } } +impl DerefMut for Client { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} impl Deref for Client { - type Target = http2::SendRequest>; + type Target = Send; fn deref(&self) -> &Self::Target { &self.inner diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs index abf0ffe290..d766a46577 100644 --- a/proxy/src/serverless/http_util.rs +++ b/proxy/src/serverless/http_util.rs @@ -5,13 +5,13 @@ use bytes::Bytes; use anyhow::Context; use http::{Response, StatusCode}; -use http_body_util::Full; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; use serde::Serialize; use utils::http::error::ApiError; /// Like [`ApiError::into_response`] -pub(crate) fn api_error_into_response(this: ApiError) -> Response> { +pub(crate) fn api_error_into_response(this: ApiError) -> Response> { match this { ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status( format!("{err:#?}"), // use debug printing so that we give the cause @@ -64,17 +64,24 @@ struct HttpErrorBody { impl HttpErrorBody { /// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`] - fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response> { + fn response_from_msg_and_status( + msg: String, + status: StatusCode, + ) -> Response> { HttpErrorBody { msg }.to_response(status) } /// Same as [`utils::http::error::HttpErrorBody::to_response`] - fn to_response(&self, status: StatusCode) -> Response> { + fn to_response(&self, status: StatusCode) -> Response> { Response::builder() .status(status) .header(http::header::CONTENT_TYPE, "application/json") // we do not have nested maps with non string keys so serialization shouldn't fail - .body(Full::new(Bytes::from(serde_json::to_string(self).unwrap()))) + .body( + Full::new(Bytes::from(serde_json::to_string(self).unwrap())) + .map_err(|x| match x {}) + .boxed(), + ) .unwrap() } } @@ -83,14 +90,14 @@ impl HttpErrorBody { pub(crate) fn json_response( status: StatusCode, data: T, -) -> Result>, ApiError> { +) -> Result>, ApiError> { let json = serde_json::to_string(&data) .context("Failed to serialize JSON response") .map_err(ApiError::InternalServerError)?; let response = Response::builder() .status(status) .header(http::header::CONTENT_TYPE, "application/json") - .body(Full::new(Bytes::from(json))) + .body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed()) .map_err(|e| ApiError::InternalServerError(e.into()))?; Ok(response) } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index dbe1df8bb0..bbafe25705 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -8,6 +8,7 @@ use futures::future::Either; use futures::StreamExt; use futures::TryFutureExt; use http::header::AUTHORIZATION; +use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; use http_body_util::Full; use hyper1::body::Body; @@ -247,7 +248,7 @@ pub(crate) async fn handle( request: Request, backend: Arc, cancel: CancellationToken, -) -> Result>, ApiError> { +) -> Result>, ApiError> { let result = handle_inner(cancel, config, &ctx, request, backend).await; let mut response = match result { @@ -504,7 +505,7 @@ async fn handle_inner( ctx: &RequestMonitoring, request: Request, backend: Arc, -) -> Result>, SqlOverHttpError> { +) -> Result>, SqlOverHttpError> { let _requeset_gauge = Metrics::get() .proxy .connection_requests @@ -521,35 +522,20 @@ async fn handle_inner( ); match conn_info.auth { - AuthData::Password(pw) => { - let res = handle_db_inner( + AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => { + handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await + } + auth => { + handle_db_inner( cancel, config, ctx, request, conn_info.conn_info, - &pw, + auth, backend, ) - .await?; - Ok(res) - } - AuthData::Jwt(jwt) => { - let keys = backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.conn_info.user_info, - jwt, - ) - .await - .map_err(HttpConnError::from)?; - - let _client = backend - .connect_to_local_proxy(ctx, conn_info.conn_info, keys) - .await?; - - todo!() + .await } } } @@ -560,9 +546,9 @@ async fn handle_db_inner( ctx: &RequestMonitoring, request: Request, conn_info: ConnInfo, - password: &[u8], + auth: AuthData, backend: Arc, -) -> Result>, SqlOverHttpError> { +) -> Result>, SqlOverHttpError> { // // Determine the destination and connection params // @@ -605,14 +591,28 @@ async fn handle_db_inner( let authenticate_and_connect = Box::pin( async { - let keys = backend - .authenticate_with_password( - ctx, - &config.authentication_config, - &conn_info.user_info, - password, - ) - .await?; + let keys = match auth { + AuthData::Password(pw) => { + backend + .authenticate_with_password( + ctx, + &config.authentication_config, + &conn_info.user_info, + &pw, + ) + .await? + } + AuthData::Jwt(jwt) => { + backend + .authenticate_with_jwt( + ctx, + &config.authentication_config, + &conn_info.user_info, + jwt, + ) + .await? + } + }; let client = backend .connect_to_compute(ctx, conn_info, keys, !allow_pool) @@ -673,7 +673,11 @@ async fn handle_db_inner( let len = json_output.len(); let response = response - .body(Full::new(Bytes::from(json_output))) + .body( + Full::new(Bytes::from(json_output)) + .map_err(|x| match x {}) + .boxed(), + ) // only fails if invalid status code or invalid header/values are given. // these are not user configurable so it cannot fail dynamically .expect("building response payload should not fail"); @@ -689,6 +693,49 @@ async fn handle_db_inner( Ok(response) } +async fn handle_auth_broker_inner( + config: &'static ProxyConfig, + ctx: &RequestMonitoring, + request: Request, + conn_info: ConnInfo, + jwt: String, + backend: Arc, +) -> Result>, SqlOverHttpError> { + let keys = backend + .authenticate_with_jwt( + ctx, + &config.authentication_config, + &conn_info.user_info, + jwt, + ) + .await + .map_err(HttpConnError::from)?; + + let mut client = backend.connect_to_local_proxy(ctx, conn_info, keys).await?; + + // always completes instantly in http2 mode + // but good just in case + client.ready().await.map_err(HttpConnError::from)?; + + let (parts, body) = request.into_parts(); + let mut req = Request::builder() + .method("POST") + .uri("http://proxy.local/sql"); + + *req.headers_mut().unwrap() = parts.headers; + + let req = req.body(body).unwrap(); + + // todo: map body to count egress + let _metrics = client.metrics(); + + Ok(client + .send_request(req) + .await + .map_err(HttpConnError::from)? + .map(|b| b.boxed())) +} + impl QueryData { async fn process( self,