From f9b3a2e059c4c1fd764da011e7a5364d86e60491 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Tue, 6 May 2025 14:51:10 -0500 Subject: [PATCH] Add scoping to compute_ctl JWT claims (#11639) Currently we only have an admin scope which allows a user to bypass the compute_id check. When the admin scope is provided, validate the audience of the JWT to be "compute". Closes: https://github.com/neondatabase/cloud/issues/27614 Signed-off-by: Tristan Partin --- .../src/http/middleware/authorize.rs | 59 +++++++++++--- control_plane/src/bin/neon_local.rs | 20 +++-- control_plane/src/endpoint.rs | 20 +++-- libs/compute_api/src/requests.rs | 41 +++++++++- test_runner/fixtures/endpoint/http.py | 12 +++ test_runner/fixtures/neon_cli.py | 7 +- test_runner/fixtures/neon_fixtures.py | 12 ++- test_runner/regress/test_compute_http.py | 78 +++++++++++++++++++ 8 files changed, 222 insertions(+), 27 deletions(-) create mode 100644 test_runner/regress/test_compute_http.py diff --git a/compute_tools/src/http/middleware/authorize.rs b/compute_tools/src/http/middleware/authorize.rs index 2d0f411d7a..2afc57ad9c 100644 --- a/compute_tools/src/http/middleware/authorize.rs +++ b/compute_tools/src/http/middleware/authorize.rs @@ -1,12 +1,10 @@ -use std::collections::HashSet; - use anyhow::{Result, anyhow}; use axum::{RequestExt, body::Body}; use axum_extra::{ TypedHeader, headers::{Authorization, authorization::Bearer}, }; -use compute_api::requests::ComputeClaims; +use compute_api::requests::{COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope}; use futures::future::BoxFuture; use http::{Request, Response, StatusCode}; use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, jwk::JwkSet}; @@ -25,13 +23,14 @@ pub(in crate::http) struct Authorize { impl Authorize { pub fn new(compute_id: String, jwks: JwkSet) -> Self { let mut validation = Validation::new(Algorithm::EdDSA); - // Nothing is currently required - validation.required_spec_claims = HashSet::new(); validation.validate_exp = true; // Unused by the control plane - validation.validate_aud = false; - // Unused by the control plane validation.validate_nbf = false; + // Unused by the control plane + validation.validate_aud = false; + validation.set_audience(&[COMPUTE_AUDIENCE]); + // Nothing is currently required + validation.set_required_spec_claims(&[] as &[&str; 0]); Self { compute_id, @@ -64,11 +63,47 @@ impl AsyncAuthorizeRequest for Authorize { Err(e) => return Err(JsonResponse::error(StatusCode::UNAUTHORIZED, e)), }; - if data.claims.compute_id != compute_id { - return Err(JsonResponse::error( - StatusCode::UNAUTHORIZED, - "invalid compute ID in authorization token claims", - )); + match data.claims.scope { + // TODO: We should validate audience for every token, but + // instead of this ad-hoc validation, we should turn + // [`Validation::validate_aud`] on. This is merely a stopgap + // while we roll out `aud` deployment. We return a 401 + // Unauthorized because when we eventually do use + // [`Validation`], we will hit the above `Err` match arm which + // returns 401 Unauthorized. + Some(ComputeClaimsScope::Admin) => { + let Some(ref audience) = data.claims.audience else { + return Err(JsonResponse::error( + StatusCode::UNAUTHORIZED, + "missing audience in authorization token claims", + )); + }; + + if audience != COMPUTE_AUDIENCE { + return Err(JsonResponse::error( + StatusCode::UNAUTHORIZED, + "invalid audience in authorization token claims", + )); + } + } + + // If the scope is not [`ComputeClaimsScope::Admin`], then we + // must validate the compute_id + _ => { + let Some(ref claimed_compute_id) = data.claims.compute_id else { + return Err(JsonResponse::error( + StatusCode::FORBIDDEN, + "missing compute_id in authorization token claims", + )); + }; + + if *claimed_compute_id != compute_id { + return Err(JsonResponse::error( + StatusCode::FORBIDDEN, + "invalid compute ID in authorization token claims", + )); + } + } } // Make claims available to any subsequent middleware or request diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 6f55c0310f..44698f7b23 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -16,6 +16,7 @@ use std::time::Duration; use anyhow::{Context, Result, anyhow, bail}; use clap::Parser; +use compute_api::requests::ComputeClaimsScope; use compute_api::spec::ComputeMode; use control_plane::broker::StorageBroker; use control_plane::endpoint::ComputeControlPlane; @@ -705,6 +706,9 @@ struct EndpointStopCmdArgs { struct EndpointGenerateJwtCmdArgs { #[clap(help = "Postgres endpoint id")] endpoint_id: String, + + #[clap(short = 's', long, help = "Scope to generate the JWT with", value_parser = ComputeClaimsScope::from_str)] + scope: Option, } #[derive(clap::Subcommand)] @@ -1540,12 +1544,16 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res endpoint.stop(&args.mode, args.destroy)?; } EndpointCmd::GenerateJwt(args) => { - let endpoint_id = &args.endpoint_id; - let endpoint = cplane - .endpoints - .get(endpoint_id) - .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; - let jwt = endpoint.generate_jwt()?; + let endpoint = { + let endpoint_id = &args.endpoint_id; + + cplane + .endpoints + .get(endpoint_id) + .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))? + }; + + let jwt = endpoint.generate_jwt(args.scope)?; print!("{jwt}"); } diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 4071b620d6..0b16339a6f 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -45,7 +45,9 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use anyhow::{Context, Result, anyhow, bail}; -use compute_api::requests::{ComputeClaims, ConfigurationRequest}; +use compute_api::requests::{ + COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest, +}; use compute_api::responses::{ ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TlsConfig, }; @@ -630,9 +632,17 @@ impl Endpoint { } /// Generate a JWT with the correct claims. - pub fn generate_jwt(&self) -> Result { + pub fn generate_jwt(&self, scope: Option) -> Result { self.env.generate_auth_token(&ComputeClaims { - compute_id: self.endpoint_id.clone(), + audience: match scope { + Some(ComputeClaimsScope::Admin) => Some(COMPUTE_AUDIENCE.to_owned()), + _ => Some(self.endpoint_id.clone()), + }, + compute_id: match scope { + Some(ComputeClaimsScope::Admin) => None, + _ => Some(self.endpoint_id.clone()), + }, + scope, }) } @@ -903,7 +913,7 @@ impl Endpoint { self.external_http_address.port() ), ) - .bearer_auth(self.generate_jwt()?) + .bearer_auth(self.generate_jwt(None::)?) .send() .await?; @@ -980,7 +990,7 @@ impl Endpoint { self.external_http_address.port() )) .header(CONTENT_TYPE.as_str(), "application/json") - .bearer_auth(self.generate_jwt()?) + .bearer_auth(self.generate_jwt(None::)?) .body( serde_json::to_string(&ConfigurationRequest { spec, diff --git a/libs/compute_api/src/requests.rs b/libs/compute_api/src/requests.rs index 98f2fc297c..40d34eccea 100644 --- a/libs/compute_api/src/requests.rs +++ b/libs/compute_api/src/requests.rs @@ -1,16 +1,55 @@ //! Structs representing the JSON formats used in the compute_ctl's HTTP API. +use std::str::FromStr; + use serde::{Deserialize, Serialize}; use crate::privilege::Privilege; use crate::responses::ComputeCtlConfig; use crate::spec::{ComputeSpec, ExtVersion, PgIdent}; +/// The value to place in the [`ComputeClaims::audience`] claim. +pub static COMPUTE_AUDIENCE: &str = "compute"; + +#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +/// Available scopes for a compute's JWT. +pub enum ComputeClaimsScope { + /// An admin-scoped token allows access to all of `compute_ctl`'s authorized + /// facilities. + Admin, +} + +impl FromStr for ComputeClaimsScope { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "admin" => Ok(ComputeClaimsScope::Admin), + _ => Err(anyhow::anyhow!("invalid compute claims scope \"{s}\"")), + } + } +} + /// When making requests to the `compute_ctl` external HTTP server, the client /// must specify a set of claims in `Authorization` header JWTs such that /// `compute_ctl` can authorize the request. #[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(rename = "snake_case")] pub struct ComputeClaims { - pub compute_id: String, + /// The compute ID that will validate the token. The only case in which this + /// can be [`None`] is if [`Self::scope`] is + /// [`ComputeClaimsScope::Admin`]. + pub compute_id: Option, + + /// The scope of what the token authorizes. + pub scope: Option, + + /// The recipient the token is intended for. + /// + /// See [RFC 7519](https://www.rfc-editor.org/rfc/rfc7519#section-4.1.3) for + /// more information. + #[serde(rename = "aud")] + pub audience: Option, } /// Request of the /configure API diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index 652c38f5c3..beed1dcd93 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -1,6 +1,7 @@ from __future__ import annotations import urllib.parse +from enum import StrEnum from typing import TYPE_CHECKING, final import requests @@ -14,6 +15,17 @@ if TYPE_CHECKING: from requests import PreparedRequest +COMPUTE_AUDIENCE = "compute" +""" +The value to place in the `aud` claim. +""" + + +@final +class ComputeClaimsScope(StrEnum): + ADMIN = "admin" + + @final class BearerAuth(AuthBase): """ diff --git a/test_runner/fixtures/neon_cli.py b/test_runner/fixtures/neon_cli.py index b5d69b5ab6..3be78719d7 100644 --- a/test_runner/fixtures/neon_cli.py +++ b/test_runner/fixtures/neon_cli.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: Any, ) + from fixtures.endpoint.http import ComputeClaimsScope from fixtures.pg_version import PgVersion @@ -535,12 +536,16 @@ class NeonLocalCli(AbstractNeonCli): res.check_returncode() return res - def endpoint_generate_jwt(self, endpoint_id: str) -> str: + def endpoint_generate_jwt( + self, endpoint_id: str, scope: ComputeClaimsScope | None = None + ) -> str: """ Generate a JWT for making requests to the endpoint's external HTTP server. """ args = ["endpoint", "generate-jwt", endpoint_id] + if scope: + args += ["--scope", str(scope)] cmd = self.raw_cli(args) cmd.check_returncode() diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 47d1228c61..133be5c045 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -51,7 +51,7 @@ from fixtures.common_types import ( TimelineId, ) from fixtures.compute_migrations import NUM_COMPUTE_MIGRATIONS -from fixtures.endpoint.http import EndpointHttpClient +from fixtures.endpoint.http import ComputeClaimsScope, EndpointHttpClient from fixtures.log_helper import log from fixtures.metrics import Metrics, MetricsGetter, parse_metrics from fixtures.neon_cli import NeonLocalCli, Pagectl @@ -4218,7 +4218,7 @@ class Endpoint(PgProtocol, LogUtils): self.config(config_lines) - self.__jwt = self.env.neon_cli.endpoint_generate_jwt(self.endpoint_id) + self.__jwt = self.generate_jwt() return self @@ -4265,6 +4265,14 @@ class Endpoint(PgProtocol, LogUtils): return self + def generate_jwt(self, scope: ComputeClaimsScope | None = None) -> str: + """ + Generate a JWT for making requests to the endpoint's external HTTP + server. + """ + assert self.endpoint_id is not None + return self.env.neon_cli.endpoint_generate_jwt(self.endpoint_id, scope) + def endpoint_path(self) -> Path: """Path to endpoint directory""" assert self.endpoint_id diff --git a/test_runner/regress/test_compute_http.py b/test_runner/regress/test_compute_http.py new file mode 100644 index 0000000000..ce31ff0fe6 --- /dev/null +++ b/test_runner/regress/test_compute_http.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from http.client import FORBIDDEN, UNAUTHORIZED +from typing import TYPE_CHECKING + +import jwt +import pytest +from fixtures.endpoint.http import COMPUTE_AUDIENCE, ComputeClaimsScope, EndpointHttpClient +from fixtures.utils import run_only_on_default_postgres +from requests import RequestException + +if TYPE_CHECKING: + from fixtures.neon_fixtures import NeonEnv + + +@run_only_on_default_postgres("The code path being tested is not dependent on Postgres version") +def test_compute_no_scope_claim(neon_simple_env: NeonEnv): + """ + Test that if the JWT scope is not admin and no compute_id is specified, + the external HTTP server returns a 403 Forbidden error. + """ + env = neon_simple_env + + endpoint = env.endpoints.create_start("main") + + # Encode nothing in the token + token = jwt.encode({}, env.auth_keys.priv, algorithm="EdDSA") + + # Create an admin-scoped HTTP client + client = EndpointHttpClient( + external_port=endpoint.external_http_port, + internal_port=endpoint.internal_http_port, + jwt=token, + ) + + try: + client.status() + pytest.fail("Exception should have been raised") + except RequestException as e: + assert e.response is not None + assert e.response.status_code == FORBIDDEN + + +@pytest.mark.parametrize( + "audience", + (COMPUTE_AUDIENCE, "invalid", None), + ids=["with_audience", "with_invalid_audience", "without_audience"], +) +@run_only_on_default_postgres("The code path being tested is not dependent on Postgres version") +def test_compute_admin_scope_claim(neon_simple_env: NeonEnv, audience: str | None): + """ + Test that an admin-scoped JWT can access the compute's external HTTP server + without the compute_id being specified in the claims. + """ + env = neon_simple_env + + endpoint = env.endpoints.create_start("main") + + data = {"scope": str(ComputeClaimsScope.ADMIN)} + if audience: + data["aud"] = audience + + token = jwt.encode(data, env.auth_keys.priv, algorithm="EdDSA") + + # Create an admin-scoped HTTP client + client = EndpointHttpClient( + external_port=endpoint.external_http_port, + internal_port=endpoint.internal_http_port, + jwt=token, + ) + + try: + client.status() + if audience != COMPUTE_AUDIENCE: + pytest.fail("Exception should have been raised") + except RequestException as e: + assert e.response is not None + assert e.response.status_code == UNAUTHORIZED