subzero integration WIP2

This commit is contained in:
Ruslan Talpa
2025-06-23 10:11:06 +03:00
parent e121da4bfc
commit 09e62e9b98

View File

@@ -45,6 +45,15 @@ use crate::types::{DbName, RoleName};
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
#[derive(Debug, thiserror::Error)]
@@ -96,6 +105,7 @@ impl UserFacingError for ConnInfoError {
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
connection_string: &str,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
@@ -105,7 +115,6 @@ fn get_conn_info(
// .to_str()
// .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
let connection_string = "postgresql://authenticated@foo.local.neon.build/database";
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
@@ -126,7 +135,7 @@ fn get_conn_info(
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
// TODO: make sure this is right in the context of rest broker
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
@@ -155,7 +164,6 @@ fn get_conn_info(
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
info!("auth passed !!!!!!!!!!!: {auth:?}");
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
if let Some(tls) = tls {
@@ -237,7 +245,6 @@ pub(crate) async fn handle(
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
info!("entered rest:handle!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
@@ -531,7 +538,6 @@ async fn handle_inner(
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
info!("entered rest:handle_inner!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
let _requeset_gauge = Metrics::get()
.proxy
.connection_requests
@@ -541,9 +547,16 @@ async fn handle_inner(
"handling interactive connection from client"
);
let host = request.uri().host().unwrap_or("").split('.').next().unwrap_or("");
let connection_string = format!("postgresql://authenticated@{}.local.neon.build/database", host);
let conn_info = get_conn_info(
&config.authentication_config,
ctx,
&connection_string,
request.headers(),
// todo: race condition?
// we're unlikely to change the common names.
@@ -556,7 +569,7 @@ async fn handle_inner(
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_rest_inner(ctx, request, conn_info.conn_info, jwt, backend).await
handle_rest_inner(ctx, request, &connection_string, conn_info.conn_info, jwt, backend).await
}
_ => {
Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(Credentials::Password)))
@@ -571,9 +584,20 @@ pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
.expect("uuid hyphenated format should be all valid header characters")
}
static HEADERS_TO_STRIP: &[&HeaderName] = &[
//AUTHORIZATION,
&NEON_REQUEST_ID,
&CONN_STRING,
&RAW_TEXT_OUTPUT,
&ARRAY_MODE,
&TXN_ISOLATION_LEVEL,
&TXN_READ_ONLY,
&TXN_DEFERRABLE,
];
async fn handle_rest_inner(
ctx: &RequestContext,
request: Request<Incoming>,
connection_string: &str,
conn_info: ConnInfo,
jwt: String,
backend: Arc<PoolingBackend>,
@@ -597,8 +621,21 @@ async fn handle_rest_inner(
// req = req.header(h, hv);
// }
// }
// forward all headers except the ones in HEADERS_TO_STRIP
for (h, v) in parts.headers.iter() {
if !HEADERS_TO_STRIP.contains(&h) {
req = req.header(h, v);
}
}
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
req = req.header(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap());
// let new_body: String = json!({
// "query": "select 1 as one",
// "params": [],
// }).to_string();
let req = req
.body(body)
.expect("all headers and params received via hyper should be valid for request");
@@ -606,7 +643,6 @@ async fn handle_rest_inner(
// todo: map body to count egress
let _metrics = client.metrics(ctx);
info!("sending request to local proxy !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
Ok(client
.inner
.inner