proxy: Relax endpoint check (#6503)

## Problem

http-over-sql allowes host to be in format api.aws.... however it's not
the case for the websocket flow.

## Summary of changes

Relax endpoint check for the ws serverless connections.
This commit is contained in:
Anna Khanova
2024-01-28 22:27:14 +01:00
committed by GitHub
parent 3a82430432
commit 8253cf1931
3 changed files with 18 additions and 15 deletions

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
use futures::pin_mut;
use futures::StreamExt;
use hyper::body::HttpBody;
@@ -35,11 +36,11 @@ use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use crate::EndpointId;
use crate::RoleName;
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
use super::SERVERLESS_DRIVER_SNI;
#[derive(serde::Deserialize)]
struct QueryData {
@@ -61,7 +62,6 @@ enum Payload {
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
const SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART: &str = "api";
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
@@ -188,9 +188,7 @@ fn get_conn_info(
}
}
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
let endpoint: EndpointId = endpoint.into();
let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?;
ctx.set_endpoint_id(Some(endpoint.clone()));
let pairs = connection_url.query_pairs();
@@ -227,8 +225,7 @@ fn check_matches(sni_hostname: &str, hostname: &str) -> Result<bool, anyhow::Err
let (_, hostname_rest) = hostname
.split_once('.')
.ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?;
Ok(sni_hostname_rest == hostname_rest
&& sni_hostname_first == SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART)
Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI)
}
// TODO: return different http error codes