put it all together

This commit is contained in:
Conrad Ludgate
2024-09-18 14:28:59 +01:00
parent 2703abccc7
commit 8e7d2aab76
9 changed files with 128 additions and 84 deletions

View File

@@ -620,6 +620,7 @@ mod tests {
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
});
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {

View File

@@ -277,6 +277,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
},
require_client_ip: false,
handshake_timeout: Duration::from_secs(10),

View File

@@ -697,6 +697,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: true,
};
let config = Box::leak(Box::new(ProxyConfig {

View File

@@ -71,6 +71,7 @@ pub struct AuthenticationConfig {
pub rate_limit_ip_subnet: u8,
pub ip_allowlist_check_enabled: bool,
pub jwks_cache: JwkCache,
pub is_auth_broker: bool,
}
impl TlsConfig {

View File

@@ -20,7 +20,8 @@ use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::Full;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use hyper1::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
@@ -364,7 +365,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let host = request
.headers()
.get("host")
@@ -408,7 +409,7 @@ async fn request_handler(
);
// Return the response so the spawned future can continue.
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
@@ -431,7 +432,7 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Full::new(Bytes::new()))
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")

View File

@@ -1,9 +1,6 @@
use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::Full;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::net::{lookup_host, TcpStream};
use tracing::{field::display, info};
@@ -396,7 +393,7 @@ impl ConnectMechanism for HyperMechanism {
Ok(poll_http2_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
&self.conn_info,
client,
connection,
self.conn_id,
@@ -411,13 +408,7 @@ async fn connect_http2(
host: &str,
port: u16,
timeout: Duration,
) -> Result<
(
http2::SendRequest<Full<Bytes>>,
http2::Connection<TokioIo<TcpStream>, Full<Bytes>, TokioExecutor>,
),
HttpConnError,
> {
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), HttpConnError> {
let mut addrs = lookup_host((host, port)).await?;
let mut last_err = None;

View File

@@ -1,11 +1,10 @@
use bytes::Bytes;
use dashmap::DashMap;
use http_body_util::Full;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::collections::VecDeque;
use std::ops::DerefMut;
use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
@@ -18,14 +17,18 @@ use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, EndpointCacheKey};
use tracing::{debug, error, Span};
use tracing::{debug, error};
use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
#[derive(Clone)]
struct ConnPoolEntry {
conn: http2::SendRequest<Full<Bytes>>,
conn: Send,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
@@ -215,7 +218,7 @@ impl GlobalConnPool {
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Some(Client::new(client.conn, conn_info.clone(), client.aux))
Some(Client::new(client.conn, client.aux))
}
fn get_or_create_endpoint_pool(
@@ -263,9 +266,9 @@ impl GlobalConnPool {
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool>,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
client: http2::SendRequest<Full<Bytes>>,
connection: http2::Connection<TokioIo<TcpStream>, Full<Bytes>, TokioExecutor>,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client {
@@ -314,7 +317,7 @@ pub(crate) fn poll_http2_client(
.instrument(span),
);
Client::new(client, conn_info, aux)
Client::new(client, aux)
}
impl Client {
@@ -327,32 +330,23 @@ impl Client {
}
pub(crate) struct Client {
span: Span,
inner: http2::SendRequest<Full<Bytes>>,
inner: Send,
aux: MetricsAuxInfo,
conn_info: ConnInfo,
}
impl Client {
pub(self) fn new(
inner: http2::SendRequest<Full<Bytes>>,
conn_info: ConnInfo,
aux: MetricsAuxInfo,
) -> Self {
Self {
inner,
span: Span::current(),
conn_info,
aux,
}
}
pub(crate) fn inner(&mut self) -> &mut http2::SendRequest<Full<Bytes>> {
&mut self.inner
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
Self { inner, aux }
}
}
impl DerefMut for Client {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl Deref for Client {
type Target = http2::SendRequest<Full<Bytes>>;
type Target = Send;
fn deref(&self) -> &Self::Target {
&self.inner

View File

@@ -5,13 +5,13 @@ use bytes::Bytes;
use anyhow::Context;
use http::{Response, StatusCode};
use http_body_util::Full;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use serde::Serialize;
use utils::http::error::ApiError;
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
@@ -64,17 +64,24 @@ struct HttpErrorBody {
impl HttpErrorBody {
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
fn response_from_msg_and_status(
msg: String,
status: StatusCode,
) -> Response<BoxBody<Bytes, hyper1::Error>> {
HttpErrorBody { msg }.to_response(status)
}
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
.body(
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
.map_err(|x| match x {})
.boxed(),
)
.unwrap()
}
}
@@ -83,14 +90,14 @@ impl HttpErrorBody {
pub(crate) fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;
let response = Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json)))
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}

View File

@@ -8,6 +8,7 @@ use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use http::header::AUTHORIZATION;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
@@ -247,7 +248,7 @@ pub(crate) async fn handle(
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
@@ -504,7 +505,7 @@ async fn handle_inner(
ctx: &RequestMonitoring,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get()
.proxy
.connection_requests
@@ -521,35 +522,20 @@ async fn handle_inner(
);
match conn_info.auth {
AuthData::Password(pw) => {
let res = handle_db_inner(
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
cancel,
config,
ctx,
request,
conn_info.conn_info,
&pw,
auth,
backend,
)
.await?;
Ok(res)
}
AuthData::Jwt(jwt) => {
let keys = backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let _client = backend
.connect_to_local_proxy(ctx, conn_info.conn_info, keys)
.await?;
todo!()
.await
}
}
}
@@ -560,9 +546,9 @@ async fn handle_db_inner(
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
password: &[u8],
auth: AuthData,
backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
//
// Determine the destination and connection params
//
@@ -605,14 +591,28 @@ async fn handle_db_inner(
let authenticate_and_connect = Box::pin(
async {
let keys = backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
password,
)
.await?;
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
&pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await?
}
};
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
@@ -673,7 +673,11 @@ async fn handle_db_inner(
let len = json_output.len();
let response = response
.body(Full::new(Bytes::from(json_output)))
.body(
Full::new(Bytes::from(json_output))
.map_err(|x| match x {})
.boxed(),
)
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
@@ -689,6 +693,49 @@ async fn handle_db_inner(
Ok(response)
}
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
let keys = backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info, keys).await?;
// always completes instantly in http2 mode
// but good just in case
client.ready().await.map_err(HttpConnError::from)?;
let (parts, body) = request.into_parts();
let mut req = Request::builder()
.method("POST")
.uri("http://proxy.local/sql");
*req.headers_mut().unwrap() = parts.headers;
let req = req.body(body).unwrap();
// todo: map body to count egress
let _metrics = client.metrics();
Ok(client
.send_request(req)
.await
.map_err(HttpConnError::from)?
.map(|b| b.boxed()))
}
impl QueryData {
async fn process(
self,