Merge branch 'main' into bojan-get-page-tests

This commit is contained in:
Bojan Serafimov
2022-04-14 13:59:59 -04:00
58 changed files with 2180 additions and 656 deletions

View File

@@ -63,21 +63,18 @@
tags:
- pageserver
# It seems that currently S3 integration does not play well
# even with fresh pageserver without a burden of old data.
# TODO: turn this back on once the issue is solved.
# - name: update remote storage (s3) config
# lineinfile:
# path: /storage/pageserver/data/pageserver.toml
# line: "{{ item }}"
# loop:
# - "[remote_storage]"
# - "bucket_name = '{{ bucket_name }}'"
# - "bucket_region = '{{ bucket_region }}'"
# - "prefix_in_bucket = '{{ inventory_hostname }}'"
# become: true
# tags:
# - pageserver
- name: update remote storage (s3) config
lineinfile:
path: /storage/pageserver/data/pageserver.toml
line: "{{ item }}"
loop:
- "[remote_storage]"
- "bucket_name = '{{ bucket_name }}'"
- "bucket_region = '{{ bucket_region }}'"
- "prefix_in_bucket = '{{ inventory_hostname }}'"
become: true
tags:
- pageserver
- name: upload systemd service definition
ansible.builtin.template:

View File

@@ -5,7 +5,6 @@ zenith-us-stage-ps-2 console_region_id=27
[safekeepers]
zenith-us-stage-sk-1 console_region_id=27
zenith-us-stage-sk-2 console_region_id=27
zenith-us-stage-sk-3 console_region_id=27
zenith-us-stage-sk-4 console_region_id=27
[storage:children]

View File

@@ -405,7 +405,7 @@ jobs:
- run:
name: Build coverage report
command: |
COMMIT_URL=https://github.com/zenithdb/zenith/commit/$CIRCLE_SHA1
COMMIT_URL=https://github.com/neondatabase/neon/commit/$CIRCLE_SHA1
scripts/coverage \
--dir=/tmp/zenith/coverage report \
@@ -416,8 +416,8 @@ jobs:
name: Upload coverage report
command: |
LOCAL_REPO=$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME
REPORT_URL=https://zenithdb.github.io/zenith-coverage-data/$CIRCLE_SHA1
COMMIT_URL=https://github.com/zenithdb/zenith/commit/$CIRCLE_SHA1
REPORT_URL=https://neondatabase.github.io/zenith-coverage-data/$CIRCLE_SHA1
COMMIT_URL=https://github.com/neondatabase/neon/commit/$CIRCLE_SHA1
scripts/git-upload \
--repo=https://$VIP_VAP_ACCESS_TOKEN@github.com/zenithdb/zenith-coverage-data.git \
@@ -593,7 +593,7 @@ jobs:
name: Setup helm v3
command: |
curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash
helm repo add zenithdb https://zenithdb.github.io/helm-charts
helm repo add zenithdb https://neondatabase.github.io/helm-charts
- run:
name: Re-deploy proxy
command: |
@@ -643,7 +643,7 @@ jobs:
name: Setup helm v3
command: |
curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash
helm repo add zenithdb https://zenithdb.github.io/helm-charts
helm repo add zenithdb https://neondatabase.github.io/helm-charts
- run:
name: Re-deploy proxy
command: |
@@ -672,7 +672,7 @@ jobs:
--data \
"{
\"state\": \"pending\",
\"context\": \"zenith-remote-ci\",
\"context\": \"neon-cloud-e2e\",
\"description\": \"[$REMOTE_REPO] Remote CI job is about to start\"
}"
- run:
@@ -688,7 +688,7 @@ jobs:
"{
\"ref\": \"main\",
\"inputs\": {
\"ci_job_name\": \"zenith-remote-ci\",
\"ci_job_name\": \"neon-cloud-e2e\",
\"commit_hash\": \"$CIRCLE_SHA1\",
\"remote_repo\": \"$LOCAL_REPO\"
}
@@ -828,11 +828,11 @@ workflows:
- remote-ci-trigger:
# Context passes credentials for gh api
context: CI_ACCESS_TOKEN
remote_repo: "zenithdb/console"
remote_repo: "neondatabase/cloud"
requires:
# XXX: Successful build doesn't mean everything is OK, but
# the job to be triggered takes so much time to complete (~22 min)
# that it's better not to wait for the commented-out steps
- build-zenith-debug
- build-zenith-release
# - pg_regress-tests-release
# - other-tests-release

View File

@@ -26,7 +26,7 @@ jobs:
runs-on: [self-hosted, zenith-benchmarker]
env:
PG_BIN: "/usr/pgsql-13/bin"
POSTGRES_DISTRIB_DIR: "/usr/pgsql-13"
steps:
- name: Checkout zenith repo
@@ -51,7 +51,7 @@ jobs:
echo Poetry
poetry --version
echo Pgbench
$PG_BIN/pgbench --version
$POSTGRES_DISTRIB_DIR/bin/pgbench --version
# FIXME cluster setup is skipped due to various changes in console API
# for now pre created cluster is used. When API gain some stability
@@ -66,7 +66,7 @@ jobs:
echo "Starting cluster"
# wake up the cluster
$PG_BIN/psql $BENCHMARK_CONNSTR -c "SELECT 1"
$POSTGRES_DISTRIB_DIR/bin/psql $BENCHMARK_CONNSTR -c "SELECT 1"
- name: Run benchmark
# pgbench is installed system wide from official repo
@@ -83,8 +83,11 @@ jobs:
# sudo yum install postgresql13-contrib
# actual binaries are located in /usr/pgsql-13/bin/
env:
TEST_PG_BENCH_TRANSACTIONS_MATRIX: "5000,10000,20000"
TEST_PG_BENCH_SCALES_MATRIX: "10,15"
# The pgbench test runs two tests of given duration against each scale.
# So the total runtime with these parameters is 2 * 2 * 300 = 1200, or 20 minutes.
# Plus time needed to initialize the test databases.
TEST_PG_BENCH_DURATIONS_MATRIX: "300"
TEST_PG_BENCH_SCALES_MATRIX: "10,100"
PLATFORM: "zenith-staging"
BENCHMARK_CONNSTR: "${{ secrets.BENCHMARK_STAGING_CONNSTR }}"
REMOTE_ENV: "1" # indicate to test harness that we do not have zenith binaries locally

44
Cargo.lock generated
View File

@@ -361,6 +361,7 @@ dependencies = [
"serde_json",
"tar",
"tokio",
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"workspace_hack",
]
@@ -1571,7 +1572,6 @@ dependencies = [
"tokio-util 0.7.0",
"toml_edit",
"tracing",
"tracing-futures",
"url",
"workspace_hack",
"zenith_metrics",
@@ -1952,12 +1952,15 @@ name = "proxy"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"base64 0.13.0",
"bytes",
"clap 3.1.8",
"fail",
"futures",
"hashbrown",
"hex",
"hmac 0.10.1",
"hyper",
"lazy_static",
"md5",
@@ -1966,10 +1969,13 @@ dependencies = [
"rand",
"rcgen",
"reqwest",
"routerify 2.2.0",
"rstest",
"rustls 0.19.1",
"scopeguard",
"serde",
"serde_json",
"sha2",
"socket2",
"thiserror",
"tokio",
@@ -2151,7 +2157,6 @@ dependencies = [
"serde_urlencoded",
"tokio",
"tokio-rustls 0.23.2",
"tokio-util 0.6.9",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
@@ -2175,6 +2180,19 @@ dependencies = [
"winapi",
]
[[package]]
name = "routerify"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c6bb49594c791cadb5ccfa5f36d41b498d40482595c199d10cd318800280bd9"
dependencies = [
"http",
"hyper",
"lazy_static",
"percent-encoding",
"regex",
]
[[package]]
name = "routerify"
version = "3.0.0"
@@ -2188,6 +2206,19 @@ dependencies = [
"regex",
]
[[package]]
name = "rstest"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d912f35156a3f99a66ee3e11ac2e0b3f34ac85a07e05263d05a7e2c8810d616f"
dependencies = [
"cfg-if",
"proc-macro2",
"quote",
"rustc_version",
"syn",
]
[[package]]
name = "rusoto_core"
version = "0.47.0"
@@ -3403,19 +3434,20 @@ dependencies = [
"anyhow",
"bytes",
"cc",
"chrono",
"clap 2.34.0",
"either",
"hashbrown",
"indexmap",
"libc",
"log",
"memchr",
"num-integer",
"num-traits",
"proc-macro2",
"quote",
"prost",
"rand",
"regex",
"regex-syntax",
"reqwest",
"scopeguard",
"serde",
"syn",
@@ -3495,7 +3527,7 @@ dependencies = [
"postgres 0.19.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"rand",
"routerify",
"routerify 3.0.0",
"rustls 0.19.1",
"rustls-split",
"serde",

View File

@@ -1,19 +1,22 @@
# Zenith
# Neon
Zenith is a serverless open source alternative to AWS Aurora Postgres. It separates storage and compute and substitutes PostgreSQL storage layer by redistributing data across a cluster of nodes.
Neon is a serverless open source alternative to AWS Aurora Postgres. It separates storage and compute and substitutes PostgreSQL storage layer by redistributing data across a cluster of nodes.
The project used to be called "Zenith". Many of the commands and code comments
still refer to "zenith", but we are in the process of renaming things.
## Architecture overview
A Zenith installation consists of compute nodes and Zenith storage engine.
A Neon installation consists of compute nodes and Neon storage engine.
Compute nodes are stateless PostgreSQL nodes, backed by Zenith storage engine.
Compute nodes are stateless PostgreSQL nodes, backed by Neon storage engine.
Zenith storage engine consists of two major components:
Neon storage engine consists of two major components:
- Pageserver. Scalable storage backend for compute nodes.
- WAL service. The service that receives WAL from compute node and ensures that it is stored durably.
Pageserver consists of:
- Repository - Zenith storage implementation.
- Repository - Neon storage implementation.
- WAL receiver - service that receives WAL from WAL service and stores it in the repository.
- Page service - service that communicates with compute nodes and responds with pages from the repository.
- WAL redo - service that builds pages from base images and WAL records on Page service request.
@@ -35,10 +38,10 @@ To run the `psql` client, install the `postgresql-client` package or modify `PAT
To run the integration tests or Python scripts (not required to use the code), install
Python (3.7 or higher), and install python3 packages using `./scripts/pysync` (requires poetry) in the project directory.
2. Build zenith and patched postgres
2. Build neon and patched postgres
```sh
git clone --recursive https://github.com/zenithdb/zenith.git
cd zenith
git clone --recursive https://github.com/neondatabase/neon.git
cd neon
make -j5
```
@@ -126,7 +129,7 @@ INSERT 0 1
## Running tests
```sh
git clone --recursive https://github.com/zenithdb/zenith.git
git clone --recursive https://github.com/neondatabase/neon.git
make # builds also postgres and installs it to ./tmp_install
./scripts/pytest
```
@@ -141,14 +144,14 @@ To view your `rustdoc` documentation in a browser, try running `cargo doc --no-d
### Postgres-specific terms
Due to Zenith's very close relation with PostgreSQL internals, there are numerous specific terms used.
Due to Neon's very close relation with PostgreSQL internals, there are numerous specific terms used.
Same applies to certain spelling: i.e. we use MB to denote 1024 * 1024 bytes, while MiB would be technically more correct, it's inconsistent with what PostgreSQL code and its documentation use.
To get more familiar with this aspect, refer to:
- [Zenith glossary](/docs/glossary.md)
- [Neon glossary](/docs/glossary.md)
- [PostgreSQL glossary](https://www.postgresql.org/docs/13/glossary.html)
- Other PostgreSQL documentation and sources (Zenith fork sources can be found [here](https://github.com/zenithdb/postgres))
- Other PostgreSQL documentation and sources (Neon fork sources can be found [here](https://github.com/neondatabase/postgres))
## Join the development

View File

@@ -17,4 +17,5 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
tar = "0.4"
tokio = { version = "1.17", features = ["macros", "rt", "rt-multi-thread"] }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -38,6 +38,7 @@ use clap::Arg;
use log::info;
use postgres::{Client, NoTls};
use compute_tools::checker::create_writablity_check_data;
use compute_tools::config;
use compute_tools::http_api::launch_http_server;
use compute_tools::logger::*;
@@ -128,6 +129,7 @@ fn run_compute(state: &Arc<RwLock<ComputeState>>) -> Result<ExitStatus> {
handle_roles(&read_state.spec, &mut client)?;
handle_databases(&read_state.spec, &mut client)?;
create_writablity_check_data(&mut client)?;
// 'Close' connection
drop(client);

View File

@@ -0,0 +1,46 @@
use std::sync::{Arc, RwLock};
use anyhow::{anyhow, Result};
use log::error;
use postgres::Client;
use tokio_postgres::NoTls;
use crate::zenith::ComputeState;
pub fn create_writablity_check_data(client: &mut Client) -> Result<()> {
let query = "
CREATE TABLE IF NOT EXISTS health_check (
id serial primary key,
updated_at timestamptz default now()
);
INSERT INTO health_check VALUES (1, now())
ON CONFLICT (id) DO UPDATE
SET updated_at = now();";
let result = client.simple_query(query)?;
if result.len() < 2 {
return Err(anyhow::format_err!("executed {} queries", result.len()));
}
Ok(())
}
pub async fn check_writability(state: &Arc<RwLock<ComputeState>>) -> Result<()> {
let connstr = state.read().unwrap().connstr.clone();
let (client, connection) = tokio_postgres::connect(&connstr, NoTls).await?;
if client.is_closed() {
return Err(anyhow!("connection to postgres closed"));
}
tokio::spawn(async move {
if let Err(e) = connection.await {
error!("connection error: {}", e);
}
});
let result = client
.simple_query("UPDATE health_check SET updated_at = now() WHERE id = 1;")
.await?;
if result.len() != 1 {
return Err(anyhow!("statement can't be executed"));
}
Ok(())
}

View File

@@ -11,7 +11,7 @@ use log::{error, info};
use crate::zenith::*;
// Service function to handle all available routes.
fn routes(req: Request<Body>, state: Arc<RwLock<ComputeState>>) -> Response<Body> {
async fn routes(req: Request<Body>, state: Arc<RwLock<ComputeState>>) -> Response<Body> {
match (req.method(), req.uri().path()) {
// Timestamp of the last Postgres activity in the plain text.
(&Method::GET, "/last_activity") => {
@@ -29,6 +29,15 @@ fn routes(req: Request<Body>, state: Arc<RwLock<ComputeState>>) -> Response<Body
Response::new(Body::from(format!("{}", state.ready)))
}
(&Method::GET, "/check_writability") => {
info!("serving /check_writability GET request");
let res = crate::checker::check_writability(&state).await;
match res {
Ok(_) => Response::new(Body::from("true")),
Err(e) => Response::new(Body::from(e.to_string())),
}
}
// Return the `404 Not Found` for any other routes.
_ => {
let mut not_found = Response::new(Body::from("404 Not Found"));
@@ -48,7 +57,7 @@ async fn serve(state: Arc<RwLock<ComputeState>>) {
async move {
Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
let state = state.clone();
async move { Ok::<_, Infallible>(routes(req, state)) }
async move { Ok::<_, Infallible>(routes(req, state).await) }
}))
}
});

View File

@@ -2,6 +2,7 @@
//! Various tools and helpers to handle cluster / compute node (Postgres)
//! configuration.
//!
pub mod checker;
pub mod config;
pub mod http_api;
#[macro_use]

View File

@@ -37,7 +37,6 @@ toml_edit = { version = "0.13", features = ["easy"] }
scopeguard = "1.1.0"
const_format = "0.2.21"
tracing = "0.1.27"
tracing-futures = "0.2"
signal-hook = "0.3.10"
url = "2"
nix = "0.23"

View File

@@ -36,8 +36,8 @@ pub mod defaults {
// Target file size, when creating image and delta layers.
// This parameter determines L1 layer file size.
pub const DEFAULT_COMPACTION_TARGET_SIZE: u64 = 128 * 1024 * 1024;
pub const DEFAULT_COMPACTION_PERIOD: &str = "1 s";
pub const DEFAULT_COMPACTION_THRESHOLD: usize = 10;
pub const DEFAULT_GC_HORIZON: u64 = 64 * 1024 * 1024;
pub const DEFAULT_GC_PERIOD: &str = "100 s";
@@ -65,6 +65,7 @@ pub mod defaults {
#checkpoint_distance = {DEFAULT_CHECKPOINT_DISTANCE} # in bytes
#compaction_target_size = {DEFAULT_COMPACTION_TARGET_SIZE} # in bytes
#compaction_period = '{DEFAULT_COMPACTION_PERIOD}'
#compaction_threshold = '{DEFAULT_COMPACTION_THRESHOLD}'
#gc_period = '{DEFAULT_GC_PERIOD}'
#gc_horizon = {DEFAULT_GC_HORIZON}
@@ -107,6 +108,9 @@ pub struct PageServerConf {
// How often to check if there's compaction work to be done.
pub compaction_period: Duration,
// Level0 delta layer threshold for compaction.
pub compaction_threshold: usize,
pub gc_horizon: u64,
pub gc_period: Duration,
@@ -164,6 +168,7 @@ struct PageServerConfigBuilder {
compaction_target_size: BuilderValue<u64>,
compaction_period: BuilderValue<Duration>,
compaction_threshold: BuilderValue<usize>,
gc_horizon: BuilderValue<u64>,
gc_period: BuilderValue<Duration>,
@@ -202,6 +207,7 @@ impl Default for PageServerConfigBuilder {
compaction_target_size: Set(DEFAULT_COMPACTION_TARGET_SIZE),
compaction_period: Set(humantime::parse_duration(DEFAULT_COMPACTION_PERIOD)
.expect("cannot parse default compaction period")),
compaction_threshold: Set(DEFAULT_COMPACTION_THRESHOLD),
gc_horizon: Set(DEFAULT_GC_HORIZON),
gc_period: Set(humantime::parse_duration(DEFAULT_GC_PERIOD)
.expect("cannot parse default gc period")),
@@ -246,6 +252,10 @@ impl PageServerConfigBuilder {
self.compaction_period = BuilderValue::Set(compaction_period)
}
pub fn compaction_threshold(&mut self, compaction_threshold: usize) {
self.compaction_threshold = BuilderValue::Set(compaction_threshold)
}
pub fn gc_horizon(&mut self, gc_horizon: u64) {
self.gc_horizon = BuilderValue::Set(gc_horizon)
}
@@ -322,6 +332,9 @@ impl PageServerConfigBuilder {
compaction_period: self
.compaction_period
.ok_or(anyhow::anyhow!("missing compaction_period"))?,
compaction_threshold: self
.compaction_threshold
.ok_or(anyhow::anyhow!("missing compaction_threshold"))?,
gc_horizon: self
.gc_horizon
.ok_or(anyhow::anyhow!("missing gc_horizon"))?,
@@ -465,6 +478,9 @@ impl PageServerConf {
builder.compaction_target_size(parse_toml_u64(key, item)?)
}
"compaction_period" => builder.compaction_period(parse_toml_duration(key, item)?),
"compaction_threshold" => {
builder.compaction_threshold(parse_toml_u64(key, item)? as usize)
}
"gc_horizon" => builder.gc_horizon(parse_toml_u64(key, item)?),
"gc_period" => builder.gc_period(parse_toml_duration(key, item)?),
"wait_lsn_timeout" => builder.wait_lsn_timeout(parse_toml_duration(key, item)?),
@@ -603,6 +619,7 @@ impl PageServerConf {
checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE,
compaction_target_size: 4 * 1024 * 1024,
compaction_period: Duration::from_secs(10),
compaction_threshold: defaults::DEFAULT_COMPACTION_THRESHOLD,
gc_horizon: defaults::DEFAULT_GC_HORIZON,
gc_period: Duration::from_secs(10),
wait_lsn_timeout: Duration::from_secs(60),
@@ -676,6 +693,7 @@ checkpoint_distance = 111 # in bytes
compaction_target_size = 111 # in bytes
compaction_period = '111 s'
compaction_threshold = 2
gc_period = '222 s'
gc_horizon = 222
@@ -714,6 +732,7 @@ id = 10
checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE,
compaction_target_size: defaults::DEFAULT_COMPACTION_TARGET_SIZE,
compaction_period: humantime::parse_duration(defaults::DEFAULT_COMPACTION_PERIOD)?,
compaction_threshold: defaults::DEFAULT_COMPACTION_THRESHOLD,
gc_horizon: defaults::DEFAULT_GC_HORIZON,
gc_period: humantime::parse_duration(defaults::DEFAULT_GC_PERIOD)?,
wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?,
@@ -760,6 +779,7 @@ id = 10
checkpoint_distance: 111,
compaction_target_size: 111,
compaction_period: Duration::from_secs(111),
compaction_threshold: 2,
gc_horizon: 222,
gc_period: Duration::from_secs(222),
wait_lsn_timeout: Duration::from_secs(111),

View File

@@ -49,7 +49,8 @@ use crate::CheckpointConfig;
use crate::{ZTenantId, ZTimelineId};
use zenith_metrics::{
register_histogram_vec, register_int_gauge_vec, Histogram, HistogramVec, IntGauge, IntGaugeVec,
register_histogram_vec, register_int_counter, register_int_gauge_vec, Histogram, HistogramVec,
IntCounter, IntGauge, IntGaugeVec,
};
use zenith_utils::crashsafe_dir;
use zenith_utils::lsn::{AtomicLsn, Lsn, RecordLsn};
@@ -109,6 +110,21 @@ lazy_static! {
.expect("failed to define a metric");
}
// Metrics for cloud upload. These metrics reflect data uploaded to cloud storage,
// or in testing they estimate how much we would upload if we did.
lazy_static! {
static ref NUM_PERSISTENT_FILES_CREATED: IntCounter = register_int_counter!(
"pageserver_num_persistent_files_created",
"Number of files created that are meant to be uploaded to cloud storage",
)
.expect("failed to define a metric");
static ref PERSISTENT_BYTES_WRITTEN: IntCounter = register_int_counter!(
"pageserver_persistent_bytes_written",
"Total bytes written that are meant to be uploaded to cloud storage",
)
.expect("failed to define a metric");
}
/// Parts of the `.zenith/tenants/<tenantid>/timelines/<timelineid>` directory prefix.
pub const TIMELINES_SEGMENT_NAME: &str = "timelines";
@@ -193,7 +209,7 @@ impl Repository for LayeredRepository {
Arc::clone(&self.walredo_mgr),
self.upload_layers,
);
timeline.layers.lock().unwrap().next_open_layer_at = Some(initdb_lsn);
timeline.layers.write().unwrap().next_open_layer_at = Some(initdb_lsn);
let timeline = Arc::new(timeline);
let r = timelines.insert(
@@ -725,7 +741,7 @@ pub struct LayeredTimeline {
tenantid: ZTenantId,
timelineid: ZTimelineId,
layers: Mutex<LayerMap>,
layers: RwLock<LayerMap>,
last_freeze_at: AtomicLsn,
@@ -997,7 +1013,7 @@ impl LayeredTimeline {
conf,
timelineid,
tenantid,
layers: Mutex::new(LayerMap::default()),
layers: RwLock::new(LayerMap::default()),
walredo_mgr,
@@ -1040,7 +1056,7 @@ impl LayeredTimeline {
/// Returns all timeline-related files that were found and loaded.
///
fn load_layer_map(&self, disk_consistent_lsn: Lsn) -> anyhow::Result<()> {
let mut layers = self.layers.lock().unwrap();
let mut layers = self.layers.write().unwrap();
let mut num_layers = 0;
// Scan timeline directory and create ImageFileName and DeltaFilename
@@ -1194,7 +1210,7 @@ impl LayeredTimeline {
continue;
}
let layers = timeline.layers.lock().unwrap();
let layers = timeline.layers.read().unwrap();
// Check the open and frozen in-memory layers first
if let Some(open_layer) = &layers.open_layer {
@@ -1276,7 +1292,7 @@ impl LayeredTimeline {
/// Get a handle to the latest layer for appending.
///
fn get_layer_for_write(&self, lsn: Lsn) -> anyhow::Result<Arc<InMemoryLayer>> {
let mut layers = self.layers.lock().unwrap();
let mut layers = self.layers.write().unwrap();
ensure!(lsn.is_aligned());
@@ -1347,7 +1363,7 @@ impl LayeredTimeline {
} else {
Some(self.write_lock.lock().unwrap())
};
let mut layers = self.layers.lock().unwrap();
let mut layers = self.layers.write().unwrap();
if let Some(open_layer) = &layers.open_layer {
let open_layer_rc = Arc::clone(open_layer);
// Does this layer need freezing?
@@ -1412,7 +1428,7 @@ impl LayeredTimeline {
let timer = self.flush_time_histo.start_timer();
loop {
let layers = self.layers.lock().unwrap();
let layers = self.layers.read().unwrap();
if let Some(frozen_layer) = layers.frozen_layers.front() {
let frozen_layer = Arc::clone(frozen_layer);
drop(layers); // to allow concurrent reads and writes
@@ -1456,7 +1472,7 @@ impl LayeredTimeline {
// Finally, replace the frozen in-memory layer with the new on-disk layers
{
let mut layers = self.layers.lock().unwrap();
let mut layers = self.layers.write().unwrap();
let l = layers.frozen_layers.pop_front();
// Only one thread may call this function at a time (for this
@@ -1524,6 +1540,10 @@ impl LayeredTimeline {
&metadata,
false,
)?;
NUM_PERSISTENT_FILES_CREATED.inc_by(1);
PERSISTENT_BYTES_WRITTEN.inc_by(new_delta_path.metadata()?.len());
if self.upload_layers.load(atomic::Ordering::Relaxed) {
schedule_timeline_checkpoint_upload(
self.tenantid,
@@ -1612,7 +1632,7 @@ impl LayeredTimeline {
lsn: Lsn,
threshold: usize,
) -> Result<bool> {
let layers = self.layers.lock().unwrap();
let layers = self.layers.read().unwrap();
for part_range in &partition.ranges {
let image_coverage = layers.image_coverage(part_range, lsn)?;
@@ -1670,7 +1690,7 @@ impl LayeredTimeline {
// FIXME: Do we need to do something to upload it to remote storage here?
let mut layers = self.layers.lock().unwrap();
let mut layers = self.layers.write().unwrap();
layers.insert_historic(Arc::new(image_layer));
drop(layers);
@@ -1678,15 +1698,13 @@ impl LayeredTimeline {
}
fn compact_level0(&self, target_file_size: u64) -> Result<()> {
let layers = self.layers.lock().unwrap();
// We compact or "shuffle" the level-0 delta layers when 10 have
// accumulated.
static COMPACT_THRESHOLD: usize = 10;
let layers = self.layers.read().unwrap();
let level0_deltas = layers.get_level0_deltas()?;
if level0_deltas.len() < COMPACT_THRESHOLD {
// We compact or "shuffle" the level-0 delta layers when they've
// accumulated over the compaction threshold.
if level0_deltas.len() < self.conf.compaction_threshold {
return Ok(());
}
drop(layers);
@@ -1770,7 +1788,7 @@ impl LayeredTimeline {
layer_paths.pop().unwrap();
}
let mut layers = self.layers.lock().unwrap();
let mut layers = self.layers.write().unwrap();
for l in new_layers {
layers.insert_historic(Arc::new(l));
}
@@ -1852,7 +1870,7 @@ impl LayeredTimeline {
// 2. it doesn't need to be retained for 'retain_lsns';
// 3. newer on-disk image layers cover the layer's whole key range
//
let mut layers = self.layers.lock().unwrap();
let mut layers = self.layers.write().unwrap();
'outer: for l in layers.iter_historic_layers() {
// This layer is in the process of being flushed to disk.
// It will be swapped out of the layer map, replaced with

View File

@@ -198,7 +198,6 @@ impl BlockWriter for BlockBuf {
assert!(buf.len() == PAGE_SZ);
let blknum = self.blocks.len();
self.blocks.push(buf);
tracing::info!("buffered block {}", blknum);
Ok(blknum as u32)
}
}

View File

@@ -713,6 +713,26 @@ impl postgres_backend::Handler for PageServerHandler {
Some(result.elapsed.as_millis().to_string().as_bytes()),
]))?
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("compact ") {
// Run compaction immediately on given timeline.
// FIXME This is just for tests. Don't expect this to be exposed to
// the users or the api.
// compact <tenant_id> <timeline_id>
let re = Regex::new(r"^compact ([[:xdigit:]]+)\s([[:xdigit:]]+)($|\s)?").unwrap();
let caps = re
.captures(query_string)
.with_context(|| format!("Invalid compact: '{}'", query_string))?;
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
let timelineid = ZTimelineId::from_str(caps.get(2).unwrap().as_str())?;
let timeline = tenant_mgr::get_timeline_for_tenant_load(tenantid, timelineid)
.context("Couldn't load timeline")?;
timeline.tline.compact()?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("checkpoint ") {
// Run checkpoint immediately on given timeline.

View File

@@ -1,6 +1,6 @@
//! Timeline synchrnonization logic to put files from archives on remote storage into pageserver's local directory.
use std::{borrow::Cow, collections::BTreeSet, path::PathBuf, sync::Arc};
use std::{collections::BTreeSet, path::PathBuf, sync::Arc};
use anyhow::{ensure, Context};
use tokio::fs;
@@ -64,11 +64,16 @@ pub(super) async fn download_timeline<
let remote_timeline = match index_read.timeline_entry(&sync_id) {
None => {
error!("Cannot download: no timeline is present in the index for given id");
drop(index_read);
return DownloadedTimeline::Abort;
}
Some(index_entry) => match index_entry.inner() {
TimelineIndexEntryInner::Full(remote_timeline) => Cow::Borrowed(remote_timeline),
TimelineIndexEntryInner::Full(remote_timeline) => {
let cloned = remote_timeline.clone();
drop(index_read);
cloned
}
TimelineIndexEntryInner::Description(_) => {
// we do not check here for awaits_download because it is ok
// to call this function while the download is in progress
@@ -84,7 +89,7 @@ pub(super) async fn download_timeline<
)
.await
{
Ok(remote_timeline) => Cow::Owned(remote_timeline),
Ok(remote_timeline) => remote_timeline,
Err(e) => {
error!("Failed to download full timeline index: {:?}", e);

View File

@@ -1,6 +1,6 @@
//! Timeline synchronization logic to compress and upload to the remote storage all new timeline files from the checkpoints.
use std::{borrow::Cow, collections::BTreeSet, path::PathBuf, sync::Arc};
use std::{collections::BTreeSet, path::PathBuf, sync::Arc};
use tracing::{debug, error, warn};
@@ -46,13 +46,21 @@ pub(super) async fn upload_timeline_checkpoint<
let index_read = index.read().await;
let remote_timeline = match index_read.timeline_entry(&sync_id) {
None => None,
None => {
drop(index_read);
None
}
Some(entry) => match entry.inner() {
TimelineIndexEntryInner::Full(remote_timeline) => Some(Cow::Borrowed(remote_timeline)),
TimelineIndexEntryInner::Full(remote_timeline) => {
let r = Some(remote_timeline.clone());
drop(index_read);
r
}
TimelineIndexEntryInner::Description(_) => {
drop(index_read);
debug!("Found timeline description for the given ids, downloading the full index");
match fetch_full_index(remote_assets.as_ref(), &timeline_dir, sync_id).await {
Ok(remote_timeline) => Some(Cow::Owned(remote_timeline)),
Ok(remote_timeline) => Some(remote_timeline),
Err(e) => {
error!("Failed to download full timeline index: {:?}", e);
sync_queue::push(SyncTask::new(
@@ -82,7 +90,6 @@ pub(super) async fn upload_timeline_checkpoint<
let already_uploaded_files = remote_timeline
.map(|timeline| timeline.stored_files(&timeline_dir))
.unwrap_or_default();
drop(index_read);
match try_upload_checkpoint(
config,

View File

@@ -252,8 +252,10 @@ pub trait Repository: Send + Sync {
checkpoint_before_gc: bool,
) -> Result<GcResult>;
/// perform one compaction iteration.
/// this function is periodically called by compactor thread.
/// Perform one compaction iteration.
/// This function is periodically called by compactor thread.
/// Also it can be explicitly requested per timeline through page server
/// api's 'compact' command.
fn compaction_iteration(&self) -> Result<()>;
/// detaches locally available timeline by stopping all threads and removing all the data.

View File

@@ -5,12 +5,14 @@ edition = "2021"
[dependencies]
anyhow = "1.0"
base64 = "0.13.0"
bytes = { version = "1.0.1", features = ['serde'] }
clap = "3.0"
fail = "0.5.0"
futures = "0.3.13"
hashbrown = "0.11.2"
hex = "0.4.3"
hmac = "0.10.1"
hyper = "0.14"
lazy_static = "1.4.0"
md5 = "0.7.0"
@@ -18,12 +20,14 @@ parking_lot = "0.11.2"
pin-project-lite = "0.2.7"
rand = "0.8.3"
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
routerify = "2"
rustls = "0.19.1"
scopeguard = "1.1.0"
serde = "1"
serde_json = "1"
sha2 = "0.9.8"
socket2 = "0.4.4"
thiserror = "1.0"
thiserror = "1.0.30"
tokio = { version = "1.17", features = ["macros"] }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
tokio-rustls = "0.22.0"
@@ -33,5 +37,7 @@ zenith_metrics = { path = "../zenith_metrics" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
[dev-dependencies]
tokio-postgres-rustls = "0.8.0"
async-trait = "0.1"
rcgen = "0.8.14"
rstest = "0.12"
tokio-postgres-rustls = "0.8.0"

View File

@@ -1,14 +1,24 @@
mod credentials;
#[cfg(test)]
mod flow;
use crate::compute::DatabaseInfo;
use crate::config::ProxyConfig;
use crate::cplane_api::{self, CPlaneApi};
use crate::error::UserFacingError;
use crate::stream::PqStream;
use crate::waiters;
use std::collections::HashMap;
use std::io;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
pub use credentials::ClientCredentials;
#[cfg(test)]
pub use flow::*;
/// Common authentication error.
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
@@ -16,13 +26,17 @@ pub enum AuthErrorImpl {
#[error(transparent)]
Console(#[from] cplane_api::AuthError),
#[cfg(test)]
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
/// For passwords that couldn't be processed by [`parse_password`].
#[error("Malformed password message")]
MalformedPassword,
/// Errors produced by [`PqStream`].
#[error(transparent)]
Io(#[from] std::io::Error),
Io(#[from] io::Error),
}
impl AuthErrorImpl {
@@ -67,70 +81,6 @@ impl UserFacingError for AuthError {
}
}
#[derive(Debug, Error)]
pub enum ClientCredsParseError {
#[error("Parameter `{0}` is missing in startup packet")]
MissingKey(&'static str),
}
impl UserFacingError for ClientCredsParseError {}
/// Various client credentials which we use for authentication.
#[derive(Debug, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
type Error = ClientCredsParseError;
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
let mut get_param = |key| {
value
.remove(key)
.ok_or(ClientCredsParseError::MissingKey(key))
};
let user = get_param("user")?;
let db = get_param("database")?;
Ok(Self { user, dbname: db })
}
}
impl ClientCredentials {
/// Use credentials to authenticate the user.
pub async fn authenticate(
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<DatabaseInfo, AuthError> {
fail::fail_point!("proxy-authenticate", |_| {
Err(AuthError::auth_failed("failpoint triggered"))
});
use crate::config::ClientAuthMethod::*;
use crate::config::RouterConfig::*;
match &config.router_config {
Static { host, port } => handle_static(host.clone(), *port, client, self).await,
Dynamic(Mixed) => {
if self.user.ends_with("@zenith") {
handle_existing_user(config, client, self).await
} else {
handle_new_user(config, client).await
}
}
Dynamic(Password) => handle_existing_user(config, client, self).await,
Dynamic(Link) => handle_new_user(config, client).await,
}
}
}
fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
async fn handle_static(
host: String,
port: u16,
@@ -169,7 +119,7 @@ async fn handle_existing_user(
let md5_salt = rand::random();
client
.write_message(&Be::AuthenticationMD5Password(&md5_salt))
.write_message(&Be::AuthenticationMD5Password(md5_salt))
.await?;
// Read client's password hash
@@ -213,6 +163,10 @@ async fn handle_new_user(
Ok(db_info)
}
fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
fn parse_password(bytes: &[u8]) -> Option<&str> {
std::str::from_utf8(bytes).ok()?.strip_suffix('\0')
}

View File

@@ -0,0 +1,70 @@
//! User credentials used in authentication.
use super::AuthError;
use crate::compute::DatabaseInfo;
use crate::config::ProxyConfig;
use crate::error::UserFacingError;
use crate::stream::PqStream;
use std::collections::HashMap;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
#[derive(Debug, Error)]
pub enum ClientCredsParseError {
#[error("Parameter `{0}` is missing in startup packet")]
MissingKey(&'static str),
}
impl UserFacingError for ClientCredsParseError {}
/// Various client credentials which we use for authentication.
#[derive(Debug, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
type Error = ClientCredsParseError;
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
let mut get_param = |key| {
value
.remove(key)
.ok_or(ClientCredsParseError::MissingKey(key))
};
let user = get_param("user")?;
let db = get_param("database")?;
Ok(Self { user, dbname: db })
}
}
impl ClientCredentials {
/// Use credentials to authenticate the user.
pub async fn authenticate(
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<DatabaseInfo, AuthError> {
fail::fail_point!("proxy-authenticate", |_| {
Err(AuthError::auth_failed("failpoint triggered"))
});
use crate::config::ClientAuthMethod::*;
use crate::config::RouterConfig::*;
match &config.router_config {
Static { host, port } => super::handle_static(host.clone(), *port, client, self).await,
Dynamic(Mixed) => {
if self.user.ends_with("@zenith") {
super::handle_existing_user(config, client, self).await
} else {
super::handle_new_user(config, client).await
}
}
Dynamic(Password) => super::handle_existing_user(config, client, self).await,
Dynamic(Link) => super::handle_new_user(config, client).await,
}
}
}

102
proxy/src/auth/flow.rs Normal file
View File

@@ -0,0 +1,102 @@
//! Main authentication flow.
use super::{AuthError, AuthErrorImpl};
use crate::stream::PqStream;
use crate::{sasl, scram};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
/// Every authentication selector is supposed to implement this trait.
pub trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
pub struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub struct Scram<'a>(pub &'a scram::ServerSecret);
impl AuthMethod for Scram<'_> {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
}
}
/// Use password-based auth in [`AuthFlow`].
pub struct Md5(
/// Salt for client.
pub [u8; 4],
);
impl AuthMethod for Md5 {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationMD5Password(self.0)
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, Stream, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream>,
/// State might contain ancillary data (see [`AuthFlow::begin`]).
state: State,
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
/// Create a new wrapper for client authentication.
pub fn new(stream: &'a mut PqStream<S>) -> Self {
Self {
stream,
state: Begin,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
self.stream.write_message(&method.first_message()).await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
})
}
}
/// Stream wrapper for handling simple MD5 password auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Md5> {
/// Perform user authentication. Raise an error in case authentication failed.
#[allow(unused)]
pub async fn authenticate(self) -> Result<(), AuthError> {
unimplemented!("MD5 auth flow is yet to be implemented");
}
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> Result<(), AuthError> {
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {
return Err(AuthErrorImpl::auth_failed("method not supported").into());
}
let secret = self.state.0;
sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(secret, rand::random, None))
.await?;
Ok(())
}
}

View File

@@ -1,19 +1,8 @@
///
/// Postgres protocol proxy/router.
///
/// This service listens psql port and can check auth via external service
/// (control plane API in our case) and can create new databases and accounts
/// in somewhat transparent manner (again via communication with control plane API).
///
use anyhow::{bail, Context};
use clap::{Arg, Command};
use config::ProxyConfig;
use futures::FutureExt;
use std::future::Future;
use tokio::{net::TcpListener, task::JoinError};
use zenith_utils::GIT_VERSION;
use crate::config::{ClientAuthMethod, RouterConfig};
//! Postgres protocol proxy/router.
//!
//! This service listens psql port and can check auth via external service
//! (control plane API in our case) and can create new databases and accounts
//! in somewhat transparent manner (again via communication with control plane API).
mod auth;
mod cancellation;
@@ -27,6 +16,24 @@ mod proxy;
mod stream;
mod waiters;
// Currently SCRAM is only used in tests
#[cfg(test)]
mod parse;
#[cfg(test)]
mod sasl;
#[cfg(test)]
mod scram;
use anyhow::{bail, Context};
use clap::{Arg, Command};
use config::ProxyConfig;
use futures::FutureExt;
use std::future::Future;
use tokio::{net::TcpListener, task::JoinError};
use zenith_utils::GIT_VERSION;
use crate::config::{ClientAuthMethod, RouterConfig};
/// Flattens `Result<Result<T>>` into `Result<T>`.
async fn flatten_err(
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,

18
proxy/src/parse.rs Normal file
View File

@@ -0,0 +1,18 @@
//! Small parsing helpers.
use std::convert::TryInto;
use std::ffi::CStr;
pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
let pos = bytes.iter().position(|&x| x == 0)?;
let (cstr, other) = bytes.split_at(pos + 1);
// SAFETY: we've already checked that there's a terminator
Some((unsafe { CStr::from_bytes_with_nul_unchecked(cstr) }, other))
}
pub fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
(bytes.len() >= N).then(|| {
let (head, tail) = bytes.split_at(N);
(head.try_into().unwrap(), tail)
})
}

View File

@@ -119,7 +119,6 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
@@ -219,32 +218,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<S> {
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::DuplexStream;
use crate::{auth, scram};
use async_trait::async_trait;
use rstest::rstest;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::MakeRustlsConnect;
async fn dummy_proxy(
client: impl AsyncRead + AsyncWrite + Unpin,
tls: Option<TlsConfig>,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
// TODO: add some infra + tests for credentials
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
.await?
.context("no stream")?;
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&BeMessage::ReadyForQuery)
.await?;
Ok(())
}
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
hostname: &str,
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
@@ -262,19 +243,115 @@ mod tests {
))
}
struct ClientConfig<'a> {
config: rustls::ClientConfig,
hostname: &'a str,
}
impl ClientConfig<'_> {
fn make_tls_connect<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
self,
) -> anyhow::Result<impl tokio_postgres::tls::TlsConnect<S>> {
let mut mk = MakeRustlsConnect::new(self.config);
let tls = MakeTlsConnect::<S>::make_tls_connect(&mut mk, self.hostname)?;
Ok(tls)
}
}
/// Generate TLS certificates and build rustls configs for client and server.
fn generate_tls_config(
hostname: &str,
) -> anyhow::Result<(ClientConfig<'_>, Arc<rustls::ServerConfig>)> {
let (ca, cert, key) = generate_certs(hostname)?;
let server_config = {
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
config.into()
};
let client_config = {
let mut config = rustls::ClientConfig::new();
config.root_store.add(&ca)?;
ClientConfig { config, hostname }
};
Ok((client_config, server_config))
}
#[async_trait]
trait TestAuth: Sized {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
self,
_stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
Ok(())
}
}
struct NoAuth;
impl TestAuth for NoAuth {}
struct Scram(scram::ServerSecret);
impl Scram {
fn new(password: &str) -> anyhow::Result<Self> {
let salt = rand::random::<[u8; 16]>();
let secret = scram::ServerSecret::build(password, &salt, 256)
.context("failed to generate scram secret")?;
Ok(Scram(secret))
}
fn mock(user: &str) -> Self {
let salt = rand::random::<[u8; 32]>();
Scram(scram::ServerSecret::mock(user, &salt))
}
}
#[async_trait]
impl TestAuth for Scram {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
auth::AuthFlow::new(stream)
.begin(auth::Scram(&self.0))
.await?
.authenticate()
.await?;
Ok(())
}
}
/// A dummy proxy impl which performs a handshake and reports auth success.
async fn dummy_proxy(
client: impl AsyncRead + AsyncWrite + Unpin + Send,
tls: Option<TlsConfig>,
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
.await?
.context("handshake failed")?;
auth.authenticate(&mut stream).await?;
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&BeMessage::ReadyForQuery)
.await?;
Ok(())
}
#[tokio::test]
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let server_config = {
let (_ca, cert, key) = generate_certs("localhost")?;
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
config
};
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
let (_, server_config) = generate_tls_config("localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let client_err = tokio_postgres::Config::new()
.user("john_doe")
@@ -301,30 +378,14 @@ mod tests {
async fn handshake_tls() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (ca, cert, key) = generate_certs("localhost")?;
let server_config = {
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
config
};
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
let client_config = {
let mut config = rustls::ClientConfig::new();
config.root_store.add(&ca)?;
config
};
let mut mk = MakeRustlsConnect::new(client_config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, "localhost")?;
let (client_config, server_config) = generate_tls_config("localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Require)
.connect_raw(server, tls)
.connect_raw(server, client_config.make_tls_connect()?)
.await?;
proxy.await?
@@ -334,7 +395,7 @@ mod tests {
async fn handshake_raw() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None));
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
@@ -350,7 +411,7 @@ mod tests {
async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None));
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let client_err = tokio_postgres::Config::new()
.ssl_mode(SslMode::Disable)
@@ -391,4 +452,66 @@ mod tests {
Ok(())
}
#[rstest]
#[case("password_foo")]
#[case("pwd-bar")]
#[case("")]
#[tokio::test]
async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) = generate_tls_config("localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::new(password)?,
));
let (_client, _conn) = tokio_postgres::Config::new()
.user("user")
.dbname("db")
.password(password)
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await?;
proxy.await?
}
#[tokio::test]
async fn scram_auth_mock() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) = generate_tls_config("localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::mock("user"),
));
use rand::{distributions::Alphanumeric, Rng};
let password: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(rand::random::<u8>() as usize)
.map(char::from)
.collect();
let _client_err = tokio_postgres::Config::new()
.user("user")
.dbname("db")
.password(&password) // no password will match the mocked secret
.ssl_mode(SslMode::Require)
.connect_raw(server, client_config.make_tls_connect()?)
.await
.err() // -> Option<E>
.context("client shouldn't be able to connect")?;
let _server_err = proxy
.await?
.err() // -> Option<E>
.context("server shouldn't accept client")?;
Ok(())
}
}

47
proxy/src/sasl.rs Normal file
View File

@@ -0,0 +1,47 @@
//! Simple Authentication and Security Layer.
//!
//! RFC: <https://datatracker.ietf.org/doc/html/rfc4422>.
//!
//! Reference implementation:
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-sasl.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth.c>
mod channel_binding;
mod messages;
mod stream;
use std::io;
use thiserror::Error;
pub use channel_binding::ChannelBinding;
pub use messages::FirstMessage;
pub use stream::SaslStream;
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]
pub enum Error {
#[error("Failed to authenticate client: {0}")]
AuthenticationFailed(&'static str),
#[error("Channel binding failed: {0}")]
ChannelBindingFailed(&'static str),
#[error("Unsupported channel binding method: {0}")]
ChannelBindingBadMethod(Box<str>),
#[error("Bad client message")]
BadClientMessage,
#[error(transparent)]
Io(#[from] io::Error),
}
/// A convenient result type for SASL exchange.
pub type Result<T> = std::result::Result<T, Error>;
/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait.
pub trait Mechanism: Sized {
/// Produce a server challenge to be sent to the client.
/// This is how this method is called in PostgreSQL (`libpq/sasl.h`).
fn exchange(self, input: &str) -> Result<(Option<Self>, String)>;
}

View File

@@ -0,0 +1,85 @@
//! Definition and parser for channel binding flag (a part of the `GS2` header).
/// Channel binding flag (possibly with params).
#[derive(Debug, PartialEq, Eq)]
pub enum ChannelBinding<T> {
/// Client doesn't support channel binding.
NotSupportedClient,
/// Client thinks server doesn't support channel binding.
NotSupportedServer,
/// Client wants to use this type of channel binding.
Required(T),
}
impl<T> ChannelBinding<T> {
pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
use ChannelBinding::*;
Ok(match self {
NotSupportedClient => NotSupportedClient,
NotSupportedServer => NotSupportedServer,
Required(x) => Required(f(x)?),
})
}
}
impl<'a> ChannelBinding<&'a str> {
// NB: FromStr doesn't work with lifetimes
pub fn parse(input: &'a str) -> Option<Self> {
use ChannelBinding::*;
Some(match input {
"n" => NotSupportedClient,
"y" => NotSupportedServer,
other => Required(other.strip_prefix("p=")?),
})
}
}
impl<T: std::fmt::Display> ChannelBinding<T> {
/// Encode channel binding data as base64 for subsequent checks.
pub fn encode<E>(
&self,
get_cbind_data: impl FnOnce(&T) -> Result<String, E>,
) -> Result<std::borrow::Cow<'static, str>, E> {
use ChannelBinding::*;
Ok(match self {
NotSupportedClient => {
// base64::encode("n,,")
"biws".into()
}
NotSupportedServer => {
// base64::encode("y,,")
"eSws".into()
}
Required(mode) => {
let msg = format!(
"p={mode},,{data}",
mode = mode,
data = get_cbind_data(mode)?
);
base64::encode(msg).into()
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn channel_binding_encode() -> anyhow::Result<()> {
use ChannelBinding::*;
let cases = [
(NotSupportedClient, base64::encode("n,,")),
(NotSupportedServer, base64::encode("y,,")),
(Required("foo"), base64::encode("p=foo,,bar")),
];
for (cb, input) in cases {
assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input);
}
Ok(())
}
}

View File

@@ -0,0 +1,67 @@
//! Definitions for SASL messages.
use crate::parse::{split_at_const, split_cstr};
use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage};
/// SASL-specific payload of [`PasswordMessage`](zenith_utils::pq_proto::FeMessage::PasswordMessage).
#[derive(Debug)]
pub struct FirstMessage<'a> {
/// Authentication method, e.g. `"SCRAM-SHA-256"`.
pub method: &'a str,
/// Initial client message.
pub message: &'a str,
}
impl<'a> FirstMessage<'a> {
// NB: FromStr doesn't work with lifetimes
pub fn parse(bytes: &'a [u8]) -> Option<Self> {
let (method_cstr, tail) = split_cstr(bytes)?;
let method = method_cstr.to_str().ok()?;
let (len_bytes, bytes) = split_at_const(tail)?;
let len = u32::from_be_bytes(*len_bytes) as usize;
if len != bytes.len() {
return None;
}
let message = std::str::from_utf8(bytes).ok()?;
Some(Self { method, message })
}
}
/// A single SASL message.
/// This struct is deliberately decoupled from lower-level
/// [`BeAuthenticationSaslMessage`](zenith_utils::pq_proto::BeAuthenticationSaslMessage).
#[derive(Debug)]
pub(super) enum ServerMessage<T> {
/// We expect to see more steps.
Continue(T),
/// This is the final step.
Final(T),
}
impl<'a> ServerMessage<&'a str> {
pub(super) fn to_reply(&self) -> BeMessage<'a> {
use BeAuthenticationSaslMessage::*;
BeMessage::AuthenticationSasl(match self {
ServerMessage::Continue(s) => Continue(s.as_bytes()),
ServerMessage::Final(s) => Final(s.as_bytes()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_sasl_first_message() {
let proto = "SCRAM-SHA-256";
let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4";
let sasl_len = (sasl.len() as u32).to_be_bytes();
let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat();
let password = FirstMessage::parse(&bytes).unwrap();
assert_eq!(password.method, proto);
assert_eq!(password.message, sasl);
}
}

70
proxy/src/sasl/stream.rs Normal file
View File

@@ -0,0 +1,70 @@
//! Abstraction for the string-oriented SASL protocols.
use super::{messages::ServerMessage, Mechanism};
use crate::stream::PqStream;
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
/// Abstracts away all peculiarities of the libpq's protocol.
pub struct SaslStream<'a, S> {
/// The underlying stream.
stream: &'a mut PqStream<S>,
/// Current password message we received from client.
current: bytes::Bytes,
/// First SASL message produced by client.
first: Option<&'a str>,
}
impl<'a, S> SaslStream<'a, S> {
pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
Self {
stream,
current: bytes::Bytes::new(),
first: Some(first),
}
}
}
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
// Receive a new SASL message from the client.
async fn recv(&mut self) -> io::Result<&str> {
if let Some(first) = self.first.take() {
return Ok(first);
}
self.current = self.stream.read_password_message().await?;
let s = std::str::from_utf8(&self.current)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
Ok(s)
}
}
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
// Send a SASL message to the client.
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message(&msg.to_reply()).await?;
Ok(())
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
/// Perform SASL message exchange according to the underlying algorithm
/// until user is either authenticated or denied access.
pub async fn authenticate(mut self, mut mechanism: impl Mechanism) -> super::Result<()> {
loop {
let input = self.recv().await?;
let (moved, reply) = mechanism.exchange(input)?;
match moved {
Some(moved) => {
self.send(&ServerMessage::Continue(&reply)).await?;
mechanism = moved;
}
None => {
self.send(&ServerMessage::Final(&reply)).await?;
return Ok(());
}
}
}
}
}

59
proxy/src/scram.rs Normal file
View File

@@ -0,0 +1,59 @@
//! Salted Challenge Response Authentication Mechanism.
//!
//! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
//!
//! Reference implementation:
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
mod exchange;
mod key;
mod messages;
mod password;
mod secret;
mod signature;
pub use secret::*;
pub use exchange::Exchange;
pub use secret::ServerSecret;
use hmac::{Hmac, Mac, NewMac};
use sha2::{Digest, Sha256};
// TODO: add SCRAM-SHA-256-PLUS
/// A list of supported SCRAM methods.
pub const METHODS: &[&str] = &["SCRAM-SHA-256"];
/// Decode base64 into array without any heap allocations
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
let mut bytes = [0u8; N];
let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
if size != N {
return None;
}
Some(bytes)
}
/// This function essentially is `Hmac(sha256, key, input)`.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_varkey(key).expect("bad key size");
parts.into_iter().for_each(|s| mac.update(s));
// TODO: maybe newer `hmac` et al already migrated to regular arrays?
let mut result = [0u8; 32];
result.copy_from_slice(mac.finalize().into_bytes().as_slice());
result
}
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut hasher = Sha256::new();
parts.into_iter().for_each(|s| hasher.update(s));
let mut result = [0u8; 32];
result.copy_from_slice(hasher.finalize().as_slice());
result
}

134
proxy/src/scram/exchange.rs Normal file
View File

@@ -0,0 +1,134 @@
//! Implementation of the SCRAM authentication algorithm.
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
/// The only channel binding mode we currently support.
#[derive(Debug)]
struct TlsServerEndPoint;
impl std::fmt::Display for TlsServerEndPoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "tls-server-end-point")
}
}
impl std::str::FromStr for TlsServerEndPoint {
type Err = sasl::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"tls-server-end-point" => Ok(TlsServerEndPoint),
_ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
}
}
}
#[derive(Debug)]
enum ExchangeState {
/// Waiting for [`ClientFirstMessage`].
Initial,
/// Waiting for [`ClientFinalMessage`].
SaltSent {
cbind_flag: ChannelBinding<TlsServerEndPoint>,
client_first_message_bare: String,
server_first_message: OwnedServerFirstMessage,
},
}
/// Server's side of SCRAM auth algorithm.
#[derive(Debug)]
pub struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
cert_digest: Option<&'a [u8]>,
}
impl<'a> Exchange<'a> {
pub fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
cert_digest: Option<&'a [u8]>,
) -> Self {
Self {
state: ExchangeState::Initial,
secret,
nonce,
cert_digest,
}
}
}
impl sasl::Mechanism for Exchange<'_> {
fn exchange(mut self, input: &str) -> sasl::Result<(Option<Self>, String)> {
use ExchangeState::*;
match &self.state {
Initial => {
let client_first_message =
ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
let server_first_message = client_first_message.build_server_first_message(
&(self.nonce)(),
&self.secret.salt_base64,
self.secret.iterations,
);
let msg = server_first_message.as_str().to_owned();
self.state = SaltSent {
cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
client_first_message_bare: client_first_message.bare.to_owned(),
server_first_message,
};
Ok((Some(self), msg))
}
SaltSent {
cbind_flag,
client_first_message_bare,
server_first_message,
} => {
let client_final_message =
ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
let channel_binding = cbind_flag.encode(|_| {
self.cert_digest
.map(base64::encode)
.ok_or(SaslError::ChannelBindingFailed("no cert digest provided"))
})?;
// This might've been caused by a MITM attack
if client_final_message.channel_binding != channel_binding {
return Err(SaslError::ChannelBindingFailed("data mismatch"));
}
if client_final_message.nonce != server_first_message.nonce() {
return Err(SaslError::AuthenticationFailed("bad nonce"));
}
let signature_builder = SignatureBuilder {
client_first_message_bare,
server_first_message: server_first_message.as_str(),
client_final_message_without_proof: client_final_message.without_proof,
};
let client_key = signature_builder
.build(&self.secret.stored_key)
.derive_client_key(&client_final_message.proof);
if client_key.sha256() != self.secret.stored_key {
return Err(SaslError::AuthenticationFailed("keys don't match"));
}
let msg = client_final_message
.build_server_final_message(signature_builder, &self.secret.server_key);
Ok((None, msg))
}
}
}
}

33
proxy/src/scram/key.rs Normal file
View File

@@ -0,0 +1,33 @@
//! Tools for client/server/stored key management.
/// Faithfully taken from PostgreSQL.
pub const SCRAM_KEY_LEN: usize = 32;
/// One of the keys derived from the [password](super::password::SaltedPassword).
/// We use the same structure for all keys, i.e.
/// `ClientKey`, `StoredKey`, and `ServerKey`.
#[derive(Default, Debug, PartialEq, Eq)]
#[repr(transparent)]
pub struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl ScramKey {
pub fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()
}
}
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
#[inline(always)]
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
Self { bytes }
}
}
impl AsRef<[u8]> for ScramKey {
#[inline(always)]
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}

232
proxy/src/scram/messages.rs Normal file
View File

@@ -0,0 +1,232 @@
//! Definitions for SCRAM messages.
use super::base64_decode_array;
use super::key::{ScramKey, SCRAM_KEY_LEN};
use super::signature::SignatureBuilder;
use crate::sasl::ChannelBinding;
use std::fmt;
use std::ops::Range;
/// Faithfully taken from PostgreSQL.
pub const SCRAM_RAW_NONCE_LEN: usize = 18;
/// Although we ignore all extensions, we still have to validate the message.
fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
for mut chars in parts.map(|s| s.chars()) {
let attr = chars.next()?;
if !('a'..'z').contains(&attr) && !('A'..'Z').contains(&attr) {
return None;
}
let eq = chars.next()?;
if eq != '=' {
return None;
}
}
Some(())
}
#[derive(Debug)]
pub struct ClientFirstMessage<'a> {
/// `client-first-message-bare`.
pub bare: &'a str,
/// Channel binding mode.
pub cbind_flag: ChannelBinding<&'a str>,
/// (Client username)[<https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf/src/backend/libpq/auth-scram.c#L13>].
pub username: &'a str,
/// Client nonce.
pub nonce: &'a str,
}
impl<'a> ClientFirstMessage<'a> {
// NB: FromStr doesn't work with lifetimes
pub fn parse(input: &'a str) -> Option<Self> {
let mut parts = input.split(',');
let cbind_flag = ChannelBinding::parse(parts.next()?)?;
// PG doesn't support authorization identity,
// so we don't bother defining GS2 header type
let authzid = parts.next()?;
if !authzid.is_empty() {
return None;
}
// Unfortunately, `parts.as_str()` is unstable
let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
let (_, bare) = input.split_at(pos);
// In theory, these might be preceded by "reserved-mext" (i.e. "m=")
let username = parts.next()?.strip_prefix("n=")?;
let nonce = parts.next()?.strip_prefix("r=")?;
// Validate but ignore auth extensions
validate_sasl_extensions(parts)?;
Some(Self {
bare,
cbind_flag,
username,
nonce,
})
}
/// Build a response to [`ClientFirstMessage`].
pub fn build_server_first_message(
&self,
nonce: &[u8; SCRAM_RAW_NONCE_LEN],
salt_base64: &str,
iterations: u32,
) -> OwnedServerFirstMessage {
use std::fmt::Write;
let mut message = String::new();
write!(&mut message, "r={}", self.nonce).unwrap();
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message
OwnedServerFirstMessage {
message,
nonce: combined_nonce,
}
}
}
#[derive(Debug)]
pub struct ClientFinalMessage<'a> {
/// `client-final-message-without-proof`.
pub without_proof: &'a str,
/// Channel binding data (base64).
pub channel_binding: &'a str,
/// Combined client & server nonce.
pub nonce: &'a str,
/// Client auth proof.
pub proof: [u8; SCRAM_KEY_LEN],
}
impl<'a> ClientFinalMessage<'a> {
// NB: FromStr doesn't work with lifetimes
pub fn parse(input: &'a str) -> Option<Self> {
let (without_proof, proof) = input.rsplit_once(',')?;
let mut parts = without_proof.split(',');
let channel_binding = parts.next()?.strip_prefix("c=")?;
let nonce = parts.next()?.strip_prefix("r=")?;
// Validate but ignore auth extensions
validate_sasl_extensions(parts)?;
let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
Some(Self {
without_proof,
channel_binding,
nonce,
proof,
})
}
/// Build a response to [`ClientFinalMessage`].
pub fn build_server_final_message(
&self,
signature_builder: SignatureBuilder,
server_key: &ScramKey,
) -> String {
let mut buf = String::from("v=");
base64::encode_config_buf(
signature_builder.build(server_key),
base64::STANDARD,
&mut buf,
);
buf
}
}
/// We need to keep a convenient representation of this
/// message for the next authentication step.
pub struct OwnedServerFirstMessage {
/// Owned `server-first-message`.
message: String,
/// Slice into `message`.
nonce: Range<usize>,
}
impl OwnedServerFirstMessage {
/// Extract combined nonce from the message.
#[inline(always)]
pub fn nonce(&self) -> &str {
&self.message[self.nonce.clone()]
}
/// Get reference to a text representation of the message.
#[inline(always)]
pub fn as_str(&self) -> &str {
&self.message
}
}
impl fmt::Debug for OwnedServerFirstMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ServerFirstMessage")
.field("message", &self.as_str())
.field("nonce", &self.nonce())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_client_first_message() {
use ChannelBinding::*;
// (Almost) real strings captured during debug sessions
let cases = [
(NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
(NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
(
Required("tls-server-end-point"),
"p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju",
),
];
for (cb, input) in cases {
let msg = ClientFirstMessage::parse(input).unwrap();
assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju");
assert_eq!(msg.username, "pepe");
assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
assert_eq!(msg.cbind_flag, cb);
}
}
#[test]
fn parse_client_final_message() {
let input = [
"c=eSws",
"r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
"p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
]
.join(",");
let msg = ClientFinalMessage::parse(&input).unwrap();
assert_eq!(
msg.without_proof,
"c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
);
assert_eq!(
msg.nonce,
"iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
);
assert_eq!(
base64::encode(msg.proof),
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
);
}
}

View File

@@ -0,0 +1,48 @@
//! Password hashing routines.
use super::key::ScramKey;
pub const SALTED_PASSWORD_LEN: usize = 32;
/// Salted hashed password is essential for [key](super::key) derivation.
#[repr(transparent)]
pub struct SaltedPassword {
bytes: [u8; SALTED_PASSWORD_LEN],
}
impl SaltedPassword {
/// See `scram-common.c : scram_SaltedPassword` for details.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2898> (see `PBKDF2`).
pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
let one = 1_u32.to_be_bytes(); // magic
let mut current = super::hmac_sha256(password, [salt, &one]);
let mut result = current;
for _ in 1..iterations {
current = super::hmac_sha256(password, [current.as_ref()]);
// TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094
for (i, x) in current.iter().enumerate() {
result[i] ^= x;
}
}
result.into()
}
/// Derive `ClientKey` from a salted hashed password.
pub fn client_key(&self) -> ScramKey {
super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into()
}
/// Derive `ServerKey` from a salted hashed password.
pub fn server_key(&self) -> ScramKey {
super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into()
}
}
impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword {
#[inline(always)]
fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self {
Self { bytes }
}
}

116
proxy/src/scram/secret.rs Normal file
View File

@@ -0,0 +1,116 @@
//! Tools for SCRAM server secret management.
use super::base64_decode_array;
use super::key::ScramKey;
/// Server secret is produced from [password](super::password::SaltedPassword)
/// and is used throughout the authentication process.
#[derive(Debug)]
pub struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
pub iterations: u32,
/// Salt used to hash user's password.
pub salt_base64: String,
/// Hashed `ClientKey`.
pub stored_key: ScramKey,
/// Used by client to verify server's signature.
pub server_key: ScramKey,
}
impl ServerSecret {
pub fn parse(input: &str) -> Option<Self> {
// SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
let s = input.strip_prefix("SCRAM-SHA-256$")?;
let (params, keys) = s.split_once('$')?;
let ((iterations, salt), (stored_key, server_key)) =
params.split_once(':').zip(keys.split_once(':'))?;
let secret = ServerSecret {
iterations: iterations.parse().ok()?,
salt_base64: salt.to_owned(),
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
};
Some(secret)
}
/// To avoid revealing information to an attacker, we use a
/// mocked server secret even if the user doesn't exist.
/// See `auth-scram.c : mock_scram_secret` for details.
pub fn mock(user: &str, nonce: &[u8; 32]) -> Self {
// Refer to `auth-scram.c : scram_mock_salt`.
let mocked_salt = super::sha256([user.as_bytes(), nonce]);
Self {
iterations: 4096,
salt_base64: base64::encode(&mocked_salt),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
}
}
/// Build a new server secret from the prerequisites.
/// XXX: We only use this function in tests.
#[cfg(test)]
pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option<Self> {
// TODO: implement proper password normalization required by the RFC
if !password.is_ascii() {
return None;
}
let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations);
Some(Self {
iterations,
salt_base64: base64::encode(&salt),
stored_key: password.client_key().sha256(),
server_key: password.server_key(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_scram_secret() {
let iterations = 4096;
let salt = "+/tQQax7twvwTj64mjBsxQ==";
let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns=";
let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI=";
let secret = format!(
"SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",
iterations = iterations,
salt = salt,
stored_key = stored_key,
server_key = server_key,
);
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(base64::encode(parsed.stored_key), stored_key);
assert_eq!(base64::encode(parsed.server_key), server_key);
}
#[test]
fn build_scram_secret() {
let salt = b"salt";
let secret = ServerSecret::build("password", salt, 4096).unwrap();
assert_eq!(secret.iterations, 4096);
assert_eq!(secret.salt_base64, base64::encode(salt));
assert_eq!(
base64::encode(secret.stored_key.as_ref()),
"lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ="
);
assert_eq!(
base64::encode(secret.server_key.as_ref()),
"ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw="
);
}
}

View File

@@ -0,0 +1,66 @@
//! Tools for client/server signature management.
use super::key::{ScramKey, SCRAM_KEY_LEN};
/// A collection of message parts needed to derive the client's signature.
#[derive(Debug)]
pub struct SignatureBuilder<'a> {
pub client_first_message_bare: &'a str,
pub server_first_message: &'a str,
pub client_final_message_without_proof: &'a str,
}
impl SignatureBuilder<'_> {
pub fn build(&self, key: &ScramKey) -> Signature {
let parts = [
self.client_first_message_bare.as_bytes(),
b",",
self.server_first_message.as_bytes(),
b",",
self.client_final_message_without_proof.as_bytes(),
];
super::hmac_sha256(key.as_ref(), parts).into()
}
}
/// A computed value which, when xored with `ClientProof`,
/// produces `ClientKey` that we need for authentication.
#[derive(Debug)]
#[repr(transparent)]
pub struct Signature {
bytes: [u8; SCRAM_KEY_LEN],
}
impl Signature {
/// Derive `ClientKey` from client's signature and proof.
pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey {
// This is how the proof is calculated:
//
// 1. sha256(ClientKey) -> StoredKey
// 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature
// 3. ClientKey ^ ClientSignature -> ClientProof
//
// Step 3 implies that we can restore ClientKey from the proof
// by xoring the latter with the ClientSignature. Afterwards we
// can check that the presumed ClientKey meets our expectations.
let mut signature = self.bytes;
for (i, x) in proof.iter().enumerate() {
signature[i] ^= x;
}
signature.into()
}
}
impl From<[u8; SCRAM_KEY_LEN]> for Signature {
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
Self { bytes }
}
}
impl AsRef<[u8]> for Signature {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}

View File

@@ -28,4 +28,4 @@ def test_createuser(zenith_simple_env: ZenithEnv):
pg2 = env.postgres.create_start('test_createuser2')
# Test that you can connect to new branch as a new user
assert pg2.safe_psql('select current_user', username='testuser') == [('testuser', )]
assert pg2.safe_psql('select current_user', user='testuser') == [('testuser', )]

View File

@@ -19,6 +19,11 @@ async def copy_test_data_to_table(pg: Postgres, worker_id: int, table_name: str)
copy_input = repeat_bytes(buf.read(), 5000)
pg_conn = await pg.connect_async()
# PgProtocol.connect_async sets statement_timeout to 2 minutes.
# That's not enough for this test, on a slow system in debug mode.
await pg_conn.execute("SET statement_timeout='300s'")
await pg_conn.copy_to_table(table_name, source=copy_input)

View File

@@ -1,14 +0,0 @@
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
def test_pgbench(zenith_simple_env: ZenithEnv, pg_bin):
env = zenith_simple_env
env.zenith_cli.create_branch("test_pgbench", "empty")
pg = env.postgres.create_start('test_pgbench')
log.info("postgres is running on 'test_pgbench' branch")
connstr = pg.connstr()
pg_bin.run_capture(['pgbench', '-i', connstr])
pg_bin.run_capture(['pgbench'] + '-c 10 -T 5 -P 1 -M prepared'.split() + [connstr])

View File

@@ -5,11 +5,14 @@ def test_proxy_select_1(static_proxy):
static_proxy.safe_psql("select 1;")
@pytest.mark.xfail # Proxy eats the extra connection options
# Pass extra options to the server.
#
# Currently, proxy eats the extra connection options, so this fails.
# See https://github.com/neondatabase/neon/issues/1287
@pytest.mark.xfail
def test_proxy_options(static_proxy):
schema_name = "tmp_schema_1"
with static_proxy.connect(schema=schema_name) as conn:
with static_proxy.connect(options="-cproxytest.option=value") as conn:
with conn.cursor() as cur:
cur.execute("SHOW search_path;")
search_path = cur.fetchall()[0][0]
assert schema_name == search_path
cur.execute("SHOW proxytest.option;")
value = cur.fetchall()[0][0]
assert value == 'value'

View File

@@ -379,7 +379,7 @@ class ProposerPostgres(PgProtocol):
tenant_id: uuid.UUID,
listen_addr: str,
port: int):
super().__init__(host=listen_addr, port=port, username='zenith_admin')
super().__init__(host=listen_addr, port=port, user='zenith_admin', dbname='postgres')
self.pgdata_dir: str = pgdata_dir
self.pg_bin: PgBin = pg_bin

View File

@@ -35,9 +35,9 @@ def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys
]
env_vars = {
'PGPORT': str(pg.port),
'PGUSER': pg.username,
'PGHOST': pg.host,
'PGPORT': str(pg.default_options['port']),
'PGUSER': pg.default_options['user'],
'PGHOST': pg.default_options['host'],
}
# Run the command.

View File

@@ -35,9 +35,9 @@ def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin,
]
env_vars = {
'PGPORT': str(pg.port),
'PGUSER': pg.username,
'PGHOST': pg.host,
'PGPORT': str(pg.default_options['port']),
'PGUSER': pg.default_options['user'],
'PGHOST': pg.default_options['host'],
}
# Run the command.

View File

@@ -40,9 +40,9 @@ def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, c
log.info(pg_regress_command)
env_vars = {
'PGPORT': str(pg.port),
'PGUSER': pg.username,
'PGHOST': pg.host,
'PGPORT': str(pg.default_options['port']),
'PGUSER': pg.default_options['user'],
'PGHOST': pg.default_options['host'],
}
# Run the command.

View File

@@ -17,7 +17,7 @@ import warnings
from contextlib import contextmanager
# Type-related stuff
from typing import Iterator
from typing import Iterator, Optional
"""
This file contains fixtures for micro-benchmarks.
@@ -51,17 +51,12 @@ in the test initialization, or measure disk usage after the test query.
@dataclasses.dataclass
class PgBenchRunResult:
scale: int
number_of_clients: int
number_of_threads: int
number_of_transactions_actually_processed: int
latency_average: float
latency_stddev: float
tps_including_connection_time: float
tps_excluding_connection_time: float
init_duration: float
init_start_timestamp: int
init_end_timestamp: int
latency_stddev: Optional[float]
tps: float
run_duration: float
run_start_timestamp: int
run_end_timestamp: int
@@ -69,56 +64,67 @@ class PgBenchRunResult:
# TODO progress
@classmethod
def parse_from_output(
def parse_from_stdout(
cls,
out: 'subprocess.CompletedProcess[str]',
init_duration: float,
init_start_timestamp: int,
init_end_timestamp: int,
stdout: str,
run_duration: float,
run_start_timestamp: int,
run_end_timestamp: int,
):
stdout_lines = out.stdout.splitlines()
stdout_lines = stdout.splitlines()
latency_stddev = None
# we know significant parts of these values from test input
# but to be precise take them from output
# scaling factor: 5
assert "scaling factor" in stdout_lines[1]
scale = int(stdout_lines[1].split()[-1])
# number of clients: 1
assert "number of clients" in stdout_lines[3]
number_of_clients = int(stdout_lines[3].split()[-1])
# number of threads: 1
assert "number of threads" in stdout_lines[4]
number_of_threads = int(stdout_lines[4].split()[-1])
# number of transactions actually processed: 1000/1000
assert "number of transactions actually processed" in stdout_lines[6]
number_of_transactions_actually_processed = int(stdout_lines[6].split("/")[1])
# latency average = 19.894 ms
assert "latency average" in stdout_lines[7]
latency_average = stdout_lines[7].split()[-2]
# latency stddev = 3.387 ms
assert "latency stddev" in stdout_lines[8]
latency_stddev = stdout_lines[8].split()[-2]
# tps = 50.219689 (including connections establishing)
assert "(including connections establishing)" in stdout_lines[9]
tps_including_connection_time = stdout_lines[9].split()[2]
# tps = 50.264435 (excluding connections establishing)
assert "(excluding connections establishing)" in stdout_lines[10]
tps_excluding_connection_time = stdout_lines[10].split()[2]
for line in stdout.splitlines():
# scaling factor: 5
if line.startswith("scaling factor:"):
scale = int(line.split()[-1])
# number of clients: 1
if line.startswith("number of clients: "):
number_of_clients = int(line.split()[-1])
# number of threads: 1
if line.startswith("number of threads: "):
number_of_threads = int(line.split()[-1])
# number of transactions actually processed: 1000/1000
# OR
# number of transactions actually processed: 1000
if line.startswith("number of transactions actually processed"):
if "/" in line:
number_of_transactions_actually_processed = int(line.split("/")[1])
else:
number_of_transactions_actually_processed = int(line.split()[-1])
# latency average = 19.894 ms
if line.startswith("latency average"):
latency_average = float(line.split()[-2])
# latency stddev = 3.387 ms
# (only printed with some options)
if line.startswith("latency stddev"):
latency_stddev = float(line.split()[-2])
# Get the TPS without initial connection time. The format
# of the tps lines changed in pgbench v14, but we accept
# either format:
#
# pgbench v13 and below:
# tps = 50.219689 (including connections establishing)
# tps = 50.264435 (excluding connections establishing)
#
# pgbench v14:
# initial connection time = 3.858 ms
# tps = 309.281539 (without initial connection time)
if (line.startswith("tps = ") and ("(excluding connections establishing)" in line
or "(without initial connection time)")):
tps = float(line.split()[2])
return cls(
scale=scale,
number_of_clients=number_of_clients,
number_of_threads=number_of_threads,
number_of_transactions_actually_processed=number_of_transactions_actually_processed,
latency_average=float(latency_average),
latency_stddev=float(latency_stddev),
tps_including_connection_time=float(tps_including_connection_time),
tps_excluding_connection_time=float(tps_excluding_connection_time),
init_duration=init_duration,
init_start_timestamp=init_start_timestamp,
init_end_timestamp=init_end_timestamp,
latency_average=latency_average,
latency_stddev=latency_stddev,
tps=tps,
run_duration=run_duration,
run_start_timestamp=run_start_timestamp,
run_end_timestamp=run_end_timestamp,
@@ -187,60 +193,41 @@ class ZenithBenchmarker:
report=MetricReport.LOWER_IS_BETTER,
)
def record_pg_bench_result(self, pg_bench_result: PgBenchRunResult):
self.record("scale", pg_bench_result.scale, '', MetricReport.TEST_PARAM)
self.record("number_of_clients",
def record_pg_bench_result(self, prefix: str, pg_bench_result: PgBenchRunResult):
self.record(f"{prefix}.number_of_clients",
pg_bench_result.number_of_clients,
'',
MetricReport.TEST_PARAM)
self.record("number_of_threads",
self.record(f"{prefix}.number_of_threads",
pg_bench_result.number_of_threads,
'',
MetricReport.TEST_PARAM)
self.record(
"number_of_transactions_actually_processed",
f"{prefix}.number_of_transactions_actually_processed",
pg_bench_result.number_of_transactions_actually_processed,
'',
# thats because this is predefined by test matrix and doesnt change across runs
report=MetricReport.TEST_PARAM,
)
self.record("latency_average",
self.record(f"{prefix}.latency_average",
pg_bench_result.latency_average,
unit="ms",
report=MetricReport.LOWER_IS_BETTER)
self.record("latency_stddev",
pg_bench_result.latency_stddev,
unit="ms",
report=MetricReport.LOWER_IS_BETTER)
self.record("tps_including_connection_time",
pg_bench_result.tps_including_connection_time,
'',
report=MetricReport.HIGHER_IS_BETTER)
self.record("tps_excluding_connection_time",
pg_bench_result.tps_excluding_connection_time,
'',
report=MetricReport.HIGHER_IS_BETTER)
self.record("init_duration",
pg_bench_result.init_duration,
unit="s",
report=MetricReport.LOWER_IS_BETTER)
self.record("init_start_timestamp",
pg_bench_result.init_start_timestamp,
'',
MetricReport.TEST_PARAM)
self.record("init_end_timestamp",
pg_bench_result.init_end_timestamp,
'',
MetricReport.TEST_PARAM)
self.record("run_duration",
if pg_bench_result.latency_stddev is not None:
self.record(f"{prefix}.latency_stddev",
pg_bench_result.latency_stddev,
unit="ms",
report=MetricReport.LOWER_IS_BETTER)
self.record(f"{prefix}.tps", pg_bench_result.tps, '', report=MetricReport.HIGHER_IS_BETTER)
self.record(f"{prefix}.run_duration",
pg_bench_result.run_duration,
unit="s",
report=MetricReport.LOWER_IS_BETTER)
self.record("run_start_timestamp",
self.record(f"{prefix}.run_start_timestamp",
pg_bench_result.run_start_timestamp,
'',
MetricReport.TEST_PARAM)
self.record("run_end_timestamp",
self.record(f"{prefix}.run_end_timestamp",
pg_bench_result.run_end_timestamp,
'',
MetricReport.TEST_PARAM)
@@ -259,10 +246,18 @@ class ZenithBenchmarker:
"""
Fetch the "cumulative # of bytes written" metric from the pageserver
"""
# Fetch all the exposed prometheus metrics from page server
all_metrics = pageserver.http_client().get_metrics()
# Use a regular expression to extract the one we're interested in
#
metric_name = r'pageserver_disk_io_bytes{io_operation="write"}'
return self.get_int_counter_value(pageserver, metric_name)
def get_peak_mem(self, pageserver) -> int:
"""
Fetch the "maxrss" metric from the pageserver
"""
metric_name = r'pageserver_maxrss_kb'
return self.get_int_counter_value(pageserver, metric_name)
def get_int_counter_value(self, pageserver, metric_name) -> int:
"""Fetch the value of given int counter from pageserver metrics."""
# TODO: If we start to collect more of the prometheus metrics in the
# performance test suite like this, we should refactor this to load and
# parse all the metrics into a more convenient structure in one go.
@@ -270,20 +265,8 @@ class ZenithBenchmarker:
# The metric should be an integer, as it's a number of bytes. But in general
# all prometheus metrics are floats. So to be pedantic, read it as a float
# and round to integer.
matches = re.search(r'^pageserver_disk_io_bytes{io_operation="write"} (\S+)$',
all_metrics,
re.MULTILINE)
assert matches
return int(round(float(matches.group(1))))
def get_peak_mem(self, pageserver) -> int:
"""
Fetch the "maxrss" metric from the pageserver
"""
# Fetch all the exposed prometheus metrics from page server
all_metrics = pageserver.http_client().get_metrics()
# See comment in get_io_writes()
matches = re.search(r'^pageserver_maxrss_kb (\S+)$', all_metrics, re.MULTILINE)
matches = re.search(fr'^{metric_name} (\S+)$', all_metrics, re.MULTILINE)
assert matches
return int(round(float(matches.group(1))))

View File

@@ -2,7 +2,7 @@ import pytest
from contextlib import contextmanager
from abc import ABC, abstractmethod
from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, ZenithEnv
from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, RemotePostgres, ZenithEnv
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
# Type-related stuff
@@ -87,6 +87,9 @@ class ZenithCompare(PgCompare):
def flush(self):
self.pscur.execute(f"do_gc {self.env.initial_tenant.hex} {self.timeline} 0")
def compact(self):
self.pscur.execute(f"compact {self.env.initial_tenant.hex} {self.timeline}")
def report_peak_memory_use(self) -> None:
self.zenbenchmark.record("peak_mem",
self.zenbenchmark.get_peak_mem(self.env.pageserver) / 1024,
@@ -102,6 +105,19 @@ class ZenithCompare(PgCompare):
'MB',
report=MetricReport.LOWER_IS_BETTER)
total_files = self.zenbenchmark.get_int_counter_value(
self.env.pageserver, "pageserver_num_persistent_files_created")
total_bytes = self.zenbenchmark.get_int_counter_value(
self.env.pageserver, "pageserver_persistent_bytes_written")
self.zenbenchmark.record("data_uploaded",
total_bytes / (1024 * 1024),
"MB",
report=MetricReport.LOWER_IS_BETTER)
self.zenbenchmark.record("num_files_uploaded",
total_files,
"",
report=MetricReport.LOWER_IS_BETTER)
def record_pageserver_writes(self, out_name):
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
@@ -159,6 +175,48 @@ class VanillaCompare(PgCompare):
return self.zenbenchmark.record_duration(out_name)
class RemoteCompare(PgCompare):
"""PgCompare interface for a remote postgres instance."""
def __init__(self, zenbenchmark, remote_pg: RemotePostgres):
self._pg = remote_pg
self._zenbenchmark = zenbenchmark
# Long-lived cursor, useful for flushing
self.conn = self.pg.connect()
self.cur = self.conn.cursor()
@property
def pg(self):
return self._pg
@property
def zenbenchmark(self):
return self._zenbenchmark
@property
def pg_bin(self):
return self._pg.pg_bin
def flush(self):
# TODO: flush the remote pageserver
pass
def report_peak_memory_use(self) -> None:
# TODO: get memory usage from remote pageserver
pass
def report_size(self) -> None:
# TODO: get storage size from remote pageserver
pass
@contextmanager
def record_pageserver_writes(self, out_name):
yield # Do nothing
def record_duration(self, out_name):
return self.zenbenchmark.record_duration(out_name)
@pytest.fixture(scope='function')
def zenith_compare(request, zenbenchmark, pg_bin, zenith_simple_env) -> ZenithCompare:
branch_name = request.node.name
@@ -170,6 +228,11 @@ def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare:
return VanillaCompare(zenbenchmark, vanilla_pg)
@pytest.fixture(scope='function')
def remote_compare(zenbenchmark, remote_pg) -> RemoteCompare:
return RemoteCompare(zenbenchmark, remote_pg)
@pytest.fixture(params=["vanilla_compare", "zenith_compare"], ids=["vanilla", "zenith"])
def zenith_with_baseline(request) -> PgCompare:
"""Parameterized fixture that helps compare zenith against vanilla postgres.

View File

@@ -27,6 +27,7 @@ from dataclasses import dataclass
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import make_dsn, parse_dsn
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple
from typing_extensions import Literal
@@ -122,6 +123,22 @@ def pytest_configure(config):
top_output_dir = os.path.join(base_dir, DEFAULT_OUTPUT_DIR)
mkdir_if_needed(top_output_dir)
# Find the postgres installation.
global pg_distrib_dir
env_postgres_bin = os.environ.get('POSTGRES_DISTRIB_DIR')
if env_postgres_bin:
pg_distrib_dir = env_postgres_bin
else:
pg_distrib_dir = os.path.normpath(os.path.join(base_dir, DEFAULT_POSTGRES_DIR))
log.info(f'pg_distrib_dir is {pg_distrib_dir}')
if os.getenv("REMOTE_ENV"):
# When testing against a remote server, we only need the client binary.
if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/psql')):
raise Exception('psql not found at "{}"'.format(pg_distrib_dir))
else:
if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/postgres')):
raise Exception('postgres not found at "{}"'.format(pg_distrib_dir))
if os.getenv("REMOTE_ENV"):
# we are in remote env and do not have zenith binaries locally
# this is the case for benchmarks run on self-hosted runner
@@ -137,17 +154,6 @@ def pytest_configure(config):
if not os.path.exists(os.path.join(zenith_binpath, 'pageserver')):
raise Exception('zenith binaries not found at "{}"'.format(zenith_binpath))
# Find the postgres installation.
global pg_distrib_dir
env_postgres_bin = os.environ.get('POSTGRES_DISTRIB_DIR')
if env_postgres_bin:
pg_distrib_dir = env_postgres_bin
else:
pg_distrib_dir = os.path.normpath(os.path.join(base_dir, DEFAULT_POSTGRES_DIR))
log.info(f'pg_distrib_dir is {pg_distrib_dir}')
if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/postgres')):
raise Exception('postgres not found at "{}"'.format(pg_distrib_dir))
def zenfixture(func: Fn) -> Fn:
"""
@@ -238,98 +244,69 @@ def port_distributor(worker_base_port):
class PgProtocol:
""" Reusable connection logic """
def __init__(self,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
dbname: Optional[str] = None,
schema: Optional[str] = None):
self.host = host
self.port = port
self.username = username
self.password = password
self.dbname = dbname
self.schema = schema
def __init__(self, **kwargs):
self.default_options = kwargs
def connstr(self,
*,
dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
statement_timeout_ms: Optional[int] = None) -> str:
def connstr(self, **kwargs) -> str:
"""
Build a libpq connection string for the Postgres instance.
"""
return str(make_dsn(**self.conn_options(**kwargs)))
username = username or self.username
password = password or self.password
dbname = dbname or self.dbname or "postgres"
schema = schema or self.schema
res = f'host={self.host} port={self.port} dbname={dbname}'
def conn_options(self, **kwargs):
conn_options = self.default_options.copy()
if 'dsn' in kwargs:
conn_options.update(parse_dsn(kwargs['dsn']))
conn_options.update(kwargs)
if username:
res = f'{res} user={username}'
if password:
res = f'{res} password={password}'
if schema:
res = f"{res} options='-c search_path={schema}'"
if statement_timeout_ms:
res = f"{res} options='-c statement_timeout={statement_timeout_ms}'"
return res
# Individual statement timeout in seconds. 2 minutes should be
# enough for our tests, but if you need a longer, you can
# change it by calling "SET statement_timeout" after
# connecting.
if 'options' in conn_options:
conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options']
else:
conn_options['options'] = "-cstatement_timeout=120s"
return conn_options
# autocommit=True here by default because that's what we need most of the time
def connect(
self,
*,
autocommit=True,
dbname: Optional[str] = None,
schema: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
# individual statement timeout in seconds, 2 minutes should be enough for our tests
statement_timeout: Optional[int] = 120
) -> PgConnection:
def connect(self, autocommit=True, **kwargs) -> PgConnection:
"""
Connect to the node.
Returns psycopg2's connection object.
This method passes all extra params to connstr.
"""
conn = psycopg2.connect(**self.conn_options(**kwargs))
conn = psycopg2.connect(
self.connstr(dbname=dbname,
schema=schema,
username=username,
password=password,
statement_timeout_ms=statement_timeout *
1000 if statement_timeout else None))
# WARNING: this setting affects *all* tests!
conn.autocommit = autocommit
return conn
async def connect_async(self,
*,
dbname: str = 'postgres',
username: Optional[str] = None,
password: Optional[str] = None) -> asyncpg.Connection:
async def connect_async(self, **kwargs) -> asyncpg.Connection:
"""
Connect to the node from async python.
Returns asyncpg's connection object.
"""
conn = await asyncpg.connect(
host=self.host,
port=self.port,
database=dbname,
user=username or self.username,
password=password,
)
return conn
# asyncpg takes slightly different options than psycopg2. Try
# to convert the defaults from the psycopg2 format.
# The psycopg2 option 'dbname' is called 'database' is asyncpg
conn_options = self.conn_options(**kwargs)
if 'dbname' in conn_options:
conn_options['database'] = conn_options.pop('dbname')
# Convert options='-c<key>=<val>' to server_settings
if 'options' in conn_options:
options = conn_options.pop('options')
for match in re.finditer('-c(\w*)=(\w*)', options):
key = match.group(1)
val = match.group(2)
if 'server_options' in conn_options:
conn_options['server_settings'].update({key: val})
else:
conn_options['server_settings'] = {key: val}
return await asyncpg.connect(**conn_options)
def safe_psql(self, query: str, **kwargs: Any) -> List[Any]:
"""
@@ -1149,10 +1126,10 @@ class ZenithPageserver(PgProtocol):
port: PageserverPort,
remote_storage: Optional[RemoteStorage] = None,
config_override: Optional[str] = None):
super().__init__(host='localhost', port=port.pg, username='zenith_admin')
super().__init__(host='localhost', port=port.pg, user='zenith_admin')
self.env = env
self.running = False
self.service_port = port # do not shadow PgProtocol.port which is just int
self.service_port = port
self.remote_storage = remote_storage
self.config_override = config_override
@@ -1314,7 +1291,7 @@ def psbench_bin(test_output_dir):
class VanillaPostgres(PgProtocol):
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int):
super().__init__(host='localhost', port=port)
super().__init__(host='localhost', port=port, dbname='postgres')
self.pgdatadir = pgdatadir
self.pg_bin = pg_bin
self.running = False
@@ -1356,10 +1333,57 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
yield vanilla_pg
class RemotePostgres(PgProtocol):
def __init__(self, pg_bin: PgBin, remote_connstr: str):
super().__init__(**parse_dsn(remote_connstr))
self.pg_bin = pg_bin
# The remote server is assumed to be running already
self.running = True
def configure(self, options: List[str]):
raise Exception('cannot change configuration of remote Posgres instance')
def start(self):
raise Exception('cannot start a remote Postgres instance')
def stop(self):
raise Exception('cannot stop a remote Postgres instance')
def get_subdir_size(self, subdir) -> int:
# TODO: Could use the server's Generic File Acccess functions if superuser.
# See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE
raise Exception('cannot get size of a Postgres instance')
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
# do nothing
pass
@pytest.fixture(scope='function')
def remote_pg(test_output_dir: str) -> Iterator[RemotePostgres]:
pg_bin = PgBin(test_output_dir)
connstr = os.getenv("BENCHMARK_CONNSTR")
if connstr is None:
raise ValueError("no connstr provided, use BENCHMARK_CONNSTR environment variable")
with RemotePostgres(pg_bin, connstr) as remote_pg:
yield remote_pg
class ZenithProxy(PgProtocol):
def __init__(self, port: int):
super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port)
super().__init__(host="127.0.0.1",
user="pytest",
password="pytest",
port=port,
dbname='postgres')
self.http_port = 7001
self.host = "127.0.0.1"
self.port = port
self._popen: Optional[subprocess.Popen[bytes]] = None
def start_static(self, addr="127.0.0.1:5432") -> None:
@@ -1403,13 +1427,13 @@ def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
class Postgres(PgProtocol):
""" An object representing a running postgres daemon. """
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
super().__init__(host='localhost', port=port, username='zenith_admin')
super().__init__(host='localhost', port=port, user='zenith_admin', dbname='postgres')
self.env = env
self.running = False
self.node_name: Optional[str] = None # dubious, see asserts below
self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA
self.tenant_id = tenant_id
self.port = port
# path to conf is <repo_dir>/pgdatadirs/tenants/<tenant_id>/<node_name>/postgresql.conf
def create(

View File

@@ -2,29 +2,113 @@ from contextlib import closing
from fixtures.zenith_fixtures import PgBin, VanillaPostgres, ZenithEnv
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.benchmark_fixture import PgBenchRunResult, MetricReport, ZenithBenchmarker
from fixtures.log_helper import log
from pathlib import Path
import pytest
from datetime import datetime
import calendar
import os
import timeit
def utc_now_timestamp() -> int:
return calendar.timegm(datetime.utcnow().utctimetuple())
def init_pgbench(env: PgCompare, cmdline):
# calculate timestamps and durations separately
# timestamp is intended to be used for linking to grafana and logs
# duration is actually a metric and uses float instead of int for timestamp
init_start_timestamp = utc_now_timestamp()
t0 = timeit.default_timer()
with env.record_pageserver_writes('init.pageserver_writes'):
env.pg_bin.run_capture(cmdline)
env.flush()
init_duration = timeit.default_timer() - t0
init_end_timestamp = utc_now_timestamp()
env.zenbenchmark.record("init.duration",
init_duration,
unit="s",
report=MetricReport.LOWER_IS_BETTER)
env.zenbenchmark.record("init.start_timestamp",
init_start_timestamp,
'',
MetricReport.TEST_PARAM)
env.zenbenchmark.record("init.end_timestamp", init_end_timestamp, '', MetricReport.TEST_PARAM)
def run_pgbench(env: PgCompare, prefix: str, cmdline):
with env.record_pageserver_writes(f'{prefix}.pageserver_writes'):
run_start_timestamp = utc_now_timestamp()
t0 = timeit.default_timer()
out = env.pg_bin.run_capture(cmdline, )
run_duration = timeit.default_timer() - t0
run_end_timestamp = utc_now_timestamp()
env.flush()
stdout = Path(f"{out}.stdout").read_text()
res = PgBenchRunResult.parse_from_stdout(
stdout=stdout,
run_duration=run_duration,
run_start_timestamp=run_start_timestamp,
run_end_timestamp=run_end_timestamp,
)
env.zenbenchmark.record_pg_bench_result(prefix, res)
#
# Run a very short pgbench test.
# Initialize a pgbench database, and run pgbench against it.
#
# Collects three metrics:
# This makes runs two different pgbench workloads against the same
# initialized database, and 'duration' is the time of each run. So
# the total runtime is 2 * duration, plus time needed to initialize
# the test database.
#
# 1. Time to initialize the pgbench database (pgbench -s5 -i)
# 2. Time to run 5000 pgbench transactions
# 3. Disk space used
#
def test_pgbench(zenith_with_baseline: PgCompare):
env = zenith_with_baseline
# Currently, the # of connections is hardcoded at 4
def run_test_pgbench(env: PgCompare, scale: int, duration: int):
with env.record_pageserver_writes('pageserver_writes'):
with env.record_duration('init'):
env.pg_bin.run_capture(['pgbench', '-s5', '-i', env.pg.connstr()])
env.flush()
# Record the scale and initialize
env.zenbenchmark.record("scale", scale, '', MetricReport.TEST_PARAM)
init_pgbench(env, ['pgbench', f'-s{scale}', '-i', env.pg.connstr()])
with env.record_duration('5000_xacts'):
env.pg_bin.run_capture(['pgbench', '-c1', '-t5000', env.pg.connstr()])
env.flush()
# Run simple-update workload
run_pgbench(env,
"simple-update",
['pgbench', '-n', '-c4', f'-T{duration}', '-P2', '-Mprepared', env.pg.connstr()])
# Run SELECT workload
run_pgbench(env,
"select-only",
['pgbench', '-S', '-c4', f'-T{duration}', '-P2', '-Mprepared', env.pg.connstr()])
env.report_size()
def get_durations_matrix():
durations = os.getenv("TEST_PG_BENCH_DURATIONS_MATRIX", default="45")
return list(map(int, durations.split(",")))
def get_scales_matrix():
scales = os.getenv("TEST_PG_BENCH_SCALES_MATRIX", default="10")
return list(map(int, scales.split(",")))
# Run the pgbench tests against vanilla Postgres and zenith
@pytest.mark.parametrize("scale", get_scales_matrix())
@pytest.mark.parametrize("duration", get_durations_matrix())
def test_pgbench(zenith_with_baseline: PgCompare, scale: int, duration: int):
run_test_pgbench(zenith_with_baseline, scale, duration)
# Run the pgbench tests against an existing Postgres cluster
@pytest.mark.parametrize("scale", get_scales_matrix())
@pytest.mark.parametrize("duration", get_durations_matrix())
@pytest.mark.remote_cluster
def test_pgbench_remote(remote_compare: PgCompare, scale: int, duration: int):
run_test_pgbench(remote_compare, scale, duration)

View File

@@ -1,124 +0,0 @@
import dataclasses
import os
import subprocess
from typing import List
from fixtures.benchmark_fixture import PgBenchRunResult, ZenithBenchmarker
import pytest
from datetime import datetime
import calendar
import timeit
import os
def utc_now_timestamp() -> int:
return calendar.timegm(datetime.utcnow().utctimetuple())
@dataclasses.dataclass
class PgBenchRunner:
connstr: str
scale: int
transactions: int
pgbench_bin_path: str = "pgbench"
def invoke(self, args: List[str]) -> 'subprocess.CompletedProcess[str]':
res = subprocess.run([self.pgbench_bin_path, *args], text=True, capture_output=True)
if res.returncode != 0:
raise RuntimeError(f"pgbench failed. stdout: {res.stdout} stderr: {res.stderr}")
return res
def init(self, vacuum: bool = True) -> 'subprocess.CompletedProcess[str]':
args = []
if not vacuum:
args.append("--no-vacuum")
args.extend([f"--scale={self.scale}", "--initialize", self.connstr])
return self.invoke(args)
def run(self, jobs: int = 1, clients: int = 1):
return self.invoke([
f"--transactions={self.transactions}",
f"--jobs={jobs}",
f"--client={clients}",
"--progress=2", # print progress every two seconds
self.connstr,
])
@pytest.fixture
def connstr():
res = os.getenv("BENCHMARK_CONNSTR")
if res is None:
raise ValueError("no connstr provided, use BENCHMARK_CONNSTR environment variable")
return res
def get_transactions_matrix():
transactions = os.getenv("TEST_PG_BENCH_TRANSACTIONS_MATRIX")
if transactions is None:
return [10**4, 10**5]
return list(map(int, transactions.split(",")))
def get_scales_matrix():
scales = os.getenv("TEST_PG_BENCH_SCALES_MATRIX")
if scales is None:
return [10, 20]
return list(map(int, scales.split(",")))
@pytest.mark.parametrize("scale", get_scales_matrix())
@pytest.mark.parametrize("transactions", get_transactions_matrix())
@pytest.mark.remote_cluster
def test_pg_bench_remote_cluster(zenbenchmark: ZenithBenchmarker,
connstr: str,
scale: int,
transactions: int):
"""
The best way is to run same pack of tests both, for local zenith
and against staging, but currently local tests heavily depend on
things available only locally e.g. zenith binaries, pageserver api, etc.
Also separate test allows to run pgbench workload against vanilla postgres
or other systems that support postgres protocol.
Also now this is more of a liveness test because it stresses pageserver internals,
so we clearly see what goes wrong in more "real" environment.
"""
pg_bin = os.getenv("PG_BIN")
if pg_bin is not None:
pgbench_bin_path = os.path.join(pg_bin, "pgbench")
else:
pgbench_bin_path = "pgbench"
runner = PgBenchRunner(
connstr=connstr,
scale=scale,
transactions=transactions,
pgbench_bin_path=pgbench_bin_path,
)
# calculate timestamps and durations separately
# timestamp is intended to be used for linking to grafana and logs
# duration is actually a metric and uses float instead of int for timestamp
init_start_timestamp = utc_now_timestamp()
t0 = timeit.default_timer()
runner.init()
init_duration = timeit.default_timer() - t0
init_end_timestamp = utc_now_timestamp()
run_start_timestamp = utc_now_timestamp()
t0 = timeit.default_timer()
out = runner.run() # TODO handle failures
run_duration = timeit.default_timer() - t0
run_end_timestamp = utc_now_timestamp()
res = PgBenchRunResult.parse_from_output(
out=out,
init_duration=init_duration,
init_start_timestamp=init_start_timestamp,
init_end_timestamp=init_end_timestamp,
run_duration=run_duration,
run_start_timestamp=run_start_timestamp,
run_end_timestamp=run_end_timestamp,
)
zenbenchmark.record_pg_bench_result(res)

View File

@@ -6,6 +6,7 @@ use lazy_static::lazy_static;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use tracing::*;
@@ -37,8 +38,10 @@ lazy_static! {
.expect("Failed to register safekeeper_persist_control_file_seconds histogram vec");
}
pub trait Storage {
/// Persist safekeeper state on disk.
/// Storage should keep actual state inside of it. It should implement Deref
/// trait to access state fields and have persist method for updating that state.
pub trait Storage: Deref<Target = SafeKeeperState> {
/// Persist safekeeper state on disk and update internal state.
fn persist(&mut self, s: &SafeKeeperState) -> Result<()>;
}
@@ -48,19 +51,47 @@ pub struct FileStorage {
timeline_dir: PathBuf,
conf: SafeKeeperConf,
persist_control_file_seconds: Histogram,
/// Last state persisted to disk.
state: SafeKeeperState,
}
impl FileStorage {
pub fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> FileStorage {
pub fn restore_new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> Result<FileStorage> {
let timeline_dir = conf.timeline_dir(zttid);
let tenant_id = zttid.tenant_id.to_string();
let timeline_id = zttid.timeline_id.to_string();
FileStorage {
let state = Self::load_control_file_conf(conf, zttid)?;
Ok(FileStorage {
timeline_dir,
conf: conf.clone(),
persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS
.with_label_values(&[&tenant_id, &timeline_id]),
}
state,
})
}
pub fn create_new(
zttid: &ZTenantTimelineId,
conf: &SafeKeeperConf,
state: SafeKeeperState,
) -> Result<FileStorage> {
let timeline_dir = conf.timeline_dir(zttid);
let tenant_id = zttid.tenant_id.to_string();
let timeline_id = zttid.timeline_id.to_string();
let mut store = FileStorage {
timeline_dir,
conf: conf.clone(),
persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS
.with_label_values(&[&tenant_id, &timeline_id]),
state: state.clone(),
};
store.persist(&state)?;
Ok(store)
}
// Check the magic/version in the on-disk data and deserialize it, if possible.
@@ -141,6 +172,14 @@ impl FileStorage {
}
}
impl Deref for FileStorage {
type Target = SafeKeeperState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl Storage for FileStorage {
// persists state durably to underlying storage
// for description see https://lwn.net/Articles/457667/
@@ -201,6 +240,9 @@ impl Storage for FileStorage {
.and_then(|f| f.sync_all())
.context("failed to sync control file directory")?;
}
// update internal state
self.state = s.clone();
Ok(())
}
}
@@ -228,7 +270,7 @@ mod test {
) -> Result<(FileStorage, SafeKeeperState)> {
fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir");
Ok((
FileStorage::new(zttid, conf),
FileStorage::restore_new(zttid, conf)?,
FileStorage::load_control_file_conf(conf, zttid)?,
))
}
@@ -239,8 +281,7 @@ mod test {
) -> Result<(FileStorage, SafeKeeperState)> {
fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir");
let state = SafeKeeperState::empty();
let mut storage = FileStorage::new(zttid, conf);
storage.persist(&state)?;
let storage = FileStorage::create_new(zttid, conf, state.clone())?;
Ok((storage, state))
}

View File

@@ -210,6 +210,7 @@ pub struct SafekeeperMemState {
pub s3_wal_lsn: Lsn, // TODO: keep only persistent version
pub peer_horizon_lsn: Lsn,
pub remote_consistent_lsn: Lsn,
pub proposer_uuid: PgUuid,
}
impl SafeKeeperState {
@@ -502,9 +503,8 @@ pub struct SafeKeeper<CTRL: control_file::Storage, WAL: wal_storage::Storage> {
epoch_start_lsn: Lsn,
pub inmem: SafekeeperMemState, // in memory part
pub s: SafeKeeperState, // persistent part
pub state: CTRL, // persistent state storage
pub control_store: CTRL,
pub wal_store: WAL,
}
@@ -516,14 +516,14 @@ where
// constructor
pub fn new(
ztli: ZTimelineId,
control_store: CTRL,
state: CTRL,
mut wal_store: WAL,
state: SafeKeeperState,
) -> Result<SafeKeeper<CTRL, WAL>> {
if state.timeline_id != ZTimelineId::from([0u8; 16]) && ztli != state.timeline_id {
bail!("Calling SafeKeeper::new with inconsistent ztli ({}) and SafeKeeperState.server.timeline_id ({})", ztli, state.timeline_id);
}
// initialize wal_store, if state is already initialized
wal_store.init_storage(&state)?;
Ok(SafeKeeper {
@@ -535,23 +535,25 @@ where
s3_wal_lsn: state.s3_wal_lsn,
peer_horizon_lsn: state.peer_horizon_lsn,
remote_consistent_lsn: state.remote_consistent_lsn,
proposer_uuid: state.proposer_uuid,
},
s: state,
control_store,
state,
wal_store,
})
}
/// Get history of term switches for the available WAL
fn get_term_history(&self) -> TermHistory {
self.s
self.state
.acceptor_state
.term_history
.up_to(self.wal_store.flush_lsn())
}
pub fn get_epoch(&self) -> Term {
self.s.acceptor_state.get_epoch(self.wal_store.flush_lsn())
self.state
.acceptor_state
.get_epoch(self.wal_store.flush_lsn())
}
/// Process message from proposer and possibly form reply. Concurrent
@@ -587,46 +589,47 @@ where
);
}
/* Postgres upgrade is not treated as fatal error */
if msg.pg_version != self.s.server.pg_version
&& self.s.server.pg_version != UNKNOWN_SERVER_VERSION
if msg.pg_version != self.state.server.pg_version
&& self.state.server.pg_version != UNKNOWN_SERVER_VERSION
{
info!(
"incompatible server version {}, expected {}",
msg.pg_version, self.s.server.pg_version
msg.pg_version, self.state.server.pg_version
);
}
if msg.tenant_id != self.s.tenant_id {
if msg.tenant_id != self.state.tenant_id {
bail!(
"invalid tenant ID, got {}, expected {}",
msg.tenant_id,
self.s.tenant_id
self.state.tenant_id
);
}
if msg.ztli != self.s.timeline_id {
if msg.ztli != self.state.timeline_id {
bail!(
"invalid timeline ID, got {}, expected {}",
msg.ztli,
self.s.timeline_id
self.state.timeline_id
);
}
// set basic info about server, if not yet
// TODO: verify that is doesn't change after
self.s.server.system_id = msg.system_id;
self.s.server.wal_seg_size = msg.wal_seg_size;
self.control_store
.persist(&self.s)
.context("failed to persist shared state")?;
{
let mut state = self.state.clone();
state.server.system_id = msg.system_id;
state.server.wal_seg_size = msg.wal_seg_size;
self.state.persist(&state)?;
}
// pass wal_seg_size to read WAL and find flush_lsn
self.wal_store.init_storage(&self.s)?;
self.wal_store.init_storage(&self.state)?;
info!(
"processed greeting from proposer {:?}, sending term {:?}",
msg.proposer_id, self.s.acceptor_state.term
msg.proposer_id, self.state.acceptor_state.term
);
Ok(Some(AcceptorProposerMessage::Greeting(AcceptorGreeting {
term: self.s.acceptor_state.term,
term: self.state.acceptor_state.term,
})))
}
@@ -637,17 +640,19 @@ where
) -> Result<Option<AcceptorProposerMessage>> {
// initialize with refusal
let mut resp = VoteResponse {
term: self.s.acceptor_state.term,
term: self.state.acceptor_state.term,
vote_given: false as u64,
flush_lsn: self.wal_store.flush_lsn(),
truncate_lsn: self.s.peer_horizon_lsn,
truncate_lsn: self.state.peer_horizon_lsn,
term_history: self.get_term_history(),
};
if self.s.acceptor_state.term < msg.term {
self.s.acceptor_state.term = msg.term;
if self.state.acceptor_state.term < msg.term {
let mut state = self.state.clone();
state.acceptor_state.term = msg.term;
// persist vote before sending it out
self.control_store.persist(&self.s)?;
resp.term = self.s.acceptor_state.term;
self.state.persist(&state)?;
resp.term = self.state.acceptor_state.term;
resp.vote_given = true as u64;
}
info!("processed VoteRequest for term {}: {:?}", msg.term, &resp);
@@ -656,9 +661,10 @@ where
/// Bump our term if received a note from elected proposer with higher one
fn bump_if_higher(&mut self, term: Term) -> Result<()> {
if self.s.acceptor_state.term < term {
self.s.acceptor_state.term = term;
self.control_store.persist(&self.s)?;
if self.state.acceptor_state.term < term {
let mut state = self.state.clone();
state.acceptor_state.term = term;
self.state.persist(&state)?;
}
Ok(())
}
@@ -666,9 +672,9 @@ where
/// Form AppendResponse from current state.
fn append_response(&self) -> AppendResponse {
let ar = AppendResponse {
term: self.s.acceptor_state.term,
term: self.state.acceptor_state.term,
flush_lsn: self.wal_store.flush_lsn(),
commit_lsn: self.s.commit_lsn,
commit_lsn: self.state.commit_lsn,
// will be filled by the upper code to avoid bothering safekeeper
hs_feedback: HotStandbyFeedback::empty(),
zenith_feedback: ZenithFeedback::empty(),
@@ -681,7 +687,7 @@ where
info!("received ProposerElected {:?}", msg);
self.bump_if_higher(msg.term)?;
// If our term is higher, ignore the message (next feedback will inform the compute)
if self.s.acceptor_state.term > msg.term {
if self.state.acceptor_state.term > msg.term {
return Ok(None);
}
@@ -692,8 +698,11 @@ where
self.wal_store.truncate_wal(msg.start_streaming_at)?;
// and now adopt term history from proposer
self.s.acceptor_state.term_history = msg.term_history.clone();
self.control_store.persist(&self.s)?;
{
let mut state = self.state.clone();
state.acceptor_state.term_history = msg.term_history.clone();
self.state.persist(&state)?;
}
info!("start receiving WAL since {:?}", msg.start_streaming_at);
@@ -715,13 +724,13 @@ where
// Also note that commit_lsn can reach epoch_start_lsn earlier
// that we receive new epoch_start_lsn, and we still need to sync
// control file in this case.
if commit_lsn == self.epoch_start_lsn && self.s.commit_lsn != commit_lsn {
if commit_lsn == self.epoch_start_lsn && self.state.commit_lsn != commit_lsn {
self.persist_control_file()?;
}
// We got our first commit_lsn, which means we should sync
// everything to disk, to initialize the state.
if self.s.commit_lsn == Lsn(0) && commit_lsn > Lsn(0) {
if self.state.commit_lsn == Lsn(0) && commit_lsn > Lsn(0) {
self.wal_store.flush_wal()?;
self.persist_control_file()?;
}
@@ -731,10 +740,12 @@ where
/// Persist in-memory state to the disk.
fn persist_control_file(&mut self) -> Result<()> {
self.s.commit_lsn = self.inmem.commit_lsn;
self.s.peer_horizon_lsn = self.inmem.peer_horizon_lsn;
let mut state = self.state.clone();
self.control_store.persist(&self.s)
state.commit_lsn = self.inmem.commit_lsn;
state.peer_horizon_lsn = self.inmem.peer_horizon_lsn;
state.proposer_uuid = self.inmem.proposer_uuid;
self.state.persist(&state)
}
/// Handle request to append WAL.
@@ -744,13 +755,13 @@ where
msg: &AppendRequest,
require_flush: bool,
) -> Result<Option<AcceptorProposerMessage>> {
if self.s.acceptor_state.term < msg.h.term {
if self.state.acceptor_state.term < msg.h.term {
bail!("got AppendRequest before ProposerElected");
}
// If our term is higher, immediately refuse the message.
if self.s.acceptor_state.term > msg.h.term {
let resp = AppendResponse::term_only(self.s.acceptor_state.term);
if self.state.acceptor_state.term > msg.h.term {
let resp = AppendResponse::term_only(self.state.acceptor_state.term);
return Ok(Some(AcceptorProposerMessage::AppendResponse(resp)));
}
@@ -758,8 +769,7 @@ where
// processing the message.
self.epoch_start_lsn = msg.h.epoch_start_lsn;
// TODO: don't update state without persisting to disk
self.s.proposer_uuid = msg.h.proposer_uuid;
self.inmem.proposer_uuid = msg.h.proposer_uuid;
// do the job
if !msg.wal_data.is_empty() {
@@ -790,7 +800,7 @@ where
// Update truncate and commit LSN in control file.
// To avoid negative impact on performance of extra fsync, do it only
// when truncate_lsn delta exceeds WAL segment size.
if self.s.peer_horizon_lsn + (self.s.server.wal_seg_size as u64)
if self.state.peer_horizon_lsn + (self.state.server.wal_seg_size as u64)
< self.inmem.peer_horizon_lsn
{
self.persist_control_file()?;
@@ -829,6 +839,8 @@ where
#[cfg(test)]
mod tests {
use std::ops::Deref;
use super::*;
use crate::wal_storage::Storage;
@@ -844,6 +856,14 @@ mod tests {
}
}
impl Deref for InMemoryState {
type Target = SafeKeeperState;
fn deref(&self) -> &Self::Target {
&self.persisted_state
}
}
struct DummyWalStore {
lsn: Lsn,
}
@@ -879,7 +899,7 @@ mod tests {
};
let wal_store = DummyWalStore { lsn: Lsn(0) };
let ztli = ZTimelineId::from([0u8; 16]);
let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::empty()).unwrap();
let mut sk = SafeKeeper::new(ztli, storage, wal_store).unwrap();
// check voting for 1 is ok
let vote_request = ProposerAcceptorMessage::VoteRequest(VoteRequest { term: 1 });
@@ -890,11 +910,11 @@ mod tests {
}
// reboot...
let state = sk.control_store.persisted_state.clone();
let state = sk.state.persisted_state.clone();
let storage = InMemoryState {
persisted_state: state.clone(),
persisted_state: state,
};
sk = SafeKeeper::new(ztli, storage, sk.wal_store, state).unwrap();
sk = SafeKeeper::new(ztli, storage, sk.wal_store).unwrap();
// and ensure voting second time for 1 is not ok
vote_resp = sk.process_msg(&vote_request);
@@ -911,7 +931,7 @@ mod tests {
};
let wal_store = DummyWalStore { lsn: Lsn(0) };
let ztli = ZTimelineId::from([0u8; 16]);
let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::empty()).unwrap();
let mut sk = SafeKeeper::new(ztli, storage, wal_store).unwrap();
let mut ar_hdr = AppendRequestHeader {
term: 1,

View File

@@ -21,7 +21,6 @@ use crate::broker::SafekeeperInfo;
use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey};
use crate::control_file;
use crate::control_file::Storage as cf_storage;
use crate::safekeeper::{
AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState,
SafekeeperMemState,
@@ -98,10 +97,9 @@ impl SharedState {
peer_ids: Vec<ZNodeId>,
) -> Result<Self> {
let state = SafeKeeperState::new(zttid, peer_ids);
let control_store = control_file::FileStorage::new(zttid, conf);
let control_store = control_file::FileStorage::create_new(zttid, conf, state)?;
let wal_store = wal_storage::PhysicalStorage::new(zttid, conf);
let mut sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state)?;
sk.control_store.persist(&sk.s)?;
let sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store)?;
Ok(Self {
notified_commit_lsn: Lsn(0),
@@ -116,18 +114,14 @@ impl SharedState {
/// Restore SharedState from control file.
/// If file doesn't exist, bails out.
fn restore(conf: &SafeKeeperConf, zttid: &ZTenantTimelineId) -> Result<Self> {
let state = control_file::FileStorage::load_control_file_conf(conf, zttid)
.context("failed to load from control file")?;
let control_store = control_file::FileStorage::new(zttid, conf);
let control_store = control_file::FileStorage::restore_new(zttid, conf)?;
let wal_store = wal_storage::PhysicalStorage::new(zttid, conf);
info!("timeline {} restored", zttid.timeline_id);
Ok(Self {
notified_commit_lsn: Lsn(0),
sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state)?,
sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store)?,
replicas: Vec::new(),
active: false,
num_computes: 0,
@@ -419,7 +413,7 @@ impl Timeline {
pub fn get_state(&self) -> (SafekeeperMemState, SafeKeeperState) {
let shared_state = self.mutex.lock().unwrap();
(shared_state.sk.inmem.clone(), shared_state.sk.s.clone())
(shared_state.sk.inmem.clone(), shared_state.sk.state.clone())
}
/// Prepare public safekeeper info for reporting.

4
workspace_hack/.gitattributes vendored Normal file
View File

@@ -0,0 +1,4 @@
# Avoid putting conflict markers in the generated Cargo.toml file, since their presence breaks
# Cargo.
# Also do not check out the file as CRLF on Windows, as that's what hakari needs.
Cargo.toml merge=binary -crlf

View File

@@ -16,32 +16,38 @@ publish = false
[dependencies]
anyhow = { version = "1", features = ["backtrace", "std"] }
bytes = { version = "1", features = ["serde", "std"] }
chrono = { version = "0.4", features = ["clock", "libc", "oldtime", "serde", "std", "time", "winapi"] }
clap = { version = "2", features = ["ansi_term", "atty", "color", "strsim", "suggestions", "vec_map"] }
either = { version = "1", features = ["use_std"] }
hashbrown = { version = "0.11", features = ["ahash", "inline-more", "raw"] }
indexmap = { version = "1", default-features = false, features = ["std"] }
libc = { version = "0.2", features = ["extra_traits", "std"] }
log = { version = "0.4", default-features = false, features = ["serde", "std"] }
memchr = { version = "2", features = ["std", "use_std"] }
num-integer = { version = "0.1", default-features = false, features = ["std"] }
num-traits = { version = "0.2", features = ["std"] }
prost = { version = "0.9", features = ["prost-derive", "std"] }
rand = { version = "0.8", features = ["alloc", "getrandom", "libc", "rand_chacha", "rand_hc", "small_rng", "std", "std_rng"] }
regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cache", "perf-dfa", "perf-inline", "perf-literal", "std", "unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] }
regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] }
reqwest = { version = "0.11", default-features = false, features = ["__rustls", "__tls", "blocking", "hyper-rustls", "json", "rustls", "rustls-pemfile", "rustls-tls", "rustls-tls-webpki-roots", "serde_json", "stream", "tokio-rustls", "tokio-util", "webpki-roots"] }
scopeguard = { version = "1", features = ["use_std"] }
serde = { version = "1", features = ["alloc", "derive", "serde_derive", "std"] }
tokio = { version = "1", features = ["bytes", "fs", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "once_cell", "process", "rt", "rt-multi-thread", "signal-hook-registry", "sync", "time", "tokio-macros"] }
tracing = { version = "0.1", features = ["attributes", "std", "tracing-attributes"] }
tokio = { version = "1", features = ["bytes", "fs", "io-std", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "once_cell", "process", "rt", "rt-multi-thread", "signal-hook-registry", "socket2", "sync", "time", "tokio-macros"] }
tracing = { version = "0.1", features = ["attributes", "log", "std", "tracing-attributes"] }
tracing-core = { version = "0.1", features = ["lazy_static", "std"] }
[build-dependencies]
anyhow = { version = "1", features = ["backtrace", "std"] }
bytes = { version = "1", features = ["serde", "std"] }
cc = { version = "1", default-features = false, features = ["jobserver", "parallel"] }
clap = { version = "2", features = ["ansi_term", "atty", "color", "strsim", "suggestions", "vec_map"] }
either = { version = "1", features = ["use_std"] }
hashbrown = { version = "0.11", features = ["ahash", "inline-more", "raw"] }
indexmap = { version = "1", default-features = false, features = ["std"] }
libc = { version = "0.2", features = ["extra_traits", "std"] }
log = { version = "0.4", default-features = false, features = ["serde", "std"] }
memchr = { version = "2", features = ["std", "use_std"] }
proc-macro2 = { version = "1", features = ["proc-macro"] }
quote = { version = "1", features = ["proc-macro"] }
prost = { version = "0.9", features = ["prost-derive", "std"] }
regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cache", "perf-dfa", "perf-inline", "perf-literal", "std", "unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] }
regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] }
serde = { version = "1", features = ["alloc", "derive", "serde_derive", "std"] }

2
workspace_hack/build.rs Normal file
View File

@@ -0,0 +1,2 @@
// A build script is required for cargo to consider build dependencies.
fn main() {}

View File

@@ -375,9 +375,8 @@ impl PostgresBackend {
}
AuthType::MD5 => {
rand::thread_rng().fill(&mut self.md5_salt);
let md5_salt = self.md5_salt;
self.write_message(&BeMessage::AuthenticationMD5Password(
&md5_salt,
self.md5_salt,
))?;
self.state = ProtoState::Authentication;
}

View File

@@ -401,7 +401,8 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
#[derive(Debug)]
pub enum BeMessage<'a> {
AuthenticationOk,
AuthenticationMD5Password(&'a [u8; 4]),
AuthenticationMD5Password([u8; 4]),
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
AuthenticationCleartextPassword,
BackendKeyData(CancelKeyData),
BindComplete,
@@ -429,6 +430,13 @@ pub enum BeMessage<'a> {
KeepAlive(WalSndKeepAlive),
}
#[derive(Debug)]
pub enum BeAuthenticationSaslMessage<'a> {
Methods(&'a [&'a str]),
Continue(&'a [u8]),
Final(&'a [u8]),
}
#[derive(Debug)]
pub enum BeParameterStatusMessage<'a> {
Encoding(&'a str),
@@ -611,6 +619,32 @@ impl<'a> BeMessage<'a> {
.unwrap(); // write into BytesMut can't fail
}
BeMessage::AuthenticationSasl(msg) => {
buf.put_u8(b'R');
write_body(buf, |buf| {
use BeAuthenticationSaslMessage::*;
match msg {
Methods(methods) => {
buf.put_i32(10); // Specifies that SASL auth method is used.
for method in methods.iter() {
write_cstr(method.as_bytes(), buf)?;
}
buf.put_u8(0); // zero terminator for the list
}
Continue(extra) => {
buf.put_i32(11); // Continue SASL auth.
buf.put_slice(extra);
}
Final(extra) => {
buf.put_i32(12); // Send final SASL message.
buf.put_slice(extra);
}
}
Ok::<_, io::Error>(())
})
.unwrap()
}
BeMessage::BackendKeyData(key_data) => {
buf.put_u8(b'K');
write_body(buf, |buf| {