From 4d37f891899add370f81d3238f7e5ac971fca004 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 11 Jan 2024 13:20:57 +0000 Subject: [PATCH] proxy http: remove need for exact endpoint match --- proxy/src/auth/credentials.rs | 6 +++--- proxy/src/serverless/sql_over_http.rs | 7 +++---- test_runner/regress/test_proxy.py | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index ada7f3614c..546eb516d4 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -56,7 +56,7 @@ impl ComputeUserInfoMaybeEndpoint { pub fn endpoint_sni<'a>( sni: &'a str, common_names: &HashSet, -) -> Result<&'a str, ComputeUserInfoParseError> { +) -> Result<(&'a str, &'a str), ComputeUserInfoParseError> { let Some((subdomain, common_name)) = sni.split_once('.') else { return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() }); }; @@ -65,7 +65,7 @@ pub fn endpoint_sni<'a>( cn: common_name.into(), }); } - Ok(subdomain) + Ok((subdomain, common_name)) } impl ComputeUserInfoMaybeEndpoint { @@ -102,7 +102,7 @@ impl ComputeUserInfoMaybeEndpoint { let project_from_domain = if let Some(sni_str) = sni { if let Some(cn) = common_names { - Some(SmolStr::from(endpoint_sni(sni_str, cn)?)) + Some(SmolStr::from(endpoint_sni(sni_str, cn)?.0)) } else { None } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 719559ed48..622f1844b7 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -170,22 +170,21 @@ fn get_conn_info( let hostname = connection_url .host_str() .ok_or(anyhow::anyhow!("no host"))?; + let (endpoint, common_name) = endpoint_sni(hostname, &tls.common_names)?; let host_header = headers .get("host") .and_then(|h| h.to_str().ok()) .and_then(|h| h.split(':').next()); - if hostname != sni_hostname { + if !sni_hostname.ends_with(common_name) { return Err(anyhow::anyhow!("mismatched SNI hostname and hostname")); } else if let Some(h) = host_header { - if h != hostname { + if !h.ends_with(common_name) { return Err(anyhow::anyhow!("mismatched host header and hostname")); } } - let endpoint = endpoint_sni(hostname, &tls.common_names)?; - let endpoint: SmolStr = endpoint.into(); ctx.set_endpoint_id(Some(endpoint.clone())); diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 0f2cd9768f..7735f00e05 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -500,3 +500,21 @@ def test_sql_over_http_pool_custom_types(static_proxy: NeonProxy): "select array['foo'::foo, 'bar'::foo, 'baz'::foo] as data", ) assert response["rows"][0]["data"] == ["foo", "bar", "baz"] + +def test_sql_over_http_different_endpoint(static_proxy: NeonProxy): + static_proxy.safe_psql("create role http with login password 'http' superuser") + + def q(sql: str, params: Optional[List[Any]] = None) -> Any: + params = params or [] + connstr = f"postgresql://http:http@my-endpoint.{static_proxy.domain}:{static_proxy.proxy_port}/postgres" + response = requests.post( + f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql", + data=json.dumps({"query": sql, "params": params}), + headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr}, + verify=str(static_proxy.test_output_dir / "proxy.crt"), + ) + assert response.status_code == 200, response.text + return response.json() + + rows = q("select 42 as answer")["rows"] + assert rows == [{"answer": 42}]