diff --git a/compute_tools/src/http/middleware/authorize.rs b/compute_tools/src/http/middleware/authorize.rs
index 2d0f411d7a..2afc57ad9c 100644
--- a/compute_tools/src/http/middleware/authorize.rs
+++ b/compute_tools/src/http/middleware/authorize.rs
@@ -1,12 +1,10 @@
-use std::collections::HashSet;
-
use anyhow::{Result, anyhow};
use axum::{RequestExt, body::Body};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
-use compute_api::requests::ComputeClaims;
+use compute_api::requests::{COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope};
use futures::future::BoxFuture;
use http::{Request, Response, StatusCode};
use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, jwk::JwkSet};
@@ -25,13 +23,14 @@ pub(in crate::http) struct Authorize {
impl Authorize {
pub fn new(compute_id: String, jwks: JwkSet) -> Self {
let mut validation = Validation::new(Algorithm::EdDSA);
- // Nothing is currently required
- validation.required_spec_claims = HashSet::new();
validation.validate_exp = true;
// Unused by the control plane
- validation.validate_aud = false;
- // Unused by the control plane
validation.validate_nbf = false;
+ // Unused by the control plane
+ validation.validate_aud = false;
+ validation.set_audience(&[COMPUTE_AUDIENCE]);
+ // Nothing is currently required
+ validation.set_required_spec_claims(&[] as &[&str; 0]);
Self {
compute_id,
@@ -64,11 +63,47 @@ impl AsyncAuthorizeRequest
for Authorize {
Err(e) => return Err(JsonResponse::error(StatusCode::UNAUTHORIZED, e)),
};
- if data.claims.compute_id != compute_id {
- return Err(JsonResponse::error(
- StatusCode::UNAUTHORIZED,
- "invalid compute ID in authorization token claims",
- ));
+ match data.claims.scope {
+ // TODO: We should validate audience for every token, but
+ // instead of this ad-hoc validation, we should turn
+ // [`Validation::validate_aud`] on. This is merely a stopgap
+ // while we roll out `aud` deployment. We return a 401
+ // Unauthorized because when we eventually do use
+ // [`Validation`], we will hit the above `Err` match arm which
+ // returns 401 Unauthorized.
+ Some(ComputeClaimsScope::Admin) => {
+ let Some(ref audience) = data.claims.audience else {
+ return Err(JsonResponse::error(
+ StatusCode::UNAUTHORIZED,
+ "missing audience in authorization token claims",
+ ));
+ };
+
+ if audience != COMPUTE_AUDIENCE {
+ return Err(JsonResponse::error(
+ StatusCode::UNAUTHORIZED,
+ "invalid audience in authorization token claims",
+ ));
+ }
+ }
+
+ // If the scope is not [`ComputeClaimsScope::Admin`], then we
+ // must validate the compute_id
+ _ => {
+ let Some(ref claimed_compute_id) = data.claims.compute_id else {
+ return Err(JsonResponse::error(
+ StatusCode::FORBIDDEN,
+ "missing compute_id in authorization token claims",
+ ));
+ };
+
+ if *claimed_compute_id != compute_id {
+ return Err(JsonResponse::error(
+ StatusCode::FORBIDDEN,
+ "invalid compute ID in authorization token claims",
+ ));
+ }
+ }
}
// Make claims available to any subsequent middleware or request
diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs
index 6f55c0310f..44698f7b23 100644
--- a/control_plane/src/bin/neon_local.rs
+++ b/control_plane/src/bin/neon_local.rs
@@ -16,6 +16,7 @@ use std::time::Duration;
use anyhow::{Context, Result, anyhow, bail};
use clap::Parser;
+use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::ComputeMode;
use control_plane::broker::StorageBroker;
use control_plane::endpoint::ComputeControlPlane;
@@ -705,6 +706,9 @@ struct EndpointStopCmdArgs {
struct EndpointGenerateJwtCmdArgs {
#[clap(help = "Postgres endpoint id")]
endpoint_id: String,
+
+ #[clap(short = 's', long, help = "Scope to generate the JWT with", value_parser = ComputeClaimsScope::from_str)]
+ scope: Option,
}
#[derive(clap::Subcommand)]
@@ -1540,12 +1544,16 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
endpoint.stop(&args.mode, args.destroy)?;
}
EndpointCmd::GenerateJwt(args) => {
- let endpoint_id = &args.endpoint_id;
- let endpoint = cplane
- .endpoints
- .get(endpoint_id)
- .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
- let jwt = endpoint.generate_jwt()?;
+ let endpoint = {
+ let endpoint_id = &args.endpoint_id;
+
+ cplane
+ .endpoints
+ .get(endpoint_id)
+ .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?
+ };
+
+ let jwt = endpoint.generate_jwt(args.scope)?;
print!("{jwt}");
}
diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs
index 4071b620d6..0b16339a6f 100644
--- a/control_plane/src/endpoint.rs
+++ b/control_plane/src/endpoint.rs
@@ -45,7 +45,9 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
-use compute_api::requests::{ComputeClaims, ConfigurationRequest};
+use compute_api::requests::{
+ COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
+};
use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TlsConfig,
};
@@ -630,9 +632,17 @@ impl Endpoint {
}
/// Generate a JWT with the correct claims.
- pub fn generate_jwt(&self) -> Result {
+ pub fn generate_jwt(&self, scope: Option) -> Result {
self.env.generate_auth_token(&ComputeClaims {
- compute_id: self.endpoint_id.clone(),
+ audience: match scope {
+ Some(ComputeClaimsScope::Admin) => Some(COMPUTE_AUDIENCE.to_owned()),
+ _ => Some(self.endpoint_id.clone()),
+ },
+ compute_id: match scope {
+ Some(ComputeClaimsScope::Admin) => None,
+ _ => Some(self.endpoint_id.clone()),
+ },
+ scope,
})
}
@@ -903,7 +913,7 @@ impl Endpoint {
self.external_http_address.port()
),
)
- .bearer_auth(self.generate_jwt()?)
+ .bearer_auth(self.generate_jwt(None::)?)
.send()
.await?;
@@ -980,7 +990,7 @@ impl Endpoint {
self.external_http_address.port()
))
.header(CONTENT_TYPE.as_str(), "application/json")
- .bearer_auth(self.generate_jwt()?)
+ .bearer_auth(self.generate_jwt(None::)?)
.body(
serde_json::to_string(&ConfigurationRequest {
spec,
diff --git a/libs/compute_api/src/requests.rs b/libs/compute_api/src/requests.rs
index 98f2fc297c..40d34eccea 100644
--- a/libs/compute_api/src/requests.rs
+++ b/libs/compute_api/src/requests.rs
@@ -1,16 +1,55 @@
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
+use std::str::FromStr;
+
use serde::{Deserialize, Serialize};
use crate::privilege::Privilege;
use crate::responses::ComputeCtlConfig;
use crate::spec::{ComputeSpec, ExtVersion, PgIdent};
+/// The value to place in the [`ComputeClaims::audience`] claim.
+pub static COMPUTE_AUDIENCE: &str = "compute";
+
+#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
+#[serde(rename_all = "snake_case")]
+/// Available scopes for a compute's JWT.
+pub enum ComputeClaimsScope {
+ /// An admin-scoped token allows access to all of `compute_ctl`'s authorized
+ /// facilities.
+ Admin,
+}
+
+impl FromStr for ComputeClaimsScope {
+ type Err = anyhow::Error;
+
+ fn from_str(s: &str) -> Result {
+ match s {
+ "admin" => Ok(ComputeClaimsScope::Admin),
+ _ => Err(anyhow::anyhow!("invalid compute claims scope \"{s}\"")),
+ }
+ }
+}
+
/// When making requests to the `compute_ctl` external HTTP server, the client
/// must specify a set of claims in `Authorization` header JWTs such that
/// `compute_ctl` can authorize the request.
#[derive(Clone, Debug, Deserialize, Serialize)]
+#[serde(rename = "snake_case")]
pub struct ComputeClaims {
- pub compute_id: String,
+ /// The compute ID that will validate the token. The only case in which this
+ /// can be [`None`] is if [`Self::scope`] is
+ /// [`ComputeClaimsScope::Admin`].
+ pub compute_id: Option,
+
+ /// The scope of what the token authorizes.
+ pub scope: Option,
+
+ /// The recipient the token is intended for.
+ ///
+ /// See [RFC 7519](https://www.rfc-editor.org/rfc/rfc7519#section-4.1.3) for
+ /// more information.
+ #[serde(rename = "aud")]
+ pub audience: Option,
}
/// Request of the /configure API
diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py
index 652c38f5c3..beed1dcd93 100644
--- a/test_runner/fixtures/endpoint/http.py
+++ b/test_runner/fixtures/endpoint/http.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import urllib.parse
+from enum import StrEnum
from typing import TYPE_CHECKING, final
import requests
@@ -14,6 +15,17 @@ if TYPE_CHECKING:
from requests import PreparedRequest
+COMPUTE_AUDIENCE = "compute"
+"""
+The value to place in the `aud` claim.
+"""
+
+
+@final
+class ComputeClaimsScope(StrEnum):
+ ADMIN = "admin"
+
+
@final
class BearerAuth(AuthBase):
"""
diff --git a/test_runner/fixtures/neon_cli.py b/test_runner/fixtures/neon_cli.py
index b5d69b5ab6..3be78719d7 100644
--- a/test_runner/fixtures/neon_cli.py
+++ b/test_runner/fixtures/neon_cli.py
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
Any,
)
+ from fixtures.endpoint.http import ComputeClaimsScope
from fixtures.pg_version import PgVersion
@@ -535,12 +536,16 @@ class NeonLocalCli(AbstractNeonCli):
res.check_returncode()
return res
- def endpoint_generate_jwt(self, endpoint_id: str) -> str:
+ def endpoint_generate_jwt(
+ self, endpoint_id: str, scope: ComputeClaimsScope | None = None
+ ) -> str:
"""
Generate a JWT for making requests to the endpoint's external HTTP
server.
"""
args = ["endpoint", "generate-jwt", endpoint_id]
+ if scope:
+ args += ["--scope", str(scope)]
cmd = self.raw_cli(args)
cmd.check_returncode()
diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py
index 47d1228c61..133be5c045 100644
--- a/test_runner/fixtures/neon_fixtures.py
+++ b/test_runner/fixtures/neon_fixtures.py
@@ -51,7 +51,7 @@ from fixtures.common_types import (
TimelineId,
)
from fixtures.compute_migrations import NUM_COMPUTE_MIGRATIONS
-from fixtures.endpoint.http import EndpointHttpClient
+from fixtures.endpoint.http import ComputeClaimsScope, EndpointHttpClient
from fixtures.log_helper import log
from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
from fixtures.neon_cli import NeonLocalCli, Pagectl
@@ -4218,7 +4218,7 @@ class Endpoint(PgProtocol, LogUtils):
self.config(config_lines)
- self.__jwt = self.env.neon_cli.endpoint_generate_jwt(self.endpoint_id)
+ self.__jwt = self.generate_jwt()
return self
@@ -4265,6 +4265,14 @@ class Endpoint(PgProtocol, LogUtils):
return self
+ def generate_jwt(self, scope: ComputeClaimsScope | None = None) -> str:
+ """
+ Generate a JWT for making requests to the endpoint's external HTTP
+ server.
+ """
+ assert self.endpoint_id is not None
+ return self.env.neon_cli.endpoint_generate_jwt(self.endpoint_id, scope)
+
def endpoint_path(self) -> Path:
"""Path to endpoint directory"""
assert self.endpoint_id
diff --git a/test_runner/regress/test_compute_http.py b/test_runner/regress/test_compute_http.py
new file mode 100644
index 0000000000..ce31ff0fe6
--- /dev/null
+++ b/test_runner/regress/test_compute_http.py
@@ -0,0 +1,78 @@
+from __future__ import annotations
+
+from http.client import FORBIDDEN, UNAUTHORIZED
+from typing import TYPE_CHECKING
+
+import jwt
+import pytest
+from fixtures.endpoint.http import COMPUTE_AUDIENCE, ComputeClaimsScope, EndpointHttpClient
+from fixtures.utils import run_only_on_default_postgres
+from requests import RequestException
+
+if TYPE_CHECKING:
+ from fixtures.neon_fixtures import NeonEnv
+
+
+@run_only_on_default_postgres("The code path being tested is not dependent on Postgres version")
+def test_compute_no_scope_claim(neon_simple_env: NeonEnv):
+ """
+ Test that if the JWT scope is not admin and no compute_id is specified,
+ the external HTTP server returns a 403 Forbidden error.
+ """
+ env = neon_simple_env
+
+ endpoint = env.endpoints.create_start("main")
+
+ # Encode nothing in the token
+ token = jwt.encode({}, env.auth_keys.priv, algorithm="EdDSA")
+
+ # Create an admin-scoped HTTP client
+ client = EndpointHttpClient(
+ external_port=endpoint.external_http_port,
+ internal_port=endpoint.internal_http_port,
+ jwt=token,
+ )
+
+ try:
+ client.status()
+ pytest.fail("Exception should have been raised")
+ except RequestException as e:
+ assert e.response is not None
+ assert e.response.status_code == FORBIDDEN
+
+
+@pytest.mark.parametrize(
+ "audience",
+ (COMPUTE_AUDIENCE, "invalid", None),
+ ids=["with_audience", "with_invalid_audience", "without_audience"],
+)
+@run_only_on_default_postgres("The code path being tested is not dependent on Postgres version")
+def test_compute_admin_scope_claim(neon_simple_env: NeonEnv, audience: str | None):
+ """
+ Test that an admin-scoped JWT can access the compute's external HTTP server
+ without the compute_id being specified in the claims.
+ """
+ env = neon_simple_env
+
+ endpoint = env.endpoints.create_start("main")
+
+ data = {"scope": str(ComputeClaimsScope.ADMIN)}
+ if audience:
+ data["aud"] = audience
+
+ token = jwt.encode(data, env.auth_keys.priv, algorithm="EdDSA")
+
+ # Create an admin-scoped HTTP client
+ client = EndpointHttpClient(
+ external_port=endpoint.external_http_port,
+ internal_port=endpoint.internal_http_port,
+ jwt=token,
+ )
+
+ try:
+ client.status()
+ if audience != COMPUTE_AUDIENCE:
+ pytest.fail("Exception should have been raised")
+ except RequestException as e:
+ assert e.response is not None
+ assert e.response.status_code == UNAUTHORIZED