mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-01 04:20:39 +00:00
translate pageserver api to http
This commit is contained in:
18
Cargo.lock
generated
18
Cargo.lock
generated
@@ -352,6 +352,7 @@ dependencies = [
|
||||
"postgres_ffi",
|
||||
"rand",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tar",
|
||||
@@ -1250,6 +1251,7 @@ dependencies = [
|
||||
"futures",
|
||||
"hex",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"postgres",
|
||||
@@ -1259,6 +1261,7 @@ dependencies = [
|
||||
"rand",
|
||||
"regex",
|
||||
"rocksdb",
|
||||
"routerify",
|
||||
"rust-s3",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
@@ -1690,6 +1693,19 @@ dependencies = [
|
||||
"librocksdb-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "routerify"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c6bb49594c791cadb5ccfa5f36d41b498d40482595c199d10cd318800280bd9"
|
||||
dependencies = [
|
||||
"http",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"percent-encoding",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-ini"
|
||||
version = "0.16.1"
|
||||
@@ -2668,7 +2684,9 @@ dependencies = [
|
||||
"log",
|
||||
"postgres",
|
||||
"rand",
|
||||
"routerify",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"workspace_hack",
|
||||
|
||||
@@ -20,6 +20,7 @@ bytes = "1.0.1"
|
||||
nix = "0.20"
|
||||
url = "2.2.2"
|
||||
hex = { version = "0.4.3", features = ["serde"] }
|
||||
reqwest = { version = "0.11", features = ["blocking", "json"] }
|
||||
|
||||
pageserver = { path = "../pageserver" }
|
||||
walkeeper = { path = "../walkeeper" }
|
||||
|
||||
@@ -332,7 +332,7 @@ impl PostgresNode {
|
||||
};
|
||||
|
||||
// Configure that node to take pages from pageserver
|
||||
let (host, port) = connection_host_port(&self.pageserver.connection_config);
|
||||
let (host, port) = connection_host_port(&self.pageserver.pg_connection_config);
|
||||
self.append_conf(
|
||||
"postgresql.conf",
|
||||
format!(
|
||||
|
||||
@@ -5,10 +5,13 @@ use std::process::Command;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use anyhow::{anyhow, bail, ensure, Result};
|
||||
use nix::sys::signal::{kill, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use pageserver::http::models::{BranchCreateRequest, TenantCreateRequest};
|
||||
use postgres::{Config, NoTls};
|
||||
use reqwest::blocking::{Client, RequestBuilder};
|
||||
use reqwest::{IntoUrl, Method, StatusCode};
|
||||
use zenith_utils::postgres_backend::AuthType;
|
||||
use zenith_utils::zid::ZTenantId;
|
||||
|
||||
@@ -17,6 +20,8 @@ use crate::read_pidfile;
|
||||
use pageserver::branches::BranchInfo;
|
||||
use zenith_utils::connstring::connection_address;
|
||||
|
||||
const HTTP_BASE_URL: &str = "http://127.0.0.1:9898/v1";
|
||||
|
||||
//
|
||||
// Control routines for pageserver.
|
||||
//
|
||||
@@ -25,13 +30,15 @@ use zenith_utils::connstring::connection_address;
|
||||
#[derive(Debug)]
|
||||
pub struct PageServerNode {
|
||||
pub kill_on_exit: bool,
|
||||
pub connection_config: Config,
|
||||
pub pg_connection_config: Config,
|
||||
pub env: LocalEnv,
|
||||
pub http_client: Client,
|
||||
pub http_base_url: String,
|
||||
}
|
||||
|
||||
impl PageServerNode {
|
||||
pub fn from_env(env: &LocalEnv) -> PageServerNode {
|
||||
let password = if matches!(env.auth_type, AuthType::ZenithJWT) {
|
||||
let password = if env.auth_type == AuthType::ZenithJWT {
|
||||
&env.auth_token
|
||||
} else {
|
||||
""
|
||||
@@ -39,8 +46,10 @@ impl PageServerNode {
|
||||
|
||||
PageServerNode {
|
||||
kill_on_exit: false,
|
||||
connection_config: Self::default_config(password), // default
|
||||
pg_connection_config: Self::default_config(password), // default
|
||||
env: env.clone(),
|
||||
http_client: Client::new(),
|
||||
http_base_url: HTTP_BASE_URL.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +109,7 @@ impl PageServerNode {
|
||||
pub fn start(&self) -> Result<()> {
|
||||
println!(
|
||||
"Starting pageserver at '{}' in {}",
|
||||
connection_address(&self.connection_config),
|
||||
connection_address(&self.pg_connection_config),
|
||||
self.repo_path().display()
|
||||
);
|
||||
|
||||
@@ -120,7 +129,7 @@ impl PageServerNode {
|
||||
// It takes a while for the page server to start up. Wait until it is
|
||||
// open for business.
|
||||
for retries in 1..15 {
|
||||
match self.page_server_psql_client() {
|
||||
match self.check_status() {
|
||||
Ok(_) => {
|
||||
println!("Pageserver started");
|
||||
return Ok(());
|
||||
@@ -145,7 +154,7 @@ impl PageServerNode {
|
||||
}
|
||||
|
||||
// wait for pageserver stop
|
||||
let address = connection_address(&self.connection_config);
|
||||
let address = connection_address(&self.pg_connection_config);
|
||||
for _ in 0..5 {
|
||||
let stream = TcpStream::connect(&address);
|
||||
thread::sleep(Duration::from_secs(1));
|
||||
@@ -160,50 +169,64 @@ impl PageServerNode {
|
||||
}
|
||||
|
||||
pub fn page_server_psql(&self, sql: &str) -> Vec<postgres::SimpleQueryMessage> {
|
||||
let mut client = self.connection_config.connect(NoTls).unwrap();
|
||||
let mut client = self.pg_connection_config.connect(NoTls).unwrap();
|
||||
|
||||
println!("Pageserver query: '{}'", sql);
|
||||
client.simple_query(sql).unwrap()
|
||||
}
|
||||
|
||||
pub fn page_server_psql_client(&self) -> Result<postgres::Client, postgres::Error> {
|
||||
self.connection_config.connect(NoTls)
|
||||
self.pg_connection_config.connect(NoTls)
|
||||
}
|
||||
|
||||
pub fn tenants_list(&self) -> Result<Vec<String>> {
|
||||
let mut client = self.page_server_psql_client()?;
|
||||
let query_result = client.simple_query("tenant_list")?;
|
||||
let tenants_json = query_result
|
||||
.first()
|
||||
.map(|msg| match msg {
|
||||
postgres::SimpleQueryMessage::Row(row) => row.get(0),
|
||||
_ => None,
|
||||
})
|
||||
.flatten()
|
||||
.ok_or_else(|| anyhow!("missing tenants"))?;
|
||||
|
||||
Ok(serde_json::from_str(tenants_json)?)
|
||||
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
|
||||
let mut builder = self.http_client.request(method, url);
|
||||
if self.env.auth_type == AuthType::ZenithJWT {
|
||||
builder = builder.bearer_auth(&self.env.auth_token)
|
||||
}
|
||||
builder
|
||||
}
|
||||
|
||||
pub fn tenant_create(&self, tenantid: &ZTenantId) -> Result<()> {
|
||||
let mut client = self.page_server_psql_client()?;
|
||||
client.simple_query(format!("tenant_create {}", tenantid).as_str())?;
|
||||
pub fn check_status(&self) -> Result<()> {
|
||||
let status = self
|
||||
.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status"))
|
||||
.send()?
|
||||
.status();
|
||||
ensure!(
|
||||
status == StatusCode::OK,
|
||||
format!("got unexpected response status {}", status)
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn branches_list(&self, tenantid: &ZTenantId) -> Result<Vec<BranchInfo>> {
|
||||
let mut client = self.page_server_psql_client()?;
|
||||
let query_result = client.simple_query(&format!("branch_list {}", tenantid))?;
|
||||
let branches_json = query_result
|
||||
.first()
|
||||
.map(|msg| match msg {
|
||||
postgres::SimpleQueryMessage::Row(row) => row.get(0),
|
||||
_ => None,
|
||||
})
|
||||
.flatten()
|
||||
.ok_or_else(|| anyhow!("missing branches"))?;
|
||||
pub fn tenant_list(&self) -> Result<Vec<String>> {
|
||||
Ok(self
|
||||
.http_request(Method::GET, format!("{}/{}", self.http_base_url, "tenant"))
|
||||
.send()?
|
||||
.error_for_status()?
|
||||
.json()?)
|
||||
}
|
||||
|
||||
Ok(serde_json::from_str(branches_json)?)
|
||||
pub fn tenant_create(&self, tenantid: ZTenantId) -> Result<()> {
|
||||
Ok(self
|
||||
.http_request(Method::POST, format!("{}/{}", self.http_base_url, "tenant"))
|
||||
.json(&TenantCreateRequest {
|
||||
tenant_id: tenantid,
|
||||
})
|
||||
.send()?
|
||||
.error_for_status()?
|
||||
.json()?)
|
||||
}
|
||||
|
||||
pub fn branch_list(&self, tenantid: &ZTenantId) -> Result<Vec<BranchInfo>> {
|
||||
Ok(self
|
||||
.http_request(
|
||||
Method::GET,
|
||||
format!("{}/branch/{}", self.http_base_url, tenantid),
|
||||
)
|
||||
.send()?
|
||||
.error_for_status()?
|
||||
.json()?)
|
||||
}
|
||||
|
||||
pub fn branch_create(
|
||||
@@ -212,29 +235,16 @@ impl PageServerNode {
|
||||
startpoint: &str,
|
||||
tenantid: &ZTenantId,
|
||||
) -> Result<BranchInfo> {
|
||||
let mut client = self.page_server_psql_client()?;
|
||||
let query_result = client.simple_query(
|
||||
format!("branch_create {} {} {}", tenantid, branch_name, startpoint).as_str(),
|
||||
)?;
|
||||
|
||||
let branch_json = query_result
|
||||
.first()
|
||||
.map(|msg| match msg {
|
||||
postgres::SimpleQueryMessage::Row(row) => row.get(0),
|
||||
_ => None,
|
||||
Ok(self
|
||||
.http_request(Method::POST, format!("{}/{}", self.http_base_url, "branch"))
|
||||
.json(&BranchCreateRequest {
|
||||
tenant_id: tenantid.to_owned(),
|
||||
name: branch_name.to_owned(),
|
||||
start_point: startpoint.to_owned(),
|
||||
})
|
||||
.flatten()
|
||||
.ok_or_else(|| anyhow!("missing branch"))?;
|
||||
|
||||
let res: BranchInfo = serde_json::from_str(branch_json).map_err(|e| {
|
||||
anyhow!(
|
||||
"failed to parse branch_create response: {}: {}",
|
||||
branch_json,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(res)
|
||||
.send()?
|
||||
.error_for_status()?
|
||||
.json()?)
|
||||
}
|
||||
|
||||
// TODO: make this a separate request type and avoid loading all the branches
|
||||
@@ -243,14 +253,14 @@ impl PageServerNode {
|
||||
tenantid: &ZTenantId,
|
||||
branch_name: &str,
|
||||
) -> Result<BranchInfo> {
|
||||
let branch_infos = self.branches_list(tenantid)?;
|
||||
let branche_by_name: Result<HashMap<String, BranchInfo>> = branch_infos
|
||||
let branch_infos = self.branch_list(tenantid)?;
|
||||
let branch_by_name: Result<HashMap<String, BranchInfo>> = branch_infos
|
||||
.into_iter()
|
||||
.map(|branch_info| Ok((branch_info.name.clone(), branch_info)))
|
||||
.collect();
|
||||
let branche_by_name = branche_by_name?;
|
||||
let branch_by_name = branch_by_name?;
|
||||
|
||||
let branch = branche_by_name
|
||||
let branch = branch_by_name
|
||||
.get(branch_name)
|
||||
.ok_or_else(|| anyhow!("Branch {} not found", branch_name))?;
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ regex = "1.4.5"
|
||||
bytes = { version = "1.0.1", features = ['serde'] }
|
||||
byteorder = "1.4.3"
|
||||
futures = "0.3.13"
|
||||
hyper = "0.14"
|
||||
lazy_static = "1.4.0"
|
||||
slog-stdlog = "4.1.0"
|
||||
slog-async = "2.6.0"
|
||||
@@ -31,11 +32,12 @@ postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev
|
||||
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
# by default rust-rocksdb tries to build a lot of compression algos. Use lz4 only for now as it is simplest dependency.
|
||||
rocksdb = { version = "0.16.0", features = ["lz4"], default-features = false }
|
||||
routerify = "2"
|
||||
anyhow = "1.0"
|
||||
crc32c = "0.6.0"
|
||||
walkdir = "2"
|
||||
thiserror = "1.0"
|
||||
hex = "0.4.3"
|
||||
hex = { version = "0.4.3", features = ["serde"] }
|
||||
tar = "0.4.33"
|
||||
humantime = "2.1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
|
||||
@@ -20,8 +20,10 @@ use anyhow::{ensure, Result};
|
||||
use clap::{App, Arg, ArgMatches};
|
||||
use daemonize::Daemonize;
|
||||
|
||||
use pageserver::{branches, logger, page_cache, page_service, PageServerConf, RepositoryFormat};
|
||||
use zenith_utils::http_endpoint;
|
||||
use pageserver::{
|
||||
branches, http, logger, page_cache, page_service, PageServerConf, RepositoryFormat,
|
||||
};
|
||||
use zenith_utils::http::endpoint;
|
||||
|
||||
const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:64000";
|
||||
const DEFAULT_HTTP_ENDPOINT_ADDR: &str = "127.0.0.1:9898";
|
||||
@@ -323,19 +325,6 @@ fn start_pageserver(conf: &'static PageServerConf) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn a new thread for the http endpoint
|
||||
thread::Builder::new()
|
||||
.name("Metrics thread".into())
|
||||
.spawn(move || http_endpoint::thread_main(conf.http_endpoint_addr.clone()))?;
|
||||
|
||||
// Check that we can bind to address before starting threads to simplify shutdown
|
||||
// sequence if port is occupied.
|
||||
info!("Starting pageserver on {}", conf.listen_addr);
|
||||
let pageserver_listener = TcpListener::bind(conf.listen_addr.clone())?;
|
||||
|
||||
// Initialize page cache, this will spawn walredo_thread
|
||||
page_cache::init(conf);
|
||||
|
||||
// initialize authentication for incoming connections
|
||||
let auth = match &conf.auth_type {
|
||||
AuthType::Trust | AuthType::MD5 => Arc::new(None),
|
||||
@@ -346,6 +335,24 @@ fn start_pageserver(conf: &'static PageServerConf) -> Result<()> {
|
||||
}
|
||||
};
|
||||
info!("Using auth: {:#?}", conf.auth_type);
|
||||
|
||||
// Spawn a new thread for the http endpoint
|
||||
let cloned = Arc::clone(&auth);
|
||||
thread::Builder::new()
|
||||
.name("http_endpoint_thread".into())
|
||||
.spawn(move || {
|
||||
let router = http::get_router(conf, cloned);
|
||||
endpoint::serve_thread_main(router, conf.http_endpoint_addr.clone())
|
||||
})?;
|
||||
|
||||
// Check that we can bind to address before starting threads to simplify shutdown
|
||||
// sequence if port is occupied.
|
||||
info!("Starting pageserver on {}", conf.listen_addr);
|
||||
let pageserver_listener = TcpListener::bind(conf.listen_addr.clone())?;
|
||||
|
||||
// Initialize page cache, this will spawn walredo_thread
|
||||
page_cache::init(conf);
|
||||
|
||||
// Spawn a thread to listen for connections. It will spawn further threads
|
||||
// for each connection.
|
||||
let page_service_thread = thread::Builder::new()
|
||||
|
||||
@@ -29,6 +29,7 @@ use crate::{repository::Repository, PageServerConf, RepositoryFormat};
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct BranchInfo {
|
||||
pub name: String,
|
||||
#[serde(with = "hex")]
|
||||
pub timeline_id: ZTimelineId,
|
||||
pub latest_valid_lsn: Option<Lsn>,
|
||||
pub ancestor_id: Option<String>,
|
||||
@@ -283,7 +284,7 @@ pub(crate) fn create_branch(
|
||||
let end_of_wal = repo
|
||||
.get_timeline(startpoint.timelineid)?
|
||||
.get_last_record_lsn();
|
||||
println!("branching at end of WAL: {}", end_of_wal);
|
||||
info!("branching at end of WAL: {}", end_of_wal);
|
||||
startpoint.lsn = end_of_wal;
|
||||
}
|
||||
|
||||
|
||||
3
pageserver/src/http/mod.rs
Normal file
3
pageserver/src/http/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod models;
|
||||
pub mod routes;
|
||||
pub use routes::get_router;
|
||||
17
pageserver/src/http/models.rs
Normal file
17
pageserver/src/http/models.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::ZTenantId;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct BranchCreateRequest {
|
||||
#[serde(with = "hex")]
|
||||
pub tenant_id: ZTenantId,
|
||||
pub name: String,
|
||||
pub start_point: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TenantCreateRequest {
|
||||
#[serde(with = "hex")]
|
||||
pub tenant_id: ZTenantId,
|
||||
}
|
||||
239
pageserver/src/http/openapi_spec.yml
Normal file
239
pageserver/src/http/openapi_spec.yml
Normal file
@@ -0,0 +1,239 @@
|
||||
openapi: "3.0.2"
|
||||
info:
|
||||
title: Page Server API
|
||||
version: "1.0"
|
||||
servers:
|
||||
- url: ""
|
||||
paths:
|
||||
/v1/status:
|
||||
description: Healthcheck endpoint
|
||||
get:
|
||||
description: Healthcheck
|
||||
security: []
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
/v1/branch/{tenant_id}:
|
||||
parameters:
|
||||
- name: tenant_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
format: hex
|
||||
get:
|
||||
description: Get branches for tenant
|
||||
responses:
|
||||
"200":
|
||||
description: BranchInfo
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/BranchInfo"
|
||||
"400":
|
||||
description: Error when no tenant id found in path
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"401":
|
||||
description: Unauthorized Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UnauthorizedError"
|
||||
"403":
|
||||
description: Forbidden Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ForbiddenError"
|
||||
|
||||
"500":
|
||||
description: Generic operation error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
/v1/branch/:
|
||||
post:
|
||||
description: Create branch
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- "tenant_id"
|
||||
- "name"
|
||||
- "start_point"
|
||||
properties:
|
||||
tenant_id:
|
||||
type: string
|
||||
format: hex
|
||||
name:
|
||||
type: string
|
||||
start_point:
|
||||
type: string
|
||||
responses:
|
||||
"201":
|
||||
description: BranchInfo
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/BranchInfo"
|
||||
"400":
|
||||
description: Malformed branch create request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"401":
|
||||
description: Unauthorized Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UnauthorizedError"
|
||||
"403":
|
||||
description: Forbidden Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ForbiddenError"
|
||||
"500":
|
||||
description: Generic operation error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
/v1/tenant/:
|
||||
get:
|
||||
description: Get tenants list
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
"401":
|
||||
description: Unauthorized Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UnauthorizedError"
|
||||
"403":
|
||||
description: Forbidden Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ForbiddenError"
|
||||
"500":
|
||||
description: Generic operation error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
post:
|
||||
description: Create tenant
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- "tenant_id"
|
||||
properties:
|
||||
tenant_id:
|
||||
type: string
|
||||
format: hex
|
||||
responses:
|
||||
"201":
|
||||
description: CREATED
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
"400":
|
||||
description: Malformed tenant create request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"401":
|
||||
description: Unauthorized Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UnauthorizedError"
|
||||
"403":
|
||||
description: Forbidden Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ForbiddenError"
|
||||
"500":
|
||||
description: Generic operation error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
components:
|
||||
securitySchemes:
|
||||
JWT:
|
||||
type: http
|
||||
scheme: bearer
|
||||
bearerFormat: JWT
|
||||
schemas:
|
||||
BranchInfo:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- timeline_id
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
timeline_id:
|
||||
type: string
|
||||
format: hex
|
||||
ancestor_id:
|
||||
type: string
|
||||
ancestor_lsn:
|
||||
type: string
|
||||
Error:
|
||||
type: object
|
||||
required:
|
||||
- msg
|
||||
properties:
|
||||
msg:
|
||||
type: string
|
||||
UnauthorizedError:
|
||||
type: object
|
||||
required:
|
||||
- msg
|
||||
properties:
|
||||
msg:
|
||||
type: string
|
||||
ForbiddenError:
|
||||
type: object
|
||||
required:
|
||||
- msg
|
||||
properties:
|
||||
msg:
|
||||
type: string
|
||||
|
||||
security:
|
||||
- JWT: []
|
||||
164
pageserver/src/http/routes.rs
Normal file
164
pageserver/src/http/routes.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use hyper::header;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, Request, Response, Uri};
|
||||
use routerify::Middleware;
|
||||
use routerify::{ext::RequestExt, RouterBuilder};
|
||||
use zenith_utils::auth::JwtAuth;
|
||||
use zenith_utils::http::endpoint::attach_openapi_ui;
|
||||
use zenith_utils::http::endpoint::auth_middleware;
|
||||
use zenith_utils::http::endpoint::check_permission;
|
||||
use zenith_utils::http::endpoint::AuthProvider;
|
||||
use zenith_utils::http::error::ApiError;
|
||||
use zenith_utils::http::{
|
||||
endpoint,
|
||||
error::HttpErrorBody,
|
||||
json::{json_request, json_response},
|
||||
};
|
||||
|
||||
use super::models::BranchCreateRequest;
|
||||
use super::models::TenantCreateRequest;
|
||||
use crate::page_cache;
|
||||
use crate::{
|
||||
branches::{self},
|
||||
PageServerConf, ZTenantId,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
conf: &'static PageServerConf,
|
||||
auth: Arc<Option<JwtAuth>>,
|
||||
whitelist_routes: Vec<Uri>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(conf: &'static PageServerConf, auth: Arc<Option<JwtAuth>>) -> Self {
|
||||
let whitelist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"]
|
||||
.iter()
|
||||
.map(|v| v.parse().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
Self {
|
||||
conf,
|
||||
auth,
|
||||
whitelist_routes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthProvider for State {
|
||||
fn provide_auth(&self, req: &Request<Body>) -> Arc<Option<JwtAuth>> {
|
||||
if self.whitelist_routes.contains(req.uri()) {
|
||||
Arc::new(None)
|
||||
} else {
|
||||
self.auth.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthcheck handler
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from("{}"))
|
||||
.map_err(ApiError::from_err)?)
|
||||
}
|
||||
|
||||
async fn branch_create_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let state = request.data::<Arc<State>>().unwrap().clone();
|
||||
let request_data: BranchCreateRequest = json_request(&mut request).await?;
|
||||
|
||||
check_permission(&request, Some(request_data.tenant_id))?;
|
||||
|
||||
let response_data = tokio::task::spawn_blocking(move || {
|
||||
branches::create_branch(
|
||||
state.conf,
|
||||
&request_data.name,
|
||||
&request_data.start_point,
|
||||
&request_data.tenant_id,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.map_err(ApiError::from_err)??;
|
||||
Ok(json_response(StatusCode::CREATED, response_data)?)
|
||||
}
|
||||
|
||||
async fn branch_list_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let tenantid: ZTenantId = match request.param("tenant_id") {
|
||||
Some(arg) => arg
|
||||
.parse()
|
||||
.map_err(|_| ApiError::BadRequest("failed to parse tenant id".to_string()))?,
|
||||
None => {
|
||||
return Err(ApiError::BadRequest(
|
||||
"no tenant id specified in path param".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
check_permission(&request, Some(tenantid))?;
|
||||
|
||||
let state = request.data::<Arc<State>>().unwrap().clone();
|
||||
let response_data =
|
||||
tokio::task::spawn_blocking(move || crate::branches::get_branches(state.conf, &tenantid))
|
||||
.await
|
||||
.map_err(ApiError::from_err)??;
|
||||
Ok(json_response(StatusCode::OK, response_data)?)
|
||||
}
|
||||
|
||||
async fn tenant_list_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
// check for management permission
|
||||
check_permission(&request, None)?;
|
||||
|
||||
let state = request.data::<Arc<State>>().unwrap().clone();
|
||||
let response_data =
|
||||
tokio::task::spawn_blocking(move || crate::branches::get_tenants(state.conf))
|
||||
.await
|
||||
.map_err(ApiError::from_err)??;
|
||||
Ok(json_response(StatusCode::OK, response_data)?)
|
||||
}
|
||||
|
||||
async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
// check for management permission
|
||||
check_permission(&request, None)?;
|
||||
|
||||
let state = request.data::<Arc<State>>().unwrap().clone();
|
||||
let request_data: TenantCreateRequest = json_request(&mut request).await?;
|
||||
|
||||
let response_data = tokio::task::spawn_blocking(move || {
|
||||
page_cache::create_repository_for_tenant(state.conf, request_data.tenant_id)
|
||||
})
|
||||
.await
|
||||
.map_err(ApiError::from_err)??;
|
||||
Ok(json_response(StatusCode::CREATED, response_data)?)
|
||||
}
|
||||
|
||||
async fn handler_404(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
json_response(
|
||||
StatusCode::NOT_FOUND,
|
||||
HttpErrorBody::from_msg("page not found".to_owned()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_router(
|
||||
conf: &'static PageServerConf,
|
||||
auth: Arc<Option<JwtAuth>>,
|
||||
) -> RouterBuilder<hyper::Body, ApiError> {
|
||||
let spec = include_bytes!("openapi_spec.yml");
|
||||
let mut router = attach_openapi_ui(endpoint::get_router(), spec, "/swagger.yml", "/v1/doc");
|
||||
if let Some(_) = &auth.as_ref() {
|
||||
// note that State is used as a type parameteer without an Arc
|
||||
// this is a simple solution because it is not possible to implement
|
||||
// AuthProvider for Arc<State> so middleware assumes that state is wrapped in Arc
|
||||
router = router.middleware(Middleware::pre(auth_middleware::<State>))
|
||||
}
|
||||
router
|
||||
.data(Arc::new(State::new(conf, auth)))
|
||||
.get("/v1/status", status_handler)
|
||||
.get("/v1/branch/:tenant_id", branch_list_handler)
|
||||
.post("/v1/branch", branch_create_handler)
|
||||
.get("/v1/tenant", tenant_list_handler)
|
||||
.post("/v1/tenant", tenant_create_handler)
|
||||
.any(handler_404)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ use zenith_metrics::{register_int_gauge_vec, IntGaugeVec};
|
||||
|
||||
pub mod basebackup;
|
||||
pub mod branches;
|
||||
pub mod http;
|
||||
pub mod layered_repository;
|
||||
pub mod logger;
|
||||
pub mod object_key;
|
||||
|
||||
@@ -23,7 +23,7 @@ use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::{io, net::TcpStream};
|
||||
use zenith_metrics::{register_histogram_vec, HistogramVec};
|
||||
use zenith_utils::auth::JwtAuth;
|
||||
use zenith_utils::auth::{self, JwtAuth};
|
||||
use zenith_utils::auth::{Claims, Scope};
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::postgres_backend::{self, AuthType};
|
||||
@@ -389,19 +389,7 @@ impl PageServerHandler {
|
||||
.claims
|
||||
.as_ref()
|
||||
.expect("claims presence already checked");
|
||||
match (&claims.scope, tenantid) {
|
||||
(Scope::Tenant, None) => {
|
||||
bail!("Attempt to access management api with tenant scope. Permission denied")
|
||||
}
|
||||
(Scope::Tenant, Some(tenantid)) => {
|
||||
if claims.tenant_id.unwrap() != tenantid {
|
||||
bail!("Tenant id mismatch. Permission denied")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
(Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope
|
||||
(Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope
|
||||
}
|
||||
Ok(auth::check_permission(claims, tenantid)?)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ pytest = ">=6.0.0"
|
||||
psycopg2 = "*"
|
||||
typing-extensions = "*"
|
||||
pyjwt = {extras = ["crypto"], version = "*"}
|
||||
requests = "*"
|
||||
|
||||
[dev-packages]
|
||||
yapf = "*"
|
||||
|
||||
43
test_runner/Pipfile.lock
generated
43
test_runner/Pipfile.lock
generated
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "f60a966726bcc19670402ad3fa57396b5dacf0a027544418ceb7cc0d42d94a52"
|
||||
"sha256": "b666740289d9c82797e5c39b2a7f0074c865c9183ee878ce4fa5cda7928506ea"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
@@ -24,6 +24,13 @@
|
||||
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'",
|
||||
"version": "==21.2.0"
|
||||
},
|
||||
"certifi": {
|
||||
"hashes": [
|
||||
"sha256:2bbf76fd432960138b3ef6dda3dde0544f27cbf8546c458e60baf371917ba9ee",
|
||||
"sha256:50b1e4f8446b06f41be7dd6338db18e0990601dce795c2b1686458aa7e8fa7d8"
|
||||
],
|
||||
"version": "==2021.5.30"
|
||||
},
|
||||
"cffi": {
|
||||
"hashes": [
|
||||
"sha256:06c54a68935738d206570b20da5ef2b6b6d92b38ef3ec45c5422c0ebaf338d4d",
|
||||
@@ -74,6 +81,14 @@
|
||||
],
|
||||
"version": "==1.14.6"
|
||||
},
|
||||
"charset-normalizer": {
|
||||
"hashes": [
|
||||
"sha256:0c8911edd15d19223366a194a513099a302055a962bca2cec0f54b8b63175d8b",
|
||||
"sha256:f23667ebe1084be45f6ae0538e4a5a865206544097e4e8bbcacf42cd02a348f3"
|
||||
],
|
||||
"markers": "python_version >= '3'",
|
||||
"version": "==2.0.4"
|
||||
},
|
||||
"cryptography": {
|
||||
"hashes": [
|
||||
"sha256:0f1212a66329c80d68aeeb39b8a16d54ef57071bf22ff4e521657b27372e327d",
|
||||
@@ -85,12 +100,22 @@
|
||||
"sha256:3d8427734c781ea5f1b41d6589c293089704d4759e34597dce91014ac125aad1",
|
||||
"sha256:7ec5d3b029f5fa2b179325908b9cd93db28ab7b85bb6c1db56b10e0b54235177",
|
||||
"sha256:8e56e16617872b0957d1c9742a3f94b43533447fd78321514abbe7db216aa250",
|
||||
"sha256:b01fd6f2737816cb1e08ed4807ae194404790eac7ad030b34f2ce72b332f5586",
|
||||
"sha256:bf40af59ca2465b24e54f671b2de2c59257ddc4f7e5706dbd6930e26823668d3",
|
||||
"sha256:de4e5f7f68220d92b7637fc99847475b59154b7a1b3868fb7385337af54ac9ca",
|
||||
"sha256:eb8cc2afe8b05acbd84a43905832ec78e7b3873fb124ca190f574dca7389a87d",
|
||||
"sha256:ee77aa129f481be46f8d92a1a7db57269a2f23052d5f2433b4621bb457081cc9"
|
||||
],
|
||||
"version": "==3.4.7"
|
||||
},
|
||||
"idna": {
|
||||
"hashes": [
|
||||
"sha256:14475042e284991034cb48e06f6851428fb14c4dc953acd9be9a5e95c7b6dd7a",
|
||||
"sha256:467fbad99067910785144ce333826c71fb0e63a425657295239737f7ecd125f3"
|
||||
],
|
||||
"markers": "python_version >= '3'",
|
||||
"version": "==3.2"
|
||||
},
|
||||
"iniconfig": {
|
||||
"hashes": [
|
||||
"sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3",
|
||||
@@ -172,6 +197,14 @@
|
||||
"index": "pypi",
|
||||
"version": "==6.2.4"
|
||||
},
|
||||
"requests": {
|
||||
"hashes": [
|
||||
"sha256:6c1246513ecd5ecd4528a0906f910e8f0f9c6b8ec72030dc9fd154dc1a6efd24",
|
||||
"sha256:b8aa58f8cf793ffd8782d3d8cb19e66ef36f7aba4353eec859e74678b01b07a7"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==2.26.0"
|
||||
},
|
||||
"toml": {
|
||||
"hashes": [
|
||||
"sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",
|
||||
@@ -188,6 +221,14 @@
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==3.10.0.0"
|
||||
},
|
||||
"urllib3": {
|
||||
"hashes": [
|
||||
"sha256:39fb8672126159acb139a7718dd10806104dec1e2f0f6c88aab05d17df10c8d4",
|
||||
"sha256:f57b4c16c62fa2760b7e3d97c35b255512fb6b59a259730f36ba32ce9f8e342f"
|
||||
],
|
||||
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' and python_version < '4'",
|
||||
"version": "==1.26.6"
|
||||
}
|
||||
},
|
||||
"develop": {
|
||||
|
||||
@@ -1,61 +1,17 @@
|
||||
|
||||
from contextlib import closing
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
from dataclasses import dataclass
|
||||
import jwt
|
||||
import psycopg2
|
||||
from fixtures.zenith_fixtures import Postgres, ZenithCli, ZenithPageserver
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pageserver_auth_enabled(zenith_cli: ZenithCli):
|
||||
with ZenithPageserver(zenith_cli).init(enable_auth=True).start() as ps:
|
||||
# For convenience in tests, create a branch from the freshly-initialized cluster.
|
||||
zenith_cli.run(["branch", "empty", "main"])
|
||||
yield ps
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthKeys:
|
||||
pub: bytes
|
||||
priv: bytes
|
||||
|
||||
def generate_management_token(self):
|
||||
token = jwt.encode({"scope": "pageserverapi"}, self.priv, algorithm="RS256")
|
||||
|
||||
# jwt.encode can return 'bytes' or 'str', depending on Python version or type
|
||||
# hinting or something (not sure what). If it returned 'bytes', convert it to 'str'
|
||||
# explicitly.
|
||||
if isinstance(token, bytes):
|
||||
token = token.decode()
|
||||
|
||||
return token
|
||||
|
||||
def generate_tenant_token(self, tenant_id):
|
||||
token = jwt.encode({"scope": "tenant", "tenant_id": tenant_id}, self.priv, algorithm="RS256")
|
||||
|
||||
if isinstance(token, bytes):
|
||||
token = token.decode()
|
||||
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_keys(repo_dir: str) -> AuthKeys:
|
||||
# TODO probably this should be specified in cli config and used in tests for single source of truth
|
||||
pub = (Path(repo_dir) / 'auth_public_key.pem').read_bytes()
|
||||
priv = (Path(repo_dir) / 'auth_private_key.pem').read_bytes()
|
||||
return AuthKeys(pub=pub, priv=priv)
|
||||
|
||||
|
||||
def test_pageserver_auth(pageserver_auth_enabled: ZenithPageserver, auth_keys: AuthKeys):
|
||||
def test_pageserver_auth(pageserver_auth_enabled: ZenithPageserver):
|
||||
ps = pageserver_auth_enabled
|
||||
|
||||
tenant_token = auth_keys.generate_tenant_token(ps.initial_tenant)
|
||||
invalid_tenant_token = auth_keys.generate_tenant_token(uuid4().hex)
|
||||
management_token = auth_keys.generate_management_token()
|
||||
tenant_token = ps.auth_keys.generate_tenant_token(ps.initial_tenant)
|
||||
invalid_tenant_token = ps.auth_keys.generate_tenant_token(uuid4().hex)
|
||||
management_token = ps.auth_keys.generate_management_token()
|
||||
|
||||
# this does not invoke auth check and only decodes jwt and checks it for validity
|
||||
# check both tokens
|
||||
@@ -86,12 +42,11 @@ def test_compute_auth_to_pageserver(
|
||||
pageserver_auth_enabled: ZenithPageserver,
|
||||
repo_dir: str,
|
||||
with_wal_acceptors: bool,
|
||||
auth_keys: AuthKeys,
|
||||
):
|
||||
ps = pageserver_auth_enabled
|
||||
# since we are in progress of refactoring protocols between compute safekeeper and page server
|
||||
# use hardcoded management token in safekeeper
|
||||
management_token = auth_keys.generate_management_token()
|
||||
management_token = ps.auth_keys.generate_management_token()
|
||||
|
||||
branch = f"test_compute_auth_to_pageserver{with_wal_acceptors}"
|
||||
zenith_cli.run(["branch", branch, "empty"])
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
import json
|
||||
import uuid
|
||||
from uuid import uuid4
|
||||
import pytest
|
||||
import psycopg2
|
||||
from fixtures.zenith_fixtures import ZenithPageserver
|
||||
import requests
|
||||
from fixtures.zenith_fixtures import ZenithPageserver, ZenithPageserverHttpClient
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_status(pageserver):
|
||||
def test_status_psql(pageserver):
|
||||
assert pageserver.safe_psql('status') == [
|
||||
('hello world', ),
|
||||
]
|
||||
|
||||
|
||||
def test_branch_list(pageserver: ZenithPageserver, zenith_cli):
|
||||
def test_branch_list_psql(pageserver: ZenithPageserver, zenith_cli):
|
||||
# Create a branch for us
|
||||
zenith_cli.run(["branch", "test_branch_list_main", "empty"])
|
||||
|
||||
@@ -52,7 +53,7 @@ def test_branch_list(pageserver: ZenithPageserver, zenith_cli):
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_tenant_list(pageserver: ZenithPageserver, zenith_cli):
|
||||
def test_tenant_list_psql(pageserver: ZenithPageserver, zenith_cli):
|
||||
res = zenith_cli.run(["tenant", "list"])
|
||||
res.check_returncode()
|
||||
tenants = res.stdout.splitlines()
|
||||
@@ -66,7 +67,7 @@ def test_tenant_list(pageserver: ZenithPageserver, zenith_cli):
|
||||
cur.execute(f'tenant_create {pageserver.initial_tenant}')
|
||||
|
||||
# create one more tenant
|
||||
tenant1 = uuid.uuid4().hex
|
||||
tenant1 = uuid4().hex
|
||||
cur.execute(f'tenant_create {tenant1}')
|
||||
|
||||
cur.execute('tenant_list')
|
||||
@@ -74,3 +75,32 @@ def test_tenant_list(pageserver: ZenithPageserver, zenith_cli):
|
||||
# compare tenants list
|
||||
new_tenants = sorted(json.loads(cur.fetchone()[0]))
|
||||
assert sorted([pageserver.initial_tenant, tenant1]) == new_tenants
|
||||
|
||||
|
||||
def check_client(client: ZenithPageserverHttpClient, initial_tenant: str):
|
||||
client.check_status()
|
||||
|
||||
# check initial tenant is there
|
||||
assert initial_tenant in set(client.tenant_list())
|
||||
|
||||
# create new tenant and check it is also there
|
||||
tenant_id = uuid4()
|
||||
client.tenant_create(tenant_id)
|
||||
assert tenant_id.hex in set(client.tenant_list())
|
||||
|
||||
# create branch
|
||||
branch_name = uuid4().hex
|
||||
client.branch_create(tenant_id, branch_name, "main")
|
||||
|
||||
# check it is there
|
||||
assert branch_name in {b['name'] for b in client.branch_list(tenant_id)}
|
||||
|
||||
|
||||
def test_pageserver_http_api_client(pageserver: ZenithPageserver):
|
||||
client = pageserver.http_client()
|
||||
check_client(client, pageserver.initial_tenant)
|
||||
|
||||
|
||||
def test_pageserver_http_api_client_auth_enabled(pageserver_auth_enabled: ZenithPageserver):
|
||||
client = pageserver_auth_enabled.http_client(auth_token=pageserver_auth_enabled.auth_keys.generate_management_token())
|
||||
check_client(client, pageserver_auth_enabled.initial_tenant)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
import os
|
||||
import pathlib
|
||||
import uuid
|
||||
import jwt
|
||||
import psycopg2
|
||||
import pytest
|
||||
import shutil
|
||||
@@ -17,6 +20,8 @@ from psycopg2.extensions import connection as PgConnection
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast
|
||||
from typing_extensions import Literal
|
||||
|
||||
import requests
|
||||
|
||||
from .utils import (get_self_dir, mkdir_if_needed, subprocess_capture)
|
||||
"""
|
||||
This file contains pytest fixtures. A fixture is a test resource that can be
|
||||
@@ -42,7 +47,8 @@ Fn = TypeVar('Fn', bound=Callable[..., Any])
|
||||
DEFAULT_OUTPUT_DIR = 'test_output'
|
||||
DEFAULT_POSTGRES_DIR = 'tmp_install'
|
||||
|
||||
DEFAULT_PAGESERVER_PORT = 64000
|
||||
DEFAULT_PAGESERVER_PG_PORT = 64000
|
||||
DEFAULT_PAGESERVER_HTTP_PORT = 9898
|
||||
|
||||
|
||||
def determine_scope(fixture_name: str, config: Any) -> str:
|
||||
@@ -175,13 +181,87 @@ def zenith_cli(zenith_binpath: str, repo_dir: str, pg_distrib_dir: str) -> Zenit
|
||||
return ZenithCli(zenith_binpath, repo_dir, pg_distrib_dir)
|
||||
|
||||
|
||||
class ZenithPageserverHttpClient(requests.Session):
|
||||
def __init__(self, port: int, auth_token: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self.port = port
|
||||
self.auth_token = auth_token
|
||||
|
||||
if auth_token is not None:
|
||||
self.headers['Authorization'] = f'Bearer {auth_token}'
|
||||
|
||||
def check_status(self):
|
||||
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
|
||||
|
||||
def branch_list(self, tenant_id: uuid.UUID) -> List[Dict]:
|
||||
res = self.get(f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}")
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
def branch_create(self, tenant_id: uuid.UUID, name: str, start_point: str) -> Dict:
|
||||
res = self.post(
|
||||
f"http://localhost:{self.port}/v1/branch",
|
||||
json={
|
||||
'tenant_id': tenant_id.hex,
|
||||
'name': name,
|
||||
'start_point': start_point,
|
||||
}
|
||||
)
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
def tenant_list(self) -> List[str]:
|
||||
res = self.get(f"http://localhost:{self.port}/v1/tenant")
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
def tenant_create(self, tenant_id: uuid.UUID):
|
||||
res = self.post(
|
||||
f"http://localhost:{self.port}/v1/tenant",
|
||||
json={
|
||||
'tenant_id': tenant_id.hex,
|
||||
},
|
||||
)
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthKeys:
|
||||
pub: bytes
|
||||
priv: bytes
|
||||
|
||||
def generate_management_token(self):
|
||||
token = jwt.encode({"scope": "pageserverapi"}, self.priv, algorithm="RS256")
|
||||
|
||||
# jwt.encode can return 'bytes' or 'str', depending on Python version or type
|
||||
# hinting or something (not sure what). If it returned 'bytes', convert it to 'str'
|
||||
# explicitly.
|
||||
if isinstance(token, bytes):
|
||||
token = token.decode()
|
||||
|
||||
return token
|
||||
|
||||
def generate_tenant_token(self, tenant_id):
|
||||
token = jwt.encode({"scope": "tenant", "tenant_id": tenant_id}, self.priv, algorithm="RS256")
|
||||
|
||||
if isinstance(token, bytes):
|
||||
token = token.decode()
|
||||
|
||||
return token
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ZenithPageserver(PgProtocol):
|
||||
""" An object representing a running pageserver. """
|
||||
def __init__(self, zenith_cli: ZenithCli):
|
||||
super().__init__(host='localhost', port=DEFAULT_PAGESERVER_PORT)
|
||||
def __init__(self, zenith_cli: ZenithCli, repo_dir: str):
|
||||
super().__init__(host='localhost', port=DEFAULT_PAGESERVER_PG_PORT)
|
||||
self.zenith_cli = zenith_cli
|
||||
self.running = False
|
||||
self.initial_tenant = None
|
||||
self.repo_dir = repo_dir
|
||||
|
||||
def init(self, enable_auth: bool = False) -> 'ZenithPageserver':
|
||||
"""
|
||||
@@ -224,9 +304,21 @@ class ZenithPageserver(PgProtocol):
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.stop()
|
||||
|
||||
@cached_property
|
||||
def auth_keys(self) -> AuthKeys:
|
||||
pub = (Path(self.repo_dir) / 'auth_public_key.pem').read_bytes()
|
||||
priv = (Path(self.repo_dir) / 'auth_private_key.pem').read_bytes()
|
||||
return AuthKeys(pub=pub, priv=priv)
|
||||
|
||||
def http_client(self, auth_token: Optional[str] = None):
|
||||
return ZenithPageserverHttpClient(
|
||||
port=DEFAULT_PAGESERVER_HTTP_PORT,
|
||||
auth_token=auth_token,
|
||||
)
|
||||
|
||||
|
||||
@zenfixture
|
||||
def pageserver(zenith_cli: ZenithCli) -> Iterator[ZenithPageserver]:
|
||||
def pageserver(zenith_cli: ZenithCli, repo_dir: str) -> Iterator[ZenithPageserver]:
|
||||
"""
|
||||
The 'pageserver' fixture provides a Page Server that's up and running.
|
||||
|
||||
@@ -239,7 +331,7 @@ def pageserver(zenith_cli: ZenithCli) -> Iterator[ZenithPageserver]:
|
||||
test called 'test_foo' would create and use branches with the 'test_foo' prefix.
|
||||
"""
|
||||
|
||||
ps = ZenithPageserver(zenith_cli).init().start()
|
||||
ps = ZenithPageserver(zenith_cli, repo_dir).init().start()
|
||||
# For convenience in tests, create a branch from the freshly-initialized cluster.
|
||||
zenith_cli.run(["branch", "empty", "main"])
|
||||
|
||||
@@ -250,6 +342,14 @@ def pageserver(zenith_cli: ZenithCli) -> Iterator[ZenithPageserver]:
|
||||
ps.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pageserver_auth_enabled(zenith_cli: ZenithCli, repo_dir: str):
|
||||
with ZenithPageserver(zenith_cli, repo_dir).init(enable_auth=True).start() as ps:
|
||||
# For convenience in tests, create a branch from the freshly-initialized cluster.
|
||||
zenith_cli.run(["branch", "empty", "main"])
|
||||
yield ps
|
||||
|
||||
|
||||
class Postgres(PgProtocol):
|
||||
""" An object representing a running postgres daemon. """
|
||||
def __init__(self, zenith_cli: ZenithCli, repo_dir: str, tenant_id: str, port: int):
|
||||
@@ -598,7 +698,7 @@ class WalAcceptor:
|
||||
cmd.append("--daemonize")
|
||||
cmd.append("--no-sync")
|
||||
# Tell page server it can receive WAL from this WAL safekeeper
|
||||
cmd.extend(["--pageserver", "localhost:{}".format(DEFAULT_PAGESERVER_PORT)])
|
||||
cmd.extend(["--pageserver", "localhost:{}".format(DEFAULT_PAGESERVER_PG_PORT)])
|
||||
cmd.extend(["--recall", "1 second"])
|
||||
print('Running command "{}"'.format(' '.join(cmd)))
|
||||
env = {'PAGESERVER_AUTH_TOKEN': self.auth_token} if self.auth_token else None
|
||||
|
||||
@@ -224,7 +224,7 @@ fn main() -> Result<()> {
|
||||
|
||||
("pg", Some(pg_match)) => {
|
||||
if let Err(e) = handle_pg(pg_match, &env) {
|
||||
eprintln!("pg operation failed: {}", e);
|
||||
eprintln!("pg operation failed: {:?}", e);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
@@ -371,7 +371,7 @@ fn get_branch_infos(
|
||||
tenantid: &ZTenantId,
|
||||
) -> Result<HashMap<ZTimelineId, BranchInfo>> {
|
||||
let page_server = PageServerNode::from_env(env);
|
||||
let branch_infos: Vec<BranchInfo> = page_server.branches_list(tenantid)?;
|
||||
let branch_infos: Vec<BranchInfo> = page_server.branch_list(tenantid)?;
|
||||
let branch_infos: HashMap<ZTimelineId, BranchInfo> = branch_infos
|
||||
.into_iter()
|
||||
.map(|branch_info| (branch_info.timeline_id, branch_info))
|
||||
@@ -384,7 +384,7 @@ fn handle_tenant(tenant_match: &ArgMatches, env: &local_env::LocalEnv) -> Result
|
||||
let pageserver = PageServerNode::from_env(&env);
|
||||
match tenant_match.subcommand() {
|
||||
("list", Some(_)) => {
|
||||
for tenant in pageserver.tenants_list()? {
|
||||
for tenant in pageserver.tenant_list()? {
|
||||
println!("{}", tenant);
|
||||
}
|
||||
}
|
||||
@@ -394,7 +394,7 @@ fn handle_tenant(tenant_match: &ArgMatches, env: &local_env::LocalEnv) -> Result
|
||||
None => ZTenantId::generate(),
|
||||
};
|
||||
println!("using tenant id {}", tenantid);
|
||||
pageserver.tenant_create(&tenantid)?;
|
||||
pageserver.tenant_create(tenantid)?;
|
||||
println!("tenant successfully created on the pageserver");
|
||||
}
|
||||
_ => {}
|
||||
@@ -424,7 +424,7 @@ fn handle_branch(branch_match: &ArgMatches, env: &local_env::LocalEnv) -> Result
|
||||
.value_of("tenantid")
|
||||
.map_or(Ok(env.tenantid), |value| value.parse())?;
|
||||
// No arguments, list branches for tenant
|
||||
let branches = pageserver.branches_list(&tenantid)?;
|
||||
let branches = pageserver.branch_list(&tenantid)?;
|
||||
print_branches_tree(branches)?;
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ hyper = { version = "0.14.7", features = ["full"] }
|
||||
lazy_static = "1.4.0"
|
||||
log = "0.4.14"
|
||||
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
routerify = "2"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1.5.0", features = ["full"] }
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ use serde::de::Error;
|
||||
use serde::{self, Deserializer, Serializer};
|
||||
use std::{fs, path::PathBuf};
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::{bail, Result};
|
||||
use jsonwebtoken::{
|
||||
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
|
||||
};
|
||||
@@ -20,7 +20,7 @@ use crate::zid::ZTenantId;
|
||||
|
||||
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Scope {
|
||||
Tenant,
|
||||
@@ -48,7 +48,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
// this custom serialize/deserialize_with is needed because Option is not transparent to serde
|
||||
// so clearest option is serde(with = "hex") but it is not working, for details see https://github.com/serde-rs/serde/issues/1301
|
||||
@@ -68,6 +68,22 @@ impl Claims {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_permission(claims: &Claims, tenantid: Option<ZTenantId>) -> Result<()> {
|
||||
match (&claims.scope, tenantid) {
|
||||
(Scope::Tenant, None) => {
|
||||
bail!("Attempt to access management api with tenant scope. Permission denied")
|
||||
}
|
||||
(Scope::Tenant, Some(tenantid)) => {
|
||||
if claims.tenant_id.unwrap() != tenantid {
|
||||
bail!("Tenant id mismatch. Permission denied")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
(Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope
|
||||
(Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct JwtAuth {
|
||||
decoding_key: DecodingKey<'static>,
|
||||
|
||||
175
zenith_utils/src/http/endpoint.rs
Normal file
175
zenith_utils/src/http/endpoint.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::auth::{self, Claims, JwtAuth};
|
||||
use crate::http::error;
|
||||
use crate::zid::ZTenantId;
|
||||
use anyhow::anyhow;
|
||||
use hyper::header::AUTHORIZATION;
|
||||
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server};
|
||||
use lazy_static::lazy_static;
|
||||
use routerify::ext::RequestExt;
|
||||
use routerify::RequestInfo;
|
||||
use routerify::{Middleware, Router, RouterBuilder, RouterService};
|
||||
use zenith_metrics::{register_int_counter, IntCounter};
|
||||
use zenith_metrics::{Encoder, TextEncoder};
|
||||
|
||||
use super::error::ApiError;
|
||||
|
||||
lazy_static! {
|
||||
static ref SERVE_METRICS_COUNT: IntCounter = register_int_counter!(
|
||||
"pageserver_serve_metrics_count",
|
||||
"Number of metric requests made"
|
||||
)
|
||||
.expect("failed to define a metric");
|
||||
}
|
||||
|
||||
async fn logger(res: Response<Body>, info: RequestInfo) -> Result<Response<Body>, ApiError> {
|
||||
log::info!("{} {} {}", info.method(), info.uri().path(), res.status(),);
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
SERVE_METRICS_COUNT.inc();
|
||||
|
||||
let mut buffer = vec![];
|
||||
let encoder = TextEncoder::new();
|
||||
let metrics = zenith_metrics::gather();
|
||||
encoder.encode(&metrics, &mut buffer).unwrap();
|
||||
|
||||
let response = Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, encoder.format_type())
|
||||
.body(Body::from(buffer))
|
||||
.unwrap();
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub fn get_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
Router::builder()
|
||||
.middleware(Middleware::post_with_info(logger))
|
||||
.get("/metrics", prometheus_metrics_handler)
|
||||
.err_handler(error::handler)
|
||||
}
|
||||
|
||||
pub fn attach_openapi_ui(
|
||||
router_builder: RouterBuilder<hyper::Body, ApiError>,
|
||||
spec: &'static [u8],
|
||||
spec_mount_path: &'static str,
|
||||
ui_mount_path: &'static str,
|
||||
) -> RouterBuilder<hyper::Body, ApiError> {
|
||||
router_builder.get(spec_mount_path, move |_| async move {
|
||||
Ok(Response::builder().body(Body::from(spec)).unwrap())
|
||||
}).get(ui_mount_path, move |_| async move {
|
||||
Ok(Response::builder().body(Body::from(format!(r#"
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>rweb</title>
|
||||
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui"></div>
|
||||
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const ui = SwaggerUIBundle({{
|
||||
"dom_id": "\#swagger-ui",
|
||||
presets: [
|
||||
SwaggerUIBundle.presets.apis,
|
||||
SwaggerUIBundle.SwaggerUIStandalonePreset
|
||||
],
|
||||
layout: "BaseLayout",
|
||||
deepLinking: true,
|
||||
showExtensions: true,
|
||||
showCommonExtensions: true,
|
||||
url: "{}",
|
||||
}})
|
||||
window.ui = ui;
|
||||
}};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"#, spec_mount_path))).unwrap())
|
||||
})
|
||||
}
|
||||
|
||||
pub trait AuthProvider {
|
||||
fn provide_auth(&self, req: &Request<Body>) -> Arc<Option<JwtAuth>>;
|
||||
}
|
||||
|
||||
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
||||
// header must be in form Bearer <token>
|
||||
let (prefix, token) = header_value.split_once(' ').ok_or(ApiError::Unauthorized(
|
||||
"malformed authorization header".to_string(),
|
||||
))?;
|
||||
if prefix != "Bearer" {
|
||||
Err(ApiError::Unauthorized(
|
||||
"malformed authorization header".to_string(),
|
||||
))?
|
||||
}
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub async fn auth_middleware<S: AuthProvider + Send + Sync + 'static>(
|
||||
req: Request<Body>,
|
||||
) -> Result<Request<Body>, ApiError> {
|
||||
// unwrap is ok because this is called in auth middleware
|
||||
// which should be enabled only when auth is some
|
||||
let state_auth = req
|
||||
.data::<Arc<S>>()
|
||||
.expect("state is always in request data")
|
||||
.provide_auth(&req);
|
||||
|
||||
if let Some(auth) = state_auth.as_ref().as_ref() {
|
||||
match req.headers().get(AUTHORIZATION) {
|
||||
Some(value) => {
|
||||
let header_value = value.to_str().map_err(|_| {
|
||||
ApiError::Unauthorized("malformed authorization header".to_string())
|
||||
})?;
|
||||
let token = parse_token(header_value)?;
|
||||
|
||||
let data = auth
|
||||
.decode(token)
|
||||
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
|
||||
req.set_context(data.claims);
|
||||
}
|
||||
None => Err(ApiError::Unauthorized(
|
||||
"missing authorization header".to_string(),
|
||||
))?,
|
||||
}
|
||||
}
|
||||
Ok(req)
|
||||
}
|
||||
|
||||
pub fn check_permission(req: &Request<Body>, tenantid: Option<ZTenantId>) -> Result<(), ApiError> {
|
||||
match req.context::<Claims>() {
|
||||
Some(claims) => Ok(auth::check_permission(&claims, tenantid)
|
||||
.map_err(|err| ApiError::Forbidden(err.to_string()))?),
|
||||
None => Ok(()), // claims is None because auth is disabled
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serve_thread_main(
|
||||
router_builder: RouterBuilder<hyper::Body, ApiError>,
|
||||
addr: String,
|
||||
) -> anyhow::Result<()> {
|
||||
let addr = addr.parse()?;
|
||||
log::info!("Starting a http endoint at {}", addr);
|
||||
|
||||
// Create a Service from the router above to handle incoming requests.
|
||||
let service = RouterService::new(router_builder.build().map_err(|err| anyhow!(err))?).unwrap();
|
||||
|
||||
// Enter a single-threaded tokio runtime bound to the current thread
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let _guard = runtime.enter();
|
||||
|
||||
let server = Server::bind(&addr).serve(service);
|
||||
|
||||
runtime.block_on(server)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
77
zenith_utils/src/http/error.rs
Normal file
77
zenith_utils/src/http/error.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use anyhow::anyhow;
|
||||
use hyper::{header, Body, Response, StatusCode};
|
||||
use serde::Serialize;
|
||||
use serde_json;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
#[error("Bad request: {0}")]
|
||||
BadRequest(String),
|
||||
|
||||
#[error("Forbidden: {0}")]
|
||||
Forbidden(String),
|
||||
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(String),
|
||||
|
||||
#[error(transparent)]
|
||||
InternalServerError(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
pub fn from_err<E: Into<anyhow::Error>>(err: E) -> Self {
|
||||
Self::InternalServerError(anyhow!(err))
|
||||
}
|
||||
|
||||
pub fn into_response(self) -> Response<Body> {
|
||||
match self {
|
||||
ApiError::BadRequest(_) => HttpErrorBody::response_from_msg_and_status(
|
||||
self.to_string(),
|
||||
StatusCode::BAD_REQUEST,
|
||||
),
|
||||
ApiError::Forbidden(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(self.to_string(), StatusCode::FORBIDDEN)
|
||||
}
|
||||
ApiError::Unauthorized(_) => HttpErrorBody::response_from_msg_and_status(
|
||||
self.to_string(),
|
||||
StatusCode::UNAUTHORIZED,
|
||||
),
|
||||
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct HttpErrorBody {
|
||||
pub msg: String,
|
||||
}
|
||||
|
||||
impl HttpErrorBody {
|
||||
pub fn from_msg(msg: String) -> Self {
|
||||
HttpErrorBody { msg }
|
||||
}
|
||||
|
||||
pub fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Body> {
|
||||
HttpErrorBody { msg }.into_response(status)
|
||||
}
|
||||
|
||||
pub fn into_response(&self, status: StatusCode) -> Response<Body> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(Body::from(serde_json::to_string(self).unwrap()))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handler(err: routerify::RouteError) -> Response<Body> {
|
||||
log::error!("{}", err);
|
||||
err.downcast::<ApiError>()
|
||||
.expect("handler should always return api error")
|
||||
.into_response()
|
||||
}
|
||||
29
zenith_utils/src/http/json.rs
Normal file
29
zenith_utils/src/http/json.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use bytes::Buf;
|
||||
use hyper::{header, Body, Request, Response, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
|
||||
use super::error::ApiError;
|
||||
|
||||
pub async fn json_request<T: for<'de> Deserialize<'de>>(
|
||||
request: &mut Request<Body>,
|
||||
) -> Result<T, ApiError> {
|
||||
let whole_body = hyper::body::aggregate(request.body_mut())
|
||||
.await
|
||||
.map_err(ApiError::from_err)?;
|
||||
Ok(serde_json::from_reader(whole_body.reader())
|
||||
.map_err(|err| ApiError::BadRequest(format!("Failed to parse json request {}", err)))?)
|
||||
}
|
||||
|
||||
pub fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let json = serde_json::to_string(&data).map_err(ApiError::from_err)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(json))
|
||||
.map_err(ApiError::from_err)?;
|
||||
Ok(response)
|
||||
}
|
||||
3
zenith_utils/src/http/mod.rs
Normal file
3
zenith_utils/src/http/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod endpoint;
|
||||
pub mod error;
|
||||
pub mod json;
|
||||
@@ -1,53 +0,0 @@
|
||||
use hyper::{
|
||||
header::CONTENT_TYPE,
|
||||
service::{make_service_fn, service_fn},
|
||||
Body, Request, Response, Server,
|
||||
};
|
||||
use lazy_static::lazy_static;
|
||||
use zenith_metrics::{register_int_counter, IntCounter};
|
||||
use zenith_metrics::{Encoder, TextEncoder};
|
||||
|
||||
lazy_static! {
|
||||
static ref SERVE_METRICS_COUNT: IntCounter = register_int_counter!(
|
||||
"pageserver_serve_metrics_count",
|
||||
"Number of metric requests made"
|
||||
)
|
||||
.expect("failed to define a metric");
|
||||
}
|
||||
|
||||
async fn serve_prometheus_metrics(_req: Request<Body>) -> anyhow::Result<Response<Body>> {
|
||||
SERVE_METRICS_COUNT.inc();
|
||||
|
||||
let mut buffer = vec![];
|
||||
let encoder = TextEncoder::new();
|
||||
let metrics = zenith_metrics::gather();
|
||||
encoder.encode(&metrics, &mut buffer).unwrap();
|
||||
|
||||
let response = Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, encoder.format_type())
|
||||
.body(Body::from(buffer))
|
||||
.unwrap();
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub fn thread_main(addr: String) -> anyhow::Result<()> {
|
||||
let addr = addr.parse()?;
|
||||
log::info!("Starting a prometheus endoint at {}", addr);
|
||||
|
||||
// Enter a single-threaded tokio runtime bound to the current thread
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let _guard = runtime.enter();
|
||||
|
||||
// TODO: use hyper_router/routerify/etc when we have more methods
|
||||
let server = Server::bind(&addr).serve(make_service_fn(|_| async {
|
||||
Ok::<_, anyhow::Error>(service_fn(serve_prometheus_metrics))
|
||||
}));
|
||||
|
||||
runtime.block_on(server)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -10,7 +10,6 @@ pub mod seqwait;
|
||||
// pub mod seqwait_async;
|
||||
|
||||
pub mod bin_ser;
|
||||
pub mod http_endpoint;
|
||||
pub mod postgres_backend;
|
||||
pub mod pq_proto;
|
||||
|
||||
@@ -22,3 +21,5 @@ pub mod auth;
|
||||
|
||||
// utility functions and helper traits for unified unique id generation/serialization etc.
|
||||
pub mod zid;
|
||||
// http endpoint utils
|
||||
pub mod http;
|
||||
|
||||
@@ -96,6 +96,20 @@ macro_rules! zid_newtype {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromHex for $t {
|
||||
type Error = hex::FromHexError;
|
||||
|
||||
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
|
||||
Ok($t(ZId::from_hex(hex)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for $t {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.0 .0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for $t {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
@@ -137,19 +151,3 @@ zid_newtype!(ZTimelineId);
|
||||
pub struct ZTenantId(ZId);
|
||||
|
||||
zid_newtype!(ZTenantId);
|
||||
|
||||
// for now the following impls are used only with ZTenantId,
|
||||
// if this impls become useful in other newtypes they can be moved under zid_newtype macro too
|
||||
impl FromHex for ZTenantId {
|
||||
type Error = hex::FromHexError;
|
||||
|
||||
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
|
||||
Ok(ZTenantId(ZId::from_hex(hex)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for ZTenantId {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.0 .0
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user