proxy: Relax endpoint check (#6503)

## Problem

http-over-sql allowes host to be in format api.aws.... however it's not
the case for the websocket flow.

## Summary of changes

Relax endpoint check for the ws serverless connections.
This commit is contained in:
Anna Khanova
2024-01-28 22:27:14 +01:00
committed by GitHub
parent 3a82430432
commit 8253cf1931
3 changed files with 18 additions and 15 deletions

View File

@@ -2,7 +2,8 @@
use crate::{
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, EndpointId, RoleName,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
EndpointId, RoleName,
};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
@@ -54,10 +55,10 @@ impl ComputeUserInfoMaybeEndpoint {
}
}
pub fn endpoint_sni<'a>(
sni: &'a str,
pub fn endpoint_sni(
sni: &str,
common_names: &HashSet<String>,
) -> Result<&'a str, ComputeUserInfoParseError> {
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
let Some((subdomain, common_name)) = sni.split_once('.') else {
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
};
@@ -66,7 +67,10 @@ pub fn endpoint_sni<'a>(
cn: common_name.into(),
});
}
Ok(subdomain)
if subdomain == SERVERLESS_DRIVER_SNI {
return Ok(None);
}
Ok(Some(EndpointId::from(subdomain)))
}
impl ComputeUserInfoMaybeEndpoint {
@@ -85,7 +89,6 @@ impl ComputeUserInfoMaybeEndpoint {
// record the values if we have them
ctx.set_application(params.get("application_name").map(SmolStr::from));
ctx.set_user(user.clone());
ctx.set_endpoint_id(sni.map(EndpointId::from));
// Project name might be passed via PG's command-line options.
let endpoint_option = params
@@ -103,7 +106,7 @@ impl ComputeUserInfoMaybeEndpoint {
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
Some(EndpointId::from(endpoint_sni(sni_str, cn)?))
endpoint_sni(sni_str, cn)?
} else {
None
}
@@ -117,12 +120,13 @@ impl ComputeUserInfoMaybeEndpoint {
Some(Err(InconsistentProjectNames { domain, option }))
}
// Invariant: project name may not contain certain characters.
(a, b) => a.or(b).map(|name| match project_name_valid(&name) {
(a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
false => Err(MalformedProjectName(name)),
true => Ok(name),
}),
}
.transpose()?;
ctx.set_endpoint_id(endpoint.clone());
info!(%user, project = endpoint.as_deref(), "credentials");
if sni.is_some() {

View File

@@ -41,6 +41,8 @@ use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};
pub const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
ws_listener: TcpListener,

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
use futures::pin_mut;
use futures::StreamExt;
use hyper::body::HttpBody;
@@ -35,11 +36,11 @@ use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use crate::EndpointId;
use crate::RoleName;
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
use super::SERVERLESS_DRIVER_SNI;
#[derive(serde::Deserialize)]
struct QueryData {
@@ -61,7 +62,6 @@ enum Payload {
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
const SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART: &str = "api";
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
@@ -188,9 +188,7 @@ fn get_conn_info(
}
}
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
let endpoint: EndpointId = endpoint.into();
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
ctx.set_endpoint_id(Some(endpoint.clone()));
let pairs = connection_url.query_pairs();
@@ -227,8 +225,7 @@ fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Err
let (_, hostname_rest) = hostname
.split_once('.')
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
Ok(sni_hostname_rest == hostname_rest
&& sni_hostname_first == SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART)
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
}
// TODO: return different http error codes