diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index dfea84953b..58dceb3bb6 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -7,6 +7,7 @@ mod credentials; pub use credentials::ClientCredentials; mod password_hack; +pub use password_hack::parse_endpoint_param; use password_hack::PasswordHackPayload; mod flow; @@ -44,10 +45,10 @@ pub enum AuthErrorImpl { #[error( "Endpoint ID is not specified. \ Either please upgrade the postgres client library (libpq) for SNI support \ - or pass the endpoint ID (first part of the domain name) as a parameter: '?options=project%3D'. \ + or pass the endpoint ID (first part of the domain name) as a parameter: '?options=endpoint%3D'. \ See more at https://neon.tech/sni" )] - MissingProjectName, + MissingEndpointName, #[error("password authentication failed for user '{0}'")] AuthFailed(Box), @@ -88,7 +89,7 @@ impl UserFacingError for AuthError { AuthFailed(_) => self.to_string(), BadAuthMethod(_) => self.to_string(), MalformedPassword(_) => self.to_string(), - MissingProjectName => self.to_string(), + MissingEndpointName => self.to_string(), Io(_) => "Internal error".to_string(), } } diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index d45806461e..dcc93ec04c 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -52,8 +52,8 @@ pub async fn password_hack( .authenticate() .await?; - info!(project = &payload.project, "received missing parameter"); - creds.project = Some(payload.project); + info!(project = &payload.endpoint, "received missing parameter"); + creds.project = Some(payload.endpoint); let mut node = api.wake_compute(extra, creds).await?; node.config.password(payload.password); diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index b21cd79ddf..6787d82b71 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -1,6 +1,7 @@ //! User credentials used in authentication. -use crate::error::UserFacingError; +use crate::{auth::password_hack::parse_endpoint_param, error::UserFacingError}; +use itertools::Itertools; use pq_proto::StartupMessageParams; use std::collections::HashSet; use thiserror::Error; @@ -61,7 +62,15 @@ impl<'a> ClientCredentials<'a> { // Project name might be passed via PG's command-line options. let project_option = params .options_raw() - .and_then(|mut options| options.find_map(|opt| opt.strip_prefix("project="))) + .and_then(|options| { + // We support both `project` (deprecated) and `endpoint` options for backward compatibility. + // However, if both are present, we don't exactly know which one to use. + // Therefore we require that only one of them is present. + options + .filter_map(parse_endpoint_param) + .at_most_one() + .ok()? + }) .map(|name| name.to_string()); let project_from_domain = if let Some(sni_str) = sni { @@ -177,6 +186,51 @@ mod tests { Ok(()) } + #[test] + fn parse_endpoint_from_options() -> anyhow::Result<()> { + let options = StartupMessageParams::new([ + ("user", "john_doe"), + ("options", "-ckey=1 endpoint=bar -c geqo=off"), + ]); + + let creds = ClientCredentials::parse(&options, None, None)?; + assert_eq!(creds.user, "john_doe"); + assert_eq!(creds.project.as_deref(), Some("bar")); + + Ok(()) + } + + #[test] + fn parse_three_endpoints_from_options() -> anyhow::Result<()> { + let options = StartupMessageParams::new([ + ("user", "john_doe"), + ( + "options", + "-ckey=1 endpoint=one endpoint=two endpoint=three -c geqo=off", + ), + ]); + + let creds = ClientCredentials::parse(&options, None, None)?; + assert_eq!(creds.user, "john_doe"); + assert!(creds.project.is_none()); + + Ok(()) + } + + #[test] + fn parse_when_endpoint_and_project_are_in_options() -> anyhow::Result<()> { + let options = StartupMessageParams::new([ + ("user", "john_doe"), + ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"), + ]); + + let creds = ClientCredentials::parse(&options, None, None)?; + assert_eq!(creds.user, "john_doe"); + assert!(creds.project.is_none()); + + Ok(()) + } + #[test] fn parse_projects_identical() -> anyhow::Result<()> { let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]); diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 4b982c0c5e..190abc9b2e 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -91,7 +91,7 @@ impl AuthFlow<'_, S, PasswordHack> { // the user neither enabled SNI nor resorted to any other method // for passing the project name we rely on. We should show them // the most helpful error message and point to the documentation. - .ok_or(AuthErrorImpl::MissingProjectName)?; + .ok_or(AuthErrorImpl::MissingEndpointName)?; Ok(payload) } diff --git a/proxy/src/auth/password_hack.rs b/proxy/src/auth/password_hack.rs index 639809e18a..33441e8c88 100644 --- a/proxy/src/auth/password_hack.rs +++ b/proxy/src/auth/password_hack.rs @@ -6,27 +6,55 @@ use bstr::ByteSlice; pub struct PasswordHackPayload { - pub project: String, + pub endpoint: String, pub password: Vec, } impl PasswordHackPayload { pub fn parse(bytes: &[u8]) -> Option { // The format is `project=;`. - let mut iter = bytes.strip_prefix(b"project=")?.splitn_str(2, ";"); - let project = iter.next()?.to_str().ok()?.to_owned(); + let mut iter = bytes.splitn_str(2, ";"); + let endpoint = iter.next()?.to_str().ok()?; + let endpoint = parse_endpoint_param(endpoint)?.to_owned(); let password = iter.next()?.to_owned(); - Some(Self { project, password }) + Some(Self { endpoint, password }) } } +pub fn parse_endpoint_param(bytes: &str) -> Option<&str> { + bytes + .strip_prefix("project=") + .or_else(|| bytes.strip_prefix("endpoint=")) +} + #[cfg(test)] mod tests { use super::*; #[test] - fn parse_password_hack_payload() { + fn parse_endpoint_param_fn() { + let input = ""; + assert!(parse_endpoint_param(input).is_none()); + + let input = "project="; + assert_eq!(parse_endpoint_param(input), Some("")); + + let input = "project=foobar"; + assert_eq!(parse_endpoint_param(input), Some("foobar")); + + let input = "endpoint="; + assert_eq!(parse_endpoint_param(input), Some("")); + + let input = "endpoint=foobar"; + assert_eq!(parse_endpoint_param(input), Some("foobar")); + + let input = "other_option=foobar"; + assert!(parse_endpoint_param(input).is_none()); + } + + #[test] + fn parse_password_hack_payload_project() { let bytes = b""; assert!(PasswordHackPayload::parse(bytes).is_none()); @@ -34,13 +62,33 @@ mod tests { assert!(PasswordHackPayload::parse(bytes).is_none()); let bytes = b"project=;"; - let payload = PasswordHackPayload::parse(bytes).expect("parsing failed"); - assert_eq!(payload.project, ""); + let payload: PasswordHackPayload = + PasswordHackPayload::parse(bytes).expect("parsing failed"); + assert_eq!(payload.endpoint, ""); assert_eq!(payload.password, b""); let bytes = b"project=foobar;pass;word"; let payload = PasswordHackPayload::parse(bytes).expect("parsing failed"); - assert_eq!(payload.project, "foobar"); + assert_eq!(payload.endpoint, "foobar"); + assert_eq!(payload.password, b"pass;word"); + } + + #[test] + fn parse_password_hack_payload_endpoint() { + let bytes = b""; + assert!(PasswordHackPayload::parse(bytes).is_none()); + + let bytes = b"endpoint="; + assert!(PasswordHackPayload::parse(bytes).is_none()); + + let bytes = b"endpoint=;"; + let payload = PasswordHackPayload::parse(bytes).expect("parsing failed"); + assert_eq!(payload.endpoint, ""); + assert_eq!(payload.password, b""); + + let bytes = b"endpoint=foobar;pass;word"; + let payload = PasswordHackPayload::parse(bytes).expect("parsing failed"); + assert_eq!(payload.endpoint, "foobar"); assert_eq!(payload.password, b"pass;word"); } } diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index d277940a12..480acb88d9 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,4 +1,4 @@ -use crate::{cancellation::CancelClosure, error::UserFacingError}; +use crate::{auth::parse_endpoint_param, cancellation::CancelClosure, error::UserFacingError}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use pq_proto::StartupMessageParams; @@ -279,7 +279,7 @@ fn filtered_options(params: &StartupMessageParams) -> Option { #[allow(unstable_name_collisions)] let options: String = params .options_raw()? - .filter(|opt| !opt.starts_with("project=")) + .filter(|opt| parse_endpoint_param(opt).is_none()) .intersperse(" ") // TODO: use impl from std once it's stabilized .collect(); diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index ee6349436b..ae914e384e 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -5,16 +5,18 @@ import pytest from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres -def test_proxy_select_1(static_proxy: NeonProxy): +@pytest.mark.parametrize("option_name", ["project", "endpoint"]) +def test_proxy_select_1(static_proxy: NeonProxy, option_name: str): """ A simplest smoke test: check proxy against a local postgres instance. """ - out = static_proxy.safe_psql("select 1", options="project=generic-project-name") + out = static_proxy.safe_psql("select 1", options=f"{option_name}=generic-project-name") assert out[0][0] == 1 -def test_password_hack(static_proxy: NeonProxy): +@pytest.mark.parametrize("option_name", ["project", "endpoint"]) +def test_password_hack(static_proxy: NeonProxy, option_name: str): """ Check the PasswordHack auth flow: an alternative to SCRAM auth for clients which can't provide the project/endpoint name via SNI or `options`. @@ -23,11 +25,12 @@ def test_password_hack(static_proxy: NeonProxy): user = "borat" password = "password" static_proxy.safe_psql( - f"create role {user} with login password '{password}'", options="project=irrelevant" + f"create role {user} with login password '{password}'", + options=f"{option_name}=irrelevant", ) # Note the format of `magic`! - magic = f"project=irrelevant;{password}" + magic = f"{option_name}=irrelevant;{password}" static_proxy.safe_psql("select 1", sslsni=0, user=user, password=magic) # Must also check that invalid magic won't be accepted. @@ -56,55 +59,62 @@ async def test_link_auth(vanilla_pg: VanillaPostgres, link_proxy: NeonProxy): assert out == "42" -def test_proxy_options(static_proxy: NeonProxy): +@pytest.mark.parametrize("option_name", ["project", "endpoint"]) +def test_proxy_options(static_proxy: NeonProxy, option_name: str): """ Check that we pass extra `options` to the PostgreSQL server: - * `project=...` shouldn't be passed at all (otherwise postgres will raise an error). + * `project=...` and `endpoint=...` shouldn't be passed at all + * (otherwise postgres will raise an error). * everything else should be passed as-is. """ - options = "project=irrelevant -cproxytest.option=value" + options = f"{option_name}=irrelevant -cproxytest.option=value" out = static_proxy.safe_psql("show proxytest.option", options=options) assert out[0][0] == "value" - options = "-c proxytest.foo=\\ str project=irrelevant" + options = f"-c proxytest.foo=\\ str {option_name}=irrelevant" out = static_proxy.safe_psql("show proxytest.foo", options=options) assert out[0][0] == " str" -def test_auth_errors(static_proxy: NeonProxy): +@pytest.mark.parametrize("option_name", ["project", "endpoint"]) +def test_auth_errors(static_proxy: NeonProxy, option_name: str): """ Check that we throw very specific errors in some unsuccessful auth scenarios. """ # User does not exist with pytest.raises(psycopg2.Error) as exprinfo: - static_proxy.connect(user="pinocchio", options="project=irrelevant") + static_proxy.connect(user="pinocchio", options=f"{option_name}=irrelevant") text = str(exprinfo.value).strip() assert text.endswith("password authentication failed for user 'pinocchio'") static_proxy.safe_psql( - "create role pinocchio with login password 'magic'", options="project=irrelevant" + "create role pinocchio with login password 'magic'", + options=f"{option_name}=irrelevant", ) # User exists, but password is missing with pytest.raises(psycopg2.Error) as exprinfo: - static_proxy.connect(user="pinocchio", password=None, options="project=irrelevant") + static_proxy.connect(user="pinocchio", password=None, options=f"{option_name}=irrelevant") text = str(exprinfo.value).strip() assert text.endswith("password authentication failed for user 'pinocchio'") # User exists, but password is wrong with pytest.raises(psycopg2.Error) as exprinfo: - static_proxy.connect(user="pinocchio", password="bad", options="project=irrelevant") + static_proxy.connect(user="pinocchio", password="bad", options=f"{option_name}=irrelevant") text = str(exprinfo.value).strip() assert text.endswith("password authentication failed for user 'pinocchio'") # Finally, check that the user can connect - with static_proxy.connect(user="pinocchio", password="magic", options="project=irrelevant"): + with static_proxy.connect( + user="pinocchio", password="magic", options=f"{option_name}=irrelevant" + ): pass -def test_forward_params_to_client(static_proxy: NeonProxy): +@pytest.mark.parametrize("option_name", ["project", "endpoint"]) +def test_forward_params_to_client(static_proxy: NeonProxy, option_name: str): """ Check that we forward all necessary PostgreSQL server params to client. """ @@ -130,7 +140,7 @@ def test_forward_params_to_client(static_proxy: NeonProxy): where name = any(%s) """ - with static_proxy.connect(options="project=irrelevant") as conn: + with static_proxy.connect(options=f"{option_name}=irrelevant") as conn: with conn.cursor() as cur: cur.execute(query, (reported_params_subset,)) for name, value in cur.fetchall(): @@ -138,17 +148,18 @@ def test_forward_params_to_client(static_proxy: NeonProxy): assert conn.get_parameter_status(name) == value +@pytest.mark.parametrize("option_name", ["project", "endpoint"]) @pytest.mark.timeout(5) -def test_close_on_connections_exit(static_proxy: NeonProxy): +def test_close_on_connections_exit(static_proxy: NeonProxy, option_name: str): # Open two connections, send SIGTERM, then ensure that proxy doesn't exit # until after connections close. - with static_proxy.connect(options="project=irrelevant"), static_proxy.connect( - options="project=irrelevant" + with static_proxy.connect(options=f"{option_name}=irrelevant"), static_proxy.connect( + options=f"{option_name}=irrelevant" ): static_proxy.terminate() with pytest.raises(subprocess.TimeoutExpired): static_proxy.wait_for_exit(timeout=2) # Ensure we don't accept any more connections with pytest.raises(psycopg2.OperationalError): - static_proxy.connect(options="project=irrelevant") + static_proxy.connect(options=f"{option_name}=irrelevant") static_proxy.wait_for_exit()