mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 01:12:56 +00:00
Merge branch 'main' into bojan-get-page-tests
This commit is contained in:
@@ -63,21 +63,18 @@
|
||||
tags:
|
||||
- pageserver
|
||||
|
||||
# It seems that currently S3 integration does not play well
|
||||
# even with fresh pageserver without a burden of old data.
|
||||
# TODO: turn this back on once the issue is solved.
|
||||
# - name: update remote storage (s3) config
|
||||
# lineinfile:
|
||||
# path: /storage/pageserver/data/pageserver.toml
|
||||
# line: "{{ item }}"
|
||||
# loop:
|
||||
# - "[remote_storage]"
|
||||
# - "bucket_name = '{{ bucket_name }}'"
|
||||
# - "bucket_region = '{{ bucket_region }}'"
|
||||
# - "prefix_in_bucket = '{{ inventory_hostname }}'"
|
||||
# become: true
|
||||
# tags:
|
||||
# - pageserver
|
||||
- name: update remote storage (s3) config
|
||||
lineinfile:
|
||||
path: /storage/pageserver/data/pageserver.toml
|
||||
line: "{{ item }}"
|
||||
loop:
|
||||
- "[remote_storage]"
|
||||
- "bucket_name = '{{ bucket_name }}'"
|
||||
- "bucket_region = '{{ bucket_region }}'"
|
||||
- "prefix_in_bucket = '{{ inventory_hostname }}'"
|
||||
become: true
|
||||
tags:
|
||||
- pageserver
|
||||
|
||||
- name: upload systemd service definition
|
||||
ansible.builtin.template:
|
||||
|
||||
@@ -5,7 +5,6 @@ zenith-us-stage-ps-2 console_region_id=27
|
||||
[safekeepers]
|
||||
zenith-us-stage-sk-1 console_region_id=27
|
||||
zenith-us-stage-sk-2 console_region_id=27
|
||||
zenith-us-stage-sk-3 console_region_id=27
|
||||
zenith-us-stage-sk-4 console_region_id=27
|
||||
|
||||
[storage:children]
|
||||
|
||||
@@ -405,7 +405,7 @@ jobs:
|
||||
- run:
|
||||
name: Build coverage report
|
||||
command: |
|
||||
COMMIT_URL=https://github.com/zenithdb/zenith/commit/$CIRCLE_SHA1
|
||||
COMMIT_URL=https://github.com/neondatabase/neon/commit/$CIRCLE_SHA1
|
||||
|
||||
scripts/coverage \
|
||||
--dir=/tmp/zenith/coverage report \
|
||||
@@ -416,8 +416,8 @@ jobs:
|
||||
name: Upload coverage report
|
||||
command: |
|
||||
LOCAL_REPO=$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME
|
||||
REPORT_URL=https://zenithdb.github.io/zenith-coverage-data/$CIRCLE_SHA1
|
||||
COMMIT_URL=https://github.com/zenithdb/zenith/commit/$CIRCLE_SHA1
|
||||
REPORT_URL=https://neondatabase.github.io/zenith-coverage-data/$CIRCLE_SHA1
|
||||
COMMIT_URL=https://github.com/neondatabase/neon/commit/$CIRCLE_SHA1
|
||||
|
||||
scripts/git-upload \
|
||||
--repo=https://$VIP_VAP_ACCESS_TOKEN@github.com/zenithdb/zenith-coverage-data.git \
|
||||
@@ -593,7 +593,7 @@ jobs:
|
||||
name: Setup helm v3
|
||||
command: |
|
||||
curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash
|
||||
helm repo add zenithdb https://zenithdb.github.io/helm-charts
|
||||
helm repo add zenithdb https://neondatabase.github.io/helm-charts
|
||||
- run:
|
||||
name: Re-deploy proxy
|
||||
command: |
|
||||
@@ -643,7 +643,7 @@ jobs:
|
||||
name: Setup helm v3
|
||||
command: |
|
||||
curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash
|
||||
helm repo add zenithdb https://zenithdb.github.io/helm-charts
|
||||
helm repo add zenithdb https://neondatabase.github.io/helm-charts
|
||||
- run:
|
||||
name: Re-deploy proxy
|
||||
command: |
|
||||
@@ -672,7 +672,7 @@ jobs:
|
||||
--data \
|
||||
"{
|
||||
\"state\": \"pending\",
|
||||
\"context\": \"zenith-remote-ci\",
|
||||
\"context\": \"neon-cloud-e2e\",
|
||||
\"description\": \"[$REMOTE_REPO] Remote CI job is about to start\"
|
||||
}"
|
||||
- run:
|
||||
@@ -688,7 +688,7 @@ jobs:
|
||||
"{
|
||||
\"ref\": \"main\",
|
||||
\"inputs\": {
|
||||
\"ci_job_name\": \"zenith-remote-ci\",
|
||||
\"ci_job_name\": \"neon-cloud-e2e\",
|
||||
\"commit_hash\": \"$CIRCLE_SHA1\",
|
||||
\"remote_repo\": \"$LOCAL_REPO\"
|
||||
}
|
||||
@@ -828,11 +828,11 @@ workflows:
|
||||
- remote-ci-trigger:
|
||||
# Context passes credentials for gh api
|
||||
context: CI_ACCESS_TOKEN
|
||||
remote_repo: "zenithdb/console"
|
||||
remote_repo: "neondatabase/cloud"
|
||||
requires:
|
||||
# XXX: Successful build doesn't mean everything is OK, but
|
||||
# the job to be triggered takes so much time to complete (~22 min)
|
||||
# that it's better not to wait for the commented-out steps
|
||||
- build-zenith-debug
|
||||
- build-zenith-release
|
||||
# - pg_regress-tests-release
|
||||
# - other-tests-release
|
||||
|
||||
13
.github/workflows/benchmarking.yml
vendored
13
.github/workflows/benchmarking.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: [self-hosted, zenith-benchmarker]
|
||||
|
||||
env:
|
||||
PG_BIN: "/usr/pgsql-13/bin"
|
||||
POSTGRES_DISTRIB_DIR: "/usr/pgsql-13"
|
||||
|
||||
steps:
|
||||
- name: Checkout zenith repo
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
echo Poetry
|
||||
poetry --version
|
||||
echo Pgbench
|
||||
$PG_BIN/pgbench --version
|
||||
$POSTGRES_DISTRIB_DIR/bin/pgbench --version
|
||||
|
||||
# FIXME cluster setup is skipped due to various changes in console API
|
||||
# for now pre created cluster is used. When API gain some stability
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
|
||||
echo "Starting cluster"
|
||||
# wake up the cluster
|
||||
$PG_BIN/psql $BENCHMARK_CONNSTR -c "SELECT 1"
|
||||
$POSTGRES_DISTRIB_DIR/bin/psql $BENCHMARK_CONNSTR -c "SELECT 1"
|
||||
|
||||
- name: Run benchmark
|
||||
# pgbench is installed system wide from official repo
|
||||
@@ -83,8 +83,11 @@ jobs:
|
||||
# sudo yum install postgresql13-contrib
|
||||
# actual binaries are located in /usr/pgsql-13/bin/
|
||||
env:
|
||||
TEST_PG_BENCH_TRANSACTIONS_MATRIX: "5000,10000,20000"
|
||||
TEST_PG_BENCH_SCALES_MATRIX: "10,15"
|
||||
# The pgbench test runs two tests of given duration against each scale.
|
||||
# So the total runtime with these parameters is 2 * 2 * 300 = 1200, or 20 minutes.
|
||||
# Plus time needed to initialize the test databases.
|
||||
TEST_PG_BENCH_DURATIONS_MATRIX: "300"
|
||||
TEST_PG_BENCH_SCALES_MATRIX: "10,100"
|
||||
PLATFORM: "zenith-staging"
|
||||
BENCHMARK_CONNSTR: "${{ secrets.BENCHMARK_STAGING_CONNSTR }}"
|
||||
REMOTE_ENV: "1" # indicate to test harness that we do not have zenith binaries locally
|
||||
|
||||
44
Cargo.lock
generated
44
Cargo.lock
generated
@@ -361,6 +361,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"tar",
|
||||
"tokio",
|
||||
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
@@ -1571,7 +1572,6 @@ dependencies = [
|
||||
"tokio-util 0.7.0",
|
||||
"toml_edit",
|
||||
"tracing",
|
||||
"tracing-futures",
|
||||
"url",
|
||||
"workspace_hack",
|
||||
"zenith_metrics",
|
||||
@@ -1952,12 +1952,15 @@ name = "proxy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64 0.13.0",
|
||||
"bytes",
|
||||
"clap 3.1.8",
|
||||
"fail",
|
||||
"futures",
|
||||
"hashbrown",
|
||||
"hex",
|
||||
"hmac 0.10.1",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"md5",
|
||||
@@ -1966,10 +1969,13 @@ dependencies = [
|
||||
"rand",
|
||||
"rcgen",
|
||||
"reqwest",
|
||||
"routerify 2.2.0",
|
||||
"rstest",
|
||||
"rustls 0.19.1",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"socket2",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
@@ -2151,7 +2157,6 @@ dependencies = [
|
||||
"serde_urlencoded",
|
||||
"tokio",
|
||||
"tokio-rustls 0.23.2",
|
||||
"tokio-util 0.6.9",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
@@ -2175,6 +2180,19 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "routerify"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c6bb49594c791cadb5ccfa5f36d41b498d40482595c199d10cd318800280bd9"
|
||||
dependencies = [
|
||||
"http",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"percent-encoding",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "routerify"
|
||||
version = "3.0.0"
|
||||
@@ -2188,6 +2206,19 @@ dependencies = [
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rstest"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d912f35156a3f99a66ee3e11ac2e0b3f34ac85a07e05263d05a7e2c8810d616f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustc_version",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rusoto_core"
|
||||
version = "0.47.0"
|
||||
@@ -3403,19 +3434,20 @@ dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"cc",
|
||||
"chrono",
|
||||
"clap 2.34.0",
|
||||
"either",
|
||||
"hashbrown",
|
||||
"indexmap",
|
||||
"libc",
|
||||
"log",
|
||||
"memchr",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"prost",
|
||||
"rand",
|
||||
"regex",
|
||||
"regex-syntax",
|
||||
"reqwest",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"syn",
|
||||
@@ -3495,7 +3527,7 @@ dependencies = [
|
||||
"postgres 0.19.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
|
||||
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
|
||||
"rand",
|
||||
"routerify",
|
||||
"routerify 3.0.0",
|
||||
"rustls 0.19.1",
|
||||
"rustls-split",
|
||||
"serde",
|
||||
|
||||
29
README.md
29
README.md
@@ -1,19 +1,22 @@
|
||||
# Zenith
|
||||
# Neon
|
||||
|
||||
Zenith is a serverless open source alternative to AWS Aurora Postgres. It separates storage and compute and substitutes PostgreSQL storage layer by redistributing data across a cluster of nodes.
|
||||
Neon is a serverless open source alternative to AWS Aurora Postgres. It separates storage and compute and substitutes PostgreSQL storage layer by redistributing data across a cluster of nodes.
|
||||
|
||||
The project used to be called "Zenith". Many of the commands and code comments
|
||||
still refer to "zenith", but we are in the process of renaming things.
|
||||
|
||||
## Architecture overview
|
||||
|
||||
A Zenith installation consists of compute nodes and Zenith storage engine.
|
||||
A Neon installation consists of compute nodes and Neon storage engine.
|
||||
|
||||
Compute nodes are stateless PostgreSQL nodes, backed by Zenith storage engine.
|
||||
Compute nodes are stateless PostgreSQL nodes, backed by Neon storage engine.
|
||||
|
||||
Zenith storage engine consists of two major components:
|
||||
Neon storage engine consists of two major components:
|
||||
- Pageserver. Scalable storage backend for compute nodes.
|
||||
- WAL service. The service that receives WAL from compute node and ensures that it is stored durably.
|
||||
|
||||
Pageserver consists of:
|
||||
- Repository - Zenith storage implementation.
|
||||
- Repository - Neon storage implementation.
|
||||
- WAL receiver - service that receives WAL from WAL service and stores it in the repository.
|
||||
- Page service - service that communicates with compute nodes and responds with pages from the repository.
|
||||
- WAL redo - service that builds pages from base images and WAL records on Page service request.
|
||||
@@ -35,10 +38,10 @@ To run the `psql` client, install the `postgresql-client` package or modify `PAT
|
||||
To run the integration tests or Python scripts (not required to use the code), install
|
||||
Python (3.7 or higher), and install python3 packages using `./scripts/pysync` (requires poetry) in the project directory.
|
||||
|
||||
2. Build zenith and patched postgres
|
||||
2. Build neon and patched postgres
|
||||
```sh
|
||||
git clone --recursive https://github.com/zenithdb/zenith.git
|
||||
cd zenith
|
||||
git clone --recursive https://github.com/neondatabase/neon.git
|
||||
cd neon
|
||||
make -j5
|
||||
```
|
||||
|
||||
@@ -126,7 +129,7 @@ INSERT 0 1
|
||||
## Running tests
|
||||
|
||||
```sh
|
||||
git clone --recursive https://github.com/zenithdb/zenith.git
|
||||
git clone --recursive https://github.com/neondatabase/neon.git
|
||||
make # builds also postgres and installs it to ./tmp_install
|
||||
./scripts/pytest
|
||||
```
|
||||
@@ -141,14 +144,14 @@ To view your `rustdoc` documentation in a browser, try running `cargo doc --no-d
|
||||
|
||||
### Postgres-specific terms
|
||||
|
||||
Due to Zenith's very close relation with PostgreSQL internals, there are numerous specific terms used.
|
||||
Due to Neon's very close relation with PostgreSQL internals, there are numerous specific terms used.
|
||||
Same applies to certain spelling: i.e. we use MB to denote 1024 * 1024 bytes, while MiB would be technically more correct, it's inconsistent with what PostgreSQL code and its documentation use.
|
||||
|
||||
To get more familiar with this aspect, refer to:
|
||||
|
||||
- [Zenith glossary](/docs/glossary.md)
|
||||
- [Neon glossary](/docs/glossary.md)
|
||||
- [PostgreSQL glossary](https://www.postgresql.org/docs/13/glossary.html)
|
||||
- Other PostgreSQL documentation and sources (Zenith fork sources can be found [here](https://github.com/zenithdb/postgres))
|
||||
- Other PostgreSQL documentation and sources (Neon fork sources can be found [here](https://github.com/neondatabase/postgres))
|
||||
|
||||
## Join the development
|
||||
|
||||
|
||||
@@ -17,4 +17,5 @@ serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tar = "0.4"
|
||||
tokio = { version = "1.17", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
workspace_hack = { version = "0.1", path = "../workspace_hack" }
|
||||
|
||||
@@ -38,6 +38,7 @@ use clap::Arg;
|
||||
use log::info;
|
||||
use postgres::{Client, NoTls};
|
||||
|
||||
use compute_tools::checker::create_writablity_check_data;
|
||||
use compute_tools::config;
|
||||
use compute_tools::http_api::launch_http_server;
|
||||
use compute_tools::logger::*;
|
||||
@@ -128,6 +129,7 @@ fn run_compute(state: &Arc<RwLock<ComputeState>>) -> Result<ExitStatus> {
|
||||
|
||||
handle_roles(&read_state.spec, &mut client)?;
|
||||
handle_databases(&read_state.spec, &mut client)?;
|
||||
create_writablity_check_data(&mut client)?;
|
||||
|
||||
// 'Close' connection
|
||||
drop(client);
|
||||
|
||||
46
compute_tools/src/checker.rs
Normal file
46
compute_tools/src/checker.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use log::error;
|
||||
use postgres::Client;
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
use crate::zenith::ComputeState;
|
||||
|
||||
pub fn create_writablity_check_data(client: &mut Client) -> Result<()> {
|
||||
let query = "
|
||||
CREATE TABLE IF NOT EXISTS health_check (
|
||||
id serial primary key,
|
||||
updated_at timestamptz default now()
|
||||
);
|
||||
INSERT INTO health_check VALUES (1, now())
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET updated_at = now();";
|
||||
let result = client.simple_query(query)?;
|
||||
if result.len() < 2 {
|
||||
return Err(anyhow::format_err!("executed {} queries", result.len()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn check_writability(state: &Arc<RwLock<ComputeState>>) -> Result<()> {
|
||||
let connstr = state.read().unwrap().connstr.clone();
|
||||
let (client, connection) = tokio_postgres::connect(&connstr, NoTls).await?;
|
||||
if client.is_closed() {
|
||||
return Err(anyhow!("connection to postgres closed"));
|
||||
}
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let result = client
|
||||
.simple_query("UPDATE health_check SET updated_at = now() WHERE id = 1;")
|
||||
.await?;
|
||||
|
||||
if result.len() != 1 {
|
||||
return Err(anyhow!("statement can't be executed"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -11,7 +11,7 @@ use log::{error, info};
|
||||
use crate::zenith::*;
|
||||
|
||||
// Service function to handle all available routes.
|
||||
fn routes(req: Request<Body>, state: Arc<RwLock<ComputeState>>) -> Response<Body> {
|
||||
async fn routes(req: Request<Body>, state: Arc<RwLock<ComputeState>>) -> Response<Body> {
|
||||
match (req.method(), req.uri().path()) {
|
||||
// Timestamp of the last Postgres activity in the plain text.
|
||||
(&Method::GET, "/last_activity") => {
|
||||
@@ -29,6 +29,15 @@ fn routes(req: Request<Body>, state: Arc<RwLock<ComputeState>>) -> Response<Body
|
||||
Response::new(Body::from(format!("{}", state.ready)))
|
||||
}
|
||||
|
||||
(&Method::GET, "/check_writability") => {
|
||||
info!("serving /check_writability GET request");
|
||||
let res = crate::checker::check_writability(&state).await;
|
||||
match res {
|
||||
Ok(_) => Response::new(Body::from("true")),
|
||||
Err(e) => Response::new(Body::from(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
// Return the `404 Not Found` for any other routes.
|
||||
_ => {
|
||||
let mut not_found = Response::new(Body::from("404 Not Found"));
|
||||
@@ -48,7 +57,7 @@ async fn serve(state: Arc<RwLock<ComputeState>>) {
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
|
||||
let state = state.clone();
|
||||
async move { Ok::<_, Infallible>(routes(req, state)) }
|
||||
async move { Ok::<_, Infallible>(routes(req, state).await) }
|
||||
}))
|
||||
}
|
||||
});
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
//! Various tools and helpers to handle cluster / compute node (Postgres)
|
||||
//! configuration.
|
||||
//!
|
||||
pub mod checker;
|
||||
pub mod config;
|
||||
pub mod http_api;
|
||||
#[macro_use]
|
||||
|
||||
@@ -37,7 +37,6 @@ toml_edit = { version = "0.13", features = ["easy"] }
|
||||
scopeguard = "1.1.0"
|
||||
const_format = "0.2.21"
|
||||
tracing = "0.1.27"
|
||||
tracing-futures = "0.2"
|
||||
signal-hook = "0.3.10"
|
||||
url = "2"
|
||||
nix = "0.23"
|
||||
|
||||
@@ -36,8 +36,8 @@ pub mod defaults {
|
||||
// Target file size, when creating image and delta layers.
|
||||
// This parameter determines L1 layer file size.
|
||||
pub const DEFAULT_COMPACTION_TARGET_SIZE: u64 = 128 * 1024 * 1024;
|
||||
|
||||
pub const DEFAULT_COMPACTION_PERIOD: &str = "1 s";
|
||||
pub const DEFAULT_COMPACTION_THRESHOLD: usize = 10;
|
||||
|
||||
pub const DEFAULT_GC_HORIZON: u64 = 64 * 1024 * 1024;
|
||||
pub const DEFAULT_GC_PERIOD: &str = "100 s";
|
||||
@@ -65,6 +65,7 @@ pub mod defaults {
|
||||
#checkpoint_distance = {DEFAULT_CHECKPOINT_DISTANCE} # in bytes
|
||||
#compaction_target_size = {DEFAULT_COMPACTION_TARGET_SIZE} # in bytes
|
||||
#compaction_period = '{DEFAULT_COMPACTION_PERIOD}'
|
||||
#compaction_threshold = '{DEFAULT_COMPACTION_THRESHOLD}'
|
||||
|
||||
#gc_period = '{DEFAULT_GC_PERIOD}'
|
||||
#gc_horizon = {DEFAULT_GC_HORIZON}
|
||||
@@ -107,6 +108,9 @@ pub struct PageServerConf {
|
||||
// How often to check if there's compaction work to be done.
|
||||
pub compaction_period: Duration,
|
||||
|
||||
// Level0 delta layer threshold for compaction.
|
||||
pub compaction_threshold: usize,
|
||||
|
||||
pub gc_horizon: u64,
|
||||
pub gc_period: Duration,
|
||||
|
||||
@@ -164,6 +168,7 @@ struct PageServerConfigBuilder {
|
||||
|
||||
compaction_target_size: BuilderValue<u64>,
|
||||
compaction_period: BuilderValue<Duration>,
|
||||
compaction_threshold: BuilderValue<usize>,
|
||||
|
||||
gc_horizon: BuilderValue<u64>,
|
||||
gc_period: BuilderValue<Duration>,
|
||||
@@ -202,6 +207,7 @@ impl Default for PageServerConfigBuilder {
|
||||
compaction_target_size: Set(DEFAULT_COMPACTION_TARGET_SIZE),
|
||||
compaction_period: Set(humantime::parse_duration(DEFAULT_COMPACTION_PERIOD)
|
||||
.expect("cannot parse default compaction period")),
|
||||
compaction_threshold: Set(DEFAULT_COMPACTION_THRESHOLD),
|
||||
gc_horizon: Set(DEFAULT_GC_HORIZON),
|
||||
gc_period: Set(humantime::parse_duration(DEFAULT_GC_PERIOD)
|
||||
.expect("cannot parse default gc period")),
|
||||
@@ -246,6 +252,10 @@ impl PageServerConfigBuilder {
|
||||
self.compaction_period = BuilderValue::Set(compaction_period)
|
||||
}
|
||||
|
||||
pub fn compaction_threshold(&mut self, compaction_threshold: usize) {
|
||||
self.compaction_threshold = BuilderValue::Set(compaction_threshold)
|
||||
}
|
||||
|
||||
pub fn gc_horizon(&mut self, gc_horizon: u64) {
|
||||
self.gc_horizon = BuilderValue::Set(gc_horizon)
|
||||
}
|
||||
@@ -322,6 +332,9 @@ impl PageServerConfigBuilder {
|
||||
compaction_period: self
|
||||
.compaction_period
|
||||
.ok_or(anyhow::anyhow!("missing compaction_period"))?,
|
||||
compaction_threshold: self
|
||||
.compaction_threshold
|
||||
.ok_or(anyhow::anyhow!("missing compaction_threshold"))?,
|
||||
gc_horizon: self
|
||||
.gc_horizon
|
||||
.ok_or(anyhow::anyhow!("missing gc_horizon"))?,
|
||||
@@ -465,6 +478,9 @@ impl PageServerConf {
|
||||
builder.compaction_target_size(parse_toml_u64(key, item)?)
|
||||
}
|
||||
"compaction_period" => builder.compaction_period(parse_toml_duration(key, item)?),
|
||||
"compaction_threshold" => {
|
||||
builder.compaction_threshold(parse_toml_u64(key, item)? as usize)
|
||||
}
|
||||
"gc_horizon" => builder.gc_horizon(parse_toml_u64(key, item)?),
|
||||
"gc_period" => builder.gc_period(parse_toml_duration(key, item)?),
|
||||
"wait_lsn_timeout" => builder.wait_lsn_timeout(parse_toml_duration(key, item)?),
|
||||
@@ -603,6 +619,7 @@ impl PageServerConf {
|
||||
checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE,
|
||||
compaction_target_size: 4 * 1024 * 1024,
|
||||
compaction_period: Duration::from_secs(10),
|
||||
compaction_threshold: defaults::DEFAULT_COMPACTION_THRESHOLD,
|
||||
gc_horizon: defaults::DEFAULT_GC_HORIZON,
|
||||
gc_period: Duration::from_secs(10),
|
||||
wait_lsn_timeout: Duration::from_secs(60),
|
||||
@@ -676,6 +693,7 @@ checkpoint_distance = 111 # in bytes
|
||||
|
||||
compaction_target_size = 111 # in bytes
|
||||
compaction_period = '111 s'
|
||||
compaction_threshold = 2
|
||||
|
||||
gc_period = '222 s'
|
||||
gc_horizon = 222
|
||||
@@ -714,6 +732,7 @@ id = 10
|
||||
checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE,
|
||||
compaction_target_size: defaults::DEFAULT_COMPACTION_TARGET_SIZE,
|
||||
compaction_period: humantime::parse_duration(defaults::DEFAULT_COMPACTION_PERIOD)?,
|
||||
compaction_threshold: defaults::DEFAULT_COMPACTION_THRESHOLD,
|
||||
gc_horizon: defaults::DEFAULT_GC_HORIZON,
|
||||
gc_period: humantime::parse_duration(defaults::DEFAULT_GC_PERIOD)?,
|
||||
wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?,
|
||||
@@ -760,6 +779,7 @@ id = 10
|
||||
checkpoint_distance: 111,
|
||||
compaction_target_size: 111,
|
||||
compaction_period: Duration::from_secs(111),
|
||||
compaction_threshold: 2,
|
||||
gc_horizon: 222,
|
||||
gc_period: Duration::from_secs(222),
|
||||
wait_lsn_timeout: Duration::from_secs(111),
|
||||
|
||||
@@ -49,7 +49,8 @@ use crate::CheckpointConfig;
|
||||
use crate::{ZTenantId, ZTimelineId};
|
||||
|
||||
use zenith_metrics::{
|
||||
register_histogram_vec, register_int_gauge_vec, Histogram, HistogramVec, IntGauge, IntGaugeVec,
|
||||
register_histogram_vec, register_int_counter, register_int_gauge_vec, Histogram, HistogramVec,
|
||||
IntCounter, IntGauge, IntGaugeVec,
|
||||
};
|
||||
use zenith_utils::crashsafe_dir;
|
||||
use zenith_utils::lsn::{AtomicLsn, Lsn, RecordLsn};
|
||||
@@ -109,6 +110,21 @@ lazy_static! {
|
||||
.expect("failed to define a metric");
|
||||
}
|
||||
|
||||
// Metrics for cloud upload. These metrics reflect data uploaded to cloud storage,
|
||||
// or in testing they estimate how much we would upload if we did.
|
||||
lazy_static! {
|
||||
static ref NUM_PERSISTENT_FILES_CREATED: IntCounter = register_int_counter!(
|
||||
"pageserver_num_persistent_files_created",
|
||||
"Number of files created that are meant to be uploaded to cloud storage",
|
||||
)
|
||||
.expect("failed to define a metric");
|
||||
static ref PERSISTENT_BYTES_WRITTEN: IntCounter = register_int_counter!(
|
||||
"pageserver_persistent_bytes_written",
|
||||
"Total bytes written that are meant to be uploaded to cloud storage",
|
||||
)
|
||||
.expect("failed to define a metric");
|
||||
}
|
||||
|
||||
/// Parts of the `.zenith/tenants/<tenantid>/timelines/<timelineid>` directory prefix.
|
||||
pub const TIMELINES_SEGMENT_NAME: &str = "timelines";
|
||||
|
||||
@@ -193,7 +209,7 @@ impl Repository for LayeredRepository {
|
||||
Arc::clone(&self.walredo_mgr),
|
||||
self.upload_layers,
|
||||
);
|
||||
timeline.layers.lock().unwrap().next_open_layer_at = Some(initdb_lsn);
|
||||
timeline.layers.write().unwrap().next_open_layer_at = Some(initdb_lsn);
|
||||
|
||||
let timeline = Arc::new(timeline);
|
||||
let r = timelines.insert(
|
||||
@@ -725,7 +741,7 @@ pub struct LayeredTimeline {
|
||||
tenantid: ZTenantId,
|
||||
timelineid: ZTimelineId,
|
||||
|
||||
layers: Mutex<LayerMap>,
|
||||
layers: RwLock<LayerMap>,
|
||||
|
||||
last_freeze_at: AtomicLsn,
|
||||
|
||||
@@ -997,7 +1013,7 @@ impl LayeredTimeline {
|
||||
conf,
|
||||
timelineid,
|
||||
tenantid,
|
||||
layers: Mutex::new(LayerMap::default()),
|
||||
layers: RwLock::new(LayerMap::default()),
|
||||
|
||||
walredo_mgr,
|
||||
|
||||
@@ -1040,7 +1056,7 @@ impl LayeredTimeline {
|
||||
/// Returns all timeline-related files that were found and loaded.
|
||||
///
|
||||
fn load_layer_map(&self, disk_consistent_lsn: Lsn) -> anyhow::Result<()> {
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
let mut layers = self.layers.write().unwrap();
|
||||
let mut num_layers = 0;
|
||||
|
||||
// Scan timeline directory and create ImageFileName and DeltaFilename
|
||||
@@ -1194,7 +1210,7 @@ impl LayeredTimeline {
|
||||
continue;
|
||||
}
|
||||
|
||||
let layers = timeline.layers.lock().unwrap();
|
||||
let layers = timeline.layers.read().unwrap();
|
||||
|
||||
// Check the open and frozen in-memory layers first
|
||||
if let Some(open_layer) = &layers.open_layer {
|
||||
@@ -1276,7 +1292,7 @@ impl LayeredTimeline {
|
||||
/// Get a handle to the latest layer for appending.
|
||||
///
|
||||
fn get_layer_for_write(&self, lsn: Lsn) -> anyhow::Result<Arc<InMemoryLayer>> {
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
let mut layers = self.layers.write().unwrap();
|
||||
|
||||
ensure!(lsn.is_aligned());
|
||||
|
||||
@@ -1347,7 +1363,7 @@ impl LayeredTimeline {
|
||||
} else {
|
||||
Some(self.write_lock.lock().unwrap())
|
||||
};
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
let mut layers = self.layers.write().unwrap();
|
||||
if let Some(open_layer) = &layers.open_layer {
|
||||
let open_layer_rc = Arc::clone(open_layer);
|
||||
// Does this layer need freezing?
|
||||
@@ -1412,7 +1428,7 @@ impl LayeredTimeline {
|
||||
let timer = self.flush_time_histo.start_timer();
|
||||
|
||||
loop {
|
||||
let layers = self.layers.lock().unwrap();
|
||||
let layers = self.layers.read().unwrap();
|
||||
if let Some(frozen_layer) = layers.frozen_layers.front() {
|
||||
let frozen_layer = Arc::clone(frozen_layer);
|
||||
drop(layers); // to allow concurrent reads and writes
|
||||
@@ -1456,7 +1472,7 @@ impl LayeredTimeline {
|
||||
|
||||
// Finally, replace the frozen in-memory layer with the new on-disk layers
|
||||
{
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
let mut layers = self.layers.write().unwrap();
|
||||
let l = layers.frozen_layers.pop_front();
|
||||
|
||||
// Only one thread may call this function at a time (for this
|
||||
@@ -1524,6 +1540,10 @@ impl LayeredTimeline {
|
||||
&metadata,
|
||||
false,
|
||||
)?;
|
||||
|
||||
NUM_PERSISTENT_FILES_CREATED.inc_by(1);
|
||||
PERSISTENT_BYTES_WRITTEN.inc_by(new_delta_path.metadata()?.len());
|
||||
|
||||
if self.upload_layers.load(atomic::Ordering::Relaxed) {
|
||||
schedule_timeline_checkpoint_upload(
|
||||
self.tenantid,
|
||||
@@ -1612,7 +1632,7 @@ impl LayeredTimeline {
|
||||
lsn: Lsn,
|
||||
threshold: usize,
|
||||
) -> Result<bool> {
|
||||
let layers = self.layers.lock().unwrap();
|
||||
let layers = self.layers.read().unwrap();
|
||||
|
||||
for part_range in &partition.ranges {
|
||||
let image_coverage = layers.image_coverage(part_range, lsn)?;
|
||||
@@ -1670,7 +1690,7 @@ impl LayeredTimeline {
|
||||
|
||||
// FIXME: Do we need to do something to upload it to remote storage here?
|
||||
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
let mut layers = self.layers.write().unwrap();
|
||||
layers.insert_historic(Arc::new(image_layer));
|
||||
drop(layers);
|
||||
|
||||
@@ -1678,15 +1698,13 @@ impl LayeredTimeline {
|
||||
}
|
||||
|
||||
fn compact_level0(&self, target_file_size: u64) -> Result<()> {
|
||||
let layers = self.layers.lock().unwrap();
|
||||
|
||||
// We compact or "shuffle" the level-0 delta layers when 10 have
|
||||
// accumulated.
|
||||
static COMPACT_THRESHOLD: usize = 10;
|
||||
let layers = self.layers.read().unwrap();
|
||||
|
||||
let level0_deltas = layers.get_level0_deltas()?;
|
||||
|
||||
if level0_deltas.len() < COMPACT_THRESHOLD {
|
||||
// We compact or "shuffle" the level-0 delta layers when they've
|
||||
// accumulated over the compaction threshold.
|
||||
if level0_deltas.len() < self.conf.compaction_threshold {
|
||||
return Ok(());
|
||||
}
|
||||
drop(layers);
|
||||
@@ -1770,7 +1788,7 @@ impl LayeredTimeline {
|
||||
layer_paths.pop().unwrap();
|
||||
}
|
||||
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
let mut layers = self.layers.write().unwrap();
|
||||
for l in new_layers {
|
||||
layers.insert_historic(Arc::new(l));
|
||||
}
|
||||
@@ -1852,7 +1870,7 @@ impl LayeredTimeline {
|
||||
// 2. it doesn't need to be retained for 'retain_lsns';
|
||||
// 3. newer on-disk image layers cover the layer's whole key range
|
||||
//
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
let mut layers = self.layers.write().unwrap();
|
||||
'outer: for l in layers.iter_historic_layers() {
|
||||
// This layer is in the process of being flushed to disk.
|
||||
// It will be swapped out of the layer map, replaced with
|
||||
|
||||
@@ -198,7 +198,6 @@ impl BlockWriter for BlockBuf {
|
||||
assert!(buf.len() == PAGE_SZ);
|
||||
let blknum = self.blocks.len();
|
||||
self.blocks.push(buf);
|
||||
tracing::info!("buffered block {}", blknum);
|
||||
Ok(blknum as u32)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -713,6 +713,26 @@ impl postgres_backend::Handler for PageServerHandler {
|
||||
Some(result.elapsed.as_millis().to_string().as_bytes()),
|
||||
]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("compact ") {
|
||||
// Run compaction immediately on given timeline.
|
||||
// FIXME This is just for tests. Don't expect this to be exposed to
|
||||
// the users or the api.
|
||||
|
||||
// compact <tenant_id> <timeline_id>
|
||||
let re = Regex::new(r"^compact ([[:xdigit:]]+)\s([[:xdigit:]]+)($|\s)?").unwrap();
|
||||
|
||||
let caps = re
|
||||
.captures(query_string)
|
||||
.with_context(|| format!("Invalid compact: '{}'", query_string))?;
|
||||
|
||||
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
|
||||
let timelineid = ZTimelineId::from_str(caps.get(2).unwrap().as_str())?;
|
||||
let timeline = tenant_mgr::get_timeline_for_tenant_load(tenantid, timelineid)
|
||||
.context("Couldn't load timeline")?;
|
||||
timeline.tline.compact()?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("checkpoint ") {
|
||||
// Run checkpoint immediately on given timeline.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! Timeline synchrnonization logic to put files from archives on remote storage into pageserver's local directory.
|
||||
|
||||
use std::{borrow::Cow, collections::BTreeSet, path::PathBuf, sync::Arc};
|
||||
use std::{collections::BTreeSet, path::PathBuf, sync::Arc};
|
||||
|
||||
use anyhow::{ensure, Context};
|
||||
use tokio::fs;
|
||||
@@ -64,11 +64,16 @@ pub(super) async fn download_timeline<
|
||||
let remote_timeline = match index_read.timeline_entry(&sync_id) {
|
||||
None => {
|
||||
error!("Cannot download: no timeline is present in the index for given id");
|
||||
drop(index_read);
|
||||
return DownloadedTimeline::Abort;
|
||||
}
|
||||
|
||||
Some(index_entry) => match index_entry.inner() {
|
||||
TimelineIndexEntryInner::Full(remote_timeline) => Cow::Borrowed(remote_timeline),
|
||||
TimelineIndexEntryInner::Full(remote_timeline) => {
|
||||
let cloned = remote_timeline.clone();
|
||||
drop(index_read);
|
||||
cloned
|
||||
}
|
||||
TimelineIndexEntryInner::Description(_) => {
|
||||
// we do not check here for awaits_download because it is ok
|
||||
// to call this function while the download is in progress
|
||||
@@ -84,7 +89,7 @@ pub(super) async fn download_timeline<
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(remote_timeline) => Cow::Owned(remote_timeline),
|
||||
Ok(remote_timeline) => remote_timeline,
|
||||
Err(e) => {
|
||||
error!("Failed to download full timeline index: {:?}", e);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! Timeline synchronization logic to compress and upload to the remote storage all new timeline files from the checkpoints.
|
||||
|
||||
use std::{borrow::Cow, collections::BTreeSet, path::PathBuf, sync::Arc};
|
||||
use std::{collections::BTreeSet, path::PathBuf, sync::Arc};
|
||||
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
@@ -46,13 +46,21 @@ pub(super) async fn upload_timeline_checkpoint<
|
||||
|
||||
let index_read = index.read().await;
|
||||
let remote_timeline = match index_read.timeline_entry(&sync_id) {
|
||||
None => None,
|
||||
None => {
|
||||
drop(index_read);
|
||||
None
|
||||
}
|
||||
Some(entry) => match entry.inner() {
|
||||
TimelineIndexEntryInner::Full(remote_timeline) => Some(Cow::Borrowed(remote_timeline)),
|
||||
TimelineIndexEntryInner::Full(remote_timeline) => {
|
||||
let r = Some(remote_timeline.clone());
|
||||
drop(index_read);
|
||||
r
|
||||
}
|
||||
TimelineIndexEntryInner::Description(_) => {
|
||||
drop(index_read);
|
||||
debug!("Found timeline description for the given ids, downloading the full index");
|
||||
match fetch_full_index(remote_assets.as_ref(), &timeline_dir, sync_id).await {
|
||||
Ok(remote_timeline) => Some(Cow::Owned(remote_timeline)),
|
||||
Ok(remote_timeline) => Some(remote_timeline),
|
||||
Err(e) => {
|
||||
error!("Failed to download full timeline index: {:?}", e);
|
||||
sync_queue::push(SyncTask::new(
|
||||
@@ -82,7 +90,6 @@ pub(super) async fn upload_timeline_checkpoint<
|
||||
let already_uploaded_files = remote_timeline
|
||||
.map(|timeline| timeline.stored_files(&timeline_dir))
|
||||
.unwrap_or_default();
|
||||
drop(index_read);
|
||||
|
||||
match try_upload_checkpoint(
|
||||
config,
|
||||
|
||||
@@ -252,8 +252,10 @@ pub trait Repository: Send + Sync {
|
||||
checkpoint_before_gc: bool,
|
||||
) -> Result<GcResult>;
|
||||
|
||||
/// perform one compaction iteration.
|
||||
/// this function is periodically called by compactor thread.
|
||||
/// Perform one compaction iteration.
|
||||
/// This function is periodically called by compactor thread.
|
||||
/// Also it can be explicitly requested per timeline through page server
|
||||
/// api's 'compact' command.
|
||||
fn compaction_iteration(&self) -> Result<()>;
|
||||
|
||||
/// detaches locally available timeline by stopping all threads and removing all the data.
|
||||
|
||||
@@ -5,12 +5,14 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
base64 = "0.13.0"
|
||||
bytes = { version = "1.0.1", features = ['serde'] }
|
||||
clap = "3.0"
|
||||
fail = "0.5.0"
|
||||
futures = "0.3.13"
|
||||
hashbrown = "0.11.2"
|
||||
hex = "0.4.3"
|
||||
hmac = "0.10.1"
|
||||
hyper = "0.14"
|
||||
lazy_static = "1.4.0"
|
||||
md5 = "0.7.0"
|
||||
@@ -18,12 +20,14 @@ parking_lot = "0.11.2"
|
||||
pin-project-lite = "0.2.7"
|
||||
rand = "0.8.3"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
routerify = "2"
|
||||
rustls = "0.19.1"
|
||||
scopeguard = "1.1.0"
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
sha2 = "0.9.8"
|
||||
socket2 = "0.4.4"
|
||||
thiserror = "1.0"
|
||||
thiserror = "1.0.30"
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
tokio-rustls = "0.22.0"
|
||||
@@ -33,5 +37,7 @@ zenith_metrics = { path = "../zenith_metrics" }
|
||||
workspace_hack = { version = "0.1", path = "../workspace_hack" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-postgres-rustls = "0.8.0"
|
||||
async-trait = "0.1"
|
||||
rcgen = "0.8.14"
|
||||
rstest = "0.12"
|
||||
tokio-postgres-rustls = "0.8.0"
|
||||
|
||||
@@ -1,14 +1,24 @@
|
||||
mod credentials;
|
||||
|
||||
#[cfg(test)]
|
||||
mod flow;
|
||||
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::cplane_api::{self, CPlaneApi};
|
||||
use crate::error::UserFacingError;
|
||||
use crate::stream::PqStream;
|
||||
use crate::waiters;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
pub use credentials::ClientCredentials;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use flow::*;
|
||||
|
||||
/// Common authentication error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
@@ -16,13 +26,17 @@ pub enum AuthErrorImpl {
|
||||
#[error(transparent)]
|
||||
Console(#[from] cplane_api::AuthError),
|
||||
|
||||
#[cfg(test)]
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
|
||||
/// For passwords that couldn't be processed by [`parse_password`].
|
||||
#[error("Malformed password message")]
|
||||
MalformedPassword,
|
||||
|
||||
/// Errors produced by [`PqStream`].
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl AuthErrorImpl {
|
||||
@@ -67,70 +81,6 @@ impl UserFacingError for AuthError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ClientCredsParseError {
|
||||
#[error("Parameter `{0}` is missing in startup packet")]
|
||||
MissingKey(&'static str),
|
||||
}
|
||||
|
||||
impl UserFacingError for ClientCredsParseError {}
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
}
|
||||
|
||||
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
type Error = ClientCredsParseError;
|
||||
|
||||
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
|
||||
let mut get_param = |key| {
|
||||
value
|
||||
.remove(key)
|
||||
.ok_or(ClientCredsParseError::MissingKey(key))
|
||||
};
|
||||
|
||||
let user = get_param("user")?;
|
||||
let db = get_param("database")?;
|
||||
|
||||
Ok(Self { user, dbname: db })
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
/// Use credentials to authenticate the user.
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<DatabaseInfo, AuthError> {
|
||||
fail::fail_point!("proxy-authenticate", |_| {
|
||||
Err(AuthError::auth_failed("failpoint triggered"))
|
||||
});
|
||||
|
||||
use crate::config::ClientAuthMethod::*;
|
||||
use crate::config::RouterConfig::*;
|
||||
match &config.router_config {
|
||||
Static { host, port } => handle_static(host.clone(), *port, client, self).await,
|
||||
Dynamic(Mixed) => {
|
||||
if self.user.ends_with("@zenith") {
|
||||
handle_existing_user(config, client, self).await
|
||||
} else {
|
||||
handle_new_user(config, client).await
|
||||
}
|
||||
}
|
||||
Dynamic(Password) => handle_existing_user(config, client, self).await,
|
||||
Dynamic(Link) => handle_new_user(config, client).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
async fn handle_static(
|
||||
host: String,
|
||||
port: u16,
|
||||
@@ -169,7 +119,7 @@ async fn handle_existing_user(
|
||||
let md5_salt = rand::random();
|
||||
|
||||
client
|
||||
.write_message(&Be::AuthenticationMD5Password(&md5_salt))
|
||||
.write_message(&Be::AuthenticationMD5Password(md5_salt))
|
||||
.await?;
|
||||
|
||||
// Read client's password hash
|
||||
@@ -213,6 +163,10 @@ async fn handle_new_user(
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
fn parse_password(bytes: &[u8]) -> Option<&str> {
|
||||
std::str::from_utf8(bytes).ok()?.strip_suffix('\0')
|
||||
}
|
||||
|
||||
70
proxy/src/auth/credentials.rs
Normal file
70
proxy/src/auth/credentials.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use super::AuthError;
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::error::UserFacingError;
|
||||
use crate::stream::PqStream;
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ClientCredsParseError {
|
||||
#[error("Parameter `{0}` is missing in startup packet")]
|
||||
MissingKey(&'static str),
|
||||
}
|
||||
|
||||
impl UserFacingError for ClientCredsParseError {}
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
}
|
||||
|
||||
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
type Error = ClientCredsParseError;
|
||||
|
||||
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
|
||||
let mut get_param = |key| {
|
||||
value
|
||||
.remove(key)
|
||||
.ok_or(ClientCredsParseError::MissingKey(key))
|
||||
};
|
||||
|
||||
let user = get_param("user")?;
|
||||
let db = get_param("database")?;
|
||||
|
||||
Ok(Self { user, dbname: db })
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
/// Use credentials to authenticate the user.
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<DatabaseInfo, AuthError> {
|
||||
fail::fail_point!("proxy-authenticate", |_| {
|
||||
Err(AuthError::auth_failed("failpoint triggered"))
|
||||
});
|
||||
|
||||
use crate::config::ClientAuthMethod::*;
|
||||
use crate::config::RouterConfig::*;
|
||||
match &config.router_config {
|
||||
Static { host, port } => super::handle_static(host.clone(), *port, client, self).await,
|
||||
Dynamic(Mixed) => {
|
||||
if self.user.ends_with("@zenith") {
|
||||
super::handle_existing_user(config, client, self).await
|
||||
} else {
|
||||
super::handle_new_user(config, client).await
|
||||
}
|
||||
}
|
||||
Dynamic(Password) => super::handle_existing_user(config, client, self).await,
|
||||
Dynamic(Link) => super::handle_new_user(config, client).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
102
proxy/src/auth/flow.rs
Normal file
102
proxy/src/auth/flow.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::{AuthError, AuthErrorImpl};
|
||||
use crate::stream::PqStream;
|
||||
use crate::{sasl, scram};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
pub struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
}
|
||||
}
|
||||
|
||||
/// Use password-based auth in [`AuthFlow`].
|
||||
pub struct Md5(
|
||||
/// Salt for client.
|
||||
pub [u8; 4],
|
||||
);
|
||||
|
||||
impl AuthMethod for Md5 {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationMD5Password(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, Stream, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream>,
|
||||
/// State might contain ancillary data (see [`AuthFlow::begin`]).
|
||||
state: State,
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub fn new(stream: &'a mut PqStream<S>) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
state: Begin,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
self.stream.write_message(&method.first_message()).await?;
|
||||
|
||||
Ok(AuthFlow {
|
||||
stream: self.stream,
|
||||
state: method,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling simple MD5 password auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Md5> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
#[allow(unused)]
|
||||
pub async fn authenticate(self) -> Result<(), AuthError> {
|
||||
unimplemented!("MD5 auth flow is yet to be implemented");
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> Result<(), AuthError> {
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
|
||||
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
if !scram::METHODS.contains(&sasl.method) {
|
||||
return Err(AuthErrorImpl::auth_failed("method not supported").into());
|
||||
}
|
||||
|
||||
let secret = self.state.0;
|
||||
sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(secret, rand::random, None))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,8 @@
|
||||
///
|
||||
/// Postgres protocol proxy/router.
|
||||
///
|
||||
/// This service listens psql port and can check auth via external service
|
||||
/// (control plane API in our case) and can create new databases and accounts
|
||||
/// in somewhat transparent manner (again via communication with control plane API).
|
||||
///
|
||||
use anyhow::{bail, Context};
|
||||
use clap::{Arg, Command};
|
||||
use config::ProxyConfig;
|
||||
use futures::FutureExt;
|
||||
use std::future::Future;
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use zenith_utils::GIT_VERSION;
|
||||
|
||||
use crate::config::{ClientAuthMethod, RouterConfig};
|
||||
//! Postgres protocol proxy/router.
|
||||
//!
|
||||
//! This service listens psql port and can check auth via external service
|
||||
//! (control plane API in our case) and can create new databases and accounts
|
||||
//! in somewhat transparent manner (again via communication with control plane API).
|
||||
|
||||
mod auth;
|
||||
mod cancellation;
|
||||
@@ -27,6 +16,24 @@ mod proxy;
|
||||
mod stream;
|
||||
mod waiters;
|
||||
|
||||
// Currently SCRAM is only used in tests
|
||||
#[cfg(test)]
|
||||
mod parse;
|
||||
#[cfg(test)]
|
||||
mod sasl;
|
||||
#[cfg(test)]
|
||||
mod scram;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use clap::{Arg, Command};
|
||||
use config::ProxyConfig;
|
||||
use futures::FutureExt;
|
||||
use std::future::Future;
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use zenith_utils::GIT_VERSION;
|
||||
|
||||
use crate::config::{ClientAuthMethod, RouterConfig};
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
async fn flatten_err(
|
||||
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
|
||||
|
||||
18
proxy/src/parse.rs
Normal file
18
proxy/src/parse.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! Small parsing helpers.
|
||||
|
||||
use std::convert::TryInto;
|
||||
use std::ffi::CStr;
|
||||
|
||||
pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
|
||||
let pos = bytes.iter().position(|&x| x == 0)?;
|
||||
let (cstr, other) = bytes.split_at(pos + 1);
|
||||
// SAFETY: we've already checked that there's a terminator
|
||||
Some((unsafe { CStr::from_bytes_with_nul_unchecked(cstr) }, other))
|
||||
}
|
||||
|
||||
pub fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
|
||||
(bytes.len() >= N).then(|| {
|
||||
let (head, tail) = bytes.split_at(N);
|
||||
(head.try_into().unwrap(), tail)
|
||||
})
|
||||
}
|
||||
@@ -119,7 +119,6 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// We can't perform TLS handshake without a config
|
||||
let enc = tls.is_some();
|
||||
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
||||
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
@@ -219,32 +218,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<S> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use tokio::io::DuplexStream;
|
||||
use crate::{auth, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
async fn dummy_proxy(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
tls: Option<TlsConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
|
||||
// TODO: add some infra + tests for credentials
|
||||
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
|
||||
.await?
|
||||
.context("no stream")?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
hostname: &str,
|
||||
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
|
||||
@@ -262,19 +243,115 @@ mod tests {
|
||||
))
|
||||
}
|
||||
|
||||
struct ClientConfig<'a> {
|
||||
config: rustls::ClientConfig,
|
||||
hostname: &'a str,
|
||||
}
|
||||
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
self,
|
||||
) -> anyhow::Result<impl tokio_postgres::tls::TlsConnect<S>> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
Ok(tls)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate TLS certificates and build rustls configs for client and server.
|
||||
fn generate_tls_config(
|
||||
hostname: &str,
|
||||
) -> anyhow::Result<(ClientConfig<'_>, Arc<rustls::ServerConfig>)> {
|
||||
let (ca, cert, key) = generate_certs(hostname)?;
|
||||
|
||||
let server_config = {
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config.into()
|
||||
};
|
||||
|
||||
let client_config = {
|
||||
let mut config = rustls::ClientConfig::new();
|
||||
config.root_store.add(&ca)?;
|
||||
ClientConfig { config, hostname }
|
||||
};
|
||||
|
||||
Ok((client_config, server_config))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
trait TestAuth: Sized {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
_stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct NoAuth;
|
||||
impl TestAuth for NoAuth {}
|
||||
|
||||
struct Scram(scram::ServerSecret);
|
||||
|
||||
impl Scram {
|
||||
fn new(password: &str) -> anyhow::Result<Self> {
|
||||
let salt = rand::random::<[u8; 16]>();
|
||||
let secret = scram::ServerSecret::build(password, &salt, 256)
|
||||
.context("failed to generate scram secret")?;
|
||||
Ok(Scram(secret))
|
||||
}
|
||||
|
||||
fn mock(user: &str) -> Self {
|
||||
let salt = rand::random::<[u8; 32]>();
|
||||
Scram(scram::ServerSecret::mock(user, &salt))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TestAuth for Scram {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
auth::AuthFlow::new(stream)
|
||||
.begin(auth::Scram(&self.0))
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A dummy proxy impl which performs a handshake and reports auth success.
|
||||
async fn dummy_proxy(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
tls: Option<TlsConfig>,
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
|
||||
.await?
|
||||
.context("handshake failed")?;
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let server_config = {
|
||||
let (_ca, cert, key) = generate_certs("localhost")?;
|
||||
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config
|
||||
};
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
|
||||
let (_, server_config) = generate_tls_config("localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let client_err = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
@@ -301,30 +378,14 @@ mod tests {
|
||||
async fn handshake_tls() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (ca, cert, key) = generate_certs("localhost")?;
|
||||
|
||||
let server_config = {
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config
|
||||
};
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
|
||||
|
||||
let client_config = {
|
||||
let mut config = rustls::ClientConfig::new();
|
||||
config.root_store.add(&ca)?;
|
||||
config
|
||||
};
|
||||
|
||||
let mut mk = MakeRustlsConnect::new(client_config);
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, "localhost")?;
|
||||
let (client_config, server_config) = generate_tls_config("localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, tls)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
@@ -334,7 +395,7 @@ mod tests {
|
||||
async fn handshake_raw() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None));
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
@@ -350,7 +411,7 @@ mod tests {
|
||||
async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None));
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
|
||||
|
||||
let client_err = tokio_postgres::Config::new()
|
||||
.ssl_mode(SslMode::Disable)
|
||||
@@ -391,4 +452,66 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case("password_foo")]
|
||||
#[case("pwd-bar")]
|
||||
#[case("")]
|
||||
#[tokio::test]
|
||||
async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) = generate_tls_config("localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::new(password)?,
|
||||
));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) = generate_tls_config("localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::mock("user"),
|
||||
));
|
||||
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
let password: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(rand::random::<u8>() as usize)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(&password) // no password will match the mocked secret
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, client_config.make_tls_connect()?)
|
||||
.await
|
||||
.err() // -> Option<E>
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
let _server_err = proxy
|
||||
.await?
|
||||
.err() // -> Option<E>
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
47
proxy/src/sasl.rs
Normal file
47
proxy/src/sasl.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
//! Simple Authentication and Security Layer.
|
||||
//!
|
||||
//! RFC: <https://datatracker.ietf.org/doc/html/rfc4422>.
|
||||
//!
|
||||
//! Reference implementation:
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-sasl.c>
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth.c>
|
||||
|
||||
mod channel_binding;
|
||||
mod messages;
|
||||
mod stream;
|
||||
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
pub use channel_binding::ChannelBinding;
|
||||
pub use messages::FirstMessage;
|
||||
pub use stream::SaslStream;
|
||||
|
||||
/// Fine-grained auth errors help in writing tests.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Failed to authenticate client: {0}")]
|
||||
AuthenticationFailed(&'static str),
|
||||
|
||||
#[error("Channel binding failed: {0}")]
|
||||
ChannelBindingFailed(&'static str),
|
||||
|
||||
#[error("Unsupported channel binding method: {0}")]
|
||||
ChannelBindingBadMethod(Box<str>),
|
||||
|
||||
#[error("Bad client message")]
|
||||
BadClientMessage,
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
/// A convenient result type for SASL exchange.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait.
|
||||
pub trait Mechanism: Sized {
|
||||
/// Produce a server challenge to be sent to the client.
|
||||
/// This is how this method is called in PostgreSQL (`libpq/sasl.h`).
|
||||
fn exchange(self, input: &str) -> Result<(Option<Self>, String)>;
|
||||
}
|
||||
85
proxy/src/sasl/channel_binding.rs
Normal file
85
proxy/src/sasl/channel_binding.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
//! Definition and parser for channel binding flag (a part of the `GS2` header).
|
||||
|
||||
/// Channel binding flag (possibly with params).
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum ChannelBinding<T> {
|
||||
/// Client doesn't support channel binding.
|
||||
NotSupportedClient,
|
||||
/// Client thinks server doesn't support channel binding.
|
||||
NotSupportedServer,
|
||||
/// Client wants to use this type of channel binding.
|
||||
Required(T),
|
||||
}
|
||||
|
||||
impl<T> ChannelBinding<T> {
|
||||
pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
NotSupportedClient => NotSupportedClient,
|
||||
NotSupportedServer => NotSupportedServer,
|
||||
Required(x) => Required(f(x)?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ChannelBinding<&'a str> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
use ChannelBinding::*;
|
||||
Some(match input {
|
||||
"n" => NotSupportedClient,
|
||||
"y" => NotSupportedServer,
|
||||
other => Required(other.strip_prefix("p=")?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
/// Encode channel binding data as base64 for subsequent checks.
|
||||
pub fn encode<E>(
|
||||
&self,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
|
||||
) -> Result<std::borrow::Cow<'static, str>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
NotSupportedClient => {
|
||||
// base64::encode("n,,")
|
||||
"biws".into()
|
||||
}
|
||||
NotSupportedServer => {
|
||||
// base64::encode("y,,")
|
||||
"eSws".into()
|
||||
}
|
||||
Required(mode) => {
|
||||
let msg = format!(
|
||||
"p={mode},,{data}",
|
||||
mode = mode,
|
||||
data = get_cbind_data(mode)?
|
||||
);
|
||||
base64::encode(msg).into()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn channel_binding_encode() -> anyhow::Result<()> {
|
||||
use ChannelBinding::*;
|
||||
|
||||
let cases = [
|
||||
(NotSupportedClient, base64::encode("n,,")),
|
||||
(NotSupportedServer, base64::encode("y,,")),
|
||||
(Required("foo"), base64::encode("p=foo,,bar")),
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
67
proxy/src/sasl/messages.rs
Normal file
67
proxy/src/sasl/messages.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
//! Definitions for SASL messages.
|
||||
|
||||
use crate::parse::{split_at_const, split_cstr};
|
||||
use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
|
||||
/// SASL-specific payload of [`PasswordMessage`](zenith_utils::pq_proto::FeMessage::PasswordMessage).
|
||||
#[derive(Debug)]
|
||||
pub struct FirstMessage<'a> {
|
||||
/// Authentication method, e.g. `"SCRAM-SHA-256"`.
|
||||
pub method: &'a str,
|
||||
/// Initial client message.
|
||||
pub message: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> FirstMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(bytes: &'a [u8]) -> Option<Self> {
|
||||
let (method_cstr, tail) = split_cstr(bytes)?;
|
||||
let method = method_cstr.to_str().ok()?;
|
||||
|
||||
let (len_bytes, bytes) = split_at_const(tail)?;
|
||||
let len = u32::from_be_bytes(*len_bytes) as usize;
|
||||
if len != bytes.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let message = std::str::from_utf8(bytes).ok()?;
|
||||
Some(Self { method, message })
|
||||
}
|
||||
}
|
||||
|
||||
/// A single SASL message.
|
||||
/// This struct is deliberately decoupled from lower-level
|
||||
/// [`BeAuthenticationSaslMessage`](zenith_utils::pq_proto::BeAuthenticationSaslMessage).
|
||||
#[derive(Debug)]
|
||||
pub(super) enum ServerMessage<T> {
|
||||
/// We expect to see more steps.
|
||||
Continue(T),
|
||||
/// This is the final step.
|
||||
Final(T),
|
||||
}
|
||||
|
||||
impl<'a> ServerMessage<&'a str> {
|
||||
pub(super) fn to_reply(&self) -> BeMessage<'a> {
|
||||
use BeAuthenticationSaslMessage::*;
|
||||
BeMessage::AuthenticationSasl(match self {
|
||||
ServerMessage::Continue(s) => Continue(s.as_bytes()),
|
||||
ServerMessage::Final(s) => Final(s.as_bytes()),
|
||||
})
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_sasl_first_message() {
|
||||
let proto = "SCRAM-SHA-256";
|
||||
let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4";
|
||||
let sasl_len = (sasl.len() as u32).to_be_bytes();
|
||||
let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat();
|
||||
|
||||
let password = FirstMessage::parse(&bytes).unwrap();
|
||||
assert_eq!(password.method, proto);
|
||||
assert_eq!(password.message, sasl);
|
||||
}
|
||||
}
|
||||
70
proxy/src/sasl/stream.rs
Normal file
70
proxy/src/sasl/stream.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
//! Abstraction for the string-oriented SASL protocols.
|
||||
|
||||
use super::{messages::ServerMessage, Mechanism};
|
||||
use crate::stream::PqStream;
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
/// Abstracts away all peculiarities of the libpq's protocol.
|
||||
pub struct SaslStream<'a, S> {
|
||||
/// The underlying stream.
|
||||
stream: &'a mut PqStream<S>,
|
||||
/// Current password message we received from client.
|
||||
current: bytes::Bytes,
|
||||
/// First SASL message produced by client.
|
||||
first: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a, S> SaslStream<'a, S> {
|
||||
pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
current: bytes::Bytes::new(),
|
||||
first: Some(first),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
|
||||
// Receive a new SASL message from the client.
|
||||
async fn recv(&mut self) -> io::Result<&str> {
|
||||
if let Some(first) = self.first.take() {
|
||||
return Ok(first);
|
||||
}
|
||||
|
||||
self.current = self.stream.read_password_message().await?;
|
||||
let s = std::str::from_utf8(&self.current)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
// Send a SASL message to the client.
|
||||
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
|
||||
self.stream.write_message(&msg.to_reply()).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
/// Perform SASL message exchange according to the underlying algorithm
|
||||
/// until user is either authenticated or denied access.
|
||||
pub async fn authenticate(mut self, mut mechanism: impl Mechanism) -> super::Result<()> {
|
||||
loop {
|
||||
let input = self.recv().await?;
|
||||
let (moved, reply) = mechanism.exchange(input)?;
|
||||
match moved {
|
||||
Some(moved) => {
|
||||
self.send(&ServerMessage::Continue(&reply)).await?;
|
||||
mechanism = moved;
|
||||
}
|
||||
None => {
|
||||
self.send(&ServerMessage::Final(&reply)).await?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
59
proxy/src/scram.rs
Normal file
59
proxy/src/scram.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
//! Salted Challenge Response Authentication Mechanism.
|
||||
//!
|
||||
//! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
|
||||
//!
|
||||
//! Reference implementation:
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
|
||||
|
||||
mod exchange;
|
||||
mod key;
|
||||
mod messages;
|
||||
mod password;
|
||||
mod secret;
|
||||
mod signature;
|
||||
|
||||
pub use secret::*;
|
||||
|
||||
pub use exchange::Exchange;
|
||||
pub use secret::ServerSecret;
|
||||
|
||||
use hmac::{Hmac, Mac, NewMac};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
// TODO: add SCRAM-SHA-256-PLUS
|
||||
/// A list of supported SCRAM methods.
|
||||
pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
|
||||
|
||||
/// Decode base64 into array without any heap allocations
|
||||
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
|
||||
let mut bytes = [0u8; N];
|
||||
|
||||
let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
|
||||
if size != N {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(bytes)
|
||||
}
|
||||
|
||||
/// This function essentially is `Hmac(sha256, key, input)`.
|
||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
|
||||
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut mac = Hmac::<Sha256>::new_varkey(key).expect("bad key size");
|
||||
parts.into_iter().for_each(|s| mac.update(s));
|
||||
|
||||
// TODO: maybe newer `hmac` et al already migrated to regular arrays?
|
||||
let mut result = [0u8; 32];
|
||||
result.copy_from_slice(mac.finalize().into_bytes().as_slice());
|
||||
result
|
||||
}
|
||||
|
||||
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
parts.into_iter().for_each(|s| hasher.update(s));
|
||||
|
||||
let mut result = [0u8; 32];
|
||||
result.copy_from_slice(hasher.finalize().as_slice());
|
||||
result
|
||||
}
|
||||
134
proxy/src/scram/exchange.rs
Normal file
134
proxy/src/scram/exchange.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
//! Implementation of the SCRAM authentication algorithm.
|
||||
|
||||
use super::messages::{
|
||||
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
|
||||
};
|
||||
use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
#[derive(Debug)]
|
||||
struct TlsServerEndPoint;
|
||||
|
||||
impl std::fmt::Display for TlsServerEndPoint {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "tls-server-end-point")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for TlsServerEndPoint {
|
||||
type Err = sasl::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"tls-server-end-point" => Ok(TlsServerEndPoint),
|
||||
_ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ExchangeState {
|
||||
/// Waiting for [`ClientFirstMessage`].
|
||||
Initial,
|
||||
/// Waiting for [`ClientFinalMessage`].
|
||||
SaltSent {
|
||||
cbind_flag: ChannelBinding<TlsServerEndPoint>,
|
||||
client_first_message_bare: String,
|
||||
server_first_message: OwnedServerFirstMessage,
|
||||
},
|
||||
}
|
||||
|
||||
/// Server's side of SCRAM auth algorithm.
|
||||
#[derive(Debug)]
|
||||
pub struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
}
|
||||
|
||||
impl<'a> Exchange<'a> {
|
||||
pub fn new(
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
cert_digest: Option<&'a [u8]>,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: ExchangeState::Initial,
|
||||
secret,
|
||||
nonce,
|
||||
cert_digest,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl sasl::Mechanism for Exchange<'_> {
|
||||
fn exchange(mut self, input: &str) -> sasl::Result<(Option<Self>, String)> {
|
||||
use ExchangeState::*;
|
||||
match &self.state {
|
||||
Initial => {
|
||||
let client_first_message =
|
||||
ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
|
||||
|
||||
let server_first_message = client_first_message.build_server_first_message(
|
||||
&(self.nonce)(),
|
||||
&self.secret.salt_base64,
|
||||
self.secret.iterations,
|
||||
);
|
||||
let msg = server_first_message.as_str().to_owned();
|
||||
|
||||
self.state = SaltSent {
|
||||
cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
|
||||
client_first_message_bare: client_first_message.bare.to_owned(),
|
||||
server_first_message,
|
||||
};
|
||||
|
||||
Ok((Some(self), msg))
|
||||
}
|
||||
SaltSent {
|
||||
cbind_flag,
|
||||
client_first_message_bare,
|
||||
server_first_message,
|
||||
} => {
|
||||
let client_final_message =
|
||||
ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| {
|
||||
self.cert_digest
|
||||
.map(base64::encode)
|
||||
.ok_or(SaslError::ChannelBindingFailed("no cert digest provided"))
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
if client_final_message.channel_binding != channel_binding {
|
||||
return Err(SaslError::ChannelBindingFailed("data mismatch"));
|
||||
}
|
||||
|
||||
if client_final_message.nonce != server_first_message.nonce() {
|
||||
return Err(SaslError::AuthenticationFailed("bad nonce"));
|
||||
}
|
||||
|
||||
let signature_builder = SignatureBuilder {
|
||||
client_first_message_bare,
|
||||
server_first_message: server_first_message.as_str(),
|
||||
client_final_message_without_proof: client_final_message.without_proof,
|
||||
};
|
||||
|
||||
let client_key = signature_builder
|
||||
.build(&self.secret.stored_key)
|
||||
.derive_client_key(&client_final_message.proof);
|
||||
|
||||
if client_key.sha256() != self.secret.stored_key {
|
||||
return Err(SaslError::AuthenticationFailed("keys don't match"));
|
||||
}
|
||||
|
||||
let msg = client_final_message
|
||||
.build_server_final_message(signature_builder, &self.secret.server_key);
|
||||
|
||||
Ok((None, msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
33
proxy/src/scram/key.rs
Normal file
33
proxy/src/scram/key.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! Tools for client/server/stored key management.
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_KEY_LEN: usize = 32;
|
||||
|
||||
/// One of the keys derived from the [password](super::password::SaltedPassword).
|
||||
/// We use the same structure for all keys, i.e.
|
||||
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
||||
#[derive(Default, Debug, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
pub struct ScramKey {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl ScramKey {
|
||||
pub fn sha256(&self) -> Self {
|
||||
super::sha256([self.as_ref()]).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
|
||||
#[inline(always)]
|
||||
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for ScramKey {
|
||||
#[inline(always)]
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.bytes
|
||||
}
|
||||
}
|
||||
232
proxy/src/scram/messages.rs
Normal file
232
proxy/src/scram/messages.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
//! Definitions for SCRAM messages.
|
||||
|
||||
use super::base64_decode_array;
|
||||
use super::key::{ScramKey, SCRAM_KEY_LEN};
|
||||
use super::signature::SignatureBuilder;
|
||||
use crate::sasl::ChannelBinding;
|
||||
use std::fmt;
|
||||
use std::ops::Range;
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_RAW_NONCE_LEN: usize = 18;
|
||||
|
||||
/// Although we ignore all extensions, we still have to validate the message.
|
||||
fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
|
||||
for mut chars in parts.map(|s| s.chars()) {
|
||||
let attr = chars.next()?;
|
||||
if !('a'..'z').contains(&attr) && !('A'..'Z').contains(&attr) {
|
||||
return None;
|
||||
}
|
||||
let eq = chars.next()?;
|
||||
if eq != '=' {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientFirstMessage<'a> {
|
||||
/// `client-first-message-bare`.
|
||||
pub bare: &'a str,
|
||||
/// Channel binding mode.
|
||||
pub cbind_flag: ChannelBinding<&'a str>,
|
||||
/// (Client username)[<https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf/src/backend/libpq/auth-scram.c#L13>].
|
||||
pub username: &'a str,
|
||||
/// Client nonce.
|
||||
pub nonce: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> ClientFirstMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
let mut parts = input.split(',');
|
||||
|
||||
let cbind_flag = ChannelBinding::parse(parts.next()?)?;
|
||||
|
||||
// PG doesn't support authorization identity,
|
||||
// so we don't bother defining GS2 header type
|
||||
let authzid = parts.next()?;
|
||||
if !authzid.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Unfortunately, `parts.as_str()` is unstable
|
||||
let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
|
||||
let (_, bare) = input.split_at(pos);
|
||||
|
||||
// In theory, these might be preceded by "reserved-mext" (i.e. "m=")
|
||||
let username = parts.next()?.strip_prefix("n=")?;
|
||||
let nonce = parts.next()?.strip_prefix("r=")?;
|
||||
|
||||
// Validate but ignore auth extensions
|
||||
validate_sasl_extensions(parts)?;
|
||||
|
||||
Some(Self {
|
||||
bare,
|
||||
cbind_flag,
|
||||
username,
|
||||
nonce,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a response to [`ClientFirstMessage`].
|
||||
pub fn build_server_first_message(
|
||||
&self,
|
||||
nonce: &[u8; SCRAM_RAW_NONCE_LEN],
|
||||
salt_base64: &str,
|
||||
iterations: u32,
|
||||
) -> OwnedServerFirstMessage {
|
||||
use std::fmt::Write;
|
||||
|
||||
let mut message = String::new();
|
||||
write!(&mut message, "r={}", self.nonce).unwrap();
|
||||
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
|
||||
let combined_nonce = 2..message.len();
|
||||
write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
|
||||
|
||||
// This design guarantees that it's impossible to create a
|
||||
// server-first-message without receiving a client-first-message
|
||||
OwnedServerFirstMessage {
|
||||
message,
|
||||
nonce: combined_nonce,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientFinalMessage<'a> {
|
||||
/// `client-final-message-without-proof`.
|
||||
pub without_proof: &'a str,
|
||||
/// Channel binding data (base64).
|
||||
pub channel_binding: &'a str,
|
||||
/// Combined client & server nonce.
|
||||
pub nonce: &'a str,
|
||||
/// Client auth proof.
|
||||
pub proof: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl<'a> ClientFinalMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
let (without_proof, proof) = input.rsplit_once(',')?;
|
||||
|
||||
let mut parts = without_proof.split(',');
|
||||
let channel_binding = parts.next()?.strip_prefix("c=")?;
|
||||
let nonce = parts.next()?.strip_prefix("r=")?;
|
||||
|
||||
// Validate but ignore auth extensions
|
||||
validate_sasl_extensions(parts)?;
|
||||
|
||||
let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
|
||||
|
||||
Some(Self {
|
||||
without_proof,
|
||||
channel_binding,
|
||||
nonce,
|
||||
proof,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a response to [`ClientFinalMessage`].
|
||||
pub fn build_server_final_message(
|
||||
&self,
|
||||
signature_builder: SignatureBuilder,
|
||||
server_key: &ScramKey,
|
||||
) -> String {
|
||||
let mut buf = String::from("v=");
|
||||
base64::encode_config_buf(
|
||||
signature_builder.build(server_key),
|
||||
base64::STANDARD,
|
||||
&mut buf,
|
||||
);
|
||||
|
||||
buf
|
||||
}
|
||||
}
|
||||
|
||||
/// We need to keep a convenient representation of this
|
||||
/// message for the next authentication step.
|
||||
pub struct OwnedServerFirstMessage {
|
||||
/// Owned `server-first-message`.
|
||||
message: String,
|
||||
/// Slice into `message`.
|
||||
nonce: Range<usize>,
|
||||
}
|
||||
|
||||
impl OwnedServerFirstMessage {
|
||||
/// Extract combined nonce from the message.
|
||||
#[inline(always)]
|
||||
pub fn nonce(&self) -> &str {
|
||||
&self.message[self.nonce.clone()]
|
||||
}
|
||||
|
||||
/// Get reference to a text representation of the message.
|
||||
#[inline(always)]
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.message
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for OwnedServerFirstMessage {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("ServerFirstMessage")
|
||||
.field("message", &self.as_str())
|
||||
.field("nonce", &self.nonce())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_client_first_message() {
|
||||
use ChannelBinding::*;
|
||||
|
||||
// (Almost) real strings captured during debug sessions
|
||||
let cases = [
|
||||
(NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
|
||||
(NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
|
||||
(
|
||||
Required("tls-server-end-point"),
|
||||
"p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju",
|
||||
),
|
||||
];
|
||||
|
||||
for (cb, input) in cases {
|
||||
let msg = ClientFirstMessage::parse(input).unwrap();
|
||||
|
||||
assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju");
|
||||
assert_eq!(msg.username, "pepe");
|
||||
assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
|
||||
assert_eq!(msg.cbind_flag, cb);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_client_final_message() {
|
||||
let input = [
|
||||
"c=eSws",
|
||||
"r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
|
||||
"p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
|
||||
]
|
||||
.join(",");
|
||||
|
||||
let msg = ClientFinalMessage::parse(&input).unwrap();
|
||||
assert_eq!(
|
||||
msg.without_proof,
|
||||
"c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
|
||||
);
|
||||
assert_eq!(
|
||||
msg.nonce,
|
||||
"iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
|
||||
);
|
||||
assert_eq!(
|
||||
base64::encode(msg.proof),
|
||||
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
|
||||
);
|
||||
}
|
||||
}
|
||||
48
proxy/src/scram/password.rs
Normal file
48
proxy/src/scram/password.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
//! Password hashing routines.
|
||||
|
||||
use super::key::ScramKey;
|
||||
|
||||
pub const SALTED_PASSWORD_LEN: usize = 32;
|
||||
|
||||
/// Salted hashed password is essential for [key](super::key) derivation.
|
||||
#[repr(transparent)]
|
||||
pub struct SaltedPassword {
|
||||
bytes: [u8; SALTED_PASSWORD_LEN],
|
||||
}
|
||||
|
||||
impl SaltedPassword {
|
||||
/// See `scram-common.c : scram_SaltedPassword` for details.
|
||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2898> (see `PBKDF2`).
|
||||
pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
|
||||
let one = 1_u32.to_be_bytes(); // magic
|
||||
|
||||
let mut current = super::hmac_sha256(password, [salt, &one]);
|
||||
let mut result = current;
|
||||
for _ in 1..iterations {
|
||||
current = super::hmac_sha256(password, [current.as_ref()]);
|
||||
// TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094
|
||||
for (i, x) in current.iter().enumerate() {
|
||||
result[i] ^= x;
|
||||
}
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
|
||||
/// Derive `ClientKey` from a salted hashed password.
|
||||
pub fn client_key(&self) -> ScramKey {
|
||||
super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into()
|
||||
}
|
||||
|
||||
/// Derive `ServerKey` from a salted hashed password.
|
||||
pub fn server_key(&self) -> ScramKey {
|
||||
super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword {
|
||||
#[inline(always)]
|
||||
fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
116
proxy/src/scram/secret.rs
Normal file
116
proxy/src/scram/secret.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
//! Tools for SCRAM server secret management.
|
||||
|
||||
use super::base64_decode_array;
|
||||
use super::key::ScramKey;
|
||||
|
||||
/// Server secret is produced from [password](super::password::SaltedPassword)
|
||||
/// and is used throughout the authentication process.
|
||||
#[derive(Debug)]
|
||||
pub struct ServerSecret {
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
pub iterations: u32,
|
||||
/// Salt used to hash user's password.
|
||||
pub salt_base64: String,
|
||||
/// Hashed `ClientKey`.
|
||||
pub stored_key: ScramKey,
|
||||
/// Used by client to verify server's signature.
|
||||
pub server_key: ScramKey,
|
||||
}
|
||||
|
||||
impl ServerSecret {
|
||||
pub fn parse(input: &str) -> Option<Self> {
|
||||
// SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
|
||||
let s = input.strip_prefix("SCRAM-SHA-256$")?;
|
||||
let (params, keys) = s.split_once('$')?;
|
||||
|
||||
let ((iterations, salt), (stored_key, server_key)) =
|
||||
params.split_once(':').zip(keys.split_once(':'))?;
|
||||
|
||||
let secret = ServerSecret {
|
||||
iterations: iterations.parse().ok()?,
|
||||
salt_base64: salt.to_owned(),
|
||||
stored_key: base64_decode_array(stored_key)?.into(),
|
||||
server_key: base64_decode_array(server_key)?.into(),
|
||||
};
|
||||
|
||||
Some(secret)
|
||||
}
|
||||
|
||||
/// To avoid revealing information to an attacker, we use a
|
||||
/// mocked server secret even if the user doesn't exist.
|
||||
/// See `auth-scram.c : mock_scram_secret` for details.
|
||||
pub fn mock(user: &str, nonce: &[u8; 32]) -> Self {
|
||||
// Refer to `auth-scram.c : scram_mock_salt`.
|
||||
let mocked_salt = super::sha256([user.as_bytes(), nonce]);
|
||||
|
||||
Self {
|
||||
iterations: 4096,
|
||||
salt_base64: base64::encode(&mocked_salt),
|
||||
stored_key: ScramKey::default(),
|
||||
server_key: ScramKey::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a new server secret from the prerequisites.
|
||||
/// XXX: We only use this function in tests.
|
||||
#[cfg(test)]
|
||||
pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option<Self> {
|
||||
// TODO: implement proper password normalization required by the RFC
|
||||
if !password.is_ascii() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations);
|
||||
|
||||
Some(Self {
|
||||
iterations,
|
||||
salt_base64: base64::encode(&salt),
|
||||
stored_key: password.client_key().sha256(),
|
||||
server_key: password.server_key(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_scram_secret() {
|
||||
let iterations = 4096;
|
||||
let salt = "+/tQQax7twvwTj64mjBsxQ==";
|
||||
let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns=";
|
||||
let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI=";
|
||||
|
||||
let secret = format!(
|
||||
"SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",
|
||||
iterations = iterations,
|
||||
salt = salt,
|
||||
stored_key = stored_key,
|
||||
server_key = server_key,
|
||||
);
|
||||
|
||||
let parsed = ServerSecret::parse(&secret).unwrap();
|
||||
assert_eq!(parsed.iterations, iterations);
|
||||
assert_eq!(parsed.salt_base64, salt);
|
||||
|
||||
assert_eq!(base64::encode(parsed.stored_key), stored_key);
|
||||
assert_eq!(base64::encode(parsed.server_key), server_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_scram_secret() {
|
||||
let salt = b"salt";
|
||||
let secret = ServerSecret::build("password", salt, 4096).unwrap();
|
||||
assert_eq!(secret.iterations, 4096);
|
||||
assert_eq!(secret.salt_base64, base64::encode(salt));
|
||||
assert_eq!(
|
||||
base64::encode(secret.stored_key.as_ref()),
|
||||
"lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ="
|
||||
);
|
||||
assert_eq!(
|
||||
base64::encode(secret.server_key.as_ref()),
|
||||
"ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw="
|
||||
);
|
||||
}
|
||||
}
|
||||
66
proxy/src/scram/signature.rs
Normal file
66
proxy/src/scram/signature.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
//! Tools for client/server signature management.
|
||||
|
||||
use super::key::{ScramKey, SCRAM_KEY_LEN};
|
||||
|
||||
/// A collection of message parts needed to derive the client's signature.
|
||||
#[derive(Debug)]
|
||||
pub struct SignatureBuilder<'a> {
|
||||
pub client_first_message_bare: &'a str,
|
||||
pub server_first_message: &'a str,
|
||||
pub client_final_message_without_proof: &'a str,
|
||||
}
|
||||
|
||||
impl SignatureBuilder<'_> {
|
||||
pub fn build(&self, key: &ScramKey) -> Signature {
|
||||
let parts = [
|
||||
self.client_first_message_bare.as_bytes(),
|
||||
b",",
|
||||
self.server_first_message.as_bytes(),
|
||||
b",",
|
||||
self.client_final_message_without_proof.as_bytes(),
|
||||
];
|
||||
|
||||
super::hmac_sha256(key.as_ref(), parts).into()
|
||||
}
|
||||
}
|
||||
|
||||
/// A computed value which, when xored with `ClientProof`,
|
||||
/// produces `ClientKey` that we need for authentication.
|
||||
#[derive(Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct Signature {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
/// Derive `ClientKey` from client's signature and proof.
|
||||
pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey {
|
||||
// This is how the proof is calculated:
|
||||
//
|
||||
// 1. sha256(ClientKey) -> StoredKey
|
||||
// 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature
|
||||
// 3. ClientKey ^ ClientSignature -> ClientProof
|
||||
//
|
||||
// Step 3 implies that we can restore ClientKey from the proof
|
||||
// by xoring the latter with the ClientSignature. Afterwards we
|
||||
// can check that the presumed ClientKey meets our expectations.
|
||||
let mut signature = self.bytes;
|
||||
for (i, x) in proof.iter().enumerate() {
|
||||
signature[i] ^= x;
|
||||
}
|
||||
|
||||
signature.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SCRAM_KEY_LEN]> for Signature {
|
||||
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for Signature {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
&self.bytes
|
||||
}
|
||||
}
|
||||
@@ -28,4 +28,4 @@ def test_createuser(zenith_simple_env: ZenithEnv):
|
||||
pg2 = env.postgres.create_start('test_createuser2')
|
||||
|
||||
# Test that you can connect to new branch as a new user
|
||||
assert pg2.safe_psql('select current_user', username='testuser') == [('testuser', )]
|
||||
assert pg2.safe_psql('select current_user', user='testuser') == [('testuser', )]
|
||||
|
||||
@@ -19,6 +19,11 @@ async def copy_test_data_to_table(pg: Postgres, worker_id: int, table_name: str)
|
||||
copy_input = repeat_bytes(buf.read(), 5000)
|
||||
|
||||
pg_conn = await pg.connect_async()
|
||||
|
||||
# PgProtocol.connect_async sets statement_timeout to 2 minutes.
|
||||
# That's not enough for this test, on a slow system in debug mode.
|
||||
await pg_conn.execute("SET statement_timeout='300s'")
|
||||
|
||||
await pg_conn.copy_to_table(table_name, source=copy_input)
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
|
||||
def test_pgbench(zenith_simple_env: ZenithEnv, pg_bin):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_pgbench", "empty")
|
||||
pg = env.postgres.create_start('test_pgbench')
|
||||
log.info("postgres is running on 'test_pgbench' branch")
|
||||
|
||||
connstr = pg.connstr()
|
||||
|
||||
pg_bin.run_capture(['pgbench', '-i', connstr])
|
||||
pg_bin.run_capture(['pgbench'] + '-c 10 -T 5 -P 1 -M prepared'.split() + [connstr])
|
||||
@@ -5,11 +5,14 @@ def test_proxy_select_1(static_proxy):
|
||||
static_proxy.safe_psql("select 1;")
|
||||
|
||||
|
||||
@pytest.mark.xfail # Proxy eats the extra connection options
|
||||
# Pass extra options to the server.
|
||||
#
|
||||
# Currently, proxy eats the extra connection options, so this fails.
|
||||
# See https://github.com/neondatabase/neon/issues/1287
|
||||
@pytest.mark.xfail
|
||||
def test_proxy_options(static_proxy):
|
||||
schema_name = "tmp_schema_1"
|
||||
with static_proxy.connect(schema=schema_name) as conn:
|
||||
with static_proxy.connect(options="-cproxytest.option=value") as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SHOW search_path;")
|
||||
search_path = cur.fetchall()[0][0]
|
||||
assert schema_name == search_path
|
||||
cur.execute("SHOW proxytest.option;")
|
||||
value = cur.fetchall()[0][0]
|
||||
assert value == 'value'
|
||||
|
||||
@@ -379,7 +379,7 @@ class ProposerPostgres(PgProtocol):
|
||||
tenant_id: uuid.UUID,
|
||||
listen_addr: str,
|
||||
port: int):
|
||||
super().__init__(host=listen_addr, port=port, username='zenith_admin')
|
||||
super().__init__(host=listen_addr, port=port, user='zenith_admin', dbname='postgres')
|
||||
|
||||
self.pgdata_dir: str = pgdata_dir
|
||||
self.pg_bin: PgBin = pg_bin
|
||||
|
||||
@@ -35,9 +35,9 @@ def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys
|
||||
]
|
||||
|
||||
env_vars = {
|
||||
'PGPORT': str(pg.port),
|
||||
'PGUSER': pg.username,
|
||||
'PGHOST': pg.host,
|
||||
'PGPORT': str(pg.default_options['port']),
|
||||
'PGUSER': pg.default_options['user'],
|
||||
'PGHOST': pg.default_options['host'],
|
||||
}
|
||||
|
||||
# Run the command.
|
||||
|
||||
@@ -35,9 +35,9 @@ def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin,
|
||||
]
|
||||
|
||||
env_vars = {
|
||||
'PGPORT': str(pg.port),
|
||||
'PGUSER': pg.username,
|
||||
'PGHOST': pg.host,
|
||||
'PGPORT': str(pg.default_options['port']),
|
||||
'PGUSER': pg.default_options['user'],
|
||||
'PGHOST': pg.default_options['host'],
|
||||
}
|
||||
|
||||
# Run the command.
|
||||
|
||||
@@ -40,9 +40,9 @@ def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, c
|
||||
|
||||
log.info(pg_regress_command)
|
||||
env_vars = {
|
||||
'PGPORT': str(pg.port),
|
||||
'PGUSER': pg.username,
|
||||
'PGHOST': pg.host,
|
||||
'PGPORT': str(pg.default_options['port']),
|
||||
'PGUSER': pg.default_options['user'],
|
||||
'PGHOST': pg.default_options['host'],
|
||||
}
|
||||
|
||||
# Run the command.
|
||||
|
||||
@@ -17,7 +17,7 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
# Type-related stuff
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Optional
|
||||
"""
|
||||
This file contains fixtures for micro-benchmarks.
|
||||
|
||||
@@ -51,17 +51,12 @@ in the test initialization, or measure disk usage after the test query.
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PgBenchRunResult:
|
||||
scale: int
|
||||
number_of_clients: int
|
||||
number_of_threads: int
|
||||
number_of_transactions_actually_processed: int
|
||||
latency_average: float
|
||||
latency_stddev: float
|
||||
tps_including_connection_time: float
|
||||
tps_excluding_connection_time: float
|
||||
init_duration: float
|
||||
init_start_timestamp: int
|
||||
init_end_timestamp: int
|
||||
latency_stddev: Optional[float]
|
||||
tps: float
|
||||
run_duration: float
|
||||
run_start_timestamp: int
|
||||
run_end_timestamp: int
|
||||
@@ -69,56 +64,67 @@ class PgBenchRunResult:
|
||||
# TODO progress
|
||||
|
||||
@classmethod
|
||||
def parse_from_output(
|
||||
def parse_from_stdout(
|
||||
cls,
|
||||
out: 'subprocess.CompletedProcess[str]',
|
||||
init_duration: float,
|
||||
init_start_timestamp: int,
|
||||
init_end_timestamp: int,
|
||||
stdout: str,
|
||||
run_duration: float,
|
||||
run_start_timestamp: int,
|
||||
run_end_timestamp: int,
|
||||
):
|
||||
stdout_lines = out.stdout.splitlines()
|
||||
stdout_lines = stdout.splitlines()
|
||||
|
||||
latency_stddev = None
|
||||
|
||||
# we know significant parts of these values from test input
|
||||
# but to be precise take them from output
|
||||
# scaling factor: 5
|
||||
assert "scaling factor" in stdout_lines[1]
|
||||
scale = int(stdout_lines[1].split()[-1])
|
||||
# number of clients: 1
|
||||
assert "number of clients" in stdout_lines[3]
|
||||
number_of_clients = int(stdout_lines[3].split()[-1])
|
||||
# number of threads: 1
|
||||
assert "number of threads" in stdout_lines[4]
|
||||
number_of_threads = int(stdout_lines[4].split()[-1])
|
||||
# number of transactions actually processed: 1000/1000
|
||||
assert "number of transactions actually processed" in stdout_lines[6]
|
||||
number_of_transactions_actually_processed = int(stdout_lines[6].split("/")[1])
|
||||
# latency average = 19.894 ms
|
||||
assert "latency average" in stdout_lines[7]
|
||||
latency_average = stdout_lines[7].split()[-2]
|
||||
# latency stddev = 3.387 ms
|
||||
assert "latency stddev" in stdout_lines[8]
|
||||
latency_stddev = stdout_lines[8].split()[-2]
|
||||
# tps = 50.219689 (including connections establishing)
|
||||
assert "(including connections establishing)" in stdout_lines[9]
|
||||
tps_including_connection_time = stdout_lines[9].split()[2]
|
||||
# tps = 50.264435 (excluding connections establishing)
|
||||
assert "(excluding connections establishing)" in stdout_lines[10]
|
||||
tps_excluding_connection_time = stdout_lines[10].split()[2]
|
||||
for line in stdout.splitlines():
|
||||
# scaling factor: 5
|
||||
if line.startswith("scaling factor:"):
|
||||
scale = int(line.split()[-1])
|
||||
# number of clients: 1
|
||||
if line.startswith("number of clients: "):
|
||||
number_of_clients = int(line.split()[-1])
|
||||
# number of threads: 1
|
||||
if line.startswith("number of threads: "):
|
||||
number_of_threads = int(line.split()[-1])
|
||||
# number of transactions actually processed: 1000/1000
|
||||
# OR
|
||||
# number of transactions actually processed: 1000
|
||||
if line.startswith("number of transactions actually processed"):
|
||||
if "/" in line:
|
||||
number_of_transactions_actually_processed = int(line.split("/")[1])
|
||||
else:
|
||||
number_of_transactions_actually_processed = int(line.split()[-1])
|
||||
# latency average = 19.894 ms
|
||||
if line.startswith("latency average"):
|
||||
latency_average = float(line.split()[-2])
|
||||
# latency stddev = 3.387 ms
|
||||
# (only printed with some options)
|
||||
if line.startswith("latency stddev"):
|
||||
latency_stddev = float(line.split()[-2])
|
||||
|
||||
# Get the TPS without initial connection time. The format
|
||||
# of the tps lines changed in pgbench v14, but we accept
|
||||
# either format:
|
||||
#
|
||||
# pgbench v13 and below:
|
||||
# tps = 50.219689 (including connections establishing)
|
||||
# tps = 50.264435 (excluding connections establishing)
|
||||
#
|
||||
# pgbench v14:
|
||||
# initial connection time = 3.858 ms
|
||||
# tps = 309.281539 (without initial connection time)
|
||||
if (line.startswith("tps = ") and ("(excluding connections establishing)" in line
|
||||
or "(without initial connection time)")):
|
||||
tps = float(line.split()[2])
|
||||
|
||||
return cls(
|
||||
scale=scale,
|
||||
number_of_clients=number_of_clients,
|
||||
number_of_threads=number_of_threads,
|
||||
number_of_transactions_actually_processed=number_of_transactions_actually_processed,
|
||||
latency_average=float(latency_average),
|
||||
latency_stddev=float(latency_stddev),
|
||||
tps_including_connection_time=float(tps_including_connection_time),
|
||||
tps_excluding_connection_time=float(tps_excluding_connection_time),
|
||||
init_duration=init_duration,
|
||||
init_start_timestamp=init_start_timestamp,
|
||||
init_end_timestamp=init_end_timestamp,
|
||||
latency_average=latency_average,
|
||||
latency_stddev=latency_stddev,
|
||||
tps=tps,
|
||||
run_duration=run_duration,
|
||||
run_start_timestamp=run_start_timestamp,
|
||||
run_end_timestamp=run_end_timestamp,
|
||||
@@ -187,60 +193,41 @@ class ZenithBenchmarker:
|
||||
report=MetricReport.LOWER_IS_BETTER,
|
||||
)
|
||||
|
||||
def record_pg_bench_result(self, pg_bench_result: PgBenchRunResult):
|
||||
self.record("scale", pg_bench_result.scale, '', MetricReport.TEST_PARAM)
|
||||
self.record("number_of_clients",
|
||||
def record_pg_bench_result(self, prefix: str, pg_bench_result: PgBenchRunResult):
|
||||
self.record(f"{prefix}.number_of_clients",
|
||||
pg_bench_result.number_of_clients,
|
||||
'',
|
||||
MetricReport.TEST_PARAM)
|
||||
self.record("number_of_threads",
|
||||
self.record(f"{prefix}.number_of_threads",
|
||||
pg_bench_result.number_of_threads,
|
||||
'',
|
||||
MetricReport.TEST_PARAM)
|
||||
self.record(
|
||||
"number_of_transactions_actually_processed",
|
||||
f"{prefix}.number_of_transactions_actually_processed",
|
||||
pg_bench_result.number_of_transactions_actually_processed,
|
||||
'',
|
||||
# thats because this is predefined by test matrix and doesnt change across runs
|
||||
report=MetricReport.TEST_PARAM,
|
||||
)
|
||||
self.record("latency_average",
|
||||
self.record(f"{prefix}.latency_average",
|
||||
pg_bench_result.latency_average,
|
||||
unit="ms",
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
self.record("latency_stddev",
|
||||
pg_bench_result.latency_stddev,
|
||||
unit="ms",
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
self.record("tps_including_connection_time",
|
||||
pg_bench_result.tps_including_connection_time,
|
||||
'',
|
||||
report=MetricReport.HIGHER_IS_BETTER)
|
||||
self.record("tps_excluding_connection_time",
|
||||
pg_bench_result.tps_excluding_connection_time,
|
||||
'',
|
||||
report=MetricReport.HIGHER_IS_BETTER)
|
||||
self.record("init_duration",
|
||||
pg_bench_result.init_duration,
|
||||
unit="s",
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
self.record("init_start_timestamp",
|
||||
pg_bench_result.init_start_timestamp,
|
||||
'',
|
||||
MetricReport.TEST_PARAM)
|
||||
self.record("init_end_timestamp",
|
||||
pg_bench_result.init_end_timestamp,
|
||||
'',
|
||||
MetricReport.TEST_PARAM)
|
||||
self.record("run_duration",
|
||||
if pg_bench_result.latency_stddev is not None:
|
||||
self.record(f"{prefix}.latency_stddev",
|
||||
pg_bench_result.latency_stddev,
|
||||
unit="ms",
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
self.record(f"{prefix}.tps", pg_bench_result.tps, '', report=MetricReport.HIGHER_IS_BETTER)
|
||||
self.record(f"{prefix}.run_duration",
|
||||
pg_bench_result.run_duration,
|
||||
unit="s",
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
self.record("run_start_timestamp",
|
||||
self.record(f"{prefix}.run_start_timestamp",
|
||||
pg_bench_result.run_start_timestamp,
|
||||
'',
|
||||
MetricReport.TEST_PARAM)
|
||||
self.record("run_end_timestamp",
|
||||
self.record(f"{prefix}.run_end_timestamp",
|
||||
pg_bench_result.run_end_timestamp,
|
||||
'',
|
||||
MetricReport.TEST_PARAM)
|
||||
@@ -259,10 +246,18 @@ class ZenithBenchmarker:
|
||||
"""
|
||||
Fetch the "cumulative # of bytes written" metric from the pageserver
|
||||
"""
|
||||
# Fetch all the exposed prometheus metrics from page server
|
||||
all_metrics = pageserver.http_client().get_metrics()
|
||||
# Use a regular expression to extract the one we're interested in
|
||||
#
|
||||
metric_name = r'pageserver_disk_io_bytes{io_operation="write"}'
|
||||
return self.get_int_counter_value(pageserver, metric_name)
|
||||
|
||||
def get_peak_mem(self, pageserver) -> int:
|
||||
"""
|
||||
Fetch the "maxrss" metric from the pageserver
|
||||
"""
|
||||
metric_name = r'pageserver_maxrss_kb'
|
||||
return self.get_int_counter_value(pageserver, metric_name)
|
||||
|
||||
def get_int_counter_value(self, pageserver, metric_name) -> int:
|
||||
"""Fetch the value of given int counter from pageserver metrics."""
|
||||
# TODO: If we start to collect more of the prometheus metrics in the
|
||||
# performance test suite like this, we should refactor this to load and
|
||||
# parse all the metrics into a more convenient structure in one go.
|
||||
@@ -270,20 +265,8 @@ class ZenithBenchmarker:
|
||||
# The metric should be an integer, as it's a number of bytes. But in general
|
||||
# all prometheus metrics are floats. So to be pedantic, read it as a float
|
||||
# and round to integer.
|
||||
matches = re.search(r'^pageserver_disk_io_bytes{io_operation="write"} (\S+)$',
|
||||
all_metrics,
|
||||
re.MULTILINE)
|
||||
assert matches
|
||||
return int(round(float(matches.group(1))))
|
||||
|
||||
def get_peak_mem(self, pageserver) -> int:
|
||||
"""
|
||||
Fetch the "maxrss" metric from the pageserver
|
||||
"""
|
||||
# Fetch all the exposed prometheus metrics from page server
|
||||
all_metrics = pageserver.http_client().get_metrics()
|
||||
# See comment in get_io_writes()
|
||||
matches = re.search(r'^pageserver_maxrss_kb (\S+)$', all_metrics, re.MULTILINE)
|
||||
matches = re.search(fr'^{metric_name} (\S+)$', all_metrics, re.MULTILINE)
|
||||
assert matches
|
||||
return int(round(float(matches.group(1))))
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, ZenithEnv
|
||||
from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, RemotePostgres, ZenithEnv
|
||||
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
|
||||
|
||||
# Type-related stuff
|
||||
@@ -87,6 +87,9 @@ class ZenithCompare(PgCompare):
|
||||
def flush(self):
|
||||
self.pscur.execute(f"do_gc {self.env.initial_tenant.hex} {self.timeline} 0")
|
||||
|
||||
def compact(self):
|
||||
self.pscur.execute(f"compact {self.env.initial_tenant.hex} {self.timeline}")
|
||||
|
||||
def report_peak_memory_use(self) -> None:
|
||||
self.zenbenchmark.record("peak_mem",
|
||||
self.zenbenchmark.get_peak_mem(self.env.pageserver) / 1024,
|
||||
@@ -102,6 +105,19 @@ class ZenithCompare(PgCompare):
|
||||
'MB',
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
|
||||
total_files = self.zenbenchmark.get_int_counter_value(
|
||||
self.env.pageserver, "pageserver_num_persistent_files_created")
|
||||
total_bytes = self.zenbenchmark.get_int_counter_value(
|
||||
self.env.pageserver, "pageserver_persistent_bytes_written")
|
||||
self.zenbenchmark.record("data_uploaded",
|
||||
total_bytes / (1024 * 1024),
|
||||
"MB",
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
self.zenbenchmark.record("num_files_uploaded",
|
||||
total_files,
|
||||
"",
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
|
||||
def record_pageserver_writes(self, out_name):
|
||||
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
|
||||
|
||||
@@ -159,6 +175,48 @@ class VanillaCompare(PgCompare):
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
class RemoteCompare(PgCompare):
|
||||
"""PgCompare interface for a remote postgres instance."""
|
||||
def __init__(self, zenbenchmark, remote_pg: RemotePostgres):
|
||||
self._pg = remote_pg
|
||||
self._zenbenchmark = zenbenchmark
|
||||
|
||||
# Long-lived cursor, useful for flushing
|
||||
self.conn = self.pg.connect()
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
@property
|
||||
def pg(self):
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
def zenbenchmark(self):
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
def pg_bin(self):
|
||||
return self._pg.pg_bin
|
||||
|
||||
def flush(self):
|
||||
# TODO: flush the remote pageserver
|
||||
pass
|
||||
|
||||
def report_peak_memory_use(self) -> None:
|
||||
# TODO: get memory usage from remote pageserver
|
||||
pass
|
||||
|
||||
def report_size(self) -> None:
|
||||
# TODO: get storage size from remote pageserver
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def record_pageserver_writes(self, out_name):
|
||||
yield # Do nothing
|
||||
|
||||
def record_duration(self, out_name):
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def zenith_compare(request, zenbenchmark, pg_bin, zenith_simple_env) -> ZenithCompare:
|
||||
branch_name = request.node.name
|
||||
@@ -170,6 +228,11 @@ def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare:
|
||||
return VanillaCompare(zenbenchmark, vanilla_pg)
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def remote_compare(zenbenchmark, remote_pg) -> RemoteCompare:
|
||||
return RemoteCompare(zenbenchmark, remote_pg)
|
||||
|
||||
|
||||
@pytest.fixture(params=["vanilla_compare", "zenith_compare"], ids=["vanilla", "zenith"])
|
||||
def zenith_with_baseline(request) -> PgCompare:
|
||||
"""Parameterized fixture that helps compare zenith against vanilla postgres.
|
||||
|
||||
@@ -27,6 +27,7 @@ from dataclasses import dataclass
|
||||
|
||||
# Type-related stuff
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
from psycopg2.extensions import make_dsn, parse_dsn
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple
|
||||
from typing_extensions import Literal
|
||||
|
||||
@@ -122,6 +123,22 @@ def pytest_configure(config):
|
||||
top_output_dir = os.path.join(base_dir, DEFAULT_OUTPUT_DIR)
|
||||
mkdir_if_needed(top_output_dir)
|
||||
|
||||
# Find the postgres installation.
|
||||
global pg_distrib_dir
|
||||
env_postgres_bin = os.environ.get('POSTGRES_DISTRIB_DIR')
|
||||
if env_postgres_bin:
|
||||
pg_distrib_dir = env_postgres_bin
|
||||
else:
|
||||
pg_distrib_dir = os.path.normpath(os.path.join(base_dir, DEFAULT_POSTGRES_DIR))
|
||||
log.info(f'pg_distrib_dir is {pg_distrib_dir}')
|
||||
if os.getenv("REMOTE_ENV"):
|
||||
# When testing against a remote server, we only need the client binary.
|
||||
if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/psql')):
|
||||
raise Exception('psql not found at "{}"'.format(pg_distrib_dir))
|
||||
else:
|
||||
if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/postgres')):
|
||||
raise Exception('postgres not found at "{}"'.format(pg_distrib_dir))
|
||||
|
||||
if os.getenv("REMOTE_ENV"):
|
||||
# we are in remote env and do not have zenith binaries locally
|
||||
# this is the case for benchmarks run on self-hosted runner
|
||||
@@ -137,17 +154,6 @@ def pytest_configure(config):
|
||||
if not os.path.exists(os.path.join(zenith_binpath, 'pageserver')):
|
||||
raise Exception('zenith binaries not found at "{}"'.format(zenith_binpath))
|
||||
|
||||
# Find the postgres installation.
|
||||
global pg_distrib_dir
|
||||
env_postgres_bin = os.environ.get('POSTGRES_DISTRIB_DIR')
|
||||
if env_postgres_bin:
|
||||
pg_distrib_dir = env_postgres_bin
|
||||
else:
|
||||
pg_distrib_dir = os.path.normpath(os.path.join(base_dir, DEFAULT_POSTGRES_DIR))
|
||||
log.info(f'pg_distrib_dir is {pg_distrib_dir}')
|
||||
if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/postgres')):
|
||||
raise Exception('postgres not found at "{}"'.format(pg_distrib_dir))
|
||||
|
||||
|
||||
def zenfixture(func: Fn) -> Fn:
|
||||
"""
|
||||
@@ -238,98 +244,69 @@ def port_distributor(worker_base_port):
|
||||
|
||||
class PgProtocol:
|
||||
""" Reusable connection logic """
|
||||
def __init__(self,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
dbname: Optional[str] = None,
|
||||
schema: Optional[str] = None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.dbname = dbname
|
||||
self.schema = schema
|
||||
def __init__(self, **kwargs):
|
||||
self.default_options = kwargs
|
||||
|
||||
def connstr(self,
|
||||
*,
|
||||
dbname: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
statement_timeout_ms: Optional[int] = None) -> str:
|
||||
def connstr(self, **kwargs) -> str:
|
||||
"""
|
||||
Build a libpq connection string for the Postgres instance.
|
||||
"""
|
||||
return str(make_dsn(**self.conn_options(**kwargs)))
|
||||
|
||||
username = username or self.username
|
||||
password = password or self.password
|
||||
dbname = dbname or self.dbname or "postgres"
|
||||
schema = schema or self.schema
|
||||
res = f'host={self.host} port={self.port} dbname={dbname}'
|
||||
def conn_options(self, **kwargs):
|
||||
conn_options = self.default_options.copy()
|
||||
if 'dsn' in kwargs:
|
||||
conn_options.update(parse_dsn(kwargs['dsn']))
|
||||
conn_options.update(kwargs)
|
||||
|
||||
if username:
|
||||
res = f'{res} user={username}'
|
||||
|
||||
if password:
|
||||
res = f'{res} password={password}'
|
||||
|
||||
if schema:
|
||||
res = f"{res} options='-c search_path={schema}'"
|
||||
|
||||
if statement_timeout_ms:
|
||||
res = f"{res} options='-c statement_timeout={statement_timeout_ms}'"
|
||||
|
||||
return res
|
||||
# Individual statement timeout in seconds. 2 minutes should be
|
||||
# enough for our tests, but if you need a longer, you can
|
||||
# change it by calling "SET statement_timeout" after
|
||||
# connecting.
|
||||
if 'options' in conn_options:
|
||||
conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options']
|
||||
else:
|
||||
conn_options['options'] = "-cstatement_timeout=120s"
|
||||
return conn_options
|
||||
|
||||
# autocommit=True here by default because that's what we need most of the time
|
||||
def connect(
|
||||
self,
|
||||
*,
|
||||
autocommit=True,
|
||||
dbname: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
# individual statement timeout in seconds, 2 minutes should be enough for our tests
|
||||
statement_timeout: Optional[int] = 120
|
||||
) -> PgConnection:
|
||||
def connect(self, autocommit=True, **kwargs) -> PgConnection:
|
||||
"""
|
||||
Connect to the node.
|
||||
Returns psycopg2's connection object.
|
||||
This method passes all extra params to connstr.
|
||||
"""
|
||||
conn = psycopg2.connect(**self.conn_options(**kwargs))
|
||||
|
||||
conn = psycopg2.connect(
|
||||
self.connstr(dbname=dbname,
|
||||
schema=schema,
|
||||
username=username,
|
||||
password=password,
|
||||
statement_timeout_ms=statement_timeout *
|
||||
1000 if statement_timeout else None))
|
||||
# WARNING: this setting affects *all* tests!
|
||||
conn.autocommit = autocommit
|
||||
return conn
|
||||
|
||||
async def connect_async(self,
|
||||
*,
|
||||
dbname: str = 'postgres',
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None) -> asyncpg.Connection:
|
||||
async def connect_async(self, **kwargs) -> asyncpg.Connection:
|
||||
"""
|
||||
Connect to the node from async python.
|
||||
Returns asyncpg's connection object.
|
||||
"""
|
||||
|
||||
conn = await asyncpg.connect(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
database=dbname,
|
||||
user=username or self.username,
|
||||
password=password,
|
||||
)
|
||||
return conn
|
||||
# asyncpg takes slightly different options than psycopg2. Try
|
||||
# to convert the defaults from the psycopg2 format.
|
||||
|
||||
# The psycopg2 option 'dbname' is called 'database' is asyncpg
|
||||
conn_options = self.conn_options(**kwargs)
|
||||
if 'dbname' in conn_options:
|
||||
conn_options['database'] = conn_options.pop('dbname')
|
||||
|
||||
# Convert options='-c<key>=<val>' to server_settings
|
||||
if 'options' in conn_options:
|
||||
options = conn_options.pop('options')
|
||||
for match in re.finditer('-c(\w*)=(\w*)', options):
|
||||
key = match.group(1)
|
||||
val = match.group(2)
|
||||
if 'server_options' in conn_options:
|
||||
conn_options['server_settings'].update({key: val})
|
||||
else:
|
||||
conn_options['server_settings'] = {key: val}
|
||||
return await asyncpg.connect(**conn_options)
|
||||
|
||||
def safe_psql(self, query: str, **kwargs: Any) -> List[Any]:
|
||||
"""
|
||||
@@ -1149,10 +1126,10 @@ class ZenithPageserver(PgProtocol):
|
||||
port: PageserverPort,
|
||||
remote_storage: Optional[RemoteStorage] = None,
|
||||
config_override: Optional[str] = None):
|
||||
super().__init__(host='localhost', port=port.pg, username='zenith_admin')
|
||||
super().__init__(host='localhost', port=port.pg, user='zenith_admin')
|
||||
self.env = env
|
||||
self.running = False
|
||||
self.service_port = port # do not shadow PgProtocol.port which is just int
|
||||
self.service_port = port
|
||||
self.remote_storage = remote_storage
|
||||
self.config_override = config_override
|
||||
|
||||
@@ -1314,7 +1291,7 @@ def psbench_bin(test_output_dir):
|
||||
|
||||
class VanillaPostgres(PgProtocol):
|
||||
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int):
|
||||
super().__init__(host='localhost', port=port)
|
||||
super().__init__(host='localhost', port=port, dbname='postgres')
|
||||
self.pgdatadir = pgdatadir
|
||||
self.pg_bin = pg_bin
|
||||
self.running = False
|
||||
@@ -1356,10 +1333,57 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
|
||||
yield vanilla_pg
|
||||
|
||||
|
||||
class RemotePostgres(PgProtocol):
|
||||
def __init__(self, pg_bin: PgBin, remote_connstr: str):
|
||||
super().__init__(**parse_dsn(remote_connstr))
|
||||
self.pg_bin = pg_bin
|
||||
# The remote server is assumed to be running already
|
||||
self.running = True
|
||||
|
||||
def configure(self, options: List[str]):
|
||||
raise Exception('cannot change configuration of remote Posgres instance')
|
||||
|
||||
def start(self):
|
||||
raise Exception('cannot start a remote Postgres instance')
|
||||
|
||||
def stop(self):
|
||||
raise Exception('cannot stop a remote Postgres instance')
|
||||
|
||||
def get_subdir_size(self, subdir) -> int:
|
||||
# TODO: Could use the server's Generic File Acccess functions if superuser.
|
||||
# See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE
|
||||
raise Exception('cannot get size of a Postgres instance')
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
# do nothing
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def remote_pg(test_output_dir: str) -> Iterator[RemotePostgres]:
|
||||
pg_bin = PgBin(test_output_dir)
|
||||
|
||||
connstr = os.getenv("BENCHMARK_CONNSTR")
|
||||
if connstr is None:
|
||||
raise ValueError("no connstr provided, use BENCHMARK_CONNSTR environment variable")
|
||||
|
||||
with RemotePostgres(pg_bin, connstr) as remote_pg:
|
||||
yield remote_pg
|
||||
|
||||
|
||||
class ZenithProxy(PgProtocol):
|
||||
def __init__(self, port: int):
|
||||
super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port)
|
||||
super().__init__(host="127.0.0.1",
|
||||
user="pytest",
|
||||
password="pytest",
|
||||
port=port,
|
||||
dbname='postgres')
|
||||
self.http_port = 7001
|
||||
self.host = "127.0.0.1"
|
||||
self.port = port
|
||||
self._popen: Optional[subprocess.Popen[bytes]] = None
|
||||
|
||||
def start_static(self, addr="127.0.0.1:5432") -> None:
|
||||
@@ -1403,13 +1427,13 @@ def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
|
||||
class Postgres(PgProtocol):
|
||||
""" An object representing a running postgres daemon. """
|
||||
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
|
||||
super().__init__(host='localhost', port=port, username='zenith_admin')
|
||||
|
||||
super().__init__(host='localhost', port=port, user='zenith_admin', dbname='postgres')
|
||||
self.env = env
|
||||
self.running = False
|
||||
self.node_name: Optional[str] = None # dubious, see asserts below
|
||||
self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA
|
||||
self.tenant_id = tenant_id
|
||||
self.port = port
|
||||
# path to conf is <repo_dir>/pgdatadirs/tenants/<tenant_id>/<node_name>/postgresql.conf
|
||||
|
||||
def create(
|
||||
|
||||
@@ -2,29 +2,113 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import PgBin, VanillaPostgres, ZenithEnv
|
||||
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
|
||||
|
||||
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
|
||||
from fixtures.benchmark_fixture import PgBenchRunResult, MetricReport, ZenithBenchmarker
|
||||
from fixtures.log_helper import log
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
import calendar
|
||||
import os
|
||||
import timeit
|
||||
|
||||
|
||||
def utc_now_timestamp() -> int:
|
||||
return calendar.timegm(datetime.utcnow().utctimetuple())
|
||||
|
||||
|
||||
def init_pgbench(env: PgCompare, cmdline):
|
||||
# calculate timestamps and durations separately
|
||||
# timestamp is intended to be used for linking to grafana and logs
|
||||
# duration is actually a metric and uses float instead of int for timestamp
|
||||
init_start_timestamp = utc_now_timestamp()
|
||||
t0 = timeit.default_timer()
|
||||
with env.record_pageserver_writes('init.pageserver_writes'):
|
||||
env.pg_bin.run_capture(cmdline)
|
||||
env.flush()
|
||||
init_duration = timeit.default_timer() - t0
|
||||
init_end_timestamp = utc_now_timestamp()
|
||||
|
||||
env.zenbenchmark.record("init.duration",
|
||||
init_duration,
|
||||
unit="s",
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
env.zenbenchmark.record("init.start_timestamp",
|
||||
init_start_timestamp,
|
||||
'',
|
||||
MetricReport.TEST_PARAM)
|
||||
env.zenbenchmark.record("init.end_timestamp", init_end_timestamp, '', MetricReport.TEST_PARAM)
|
||||
|
||||
|
||||
def run_pgbench(env: PgCompare, prefix: str, cmdline):
|
||||
with env.record_pageserver_writes(f'{prefix}.pageserver_writes'):
|
||||
run_start_timestamp = utc_now_timestamp()
|
||||
t0 = timeit.default_timer()
|
||||
out = env.pg_bin.run_capture(cmdline, )
|
||||
run_duration = timeit.default_timer() - t0
|
||||
run_end_timestamp = utc_now_timestamp()
|
||||
env.flush()
|
||||
|
||||
stdout = Path(f"{out}.stdout").read_text()
|
||||
|
||||
res = PgBenchRunResult.parse_from_stdout(
|
||||
stdout=stdout,
|
||||
run_duration=run_duration,
|
||||
run_start_timestamp=run_start_timestamp,
|
||||
run_end_timestamp=run_end_timestamp,
|
||||
)
|
||||
env.zenbenchmark.record_pg_bench_result(prefix, res)
|
||||
|
||||
|
||||
#
|
||||
# Run a very short pgbench test.
|
||||
# Initialize a pgbench database, and run pgbench against it.
|
||||
#
|
||||
# Collects three metrics:
|
||||
# This makes runs two different pgbench workloads against the same
|
||||
# initialized database, and 'duration' is the time of each run. So
|
||||
# the total runtime is 2 * duration, plus time needed to initialize
|
||||
# the test database.
|
||||
#
|
||||
# 1. Time to initialize the pgbench database (pgbench -s5 -i)
|
||||
# 2. Time to run 5000 pgbench transactions
|
||||
# 3. Disk space used
|
||||
#
|
||||
def test_pgbench(zenith_with_baseline: PgCompare):
|
||||
env = zenith_with_baseline
|
||||
# Currently, the # of connections is hardcoded at 4
|
||||
def run_test_pgbench(env: PgCompare, scale: int, duration: int):
|
||||
|
||||
with env.record_pageserver_writes('pageserver_writes'):
|
||||
with env.record_duration('init'):
|
||||
env.pg_bin.run_capture(['pgbench', '-s5', '-i', env.pg.connstr()])
|
||||
env.flush()
|
||||
# Record the scale and initialize
|
||||
env.zenbenchmark.record("scale", scale, '', MetricReport.TEST_PARAM)
|
||||
init_pgbench(env, ['pgbench', f'-s{scale}', '-i', env.pg.connstr()])
|
||||
|
||||
with env.record_duration('5000_xacts'):
|
||||
env.pg_bin.run_capture(['pgbench', '-c1', '-t5000', env.pg.connstr()])
|
||||
env.flush()
|
||||
# Run simple-update workload
|
||||
run_pgbench(env,
|
||||
"simple-update",
|
||||
['pgbench', '-n', '-c4', f'-T{duration}', '-P2', '-Mprepared', env.pg.connstr()])
|
||||
|
||||
# Run SELECT workload
|
||||
run_pgbench(env,
|
||||
"select-only",
|
||||
['pgbench', '-S', '-c4', f'-T{duration}', '-P2', '-Mprepared', env.pg.connstr()])
|
||||
|
||||
env.report_size()
|
||||
|
||||
|
||||
def get_durations_matrix():
|
||||
durations = os.getenv("TEST_PG_BENCH_DURATIONS_MATRIX", default="45")
|
||||
return list(map(int, durations.split(",")))
|
||||
|
||||
|
||||
def get_scales_matrix():
|
||||
scales = os.getenv("TEST_PG_BENCH_SCALES_MATRIX", default="10")
|
||||
return list(map(int, scales.split(",")))
|
||||
|
||||
|
||||
# Run the pgbench tests against vanilla Postgres and zenith
|
||||
@pytest.mark.parametrize("scale", get_scales_matrix())
|
||||
@pytest.mark.parametrize("duration", get_durations_matrix())
|
||||
def test_pgbench(zenith_with_baseline: PgCompare, scale: int, duration: int):
|
||||
run_test_pgbench(zenith_with_baseline, scale, duration)
|
||||
|
||||
|
||||
# Run the pgbench tests against an existing Postgres cluster
|
||||
@pytest.mark.parametrize("scale", get_scales_matrix())
|
||||
@pytest.mark.parametrize("duration", get_durations_matrix())
|
||||
@pytest.mark.remote_cluster
|
||||
def test_pgbench_remote(remote_compare: PgCompare, scale: int, duration: int):
|
||||
run_test_pgbench(remote_compare, scale, duration)
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List
|
||||
from fixtures.benchmark_fixture import PgBenchRunResult, ZenithBenchmarker
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
import calendar
|
||||
import timeit
|
||||
import os
|
||||
|
||||
|
||||
def utc_now_timestamp() -> int:
|
||||
return calendar.timegm(datetime.utcnow().utctimetuple())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PgBenchRunner:
|
||||
connstr: str
|
||||
scale: int
|
||||
transactions: int
|
||||
pgbench_bin_path: str = "pgbench"
|
||||
|
||||
def invoke(self, args: List[str]) -> 'subprocess.CompletedProcess[str]':
|
||||
res = subprocess.run([self.pgbench_bin_path, *args], text=True, capture_output=True)
|
||||
|
||||
if res.returncode != 0:
|
||||
raise RuntimeError(f"pgbench failed. stdout: {res.stdout} stderr: {res.stderr}")
|
||||
return res
|
||||
|
||||
def init(self, vacuum: bool = True) -> 'subprocess.CompletedProcess[str]':
|
||||
args = []
|
||||
if not vacuum:
|
||||
args.append("--no-vacuum")
|
||||
args.extend([f"--scale={self.scale}", "--initialize", self.connstr])
|
||||
return self.invoke(args)
|
||||
|
||||
def run(self, jobs: int = 1, clients: int = 1):
|
||||
return self.invoke([
|
||||
f"--transactions={self.transactions}",
|
||||
f"--jobs={jobs}",
|
||||
f"--client={clients}",
|
||||
"--progress=2", # print progress every two seconds
|
||||
self.connstr,
|
||||
])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connstr():
|
||||
res = os.getenv("BENCHMARK_CONNSTR")
|
||||
if res is None:
|
||||
raise ValueError("no connstr provided, use BENCHMARK_CONNSTR environment variable")
|
||||
return res
|
||||
|
||||
|
||||
def get_transactions_matrix():
|
||||
transactions = os.getenv("TEST_PG_BENCH_TRANSACTIONS_MATRIX")
|
||||
if transactions is None:
|
||||
return [10**4, 10**5]
|
||||
return list(map(int, transactions.split(",")))
|
||||
|
||||
|
||||
def get_scales_matrix():
|
||||
scales = os.getenv("TEST_PG_BENCH_SCALES_MATRIX")
|
||||
if scales is None:
|
||||
return [10, 20]
|
||||
return list(map(int, scales.split(",")))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scale", get_scales_matrix())
|
||||
@pytest.mark.parametrize("transactions", get_transactions_matrix())
|
||||
@pytest.mark.remote_cluster
|
||||
def test_pg_bench_remote_cluster(zenbenchmark: ZenithBenchmarker,
|
||||
connstr: str,
|
||||
scale: int,
|
||||
transactions: int):
|
||||
"""
|
||||
The best way is to run same pack of tests both, for local zenith
|
||||
and against staging, but currently local tests heavily depend on
|
||||
things available only locally e.g. zenith binaries, pageserver api, etc.
|
||||
Also separate test allows to run pgbench workload against vanilla postgres
|
||||
or other systems that support postgres protocol.
|
||||
|
||||
Also now this is more of a liveness test because it stresses pageserver internals,
|
||||
so we clearly see what goes wrong in more "real" environment.
|
||||
"""
|
||||
pg_bin = os.getenv("PG_BIN")
|
||||
if pg_bin is not None:
|
||||
pgbench_bin_path = os.path.join(pg_bin, "pgbench")
|
||||
else:
|
||||
pgbench_bin_path = "pgbench"
|
||||
|
||||
runner = PgBenchRunner(
|
||||
connstr=connstr,
|
||||
scale=scale,
|
||||
transactions=transactions,
|
||||
pgbench_bin_path=pgbench_bin_path,
|
||||
)
|
||||
# calculate timestamps and durations separately
|
||||
# timestamp is intended to be used for linking to grafana and logs
|
||||
# duration is actually a metric and uses float instead of int for timestamp
|
||||
init_start_timestamp = utc_now_timestamp()
|
||||
t0 = timeit.default_timer()
|
||||
runner.init()
|
||||
init_duration = timeit.default_timer() - t0
|
||||
init_end_timestamp = utc_now_timestamp()
|
||||
|
||||
run_start_timestamp = utc_now_timestamp()
|
||||
t0 = timeit.default_timer()
|
||||
out = runner.run() # TODO handle failures
|
||||
run_duration = timeit.default_timer() - t0
|
||||
run_end_timestamp = utc_now_timestamp()
|
||||
|
||||
res = PgBenchRunResult.parse_from_output(
|
||||
out=out,
|
||||
init_duration=init_duration,
|
||||
init_start_timestamp=init_start_timestamp,
|
||||
init_end_timestamp=init_end_timestamp,
|
||||
run_duration=run_duration,
|
||||
run_start_timestamp=run_start_timestamp,
|
||||
run_end_timestamp=run_end_timestamp,
|
||||
)
|
||||
|
||||
zenbenchmark.record_pg_bench_result(res)
|
||||
@@ -6,6 +6,7 @@ use lazy_static::lazy_static;
|
||||
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Write};
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use tracing::*;
|
||||
@@ -37,8 +38,10 @@ lazy_static! {
|
||||
.expect("Failed to register safekeeper_persist_control_file_seconds histogram vec");
|
||||
}
|
||||
|
||||
pub trait Storage {
|
||||
/// Persist safekeeper state on disk.
|
||||
/// Storage should keep actual state inside of it. It should implement Deref
|
||||
/// trait to access state fields and have persist method for updating that state.
|
||||
pub trait Storage: Deref<Target = SafeKeeperState> {
|
||||
/// Persist safekeeper state on disk and update internal state.
|
||||
fn persist(&mut self, s: &SafeKeeperState) -> Result<()>;
|
||||
}
|
||||
|
||||
@@ -48,19 +51,47 @@ pub struct FileStorage {
|
||||
timeline_dir: PathBuf,
|
||||
conf: SafeKeeperConf,
|
||||
persist_control_file_seconds: Histogram,
|
||||
|
||||
/// Last state persisted to disk.
|
||||
state: SafeKeeperState,
|
||||
}
|
||||
|
||||
impl FileStorage {
|
||||
pub fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> FileStorage {
|
||||
pub fn restore_new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> Result<FileStorage> {
|
||||
let timeline_dir = conf.timeline_dir(zttid);
|
||||
let tenant_id = zttid.tenant_id.to_string();
|
||||
let timeline_id = zttid.timeline_id.to_string();
|
||||
FileStorage {
|
||||
|
||||
let state = Self::load_control_file_conf(conf, zttid)?;
|
||||
|
||||
Ok(FileStorage {
|
||||
timeline_dir,
|
||||
conf: conf.clone(),
|
||||
persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS
|
||||
.with_label_values(&[&tenant_id, &timeline_id]),
|
||||
}
|
||||
state,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_new(
|
||||
zttid: &ZTenantTimelineId,
|
||||
conf: &SafeKeeperConf,
|
||||
state: SafeKeeperState,
|
||||
) -> Result<FileStorage> {
|
||||
let timeline_dir = conf.timeline_dir(zttid);
|
||||
let tenant_id = zttid.tenant_id.to_string();
|
||||
let timeline_id = zttid.timeline_id.to_string();
|
||||
|
||||
let mut store = FileStorage {
|
||||
timeline_dir,
|
||||
conf: conf.clone(),
|
||||
persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS
|
||||
.with_label_values(&[&tenant_id, &timeline_id]),
|
||||
state: state.clone(),
|
||||
};
|
||||
|
||||
store.persist(&state)?;
|
||||
Ok(store)
|
||||
}
|
||||
|
||||
// Check the magic/version in the on-disk data and deserialize it, if possible.
|
||||
@@ -141,6 +172,14 @@ impl FileStorage {
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for FileStorage {
|
||||
type Target = SafeKeeperState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl Storage for FileStorage {
|
||||
// persists state durably to underlying storage
|
||||
// for description see https://lwn.net/Articles/457667/
|
||||
@@ -201,6 +240,9 @@ impl Storage for FileStorage {
|
||||
.and_then(|f| f.sync_all())
|
||||
.context("failed to sync control file directory")?;
|
||||
}
|
||||
|
||||
// update internal state
|
||||
self.state = s.clone();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -228,7 +270,7 @@ mod test {
|
||||
) -> Result<(FileStorage, SafeKeeperState)> {
|
||||
fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir");
|
||||
Ok((
|
||||
FileStorage::new(zttid, conf),
|
||||
FileStorage::restore_new(zttid, conf)?,
|
||||
FileStorage::load_control_file_conf(conf, zttid)?,
|
||||
))
|
||||
}
|
||||
@@ -239,8 +281,7 @@ mod test {
|
||||
) -> Result<(FileStorage, SafeKeeperState)> {
|
||||
fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir");
|
||||
let state = SafeKeeperState::empty();
|
||||
let mut storage = FileStorage::new(zttid, conf);
|
||||
storage.persist(&state)?;
|
||||
let storage = FileStorage::create_new(zttid, conf, state.clone())?;
|
||||
Ok((storage, state))
|
||||
}
|
||||
|
||||
|
||||
@@ -210,6 +210,7 @@ pub struct SafekeeperMemState {
|
||||
pub s3_wal_lsn: Lsn, // TODO: keep only persistent version
|
||||
pub peer_horizon_lsn: Lsn,
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
pub proposer_uuid: PgUuid,
|
||||
}
|
||||
|
||||
impl SafeKeeperState {
|
||||
@@ -502,9 +503,8 @@ pub struct SafeKeeper<CTRL: control_file::Storage, WAL: wal_storage::Storage> {
|
||||
epoch_start_lsn: Lsn,
|
||||
|
||||
pub inmem: SafekeeperMemState, // in memory part
|
||||
pub s: SafeKeeperState, // persistent part
|
||||
pub state: CTRL, // persistent state storage
|
||||
|
||||
pub control_store: CTRL,
|
||||
pub wal_store: WAL,
|
||||
}
|
||||
|
||||
@@ -516,14 +516,14 @@ where
|
||||
// constructor
|
||||
pub fn new(
|
||||
ztli: ZTimelineId,
|
||||
control_store: CTRL,
|
||||
state: CTRL,
|
||||
mut wal_store: WAL,
|
||||
state: SafeKeeperState,
|
||||
) -> Result<SafeKeeper<CTRL, WAL>> {
|
||||
if state.timeline_id != ZTimelineId::from([0u8; 16]) && ztli != state.timeline_id {
|
||||
bail!("Calling SafeKeeper::new with inconsistent ztli ({}) and SafeKeeperState.server.timeline_id ({})", ztli, state.timeline_id);
|
||||
}
|
||||
|
||||
// initialize wal_store, if state is already initialized
|
||||
wal_store.init_storage(&state)?;
|
||||
|
||||
Ok(SafeKeeper {
|
||||
@@ -535,23 +535,25 @@ where
|
||||
s3_wal_lsn: state.s3_wal_lsn,
|
||||
peer_horizon_lsn: state.peer_horizon_lsn,
|
||||
remote_consistent_lsn: state.remote_consistent_lsn,
|
||||
proposer_uuid: state.proposer_uuid,
|
||||
},
|
||||
s: state,
|
||||
control_store,
|
||||
state,
|
||||
wal_store,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get history of term switches for the available WAL
|
||||
fn get_term_history(&self) -> TermHistory {
|
||||
self.s
|
||||
self.state
|
||||
.acceptor_state
|
||||
.term_history
|
||||
.up_to(self.wal_store.flush_lsn())
|
||||
}
|
||||
|
||||
pub fn get_epoch(&self) -> Term {
|
||||
self.s.acceptor_state.get_epoch(self.wal_store.flush_lsn())
|
||||
self.state
|
||||
.acceptor_state
|
||||
.get_epoch(self.wal_store.flush_lsn())
|
||||
}
|
||||
|
||||
/// Process message from proposer and possibly form reply. Concurrent
|
||||
@@ -587,46 +589,47 @@ where
|
||||
);
|
||||
}
|
||||
/* Postgres upgrade is not treated as fatal error */
|
||||
if msg.pg_version != self.s.server.pg_version
|
||||
&& self.s.server.pg_version != UNKNOWN_SERVER_VERSION
|
||||
if msg.pg_version != self.state.server.pg_version
|
||||
&& self.state.server.pg_version != UNKNOWN_SERVER_VERSION
|
||||
{
|
||||
info!(
|
||||
"incompatible server version {}, expected {}",
|
||||
msg.pg_version, self.s.server.pg_version
|
||||
msg.pg_version, self.state.server.pg_version
|
||||
);
|
||||
}
|
||||
if msg.tenant_id != self.s.tenant_id {
|
||||
if msg.tenant_id != self.state.tenant_id {
|
||||
bail!(
|
||||
"invalid tenant ID, got {}, expected {}",
|
||||
msg.tenant_id,
|
||||
self.s.tenant_id
|
||||
self.state.tenant_id
|
||||
);
|
||||
}
|
||||
if msg.ztli != self.s.timeline_id {
|
||||
if msg.ztli != self.state.timeline_id {
|
||||
bail!(
|
||||
"invalid timeline ID, got {}, expected {}",
|
||||
msg.ztli,
|
||||
self.s.timeline_id
|
||||
self.state.timeline_id
|
||||
);
|
||||
}
|
||||
|
||||
// set basic info about server, if not yet
|
||||
// TODO: verify that is doesn't change after
|
||||
self.s.server.system_id = msg.system_id;
|
||||
self.s.server.wal_seg_size = msg.wal_seg_size;
|
||||
self.control_store
|
||||
.persist(&self.s)
|
||||
.context("failed to persist shared state")?;
|
||||
{
|
||||
let mut state = self.state.clone();
|
||||
state.server.system_id = msg.system_id;
|
||||
state.server.wal_seg_size = msg.wal_seg_size;
|
||||
self.state.persist(&state)?;
|
||||
}
|
||||
|
||||
// pass wal_seg_size to read WAL and find flush_lsn
|
||||
self.wal_store.init_storage(&self.s)?;
|
||||
self.wal_store.init_storage(&self.state)?;
|
||||
|
||||
info!(
|
||||
"processed greeting from proposer {:?}, sending term {:?}",
|
||||
msg.proposer_id, self.s.acceptor_state.term
|
||||
msg.proposer_id, self.state.acceptor_state.term
|
||||
);
|
||||
Ok(Some(AcceptorProposerMessage::Greeting(AcceptorGreeting {
|
||||
term: self.s.acceptor_state.term,
|
||||
term: self.state.acceptor_state.term,
|
||||
})))
|
||||
}
|
||||
|
||||
@@ -637,17 +640,19 @@ where
|
||||
) -> Result<Option<AcceptorProposerMessage>> {
|
||||
// initialize with refusal
|
||||
let mut resp = VoteResponse {
|
||||
term: self.s.acceptor_state.term,
|
||||
term: self.state.acceptor_state.term,
|
||||
vote_given: false as u64,
|
||||
flush_lsn: self.wal_store.flush_lsn(),
|
||||
truncate_lsn: self.s.peer_horizon_lsn,
|
||||
truncate_lsn: self.state.peer_horizon_lsn,
|
||||
term_history: self.get_term_history(),
|
||||
};
|
||||
if self.s.acceptor_state.term < msg.term {
|
||||
self.s.acceptor_state.term = msg.term;
|
||||
if self.state.acceptor_state.term < msg.term {
|
||||
let mut state = self.state.clone();
|
||||
state.acceptor_state.term = msg.term;
|
||||
// persist vote before sending it out
|
||||
self.control_store.persist(&self.s)?;
|
||||
resp.term = self.s.acceptor_state.term;
|
||||
self.state.persist(&state)?;
|
||||
|
||||
resp.term = self.state.acceptor_state.term;
|
||||
resp.vote_given = true as u64;
|
||||
}
|
||||
info!("processed VoteRequest for term {}: {:?}", msg.term, &resp);
|
||||
@@ -656,9 +661,10 @@ where
|
||||
|
||||
/// Bump our term if received a note from elected proposer with higher one
|
||||
fn bump_if_higher(&mut self, term: Term) -> Result<()> {
|
||||
if self.s.acceptor_state.term < term {
|
||||
self.s.acceptor_state.term = term;
|
||||
self.control_store.persist(&self.s)?;
|
||||
if self.state.acceptor_state.term < term {
|
||||
let mut state = self.state.clone();
|
||||
state.acceptor_state.term = term;
|
||||
self.state.persist(&state)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -666,9 +672,9 @@ where
|
||||
/// Form AppendResponse from current state.
|
||||
fn append_response(&self) -> AppendResponse {
|
||||
let ar = AppendResponse {
|
||||
term: self.s.acceptor_state.term,
|
||||
term: self.state.acceptor_state.term,
|
||||
flush_lsn: self.wal_store.flush_lsn(),
|
||||
commit_lsn: self.s.commit_lsn,
|
||||
commit_lsn: self.state.commit_lsn,
|
||||
// will be filled by the upper code to avoid bothering safekeeper
|
||||
hs_feedback: HotStandbyFeedback::empty(),
|
||||
zenith_feedback: ZenithFeedback::empty(),
|
||||
@@ -681,7 +687,7 @@ where
|
||||
info!("received ProposerElected {:?}", msg);
|
||||
self.bump_if_higher(msg.term)?;
|
||||
// If our term is higher, ignore the message (next feedback will inform the compute)
|
||||
if self.s.acceptor_state.term > msg.term {
|
||||
if self.state.acceptor_state.term > msg.term {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
@@ -692,8 +698,11 @@ where
|
||||
self.wal_store.truncate_wal(msg.start_streaming_at)?;
|
||||
|
||||
// and now adopt term history from proposer
|
||||
self.s.acceptor_state.term_history = msg.term_history.clone();
|
||||
self.control_store.persist(&self.s)?;
|
||||
{
|
||||
let mut state = self.state.clone();
|
||||
state.acceptor_state.term_history = msg.term_history.clone();
|
||||
self.state.persist(&state)?;
|
||||
}
|
||||
|
||||
info!("start receiving WAL since {:?}", msg.start_streaming_at);
|
||||
|
||||
@@ -715,13 +724,13 @@ where
|
||||
// Also note that commit_lsn can reach epoch_start_lsn earlier
|
||||
// that we receive new epoch_start_lsn, and we still need to sync
|
||||
// control file in this case.
|
||||
if commit_lsn == self.epoch_start_lsn && self.s.commit_lsn != commit_lsn {
|
||||
if commit_lsn == self.epoch_start_lsn && self.state.commit_lsn != commit_lsn {
|
||||
self.persist_control_file()?;
|
||||
}
|
||||
|
||||
// We got our first commit_lsn, which means we should sync
|
||||
// everything to disk, to initialize the state.
|
||||
if self.s.commit_lsn == Lsn(0) && commit_lsn > Lsn(0) {
|
||||
if self.state.commit_lsn == Lsn(0) && commit_lsn > Lsn(0) {
|
||||
self.wal_store.flush_wal()?;
|
||||
self.persist_control_file()?;
|
||||
}
|
||||
@@ -731,10 +740,12 @@ where
|
||||
|
||||
/// Persist in-memory state to the disk.
|
||||
fn persist_control_file(&mut self) -> Result<()> {
|
||||
self.s.commit_lsn = self.inmem.commit_lsn;
|
||||
self.s.peer_horizon_lsn = self.inmem.peer_horizon_lsn;
|
||||
let mut state = self.state.clone();
|
||||
|
||||
self.control_store.persist(&self.s)
|
||||
state.commit_lsn = self.inmem.commit_lsn;
|
||||
state.peer_horizon_lsn = self.inmem.peer_horizon_lsn;
|
||||
state.proposer_uuid = self.inmem.proposer_uuid;
|
||||
self.state.persist(&state)
|
||||
}
|
||||
|
||||
/// Handle request to append WAL.
|
||||
@@ -744,13 +755,13 @@ where
|
||||
msg: &AppendRequest,
|
||||
require_flush: bool,
|
||||
) -> Result<Option<AcceptorProposerMessage>> {
|
||||
if self.s.acceptor_state.term < msg.h.term {
|
||||
if self.state.acceptor_state.term < msg.h.term {
|
||||
bail!("got AppendRequest before ProposerElected");
|
||||
}
|
||||
|
||||
// If our term is higher, immediately refuse the message.
|
||||
if self.s.acceptor_state.term > msg.h.term {
|
||||
let resp = AppendResponse::term_only(self.s.acceptor_state.term);
|
||||
if self.state.acceptor_state.term > msg.h.term {
|
||||
let resp = AppendResponse::term_only(self.state.acceptor_state.term);
|
||||
return Ok(Some(AcceptorProposerMessage::AppendResponse(resp)));
|
||||
}
|
||||
|
||||
@@ -758,8 +769,7 @@ where
|
||||
// processing the message.
|
||||
|
||||
self.epoch_start_lsn = msg.h.epoch_start_lsn;
|
||||
// TODO: don't update state without persisting to disk
|
||||
self.s.proposer_uuid = msg.h.proposer_uuid;
|
||||
self.inmem.proposer_uuid = msg.h.proposer_uuid;
|
||||
|
||||
// do the job
|
||||
if !msg.wal_data.is_empty() {
|
||||
@@ -790,7 +800,7 @@ where
|
||||
// Update truncate and commit LSN in control file.
|
||||
// To avoid negative impact on performance of extra fsync, do it only
|
||||
// when truncate_lsn delta exceeds WAL segment size.
|
||||
if self.s.peer_horizon_lsn + (self.s.server.wal_seg_size as u64)
|
||||
if self.state.peer_horizon_lsn + (self.state.server.wal_seg_size as u64)
|
||||
< self.inmem.peer_horizon_lsn
|
||||
{
|
||||
self.persist_control_file()?;
|
||||
@@ -829,6 +839,8 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Deref;
|
||||
|
||||
use super::*;
|
||||
use crate::wal_storage::Storage;
|
||||
|
||||
@@ -844,6 +856,14 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for InMemoryState {
|
||||
type Target = SafeKeeperState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.persisted_state
|
||||
}
|
||||
}
|
||||
|
||||
struct DummyWalStore {
|
||||
lsn: Lsn,
|
||||
}
|
||||
@@ -879,7 +899,7 @@ mod tests {
|
||||
};
|
||||
let wal_store = DummyWalStore { lsn: Lsn(0) };
|
||||
let ztli = ZTimelineId::from([0u8; 16]);
|
||||
let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::empty()).unwrap();
|
||||
let mut sk = SafeKeeper::new(ztli, storage, wal_store).unwrap();
|
||||
|
||||
// check voting for 1 is ok
|
||||
let vote_request = ProposerAcceptorMessage::VoteRequest(VoteRequest { term: 1 });
|
||||
@@ -890,11 +910,11 @@ mod tests {
|
||||
}
|
||||
|
||||
// reboot...
|
||||
let state = sk.control_store.persisted_state.clone();
|
||||
let state = sk.state.persisted_state.clone();
|
||||
let storage = InMemoryState {
|
||||
persisted_state: state.clone(),
|
||||
persisted_state: state,
|
||||
};
|
||||
sk = SafeKeeper::new(ztli, storage, sk.wal_store, state).unwrap();
|
||||
sk = SafeKeeper::new(ztli, storage, sk.wal_store).unwrap();
|
||||
|
||||
// and ensure voting second time for 1 is not ok
|
||||
vote_resp = sk.process_msg(&vote_request);
|
||||
@@ -911,7 +931,7 @@ mod tests {
|
||||
};
|
||||
let wal_store = DummyWalStore { lsn: Lsn(0) };
|
||||
let ztli = ZTimelineId::from([0u8; 16]);
|
||||
let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::empty()).unwrap();
|
||||
let mut sk = SafeKeeper::new(ztli, storage, wal_store).unwrap();
|
||||
|
||||
let mut ar_hdr = AppendRequestHeader {
|
||||
term: 1,
|
||||
|
||||
@@ -21,7 +21,6 @@ use crate::broker::SafekeeperInfo;
|
||||
use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey};
|
||||
|
||||
use crate::control_file;
|
||||
use crate::control_file::Storage as cf_storage;
|
||||
use crate::safekeeper::{
|
||||
AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState,
|
||||
SafekeeperMemState,
|
||||
@@ -98,10 +97,9 @@ impl SharedState {
|
||||
peer_ids: Vec<ZNodeId>,
|
||||
) -> Result<Self> {
|
||||
let state = SafeKeeperState::new(zttid, peer_ids);
|
||||
let control_store = control_file::FileStorage::new(zttid, conf);
|
||||
let control_store = control_file::FileStorage::create_new(zttid, conf, state)?;
|
||||
let wal_store = wal_storage::PhysicalStorage::new(zttid, conf);
|
||||
let mut sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state)?;
|
||||
sk.control_store.persist(&sk.s)?;
|
||||
let sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store)?;
|
||||
|
||||
Ok(Self {
|
||||
notified_commit_lsn: Lsn(0),
|
||||
@@ -116,18 +114,14 @@ impl SharedState {
|
||||
/// Restore SharedState from control file.
|
||||
/// If file doesn't exist, bails out.
|
||||
fn restore(conf: &SafeKeeperConf, zttid: &ZTenantTimelineId) -> Result<Self> {
|
||||
let state = control_file::FileStorage::load_control_file_conf(conf, zttid)
|
||||
.context("failed to load from control file")?;
|
||||
|
||||
let control_store = control_file::FileStorage::new(zttid, conf);
|
||||
|
||||
let control_store = control_file::FileStorage::restore_new(zttid, conf)?;
|
||||
let wal_store = wal_storage::PhysicalStorage::new(zttid, conf);
|
||||
|
||||
info!("timeline {} restored", zttid.timeline_id);
|
||||
|
||||
Ok(Self {
|
||||
notified_commit_lsn: Lsn(0),
|
||||
sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state)?,
|
||||
sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store)?,
|
||||
replicas: Vec::new(),
|
||||
active: false,
|
||||
num_computes: 0,
|
||||
@@ -419,7 +413,7 @@ impl Timeline {
|
||||
|
||||
pub fn get_state(&self) -> (SafekeeperMemState, SafeKeeperState) {
|
||||
let shared_state = self.mutex.lock().unwrap();
|
||||
(shared_state.sk.inmem.clone(), shared_state.sk.s.clone())
|
||||
(shared_state.sk.inmem.clone(), shared_state.sk.state.clone())
|
||||
}
|
||||
|
||||
/// Prepare public safekeeper info for reporting.
|
||||
|
||||
4
workspace_hack/.gitattributes
vendored
Normal file
4
workspace_hack/.gitattributes
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# Avoid putting conflict markers in the generated Cargo.toml file, since their presence breaks
|
||||
# Cargo.
|
||||
# Also do not check out the file as CRLF on Windows, as that's what hakari needs.
|
||||
Cargo.toml merge=binary -crlf
|
||||
@@ -16,32 +16,38 @@ publish = false
|
||||
[dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace", "std"] }
|
||||
bytes = { version = "1", features = ["serde", "std"] }
|
||||
chrono = { version = "0.4", features = ["clock", "libc", "oldtime", "serde", "std", "time", "winapi"] }
|
||||
clap = { version = "2", features = ["ansi_term", "atty", "color", "strsim", "suggestions", "vec_map"] }
|
||||
either = { version = "1", features = ["use_std"] }
|
||||
hashbrown = { version = "0.11", features = ["ahash", "inline-more", "raw"] }
|
||||
indexmap = { version = "1", default-features = false, features = ["std"] }
|
||||
libc = { version = "0.2", features = ["extra_traits", "std"] }
|
||||
log = { version = "0.4", default-features = false, features = ["serde", "std"] }
|
||||
memchr = { version = "2", features = ["std", "use_std"] }
|
||||
num-integer = { version = "0.1", default-features = false, features = ["std"] }
|
||||
num-traits = { version = "0.2", features = ["std"] }
|
||||
prost = { version = "0.9", features = ["prost-derive", "std"] }
|
||||
rand = { version = "0.8", features = ["alloc", "getrandom", "libc", "rand_chacha", "rand_hc", "small_rng", "std", "std_rng"] }
|
||||
regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cache", "perf-dfa", "perf-inline", "perf-literal", "std", "unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] }
|
||||
regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] }
|
||||
reqwest = { version = "0.11", default-features = false, features = ["__rustls", "__tls", "blocking", "hyper-rustls", "json", "rustls", "rustls-pemfile", "rustls-tls", "rustls-tls-webpki-roots", "serde_json", "stream", "tokio-rustls", "tokio-util", "webpki-roots"] }
|
||||
scopeguard = { version = "1", features = ["use_std"] }
|
||||
serde = { version = "1", features = ["alloc", "derive", "serde_derive", "std"] }
|
||||
tokio = { version = "1", features = ["bytes", "fs", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "once_cell", "process", "rt", "rt-multi-thread", "signal-hook-registry", "sync", "time", "tokio-macros"] }
|
||||
tracing = { version = "0.1", features = ["attributes", "std", "tracing-attributes"] }
|
||||
tokio = { version = "1", features = ["bytes", "fs", "io-std", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "once_cell", "process", "rt", "rt-multi-thread", "signal-hook-registry", "socket2", "sync", "time", "tokio-macros"] }
|
||||
tracing = { version = "0.1", features = ["attributes", "log", "std", "tracing-attributes"] }
|
||||
tracing-core = { version = "0.1", features = ["lazy_static", "std"] }
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace", "std"] }
|
||||
bytes = { version = "1", features = ["serde", "std"] }
|
||||
cc = { version = "1", default-features = false, features = ["jobserver", "parallel"] }
|
||||
clap = { version = "2", features = ["ansi_term", "atty", "color", "strsim", "suggestions", "vec_map"] }
|
||||
either = { version = "1", features = ["use_std"] }
|
||||
hashbrown = { version = "0.11", features = ["ahash", "inline-more", "raw"] }
|
||||
indexmap = { version = "1", default-features = false, features = ["std"] }
|
||||
libc = { version = "0.2", features = ["extra_traits", "std"] }
|
||||
log = { version = "0.4", default-features = false, features = ["serde", "std"] }
|
||||
memchr = { version = "2", features = ["std", "use_std"] }
|
||||
proc-macro2 = { version = "1", features = ["proc-macro"] }
|
||||
quote = { version = "1", features = ["proc-macro"] }
|
||||
prost = { version = "0.9", features = ["prost-derive", "std"] }
|
||||
regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cache", "perf-dfa", "perf-inline", "perf-literal", "std", "unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] }
|
||||
regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] }
|
||||
serde = { version = "1", features = ["alloc", "derive", "serde_derive", "std"] }
|
||||
|
||||
2
workspace_hack/build.rs
Normal file
2
workspace_hack/build.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
// A build script is required for cargo to consider build dependencies.
|
||||
fn main() {}
|
||||
@@ -375,9 +375,8 @@ impl PostgresBackend {
|
||||
}
|
||||
AuthType::MD5 => {
|
||||
rand::thread_rng().fill(&mut self.md5_salt);
|
||||
let md5_salt = self.md5_salt;
|
||||
self.write_message(&BeMessage::AuthenticationMD5Password(
|
||||
&md5_salt,
|
||||
self.md5_salt,
|
||||
))?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
|
||||
@@ -401,7 +401,8 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
||||
#[derive(Debug)]
|
||||
pub enum BeMessage<'a> {
|
||||
AuthenticationOk,
|
||||
AuthenticationMD5Password(&'a [u8; 4]),
|
||||
AuthenticationMD5Password([u8; 4]),
|
||||
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
|
||||
AuthenticationCleartextPassword,
|
||||
BackendKeyData(CancelKeyData),
|
||||
BindComplete,
|
||||
@@ -429,6 +430,13 @@ pub enum BeMessage<'a> {
|
||||
KeepAlive(WalSndKeepAlive),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BeAuthenticationSaslMessage<'a> {
|
||||
Methods(&'a [&'a str]),
|
||||
Continue(&'a [u8]),
|
||||
Final(&'a [u8]),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BeParameterStatusMessage<'a> {
|
||||
Encoding(&'a str),
|
||||
@@ -611,6 +619,32 @@ impl<'a> BeMessage<'a> {
|
||||
.unwrap(); // write into BytesMut can't fail
|
||||
}
|
||||
|
||||
BeMessage::AuthenticationSasl(msg) => {
|
||||
buf.put_u8(b'R');
|
||||
write_body(buf, |buf| {
|
||||
use BeAuthenticationSaslMessage::*;
|
||||
match msg {
|
||||
Methods(methods) => {
|
||||
buf.put_i32(10); // Specifies that SASL auth method is used.
|
||||
for method in methods.iter() {
|
||||
write_cstr(method.as_bytes(), buf)?;
|
||||
}
|
||||
buf.put_u8(0); // zero terminator for the list
|
||||
}
|
||||
Continue(extra) => {
|
||||
buf.put_i32(11); // Continue SASL auth.
|
||||
buf.put_slice(extra);
|
||||
}
|
||||
Final(extra) => {
|
||||
buf.put_i32(12); // Send final SASL message.
|
||||
buf.put_slice(extra);
|
||||
}
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
BeMessage::BackendKeyData(key_data) => {
|
||||
buf.put_u8(b'K');
|
||||
write_body(buf, |buf| {
|
||||
|
||||
Reference in New Issue
Block a user