From 88d1a78260d19e743534f96ed05e80aa568b9879 Mon Sep 17 00:00:00 2001 From: Ruslan Talpa Date: Mon, 30 Jun 2025 12:30:33 +0300 Subject: [PATCH] cleanup the rest path code --- Cargo.lock | 65 ++-- proxy/src/serverless/http_util.rs | 158 +++++++- proxy/src/serverless/rest.rs | 499 ++++++-------------------- proxy/src/serverless/sql_over_http.rs | 171 +-------- 4 files changed, 322 insertions(+), 571 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eb1be2e603..14baeb11ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -496,7 +496,7 @@ dependencies = [ "hex", "hmac", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "once_cell", "p256 0.11.1", "percent-encoding", @@ -637,7 +637,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -655,7 +655,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "http-body 0.4.5", "http-body 1.0.0", "http-body-util", @@ -704,7 +704,7 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -738,7 +738,7 @@ checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -762,7 +762,7 @@ dependencies = [ "form_urlencoded", "futures-util", "headers", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -1310,7 +1310,7 @@ dependencies = [ "fail", "flate2", "futures", - "http 1.1.0", + "http 1.3.1", "indexmap 2.9.0", "itertools 0.10.5", "jsonwebtoken", @@ -2642,7 +2642,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http 1.1.0", + "http 1.3.1", "indexmap 2.9.0", "slab", "tokio", @@ -2724,7 +2724,7 @@ dependencies = [ "base64 0.21.7", "bytes", "headers-core", - "http 1.1.0", + "http 1.3.1", "httpdate", "mime", "sha1", @@ -2736,7 +2736,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" dependencies = [ - "http 1.1.0", + "http 1.3.1", ] [[package]] @@ -2814,9 +2814,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -2841,7 +2841,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", - "http 1.1.0", + "http 1.3.1", ] [[package]] @@ -2852,7 +2852,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "pin-project-lite", ] @@ -2976,7 +2976,7 @@ dependencies = [ "futures-channel", "futures-util", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "httparse", "httpdate", @@ -3009,7 +3009,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c" dependencies = [ "futures-util", - "http 1.1.0", + "http 1.3.1", "hyper 1.4.1", "hyper-util", "rustls 0.22.4", @@ -3041,7 +3041,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "hyper 1.4.1", "pin-project-lite", @@ -4147,7 +4147,7 @@ checksum = "10a8a7f5f6ba7c1b286c2fbca0454eaba116f63bbe69ed250b642d36fbb04d80" dependencies = [ "async-trait", "bytes", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "reqwest", ] @@ -4160,7 +4160,7 @@ checksum = "91cf61a1868dacc576bf2b2a1c3e9ab150af7272909e80085c3173384fe11f76" dependencies = [ "async-trait", "futures-core", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "opentelemetry-http", "opentelemetry-proto", @@ -4388,7 +4388,7 @@ dependencies = [ "hashlink", "hex", "hex-literal", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "humantime-serde", @@ -5329,7 +5329,7 @@ dependencies = [ "hex", "hmac", "hostname", - "http 1.1.0", + "http 1.3.1", "http-body-util", "http-utils", "humantime", @@ -5771,7 +5771,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -5813,7 +5813,7 @@ checksum = "d1ccd3b55e711f91a9885a2fa6fbbb2e39db1776420b062efc058c6410f7e5e3" dependencies = [ "anyhow", "async-trait", - "http 1.1.0", + "http 1.3.1", "reqwest", "serde", "thiserror 1.0.69", @@ -5830,7 +5830,7 @@ dependencies = [ "async-trait", "futures", "getrandom 0.2.11", - "http 1.1.0", + "http 1.3.1", "hyper 1.4.1", "parking_lot 0.11.2", "reqwest", @@ -5851,7 +5851,7 @@ dependencies = [ "anyhow", "async-trait", "getrandom 0.2.11", - "http 1.1.0", + "http 1.3.1", "matchit", "opentelemetry", "reqwest", @@ -6210,7 +6210,7 @@ dependencies = [ "fail", "futures", "hex", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "hyper 0.14.30", @@ -7096,6 +7096,7 @@ dependencies = [ "base64 0.22.1", "csv", "getrandom 0.2.11", + "http 1.3.1", "itertools 0.13.0", "lazy_static", "log", @@ -7717,7 +7718,7 @@ dependencies = [ "async-trait", "base64 0.22.1", "bytes", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "percent-encoding", @@ -7741,7 +7742,7 @@ dependencies = [ "bytes", "flate2", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -7832,7 +7833,7 @@ dependencies = [ "base64 0.22.1", "bitflags 2.8.0", "bytes", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "mime", "pin-project-lite", @@ -7853,7 +7854,7 @@ name = "tower-otel" version = "0.2.0" source = "git+https://github.com/mattiapenati/tower-otel?rev=56a7321053bcb72443888257b622ba0d43a11fcd#56a7321053bcb72443888257b622ba0d43a11fcd" dependencies = [ - "http 1.1.0", + "http 1.3.1", "opentelemetry", "pin-project", "tower-layer", @@ -8034,7 +8035,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8053,7 +8054,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs index 2b0955f857..2fd8d5cee5 100644 --- a/proxy/src/serverless/http_util.rs +++ b/proxy/src/serverless/http_util.rs @@ -3,13 +3,28 @@ use anyhow::Context; use bytes::Bytes; -use http::{Response, StatusCode, HeaderName, HeaderValue}; +use http::{Response, StatusCode, HeaderName, HeaderValue, HeaderMap}; +use http::header::AUTHORIZATION; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full}; use http_utils::error::ApiError; use serde::Serialize; +use url::Url; use uuid::Uuid; +use crate::context::RequestContext; +use crate::config::{AuthenticationConfig, TlsConfig}; +use crate::auth::backend::{ComputeUserInfo}; +use crate::auth::{endpoint_sni, }; +use crate::metrics::{Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; +use crate::proxy::NeonOptions; +use super::conn_pool::ConnInfoWithAuth; +use super::error::{ConnInfoError, Credentials}; +use crate::types::{DbName, RoleName}; +use super::conn_pool::{AuthData}; +use super::conn_pool_lib::{ConnInfo}; + // Common header names used across serverless modules pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id"); pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string"); @@ -124,3 +139,144 @@ pub(crate) fn json_response( .map_err(|e| ApiError::InternalServerError(e.into()))?; Ok(response) } + + +pub(crate) fn get_conn_info( + config: &'static AuthenticationConfig, + ctx: &RequestContext, + connection_string: Option<&str>, + headers: &HeaderMap, + tls: Option<&TlsConfig>, +) -> Result { + let connection_url = match connection_string { + Some(connection_string) => Url::parse(connection_string)?, + None => { + let connection_string = headers + .get(&CONN_STRING) + .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))? + .to_str() + .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?; + Url::parse(connection_string)? + } + }; + + let protocol = connection_url.scheme(); + if protocol != "postgres" && protocol != "postgresql" { + return Err(ConnInfoError::IncorrectScheme); + } + + let mut url_path = connection_url + .path_segments() + .ok_or(ConnInfoError::MissingDbName)?; + + let dbname: DbName = + urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into(); + ctx.set_dbname(dbname.clone()); + + let username = RoleName::from(urlencoding::decode(connection_url.username())?); + if username.is_empty() { + return Err(ConnInfoError::MissingUsername); + } + ctx.set_user(username.clone()); + // TODO: make sure this is right in the context of rest broker + let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { + if !config.accept_jwts { + return Err(ConnInfoError::MissingCredentials(Credentials::Password)); + } + + let auth = auth + .to_str() + .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; + AuthData::Jwt( + auth.strip_prefix("Bearer ") + .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))? + .into(), + ) + } else if let Some(pass) = connection_url.password() { + // wrong credentials provided + if config.accept_jwts { + return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); + } + + AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { + std::borrow::Cow::Borrowed(b) => b.into(), + std::borrow::Cow::Owned(b) => b.into(), + }) + } else if config.accept_jwts { + return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); + } else { + return Err(ConnInfoError::MissingCredentials(Credentials::Password)); + }; + let endpoint = match connection_url.host() { + Some(url::Host::Domain(hostname)) => { + if let Some(tls) = tls { + endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)? + } else { + hostname + .split_once('.') + .map_or(hostname, |(prefix, _)| prefix) + .into() + } + } + Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => { + return Err(ConnInfoError::MissingHostname); + } + }; + ctx.set_endpoint_id(endpoint.clone()); + + let pairs = connection_url.query_pairs(); + + let mut options = Option::None; + + let mut params = StartupMessageParams::default(); + params.insert("user", &username); + params.insert("database", &dbname); + for (key, value) in pairs { + params.insert(&key, &value); + if key == "options" { + options = Some(NeonOptions::parse_options_raw(&value)); + } + } + + // check the URL that was used, for metrics + { + let host_endpoint = headers + // get the host header + .get("host") + // extract the domain + .and_then(|h| { + let (host, _port) = h.to_str().ok()?.split_once(':')?; + Some(host) + }) + // get the endpoint prefix + .map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix)); + + let kind = if host_endpoint == Some(&*endpoint) { + SniKind::Sni + } else { + SniKind::NoSni + }; + + let protocol = ctx.protocol(); + Metrics::get() + .proxy + .accepted_connections_by_sni + .inc(SniGroup { protocol, kind }); + } + + ctx.set_user_agent( + headers + .get(hyper::header::USER_AGENT) + .and_then(|h| h.to_str().ok()) + .map(Into::into), + ); + + let user_info = ComputeUserInfo { + endpoint, + user: username, + options: options.unwrap_or_default(), + }; + + let conn_info = ConnInfo { user_info, dbname }; + Ok(ConnInfoWithAuth { conn_info, auth }) +} \ No newline at end of file diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs index 5410df17c2..1cbd499b85 100644 --- a/proxy/src/serverless/rest.rs +++ b/proxy/src/serverless/rest.rs @@ -2,13 +2,9 @@ use std::sync::Arc; use bytes::Bytes; use http::Method; use http::header::AUTHORIZATION; -use http_body_util::combinators::BoxBody; -use http_body_util::Full; -use http_body_util::{BodyExt}; +use http_body_util::{combinators::BoxBody, Full, BodyExt}; use http_utils::error::ApiError; -use hyper::body::Incoming; -use hyper::http::{HeaderName, HeaderValue}; -use hyper::{HeaderMap, Request, Response, StatusCode}; +use hyper::{body::Incoming, http::{HeaderName, HeaderValue}, Request, Response, StatusCode}; use indexmap::IndexMap; use serde::{Deserialize, Deserializer}; use super::http_conn_pool::{self, Send,}; @@ -16,24 +12,22 @@ use serde_json::{value::RawValue, Value as JsonValue}; use tokio_util::sync::CancellationToken; use tracing::{error, info}; use typed_json::json; -use url::Url; use super::backend::{LocalProxyConnError, PoolingBackend}; -use super::conn_pool::{AuthData, ConnInfoWithAuth}; +use super::conn_pool::{AuthData}; use super::conn_pool_lib::{ConnInfo}; use super::error::{HttpCodeError, ConnInfoError, Credentials, ReadPayloadError}; -use super::http_util::{json_response, uuid_to_header_value, NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY}; +use super::http_util::{ + json_response, uuid_to_header_value, get_conn_info, + NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY +}; use super::json::{JsonConversionError}; -use crate::auth::backend::{ComputeUserInfo, ComputeCredentialKeys}; -use crate::auth::{endpoint_sni, }; -use crate::config::{AuthenticationConfig, ProxyConfig, TlsConfig}; +use crate::auth::backend::{ComputeCredentialKeys}; +use crate::config::{ProxyConfig, }; use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{read_body_with_limit}; -use crate::metrics::{Metrics, SniGroup, SniKind}; -use crate::pqproto::StartupMessageParams; -use crate::proxy::NeonOptions; -use crate::serverless::backend::HttpConnError; -use crate::types::{DbName, RoleName}; +use crate::metrics::{Metrics, }; +use super::backend::HttpConnError; use crate::cache::{TimedLru}; use crate::types::{EndpointCacheKey}; use ouroboros::self_referencing; @@ -52,6 +46,7 @@ use subzero_core::{ config::{db_schemas, db_allowed_select_functions, role_claim_key, /*to_tuple*/}, parser::postgrest::parse, permissions::{check_safe_functions}, + content_range_header, content_range_status }; static MAX_SCHEMA_SIZE: usize = 1024 * 1024 * 5; // 5MB @@ -59,9 +54,6 @@ static MAX_HTTP_BODY_SIZE: usize = 10 * 1024 * 1024; // 10MB limit static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#; const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL; const CONFIGURATION_SQL: &str = POSTGRESQL_CONFIGURATION_SQL; -static HEADERS_TO_FORWARD: &[&HeaderName] = &[ - &AUTHORIZATION, -]; // A wrapper around the DbSchema that allows for self-referencing #[self_referencing] @@ -101,7 +93,7 @@ pub struct ApiConfig { // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint pub(crate) type DbSchemaCache = TimedLru>; impl DbSchemaCache { - pub async fn get_local_or_remote(&self, + pub async fn get_cached_or_remote(&self, endpoint_id: &EndpointCacheKey, auth_header: &HeaderValue, connection_string: &str, @@ -153,61 +145,29 @@ impl DbSchemaCache { ctx: &RequestContext, ) -> Result<(ApiConfig, DbSchemaOwned), RestError> { - let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql"); - let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri); - req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())); - req = req.header(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()); - req = req.header(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()); - req = req.header(AUTHORIZATION, auth_header); - //req = req.header(&ARRAY_MODE, HeaderValue::from_str("true").unwrap()); - req = req.header(&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap()); + let headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + (&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()), + (&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()), + (&AUTHORIZATION, auth_header.clone()), + (&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap()), + ]; - let body = json!({"query": CONFIGURATION_SQL}).to_string(); - let body_boxed = Full::new(Bytes::from(body)) - .map_err(|never| match never {}) // Convert Infallible to hyper::Error - .boxed(); - let req = req - .body(body_boxed) - .map_err(|_| RestError::SubzeroCore(InternalError { - message: "Failed to build request".to_string() - }))?; + let body = serde_json::json!({"query": CONFIGURATION_SQL}); + let (response_status, mut response_json) = make_local_proxy_request(client, headers, body).await?; - // send the request to the local proxy - let response = client - .inner - .inner - .send_request(req) - .await - .map_err(LocalProxyConnError::from) - .map_err(HttpConnError::from)?; - - let response_status = response.status(); - if response_status != StatusCode::OK { return Err(RestError::SubzeroCore(InternalError { - message: "Failed to get configuration data".to_string() + message: "Failed to get endpoint configuration".to_string() })); } - // Capture the response body - let response_body = response - .collect() - .await - .map_err(ReadPayloadError::from)? - .to_bytes(); - - //println!("response_body: {:?}", response_body); - - let mut response_json: serde_json::Value = serde_json::from_slice(&response_body) - .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; - - //println!("response_json: {:?}", response_json); let rows = response_json["rows"].as_array_mut() .ok_or_else(|| RestError::SubzeroCore(InternalError { message: "Missing 'rows' array in second result".to_string() }))?; - //println!("rows: {:?}", rows); + if rows.is_empty() { return Err(RestError::SubzeroCore(InternalError { message: "No rows in second result".to_string() @@ -217,74 +177,41 @@ impl DbSchemaCache { // Extract columns from the first (and only) row let mut row = &mut rows[0]; let config_string = extract_string(&mut row, "config").unwrap_or_default(); - //println!("config_string: {:?}", config_string); - // Parse the JSON response and extract the body content efficiently + + // Parse the configuration response let api_config: ApiConfig = serde_json::from_str(&config_string) .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; - //println!("api_config: {:?}", api_config); - // now that we have the api_config let's run the second INTROSPECTION_SQL query - let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql"); - let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri); - req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())); - req = req.header(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()); - req = req.header(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()); - req = req.header(AUTHORIZATION, auth_header); - //req = req.header(&ARRAY_MODE, HeaderValue::from_str("true").unwrap()); - req = req.header(&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap()); - let body = json!({ + let headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + (&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()), + (&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()), + (&AUTHORIZATION, auth_header.clone()), + (&RAW_TEXT_OUTPUT, HeaderValue::from_str("true").unwrap()), + ]; + + let body = serde_json::json!({ "query": INTROSPECTION_SQL, "params": [ &api_config.db_schemas, false, // include_roles_with_login false, // use_internal_permissions ] - }).to_string(); - let body_boxed = Full::new(Bytes::from(body)) - .map_err(|never| match never {}) // Convert Infallible to hyper::Error - .boxed(); - let req = req - .body(body_boxed) - .map_err(|_| RestError::SubzeroCore(InternalError { - message: "Failed to build request".to_string() - }))?; - - // send the request to the local proxy - let response = client - .inner - .inner - .send_request(req) - .await - .map_err(LocalProxyConnError::from) - .map_err(HttpConnError::from)?; - - let response_status = response.status(); + }); + let (response_status, mut response_json) = make_local_proxy_request(client, headers, body).await?; if response_status != StatusCode::OK { return Err(RestError::SubzeroCore(InternalError { - message: "Failed to get introspection data".to_string() + message: "Failed to get endpoint schema".to_string() })); } - let response_body = response - .collect() - .await - .map_err(ReadPayloadError::from)? - .to_bytes(); - - //println!("second response_body: {:?}", response_body); - - let mut response_json: serde_json::Value = serde_json::from_slice(&response_body) - .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; - - //println!("response_json: {:?}", response_json); let rows = response_json["rows"].as_array_mut() .ok_or_else(|| RestError::SubzeroCore(InternalError { message: "Missing 'rows' array in second result".to_string() }))?; - //println!("rows: {:?}", rows); if rows.is_empty() { return Err(RestError::SubzeroCore(InternalError { message: "No rows in second result".to_string() @@ -297,21 +224,14 @@ impl DbSchemaCache { let string_size = json_schema.len(); if string_size > MAX_SCHEMA_SIZE { - // return Err(RestError::SubzeroCore(InternalError { - // message: format!("Schema is too large: {} bytes, max is {} bytes", string_size, MAX_SCHEMA_SIZE) - // })); return Err(RestError::SchemaTooLarge(MAX_SCHEMA_SIZE, string_size)); } - - let schema_owned = DbSchemaOwned::new(json_schema, |s| { serde_json::from_str::(s.as_str()) .map_err(|e| JsonDeserialize { source: e }) }); - //let schema = schema_owned.borrow_schema().as_ref().unwrap(); - //println!("schema!!!!!: {:?}", schema); // check if schema is an ok result let schema = schema_owned.borrow_schema(); if schema.is_ok() { @@ -323,143 +243,8 @@ impl DbSchemaCache { } } -fn get_conn_info( - config: &'static AuthenticationConfig, - ctx: &RequestContext, - connection_string: &str, - headers: &HeaderMap, - tls: Option<&TlsConfig>, -) -> Result { - // let connection_string = headers - // .get(&CONN_STRING) - // .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))? - // .to_str() - // .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?; - let connection_url = Url::parse(connection_string)?; - - let protocol = connection_url.scheme(); - if protocol != "postgres" && protocol != "postgresql" { - return Err(ConnInfoError::IncorrectScheme); - } - - let mut url_path = connection_url - .path_segments() - .ok_or(ConnInfoError::MissingDbName)?; - - let dbname: DbName = - urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into(); - ctx.set_dbname(dbname.clone()); - - let username = RoleName::from(urlencoding::decode(connection_url.username())?); - if username.is_empty() { - return Err(ConnInfoError::MissingUsername); - } - ctx.set_user(username.clone()); - // TODO: make sure this is right in the context of rest broker - let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { - if !config.accept_jwts { - return Err(ConnInfoError::MissingCredentials(Credentials::Password)); - } - - let auth = auth - .to_str() - .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; - AuthData::Jwt( - auth.strip_prefix("Bearer ") - .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))? - .into(), - ) - } else if let Some(pass) = connection_url.password() { - // wrong credentials provided - if config.accept_jwts { - return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); - } - - AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { - std::borrow::Cow::Borrowed(b) => b.into(), - std::borrow::Cow::Owned(b) => b.into(), - }) - } else if config.accept_jwts { - return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); - } else { - return Err(ConnInfoError::MissingCredentials(Credentials::Password)); - }; - let endpoint = match connection_url.host() { - Some(url::Host::Domain(hostname)) => { - if let Some(tls) = tls { - endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)? - } else { - hostname - .split_once('.') - .map_or(hostname, |(prefix, _)| prefix) - .into() - } - } - Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => { - return Err(ConnInfoError::MissingHostname); - } - }; - ctx.set_endpoint_id(endpoint.clone()); - - let pairs = connection_url.query_pairs(); - - let mut options = Option::None; - - let mut params = StartupMessageParams::default(); - params.insert("user", &username); - params.insert("database", &dbname); - for (key, value) in pairs { - params.insert(&key, &value); - if key == "options" { - options = Some(NeonOptions::parse_options_raw(&value)); - } - } - - // check the URL that was used, for metrics - { - let host_endpoint = headers - // get the host header - .get("host") - // extract the domain - .and_then(|h| { - let (host, _port) = h.to_str().ok()?.split_once(':')?; - Some(host) - }) - // get the endpoint prefix - .map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix)); - - let kind = if host_endpoint == Some(&*endpoint) { - SniKind::Sni - } else { - SniKind::NoSni - }; - - let protocol = ctx.protocol(); - Metrics::get() - .proxy - .accepted_connections_by_sni - .inc(SniGroup { protocol, kind }); - } - - ctx.set_user_agent( - headers - .get(hyper::header::USER_AGENT) - .and_then(|h| h.to_str().ok()) - .map(Into::into), - ); - - let user_info = ComputeUserInfo { - endpoint, - user: username, - options: options.unwrap_or_default(), - }; - - let conn_info = ConnInfo { user_info, dbname }; - Ok(ConnInfoWithAuth { conn_info, auth }) -} - -// A type to represent a postgres errors +// A type to represent a postgresql errors // we use our own type (instead of postgres_client::Error) because we get the error from the json response #[derive(Debug, thiserror::Error)] pub(crate) struct PostgresError { @@ -568,33 +353,14 @@ impl HttpCodeError for RestError { } } } - -// Helper functions for the rest broker - -fn content_range_header(lower: i64, upper: i64, total: Option) -> String { - //debug!("content_range_header: lower: {}, upper: {}, total: {:?}", lower, upper, total); - let range_string = if total != Some(0) && lower <= upper { - format!("{lower}-{upper}") - } else { - "*".to_string() - }; - let total_string = match total { - Some(t) => format!("{t}"), - None => "*".to_string(), - }; - format!("{range_string}/{total_string}") -} - -fn content_range_status(lower: i64, upper: i64, total: Option) -> u16 { - //debug!("content_range_status: lower: {}, upper: {}, total: {:?}", lower, upper, total); - match (lower, upper, total) { - //(_, _, None) => 200, - (l, _, Some(t)) if l > t => 406, - (l, u, Some(t)) if (1 + u - l) < t => 206, - _ => 200, +impl From for RestError { + fn from(e: SubzeroCoreError) -> Self { + RestError::SubzeroCore(e) } } +// Helper functions for the rest broker + fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> { "select " + if env.is_empty() { @@ -606,52 +372,6 @@ fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> { } } -fn current_schema(db_schemas: &Vec, method: &Method, headers: &HeaderMap) -> Result { - match (db_schemas.len() > 1, method, headers.get("accept-profile"), headers.get("content-profile")) { - (false, ..) => Ok(db_schemas.first().unwrap_or(&"_inexistent_".to_string()).clone()), - (_, &Method::DELETE, _, Some(content_profile_header)) - | (_, &Method::POST, _, Some(content_profile_header)) - | (_, &Method::PATCH, _, Some(content_profile_header)) - | (_, &Method::PUT, _, Some(content_profile_header)) => { - match content_profile_header.to_str() { - Ok(content_profile_str) => { - let content_profile = String::from(content_profile_str); - if db_schemas.contains(&content_profile) { - Ok(content_profile) - } else { - Err(SubzeroCoreError::UnacceptableSchema { - schemas: db_schemas.clone(), - }) - } - } - Err(_) => Err(SubzeroCoreError::UnacceptableSchema { - schemas: db_schemas.clone(), - }) - } - } - (_, _, Some(accept_profile_header), _) => { - match accept_profile_header.to_str() { - Ok(accept_profile_str) => { - let accept_profile = String::from(accept_profile_str); - if db_schemas.contains(&accept_profile) { - Ok(accept_profile) - } else { - Err(SubzeroCoreError::UnacceptableSchema { - schemas: db_schemas.clone(), - }) - } - } - Err(_) => Err(SubzeroCoreError::UnacceptableSchema { - schemas: db_schemas.clone(), - }) - } - } - _ => Ok(db_schemas.first().unwrap_or(&"_inexistent_".to_string()).clone()), - } -} - -fn to_core_error(e: SubzeroCoreError) -> RestError { RestError::SubzeroCore(e) } - // TODO: see about removing the need for cloning the values (inner things are &Cow already) fn to_sql_param(p: &Param) -> JsonValue { match p { @@ -690,6 +410,55 @@ fn extract_string(json: &mut serde_json::Value, key: &str) -> Option { } } +async fn make_local_proxy_request( + client: &mut http_conn_pool::Client, + headers: Vec<(&HeaderName, HeaderValue)>, + body: JsonValue, +) -> Result<(StatusCode, JsonValue), RestError> { + let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql"); + let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri); + let req_headers = req.headers_mut().unwrap(); + // Add all provided headers to the request + for (header_name, header_value) in headers.into_iter() { + req_headers.insert(header_name, header_value); + } + + let body_string = body.to_string(); + let body_boxed = Full::new(Bytes::from(body_string)) + .map_err(|never| match never {}) // Convert Infallible to hyper::Error + .boxed(); + + let req = req + .body(body_boxed) + .map_err(|_| RestError::SubzeroCore(InternalError { + message: "Failed to build request".to_string() + }))?; + + // Send the request to the local proxy + let response = client + .inner + .inner + .send_request(req) + .await + .map_err(LocalProxyConnError::from) + .map_err(HttpConnError::from)?; + + let response_status = response.status(); + + // Capture the response body + let response_body = response + .collect() + .await + .map_err(ReadPayloadError::from)? + .to_bytes(); + + // Parse the JSON response + let response_json: serde_json::Value = serde_json::from_slice(&response_body) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + Ok((response_status, response_json)) +} + pub(crate) async fn handle( config: &'static ProxyConfig, ctx: RequestContext, @@ -856,7 +625,7 @@ async fn handle_inner( let conn_info = get_conn_info( &config.authentication_config, ctx, - &connection_string, + Some(&connection_string), request.headers(), // todo: race condition? // we're unlikely to change the common names. @@ -910,7 +679,7 @@ async fn handle_rest_inner( let (parts, originial_body) = request.into_parts(); let headers_map = parts.headers; let auth_header = headers_map.get(AUTHORIZATION).unwrap(); - let entry = db_schema_cache.get_local_or_remote(&endpoint_cache_key, auth_header, &connection_string, &mut client, &ctx).await?; + let entry = db_schema_cache.get_cached_or_remote(&endpoint_cache_key, auth_header, &connection_string, &mut client, &ctx).await?; let (api_config, db_schema_owned) = entry.as_ref(); let db_schema = db_schema_owned.borrow_schema().as_ref().map_err(|_| RestError::SubzeroCore(InternalError { message: "Failed to get schema".to_string() }))?; @@ -927,8 +696,7 @@ async fn handle_rest_inner( let db_allowed_select_functions = api_config.db_allowed_select_functions.iter().map(|s| s.as_str()).collect::>(); // end hardcoded values - - + // extract the jwt claims (we'll need them later to set the role and env) let jwt_claims = match jwt_parsed.keys { ComputeCredentialKeys::JwtPayload(payload_bytes) => { @@ -978,12 +746,13 @@ async fn handle_rest_inner( }?; // pick the current schema from the headers (or the first one from config) - let schema_name = ¤t_schema(db_schemas, &method, &headers_map).map_err(RestError::SubzeroCore)?; + //let schema_name = ¤t_schema(db_schemas, &method, &headers_map).map_err(RestError::SubzeroCore)?; + let schema_name = db_schema.pick_current_schema(&method_str, &headers_map).map_err(RestError::SubzeroCore)?; // add the content-profile header to the response let mut response_headers = vec![]; if db_schemas.len() > 1 { - response_headers.push(("Content-Profile".to_string(), schema_name.clone())); + response_headers.push(("Content-Profile".to_string(), schema_name.to_string())); } // parse the query string into a Vec<(&str, &str)> @@ -994,17 +763,6 @@ async fn handle_rest_inner( let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect(); - let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql"); - let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri); - - // todo(conradludgate): maybe auth-broker should parse these and re-serialize - // these instead just to ensure they remain normalised. - for &h in HEADERS_TO_FORWARD { - if let Some(hv) = headers_map.get(h) { - req = req.header(h, hv); - } - } - // convert the headers map to a HashMap<&str, &str> let headers: HashMap<&str, &str> = headers_map.iter() .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__"))) @@ -1028,7 +786,7 @@ async fn handle_rest_inner( // replace "*" with the list of columns the user has access to // so that he does not encounter permission errors - // replace_select_star(db_schema, schema_name, role, &mut api_request.query).map_err(to_core_error)?; + // replace_select_star(db_schema, schema_name, role, &mut api_request.query)?; let role_str = match role { Some(r) => r, @@ -1036,14 +794,14 @@ async fn handle_rest_inner( }; // this is not relevant when acting as PostgREST // if !disable_internal_permissions { - // check_privileges(db_schema, schema_name, role_str, &api_request).map_err(to_core_error)?; + // check_privileges(db_schema, schema_name, role_str, &api_request)?; // } - check_safe_functions(&api_request, &db_allowed_select_functions).map_err(to_core_error)?; + check_safe_functions(&api_request, &db_allowed_select_functions)?; // this is not relevant when acting as PostgREST // if !disable_internal_permissions { - // insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query).map_err(to_core_error)?; + // insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?; // } // when using internal privileges not switch "current_role" @@ -1089,20 +847,24 @@ async fn handle_rest_inner( } // generate the sql statements let (env_statement, env_parameters, _) = generate(fmt_env_query(&env)); - let (main_statement, main_parameters, _) = generate(fmt_main_query(db_schema, api_request.schema_name, &api_request, &env).map_err(to_core_error)?); + let (main_statement, main_parameters, _) = generate(fmt_main_query(db_schema, api_request.schema_name, &api_request, &env)?); - req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())); - req = req.header(&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()); - req = req.header(&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()); - req = req.header(&ALLOW_POOL, HeaderValue::from_str("true").unwrap()); + let mut headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + (&CONN_STRING, HeaderValue::from_str(connection_string).unwrap()), + (&AUTHORIZATION, auth_header.clone()), + (&TXN_ISOLATION_LEVEL, HeaderValue::from_str("ReadCommitted").unwrap()), + (&ALLOW_POOL, HeaderValue::from_str("true").unwrap()), + ]; + if api_request.read_only { - req = req.header(&TXN_READ_ONLY, HeaderValue::from_str("true").unwrap()); + headers.push((&TXN_READ_ONLY, HeaderValue::from_str("true").unwrap())); } // convert the parameters from subzero core representation to a Vec let env_parameters_json = env_parameters.iter().map(|p| to_sql_param(&p.to_param())).collect::>(); let main_parameters_json = main_parameters.iter().map(|p| to_sql_param(&p.to_param())).collect::>(); - let body: String = json!({ + let body = serde_json::json!({ "queries": [ { "query": env_statement, @@ -1113,42 +875,13 @@ async fn handle_rest_inner( "params": main_parameters_json, } ] - }).to_string(); - - let body_boxed = Full::new(Bytes::from(body)) - .map_err(|never| match never {}) // Convert Infallible to hyper::Error - .boxed(); - - let req = req - .body(body_boxed) - .map_err(|_| RestError::SubzeroCore(InternalError { - message: "Failed to build request".to_string() - }))?; + }); // todo: map body to count egress let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly? // send the request to the local proxy - let response = client - .inner - .inner - .send_request(req) - .await - .map_err(LocalProxyConnError::from) - .map_err(HttpConnError::from)?; - - let response_status = response.status(); - - // Capture the response body - let response_body = response - .collect() - .await - .map_err(ReadPayloadError::from)? - .to_bytes(); - - // Parse the JSON response and extract the body content efficiently - let mut response_json: serde_json::Value = serde_json::from_slice(&response_body) - .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + let (response_status, mut response_json) = make_local_proxy_request(&mut client, headers, body).await?; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index d464451a72..09a7394f50 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -3,44 +3,40 @@ use std::sync::Arc; use bytes::Bytes; use futures::future::{Either, select, try_join}; use futures::{StreamExt, TryFutureExt}; -use http::Method; -use http::header::AUTHORIZATION; -use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full}; +use http::{Method, header::AUTHORIZATION}; +use http_body_util::{combinators::BoxBody, Full, BodyExt}; use http_utils::error::ApiError; use hyper::body::Incoming; -use hyper::http::{HeaderName, HeaderValue}; -use hyper::{HeaderMap, Request, Response, StatusCode, header}; +use hyper::{http::{HeaderName, HeaderValue}, Request, Response, StatusCode, header}; use indexmap::IndexMap; use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; use serde::Serialize; -use serde_json::Value; -use serde_json::value::RawValue; +use serde_json::{Value, value::RawValue}; use tokio::time::{self, Instant}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info}; use typed_json::json; -use url::Url; + use super::backend::{LocalProxyConnError, PoolingBackend}; -use super::conn_pool::{AuthData, ConnInfoWithAuth}; +use super::conn_pool::{AuthData,}; use super::conn_pool_lib::{self, ConnInfo}; -use super::error::{HttpCodeError, ConnInfoError, Credentials, ReadPayloadError}; -use super::http_util::{json_response, uuid_to_header_value, NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ARRAY_MODE, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, TXN_DEFERRABLE}; +use super::error::{HttpCodeError, ConnInfoError, ReadPayloadError}; +use super::http_util::{ + json_response, uuid_to_header_value, get_conn_info, + NEON_REQUEST_ID, CONN_STRING, RAW_TEXT_OUTPUT, ARRAY_MODE, ALLOW_POOL, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, TXN_DEFERRABLE +}; use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json}; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::auth::{endpoint_sni}; -use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig}; +use crate::auth::backend::{ComputeCredentialKeys,}; + +use crate::config::{HttpConfig, ProxyConfig,}; use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{read_body_with_limit}; -use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; -use crate::pqproto::StartupMessageParams; -use crate::proxy::NeonOptions; +use crate::metrics::{HttpDirection, Metrics, }; use crate::serverless::backend::HttpConnError; -use crate::types::{DbName, RoleName}; use crate::usage_metrics::{MetricCounter, MetricCounterRecorder}; use crate::util::run_until_cancelled; @@ -78,142 +74,6 @@ where Ok(json_to_pg_text(json)) } -fn get_conn_info( - config: &'static AuthenticationConfig, - ctx: &RequestContext, - headers: &HeaderMap, - tls: Option<&TlsConfig>, -) -> Result { - let connection_string = headers - .get(&CONN_STRING) - .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))? - .to_str() - .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?; - - let connection_url = Url::parse(connection_string)?; - - let protocol = connection_url.scheme(); - if protocol != "postgres" && protocol != "postgresql" { - return Err(ConnInfoError::IncorrectScheme); - } - - let mut url_path = connection_url - .path_segments() - .ok_or(ConnInfoError::MissingDbName)?; - - let dbname: DbName = - urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into(); - ctx.set_dbname(dbname.clone()); - - let username = RoleName::from(urlencoding::decode(connection_url.username())?); - if username.is_empty() { - return Err(ConnInfoError::MissingUsername); - } - ctx.set_user(username.clone()); - - let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { - if !config.accept_jwts { - return Err(ConnInfoError::MissingCredentials(Credentials::Password)); - } - - let auth = auth - .to_str() - .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; - AuthData::Jwt( - auth.strip_prefix("Bearer ") - .ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))? - .into(), - ) - } else if let Some(pass) = connection_url.password() { - // wrong credentials provided - if config.accept_jwts { - return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); - } - - AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { - std::borrow::Cow::Borrowed(b) => b.into(), - std::borrow::Cow::Owned(b) => b.into(), - }) - } else if config.accept_jwts { - return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt)); - } else { - return Err(ConnInfoError::MissingCredentials(Credentials::Password)); - }; - - let endpoint = match connection_url.host() { - Some(url::Host::Domain(hostname)) => { - if let Some(tls) = tls { - endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)? - } else { - hostname - .split_once('.') - .map_or(hostname, |(prefix, _)| prefix) - .into() - } - } - Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => { - return Err(ConnInfoError::MissingHostname); - } - }; - ctx.set_endpoint_id(endpoint.clone()); - - let pairs = connection_url.query_pairs(); - - let mut options = Option::None; - - let mut params = StartupMessageParams::default(); - params.insert("user", &username); - params.insert("database", &dbname); - for (key, value) in pairs { - params.insert(&key, &value); - if key == "options" { - options = Some(NeonOptions::parse_options_raw(&value)); - } - } - - // check the URL that was used, for metrics - { - let host_endpoint = headers - // get the host header - .get("host") - // extract the domain - .and_then(|h| { - let (host, _port) = h.to_str().ok()?.split_once(':')?; - Some(host) - }) - // get the endpoint prefix - .map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix)); - - let kind = if host_endpoint == Some(&*endpoint) { - SniKind::Sni - } else { - SniKind::NoSni - }; - - let protocol = ctx.protocol(); - Metrics::get() - .proxy - .accepted_connections_by_sni - .inc(SniGroup { protocol, kind }); - } - - ctx.set_user_agent( - headers - .get(hyper::header::USER_AGENT) - .and_then(|h| h.to_str().ok()) - .map(Into::into), - ); - - let user_info = ComputeUserInfo { - endpoint, - user: username, - options: options.unwrap_or_default(), - }; - - let conn_info = ConnInfo { user_info, dbname }; - Ok(ConnInfoWithAuth { conn_info, auth }) -} - pub(crate) async fn handle( config: &'static ProxyConfig, ctx: RequestContext, @@ -544,6 +404,7 @@ async fn handle_inner( let conn_info = get_conn_info( &config.authentication_config, ctx, + None, request.headers(), // todo: race condition? // we're unlikely to change the common names.