diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index bdb79f2517..5bf7667a1f 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -2,7 +2,8 @@ use crate::{ auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError, - metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, EndpointId, RoleName, + metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI, + EndpointId, RoleName, }; use itertools::Itertools; use pq_proto::StartupMessageParams; @@ -54,10 +55,10 @@ impl ComputeUserInfoMaybeEndpoint { } } -pub fn endpoint_sni<'a>( - sni: &'a str, +pub fn endpoint_sni( + sni: &str, common_names: &HashSet, -) -> Result<&'a str, ComputeUserInfoParseError> { +) -> Result, ComputeUserInfoParseError> { let Some((subdomain, common_name)) = sni.split_once('.') else { return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() }); }; @@ -66,7 +67,10 @@ pub fn endpoint_sni<'a>( cn: common_name.into(), }); } - Ok(subdomain) + if subdomain == SERVERLESS_DRIVER_SNI { + return Ok(None); + } + Ok(Some(EndpointId::from(subdomain))) } impl ComputeUserInfoMaybeEndpoint { @@ -85,7 +89,6 @@ impl ComputeUserInfoMaybeEndpoint { // record the values if we have them ctx.set_application(params.get("application_name").map(SmolStr::from)); ctx.set_user(user.clone()); - ctx.set_endpoint_id(sni.map(EndpointId::from)); // Project name might be passed via PG's command-line options. let endpoint_option = params @@ -103,7 +106,7 @@ impl ComputeUserInfoMaybeEndpoint { let endpoint_from_domain = if let Some(sni_str) = sni { if let Some(cn) = common_names { - Some(EndpointId::from(endpoint_sni(sni_str, cn)?)) + endpoint_sni(sni_str, cn)? } else { None } @@ -117,12 +120,13 @@ impl ComputeUserInfoMaybeEndpoint { Some(Err(InconsistentProjectNames { domain, option })) } // Invariant: project name may not contain certain characters. - (a, b) => a.or(b).map(|name| match project_name_valid(&name) { + (a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) { false => Err(MalformedProjectName(name)), true => Ok(name), }), } .transpose()?; + ctx.set_endpoint_id(endpoint.clone()); info!(%user, project = endpoint.as_deref(), "credentials"); if sni.is_some() { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 8af008394a..dfef4ccdfa 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -41,6 +41,8 @@ use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; use utils::http::{error::ApiError, json::json_response}; +pub const SERVERLESS_DRIVER_SNI: &str = "api"; + pub async fn task_main( config: &'static ProxyConfig, ws_listener: TcpListener, diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index f108ab34ab..1e2ddaa2ff 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use anyhow::bail; +use anyhow::Context; use futures::pin_mut; use futures::StreamExt; use hyper::body::HttpBody; @@ -35,11 +36,11 @@ use crate::config::TlsConfig; use crate::context::RequestMonitoring; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; -use crate::EndpointId; use crate::RoleName; use super::conn_pool::ConnInfo; use super::conn_pool::GlobalConnPool; +use super::SERVERLESS_DRIVER_SNI; #[derive(serde::Deserialize)] struct QueryData { @@ -61,7 +62,6 @@ enum Payload { const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB -const SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART: &str = "api"; static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); @@ -188,9 +188,7 @@ fn get_conn_info( } } - let endpoint = endpoint_sni(hostname, &tls.common_names)?; - - let endpoint: EndpointId = endpoint.into(); + let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?; ctx.set_endpoint_id(Some(endpoint.clone())); let pairs = connection_url.query_pairs(); @@ -227,8 +225,7 @@ fn check_matches(sni_hostname: &str, hostname: &str) -> Result