mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 01:12:56 +00:00
put it all together
This commit is contained in:
@@ -620,6 +620,7 @@ mod tests {
|
||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
});
|
||||
|
||||
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
||||
|
||||
@@ -277,6 +277,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
rate_limiter: BucketRateLimiter::new(vec![]),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
},
|
||||
require_client_ip: false,
|
||||
handshake_timeout: Duration::from_secs(10),
|
||||
|
||||
@@ -697,6 +697,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_auth_broker: true,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
|
||||
@@ -71,6 +71,7 @@ pub struct AuthenticationConfig {
|
||||
pub rate_limit_ip_subnet: u8,
|
||||
pub ip_allowlist_check_enabled: bool,
|
||||
pub jwks_cache: JwkCache,
|
||||
pub is_auth_broker: bool,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
|
||||
@@ -20,7 +20,8 @@ use anyhow::Context;
|
||||
use futures::future::{select, Either};
|
||||
use futures::TryFutureExt;
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use hyper1::body::Incoming;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
@@ -364,7 +365,7 @@ async fn request_handler(
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
@@ -408,7 +409,7 @@ async fn request_handler(
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
|
||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
@@ -431,7 +432,7 @@ async fn request_handler(
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Full::new(Bytes::new()))
|
||||
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
use std::{io, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
use hyper1::client::conn::http2;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use tokio::net::{lookup_host, TcpStream};
|
||||
use tracing::{field::display, info};
|
||||
@@ -396,7 +393,7 @@ impl ConnectMechanism for HyperMechanism {
|
||||
Ok(poll_http2_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
self.conn_info.clone(),
|
||||
&self.conn_info,
|
||||
client,
|
||||
connection,
|
||||
self.conn_id,
|
||||
@@ -411,13 +408,7 @@ async fn connect_http2(
|
||||
host: &str,
|
||||
port: u16,
|
||||
timeout: Duration,
|
||||
) -> Result<
|
||||
(
|
||||
http2::SendRequest<Full<Bytes>>,
|
||||
http2::Connection<TokioIo<TcpStream>, Full<Bytes>, TokioExecutor>,
|
||||
),
|
||||
HttpConnError,
|
||||
> {
|
||||
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), HttpConnError> {
|
||||
let mut addrs = lookup_host((host, port)).await?;
|
||||
|
||||
let mut last_err = None;
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use http_body_util::Full;
|
||||
use hyper1::client::conn::http2;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use parking_lot::RwLock;
|
||||
use rand::Rng;
|
||||
use std::collections::VecDeque;
|
||||
use std::ops::DerefMut;
|
||||
use std::{
|
||||
ops::Deref,
|
||||
sync::atomic::{self, AtomicUsize},
|
||||
@@ -18,14 +17,18 @@ use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{context::RequestMonitoring, EndpointCacheKey};
|
||||
|
||||
use tracing::{debug, error, Span};
|
||||
use tracing::{debug, error};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
|
||||
pub(crate) type Connect =
|
||||
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConnPoolEntry {
|
||||
conn: http2::SendRequest<Full<Bytes>>,
|
||||
conn: Send,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
}
|
||||
@@ -215,7 +218,7 @@ impl GlobalConnPool {
|
||||
);
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
Some(Client::new(client.conn, conn_info.clone(), client.aux))
|
||||
Some(Client::new(client.conn, client.aux))
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(
|
||||
@@ -263,9 +266,9 @@ impl GlobalConnPool {
|
||||
pub(crate) fn poll_http2_client(
|
||||
global_pool: Arc<GlobalConnPool>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
client: http2::SendRequest<Full<Bytes>>,
|
||||
connection: http2::Connection<TokioIo<TcpStream>, Full<Bytes>, TokioExecutor>,
|
||||
conn_info: &ConnInfo,
|
||||
client: Send,
|
||||
connection: Connect,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client {
|
||||
@@ -314,7 +317,7 @@ pub(crate) fn poll_http2_client(
|
||||
.instrument(span),
|
||||
);
|
||||
|
||||
Client::new(client, conn_info, aux)
|
||||
Client::new(client, aux)
|
||||
}
|
||||
|
||||
impl Client {
|
||||
@@ -327,32 +330,23 @@ impl Client {
|
||||
}
|
||||
|
||||
pub(crate) struct Client {
|
||||
span: Span,
|
||||
inner: http2::SendRequest<Full<Bytes>>,
|
||||
inner: Send,
|
||||
aux: MetricsAuxInfo,
|
||||
conn_info: ConnInfo,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub(self) fn new(
|
||||
inner: http2::SendRequest<Full<Bytes>>,
|
||||
conn_info: ConnInfo,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
span: Span::current(),
|
||||
conn_info,
|
||||
aux,
|
||||
}
|
||||
}
|
||||
pub(crate) fn inner(&mut self) -> &mut http2::SendRequest<Full<Bytes>> {
|
||||
&mut self.inner
|
||||
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
|
||||
Self { inner, aux }
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Client {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
impl Deref for Client {
|
||||
type Target = http2::SendRequest<Full<Bytes>>;
|
||||
type Target = Send;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
|
||||
@@ -5,13 +5,13 @@ use bytes::Bytes;
|
||||
|
||||
use anyhow::Context;
|
||||
use http::{Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
|
||||
use serde::Serialize;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
/// Like [`ApiError::into_response`]
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
match this {
|
||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||
@@ -64,17 +64,24 @@ struct HttpErrorBody {
|
||||
|
||||
impl HttpErrorBody {
|
||||
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
|
||||
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
fn response_from_msg_and_status(
|
||||
msg: String,
|
||||
status: StatusCode,
|
||||
) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
HttpErrorBody { msg }.to_response(status)
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
|
||||
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
|
||||
.body(
|
||||
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
@@ -83,14 +90,14 @@ impl HttpErrorBody {
|
||||
pub(crate) fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
.body(Full::new(Bytes::from(json)))
|
||||
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ use futures::future::Either;
|
||||
use futures::StreamExt;
|
||||
use futures::TryFutureExt;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Body;
|
||||
@@ -247,7 +248,7 @@ pub(crate) async fn handle(
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
@@ -504,7 +505,7 @@ async fn handle_inner(
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
let _requeset_gauge = Metrics::get()
|
||||
.proxy
|
||||
.connection_requests
|
||||
@@ -521,35 +522,20 @@ async fn handle_inner(
|
||||
);
|
||||
|
||||
match conn_info.auth {
|
||||
AuthData::Password(pw) => {
|
||||
let res = handle_db_inner(
|
||||
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
}
|
||||
auth => {
|
||||
handle_db_inner(
|
||||
cancel,
|
||||
config,
|
||||
ctx,
|
||||
request,
|
||||
conn_info.conn_info,
|
||||
&pw,
|
||||
auth,
|
||||
backend,
|
||||
)
|
||||
.await?;
|
||||
Ok(res)
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
let keys = backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let _client = backend
|
||||
.connect_to_local_proxy(ctx, conn_info.conn_info, keys)
|
||||
.await?;
|
||||
|
||||
todo!()
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -560,9 +546,9 @@ async fn handle_db_inner(
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
password: &[u8],
|
||||
auth: AuthData,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
@@ -605,14 +591,28 @@ async fn handle_db_inner(
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let keys = backend
|
||||
.authenticate_with_password(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
password,
|
||||
)
|
||||
.await?;
|
||||
let keys = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
.authenticate_with_password(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
&pw,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
@@ -673,7 +673,11 @@ async fn handle_db_inner(
|
||||
|
||||
let len = json_output.len();
|
||||
let response = response
|
||||
.body(Full::new(Bytes::from(json_output)))
|
||||
.body(
|
||||
Full::new(Bytes::from(json_output))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
)
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
@@ -689,6 +693,49 @@ async fn handle_db_inner(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn handle_auth_broker_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
jwt: String,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
let keys = backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info, keys).await?;
|
||||
|
||||
// always completes instantly in http2 mode
|
||||
// but good just in case
|
||||
client.ready().await.map_err(HttpConnError::from)?;
|
||||
|
||||
let (parts, body) = request.into_parts();
|
||||
let mut req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("http://proxy.local/sql");
|
||||
|
||||
*req.headers_mut().unwrap() = parts.headers;
|
||||
|
||||
let req = req.body(body).unwrap();
|
||||
|
||||
// todo: map body to count egress
|
||||
let _metrics = client.metrics();
|
||||
|
||||
Ok(client
|
||||
.send_request(req)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?
|
||||
.map(|b| b.boxed()))
|
||||
}
|
||||
|
||||
impl QueryData {
|
||||
async fn process(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user