mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
review adjustments
This commit is contained in:
@@ -341,7 +341,7 @@ fn start_pageserver(conf: &'static PageServerConf) -> Result<()> {
|
||||
thread::Builder::new()
|
||||
.name("http_endpoint_thread".into())
|
||||
.spawn(move || {
|
||||
let router = http::get_router(conf, cloned);
|
||||
let router = http::make_router(conf, cloned);
|
||||
endpoint::serve_thread_main(router, conf.http_endpoint_addr.clone())
|
||||
})?;
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
pub mod models;
|
||||
pub mod routes;
|
||||
pub use routes::get_router;
|
||||
pub use routes::make_router;
|
||||
|
||||
@@ -4,13 +4,11 @@ use anyhow::Result;
|
||||
use hyper::header;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, Request, Response, Uri};
|
||||
use routerify::Middleware;
|
||||
use routerify::{ext::RequestExt, RouterBuilder};
|
||||
use zenith_utils::auth::JwtAuth;
|
||||
use zenith_utils::http::endpoint::attach_openapi_ui;
|
||||
use zenith_utils::http::endpoint::auth_middleware;
|
||||
use zenith_utils::http::endpoint::check_permission;
|
||||
use zenith_utils::http::endpoint::AuthProvider;
|
||||
use zenith_utils::http::error::ApiError;
|
||||
use zenith_utils::http::{
|
||||
endpoint,
|
||||
@@ -30,31 +28,34 @@ use crate::{
|
||||
struct State {
|
||||
conf: &'static PageServerConf,
|
||||
auth: Arc<Option<JwtAuth>>,
|
||||
whitelist_routes: Vec<Uri>,
|
||||
allowlist_routes: Vec<Uri>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(conf: &'static PageServerConf, auth: Arc<Option<JwtAuth>>) -> Self {
|
||||
let whitelist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"]
|
||||
let allowlist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"]
|
||||
.iter()
|
||||
.map(|v| v.parse().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
Self {
|
||||
conf,
|
||||
auth,
|
||||
whitelist_routes,
|
||||
allowlist_routes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthProvider for State {
|
||||
fn provide_auth(&self, req: &Request<Body>) -> Arc<Option<JwtAuth>> {
|
||||
if self.whitelist_routes.contains(req.uri()) {
|
||||
Arc::new(None)
|
||||
} else {
|
||||
self.auth.clone()
|
||||
}
|
||||
}
|
||||
#[inline(always)]
|
||||
fn get_state(request: &Request<Body>) -> &State {
|
||||
request
|
||||
.data::<Arc<State>>()
|
||||
.expect("unknown state type")
|
||||
.as_ref()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_config(request: &Request<Body>) -> &'static PageServerConf {
|
||||
get_state(request).conf
|
||||
}
|
||||
|
||||
// healthcheck handler
|
||||
@@ -67,14 +68,13 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
}
|
||||
|
||||
async fn branch_create_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let state = request.data::<Arc<State>>().unwrap().clone();
|
||||
let request_data: BranchCreateRequest = json_request(&mut request).await?;
|
||||
|
||||
check_permission(&request, Some(request_data.tenant_id))?;
|
||||
|
||||
let response_data = tokio::task::spawn_blocking(move || {
|
||||
branches::create_branch(
|
||||
state.conf,
|
||||
get_config(&request),
|
||||
&request_data.name,
|
||||
&request_data.start_point,
|
||||
&request_data.tenant_id,
|
||||
@@ -99,11 +99,11 @@ async fn branch_list_handler(request: Request<Body>) -> Result<Response<Body>, A
|
||||
|
||||
check_permission(&request, Some(tenantid))?;
|
||||
|
||||
let state = request.data::<Arc<State>>().unwrap().clone();
|
||||
let response_data =
|
||||
tokio::task::spawn_blocking(move || crate::branches::get_branches(state.conf, &tenantid))
|
||||
.await
|
||||
.map_err(ApiError::from_err)??;
|
||||
let response_data = tokio::task::spawn_blocking(move || {
|
||||
crate::branches::get_branches(get_config(&request), &tenantid)
|
||||
})
|
||||
.await
|
||||
.map_err(ApiError::from_err)??;
|
||||
Ok(json_response(StatusCode::OK, response_data)?)
|
||||
}
|
||||
|
||||
@@ -111,9 +111,8 @@ async fn tenant_list_handler(request: Request<Body>) -> Result<Response<Body>, A
|
||||
// check for management permission
|
||||
check_permission(&request, None)?;
|
||||
|
||||
let state = request.data::<Arc<State>>().unwrap().clone();
|
||||
let response_data =
|
||||
tokio::task::spawn_blocking(move || crate::branches::get_tenants(state.conf))
|
||||
tokio::task::spawn_blocking(move || crate::branches::get_tenants(get_config(&request)))
|
||||
.await
|
||||
.map_err(ApiError::from_err)??;
|
||||
Ok(json_response(StatusCode::OK, response_data)?)
|
||||
@@ -123,11 +122,10 @@ async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Bo
|
||||
// check for management permission
|
||||
check_permission(&request, None)?;
|
||||
|
||||
let state = request.data::<Arc<State>>().unwrap().clone();
|
||||
let request_data: TenantCreateRequest = json_request(&mut request).await?;
|
||||
|
||||
let response_data = tokio::task::spawn_blocking(move || {
|
||||
page_cache::create_repository_for_tenant(state.conf, request_data.tenant_id)
|
||||
page_cache::create_repository_for_tenant(get_config(&request), request_data.tenant_id)
|
||||
})
|
||||
.await
|
||||
.map_err(ApiError::from_err)??;
|
||||
@@ -141,18 +139,23 @@ async fn handler_404(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_router(
|
||||
pub fn make_router(
|
||||
conf: &'static PageServerConf,
|
||||
auth: Arc<Option<JwtAuth>>,
|
||||
) -> RouterBuilder<hyper::Body, ApiError> {
|
||||
let spec = include_bytes!("openapi_spec.yml");
|
||||
let mut router = attach_openapi_ui(endpoint::get_router(), spec, "/swagger.yml", "/v1/doc");
|
||||
if let Some(_) = &auth.as_ref() {
|
||||
// note that State is used as a type parameteer without an Arc
|
||||
// this is a simple solution because it is not possible to implement
|
||||
// AuthProvider for Arc<State> so middleware assumes that state is wrapped in Arc
|
||||
router = router.middleware(Middleware::pre(auth_middleware::<State>))
|
||||
let mut router = attach_openapi_ui(endpoint::make_router(), spec, "/swagger.yml", "/v1/doc");
|
||||
if auth.is_some() {
|
||||
router = router.middleware(auth_middleware(|request| {
|
||||
let state = get_state(request);
|
||||
if state.allowlist_routes.contains(request.uri()) {
|
||||
None
|
||||
} else {
|
||||
Option::as_ref(&state.auth)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
router
|
||||
.data(Arc::new(State::new(conf, auth)))
|
||||
.get("/v1/status", status_handler)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::auth::{self, Claims, JwtAuth};
|
||||
use crate::http::error;
|
||||
use crate::zid::ZTenantId;
|
||||
@@ -45,7 +43,7 @@ async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub fn get_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
Router::builder()
|
||||
.middleware(Middleware::post_with_info(logger))
|
||||
.get("/metrics", prometheus_metrics_handler)
|
||||
@@ -94,10 +92,6 @@ pub fn attach_openapi_ui(
|
||||
})
|
||||
}
|
||||
|
||||
pub trait AuthProvider {
|
||||
fn provide_auth(&self, req: &Request<Body>) -> Arc<Option<JwtAuth>>;
|
||||
}
|
||||
|
||||
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
||||
// header must be in form Bearer <token>
|
||||
let (prefix, token) = header_value.split_once(' ').ok_or(ApiError::Unauthorized(
|
||||
@@ -111,35 +105,30 @@ fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub async fn auth_middleware<S: AuthProvider + Send + Sync + 'static>(
|
||||
req: Request<Body>,
|
||||
) -> Result<Request<Body>, ApiError> {
|
||||
// unwrap is ok because this is called in auth middleware
|
||||
// which should be enabled only when auth is some
|
||||
let state_auth = req
|
||||
.data::<Arc<S>>()
|
||||
.expect("state is always in request data")
|
||||
.provide_auth(&req);
|
||||
pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
||||
provide_auth: fn(&Request<Body>) -> Option<&JwtAuth>,
|
||||
) -> Middleware<B, ApiError> {
|
||||
Middleware::pre(move |req| async move {
|
||||
if let Some(auth) = provide_auth(&req) {
|
||||
match req.headers().get(AUTHORIZATION) {
|
||||
Some(value) => {
|
||||
let header_value = value.to_str().map_err(|_| {
|
||||
ApiError::Unauthorized("malformed authorization header".to_string())
|
||||
})?;
|
||||
let token = parse_token(header_value)?;
|
||||
|
||||
if let Some(auth) = state_auth.as_ref().as_ref() {
|
||||
match req.headers().get(AUTHORIZATION) {
|
||||
Some(value) => {
|
||||
let header_value = value.to_str().map_err(|_| {
|
||||
ApiError::Unauthorized("malformed authorization header".to_string())
|
||||
})?;
|
||||
let token = parse_token(header_value)?;
|
||||
|
||||
let data = auth
|
||||
.decode(token)
|
||||
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
|
||||
req.set_context(data.claims);
|
||||
let data = auth
|
||||
.decode(token)
|
||||
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
|
||||
req.set_context(data.claims);
|
||||
}
|
||||
None => Err(ApiError::Unauthorized(
|
||||
"missing authorization header".to_string(),
|
||||
))?,
|
||||
}
|
||||
None => Err(ApiError::Unauthorized(
|
||||
"missing authorization header".to_string(),
|
||||
))?,
|
||||
}
|
||||
}
|
||||
Ok(req)
|
||||
Ok(req)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn check_permission(req: &Request<Body>, tenantid: Option<ZTenantId>) -> Result<(), ApiError> {
|
||||
|
||||
Reference in New Issue
Block a user