attachment_service: JWT auth enforcement (#6897)

## Problem
Attachment service does not do auth based on JWT scopes.

## Summary of changes
Do JWT based permission checking for requests coming into the attachment
service.

Requests into the attachment service must use different tokens based on
the endpoint:
* `/control` and `/debug` require `admin` scope
* `/upcall` requires `generations_api` scope
* `/v1/...` requires `pageserverapi` scope

Requests into the pageserver from the attachment service must use
`pageserverapi` scope.
This commit is contained in:
Vlad Lazar
2024-02-26 18:17:06 +00:00
committed by GitHub
parent 0881d4f9e3
commit 5accf6e24a
12 changed files with 268 additions and 73 deletions

View File

@@ -0,0 +1,9 @@
use utils::auth::{AuthError, Claims, Scope};
pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), AuthError> {
if claims.scope != required_scope {
return Err(AuthError("Scope mismatch. Permission denied".into()));
}
Ok(())
}

View File

@@ -10,8 +10,8 @@ use pageserver_api::shard::TenantShardId;
use pageserver_client::mgmt_api; use pageserver_client::mgmt_api;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use utils::auth::SwappableJwtAuth; use utils::auth::{Scope, SwappableJwtAuth};
use utils::http::endpoint::{auth_middleware, request_span}; use utils::http::endpoint::{auth_middleware, check_permission_with, request_span};
use utils::http::request::{must_get_query_param, parse_request_param}; use utils::http::request::{must_get_query_param, parse_request_param};
use utils::id::{TenantId, TimelineId}; use utils::id::{TenantId, TimelineId};
@@ -64,6 +64,8 @@ fn get_state(request: &Request<Body>) -> &HttpState {
/// Pageserver calls into this on startup, to learn which tenants it should attach /// Pageserver calls into this on startup, to learn which tenants it should attach
async fn handle_re_attach(mut req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_re_attach(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::GenerationsApi)?;
let reattach_req = json_request::<ReAttachRequest>(&mut req).await?; let reattach_req = json_request::<ReAttachRequest>(&mut req).await?;
let state = get_state(&req); let state = get_state(&req);
json_response(StatusCode::OK, state.service.re_attach(reattach_req).await?) json_response(StatusCode::OK, state.service.re_attach(reattach_req).await?)
@@ -72,6 +74,8 @@ async fn handle_re_attach(mut req: Request<Body>) -> Result<Response<Body>, ApiE
/// Pageserver calls into this before doing deletions, to confirm that it still /// Pageserver calls into this before doing deletions, to confirm that it still
/// holds the latest generation for the tenants with deletions enqueued /// holds the latest generation for the tenants with deletions enqueued
async fn handle_validate(mut req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_validate(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::GenerationsApi)?;
let validate_req = json_request::<ValidateRequest>(&mut req).await?; let validate_req = json_request::<ValidateRequest>(&mut req).await?;
let state = get_state(&req); let state = get_state(&req);
json_response(StatusCode::OK, state.service.validate(validate_req)) json_response(StatusCode::OK, state.service.validate(validate_req))
@@ -81,6 +85,8 @@ async fn handle_validate(mut req: Request<Body>) -> Result<Response<Body>, ApiEr
/// (in the real control plane this is unnecessary, because the same program is managing /// (in the real control plane this is unnecessary, because the same program is managing
/// generation numbers and doing attachments). /// generation numbers and doing attachments).
async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let attach_req = json_request::<AttachHookRequest>(&mut req).await?; let attach_req = json_request::<AttachHookRequest>(&mut req).await?;
let state = get_state(&req); let state = get_state(&req);
@@ -95,6 +101,8 @@ async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, Ap
} }
async fn handle_inspect(mut req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_inspect(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let inspect_req = json_request::<InspectRequest>(&mut req).await?; let inspect_req = json_request::<InspectRequest>(&mut req).await?;
let state = get_state(&req); let state = get_state(&req);
@@ -106,6 +114,8 @@ async fn handle_tenant_create(
service: Arc<Service>, service: Arc<Service>,
mut req: Request<Body>, mut req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::PageServerApi)?;
let create_req = json_request::<TenantCreateRequest>(&mut req).await?; let create_req = json_request::<TenantCreateRequest>(&mut req).await?;
json_response( json_response(
StatusCode::CREATED, StatusCode::CREATED,
@@ -164,6 +174,8 @@ async fn handle_tenant_location_config(
mut req: Request<Body>, mut req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
check_permissions(&req, Scope::PageServerApi)?;
let config_req = json_request::<TenantLocationConfigRequest>(&mut req).await?; let config_req = json_request::<TenantLocationConfigRequest>(&mut req).await?;
json_response( json_response(
StatusCode::OK, StatusCode::OK,
@@ -178,6 +190,8 @@ async fn handle_tenant_time_travel_remote_storage(
mut req: Request<Body>, mut req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
check_permissions(&req, Scope::PageServerApi)?;
let time_travel_req = json_request::<TenantTimeTravelRequest>(&mut req).await?; let time_travel_req = json_request::<TenantTimeTravelRequest>(&mut req).await?;
let timestamp_raw = must_get_query_param(&req, "travel_to")?; let timestamp_raw = must_get_query_param(&req, "travel_to")?;
@@ -211,6 +225,7 @@ async fn handle_tenant_delete(
req: Request<Body>, req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
check_permissions(&req, Scope::PageServerApi)?;
deletion_wrapper(service, move |service| async move { deletion_wrapper(service, move |service| async move {
service.tenant_delete(tenant_id).await service.tenant_delete(tenant_id).await
@@ -223,6 +238,8 @@ async fn handle_tenant_timeline_create(
mut req: Request<Body>, mut req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
check_permissions(&req, Scope::PageServerApi)?;
let create_req = json_request::<TimelineCreateRequest>(&mut req).await?; let create_req = json_request::<TimelineCreateRequest>(&mut req).await?;
json_response( json_response(
StatusCode::CREATED, StatusCode::CREATED,
@@ -237,6 +254,8 @@ async fn handle_tenant_timeline_delete(
req: Request<Body>, req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
check_permissions(&req, Scope::PageServerApi)?;
let timeline_id: TimelineId = parse_request_param(&req, "timeline_id")?; let timeline_id: TimelineId = parse_request_param(&req, "timeline_id")?;
deletion_wrapper(service, move |service| async move { deletion_wrapper(service, move |service| async move {
@@ -250,6 +269,7 @@ async fn handle_tenant_timeline_passthrough(
req: Request<Body>, req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
check_permissions(&req, Scope::PageServerApi)?;
let Some(path) = req.uri().path_and_query() else { let Some(path) = req.uri().path_and_query() else {
// This should never happen, our request router only calls us if there is a path // This should never happen, our request router only calls us if there is a path
@@ -293,11 +313,15 @@ async fn handle_tenant_locate(
service: Arc<Service>, service: Arc<Service>,
req: Request<Body>, req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
json_response(StatusCode::OK, service.tenant_locate(tenant_id)?) json_response(StatusCode::OK, service.tenant_locate(tenant_id)?)
} }
async fn handle_node_register(mut req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_node_register(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let register_req = json_request::<NodeRegisterRequest>(&mut req).await?; let register_req = json_request::<NodeRegisterRequest>(&mut req).await?;
let state = get_state(&req); let state = get_state(&req);
state.service.node_register(register_req).await?; state.service.node_register(register_req).await?;
@@ -305,17 +329,23 @@ async fn handle_node_register(mut req: Request<Body>) -> Result<Response<Body>,
} }
async fn handle_node_list(req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_node_list(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let state = get_state(&req); let state = get_state(&req);
json_response(StatusCode::OK, state.service.node_list().await?) json_response(StatusCode::OK, state.service.node_list().await?)
} }
async fn handle_node_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_node_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let state = get_state(&req); let state = get_state(&req);
let node_id: NodeId = parse_request_param(&req, "node_id")?; let node_id: NodeId = parse_request_param(&req, "node_id")?;
json_response(StatusCode::OK, state.service.node_drop(node_id).await?) json_response(StatusCode::OK, state.service.node_drop(node_id).await?)
} }
async fn handle_node_configure(mut req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_node_configure(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let node_id: NodeId = parse_request_param(&req, "node_id")?; let node_id: NodeId = parse_request_param(&req, "node_id")?;
let config_req = json_request::<NodeConfigureRequest>(&mut req).await?; let config_req = json_request::<NodeConfigureRequest>(&mut req).await?;
if node_id != config_req.node_id { if node_id != config_req.node_id {
@@ -335,6 +365,8 @@ async fn handle_tenant_shard_split(
service: Arc<Service>, service: Arc<Service>,
mut req: Request<Body>, mut req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let split_req = json_request::<TenantShardSplitRequest>(&mut req).await?; let split_req = json_request::<TenantShardSplitRequest>(&mut req).await?;
@@ -348,6 +380,8 @@ async fn handle_tenant_shard_migrate(
service: Arc<Service>, service: Arc<Service>,
mut req: Request<Body>, mut req: Request<Body>,
) -> Result<Response<Body>, ApiError> { ) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let tenant_shard_id: TenantShardId = parse_request_param(&req, "tenant_shard_id")?; let tenant_shard_id: TenantShardId = parse_request_param(&req, "tenant_shard_id")?;
let migrate_req = json_request::<TenantShardMigrateRequest>(&mut req).await?; let migrate_req = json_request::<TenantShardMigrateRequest>(&mut req).await?;
json_response( json_response(
@@ -360,22 +394,30 @@ async fn handle_tenant_shard_migrate(
async fn handle_tenant_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_tenant_drop(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
check_permissions(&req, Scope::PageServerApi)?;
let state = get_state(&req); let state = get_state(&req);
json_response(StatusCode::OK, state.service.tenant_drop(tenant_id).await?) json_response(StatusCode::OK, state.service.tenant_drop(tenant_id).await?)
} }
async fn handle_tenants_dump(req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_tenants_dump(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let state = get_state(&req); let state = get_state(&req);
state.service.tenants_dump() state.service.tenants_dump()
} }
async fn handle_scheduler_dump(req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_scheduler_dump(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let state = get_state(&req); let state = get_state(&req);
state.service.scheduler_dump() state.service.scheduler_dump()
} }
async fn handle_consistency_check(req: Request<Body>) -> Result<Response<Body>, ApiError> { async fn handle_consistency_check(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
let state = get_state(&req); let state = get_state(&req);
json_response(StatusCode::OK, state.service.consistency_check().await?) json_response(StatusCode::OK, state.service.consistency_check().await?)
@@ -432,6 +474,12 @@ where
.await .await
} }
fn check_permissions(request: &Request<Body>, required_scope: Scope) -> Result<(), ApiError> {
check_permission_with(request, |claims| {
crate::auth::check_permission(claims, required_scope)
})
}
pub fn make_router( pub fn make_router(
service: Arc<Service>, service: Arc<Service>,
auth: Option<Arc<SwappableJwtAuth>>, auth: Option<Arc<SwappableJwtAuth>>,

View File

@@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use utils::seqwait::MonotonicCounter; use utils::seqwait::MonotonicCounter;
mod auth;
mod compute_hook; mod compute_hook;
pub mod http; pub mod http;
pub mod metrics; pub mod metrics;

View File

@@ -11,12 +11,12 @@ use pageserver_api::{
use pageserver_client::mgmt_api::ResponseErrorMessageExt; use pageserver_client::mgmt_api::ResponseErrorMessageExt;
use postgres_backend::AuthType; use postgres_backend::AuthType;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::str::FromStr; use std::{fs, str::FromStr};
use tokio::process::Command; use tokio::process::Command;
use tracing::instrument; use tracing::instrument;
use url::Url; use url::Url;
use utils::{ use utils::{
auth::{Claims, Scope}, auth::{encode_from_key_file, Claims, Scope},
id::{NodeId, TenantId}, id::{NodeId, TenantId},
}; };
@@ -24,7 +24,7 @@ pub struct AttachmentService {
env: LocalEnv, env: LocalEnv,
listen: String, listen: String,
path: Utf8PathBuf, path: Utf8PathBuf,
jwt_token: Option<String>, private_key: Option<Vec<u8>>,
public_key: Option<String>, public_key: Option<String>,
postgres_port: u16, postgres_port: u16,
client: reqwest::Client, client: reqwest::Client,
@@ -204,12 +204,11 @@ impl AttachmentService {
.pageservers .pageservers
.first() .first()
.expect("Config is validated to contain at least one pageserver"); .expect("Config is validated to contain at least one pageserver");
let (jwt_token, public_key) = match ps_conf.http_auth_type { let (private_key, public_key) = match ps_conf.http_auth_type {
AuthType::Trust => (None, None), AuthType::Trust => (None, None),
AuthType::NeonJWT => { AuthType::NeonJWT => {
let jwt_token = env let private_key_path = env.get_private_key_path();
.generate_auth_token(&Claims::new(None, Scope::PageServerApi)) let private_key = fs::read(private_key_path).expect("failed to read private key");
.unwrap();
// If pageserver auth is enabled, this implicitly enables auth for this service, // If pageserver auth is enabled, this implicitly enables auth for this service,
// using the same credentials. // using the same credentials.
@@ -235,7 +234,7 @@ impl AttachmentService {
} else { } else {
std::fs::read_to_string(&public_key_path).expect("Can't read public key") std::fs::read_to_string(&public_key_path).expect("Can't read public key")
}; };
(Some(jwt_token), Some(public_key)) (Some(private_key), Some(public_key))
} }
}; };
@@ -243,7 +242,7 @@ impl AttachmentService {
env: env.clone(), env: env.clone(),
path, path,
listen, listen,
jwt_token, private_key,
public_key, public_key,
postgres_port, postgres_port,
client: reqwest::ClientBuilder::new() client: reqwest::ClientBuilder::new()
@@ -397,7 +396,10 @@ impl AttachmentService {
.into_iter() .into_iter()
.map(|s| s.to_string()) .map(|s| s.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if let Some(jwt_token) = &self.jwt_token { if let Some(private_key) = &self.private_key {
let claims = Claims::new(None, Scope::PageServerApi);
let jwt_token =
encode_from_key_file(&claims, private_key).expect("failed to generate jwt token");
args.push(format!("--jwt-token={jwt_token}")); args.push(format!("--jwt-token={jwt_token}"));
} }
@@ -468,6 +470,20 @@ impl AttachmentService {
Ok(()) Ok(())
} }
fn get_claims_for_path(path: &str) -> anyhow::Result<Option<Claims>> {
let category = match path.find('/') {
Some(idx) => &path[..idx],
None => path,
};
match category {
"status" | "ready" => Ok(None),
"control" | "debug" => Ok(Some(Claims::new(None, Scope::Admin))),
"v1" => Ok(Some(Claims::new(None, Scope::PageServerApi))),
_ => Err(anyhow::anyhow!("Failed to determine claims for {}", path)),
}
}
/// Simple HTTP request wrapper for calling into attachment service /// Simple HTTP request wrapper for calling into attachment service
async fn dispatch<RQ, RS>( async fn dispatch<RQ, RS>(
&self, &self,
@@ -493,11 +509,16 @@ impl AttachmentService {
if let Some(body) = body { if let Some(body) = body {
builder = builder.json(&body) builder = builder.json(&body)
} }
if let Some(jwt_token) = &self.jwt_token { if let Some(private_key) = &self.private_key {
builder = builder.header( println!("Getting claims for path {}", path);
reqwest::header::AUTHORIZATION, if let Some(required_claims) = Self::get_claims_for_path(&path)? {
format!("Bearer {jwt_token}"), println!("Got claims {:?} for path {}", required_claims, path);
); let jwt_token = encode_from_key_file(&required_claims, private_key)?;
builder = builder.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {jwt_token}"),
);
}
} }
let response = builder.send().await?; let response = builder.send().await?;

View File

@@ -412,14 +412,17 @@ impl LocalEnv {
// this function is used only for testing purposes in CLI e g generate tokens during init // this function is used only for testing purposes in CLI e g generate tokens during init
pub fn generate_auth_token(&self, claims: &Claims) -> anyhow::Result<String> { pub fn generate_auth_token(&self, claims: &Claims) -> anyhow::Result<String> {
let private_key_path = if self.private_key_path.is_absolute() { let private_key_path = self.get_private_key_path();
let key_data = fs::read(private_key_path)?;
encode_from_key_file(claims, &key_data)
}
pub fn get_private_key_path(&self) -> PathBuf {
if self.private_key_path.is_absolute() {
self.private_key_path.to_path_buf() self.private_key_path.to_path_buf()
} else { } else {
self.base_data_dir.join(&self.private_key_path) self.base_data_dir.join(&self.private_key_path)
}; }
let key_data = fs::read(private_key_path)?;
encode_from_key_file(claims, &key_data)
} }
// //

View File

@@ -115,7 +115,7 @@ impl PageServerNode {
if matches!(self.conf.http_auth_type, AuthType::NeonJWT) { if matches!(self.conf.http_auth_type, AuthType::NeonJWT) {
let jwt_token = self let jwt_token = self
.env .env
.generate_auth_token(&Claims::new(None, Scope::PageServerApi)) .generate_auth_token(&Claims::new(None, Scope::GenerationsApi))
.unwrap(); .unwrap();
overrides.push(format!("control_plane_api_token='{}'", jwt_token)); overrides.push(format!("control_plane_api_token='{}'", jwt_token));
} }

View File

@@ -70,6 +70,9 @@ Should only be used e.g. for status check/tenant creation/list.
Should only be used e.g. for status check. Should only be used e.g. for status check.
Currently also used for connection from any pageserver to any safekeeper. Currently also used for connection from any pageserver to any safekeeper.
"generations_api": Provides access to the upcall APIs served by the attachment service or the control plane.
"admin": Provides access to the control plane and admin APIs of the attachment service.
### CLI ### CLI
CLI generates a key pair during call to `neon_local init` with the following commands: CLI generates a key pair during call to `neon_local init` with the following commands:

View File

@@ -32,6 +32,8 @@ pub enum Scope {
// The scope used by pageservers in upcalls to storage controller and cloud control plane // The scope used by pageservers in upcalls to storage controller and cloud control plane
#[serde(rename = "generations_api")] #[serde(rename = "generations_api")]
GenerationsApi, GenerationsApi,
// Allows access to control plane managment API and some storage controller endpoints.
Admin,
} }
/// JWT payload. See docs/authentication.md for the format /// JWT payload. See docs/authentication.md for the format

View File

@@ -14,7 +14,7 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
} }
(Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope (Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope
(Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope (Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope
(Scope::SafekeeperData | Scope::GenerationsApi, _) => Err(AuthError( (Scope::Admin | Scope::SafekeeperData | Scope::GenerationsApi, _) => Err(AuthError(
format!( format!(
"JWT scope '{:?}' is ineligible for Pageserver auth", "JWT scope '{:?}' is ineligible for Pageserver auth",
claims.scope claims.scope

View File

@@ -12,7 +12,7 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
} }
Ok(()) Ok(())
} }
(Scope::PageServerApi | Scope::GenerationsApi, _) => Err(AuthError( (Scope::Admin | Scope::PageServerApi | Scope::GenerationsApi, _) => Err(AuthError(
format!( format!(
"JWT scope '{:?}' is ineligible for Safekeeper auth", "JWT scope '{:?}' is ineligible for Safekeeper auth",
claims.scope claims.scope

View File

@@ -17,6 +17,7 @@ import uuid
from contextlib import closing, contextmanager from contextlib import closing, contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum
from fcntl import LOCK_EX, LOCK_UN, flock from fcntl import LOCK_EX, LOCK_UN, flock
from functools import cached_property from functools import cached_property
from itertools import chain, product from itertools import chain, product
@@ -388,7 +389,8 @@ class PgProtocol:
class AuthKeys: class AuthKeys:
priv: str priv: str
def generate_token(self, *, scope: str, **token_data: str) -> str: def generate_token(self, *, scope: TokenScope, **token_data: Any) -> str:
token_data = {key: str(val) for key, val in token_data.items()}
token = jwt.encode({"scope": scope, **token_data}, self.priv, algorithm="EdDSA") token = jwt.encode({"scope": scope, **token_data}, self.priv, algorithm="EdDSA")
# cast(Any, self.priv) # cast(Any, self.priv)
@@ -401,14 +403,23 @@ class AuthKeys:
return token return token
def generate_pageserver_token(self) -> str: def generate_pageserver_token(self) -> str:
return self.generate_token(scope="pageserverapi") return self.generate_token(scope=TokenScope.PAGE_SERVER_API)
def generate_safekeeper_token(self) -> str: def generate_safekeeper_token(self) -> str:
return self.generate_token(scope="safekeeperdata") return self.generate_token(scope=TokenScope.SAFEKEEPER_DATA)
# generate token giving access to only one tenant # generate token giving access to only one tenant
def generate_tenant_token(self, tenant_id: TenantId) -> str: def generate_tenant_token(self, tenant_id: TenantId) -> str:
return self.generate_token(scope="tenant", tenant_id=str(tenant_id)) return self.generate_token(scope=TokenScope.TENANT, tenant_id=str(tenant_id))
# TODO: Replace with `StrEnum` when we upgrade to python 3.11
class TokenScope(str, Enum):
ADMIN = "admin"
PAGE_SERVER_API = "pageserverapi"
GENERATIONS_API = "generations_api"
SAFEKEEPER_DATA = "safekeeperdata"
TENANT = "tenant"
class NeonEnvBuilder: class NeonEnvBuilder:
@@ -1922,6 +1933,13 @@ class Pagectl(AbstractNeonCli):
return IndexPartDump.from_json(parsed) return IndexPartDump.from_json(parsed)
class AttachmentServiceApiException(Exception):
def __init__(self, message, status_code: int):
super().__init__(message)
self.message = message
self.status_code = status_code
class NeonAttachmentService(MetricsGetter): class NeonAttachmentService(MetricsGetter):
def __init__(self, env: NeonEnv, auth_enabled: bool): def __init__(self, env: NeonEnv, auth_enabled: bool):
self.env = env self.env = env
@@ -1940,39 +1958,60 @@ class NeonAttachmentService(MetricsGetter):
self.running = False self.running = False
return self return self
@staticmethod
def raise_api_exception(res: requests.Response):
try:
res.raise_for_status()
except requests.RequestException as e:
try:
msg = res.json()["msg"]
except: # noqa: E722
msg = ""
raise AttachmentServiceApiException(msg, res.status_code) from e
def pageserver_api(self) -> PageserverHttpClient: def pageserver_api(self) -> PageserverHttpClient:
""" """
The attachment service implements a subset of the pageserver REST API, for mapping The attachment service implements a subset of the pageserver REST API, for mapping
per-tenant actions into per-shard actions (e.g. timeline creation). Tests should invoke those per-tenant actions into per-shard actions (e.g. timeline creation). Tests should invoke those
functions via the HttpClient, as an implicit check that these APIs remain compatible. functions via the HttpClient, as an implicit check that these APIs remain compatible.
""" """
return PageserverHttpClient(self.env.attachment_service_port, lambda: True) auth_token = None
if self.auth_enabled:
auth_token = self.env.auth_keys.generate_token(scope=TokenScope.PAGE_SERVER_API)
return PageserverHttpClient(self.env.attachment_service_port, lambda: True, auth_token)
def request(self, method, *args, **kwargs) -> requests.Response: def request(self, method, *args, **kwargs) -> requests.Response:
kwargs["headers"] = self.headers() resp = requests.request(method, *args, **kwargs)
return requests.request(method, *args, **kwargs) NeonAttachmentService.raise_api_exception(resp)
def headers(self) -> Dict[str, str]: return resp
def headers(self, scope: Optional[TokenScope]) -> Dict[str, str]:
headers = {} headers = {}
if self.auth_enabled: if self.auth_enabled and scope is not None:
jwt_token = self.env.auth_keys.generate_pageserver_token() jwt_token = self.env.auth_keys.generate_token(scope=scope)
headers["Authorization"] = f"Bearer {jwt_token}" headers["Authorization"] = f"Bearer {jwt_token}"
return headers return headers
def get_metrics(self) -> Metrics: def get_metrics(self) -> Metrics:
res = self.request("GET", f"{self.env.attachment_service_api}/metrics") res = self.request("GET", f"{self.env.attachment_service_api}/metrics")
res.raise_for_status()
return parse_metrics(res.text) return parse_metrics(res.text)
def ready(self) -> bool: def ready(self) -> bool:
resp = self.request("GET", f"{self.env.attachment_service_api}/ready") status = None
if resp.status_code == 503: try:
resp = self.request("GET", f"{self.env.attachment_service_api}/ready")
status = resp.status_code
except AttachmentServiceApiException as e:
status = e.status_code
if status == 503:
return False return False
elif resp.status_code == 200: elif status == 200:
return True return True
else: else:
raise RuntimeError(f"Unexpected status {resp.status_code} from readiness endpoint") raise RuntimeError(f"Unexpected status {status} from readiness endpoint")
def attach_hook_issue( def attach_hook_issue(
self, tenant_shard_id: Union[TenantId, TenantShardId], pageserver_id: int self, tenant_shard_id: Union[TenantId, TenantShardId], pageserver_id: int
@@ -1981,21 +2020,19 @@ class NeonAttachmentService(MetricsGetter):
"POST", "POST",
f"{self.env.attachment_service_api}/debug/v1/attach-hook", f"{self.env.attachment_service_api}/debug/v1/attach-hook",
json={"tenant_shard_id": str(tenant_shard_id), "node_id": pageserver_id}, json={"tenant_shard_id": str(tenant_shard_id), "node_id": pageserver_id},
headers=self.headers(), headers=self.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
gen = response.json()["gen"] gen = response.json()["gen"]
assert isinstance(gen, int) assert isinstance(gen, int)
return gen return gen
def attach_hook_drop(self, tenant_shard_id: Union[TenantId, TenantShardId]): def attach_hook_drop(self, tenant_shard_id: Union[TenantId, TenantShardId]):
response = self.request( self.request(
"POST", "POST",
f"{self.env.attachment_service_api}/debug/v1/attach-hook", f"{self.env.attachment_service_api}/debug/v1/attach-hook",
json={"tenant_shard_id": str(tenant_shard_id), "node_id": None}, json={"tenant_shard_id": str(tenant_shard_id), "node_id": None},
headers=self.headers(), headers=self.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
def inspect(self, tenant_shard_id: Union[TenantId, TenantShardId]) -> Optional[tuple[int, int]]: def inspect(self, tenant_shard_id: Union[TenantId, TenantShardId]) -> Optional[tuple[int, int]]:
""" """
@@ -2005,9 +2042,8 @@ class NeonAttachmentService(MetricsGetter):
"POST", "POST",
f"{self.env.attachment_service_api}/debug/v1/inspect", f"{self.env.attachment_service_api}/debug/v1/inspect",
json={"tenant_shard_id": str(tenant_shard_id)}, json={"tenant_shard_id": str(tenant_shard_id)},
headers=self.headers(), headers=self.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
json = response.json() json = response.json()
log.info(f"Response: {json}") log.info(f"Response: {json}")
if json["attachment"]: if json["attachment"]:
@@ -2027,14 +2063,15 @@ class NeonAttachmentService(MetricsGetter):
"POST", "POST",
f"{self.env.attachment_service_api}/control/v1/node", f"{self.env.attachment_service_api}/control/v1/node",
json=body, json=body,
headers=self.headers(), headers=self.headers(TokenScope.ADMIN),
).raise_for_status() )
def node_list(self): def node_list(self):
response = self.request( response = self.request(
"GET", f"{self.env.attachment_service_api}/control/v1/node", headers=self.headers() "GET",
f"{self.env.attachment_service_api}/control/v1/node",
headers=self.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
return response.json() return response.json()
def node_configure(self, node_id, body: dict[str, Any]): def node_configure(self, node_id, body: dict[str, Any]):
@@ -2044,8 +2081,8 @@ class NeonAttachmentService(MetricsGetter):
"PUT", "PUT",
f"{self.env.attachment_service_api}/control/v1/node/{node_id}/config", f"{self.env.attachment_service_api}/control/v1/node/{node_id}/config",
json=body, json=body,
headers=self.headers(), headers=self.headers(TokenScope.ADMIN),
).raise_for_status() )
def tenant_create( def tenant_create(
self, self,
@@ -2070,8 +2107,12 @@ class NeonAttachmentService(MetricsGetter):
for k, v in tenant_config.items(): for k, v in tenant_config.items():
body[k] = v body[k] = v
response = self.request("POST", f"{self.env.attachment_service_api}/v1/tenant", json=body) response = self.request(
response.raise_for_status() "POST",
f"{self.env.attachment_service_api}/v1/tenant",
json=body,
headers=self.headers(TokenScope.PAGE_SERVER_API),
)
log.info(f"tenant_create success: {response.json()}") log.info(f"tenant_create success: {response.json()}")
def locate(self, tenant_id: TenantId) -> list[dict[str, Any]]: def locate(self, tenant_id: TenantId) -> list[dict[str, Any]]:
@@ -2079,9 +2120,10 @@ class NeonAttachmentService(MetricsGetter):
:return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr: str, "listen_http_port: int} :return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr: str, "listen_http_port: int}
""" """
response = self.request( response = self.request(
"GET", f"{self.env.attachment_service_api}/control/v1/tenant/{tenant_id}/locate" "GET",
f"{self.env.attachment_service_api}/control/v1/tenant/{tenant_id}/locate",
headers=self.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
body = response.json() body = response.json()
shards: list[dict[str, Any]] = body["shards"] shards: list[dict[str, Any]] = body["shards"]
return shards return shards
@@ -2091,20 +2133,20 @@ class NeonAttachmentService(MetricsGetter):
"PUT", "PUT",
f"{self.env.attachment_service_api}/control/v1/tenant/{tenant_id}/shard_split", f"{self.env.attachment_service_api}/control/v1/tenant/{tenant_id}/shard_split",
json={"new_shard_count": shard_count}, json={"new_shard_count": shard_count},
headers=self.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
body = response.json() body = response.json()
log.info(f"tenant_shard_split success: {body}") log.info(f"tenant_shard_split success: {body}")
shards: list[TenantShardId] = body["new_shards"] shards: list[TenantShardId] = body["new_shards"]
return shards return shards
def tenant_shard_migrate(self, tenant_shard_id: TenantShardId, dest_ps_id: int): def tenant_shard_migrate(self, tenant_shard_id: TenantShardId, dest_ps_id: int):
response = self.request( self.request(
"PUT", "PUT",
f"{self.env.attachment_service_api}/control/v1/tenant/{tenant_shard_id}/migrate", f"{self.env.attachment_service_api}/control/v1/tenant/{tenant_shard_id}/migrate",
json={"tenant_shard_id": str(tenant_shard_id), "node_id": dest_ps_id}, json={"tenant_shard_id": str(tenant_shard_id), "node_id": dest_ps_id},
headers=self.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
log.info(f"Migrated tenant {tenant_shard_id} to pageserver {dest_ps_id}") log.info(f"Migrated tenant {tenant_shard_id} to pageserver {dest_ps_id}")
assert self.env.get_tenant_pageserver(tenant_shard_id).id == dest_ps_id assert self.env.get_tenant_pageserver(tenant_shard_id).id == dest_ps_id
@@ -2112,11 +2154,11 @@ class NeonAttachmentService(MetricsGetter):
""" """
Throw an exception if the service finds any inconsistencies in its state Throw an exception if the service finds any inconsistencies in its state
""" """
response = self.request( self.request(
"POST", "POST",
f"{self.env.attachment_service_api}/debug/v1/consistency_check", f"{self.env.attachment_service_api}/debug/v1/consistency_check",
headers=self.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
log.info("Attachment service passed consistency check") log.info("Attachment service passed consistency check")
def __enter__(self) -> "NeonAttachmentService": def __enter__(self) -> "NeonAttachmentService":
@@ -2894,7 +2936,6 @@ class NeonProxy(PgProtocol):
def get_metrics(self) -> str: def get_metrics(self) -> str:
request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics") request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics")
request_result.raise_for_status()
return request_result.text return request_result.text
@staticmethod @staticmethod

View File

@@ -1,13 +1,16 @@
import time import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List from typing import Any, Dict, List
import pytest
from fixtures.log_helper import log from fixtures.log_helper import log
from fixtures.neon_fixtures import ( from fixtures.neon_fixtures import (
AttachmentServiceApiException,
NeonEnv, NeonEnv,
NeonEnvBuilder, NeonEnvBuilder,
PgBin, PgBin,
TokenScope,
) )
from fixtures.pageserver.http import PageserverHttpClient from fixtures.pageserver.http import PageserverHttpClient
from fixtures.pageserver.utils import ( from fixtures.pageserver.utils import (
@@ -457,37 +460,40 @@ def test_sharding_service_debug_apis(neon_env_builder: NeonEnvBuilder):
# Initial tenant (1 shard) and the one we just created (2 shards) should be visible # Initial tenant (1 shard) and the one we just created (2 shards) should be visible
response = env.attachment_service.request( response = env.attachment_service.request(
"GET", f"{env.attachment_service_api}/debug/v1/tenant" "GET",
f"{env.attachment_service_api}/debug/v1/tenant",
headers=env.attachment_service.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
assert len(response.json()) == 3 assert len(response.json()) == 3
# Scheduler should report the expected nodes and shard counts # Scheduler should report the expected nodes and shard counts
response = env.attachment_service.request( response = env.attachment_service.request(
"GET", f"{env.attachment_service_api}/debug/v1/scheduler" "GET", f"{env.attachment_service_api}/debug/v1/scheduler"
) )
response.raise_for_status()
# Two nodes, in a dict of node_id->node # Two nodes, in a dict of node_id->node
assert len(response.json()["nodes"]) == 2 assert len(response.json()["nodes"]) == 2
assert sum(v["shard_count"] for v in response.json()["nodes"].values()) == 3 assert sum(v["shard_count"] for v in response.json()["nodes"].values()) == 3
assert all(v["may_schedule"] for v in response.json()["nodes"].values()) assert all(v["may_schedule"] for v in response.json()["nodes"].values())
response = env.attachment_service.request( response = env.attachment_service.request(
"POST", f"{env.attachment_service_api}/debug/v1/node/{env.pageservers[1].id}/drop" "POST",
f"{env.attachment_service_api}/debug/v1/node/{env.pageservers[1].id}/drop",
headers=env.attachment_service.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
assert len(env.attachment_service.node_list()) == 1 assert len(env.attachment_service.node_list()) == 1
response = env.attachment_service.request( response = env.attachment_service.request(
"POST", f"{env.attachment_service_api}/debug/v1/tenant/{tenant_id}/drop" "POST",
f"{env.attachment_service_api}/debug/v1/tenant/{tenant_id}/drop",
headers=env.attachment_service.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
# Tenant drop should be reflected in dump output # Tenant drop should be reflected in dump output
response = env.attachment_service.request( response = env.attachment_service.request(
"GET", f"{env.attachment_service_api}/debug/v1/tenant" "GET",
f"{env.attachment_service_api}/debug/v1/tenant",
headers=env.attachment_service.headers(TokenScope.ADMIN),
) )
response.raise_for_status()
assert len(response.json()) == 1 assert len(response.json()) == 1
# Check that the 'drop' APIs didn't leave things in a state that would fail a consistency check: they're # Check that the 'drop' APIs didn't leave things in a state that would fail a consistency check: they're
@@ -603,3 +609,64 @@ def test_sharding_service_s3_time_travel_recovery(
endpoint.safe_psql("SELECT * FROM created_foo;") endpoint.safe_psql("SELECT * FROM created_foo;")
env.attachment_service.consistency_check() env.attachment_service.consistency_check()
def test_sharding_service_auth(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
env = neon_env_builder.init_start()
svc = env.attachment_service
api = env.attachment_service_api
tenant_id = TenantId.generate()
body: Dict[str, Any] = {"new_tenant_id": str(tenant_id)}
# No token
with pytest.raises(
AttachmentServiceApiException,
match="Unauthorized: missing authorization header",
):
svc.request("POST", f"{env.attachment_service_api}/v1/tenant", json=body)
# Token with incorrect scope
with pytest.raises(
AttachmentServiceApiException,
match="Forbidden: JWT authentication error",
):
svc.request("POST", f"{api}/v1/tenant", json=body, headers=svc.headers(TokenScope.ADMIN))
# Token with correct scope
svc.request(
"POST", f"{api}/v1/tenant", json=body, headers=svc.headers(TokenScope.PAGE_SERVER_API)
)
# No token
with pytest.raises(
AttachmentServiceApiException,
match="Unauthorized: missing authorization header",
):
svc.request("GET", f"{api}/debug/v1/tenant")
# Token with incorrect scope
with pytest.raises(
AttachmentServiceApiException,
match="Forbidden: JWT authentication error",
):
svc.request(
"GET", f"{api}/debug/v1/tenant", headers=svc.headers(TokenScope.GENERATIONS_API)
)
# No token
with pytest.raises(
AttachmentServiceApiException,
match="Unauthorized: missing authorization header",
):
svc.request("POST", f"{api}/upcall/v1/re-attach")
# Token with incorrect scope
with pytest.raises(
AttachmentServiceApiException,
match="Forbidden: JWT authentication error",
):
svc.request(
"POST", f"{api}/upcall/v1/re-attach", headers=svc.headers(TokenScope.PAGE_SERVER_API)
)