From c624317b0e2ff73c6c01904e4d883c256c078f22 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Tue, 13 Aug 2024 15:34:10 -0500 Subject: [PATCH] Decode the database name in SQL/HTTP connections A url::Url does not hand you back a URL decoded value for path values, so we must decode them ourselves. Link: https://docs.rs/url/2.5.2/url/struct.Url.html#method.path Link: https://docs.rs/url/2.5.2/url/struct.Url.html#method.path_segments Signed-off-by: Tristan Partin --- proxy/src/serverless/sql_over_http.rs | 4 +++- test_runner/regress/test_proxy.py | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index e5b6536328..c41df07a4d 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -34,6 +34,7 @@ use tracing::error; use tracing::info; use typed_json::json; use url::Url; +use urlencoding; use utils::http::error::ApiError; use crate::auth::backend::ComputeUserInfo; @@ -168,7 +169,8 @@ fn get_conn_info( .path_segments() .ok_or(ConnInfoError::MissingDbName)?; - let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into(); + let dbname: DbName = + urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into(); ctx.set_dbname(dbname.clone()); let username = RoleName::from(urlencoding::decode(connection_url.username())?); diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index f446f4f200..d2b8c2ed8b 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -2,6 +2,7 @@ import asyncio import json import subprocess import time +import urllib.parse from typing import Any, List, Optional, Tuple import psycopg2 @@ -275,6 +276,31 @@ def test_sql_over_http(static_proxy: NeonProxy): assert res["rowCount"] is None +def test_sql_over_http_db_name_with_space(static_proxy: NeonProxy): + db = "db with spaces" + static_proxy.safe_psql_many( + ( + f'create database "{db}"', + "create role http with login password 'http' superuser", + ) + ) + + def q(sql: str, params: Optional[List[Any]] = None) -> Any: + params = params or [] + connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/{urllib.parse.quote(db)}" + response = requests.post( + f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql", + data=json.dumps({"query": sql, "params": params}), + headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr}, + verify=str(static_proxy.test_output_dir / "proxy.crt"), + ) + assert response.status_code == 200, response.text + return response.json() + + rows = q("select 42 as answer")["rows"] + assert rows == [{"answer": 42}] + + def test_sql_over_http_output_options(static_proxy: NeonProxy): static_proxy.safe_psql("create role http2 with login password 'http2' superuser")