review adjustments

This commit is contained in:
Dmitry Rodionov
2021-08-18 14:28:30 +03:00
committed by Dmitry
parent 23b5249512
commit b135723994
4 changed files with 58 additions and 66 deletions

View File

@@ -341,7 +341,7 @@ fn start_pageserver(conf: &'static PageServerConf) -> Result<()> {
thread::Builder::new()
.name("http_endpoint_thread".into())
.spawn(move || {
let router = http::get_router(conf, cloned);
let router = http::make_router(conf, cloned);
endpoint::serve_thread_main(router, conf.http_endpoint_addr.clone())
})?;

View File

@@ -1,3 +1,3 @@
pub mod models;
pub mod routes;
pub use routes::get_router;
pub use routes::make_router;

View File

@@ -4,13 +4,11 @@ use anyhow::Result;
use hyper::header;
use hyper::StatusCode;
use hyper::{Body, Request, Response, Uri};
use routerify::Middleware;
use routerify::{ext::RequestExt, RouterBuilder};
use zenith_utils::auth::JwtAuth;
use zenith_utils::http::endpoint::attach_openapi_ui;
use zenith_utils::http::endpoint::auth_middleware;
use zenith_utils::http::endpoint::check_permission;
use zenith_utils::http::endpoint::AuthProvider;
use zenith_utils::http::error::ApiError;
use zenith_utils::http::{
endpoint,
@@ -30,31 +28,34 @@ use crate::{
struct State {
conf: &'static PageServerConf,
auth: Arc<Option<JwtAuth>>,
whitelist_routes: Vec<Uri>,
allowlist_routes: Vec<Uri>,
}
impl State {
fn new(conf: &'static PageServerConf, auth: Arc<Option<JwtAuth>>) -> Self {
let whitelist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"]
let allowlist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"]
.iter()
.map(|v| v.parse().unwrap())
.collect::<Vec<_>>();
Self {
conf,
auth,
whitelist_routes,
allowlist_routes,
}
}
}
impl AuthProvider for State {
fn provide_auth(&self, req: &Request<Body>) -> Arc<Option<JwtAuth>> {
if self.whitelist_routes.contains(req.uri()) {
Arc::new(None)
} else {
self.auth.clone()
}
}
#[inline(always)]
fn get_state(request: &Request<Body>) -> &State {
request
.data::<Arc<State>>()
.expect("unknown state type")
.as_ref()
}
#[inline(always)]
fn get_config(request: &Request<Body>) -> &'static PageServerConf {
get_state(request).conf
}
// healthcheck handler
@@ -67,14 +68,13 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
}
async fn branch_create_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
let state = request.data::<Arc<State>>().unwrap().clone();
let request_data: BranchCreateRequest = json_request(&mut request).await?;
check_permission(&request, Some(request_data.tenant_id))?;
let response_data = tokio::task::spawn_blocking(move || {
branches::create_branch(
state.conf,
get_config(&request),
&request_data.name,
&request_data.start_point,
&request_data.tenant_id,
@@ -99,11 +99,11 @@ async fn branch_list_handler(request: Request<Body>) -> Result<Response<Body>, A
check_permission(&request, Some(tenantid))?;
let state = request.data::<Arc<State>>().unwrap().clone();
let response_data =
tokio::task::spawn_blocking(move || crate::branches::get_branches(state.conf, &tenantid))
.await
.map_err(ApiError::from_err)??;
let response_data = tokio::task::spawn_blocking(move || {
crate::branches::get_branches(get_config(&request), &tenantid)
})
.await
.map_err(ApiError::from_err)??;
Ok(json_response(StatusCode::OK, response_data)?)
}
@@ -111,9 +111,8 @@ async fn tenant_list_handler(request: Request<Body>) -> Result<Response<Body>, A
// check for management permission
check_permission(&request, None)?;
let state = request.data::<Arc<State>>().unwrap().clone();
let response_data =
tokio::task::spawn_blocking(move || crate::branches::get_tenants(state.conf))
tokio::task::spawn_blocking(move || crate::branches::get_tenants(get_config(&request)))
.await
.map_err(ApiError::from_err)??;
Ok(json_response(StatusCode::OK, response_data)?)
@@ -123,11 +122,10 @@ async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Bo
// check for management permission
check_permission(&request, None)?;
let state = request.data::<Arc<State>>().unwrap().clone();
let request_data: TenantCreateRequest = json_request(&mut request).await?;
let response_data = tokio::task::spawn_blocking(move || {
page_cache::create_repository_for_tenant(state.conf, request_data.tenant_id)
page_cache::create_repository_for_tenant(get_config(&request), request_data.tenant_id)
})
.await
.map_err(ApiError::from_err)??;
@@ -141,18 +139,23 @@ async fn handler_404(_: Request<Body>) -> Result<Response<Body>, ApiError> {
)
}
pub fn get_router(
pub fn make_router(
conf: &'static PageServerConf,
auth: Arc<Option<JwtAuth>>,
) -> RouterBuilder<hyper::Body, ApiError> {
let spec = include_bytes!("openapi_spec.yml");
let mut router = attach_openapi_ui(endpoint::get_router(), spec, "/swagger.yml", "/v1/doc");
if let Some(_) = &auth.as_ref() {
// note that State is used as a type parameteer without an Arc
// this is a simple solution because it is not possible to implement
// AuthProvider for Arc<State> so middleware assumes that state is wrapped in Arc
router = router.middleware(Middleware::pre(auth_middleware::<State>))
let mut router = attach_openapi_ui(endpoint::make_router(), spec, "/swagger.yml", "/v1/doc");
if auth.is_some() {
router = router.middleware(auth_middleware(|request| {
let state = get_state(request);
if state.allowlist_routes.contains(request.uri()) {
None
} else {
Option::as_ref(&state.auth)
}
}))
}
router
.data(Arc::new(State::new(conf, auth)))
.get("/v1/status", status_handler)

View File

@@ -1,5 +1,3 @@
use std::sync::Arc;
use crate::auth::{self, Claims, JwtAuth};
use crate::http::error;
use crate::zid::ZTenantId;
@@ -45,7 +43,7 @@ async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body
Ok(response)
}
pub fn get_router() -> RouterBuilder<hyper::Body, ApiError> {
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
Router::builder()
.middleware(Middleware::post_with_info(logger))
.get("/metrics", prometheus_metrics_handler)
@@ -94,10 +92,6 @@ pub fn attach_openapi_ui(
})
}
pub trait AuthProvider {
fn provide_auth(&self, req: &Request<Body>) -> Arc<Option<JwtAuth>>;
}
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
// header must be in form Bearer <token>
let (prefix, token) = header_value.split_once(' ').ok_or(ApiError::Unauthorized(
@@ -111,35 +105,30 @@ fn parse_token(header_value: &str) -> Result<&str, ApiError> {
Ok(token)
}
pub async fn auth_middleware<S: AuthProvider + Send + Sync + 'static>(
req: Request<Body>,
) -> Result<Request<Body>, ApiError> {
// unwrap is ok because this is called in auth middleware
// which should be enabled only when auth is some
let state_auth = req
.data::<Arc<S>>()
.expect("state is always in request data")
.provide_auth(&req);
pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
provide_auth: fn(&Request<Body>) -> Option<&JwtAuth>,
) -> Middleware<B, ApiError> {
Middleware::pre(move |req| async move {
if let Some(auth) = provide_auth(&req) {
match req.headers().get(AUTHORIZATION) {
Some(value) => {
let header_value = value.to_str().map_err(|_| {
ApiError::Unauthorized("malformed authorization header".to_string())
})?;
let token = parse_token(header_value)?;
if let Some(auth) = state_auth.as_ref().as_ref() {
match req.headers().get(AUTHORIZATION) {
Some(value) => {
let header_value = value.to_str().map_err(|_| {
ApiError::Unauthorized("malformed authorization header".to_string())
})?;
let token = parse_token(header_value)?;
let data = auth
.decode(token)
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
req.set_context(data.claims);
let data = auth
.decode(token)
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
req.set_context(data.claims);
}
None => Err(ApiError::Unauthorized(
"missing authorization header".to_string(),
))?,
}
None => Err(ApiError::Unauthorized(
"missing authorization header".to_string(),
))?,
}
}
Ok(req)
Ok(req)
})
}
pub fn check_permission(req: &Request<Body>, tenantid: Option<ZTenantId>) -> Result<(), ApiError> {