mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-05 20:42:54 +00:00
proxy: Relax endpoint check (#6503)
## Problem http-over-sql allowes host to be in format api.aws.... however it's not the case for the websocket flow. ## Summary of changes Relax endpoint check for the ws serverless connections.
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
|
||||
use crate::{
|
||||
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, EndpointId, RoleName,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
|
||||
EndpointId, RoleName,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
@@ -54,10 +55,10 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint_sni<'a>(
|
||||
sni: &'a str,
|
||||
pub fn endpoint_sni(
|
||||
sni: &str,
|
||||
common_names: &HashSet<String>,
|
||||
) -> Result<&'a str, ComputeUserInfoParseError> {
|
||||
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
|
||||
let Some((subdomain, common_name)) = sni.split_once('.') else {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
|
||||
};
|
||||
@@ -66,7 +67,10 @@ pub fn endpoint_sni<'a>(
|
||||
cn: common_name.into(),
|
||||
});
|
||||
}
|
||||
Ok(subdomain)
|
||||
if subdomain == SERVERLESS_DRIVER_SNI {
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(EndpointId::from(subdomain)))
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
@@ -85,7 +89,6 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
// record the values if we have them
|
||||
ctx.set_application(params.get("application_name").map(SmolStr::from));
|
||||
ctx.set_user(user.clone());
|
||||
ctx.set_endpoint_id(sni.map(EndpointId::from));
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let endpoint_option = params
|
||||
@@ -103,7 +106,7 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
|
||||
let endpoint_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
Some(EndpointId::from(endpoint_sni(sni_str, cn)?))
|
||||
endpoint_sni(sni_str, cn)?
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -117,12 +120,13 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
Some(Err(InconsistentProjectNames { domain, option }))
|
||||
}
|
||||
// Invariant: project name may not contain certain characters.
|
||||
(a, b) => a.or(b).map(|name| match project_name_valid(&name) {
|
||||
(a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
|
||||
false => Err(MalformedProjectName(name)),
|
||||
true => Ok(name),
|
||||
}),
|
||||
}
|
||||
.transpose()?;
|
||||
ctx.set_endpoint_id(endpoint.clone());
|
||||
|
||||
info!(%user, project = endpoint.as_deref(), "credentials");
|
||||
if sni.is_some() {
|
||||
|
||||
@@ -41,6 +41,8 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
pub const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
ws_listener: TcpListener,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::Context;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
@@ -35,11 +36,11 @@ use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::EndpointId;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
use super::SERVERLESS_DRIVER_SNI;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct QueryData {
|
||||
@@ -61,7 +62,6 @@ enum Payload {
|
||||
|
||||
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
|
||||
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
|
||||
const SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART: &str = "api";
|
||||
|
||||
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
@@ -188,9 +188,7 @@ fn get_conn_info(
|
||||
}
|
||||
}
|
||||
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
|
||||
|
||||
let endpoint: EndpointId = endpoint.into();
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
|
||||
ctx.set_endpoint_id(Some(endpoint.clone()));
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
@@ -227,8 +225,7 @@ fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Err
|
||||
let (_, hostname_rest) = hostname
|
||||
.split_once('.')
|
||||
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
|
||||
Ok(sni_hostname_rest == hostname_rest
|
||||
&& sni_hostname_first == SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART)
|
||||
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
|
||||
}
|
||||
|
||||
// TODO: return different http error codes
|
||||
|
||||
Reference in New Issue
Block a user