diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index a12dff7b72..391fb95e9e 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -60,6 +60,7 @@ enum Payload { const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB +const SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART: &str = "api"; static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); @@ -177,10 +178,11 @@ fn get_conn_info( .and_then(|h| h.to_str().ok()) .and_then(|h| h.split(':').next()); - if hostname != sni_hostname { + // sni_hostname has to be either the same as hostname or the one used in serverless driver. + if !check_matches(&sni_hostname, hostname)? { return Err(anyhow::anyhow!("mismatched SNI hostname and hostname")); } else if let Some(h) = host_header { - if h != hostname { + if h != sni_hostname { return Err(anyhow::anyhow!("mismatched host header and hostname")); } } @@ -214,6 +216,20 @@ fn get_conn_info( }) } +fn check_matches(sni_hostname: &str, hostname: &str) -> Result { + if sni_hostname == hostname { + return Ok(true); + } + let (sni_hostname_first, sni_hostname_rest) = sni_hostname + .split_once('.') + .ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?; + let (_, hostname_rest) = hostname + .split_once('.') + .ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?; + Ok(sni_hostname_rest == hostname_rest + && sni_hostname_first == SERVERLESS_DRIVER_SNI_HOSTNAME_FIRST_PART) +} + // TODO: return different http error codes pub async fn handle( tls: &'static TlsConfig, diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 0f2cd9768f..1d62f09840 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -203,6 +203,21 @@ def test_close_on_connections_exit(static_proxy: NeonProxy): static_proxy.wait_for_exit() +def test_sql_over_http_serverless_driver(static_proxy: NeonProxy): + static_proxy.safe_psql("create role http with login password 'http' superuser") + + connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" + response = requests.post( + f"https://api.localtest.me:{static_proxy.external_http_port}/sql", + data=json.dumps({"query": "select 42 as answer", "params": []}), + headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr}, + verify=str(static_proxy.test_output_dir / "proxy.crt"), + ) + assert response.status_code == 200, response.text + rows = response.json()["rows"] + assert rows == [{"answer": 42}] + + def test_sql_over_http(static_proxy: NeonProxy): static_proxy.safe_psql("create role http with login password 'http' superuser")