From b13572399402146b88fab862f4a046cc20d36217 Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Wed, 18 Aug 2021 14:28:30 +0300 Subject: [PATCH] review adjustments --- pageserver/src/bin/pageserver.rs | 2 +- pageserver/src/http/mod.rs | 2 +- pageserver/src/http/routes.rs | 65 ++++++++++++++++--------------- zenith_utils/src/http/endpoint.rs | 55 +++++++++++--------------- 4 files changed, 58 insertions(+), 66 deletions(-) diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 975551374c..2c2955ab7d 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -341,7 +341,7 @@ fn start_pageserver(conf: &'static PageServerConf) -> Result<()> { thread::Builder::new() .name("http_endpoint_thread".into()) .spawn(move || { - let router = http::get_router(conf, cloned); + let router = http::make_router(conf, cloned); endpoint::serve_thread_main(router, conf.http_endpoint_addr.clone()) })?; diff --git a/pageserver/src/http/mod.rs b/pageserver/src/http/mod.rs index ef282c43ac..4c0be17ecd 100644 --- a/pageserver/src/http/mod.rs +++ b/pageserver/src/http/mod.rs @@ -1,3 +1,3 @@ pub mod models; pub mod routes; -pub use routes::get_router; +pub use routes::make_router; diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 6204721f47..257d24a2e5 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -4,13 +4,11 @@ use anyhow::Result; use hyper::header; use hyper::StatusCode; use hyper::{Body, Request, Response, Uri}; -use routerify::Middleware; use routerify::{ext::RequestExt, RouterBuilder}; use zenith_utils::auth::JwtAuth; use zenith_utils::http::endpoint::attach_openapi_ui; use zenith_utils::http::endpoint::auth_middleware; use zenith_utils::http::endpoint::check_permission; -use zenith_utils::http::endpoint::AuthProvider; use zenith_utils::http::error::ApiError; use zenith_utils::http::{ endpoint, @@ -30,31 +28,34 @@ use crate::{ struct State { conf: &'static PageServerConf, auth: Arc>, - whitelist_routes: Vec, + allowlist_routes: Vec, } impl State { fn new(conf: &'static PageServerConf, auth: Arc>) -> Self { - let whitelist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"] + let allowlist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"] .iter() .map(|v| v.parse().unwrap()) .collect::>(); Self { conf, auth, - whitelist_routes, + allowlist_routes, } } } -impl AuthProvider for State { - fn provide_auth(&self, req: &Request) -> Arc> { - if self.whitelist_routes.contains(req.uri()) { - Arc::new(None) - } else { - self.auth.clone() - } - } +#[inline(always)] +fn get_state(request: &Request) -> &State { + request + .data::>() + .expect("unknown state type") + .as_ref() +} + +#[inline(always)] +fn get_config(request: &Request) -> &'static PageServerConf { + get_state(request).conf } // healthcheck handler @@ -67,14 +68,13 @@ async fn status_handler(_: Request) -> Result, ApiError> { } async fn branch_create_handler(mut request: Request) -> Result, ApiError> { - let state = request.data::>().unwrap().clone(); let request_data: BranchCreateRequest = json_request(&mut request).await?; check_permission(&request, Some(request_data.tenant_id))?; let response_data = tokio::task::spawn_blocking(move || { branches::create_branch( - state.conf, + get_config(&request), &request_data.name, &request_data.start_point, &request_data.tenant_id, @@ -99,11 +99,11 @@ async fn branch_list_handler(request: Request) -> Result, A check_permission(&request, Some(tenantid))?; - let state = request.data::>().unwrap().clone(); - let response_data = - tokio::task::spawn_blocking(move || crate::branches::get_branches(state.conf, &tenantid)) - .await - .map_err(ApiError::from_err)??; + let response_data = tokio::task::spawn_blocking(move || { + crate::branches::get_branches(get_config(&request), &tenantid) + }) + .await + .map_err(ApiError::from_err)??; Ok(json_response(StatusCode::OK, response_data)?) } @@ -111,9 +111,8 @@ async fn tenant_list_handler(request: Request) -> Result, A // check for management permission check_permission(&request, None)?; - let state = request.data::>().unwrap().clone(); let response_data = - tokio::task::spawn_blocking(move || crate::branches::get_tenants(state.conf)) + tokio::task::spawn_blocking(move || crate::branches::get_tenants(get_config(&request))) .await .map_err(ApiError::from_err)??; Ok(json_response(StatusCode::OK, response_data)?) @@ -123,11 +122,10 @@ async fn tenant_create_handler(mut request: Request) -> Result>().unwrap().clone(); let request_data: TenantCreateRequest = json_request(&mut request).await?; let response_data = tokio::task::spawn_blocking(move || { - page_cache::create_repository_for_tenant(state.conf, request_data.tenant_id) + page_cache::create_repository_for_tenant(get_config(&request), request_data.tenant_id) }) .await .map_err(ApiError::from_err)??; @@ -141,18 +139,23 @@ async fn handler_404(_: Request) -> Result, ApiError> { ) } -pub fn get_router( +pub fn make_router( conf: &'static PageServerConf, auth: Arc>, ) -> RouterBuilder { let spec = include_bytes!("openapi_spec.yml"); - let mut router = attach_openapi_ui(endpoint::get_router(), spec, "/swagger.yml", "/v1/doc"); - if let Some(_) = &auth.as_ref() { - // note that State is used as a type parameteer without an Arc - // this is a simple solution because it is not possible to implement - // AuthProvider for Arc so middleware assumes that state is wrapped in Arc - router = router.middleware(Middleware::pre(auth_middleware::)) + let mut router = attach_openapi_ui(endpoint::make_router(), spec, "/swagger.yml", "/v1/doc"); + if auth.is_some() { + router = router.middleware(auth_middleware(|request| { + let state = get_state(request); + if state.allowlist_routes.contains(request.uri()) { + None + } else { + Option::as_ref(&state.auth) + } + })) } + router .data(Arc::new(State::new(conf, auth))) .get("/v1/status", status_handler) diff --git a/zenith_utils/src/http/endpoint.rs b/zenith_utils/src/http/endpoint.rs index 46b1acc932..e6239aaa5d 100644 --- a/zenith_utils/src/http/endpoint.rs +++ b/zenith_utils/src/http/endpoint.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::auth::{self, Claims, JwtAuth}; use crate::http::error; use crate::zid::ZTenantId; @@ -45,7 +43,7 @@ async fn prometheus_metrics_handler(_req: Request) -> Result RouterBuilder { +pub fn make_router() -> RouterBuilder { Router::builder() .middleware(Middleware::post_with_info(logger)) .get("/metrics", prometheus_metrics_handler) @@ -94,10 +92,6 @@ pub fn attach_openapi_ui( }) } -pub trait AuthProvider { - fn provide_auth(&self, req: &Request) -> Arc>; -} - fn parse_token(header_value: &str) -> Result<&str, ApiError> { // header must be in form Bearer let (prefix, token) = header_value.split_once(' ').ok_or(ApiError::Unauthorized( @@ -111,35 +105,30 @@ fn parse_token(header_value: &str) -> Result<&str, ApiError> { Ok(token) } -pub async fn auth_middleware( - req: Request, -) -> Result, ApiError> { - // unwrap is ok because this is called in auth middleware - // which should be enabled only when auth is some - let state_auth = req - .data::>() - .expect("state is always in request data") - .provide_auth(&req); +pub fn auth_middleware( + provide_auth: fn(&Request) -> Option<&JwtAuth>, +) -> Middleware { + Middleware::pre(move |req| async move { + if let Some(auth) = provide_auth(&req) { + match req.headers().get(AUTHORIZATION) { + Some(value) => { + let header_value = value.to_str().map_err(|_| { + ApiError::Unauthorized("malformed authorization header".to_string()) + })?; + let token = parse_token(header_value)?; - if let Some(auth) = state_auth.as_ref().as_ref() { - match req.headers().get(AUTHORIZATION) { - Some(value) => { - let header_value = value.to_str().map_err(|_| { - ApiError::Unauthorized("malformed authorization header".to_string()) - })?; - let token = parse_token(header_value)?; - - let data = auth - .decode(token) - .map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?; - req.set_context(data.claims); + let data = auth + .decode(token) + .map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?; + req.set_context(data.claims); + } + None => Err(ApiError::Unauthorized( + "missing authorization header".to_string(), + ))?, } - None => Err(ApiError::Unauthorized( - "missing authorization header".to_string(), - ))?, } - } - Ok(req) + Ok(req) + }) } pub fn check_permission(req: &Request, tenantid: Option) -> Result<(), ApiError> {