From 09e62e9b98651b8ce3a5b3884acea7cf6664f49a Mon Sep 17 00:00:00 2001 From: Ruslan Talpa Date: Mon, 23 Jun 2025 10:11:06 +0300 Subject: [PATCH] subzero integration WIP2 --- proxy/src/serverless/rest.rs | 52 ++++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs index 7e1aab6fdd..caf9436656 100644 --- a/proxy/src/serverless/rest.rs +++ b/proxy/src/serverless/rest.rs @@ -45,6 +45,15 @@ use crate::types::{DbName, RoleName}; pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id"); static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string"); +static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); +static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); +static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); +static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level"); +static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only"); +static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable"); + +static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); + #[derive(Debug, thiserror::Error)] @@ -96,6 +105,7 @@ impl UserFacingError for ConnInfoError { fn get_conn_info( config: &'static AuthenticationConfig, ctx: &RequestContext, + connection_string: &str, headers: &HeaderMap, tls: Option<&TlsConfig>, ) -> Result { @@ -105,7 +115,6 @@ fn get_conn_info( // .to_str() // .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?; - let connection_string = "postgresql://authenticated@foo.local.neon.build/database"; let connection_url = Url::parse(connection_string)?; let protocol = connection_url.scheme(); @@ -126,7 +135,7 @@ fn get_conn_info( return Err(ConnInfoError::MissingUsername); } ctx.set_user(username.clone()); - + // TODO: make sure this is right in the context of rest broker let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { if !config.accept_jwts { return Err(ConnInfoError::MissingCredentials(Credentials::Password)); @@ -155,7 +164,6 @@ fn get_conn_info( } else { return Err(ConnInfoError::MissingCredentials(Credentials::Password)); }; - info!("auth passed !!!!!!!!!!!: {auth:?}"); let endpoint = match connection_url.host() { Some(url::Host::Domain(hostname)) => { if let Some(tls) = tls { @@ -237,7 +245,6 @@ pub(crate) async fn handle( backend: Arc, cancel: CancellationToken, ) -> Result>, ApiError> { - info!("entered rest:handle!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); let result = handle_inner(cancel, config, &ctx, request, backend).await; let mut response = match result { @@ -531,7 +538,6 @@ async fn handle_inner( request: Request, backend: Arc, ) -> Result>, RestError> { - info!("entered rest:handle_inner!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); let _requeset_gauge = Metrics::get() .proxy .connection_requests @@ -541,9 +547,16 @@ async fn handle_inner( "handling interactive connection from client" ); + + let host = request.uri().host().unwrap_or("").split('.').next().unwrap_or(""); + let connection_string = format!("postgresql://authenticated@{}.local.neon.build/database", host); + + + let conn_info = get_conn_info( &config.authentication_config, ctx, + &connection_string, request.headers(), // todo: race condition? // we're unlikely to change the common names. @@ -556,7 +569,7 @@ async fn handle_inner( match conn_info.auth { AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => { - handle_rest_inner(ctx, request, conn_info.conn_info, jwt, backend).await + handle_rest_inner(ctx, request, &connection_string, conn_info.conn_info, jwt, backend).await } _ => { Err(RestError::ConnInfo(ConnInfoError::MissingCredentials(Credentials::Password))) @@ -571,9 +584,20 @@ pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue { .expect("uuid hyphenated format should be all valid header characters") } +static HEADERS_TO_STRIP: &[&HeaderName] = &[ + //AUTHORIZATION, + &NEON_REQUEST_ID, + &CONN_STRING, + &RAW_TEXT_OUTPUT, + &ARRAY_MODE, + &TXN_ISOLATION_LEVEL, + &TXN_READ_ONLY, + &TXN_DEFERRABLE, +]; async fn handle_rest_inner( ctx: &RequestContext, request: Request, + connection_string: &str, conn_info: ConnInfo, jwt: String, backend: Arc, @@ -597,8 +621,21 @@ async fn handle_rest_inner( // req = req.header(h, hv); // } // } + // forward all headers except the ones in HEADERS_TO_STRIP + for (h, v) in parts.headers.iter() { + if !HEADERS_TO_STRIP.contains(&h) { + req = req.header(h, v); + } + } req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())); - + req = req.header(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()); + + // let new_body: String = json!({ + // "query": "select 1 as one", + // "params": [], + // }).to_string(); + + let req = req .body(body) .expect("all headers and params received via hyper should be valid for request"); @@ -606,7 +643,6 @@ async fn handle_rest_inner( // todo: map body to count egress let _metrics = client.metrics(ctx); - info!("sending request to local proxy !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); Ok(client .inner .inner