translate pageserver api to http

This commit is contained in:
Dmitry Rodionov
2021-08-10 01:13:37 +03:00
committed by Dmitry
parent 41fa02f82b
commit 23b5249512
28 changed files with 1062 additions and 236 deletions

18
Cargo.lock generated
View File

@@ -352,6 +352,7 @@ dependencies = [
"postgres_ffi",
"rand",
"regex",
"reqwest",
"serde",
"serde_json",
"tar",
@@ -1250,6 +1251,7 @@ dependencies = [
"futures",
"hex",
"humantime",
"hyper",
"lazy_static",
"log",
"postgres",
@@ -1259,6 +1261,7 @@ dependencies = [
"rand",
"regex",
"rocksdb",
"routerify",
"rust-s3",
"scopeguard",
"serde",
@@ -1690,6 +1693,19 @@ dependencies = [
"librocksdb-sys",
]
[[package]]
name = "routerify"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c6bb49594c791cadb5ccfa5f36d41b498d40482595c199d10cd318800280bd9"
dependencies = [
"http",
"hyper",
"lazy_static",
"percent-encoding",
"regex",
]
[[package]]
name = "rust-ini"
version = "0.16.1"
@@ -2668,7 +2684,9 @@ dependencies = [
"log",
"postgres",
"rand",
"routerify",
"serde",
"serde_json",
"thiserror",
"tokio",
"workspace_hack",

View File

@@ -20,6 +20,7 @@ bytes = "1.0.1"
nix = "0.20"
url = "2.2.2"
hex = { version = "0.4.3", features = ["serde"] }
reqwest = { version = "0.11", features = ["blocking", "json"] }
pageserver = { path = "../pageserver" }
walkeeper = { path = "../walkeeper" }

View File

@@ -332,7 +332,7 @@ impl PostgresNode {
};
// Configure that node to take pages from pageserver
let (host, port) = connection_host_port(&self.pageserver.connection_config);
let (host, port) = connection_host_port(&self.pageserver.pg_connection_config);
self.append_conf(
"postgresql.conf",
format!(

View File

@@ -5,10 +5,13 @@ use std::process::Command;
use std::thread;
use std::time::Duration;
use anyhow::{anyhow, bail, Result};
use anyhow::{anyhow, bail, ensure, Result};
use nix::sys::signal::{kill, Signal};
use nix::unistd::Pid;
use pageserver::http::models::{BranchCreateRequest, TenantCreateRequest};
use postgres::{Config, NoTls};
use reqwest::blocking::{Client, RequestBuilder};
use reqwest::{IntoUrl, Method, StatusCode};
use zenith_utils::postgres_backend::AuthType;
use zenith_utils::zid::ZTenantId;
@@ -17,6 +20,8 @@ use crate::read_pidfile;
use pageserver::branches::BranchInfo;
use zenith_utils::connstring::connection_address;
const HTTP_BASE_URL: &str = "http://127.0.0.1:9898/v1";
//
// Control routines for pageserver.
//
@@ -25,13 +30,15 @@ use zenith_utils::connstring::connection_address;
#[derive(Debug)]
pub struct PageServerNode {
pub kill_on_exit: bool,
pub connection_config: Config,
pub pg_connection_config: Config,
pub env: LocalEnv,
pub http_client: Client,
pub http_base_url: String,
}
impl PageServerNode {
pub fn from_env(env: &LocalEnv) -> PageServerNode {
let password = if matches!(env.auth_type, AuthType::ZenithJWT) {
let password = if env.auth_type == AuthType::ZenithJWT {
&env.auth_token
} else {
""
@@ -39,8 +46,10 @@ impl PageServerNode {
PageServerNode {
kill_on_exit: false,
connection_config: Self::default_config(password), // default
pg_connection_config: Self::default_config(password), // default
env: env.clone(),
http_client: Client::new(),
http_base_url: HTTP_BASE_URL.to_owned(),
}
}
@@ -100,7 +109,7 @@ impl PageServerNode {
pub fn start(&self) -> Result<()> {
println!(
"Starting pageserver at '{}' in {}",
connection_address(&self.connection_config),
connection_address(&self.pg_connection_config),
self.repo_path().display()
);
@@ -120,7 +129,7 @@ impl PageServerNode {
// It takes a while for the page server to start up. Wait until it is
// open for business.
for retries in 1..15 {
match self.page_server_psql_client() {
match self.check_status() {
Ok(_) => {
println!("Pageserver started");
return Ok(());
@@ -145,7 +154,7 @@ impl PageServerNode {
}
// wait for pageserver stop
let address = connection_address(&self.connection_config);
let address = connection_address(&self.pg_connection_config);
for _ in 0..5 {
let stream = TcpStream::connect(&address);
thread::sleep(Duration::from_secs(1));
@@ -160,50 +169,64 @@ impl PageServerNode {
}
pub fn page_server_psql(&self, sql: &str) -> Vec<postgres::SimpleQueryMessage> {
let mut client = self.connection_config.connect(NoTls).unwrap();
let mut client = self.pg_connection_config.connect(NoTls).unwrap();
println!("Pageserver query: '{}'", sql);
client.simple_query(sql).unwrap()
}
pub fn page_server_psql_client(&self) -> Result<postgres::Client, postgres::Error> {
self.connection_config.connect(NoTls)
self.pg_connection_config.connect(NoTls)
}
pub fn tenants_list(&self) -> Result<Vec<String>> {
let mut client = self.page_server_psql_client()?;
let query_result = client.simple_query("tenant_list")?;
let tenants_json = query_result
.first()
.map(|msg| match msg {
postgres::SimpleQueryMessage::Row(row) => row.get(0),
_ => None,
})
.flatten()
.ok_or_else(|| anyhow!("missing tenants"))?;
Ok(serde_json::from_str(tenants_json)?)
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
let mut builder = self.http_client.request(method, url);
if self.env.auth_type == AuthType::ZenithJWT {
builder = builder.bearer_auth(&self.env.auth_token)
}
builder
}
pub fn tenant_create(&self, tenantid: &ZTenantId) -> Result<()> {
let mut client = self.page_server_psql_client()?;
client.simple_query(format!("tenant_create {}", tenantid).as_str())?;
pub fn check_status(&self) -> Result<()> {
let status = self
.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status"))
.send()?
.status();
ensure!(
status == StatusCode::OK,
format!("got unexpected response status {}", status)
);
Ok(())
}
pub fn branches_list(&self, tenantid: &ZTenantId) -> Result<Vec<BranchInfo>> {
let mut client = self.page_server_psql_client()?;
let query_result = client.simple_query(&format!("branch_list {}", tenantid))?;
let branches_json = query_result
.first()
.map(|msg| match msg {
postgres::SimpleQueryMessage::Row(row) => row.get(0),
_ => None,
})
.flatten()
.ok_or_else(|| anyhow!("missing branches"))?;
pub fn tenant_list(&self) -> Result<Vec<String>> {
Ok(self
.http_request(Method::GET, format!("{}/{}", self.http_base_url, "tenant"))
.send()?
.error_for_status()?
.json()?)
}
Ok(serde_json::from_str(branches_json)?)
pub fn tenant_create(&self, tenantid: ZTenantId) -> Result<()> {
Ok(self
.http_request(Method::POST, format!("{}/{}", self.http_base_url, "tenant"))
.json(&TenantCreateRequest {
tenant_id: tenantid,
})
.send()?
.error_for_status()?
.json()?)
}
pub fn branch_list(&self, tenantid: &ZTenantId) -> Result<Vec<BranchInfo>> {
Ok(self
.http_request(
Method::GET,
format!("{}/branch/{}", self.http_base_url, tenantid),
)
.send()?
.error_for_status()?
.json()?)
}
pub fn branch_create(
@@ -212,29 +235,16 @@ impl PageServerNode {
startpoint: &str,
tenantid: &ZTenantId,
) -> Result<BranchInfo> {
let mut client = self.page_server_psql_client()?;
let query_result = client.simple_query(
format!("branch_create {} {} {}", tenantid, branch_name, startpoint).as_str(),
)?;
let branch_json = query_result
.first()
.map(|msg| match msg {
postgres::SimpleQueryMessage::Row(row) => row.get(0),
_ => None,
Ok(self
.http_request(Method::POST, format!("{}/{}", self.http_base_url, "branch"))
.json(&BranchCreateRequest {
tenant_id: tenantid.to_owned(),
name: branch_name.to_owned(),
start_point: startpoint.to_owned(),
})
.flatten()
.ok_or_else(|| anyhow!("missing branch"))?;
let res: BranchInfo = serde_json::from_str(branch_json).map_err(|e| {
anyhow!(
"failed to parse branch_create response: {}: {}",
branch_json,
e
)
})?;
Ok(res)
.send()?
.error_for_status()?
.json()?)
}
// TODO: make this a separate request type and avoid loading all the branches
@@ -243,14 +253,14 @@ impl PageServerNode {
tenantid: &ZTenantId,
branch_name: &str,
) -> Result<BranchInfo> {
let branch_infos = self.branches_list(tenantid)?;
let branche_by_name: Result<HashMap<String, BranchInfo>> = branch_infos
let branch_infos = self.branch_list(tenantid)?;
let branch_by_name: Result<HashMap<String, BranchInfo>> = branch_infos
.into_iter()
.map(|branch_info| Ok((branch_info.name.clone(), branch_info)))
.collect();
let branche_by_name = branche_by_name?;
let branch_by_name = branch_by_name?;
let branch = branche_by_name
let branch = branch_by_name
.get(branch_name)
.ok_or_else(|| anyhow!("Branch {} not found", branch_name))?;

View File

@@ -14,6 +14,7 @@ regex = "1.4.5"
bytes = { version = "1.0.1", features = ['serde'] }
byteorder = "1.4.3"
futures = "0.3.13"
hyper = "0.14"
lazy_static = "1.4.0"
slog-stdlog = "4.1.0"
slog-async = "2.6.0"
@@ -31,11 +32,12 @@ postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
# by default rust-rocksdb tries to build a lot of compression algos. Use lz4 only for now as it is simplest dependency.
rocksdb = { version = "0.16.0", features = ["lz4"], default-features = false }
routerify = "2"
anyhow = "1.0"
crc32c = "0.6.0"
walkdir = "2"
thiserror = "1.0"
hex = "0.4.3"
hex = { version = "0.4.3", features = ["serde"] }
tar = "0.4.33"
humantime = "2.1.0"
serde = { version = "1.0", features = ["derive"] }

View File

@@ -20,8 +20,10 @@ use anyhow::{ensure, Result};
use clap::{App, Arg, ArgMatches};
use daemonize::Daemonize;
use pageserver::{branches, logger, page_cache, page_service, PageServerConf, RepositoryFormat};
use zenith_utils::http_endpoint;
use pageserver::{
branches, http, logger, page_cache, page_service, PageServerConf, RepositoryFormat,
};
use zenith_utils::http::endpoint;
const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:64000";
const DEFAULT_HTTP_ENDPOINT_ADDR: &str = "127.0.0.1:9898";
@@ -323,19 +325,6 @@ fn start_pageserver(conf: &'static PageServerConf) -> Result<()> {
}
}
// Spawn a new thread for the http endpoint
thread::Builder::new()
.name("Metrics thread".into())
.spawn(move || http_endpoint::thread_main(conf.http_endpoint_addr.clone()))?;
// Check that we can bind to address before starting threads to simplify shutdown
// sequence if port is occupied.
info!("Starting pageserver on {}", conf.listen_addr);
let pageserver_listener = TcpListener::bind(conf.listen_addr.clone())?;
// Initialize page cache, this will spawn walredo_thread
page_cache::init(conf);
// initialize authentication for incoming connections
let auth = match &conf.auth_type {
AuthType::Trust | AuthType::MD5 => Arc::new(None),
@@ -346,6 +335,24 @@ fn start_pageserver(conf: &'static PageServerConf) -> Result<()> {
}
};
info!("Using auth: {:#?}", conf.auth_type);
// Spawn a new thread for the http endpoint
let cloned = Arc::clone(&auth);
thread::Builder::new()
.name("http_endpoint_thread".into())
.spawn(move || {
let router = http::get_router(conf, cloned);
endpoint::serve_thread_main(router, conf.http_endpoint_addr.clone())
})?;
// Check that we can bind to address before starting threads to simplify shutdown
// sequence if port is occupied.
info!("Starting pageserver on {}", conf.listen_addr);
let pageserver_listener = TcpListener::bind(conf.listen_addr.clone())?;
// Initialize page cache, this will spawn walredo_thread
page_cache::init(conf);
// Spawn a thread to listen for connections. It will spawn further threads
// for each connection.
let page_service_thread = thread::Builder::new()

View File

@@ -29,6 +29,7 @@ use crate::{repository::Repository, PageServerConf, RepositoryFormat};
#[derive(Serialize, Deserialize, Clone)]
pub struct BranchInfo {
pub name: String,
#[serde(with = "hex")]
pub timeline_id: ZTimelineId,
pub latest_valid_lsn: Option<Lsn>,
pub ancestor_id: Option<String>,
@@ -283,7 +284,7 @@ pub(crate) fn create_branch(
let end_of_wal = repo
.get_timeline(startpoint.timelineid)?
.get_last_record_lsn();
println!("branching at end of WAL: {}", end_of_wal);
info!("branching at end of WAL: {}", end_of_wal);
startpoint.lsn = end_of_wal;
}

View File

@@ -0,0 +1,3 @@
pub mod models;
pub mod routes;
pub use routes::get_router;

View File

@@ -0,0 +1,17 @@
use serde::{Deserialize, Serialize};
use crate::ZTenantId;
#[derive(Serialize, Deserialize)]
pub struct BranchCreateRequest {
#[serde(with = "hex")]
pub tenant_id: ZTenantId,
pub name: String,
pub start_point: String,
}
#[derive(Serialize, Deserialize)]
pub struct TenantCreateRequest {
#[serde(with = "hex")]
pub tenant_id: ZTenantId,
}

View File

@@ -0,0 +1,239 @@
openapi: "3.0.2"
info:
title: Page Server API
version: "1.0"
servers:
- url: ""
paths:
/v1/status:
description: Healthcheck endpoint
get:
description: Healthcheck
security: []
responses:
"200":
description: OK
content:
application/json:
schema:
type: object
/v1/branch/{tenant_id}:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
get:
description: Get branches for tenant
responses:
"200":
description: BranchInfo
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/BranchInfo"
"400":
description: Error when no tenant id found in path
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/branch/:
post:
description: Create branch
requestBody:
content:
application/json:
schema:
type: object
required:
- "tenant_id"
- "name"
- "start_point"
properties:
tenant_id:
type: string
format: hex
name:
type: string
start_point:
type: string
responses:
"201":
description: BranchInfo
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/BranchInfo"
"400":
description: Malformed branch create request
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/:
get:
description: Get tenants list
responses:
"200":
description: OK
content:
application/json:
schema:
type: array
items:
type: string
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
post:
description: Create tenant
requestBody:
content:
application/json:
schema:
type: object
required:
- "tenant_id"
properties:
tenant_id:
type: string
format: hex
responses:
"201":
description: CREATED
content:
application/json:
schema:
type: array
items:
type: string
"400":
description: Malformed tenant create request
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
components:
securitySchemes:
JWT:
type: http
scheme: bearer
bearerFormat: JWT
schemas:
BranchInfo:
type: object
required:
- name
- timeline_id
properties:
name:
type: string
timeline_id:
type: string
format: hex
ancestor_id:
type: string
ancestor_lsn:
type: string
Error:
type: object
required:
- msg
properties:
msg:
type: string
UnauthorizedError:
type: object
required:
- msg
properties:
msg:
type: string
ForbiddenError:
type: object
required:
- msg
properties:
msg:
type: string
security:
- JWT: []

View File

@@ -0,0 +1,164 @@
use std::sync::Arc;
use anyhow::Result;
use hyper::header;
use hyper::StatusCode;
use hyper::{Body, Request, Response, Uri};
use routerify::Middleware;
use routerify::{ext::RequestExt, RouterBuilder};
use zenith_utils::auth::JwtAuth;
use zenith_utils::http::endpoint::attach_openapi_ui;
use zenith_utils::http::endpoint::auth_middleware;
use zenith_utils::http::endpoint::check_permission;
use zenith_utils::http::endpoint::AuthProvider;
use zenith_utils::http::error::ApiError;
use zenith_utils::http::{
endpoint,
error::HttpErrorBody,
json::{json_request, json_response},
};
use super::models::BranchCreateRequest;
use super::models::TenantCreateRequest;
use crate::page_cache;
use crate::{
branches::{self},
PageServerConf, ZTenantId,
};
#[derive(Debug)]
struct State {
conf: &'static PageServerConf,
auth: Arc<Option<JwtAuth>>,
whitelist_routes: Vec<Uri>,
}
impl State {
fn new(conf: &'static PageServerConf, auth: Arc<Option<JwtAuth>>) -> Self {
let whitelist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"]
.iter()
.map(|v| v.parse().unwrap())
.collect::<Vec<_>>();
Self {
conf,
auth,
whitelist_routes,
}
}
}
impl AuthProvider for State {
fn provide_auth(&self, req: &Request<Body>) -> Arc<Option<JwtAuth>> {
if self.whitelist_routes.contains(req.uri()) {
Arc::new(None)
} else {
self.auth.clone()
}
}
}
// healthcheck handler
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
Ok(Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from("{}"))
.map_err(ApiError::from_err)?)
}
async fn branch_create_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
let state = request.data::<Arc<State>>().unwrap().clone();
let request_data: BranchCreateRequest = json_request(&mut request).await?;
check_permission(&request, Some(request_data.tenant_id))?;
let response_data = tokio::task::spawn_blocking(move || {
branches::create_branch(
state.conf,
&request_data.name,
&request_data.start_point,
&request_data.tenant_id,
)
})
.await
.map_err(ApiError::from_err)??;
Ok(json_response(StatusCode::CREATED, response_data)?)
}
async fn branch_list_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenantid: ZTenantId = match request.param("tenant_id") {
Some(arg) => arg
.parse()
.map_err(|_| ApiError::BadRequest("failed to parse tenant id".to_string()))?,
None => {
return Err(ApiError::BadRequest(
"no tenant id specified in path param".to_string(),
))
}
};
check_permission(&request, Some(tenantid))?;
let state = request.data::<Arc<State>>().unwrap().clone();
let response_data =
tokio::task::spawn_blocking(move || crate::branches::get_branches(state.conf, &tenantid))
.await
.map_err(ApiError::from_err)??;
Ok(json_response(StatusCode::OK, response_data)?)
}
async fn tenant_list_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
// check for management permission
check_permission(&request, None)?;
let state = request.data::<Arc<State>>().unwrap().clone();
let response_data =
tokio::task::spawn_blocking(move || crate::branches::get_tenants(state.conf))
.await
.map_err(ApiError::from_err)??;
Ok(json_response(StatusCode::OK, response_data)?)
}
async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
// check for management permission
check_permission(&request, None)?;
let state = request.data::<Arc<State>>().unwrap().clone();
let request_data: TenantCreateRequest = json_request(&mut request).await?;
let response_data = tokio::task::spawn_blocking(move || {
page_cache::create_repository_for_tenant(state.conf, request_data.tenant_id)
})
.await
.map_err(ApiError::from_err)??;
Ok(json_response(StatusCode::CREATED, response_data)?)
}
async fn handler_404(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(
StatusCode::NOT_FOUND,
HttpErrorBody::from_msg("page not found".to_owned()),
)
}
pub fn get_router(
conf: &'static PageServerConf,
auth: Arc<Option<JwtAuth>>,
) -> RouterBuilder<hyper::Body, ApiError> {
let spec = include_bytes!("openapi_spec.yml");
let mut router = attach_openapi_ui(endpoint::get_router(), spec, "/swagger.yml", "/v1/doc");
if let Some(_) = &auth.as_ref() {
// note that State is used as a type parameteer without an Arc
// this is a simple solution because it is not possible to implement
// AuthProvider for Arc<State> so middleware assumes that state is wrapped in Arc
router = router.middleware(Middleware::pre(auth_middleware::<State>))
}
router
.data(Arc::new(State::new(conf, auth)))
.get("/v1/status", status_handler)
.get("/v1/branch/:tenant_id", branch_list_handler)
.post("/v1/branch", branch_create_handler)
.get("/v1/tenant", tenant_list_handler)
.post("/v1/tenant", tenant_create_handler)
.any(handler_404)
}

View File

@@ -9,6 +9,7 @@ use zenith_metrics::{register_int_gauge_vec, IntGaugeVec};
pub mod basebackup;
pub mod branches;
pub mod http;
pub mod layered_repository;
pub mod logger;
pub mod object_key;

View File

@@ -23,7 +23,7 @@ use std::sync::Arc;
use std::thread;
use std::{io, net::TcpStream};
use zenith_metrics::{register_histogram_vec, HistogramVec};
use zenith_utils::auth::JwtAuth;
use zenith_utils::auth::{self, JwtAuth};
use zenith_utils::auth::{Claims, Scope};
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::postgres_backend::{self, AuthType};
@@ -389,19 +389,7 @@ impl PageServerHandler {
.claims
.as_ref()
.expect("claims presence already checked");
match (&claims.scope, tenantid) {
(Scope::Tenant, None) => {
bail!("Attempt to access management api with tenant scope. Permission denied")
}
(Scope::Tenant, Some(tenantid)) => {
if claims.tenant_id.unwrap() != tenantid {
bail!("Tenant id mismatch. Permission denied")
}
Ok(())
}
(Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope
(Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope
}
Ok(auth::check_permission(claims, tenantid)?)
}
}

View File

@@ -8,6 +8,7 @@ pytest = ">=6.0.0"
psycopg2 = "*"
typing-extensions = "*"
pyjwt = {extras = ["crypto"], version = "*"}
requests = "*"
[dev-packages]
yapf = "*"

View File

@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
"sha256": "f60a966726bcc19670402ad3fa57396b5dacf0a027544418ceb7cc0d42d94a52"
"sha256": "b666740289d9c82797e5c39b2a7f0074c865c9183ee878ce4fa5cda7928506ea"
},
"pipfile-spec": 6,
"requires": {
@@ -24,6 +24,13 @@
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'",
"version": "==21.2.0"
},
"certifi": {
"hashes": [
"sha256:2bbf76fd432960138b3ef6dda3dde0544f27cbf8546c458e60baf371917ba9ee",
"sha256:50b1e4f8446b06f41be7dd6338db18e0990601dce795c2b1686458aa7e8fa7d8"
],
"version": "==2021.5.30"
},
"cffi": {
"hashes": [
"sha256:06c54a68935738d206570b20da5ef2b6b6d92b38ef3ec45c5422c0ebaf338d4d",
@@ -74,6 +81,14 @@
],
"version": "==1.14.6"
},
"charset-normalizer": {
"hashes": [
"sha256:0c8911edd15d19223366a194a513099a302055a962bca2cec0f54b8b63175d8b",
"sha256:f23667ebe1084be45f6ae0538e4a5a865206544097e4e8bbcacf42cd02a348f3"
],
"markers": "python_version >= '3'",
"version": "==2.0.4"
},
"cryptography": {
"hashes": [
"sha256:0f1212a66329c80d68aeeb39b8a16d54ef57071bf22ff4e521657b27372e327d",
@@ -85,12 +100,22 @@
"sha256:3d8427734c781ea5f1b41d6589c293089704d4759e34597dce91014ac125aad1",
"sha256:7ec5d3b029f5fa2b179325908b9cd93db28ab7b85bb6c1db56b10e0b54235177",
"sha256:8e56e16617872b0957d1c9742a3f94b43533447fd78321514abbe7db216aa250",
"sha256:b01fd6f2737816cb1e08ed4807ae194404790eac7ad030b34f2ce72b332f5586",
"sha256:bf40af59ca2465b24e54f671b2de2c59257ddc4f7e5706dbd6930e26823668d3",
"sha256:de4e5f7f68220d92b7637fc99847475b59154b7a1b3868fb7385337af54ac9ca",
"sha256:eb8cc2afe8b05acbd84a43905832ec78e7b3873fb124ca190f574dca7389a87d",
"sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9"
],
"version": "==3.4.7"
},
"idna": {
"hashes": [
"sha256:14475042e284991034cb48e06f6851428fb14c4dc953acd9be9a5e95c7b6dd7a",
"sha256:467fbad99067910785144ce333826c71fb0e63a425657295239737f7ecd125f3"
],
"markers": "python_version >= '3'",
"version": "==3.2"
},
"iniconfig": {
"hashes": [
"sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3",
@@ -172,6 +197,14 @@
"index": "pypi",
"version": "==6.2.4"
},
"requests": {
"hashes": [
"sha256:6c1246513ecd5ecd4528a0906f910e8f0f9c6b8ec72030dc9fd154dc1a6efd24",
"sha256:b8aa58f8cf793ffd8782d3d8cb19e66ef36f7aba4353eec859e74678b01b07a7"
],
"index": "pypi",
"version": "==2.26.0"
},
"toml": {
"hashes": [
"sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",
@@ -188,6 +221,14 @@
],
"index": "pypi",
"version": "==3.10.0.0"
},
"urllib3": {
"hashes": [
"sha256:39fb8672126159acb139a7718dd10806104dec1e2f0f6c88aab05d17df10c8d4",
"sha256:f57b4c16c62fa2760b7e3d97c35b255512fb6b59a259730f36ba32ce9f8e342f"
],
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' and python_version < '4'",
"version": "==1.26.6"
}
},
"develop": {

View File

@@ -1,61 +1,17 @@
from contextlib import closing
from pathlib import Path
from uuid import uuid4
from dataclasses import dataclass
import jwt
import psycopg2
from fixtures.zenith_fixtures import Postgres, ZenithCli, ZenithPageserver
import pytest
@pytest.fixture
def pageserver_auth_enabled(zenith_cli: ZenithCli):
with ZenithPageserver(zenith_cli).init(enable_auth=True).start() as ps:
# For convenience in tests, create a branch from the freshly-initialized cluster.
zenith_cli.run(["branch", "empty", "main"])
yield ps
@dataclass
class AuthKeys:
pub: bytes
priv: bytes
def generate_management_token(self):
token = jwt.encode({"scope": "pageserverapi"}, self.priv, algorithm="RS256")
# jwt.encode can return 'bytes' or 'str', depending on Python version or type
# hinting or something (not sure what). If it returned 'bytes', convert it to 'str'
# explicitly.
if isinstance(token, bytes):
token = token.decode()
return token
def generate_tenant_token(self, tenant_id):
token = jwt.encode({"scope": "tenant", "tenant_id": tenant_id}, self.priv, algorithm="RS256")
if isinstance(token, bytes):
token = token.decode()
return token
@pytest.fixture
def auth_keys(repo_dir: str) -> AuthKeys:
# TODO probably this should be specified in cli config and used in tests for single source of truth
pub = (Path(repo_dir) / 'auth_public_key.pem').read_bytes()
priv = (Path(repo_dir) / 'auth_private_key.pem').read_bytes()
return AuthKeys(pub=pub, priv=priv)
def test_pageserver_auth(pageserver_auth_enabled: ZenithPageserver, auth_keys: AuthKeys):
def test_pageserver_auth(pageserver_auth_enabled: ZenithPageserver):
ps = pageserver_auth_enabled
tenant_token = auth_keys.generate_tenant_token(ps.initial_tenant)
invalid_tenant_token = auth_keys.generate_tenant_token(uuid4().hex)
management_token = auth_keys.generate_management_token()
tenant_token = ps.auth_keys.generate_tenant_token(ps.initial_tenant)
invalid_tenant_token = ps.auth_keys.generate_tenant_token(uuid4().hex)
management_token = ps.auth_keys.generate_management_token()
# this does not invoke auth check and only decodes jwt and checks it for validity
# check both tokens
@@ -86,12 +42,11 @@ def test_compute_auth_to_pageserver(
pageserver_auth_enabled: ZenithPageserver,
repo_dir: str,
with_wal_acceptors: bool,
auth_keys: AuthKeys,
):
ps = pageserver_auth_enabled
# since we are in progress of refactoring protocols between compute safekeeper and page server
# use hardcoded management token in safekeeper
management_token = auth_keys.generate_management_token()
management_token = ps.auth_keys.generate_management_token()
branch = f"test_compute_auth_to_pageserver{with_wal_acceptors}"
zenith_cli.run(["branch", branch, "empty"])

View File

@@ -1,19 +1,20 @@
import json
import uuid
from uuid import uuid4
import pytest
import psycopg2
from fixtures.zenith_fixtures import ZenithPageserver
import requests
from fixtures.zenith_fixtures import ZenithPageserver, ZenithPageserverHttpClient
pytest_plugins = ("fixtures.zenith_fixtures")
def test_status(pageserver):
def test_status_psql(pageserver):
assert pageserver.safe_psql('status') == [
('hello world', ),
]
def test_branch_list(pageserver: ZenithPageserver, zenith_cli):
def test_branch_list_psql(pageserver: ZenithPageserver, zenith_cli):
# Create a branch for us
zenith_cli.run(["branch", "test_branch_list_main", "empty"])
@@ -52,7 +53,7 @@ def test_branch_list(pageserver: ZenithPageserver, zenith_cli):
conn.close()
def test_tenant_list(pageserver: ZenithPageserver, zenith_cli):
def test_tenant_list_psql(pageserver: ZenithPageserver, zenith_cli):
res = zenith_cli.run(["tenant", "list"])
res.check_returncode()
tenants = res.stdout.splitlines()
@@ -66,7 +67,7 @@ def test_tenant_list(pageserver: ZenithPageserver, zenith_cli):
cur.execute(f'tenant_create {pageserver.initial_tenant}')
# create one more tenant
tenant1 = uuid.uuid4().hex
tenant1 = uuid4().hex
cur.execute(f'tenant_create {tenant1}')
cur.execute('tenant_list')
@@ -74,3 +75,32 @@ def test_tenant_list(pageserver: ZenithPageserver, zenith_cli):
# compare tenants list
new_tenants = sorted(json.loads(cur.fetchone()[0]))
assert sorted([pageserver.initial_tenant, tenant1]) == new_tenants
def check_client(client: ZenithPageserverHttpClient, initial_tenant: str):
client.check_status()
# check initial tenant is there
assert initial_tenant in set(client.tenant_list())
# create new tenant and check it is also there
tenant_id = uuid4()
client.tenant_create(tenant_id)
assert tenant_id.hex in set(client.tenant_list())
# create branch
branch_name = uuid4().hex
client.branch_create(tenant_id, branch_name, "main")
# check it is there
assert branch_name in {b['name'] for b in client.branch_list(tenant_id)}
def test_pageserver_http_api_client(pageserver: ZenithPageserver):
client = pageserver.http_client()
check_client(client, pageserver.initial_tenant)
def test_pageserver_http_api_client_auth_enabled(pageserver_auth_enabled: ZenithPageserver):
client = pageserver_auth_enabled.http_client(auth_token=pageserver_auth_enabled.auth_keys.generate_management_token())
check_client(client, pageserver_auth_enabled.initial_tenant)

View File

@@ -1,6 +1,9 @@
from dataclasses import dataclass
from functools import cached_property
import os
import pathlib
import uuid
import jwt
import psycopg2
import pytest
import shutil
@@ -17,6 +20,8 @@ from psycopg2.extensions import connection as PgConnection
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast
from typing_extensions import Literal
import requests
from .utils import (get_self_dir, mkdir_if_needed, subprocess_capture)
"""
This file contains pytest fixtures. A fixture is a test resource that can be
@@ -42,7 +47,8 @@ Fn = TypeVar('Fn', bound=Callable[..., Any])
DEFAULT_OUTPUT_DIR = 'test_output'
DEFAULT_POSTGRES_DIR = 'tmp_install'
DEFAULT_PAGESERVER_PORT = 64000
DEFAULT_PAGESERVER_PG_PORT = 64000
DEFAULT_PAGESERVER_HTTP_PORT = 9898
def determine_scope(fixture_name: str, config: Any) -> str:
@@ -175,13 +181,87 @@ def zenith_cli(zenith_binpath: str, repo_dir: str, pg_distrib_dir: str) -> Zenit
return ZenithCli(zenith_binpath, repo_dir, pg_distrib_dir)
class ZenithPageserverHttpClient(requests.Session):
def __init__(self, port: int, auth_token: Optional[str] = None) -> None:
super().__init__()
self.port = port
self.auth_token = auth_token
if auth_token is not None:
self.headers['Authorization'] = f'Bearer {auth_token}'
def check_status(self):
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
def branch_list(self, tenant_id: uuid.UUID) -> List[Dict]:
res = self.get(f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}")
res.raise_for_status()
return res.json()
def branch_create(self, tenant_id: uuid.UUID, name: str, start_point: str) -> Dict:
res = self.post(
f"http://localhost:{self.port}/v1/branch",
json={
'tenant_id': tenant_id.hex,
'name': name,
'start_point': start_point,
}
)
res.raise_for_status()
return res.json()
def tenant_list(self) -> List[str]:
res = self.get(f"http://localhost:{self.port}/v1/tenant")
res.raise_for_status()
return res.json()
def tenant_create(self, tenant_id: uuid.UUID):
res = self.post(
f"http://localhost:{self.port}/v1/tenant",
json={
'tenant_id': tenant_id.hex,
},
)
res.raise_for_status()
return res.json()
@dataclass
class AuthKeys:
pub: bytes
priv: bytes
def generate_management_token(self):
token = jwt.encode({"scope": "pageserverapi"}, self.priv, algorithm="RS256")
# jwt.encode can return 'bytes' or 'str', depending on Python version or type
# hinting or something (not sure what). If it returned 'bytes', convert it to 'str'
# explicitly.
if isinstance(token, bytes):
token = token.decode()
return token
def generate_tenant_token(self, tenant_id):
token = jwt.encode({"scope": "tenant", "tenant_id": tenant_id}, self.priv, algorithm="RS256")
if isinstance(token, bytes):
token = token.decode()
return token
class ZenithPageserver(PgProtocol):
""" An object representing a running pageserver. """
def __init__(self, zenith_cli: ZenithCli):
super().__init__(host='localhost', port=DEFAULT_PAGESERVER_PORT)
def __init__(self, zenith_cli: ZenithCli, repo_dir: str):
super().__init__(host='localhost', port=DEFAULT_PAGESERVER_PG_PORT)
self.zenith_cli = zenith_cli
self.running = False
self.initial_tenant = None
self.repo_dir = repo_dir
def init(self, enable_auth: bool = False) -> 'ZenithPageserver':
"""
@@ -224,9 +304,21 @@ class ZenithPageserver(PgProtocol):
def __exit__(self, exc_type, exc, tb):
self.stop()
@cached_property
def auth_keys(self) -> AuthKeys:
pub = (Path(self.repo_dir) / 'auth_public_key.pem').read_bytes()
priv = (Path(self.repo_dir) / 'auth_private_key.pem').read_bytes()
return AuthKeys(pub=pub, priv=priv)
def http_client(self, auth_token: Optional[str] = None):
return ZenithPageserverHttpClient(
port=DEFAULT_PAGESERVER_HTTP_PORT,
auth_token=auth_token,
)
@zenfixture
def pageserver(zenith_cli: ZenithCli) -> Iterator[ZenithPageserver]:
def pageserver(zenith_cli: ZenithCli, repo_dir: str) -> Iterator[ZenithPageserver]:
"""
The 'pageserver' fixture provides a Page Server that's up and running.
@@ -239,7 +331,7 @@ def pageserver(zenith_cli: ZenithCli) -> Iterator[ZenithPageserver]:
test called 'test_foo' would create and use branches with the 'test_foo' prefix.
"""
ps = ZenithPageserver(zenith_cli).init().start()
ps = ZenithPageserver(zenith_cli, repo_dir).init().start()
# For convenience in tests, create a branch from the freshly-initialized cluster.
zenith_cli.run(["branch", "empty", "main"])
@@ -250,6 +342,14 @@ def pageserver(zenith_cli: ZenithCli) -> Iterator[ZenithPageserver]:
ps.stop()
@pytest.fixture
def pageserver_auth_enabled(zenith_cli: ZenithCli, repo_dir: str):
with ZenithPageserver(zenith_cli, repo_dir).init(enable_auth=True).start() as ps:
# For convenience in tests, create a branch from the freshly-initialized cluster.
zenith_cli.run(["branch", "empty", "main"])
yield ps
class Postgres(PgProtocol):
""" An object representing a running postgres daemon. """
def __init__(self, zenith_cli: ZenithCli, repo_dir: str, tenant_id: str, port: int):
@@ -598,7 +698,7 @@ class WalAcceptor:
cmd.append("--daemonize")
cmd.append("--no-sync")
# Tell page server it can receive WAL from this WAL safekeeper
cmd.extend(["--pageserver", "localhost:{}".format(DEFAULT_PAGESERVER_PORT)])
cmd.extend(["--pageserver", "localhost:{}".format(DEFAULT_PAGESERVER_PG_PORT)])
cmd.extend(["--recall", "1 second"])
print('Running command "{}"'.format(' '.join(cmd)))
env = {'PAGESERVER_AUTH_TOKEN': self.auth_token} if self.auth_token else None

View File

@@ -224,7 +224,7 @@ fn main() -> Result<()> {
("pg", Some(pg_match)) => {
if let Err(e) = handle_pg(pg_match, &env) {
eprintln!("pg operation failed: {}", e);
eprintln!("pg operation failed: {:?}", e);
exit(1);
}
}
@@ -371,7 +371,7 @@ fn get_branch_infos(
tenantid: &ZTenantId,
) -> Result<HashMap<ZTimelineId, BranchInfo>> {
let page_server = PageServerNode::from_env(env);
let branch_infos: Vec<BranchInfo> = page_server.branches_list(tenantid)?;
let branch_infos: Vec<BranchInfo> = page_server.branch_list(tenantid)?;
let branch_infos: HashMap<ZTimelineId, BranchInfo> = branch_infos
.into_iter()
.map(|branch_info| (branch_info.timeline_id, branch_info))
@@ -384,7 +384,7 @@ fn handle_tenant(tenant_match: &ArgMatches, env: &local_env::LocalEnv) -> Result
let pageserver = PageServerNode::from_env(&env);
match tenant_match.subcommand() {
("list", Some(_)) => {
for tenant in pageserver.tenants_list()? {
for tenant in pageserver.tenant_list()? {
println!("{}", tenant);
}
}
@@ -394,7 +394,7 @@ fn handle_tenant(tenant_match: &ArgMatches, env: &local_env::LocalEnv) -> Result
None => ZTenantId::generate(),
};
println!("using tenant id {}", tenantid);
pageserver.tenant_create(&tenantid)?;
pageserver.tenant_create(tenantid)?;
println!("tenant successfully created on the pageserver");
}
_ => {}
@@ -424,7 +424,7 @@ fn handle_branch(branch_match: &ArgMatches, env: &local_env::LocalEnv) -> Result
.value_of("tenantid")
.map_or(Ok(env.tenantid), |value| value.parse())?;
// No arguments, list branches for tenant
let branches = pageserver.branches_list(&tenantid)?;
let branches = pageserver.branch_list(&tenantid)?;
print_branches_tree(branches)?;
}

View File

@@ -13,7 +13,9 @@ hyper = { version = "0.14.7", features = ["full"] }
lazy_static = "1.4.0"
log = "0.4.14"
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
routerify = "2"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
thiserror = "1.0"
tokio = { version = "1.5.0", features = ["full"] }

View File

@@ -10,7 +10,7 @@ use serde::de::Error;
use serde::{self, Deserializer, Serializer};
use std::{fs, path::PathBuf};
use anyhow::Result;
use anyhow::{bail, Result};
use jsonwebtoken::{
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
};
@@ -20,7 +20,7 @@ use crate::zid::ZTenantId;
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Scope {
Tenant,
@@ -48,7 +48,7 @@ where
}
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
// this custom serialize/deserialize_with is needed because Option is not transparent to serde
// so clearest option is serde(with = "hex") but it is not working, for details see https://github.com/serde-rs/serde/issues/1301
@@ -68,6 +68,22 @@ impl Claims {
}
}
pub fn check_permission(claims: &Claims, tenantid: Option<ZTenantId>) -> Result<()> {
match (&claims.scope, tenantid) {
(Scope::Tenant, None) => {
bail!("Attempt to access management api with tenant scope. Permission denied")
}
(Scope::Tenant, Some(tenantid)) => {
if claims.tenant_id.unwrap() != tenantid {
bail!("Tenant id mismatch. Permission denied")
}
Ok(())
}
(Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope
(Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope
}
}
#[derive(Debug)]
pub struct JwtAuth {
decoding_key: DecodingKey<'static>,

View File

@@ -0,0 +1,175 @@
use std::sync::Arc;
use crate::auth::{self, Claims, JwtAuth};
use crate::http::error;
use crate::zid::ZTenantId;
use anyhow::anyhow;
use hyper::header::AUTHORIZATION;
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server};
use lazy_static::lazy_static;
use routerify::ext::RequestExt;
use routerify::RequestInfo;
use routerify::{Middleware, Router, RouterBuilder, RouterService};
use zenith_metrics::{register_int_counter, IntCounter};
use zenith_metrics::{Encoder, TextEncoder};
use super::error::ApiError;
lazy_static! {
static ref SERVE_METRICS_COUNT: IntCounter = register_int_counter!(
"pageserver_serve_metrics_count",
"Number of metric requests made"
)
.expect("failed to define a metric");
}
async fn logger(res: Response<Body>, info: RequestInfo) -> Result<Response<Body>, ApiError> {
log::info!("{} {} {}", info.method(), info.uri().path(), res.status(),);
Ok(res)
}
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
SERVE_METRICS_COUNT.inc();
let mut buffer = vec![];
let encoder = TextEncoder::new();
let metrics = zenith_metrics::gather();
encoder.encode(&metrics, &mut buffer).unwrap();
let response = Response::builder()
.status(200)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.unwrap();
Ok(response)
}
pub fn get_router() -> RouterBuilder<hyper::Body, ApiError> {
Router::builder()
.middleware(Middleware::post_with_info(logger))
.get("/metrics", prometheus_metrics_handler)
.err_handler(error::handler)
}
pub fn attach_openapi_ui(
router_builder: RouterBuilder<hyper::Body, ApiError>,
spec: &'static [u8],
spec_mount_path: &'static str,
ui_mount_path: &'static str,
) -> RouterBuilder<hyper::Body, ApiError> {
router_builder.get(spec_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(spec)).unwrap())
}).get(ui_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
})
}
pub trait AuthProvider {
fn provide_auth(&self, req: &Request<Body>) -> Arc<Option<JwtAuth>>;
}
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
// header must be in form Bearer <token>
let (prefix, token) = header_value.split_once(' ').ok_or(ApiError::Unauthorized(
"malformed authorization header".to_string(),
))?;
if prefix != "Bearer" {
Err(ApiError::Unauthorized(
"malformed authorization header".to_string(),
))?
}
Ok(token)
}
pub async fn auth_middleware<S: AuthProvider + Send + Sync + 'static>(
req: Request<Body>,
) -> Result<Request<Body>, ApiError> {
// unwrap is ok because this is called in auth middleware
// which should be enabled only when auth is some
let state_auth = req
.data::<Arc<S>>()
.expect("state is always in request data")
.provide_auth(&req);
if let Some(auth) = state_auth.as_ref().as_ref() {
match req.headers().get(AUTHORIZATION) {
Some(value) => {
let header_value = value.to_str().map_err(|_| {
ApiError::Unauthorized("malformed authorization header".to_string())
})?;
let token = parse_token(header_value)?;
let data = auth
.decode(token)
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
req.set_context(data.claims);
}
None => Err(ApiError::Unauthorized(
"missing authorization header".to_string(),
))?,
}
}
Ok(req)
}
pub fn check_permission(req: &Request<Body>, tenantid: Option<ZTenantId>) -> Result<(), ApiError> {
match req.context::<Claims>() {
Some(claims) => Ok(auth::check_permission(&claims, tenantid)
.map_err(|err| ApiError::Forbidden(err.to_string()))?),
None => Ok(()), // claims is None because auth is disabled
}
}
pub fn serve_thread_main(
router_builder: RouterBuilder<hyper::Body, ApiError>,
addr: String,
) -> anyhow::Result<()> {
let addr = addr.parse()?;
log::info!("Starting a http endoint at {}", addr);
// Create a Service from the router above to handle incoming requests.
let service = RouterService::new(router_builder.build().map_err(|err| anyhow!(err))?).unwrap();
// Enter a single-threaded tokio runtime bound to the current thread
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let _guard = runtime.enter();
let server = Server::bind(&addr).serve(service);
runtime.block_on(server)?;
Ok(())
}

View File

@@ -0,0 +1,77 @@
use anyhow::anyhow;
use hyper::{header, Body, Response, StatusCode};
use serde::Serialize;
use serde_json;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error(transparent)]
InternalServerError(#[from] anyhow::Error),
}
impl ApiError {
pub fn from_err<E: Into<anyhow::Error>>(err: E) -> Self {
Self::InternalServerError(anyhow!(err))
}
pub fn into_response(self) -> Response<Body> {
match self {
ApiError::BadRequest(_) => HttpErrorBody::response_from_msg_and_status(
self.to_string(),
StatusCode::BAD_REQUEST,
),
ApiError::Forbidden(_) => {
HttpErrorBody::response_from_msg_and_status(self.to_string(), StatusCode::FORBIDDEN)
}
ApiError::Unauthorized(_) => HttpErrorBody::response_from_msg_and_status(
self.to_string(),
StatusCode::UNAUTHORIZED,
),
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::INTERNAL_SERVER_ERROR,
),
}
}
}
#[derive(Serialize)]
pub struct HttpErrorBody {
pub msg: String,
}
impl HttpErrorBody {
pub fn from_msg(msg: String) -> Self {
HttpErrorBody { msg }
}
pub fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Body> {
HttpErrorBody { msg }.into_response(status)
}
pub fn into_response(&self, status: StatusCode) -> Response<Body> {
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(Body::from(serde_json::to_string(self).unwrap()))
.unwrap()
}
}
pub async fn handler(err: routerify::RouteError) -> Response<Body> {
log::error!("{}", err);
err.downcast::<ApiError>()
.expect("handler should always return api error")
.into_response()
}

View File

@@ -0,0 +1,29 @@
use bytes::Buf;
use hyper::{header, Body, Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json;
use super::error::ApiError;
pub async fn json_request<T: for<'de> Deserialize<'de>>(
request: &mut Request<Body>,
) -> Result<T, ApiError> {
let whole_body = hyper::body::aggregate(request.body_mut())
.await
.map_err(ApiError::from_err)?;
Ok(serde_json::from_reader(whole_body.reader())
.map_err(|err| ApiError::BadRequest(format!("Failed to parse json request {}", err)))?)
}
pub fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<Body>, ApiError> {
let json = serde_json::to_string(&data).map_err(ApiError::from_err)?;
let response = Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(json))
.map_err(ApiError::from_err)?;
Ok(response)
}

View File

@@ -0,0 +1,3 @@
pub mod endpoint;
pub mod error;
pub mod json;

View File

@@ -1,53 +0,0 @@
use hyper::{
header::CONTENT_TYPE,
service::{make_service_fn, service_fn},
Body, Request, Response, Server,
};
use lazy_static::lazy_static;
use zenith_metrics::{register_int_counter, IntCounter};
use zenith_metrics::{Encoder, TextEncoder};
lazy_static! {
static ref SERVE_METRICS_COUNT: IntCounter = register_int_counter!(
"pageserver_serve_metrics_count",
"Number of metric requests made"
)
.expect("failed to define a metric");
}
async fn serve_prometheus_metrics(_req: Request<Body>) -> anyhow::Result<Response<Body>> {
SERVE_METRICS_COUNT.inc();
let mut buffer = vec![];
let encoder = TextEncoder::new();
let metrics = zenith_metrics::gather();
encoder.encode(&metrics, &mut buffer).unwrap();
let response = Response::builder()
.status(200)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.unwrap();
Ok(response)
}
pub fn thread_main(addr: String) -> anyhow::Result<()> {
let addr = addr.parse()?;
log::info!("Starting a prometheus endoint at {}", addr);
// Enter a single-threaded tokio runtime bound to the current thread
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let _guard = runtime.enter();
// TODO: use hyper_router/routerify/etc when we have more methods
let server = Server::bind(&addr).serve(make_service_fn(|_| async {
Ok::<_, anyhow::Error>(service_fn(serve_prometheus_metrics))
}));
runtime.block_on(server)?;
Ok(())
}

View File

@@ -10,7 +10,6 @@ pub mod seqwait;
// pub mod seqwait_async;
pub mod bin_ser;
pub mod http_endpoint;
pub mod postgres_backend;
pub mod pq_proto;
@@ -22,3 +21,5 @@ pub mod auth;
// utility functions and helper traits for unified unique id generation/serialization etc.
pub mod zid;
// http endpoint utils
pub mod http;

View File

@@ -96,6 +96,20 @@ macro_rules! zid_newtype {
}
}
impl FromHex for $t {
type Error = hex::FromHexError;
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
Ok($t(ZId::from_hex(hex)?))
}
}
impl AsRef<[u8]> for $t {
fn as_ref(&self) -> &[u8] {
&self.0 .0
}
}
impl fmt::Display for $t {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
@@ -137,19 +151,3 @@ zid_newtype!(ZTimelineId);
pub struct ZTenantId(ZId);
zid_newtype!(ZTenantId);
// for now the following impls are used only with ZTenantId,
// if this impls become useful in other newtypes they can be moved under zid_newtype macro too
impl FromHex for ZTenantId {
type Error = hex::FromHexError;
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
Ok(ZTenantId(ZId::from_hex(hex)?))
}
}
impl AsRef<[u8]> for ZTenantId {
fn as_ref(&self) -> &[u8] {
&self.0 .0
}
}