mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-06 13:02:55 +00:00
review adjustments
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::auth::{self, Claims, JwtAuth};
|
||||
use crate::http::error;
|
||||
use crate::zid::ZTenantId;
|
||||
@@ -45,7 +43,7 @@ async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub fn get_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
Router::builder()
|
||||
.middleware(Middleware::post_with_info(logger))
|
||||
.get("/metrics", prometheus_metrics_handler)
|
||||
@@ -94,10 +92,6 @@ pub fn attach_openapi_ui(
|
||||
})
|
||||
}
|
||||
|
||||
pub trait AuthProvider {
|
||||
fn provide_auth(&self, req: &Request<Body>) -> Arc<Option<JwtAuth>>;
|
||||
}
|
||||
|
||||
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
||||
// header must be in form Bearer <token>
|
||||
let (prefix, token) = header_value.split_once(' ').ok_or(ApiError::Unauthorized(
|
||||
@@ -111,35 +105,30 @@ fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub async fn auth_middleware<S: AuthProvider + Send + Sync + 'static>(
|
||||
req: Request<Body>,
|
||||
) -> Result<Request<Body>, ApiError> {
|
||||
// unwrap is ok because this is called in auth middleware
|
||||
// which should be enabled only when auth is some
|
||||
let state_auth = req
|
||||
.data::<Arc<S>>()
|
||||
.expect("state is always in request data")
|
||||
.provide_auth(&req);
|
||||
pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
||||
provide_auth: fn(&Request<Body>) -> Option<&JwtAuth>,
|
||||
) -> Middleware<B, ApiError> {
|
||||
Middleware::pre(move |req| async move {
|
||||
if let Some(auth) = provide_auth(&req) {
|
||||
match req.headers().get(AUTHORIZATION) {
|
||||
Some(value) => {
|
||||
let header_value = value.to_str().map_err(|_| {
|
||||
ApiError::Unauthorized("malformed authorization header".to_string())
|
||||
})?;
|
||||
let token = parse_token(header_value)?;
|
||||
|
||||
if let Some(auth) = state_auth.as_ref().as_ref() {
|
||||
match req.headers().get(AUTHORIZATION) {
|
||||
Some(value) => {
|
||||
let header_value = value.to_str().map_err(|_| {
|
||||
ApiError::Unauthorized("malformed authorization header".to_string())
|
||||
})?;
|
||||
let token = parse_token(header_value)?;
|
||||
|
||||
let data = auth
|
||||
.decode(token)
|
||||
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
|
||||
req.set_context(data.claims);
|
||||
let data = auth
|
||||
.decode(token)
|
||||
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
|
||||
req.set_context(data.claims);
|
||||
}
|
||||
None => Err(ApiError::Unauthorized(
|
||||
"missing authorization header".to_string(),
|
||||
))?,
|
||||
}
|
||||
None => Err(ApiError::Unauthorized(
|
||||
"missing authorization header".to_string(),
|
||||
))?,
|
||||
}
|
||||
}
|
||||
Ok(req)
|
||||
Ok(req)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn check_permission(req: &Request<Body>, tenantid: Option<ZTenantId>) -> Result<(), ApiError> {
|
||||
|
||||
Reference in New Issue
Block a user