mirror of
https://github.com/neondatabase/neon.git
synced 2026-03-03 16:30:38 +00:00
Compare commits
86 Commits
cli-overri
...
ars/tmp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23644ed251 | ||
|
|
99e0f07a1d | ||
|
|
5d490babf8 | ||
|
|
5865f85ae2 | ||
|
|
b815f5fb9f | ||
|
|
74a0942a77 | ||
|
|
1a4682a04a | ||
|
|
993b544ad0 | ||
|
|
dba1d36a4a | ||
|
|
ca81a550ef | ||
|
|
65a0b2736b | ||
|
|
cca886682b | ||
|
|
c8f47cd38e | ||
|
|
92787159f7 | ||
|
|
abb422d5de | ||
|
|
fdc15de8b2 | ||
|
|
207286f2b8 | ||
|
|
d2b896381a | ||
|
|
009f6d4ae8 | ||
|
|
1b31379456 | ||
|
|
4c64b10aec | ||
|
|
ad262a46ad | ||
|
|
ce533835e5 | ||
|
|
e5bf520b18 | ||
|
|
9512e21b9e | ||
|
|
a26d565282 | ||
|
|
a47dade622 | ||
|
|
9cce430430 | ||
|
|
4bf4bacf01 | ||
|
|
335abfcc28 | ||
|
|
afb3342e46 | ||
|
|
5563ff123f | ||
|
|
0a557b2fa9 | ||
|
|
9632c352ab | ||
|
|
328e3b4189 | ||
|
|
47f6a1f9a8 | ||
|
|
a4829712f4 | ||
|
|
d4d26f619d | ||
|
|
36481f3374 | ||
|
|
d951dd8977 | ||
|
|
ea13838be7 | ||
|
|
b51f23cdf0 | ||
|
|
3cfcdb92ed | ||
|
|
d7af965982 | ||
|
|
7c1c7702d2 | ||
|
|
6eef401602 | ||
|
|
c5b5905ed3 | ||
|
|
76b74349cb | ||
|
|
b08e340f60 | ||
|
|
a25fa29bc9 | ||
|
|
ccf3c8cc30 | ||
|
|
c45ee13b4e | ||
|
|
f1e7db9d0d | ||
|
|
fa8a6c0e94 | ||
|
|
1e8ca497e0 | ||
|
|
a504cc87ab | ||
|
|
5268bbc840 | ||
|
|
e1d770939b | ||
|
|
2866a9e82e | ||
|
|
b67cddb303 | ||
|
|
cb1d84d980 | ||
|
|
642797b69e | ||
|
|
3ed156a5b6 | ||
|
|
2d93b129a0 | ||
|
|
32c7859659 | ||
|
|
729ac38ea8 | ||
|
|
d69b0539ba | ||
|
|
ec78babad2 | ||
|
|
9350dfb215 | ||
|
|
8ac8be5206 | ||
|
|
c2927353a5 | ||
|
|
33251a9d8f | ||
|
|
c045ae7a9b | ||
|
|
602ccb7d5f | ||
|
|
5df21e1058 | ||
|
|
08135910a5 | ||
|
|
f58a22d07e | ||
|
|
cedde559b8 | ||
|
|
49d1d1ddf9 | ||
|
|
86045ac36c | ||
|
|
79f0e44a20 | ||
|
|
c44695f34b | ||
|
|
5abe2129c6 | ||
|
|
63dd7bce7e | ||
|
|
f3c73f5797 | ||
|
|
e6f2d70517 |
@@ -54,7 +54,8 @@ jobs:
|
||||
if [ ! -e tmp_install/bin/postgres ]; then
|
||||
# "depth 1" saves some time by not cloning the whole repo
|
||||
git submodule update --init --depth 1
|
||||
make postgres -j$(nproc)
|
||||
# bail out on any warnings
|
||||
COPT='-Werror' mold -run make postgres -j$(nproc)
|
||||
fi
|
||||
|
||||
- save_cache:
|
||||
@@ -110,7 +111,7 @@ jobs:
|
||||
fi
|
||||
|
||||
export CARGO_INCREMENTAL=0
|
||||
"${cov_prefix[@]}" cargo build $CARGO_FLAGS --bins --tests
|
||||
"${cov_prefix[@]}" mold -run cargo build $CARGO_FLAGS --bins --tests
|
||||
|
||||
- save_cache:
|
||||
name: Save rust cache
|
||||
@@ -194,6 +195,14 @@ jobs:
|
||||
command: |
|
||||
cp -a tmp_install /tmp/zenith/pg_install
|
||||
|
||||
- run:
|
||||
name: Merge coverage data
|
||||
command: |
|
||||
# This will speed up workspace uploads
|
||||
if [[ $BUILD_TYPE == "debug" ]]; then
|
||||
scripts/coverage "--profraw-prefix=$CIRCLE_JOB" --dir=/tmp/zenith/coverage merge
|
||||
fi
|
||||
|
||||
# Save the rust binaries and coverage data for other jobs in this workflow.
|
||||
- persist_to_workspace:
|
||||
root: /tmp/zenith
|
||||
@@ -204,9 +213,16 @@ jobs:
|
||||
executor: zenith-executor
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- v1-python-deps-{{ checksum "poetry.lock" }}
|
||||
- run:
|
||||
name: Install deps
|
||||
command: ./scripts/pysync
|
||||
- save_cache:
|
||||
key: v1-python-deps-{{ checksum "poetry.lock" }}
|
||||
paths:
|
||||
- /home/circleci/.cache/pypoetry/virtualenvs
|
||||
- run:
|
||||
name: Run yapf to ensure code format
|
||||
when: always
|
||||
@@ -256,9 +272,16 @@ jobs:
|
||||
condition: << parameters.needs_postgres_source >>
|
||||
steps:
|
||||
- run: git submodule update --init --depth 1
|
||||
- restore_cache:
|
||||
keys:
|
||||
- v1-python-deps-{{ checksum "poetry.lock" }}
|
||||
- run:
|
||||
name: Install deps
|
||||
command: ./scripts/pysync
|
||||
- save_cache:
|
||||
key: v1-python-deps-{{ checksum "poetry.lock" }}
|
||||
paths:
|
||||
- /home/circleci/.cache/pypoetry/virtualenvs
|
||||
- run:
|
||||
name: Run pytest
|
||||
# pytest doesn't output test logs in real time, so CI job may fail with
|
||||
@@ -275,6 +298,7 @@ jobs:
|
||||
- PLATFORM: zenith-local-ci
|
||||
command: |
|
||||
PERF_REPORT_DIR="$(realpath test_runner/perf-report-local)"
|
||||
rm -rf $PERF_REPORT_DIR
|
||||
|
||||
TEST_SELECTION="test_runner/<< parameters.test_selection >>"
|
||||
EXTRA_PARAMS="<< parameters.extra_params >>"
|
||||
@@ -319,7 +343,6 @@ jobs:
|
||||
|
||||
if << parameters.save_perf_report >>; then
|
||||
if [[ $CIRCLE_BRANCH == "main" ]]; then
|
||||
# TODO: reuse scripts/git-upload
|
||||
export REPORT_FROM="$PERF_REPORT_DIR"
|
||||
export REPORT_TO=local
|
||||
scripts/generate_and_push_perf_report.sh
|
||||
@@ -340,6 +363,13 @@ jobs:
|
||||
# The store_test_results step tells CircleCI where to find the junit.xml file.
|
||||
- store_test_results:
|
||||
path: /tmp/test_output
|
||||
- run:
|
||||
name: Merge coverage data
|
||||
command: |
|
||||
# This will speed up workspace uploads
|
||||
if [[ $BUILD_TYPE == "debug" ]]; then
|
||||
scripts/coverage "--profraw-prefix=$CIRCLE_JOB" --dir=/tmp/zenith/coverage merge
|
||||
fi
|
||||
# Save coverage data (if any)
|
||||
- persist_to_workspace:
|
||||
root: /tmp/zenith
|
||||
@@ -568,6 +598,7 @@ workflows:
|
||||
- build-postgres-<< matrix.build_type >>
|
||||
- run-pytest:
|
||||
name: pg_regress-tests-<< matrix.build_type >>
|
||||
context: PERF_TEST_RESULT_CONNSTR
|
||||
matrix:
|
||||
parameters:
|
||||
build_type: ["debug", "release"]
|
||||
@@ -585,6 +616,7 @@ workflows:
|
||||
- build-zenith-<< matrix.build_type >>
|
||||
- run-pytest:
|
||||
name: benchmarks
|
||||
context: PERF_TEST_RESULT_CONNSTR
|
||||
build_type: release
|
||||
test_selection: performance
|
||||
run_in_parallel: false
|
||||
|
||||
@@ -5,6 +5,13 @@ settings:
|
||||
authEndpoint: "https://console.stage.zenith.tech/authenticate_proxy_request/"
|
||||
uri: "https://console.stage.zenith.tech/psql_session/"
|
||||
|
||||
# -- Additional labels for zenith-proxy pods
|
||||
podLabels:
|
||||
zenith_service: proxy
|
||||
zenith_env: staging
|
||||
zenith_region: us-east-1
|
||||
zenith_region_slug: virginia
|
||||
|
||||
exposedService:
|
||||
annotations:
|
||||
service.beta.kubernetes.io/aws-load-balancer-type: external
|
||||
@@ -17,4 +24,4 @@ metrics:
|
||||
serviceMonitor:
|
||||
enabled: true
|
||||
selector:
|
||||
prometheus: zenith
|
||||
release: kube-prometheus-stack
|
||||
|
||||
8
.github/workflows/benchmarking.yml
vendored
8
.github/workflows/benchmarking.yml
vendored
@@ -3,7 +3,7 @@ name: benchmarking
|
||||
on:
|
||||
# uncomment to run on push for debugging your PR
|
||||
# push:
|
||||
# branches: [ mybranch ]
|
||||
# branches: [ your branch ]
|
||||
schedule:
|
||||
# * is a special character in YAML so you have to quote this string
|
||||
# ┌───────────── minute (0 - 59)
|
||||
@@ -41,7 +41,7 @@ jobs:
|
||||
run: |
|
||||
python3 -m pip install --upgrade poetry wheel
|
||||
# since pip/poetry caches are reused there shouldn't be any troubles with install every time
|
||||
poetry install
|
||||
./scripts/pysync
|
||||
|
||||
- name: Show versions
|
||||
run: |
|
||||
@@ -89,11 +89,15 @@ jobs:
|
||||
BENCHMARK_CONNSTR: "${{ secrets.BENCHMARK_STAGING_CONNSTR }}"
|
||||
REMOTE_ENV: "1" # indicate to test harness that we do not have zenith binaries locally
|
||||
run: |
|
||||
# just to be sure that no data was cached on self hosted runner
|
||||
# since it might generate duplicates when calling ingest_perf_test_result.py
|
||||
rm -rf perf-report-staging
|
||||
mkdir -p perf-report-staging
|
||||
./scripts/pytest test_runner/performance/ -v -m "remote_cluster" --skip-interfering-proc-check --out-dir perf-report-staging
|
||||
|
||||
- name: Submit result
|
||||
env:
|
||||
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
|
||||
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
|
||||
run: |
|
||||
REPORT_FROM=$(realpath perf-report-staging) REPORT_TO=staging scripts/generate_and_push_perf_report.sh
|
||||
|
||||
782
Cargo.lock
generated
782
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -16,3 +16,8 @@ members = [
|
||||
# This is useful for profiling and, to some extent, debug.
|
||||
# Besides, debug info should not affect the performance.
|
||||
debug = true
|
||||
|
||||
# This is only needed for proxy's tests
|
||||
# TODO: we should probably fork tokio-postgres-rustls instead
|
||||
[patch.crates-io]
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
[package]
|
||||
name = "compute_tools"
|
||||
version = "0.1.0"
|
||||
authors = ["Alexey Kondratov <kondratov.aleksey@gmail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
libc = "0.2"
|
||||
anyhow = "1.0"
|
||||
chrono = "0.4"
|
||||
clap = "2.33"
|
||||
env_logger = "0.8"
|
||||
clap = "3.0"
|
||||
env_logger = "0.9"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
log = { version = "0.4", features = ["std", "serde"] }
|
||||
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
|
||||
@@ -34,6 +34,7 @@ use std::sync::{Arc, RwLock};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::Utc;
|
||||
use clap::Arg;
|
||||
use log::info;
|
||||
use postgres::{Client, NoTls};
|
||||
|
||||
@@ -162,34 +163,34 @@ fn main() -> Result<()> {
|
||||
let matches = clap::App::new("zenith_ctl")
|
||||
.version(version.unwrap_or("unknown"))
|
||||
.arg(
|
||||
clap::Arg::with_name("connstr")
|
||||
.short("C")
|
||||
Arg::new("connstr")
|
||||
.short('C')
|
||||
.long("connstr")
|
||||
.value_name("DATABASE_URL")
|
||||
.required(true),
|
||||
)
|
||||
.arg(
|
||||
clap::Arg::with_name("pgdata")
|
||||
.short("D")
|
||||
Arg::new("pgdata")
|
||||
.short('D')
|
||||
.long("pgdata")
|
||||
.value_name("DATADIR")
|
||||
.required(true),
|
||||
)
|
||||
.arg(
|
||||
clap::Arg::with_name("pgbin")
|
||||
.short("b")
|
||||
Arg::new("pgbin")
|
||||
.short('b')
|
||||
.long("pgbin")
|
||||
.value_name("POSTGRES_PATH"),
|
||||
)
|
||||
.arg(
|
||||
clap::Arg::with_name("spec")
|
||||
.short("s")
|
||||
Arg::new("spec")
|
||||
.short('s')
|
||||
.long("spec")
|
||||
.value_name("SPEC_JSON"),
|
||||
)
|
||||
.arg(
|
||||
clap::Arg::with_name("spec-path")
|
||||
.short("S")
|
||||
Arg::new("spec-path")
|
||||
.short('S')
|
||||
.long("spec-path")
|
||||
.value_name("SPEC_PATH"),
|
||||
)
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
[package]
|
||||
name = "control_plane"
|
||||
version = "0.1.0"
|
||||
authors = ["Stas Kelvich <stas@zenith.tech>"]
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
tar = "0.4.33"
|
||||
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
toml = "0.5"
|
||||
lazy_static = "1.4"
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# Page server and three safekeepers.
|
||||
[pageserver]
|
||||
listen_pg_addr = 'localhost:64000'
|
||||
listen_http_addr = 'localhost:9898'
|
||||
listen_pg_addr = '127.0.0.1:64000'
|
||||
listen_http_addr = '127.0.0.1:9898'
|
||||
auth_type = 'Trust'
|
||||
|
||||
[[safekeepers]]
|
||||
name = 'sk1'
|
||||
id = 1
|
||||
pg_port = 5454
|
||||
http_port = 7676
|
||||
|
||||
[[safekeepers]]
|
||||
name = 'sk2'
|
||||
id = 2
|
||||
pg_port = 5455
|
||||
http_port = 7677
|
||||
|
||||
[[safekeepers]]
|
||||
name = 'sk3'
|
||||
id = 3
|
||||
pg_port = 5456
|
||||
http_port = 7678
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# Minimal zenith environment with one safekeeper. This is equivalent to the built-in
|
||||
# defaults that you get with no --config
|
||||
[pageserver]
|
||||
listen_pg_addr = 'localhost:64000'
|
||||
listen_http_addr = 'localhost:9898'
|
||||
listen_pg_addr = '127.0.0.1:64000'
|
||||
listen_http_addr = '127.0.0.1:9898'
|
||||
auth_type = 'Trust'
|
||||
|
||||
[[safekeepers]]
|
||||
name = 'single'
|
||||
id = 1
|
||||
pg_port = 5454
|
||||
http_port = 7676
|
||||
|
||||
@@ -334,14 +334,26 @@ impl PostgresNode {
|
||||
if let Some(lsn) = self.lsn {
|
||||
conf.append("recovery_target_lsn", &lsn.to_string());
|
||||
}
|
||||
|
||||
conf.append_line("");
|
||||
// Configure backpressure
|
||||
// - Replication write lag depends on how fast the walreceiver can process incoming WAL.
|
||||
// This lag determines latency of get_page_at_lsn. Speed of applying WAL is about 10MB/sec,
|
||||
// so to avoid expiration of 1 minute timeout, this lag should not be larger than 600MB.
|
||||
// Actually latency should be much smaller (better if < 1sec). But we assume that recently
|
||||
// updates pages are not requested from pageserver.
|
||||
// - Replication flush lag depends on speed of persisting data by checkpointer (creation of
|
||||
// delta/image layers) and advancing disk_consistent_lsn. Safekeepers are able to
|
||||
// remove/archive WAL only beyond disk_consistent_lsn. Too large a lag can cause long
|
||||
// recovery time (in case of pageserver crash) and disk space overflow at safekeepers.
|
||||
// - Replication apply lag depends on speed of uploading changes to S3 by uploader thread.
|
||||
// To be able to restore database in case of pageserver node crash, safekeeper should not
|
||||
// remove WAL beyond this point. Too large lag can cause space exhaustion in safekeepers
|
||||
// (if they are not able to upload WAL to S3).
|
||||
conf.append("max_replication_write_lag", "500MB");
|
||||
conf.append("max_replication_flush_lag", "10GB");
|
||||
|
||||
if !self.env.safekeepers.is_empty() {
|
||||
// Configure backpressure
|
||||
// In setup with safekeepers apply_lag depends on
|
||||
// speed of data checkpointing on pageserver (see disk_consistent_lsn).
|
||||
conf.append("max_replication_apply_lag", "1500MB");
|
||||
|
||||
// Configure the node to connect to the safekeepers
|
||||
conf.append("synchronous_standby_names", "walproposer");
|
||||
|
||||
@@ -354,11 +366,6 @@ impl PostgresNode {
|
||||
.join(",");
|
||||
conf.append("wal_acceptors", &wal_acceptors);
|
||||
} else {
|
||||
// Configure backpressure
|
||||
// In setup without safekeepers, flush_lag depends on
|
||||
// speed of of data checkpointing on pageserver (see disk_consistent_lsn)
|
||||
conf.append("max_replication_flush_lag", "1500MB");
|
||||
|
||||
// We only use setup without safekeepers for tests,
|
||||
// and don't care about data durability on pageserver,
|
||||
// so set more relaxed synchronous_commit.
|
||||
|
||||
@@ -12,7 +12,9 @@ use std::path::{Path, PathBuf};
|
||||
use std::process::{Command, Stdio};
|
||||
use zenith_utils::auth::{encode_from_key_file, Claims, Scope};
|
||||
use zenith_utils::postgres_backend::AuthType;
|
||||
use zenith_utils::zid::{opt_display_serde, ZTenantId};
|
||||
use zenith_utils::zid::{opt_display_serde, ZNodeId, ZTenantId};
|
||||
|
||||
use crate::safekeeper::SafekeeperNode;
|
||||
|
||||
//
|
||||
// This data structures represents zenith CLI config
|
||||
@@ -62,6 +64,8 @@ pub struct LocalEnv {
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(default)]
|
||||
pub struct PageServerConf {
|
||||
// node id
|
||||
pub id: ZNodeId,
|
||||
// Pageserver connection settings
|
||||
pub listen_pg_addr: String,
|
||||
pub listen_http_addr: String,
|
||||
@@ -76,6 +80,7 @@ pub struct PageServerConf {
|
||||
impl Default for PageServerConf {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
id: ZNodeId(0),
|
||||
listen_pg_addr: String::new(),
|
||||
listen_http_addr: String::new(),
|
||||
auth_type: AuthType::Trust,
|
||||
@@ -87,7 +92,7 @@ impl Default for PageServerConf {
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(default)]
|
||||
pub struct SafekeeperConf {
|
||||
pub name: String,
|
||||
pub id: ZNodeId,
|
||||
pub pg_port: u16,
|
||||
pub http_port: u16,
|
||||
pub sync: bool,
|
||||
@@ -96,7 +101,7 @@ pub struct SafekeeperConf {
|
||||
impl Default for SafekeeperConf {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
name: String::new(),
|
||||
id: ZNodeId(0),
|
||||
pg_port: 0,
|
||||
http_port: 0,
|
||||
sync: true,
|
||||
@@ -136,8 +141,8 @@ impl LocalEnv {
|
||||
self.base_data_dir.clone()
|
||||
}
|
||||
|
||||
pub fn safekeeper_data_dir(&self, node_name: &str) -> PathBuf {
|
||||
self.base_data_dir.join("safekeepers").join(node_name)
|
||||
pub fn safekeeper_data_dir(&self, data_dir_name: &str) -> PathBuf {
|
||||
self.base_data_dir.join("safekeepers").join(data_dir_name)
|
||||
}
|
||||
|
||||
/// Create a LocalEnv from a config file.
|
||||
@@ -285,7 +290,7 @@ impl LocalEnv {
|
||||
fs::create_dir_all(self.pg_data_dirs_path())?;
|
||||
|
||||
for safekeeper in &self.safekeepers {
|
||||
fs::create_dir_all(self.safekeeper_data_dir(&safekeeper.name))?;
|
||||
fs::create_dir_all(SafekeeperNode::datadir_path_by_id(self, safekeeper.id))?;
|
||||
}
|
||||
|
||||
let mut conf_content = String::new();
|
||||
|
||||
@@ -15,6 +15,7 @@ use reqwest::blocking::{Client, RequestBuilder, Response};
|
||||
use reqwest::{IntoUrl, Method};
|
||||
use thiserror::Error;
|
||||
use zenith_utils::http::error::HttpErrorBody;
|
||||
use zenith_utils::zid::ZNodeId;
|
||||
|
||||
use crate::local_env::{LocalEnv, SafekeeperConf};
|
||||
use crate::storage::PageServerNode;
|
||||
@@ -61,7 +62,7 @@ impl ResponseErrorMessageExt for Response {
|
||||
//
|
||||
#[derive(Debug)]
|
||||
pub struct SafekeeperNode {
|
||||
pub name: String,
|
||||
pub id: ZNodeId,
|
||||
|
||||
pub conf: SafekeeperConf,
|
||||
|
||||
@@ -77,15 +78,15 @@ impl SafekeeperNode {
|
||||
pub fn from_env(env: &LocalEnv, conf: &SafekeeperConf) -> SafekeeperNode {
|
||||
let pageserver = Arc::new(PageServerNode::from_env(env));
|
||||
|
||||
println!("initializing for {} for {}", conf.name, conf.http_port);
|
||||
println!("initializing for sk {} for {}", conf.id, conf.http_port);
|
||||
|
||||
SafekeeperNode {
|
||||
name: conf.name.clone(),
|
||||
id: conf.id,
|
||||
conf: conf.clone(),
|
||||
pg_connection_config: Self::safekeeper_connection_config(conf.pg_port),
|
||||
env: env.clone(),
|
||||
http_client: Client::new(),
|
||||
http_base_url: format!("http://localhost:{}/v1", conf.http_port),
|
||||
http_base_url: format!("http://127.0.0.1:{}/v1", conf.http_port),
|
||||
pageserver,
|
||||
}
|
||||
}
|
||||
@@ -93,13 +94,17 @@ impl SafekeeperNode {
|
||||
/// Construct libpq connection string for connecting to this safekeeper.
|
||||
fn safekeeper_connection_config(port: u16) -> Config {
|
||||
// TODO safekeeper authentication not implemented yet
|
||||
format!("postgresql://no_user@localhost:{}/no_db", port)
|
||||
format!("postgresql://no_user@127.0.0.1:{}/no_db", port)
|
||||
.parse()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn datadir_path_by_id(env: &LocalEnv, sk_id: ZNodeId) -> PathBuf {
|
||||
env.safekeeper_data_dir(format!("sk{}", sk_id).as_ref())
|
||||
}
|
||||
|
||||
pub fn datadir_path(&self) -> PathBuf {
|
||||
self.env.safekeeper_data_dir(&self.name)
|
||||
SafekeeperNode::datadir_path_by_id(&self.env, self.id)
|
||||
}
|
||||
|
||||
pub fn pid_file(&self) -> PathBuf {
|
||||
@@ -114,12 +119,13 @@ impl SafekeeperNode {
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
let listen_pg = format!("localhost:{}", self.conf.pg_port);
|
||||
let listen_http = format!("localhost:{}", self.conf.http_port);
|
||||
let listen_pg = format!("127.0.0.1:{}", self.conf.pg_port);
|
||||
let listen_http = format!("127.0.0.1:{}", self.conf.http_port);
|
||||
|
||||
let mut cmd = Command::new(self.env.safekeeper_bin()?);
|
||||
fill_rust_env_vars(
|
||||
cmd.args(&["-D", self.datadir_path().to_str().unwrap()])
|
||||
.args(&["--id", self.id.to_string().as_ref()])
|
||||
.args(&["--listen-pg", &listen_pg])
|
||||
.args(&["--listen-http", &listen_http])
|
||||
.args(&["--recall", "1 second"])
|
||||
@@ -183,7 +189,7 @@ impl SafekeeperNode {
|
||||
pub fn stop(&self, immediate: bool) -> anyhow::Result<()> {
|
||||
let pid_file = self.pid_file();
|
||||
if !pid_file.exists() {
|
||||
println!("Safekeeper {} is already stopped", self.name);
|
||||
println!("Safekeeper {} is already stopped", self.id);
|
||||
return Ok(());
|
||||
}
|
||||
let pid = read_pidfile(&pid_file)?;
|
||||
|
||||
@@ -103,6 +103,8 @@ impl PageServerNode {
|
||||
) -> anyhow::Result<()> {
|
||||
let mut cmd = Command::new(self.env.pageserver_bin()?);
|
||||
|
||||
let id = format!("id={}", self.env.pageserver.id);
|
||||
|
||||
// FIXME: the paths should be shell-escaped to handle paths with spaces, quotas etc.
|
||||
let base_data_dir_param = self.env.base_data_dir.display().to_string();
|
||||
let pg_distrib_dir_param =
|
||||
@@ -122,6 +124,7 @@ impl PageServerNode {
|
||||
args.extend(["-c", &authg_type_param]);
|
||||
args.extend(["-c", &listen_http_addr_param]);
|
||||
args.extend(["-c", &listen_pg_addr_param]);
|
||||
args.extend(["-c", &id]);
|
||||
|
||||
for config_override in config_overrides {
|
||||
args.extend(["-c", config_override]);
|
||||
|
||||
@@ -4,7 +4,7 @@ set -eux
|
||||
if [ "$1" = 'pageserver' ]; then
|
||||
if [ ! -d "/data/tenants" ]; then
|
||||
echo "Initializing pageserver data directory"
|
||||
pageserver --init -D /data -c "pg_distrib_dir='/usr/local'"
|
||||
pageserver --init -D /data -c "pg_distrib_dir='/usr/local'" -c "id=10"
|
||||
fi
|
||||
echo "Staring pageserver at 0.0.0.0:6400"
|
||||
pageserver -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -D /data
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
[package]
|
||||
name = "pageserver"
|
||||
version = "0.1.0"
|
||||
authors = ["Stas Kelvich <stas@zenith.tech>"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
bookfile = { git = "https://github.com/zenithdb/bookfile.git", branch="generic-readext" }
|
||||
@@ -15,15 +14,14 @@ futures = "0.3.13"
|
||||
hyper = "0.14"
|
||||
lazy_static = "1.4.0"
|
||||
log = "0.4.14"
|
||||
clap = "2.33.0"
|
||||
clap = "3.0"
|
||||
daemonize = "0.4.1"
|
||||
tokio = { version = "1.11", features = ["process", "sync", "macros", "fs", "rt", "io-util", "time"] }
|
||||
postgres-types = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
postgres-types = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
tokio-stream = "0.1.8"
|
||||
routerify = "2"
|
||||
anyhow = { version = "1.0", features = ["backtrace"] }
|
||||
crc32c = "0.6.0"
|
||||
thiserror = "1.0"
|
||||
@@ -32,7 +30,7 @@ tar = "0.4.33"
|
||||
humantime = "2.1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
toml_edit = { version = "0.12", features = ["easy"] }
|
||||
toml_edit = { version = "0.13", features = ["easy"] }
|
||||
scopeguard = "1.1.0"
|
||||
async-trait = "0.1"
|
||||
const_format = "0.2.21"
|
||||
@@ -42,8 +40,8 @@ signal-hook = "0.3.10"
|
||||
url = "2"
|
||||
nix = "0.23"
|
||||
once_cell = "1.8.0"
|
||||
parking_lot = "0.11.2"
|
||||
crossbeam-utils = "0.8.5"
|
||||
fail = "0.5.0"
|
||||
|
||||
rust-s3 = { version = "0.28", default-features = false, features = ["no-verify-ssl", "tokio-rustls-tls"] }
|
||||
async-compression = {version = "0.3", features = ["zstd", "tokio"]}
|
||||
|
||||
@@ -13,7 +13,7 @@ fn main() -> Result<()> {
|
||||
.about("Dump contents of one layer file, for debugging")
|
||||
.version(GIT_VERSION)
|
||||
.arg(
|
||||
Arg::with_name("path")
|
||||
Arg::new("path")
|
||||
.help("Path to file to dump")
|
||||
.required(true)
|
||||
.index(1),
|
||||
|
||||
@@ -27,27 +27,27 @@ fn main() -> Result<()> {
|
||||
.about("Materializes WAL stream to pages and serves them to the postgres")
|
||||
.version(GIT_VERSION)
|
||||
.arg(
|
||||
Arg::with_name("daemonize")
|
||||
.short("d")
|
||||
Arg::new("daemonize")
|
||||
.short('d')
|
||||
.long("daemonize")
|
||||
.takes_value(false)
|
||||
.help("Run in the background"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("init")
|
||||
Arg::new("init")
|
||||
.long("init")
|
||||
.takes_value(false)
|
||||
.help("Initialize pageserver repo"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("workdir")
|
||||
.short("D")
|
||||
Arg::new("workdir")
|
||||
.short('D')
|
||||
.long("workdir")
|
||||
.takes_value(true)
|
||||
.help("Working directory for the pageserver"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("create-tenant")
|
||||
Arg::new("create-tenant")
|
||||
.long("create-tenant")
|
||||
.takes_value(true)
|
||||
.help("Create tenant during init")
|
||||
@@ -55,13 +55,13 @@ fn main() -> Result<()> {
|
||||
)
|
||||
// See `settings.md` for more details on the extra configuration patameters pageserver can process
|
||||
.arg(
|
||||
Arg::with_name("config-override")
|
||||
.short("c")
|
||||
Arg::new("config-override")
|
||||
.short('c')
|
||||
.takes_value(true)
|
||||
.number_of_values(1)
|
||||
.multiple(true)
|
||||
.multiple_occurrences(true)
|
||||
.help("Additional configuration overrides of the ones from the toml config file (or new ones to add there).
|
||||
Any option has to be a valid toml document, example: `-c \"foo='hey'\"` `-c \"foo={value=1}\"`"),
|
||||
Any option has to be a valid toml document, example: `-c=\"foo='hey'\"` `-c=\"foo={value=1}\"`"),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
@@ -115,7 +115,14 @@ fn main() -> Result<()> {
|
||||
option_line
|
||||
)
|
||||
})?;
|
||||
|
||||
for (key, item) in doc.iter() {
|
||||
if key == "id" {
|
||||
anyhow::ensure!(
|
||||
init,
|
||||
"node id can only be set during pageserver init and cannot be overridden"
|
||||
);
|
||||
}
|
||||
toml.insert(key, item.clone());
|
||||
}
|
||||
}
|
||||
|
||||
334
pageserver/src/bin/pageserver_zst.rs
Normal file
334
pageserver/src/bin/pageserver_zst.rs
Normal file
@@ -0,0 +1,334 @@
|
||||
//! A CLI helper to deal with remote storage (S3, usually) blobs as archives.
|
||||
//! See [`compression`] for more details about the archives.
|
||||
|
||||
use std::{collections::BTreeSet, path::Path};
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use clap::{App, Arg};
|
||||
use pageserver::{
|
||||
layered_repository::metadata::{TimelineMetadata, METADATA_FILE_NAME},
|
||||
remote_storage::compression,
|
||||
};
|
||||
use tokio::{fs, io};
|
||||
use zenith_utils::GIT_VERSION;
|
||||
|
||||
const LIST_SUBCOMMAND: &str = "list";
|
||||
const ARCHIVE_ARG_NAME: &str = "archive";
|
||||
|
||||
const EXTRACT_SUBCOMMAND: &str = "extract";
|
||||
const TARGET_DIRECTORY_ARG_NAME: &str = "target_directory";
|
||||
|
||||
const CREATE_SUBCOMMAND: &str = "create";
|
||||
const SOURCE_DIRECTORY_ARG_NAME: &str = "source_directory";
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let arg_matches = App::new("pageserver zst blob [un]compressor utility")
|
||||
.version(GIT_VERSION)
|
||||
.subcommands(vec![
|
||||
App::new(LIST_SUBCOMMAND)
|
||||
.about("List the archive contents")
|
||||
.arg(
|
||||
Arg::new(ARCHIVE_ARG_NAME)
|
||||
.required(true)
|
||||
.takes_value(true)
|
||||
.help("An archive to list the contents of"),
|
||||
),
|
||||
App::new(EXTRACT_SUBCOMMAND)
|
||||
.about("Extracts the archive into the directory")
|
||||
.arg(
|
||||
Arg::new(ARCHIVE_ARG_NAME)
|
||||
.required(true)
|
||||
.takes_value(true)
|
||||
.help("An archive to extract"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new(TARGET_DIRECTORY_ARG_NAME)
|
||||
.required(false)
|
||||
.takes_value(true)
|
||||
.help("A directory to extract the archive into. Optional, will use the current directory if not specified"),
|
||||
),
|
||||
App::new(CREATE_SUBCOMMAND)
|
||||
.about("Creates an archive with the contents of a directory (only the first level files are taken, metadata file has to be present in the same directory)")
|
||||
.arg(
|
||||
Arg::new(SOURCE_DIRECTORY_ARG_NAME)
|
||||
.required(true)
|
||||
.takes_value(true)
|
||||
.help("A directory to use for creating the archive"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new(TARGET_DIRECTORY_ARG_NAME)
|
||||
.required(false)
|
||||
.takes_value(true)
|
||||
.help("A directory to create the archive in. Optional, will use the current directory if not specified"),
|
||||
),
|
||||
])
|
||||
.get_matches();
|
||||
|
||||
let subcommand_name = match arg_matches.subcommand_name() {
|
||||
Some(name) => name,
|
||||
None => bail!("No subcommand specified"),
|
||||
};
|
||||
|
||||
let subcommand_matches = match arg_matches.subcommand_matches(subcommand_name) {
|
||||
Some(matches) => matches,
|
||||
None => bail!(
|
||||
"No subcommand arguments were recognized for subcommand '{}'",
|
||||
subcommand_name
|
||||
),
|
||||
};
|
||||
|
||||
let target_dir = Path::new(
|
||||
subcommand_matches
|
||||
.value_of(TARGET_DIRECTORY_ARG_NAME)
|
||||
.unwrap_or("./"),
|
||||
);
|
||||
|
||||
match subcommand_name {
|
||||
LIST_SUBCOMMAND => {
|
||||
let archive = match subcommand_matches.value_of(ARCHIVE_ARG_NAME) {
|
||||
Some(archive) => Path::new(archive),
|
||||
None => bail!("No '{}' argument is specified", ARCHIVE_ARG_NAME),
|
||||
};
|
||||
list_archive(archive).await
|
||||
}
|
||||
EXTRACT_SUBCOMMAND => {
|
||||
let archive = match subcommand_matches.value_of(ARCHIVE_ARG_NAME) {
|
||||
Some(archive) => Path::new(archive),
|
||||
None => bail!("No '{}' argument is specified", ARCHIVE_ARG_NAME),
|
||||
};
|
||||
extract_archive(archive, target_dir).await
|
||||
}
|
||||
CREATE_SUBCOMMAND => {
|
||||
let source_dir = match subcommand_matches.value_of(SOURCE_DIRECTORY_ARG_NAME) {
|
||||
Some(source) => Path::new(source),
|
||||
None => bail!("No '{}' argument is specified", SOURCE_DIRECTORY_ARG_NAME),
|
||||
};
|
||||
create_archive(source_dir, target_dir).await
|
||||
}
|
||||
unknown => bail!("Unknown subcommand {}", unknown),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_archive(archive: &Path) -> anyhow::Result<()> {
|
||||
let archive = archive.canonicalize().with_context(|| {
|
||||
format!(
|
||||
"Failed to get the absolute path for the archive path '{}'",
|
||||
archive.display()
|
||||
)
|
||||
})?;
|
||||
ensure!(
|
||||
archive.is_file(),
|
||||
"Path '{}' is not an archive file",
|
||||
archive.display()
|
||||
);
|
||||
println!("Listing an archive at path '{}'", archive.display());
|
||||
let archive_name = match archive.file_name().and_then(|name| name.to_str()) {
|
||||
Some(name) => name,
|
||||
None => bail!(
|
||||
"Failed to get the archive name from the path '{}'",
|
||||
archive.display()
|
||||
),
|
||||
};
|
||||
|
||||
let archive_bytes = fs::read(&archive)
|
||||
.await
|
||||
.context("Failed to read the archive bytes")?;
|
||||
|
||||
let header = compression::read_archive_header(archive_name, &mut archive_bytes.as_slice())
|
||||
.await
|
||||
.context("Failed to read the archive header")?;
|
||||
|
||||
let empty_path = Path::new("");
|
||||
println!("-------------------------------");
|
||||
|
||||
let longest_path_in_archive = header
|
||||
.files
|
||||
.iter()
|
||||
.filter_map(|file| Some(file.subpath.as_path(empty_path).to_str()?.len()))
|
||||
.max()
|
||||
.unwrap_or_default()
|
||||
.max(METADATA_FILE_NAME.len());
|
||||
|
||||
for regular_file in &header.files {
|
||||
println!(
|
||||
"File: {:width$} uncompressed size: {} bytes",
|
||||
regular_file.subpath.as_path(empty_path).display(),
|
||||
regular_file.size,
|
||||
width = longest_path_in_archive,
|
||||
)
|
||||
}
|
||||
println!(
|
||||
"File: {:width$} uncompressed size: {} bytes",
|
||||
METADATA_FILE_NAME,
|
||||
header.metadata_file_size,
|
||||
width = longest_path_in_archive,
|
||||
);
|
||||
println!("-------------------------------");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn extract_archive(archive: &Path, target_dir: &Path) -> anyhow::Result<()> {
|
||||
let archive = archive.canonicalize().with_context(|| {
|
||||
format!(
|
||||
"Failed to get the absolute path for the archive path '{}'",
|
||||
archive.display()
|
||||
)
|
||||
})?;
|
||||
ensure!(
|
||||
archive.is_file(),
|
||||
"Path '{}' is not an archive file",
|
||||
archive.display()
|
||||
);
|
||||
let archive_name = match archive.file_name().and_then(|name| name.to_str()) {
|
||||
Some(name) => name,
|
||||
None => bail!(
|
||||
"Failed to get the archive name from the path '{}'",
|
||||
archive.display()
|
||||
),
|
||||
};
|
||||
|
||||
if !target_dir.exists() {
|
||||
fs::create_dir_all(target_dir).await.with_context(|| {
|
||||
format!(
|
||||
"Failed to create the target dir at path '{}'",
|
||||
target_dir.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
let target_dir = target_dir.canonicalize().with_context(|| {
|
||||
format!(
|
||||
"Failed to get the absolute path for the target dir path '{}'",
|
||||
target_dir.display()
|
||||
)
|
||||
})?;
|
||||
ensure!(
|
||||
target_dir.is_dir(),
|
||||
"Path '{}' is not a directory",
|
||||
target_dir.display()
|
||||
);
|
||||
let mut dir_contents = fs::read_dir(&target_dir)
|
||||
.await
|
||||
.context("Failed to list the target directory contents")?;
|
||||
let dir_entry = dir_contents
|
||||
.next_entry()
|
||||
.await
|
||||
.context("Failed to list the target directory contents")?;
|
||||
ensure!(
|
||||
dir_entry.is_none(),
|
||||
"Target directory '{}' is not empty",
|
||||
target_dir.display()
|
||||
);
|
||||
|
||||
println!(
|
||||
"Extracting an archive at path '{}' into directory '{}'",
|
||||
archive.display(),
|
||||
target_dir.display()
|
||||
);
|
||||
|
||||
let mut archive_file = fs::File::open(&archive).await.with_context(|| {
|
||||
format!(
|
||||
"Failed to get the archive name from the path '{}'",
|
||||
archive.display()
|
||||
)
|
||||
})?;
|
||||
let header = compression::read_archive_header(archive_name, &mut archive_file)
|
||||
.await
|
||||
.context("Failed to read the archive header")?;
|
||||
compression::uncompress_with_header(&BTreeSet::new(), &target_dir, header, &mut archive_file)
|
||||
.await
|
||||
.context("Failed to extract the archive")
|
||||
}
|
||||
|
||||
async fn create_archive(source_dir: &Path, target_dir: &Path) -> anyhow::Result<()> {
|
||||
let source_dir = source_dir.canonicalize().with_context(|| {
|
||||
format!(
|
||||
"Failed to get the absolute path for the source dir path '{}'",
|
||||
source_dir.display()
|
||||
)
|
||||
})?;
|
||||
ensure!(
|
||||
source_dir.is_dir(),
|
||||
"Path '{}' is not a directory",
|
||||
source_dir.display()
|
||||
);
|
||||
|
||||
if !target_dir.exists() {
|
||||
fs::create_dir_all(target_dir).await.with_context(|| {
|
||||
format!(
|
||||
"Failed to create the target dir at path '{}'",
|
||||
target_dir.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
let target_dir = target_dir.canonicalize().with_context(|| {
|
||||
format!(
|
||||
"Failed to get the absolute path for the target dir path '{}'",
|
||||
target_dir.display()
|
||||
)
|
||||
})?;
|
||||
ensure!(
|
||||
target_dir.is_dir(),
|
||||
"Path '{}' is not a directory",
|
||||
target_dir.display()
|
||||
);
|
||||
|
||||
println!(
|
||||
"Compressing directory '{}' and creating resulting archive in directory '{}'",
|
||||
source_dir.display(),
|
||||
target_dir.display()
|
||||
);
|
||||
|
||||
let mut metadata_file_contents = None;
|
||||
let mut files_co_archive = Vec::new();
|
||||
|
||||
let mut source_dir_contents = fs::read_dir(&source_dir)
|
||||
.await
|
||||
.context("Failed to read the source directory contents")?;
|
||||
|
||||
while let Some(source_dir_entry) = source_dir_contents
|
||||
.next_entry()
|
||||
.await
|
||||
.context("Failed to read a source dir entry")?
|
||||
{
|
||||
let entry_path = source_dir_entry.path();
|
||||
if entry_path.is_file() {
|
||||
if entry_path.file_name().and_then(|name| name.to_str()) == Some(METADATA_FILE_NAME) {
|
||||
let metadata_bytes = fs::read(entry_path)
|
||||
.await
|
||||
.context("Failed to read metata file bytes in the source dir")?;
|
||||
metadata_file_contents = Some(
|
||||
TimelineMetadata::from_bytes(&metadata_bytes)
|
||||
.context("Failed to parse metata file contents in the source dir")?,
|
||||
);
|
||||
} else {
|
||||
files_co_archive.push(entry_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let metadata = match metadata_file_contents {
|
||||
Some(metadata) => metadata,
|
||||
None => bail!(
|
||||
"No metadata file found in the source dir '{}', cannot create the archive",
|
||||
source_dir.display()
|
||||
),
|
||||
};
|
||||
|
||||
let _ = compression::archive_files_as_stream(
|
||||
&source_dir,
|
||||
files_co_archive.iter(),
|
||||
&metadata,
|
||||
move |mut archive_streamer, archive_name| async move {
|
||||
let archive_target = target_dir.join(&archive_name);
|
||||
let mut archive_file = fs::File::create(&archive_target).await?;
|
||||
io::copy(&mut archive_streamer, &mut archive_file).await?;
|
||||
Ok(archive_target)
|
||||
},
|
||||
)
|
||||
.await
|
||||
.context("Failed to create an archive")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -14,20 +14,20 @@ fn main() -> Result<()> {
|
||||
.about("Dump or update metadata file")
|
||||
.version(GIT_VERSION)
|
||||
.arg(
|
||||
Arg::with_name("path")
|
||||
Arg::new("path")
|
||||
.help("Path to metadata file")
|
||||
.required(true),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("disk_lsn")
|
||||
.short("d")
|
||||
Arg::new("disk_lsn")
|
||||
.short('d')
|
||||
.long("disk_lsn")
|
||||
.takes_value(true)
|
||||
.help("Replace disk constistent lsn"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("prev_lsn")
|
||||
.short("p")
|
||||
Arg::new("prev_lsn")
|
||||
.short('p')
|
||||
.long("prev_lsn")
|
||||
.takes_value(true)
|
||||
.help("Previous record LSN"),
|
||||
|
||||
@@ -324,12 +324,13 @@ pub(crate) fn create_branch(
|
||||
timeline.wait_lsn(startpoint.lsn)?;
|
||||
}
|
||||
startpoint.lsn = startpoint.lsn.align();
|
||||
if timeline.get_start_lsn() > startpoint.lsn {
|
||||
if timeline.get_ancestor_lsn() > startpoint.lsn {
|
||||
// can we safely just branch from the ancestor instead?
|
||||
anyhow::bail!(
|
||||
"invalid startpoint {} for the branch {}: less than timeline start {}",
|
||||
"invalid startpoint {} for the branch {}: less than timeline ancestor lsn {:?}",
|
||||
startpoint.lsn,
|
||||
branchname,
|
||||
timeline.get_start_lsn()
|
||||
timeline.get_ancestor_lsn()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ use anyhow::{bail, ensure, Context, Result};
|
||||
use toml_edit;
|
||||
use toml_edit::{Document, Item};
|
||||
use zenith_utils::postgres_backend::AuthType;
|
||||
use zenith_utils::zid::{ZTenantId, ZTimelineId};
|
||||
use zenith_utils::zid::{ZNodeId, ZTenantId, ZTimelineId};
|
||||
|
||||
use std::convert::TryInto;
|
||||
use std::env;
|
||||
@@ -72,6 +72,10 @@ pub mod defaults {
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PageServerConf {
|
||||
// Identifier of that particular pageserver so e g safekeepers
|
||||
// can safely distinguish different pageservers
|
||||
pub id: ZNodeId,
|
||||
|
||||
/// Example (default): 127.0.0.1:64000
|
||||
pub listen_pg_addr: String,
|
||||
/// Example (default): 127.0.0.1:9898
|
||||
@@ -106,6 +110,184 @@ pub struct PageServerConf {
|
||||
pub remote_storage_config: Option<RemoteStorageConfig>,
|
||||
}
|
||||
|
||||
// use dedicated enum for builder to better indicate the intention
|
||||
// and avoid possible confusion with nested options
|
||||
pub enum BuilderValue<T> {
|
||||
Set(T),
|
||||
NotSet,
|
||||
}
|
||||
|
||||
impl<T> BuilderValue<T> {
|
||||
pub fn ok_or<E>(self, err: E) -> Result<T, E> {
|
||||
match self {
|
||||
Self::Set(v) => Ok(v),
|
||||
Self::NotSet => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// needed to simplify config construction
|
||||
struct PageServerConfigBuilder {
|
||||
listen_pg_addr: BuilderValue<String>,
|
||||
|
||||
listen_http_addr: BuilderValue<String>,
|
||||
|
||||
checkpoint_distance: BuilderValue<u64>,
|
||||
checkpoint_period: BuilderValue<Duration>,
|
||||
|
||||
gc_horizon: BuilderValue<u64>,
|
||||
gc_period: BuilderValue<Duration>,
|
||||
superuser: BuilderValue<String>,
|
||||
|
||||
page_cache_size: BuilderValue<usize>,
|
||||
max_file_descriptors: BuilderValue<usize>,
|
||||
|
||||
workdir: BuilderValue<PathBuf>,
|
||||
|
||||
pg_distrib_dir: BuilderValue<PathBuf>,
|
||||
|
||||
auth_type: BuilderValue<AuthType>,
|
||||
|
||||
//
|
||||
auth_validation_public_key_path: BuilderValue<Option<PathBuf>>,
|
||||
remote_storage_config: BuilderValue<Option<RemoteStorageConfig>>,
|
||||
|
||||
id: BuilderValue<ZNodeId>,
|
||||
}
|
||||
|
||||
impl Default for PageServerConfigBuilder {
|
||||
fn default() -> Self {
|
||||
use self::BuilderValue::*;
|
||||
use defaults::*;
|
||||
Self {
|
||||
listen_pg_addr: Set(DEFAULT_PG_LISTEN_ADDR.to_string()),
|
||||
listen_http_addr: Set(DEFAULT_HTTP_LISTEN_ADDR.to_string()),
|
||||
checkpoint_distance: Set(DEFAULT_CHECKPOINT_DISTANCE),
|
||||
checkpoint_period: Set(humantime::parse_duration(DEFAULT_CHECKPOINT_PERIOD)
|
||||
.expect("cannot parse default checkpoint period")),
|
||||
gc_horizon: Set(DEFAULT_GC_HORIZON),
|
||||
gc_period: Set(humantime::parse_duration(DEFAULT_GC_PERIOD)
|
||||
.expect("cannot parse default gc period")),
|
||||
superuser: Set(DEFAULT_SUPERUSER.to_string()),
|
||||
page_cache_size: Set(DEFAULT_PAGE_CACHE_SIZE),
|
||||
max_file_descriptors: Set(DEFAULT_MAX_FILE_DESCRIPTORS),
|
||||
workdir: Set(PathBuf::new()),
|
||||
pg_distrib_dir: Set(env::current_dir()
|
||||
.expect("cannot access current directory")
|
||||
.join("tmp_install")),
|
||||
auth_type: Set(AuthType::Trust),
|
||||
auth_validation_public_key_path: Set(None),
|
||||
remote_storage_config: Set(None),
|
||||
id: NotSet,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PageServerConfigBuilder {
|
||||
pub fn listen_pg_addr(&mut self, listen_pg_addr: String) {
|
||||
self.listen_pg_addr = BuilderValue::Set(listen_pg_addr)
|
||||
}
|
||||
|
||||
pub fn listen_http_addr(&mut self, listen_http_addr: String) {
|
||||
self.listen_http_addr = BuilderValue::Set(listen_http_addr)
|
||||
}
|
||||
|
||||
pub fn checkpoint_distance(&mut self, checkpoint_distance: u64) {
|
||||
self.checkpoint_distance = BuilderValue::Set(checkpoint_distance)
|
||||
}
|
||||
|
||||
pub fn checkpoint_period(&mut self, checkpoint_period: Duration) {
|
||||
self.checkpoint_period = BuilderValue::Set(checkpoint_period)
|
||||
}
|
||||
|
||||
pub fn gc_horizon(&mut self, gc_horizon: u64) {
|
||||
self.gc_horizon = BuilderValue::Set(gc_horizon)
|
||||
}
|
||||
|
||||
pub fn gc_period(&mut self, gc_period: Duration) {
|
||||
self.gc_period = BuilderValue::Set(gc_period)
|
||||
}
|
||||
|
||||
pub fn superuser(&mut self, superuser: String) {
|
||||
self.superuser = BuilderValue::Set(superuser)
|
||||
}
|
||||
|
||||
pub fn page_cache_size(&mut self, page_cache_size: usize) {
|
||||
self.page_cache_size = BuilderValue::Set(page_cache_size)
|
||||
}
|
||||
|
||||
pub fn max_file_descriptors(&mut self, max_file_descriptors: usize) {
|
||||
self.max_file_descriptors = BuilderValue::Set(max_file_descriptors)
|
||||
}
|
||||
|
||||
pub fn workdir(&mut self, workdir: PathBuf) {
|
||||
self.workdir = BuilderValue::Set(workdir)
|
||||
}
|
||||
|
||||
pub fn pg_distrib_dir(&mut self, pg_distrib_dir: PathBuf) {
|
||||
self.pg_distrib_dir = BuilderValue::Set(pg_distrib_dir)
|
||||
}
|
||||
|
||||
pub fn auth_type(&mut self, auth_type: AuthType) {
|
||||
self.auth_type = BuilderValue::Set(auth_type)
|
||||
}
|
||||
|
||||
pub fn auth_validation_public_key_path(
|
||||
&mut self,
|
||||
auth_validation_public_key_path: Option<PathBuf>,
|
||||
) {
|
||||
self.auth_validation_public_key_path = BuilderValue::Set(auth_validation_public_key_path)
|
||||
}
|
||||
|
||||
pub fn remote_storage_config(&mut self, remote_storage_config: Option<RemoteStorageConfig>) {
|
||||
self.remote_storage_config = BuilderValue::Set(remote_storage_config)
|
||||
}
|
||||
|
||||
pub fn id(&mut self, node_id: ZNodeId) {
|
||||
self.id = BuilderValue::Set(node_id)
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<PageServerConf> {
|
||||
Ok(PageServerConf {
|
||||
listen_pg_addr: self
|
||||
.listen_pg_addr
|
||||
.ok_or(anyhow::anyhow!("missing listen_pg_addr"))?,
|
||||
listen_http_addr: self
|
||||
.listen_http_addr
|
||||
.ok_or(anyhow::anyhow!("missing listen_http_addr"))?,
|
||||
checkpoint_distance: self
|
||||
.checkpoint_distance
|
||||
.ok_or(anyhow::anyhow!("missing checkpoint_distance"))?,
|
||||
checkpoint_period: self
|
||||
.checkpoint_period
|
||||
.ok_or(anyhow::anyhow!("missing checkpoint_period"))?,
|
||||
gc_horizon: self
|
||||
.gc_horizon
|
||||
.ok_or(anyhow::anyhow!("missing gc_horizon"))?,
|
||||
gc_period: self.gc_period.ok_or(anyhow::anyhow!("missing gc_period"))?,
|
||||
superuser: self.superuser.ok_or(anyhow::anyhow!("missing superuser"))?,
|
||||
page_cache_size: self
|
||||
.page_cache_size
|
||||
.ok_or(anyhow::anyhow!("missing page_cache_size"))?,
|
||||
max_file_descriptors: self
|
||||
.max_file_descriptors
|
||||
.ok_or(anyhow::anyhow!("missing max_file_descriptors"))?,
|
||||
workdir: self.workdir.ok_or(anyhow::anyhow!("missing workdir"))?,
|
||||
pg_distrib_dir: self
|
||||
.pg_distrib_dir
|
||||
.ok_or(anyhow::anyhow!("missing pg_distrib_dir"))?,
|
||||
auth_type: self.auth_type.ok_or(anyhow::anyhow!("missing auth_type"))?,
|
||||
auth_validation_public_key_path: self
|
||||
.auth_validation_public_key_path
|
||||
.ok_or(anyhow::anyhow!("missing auth_validation_public_key_path"))?,
|
||||
remote_storage_config: self
|
||||
.remote_storage_config
|
||||
.ok_or(anyhow::anyhow!("missing remote_storage_config"))?,
|
||||
id: self.id.ok_or(anyhow::anyhow!("missing id"))?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// External backup storage configuration, enough for creating a client for that storage.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RemoteStorageConfig {
|
||||
@@ -221,57 +403,39 @@ impl PageServerConf {
|
||||
///
|
||||
/// This leaves any options not present in the file in the built-in defaults.
|
||||
pub fn parse_and_validate(toml: &Document, workdir: &Path) -> Result<Self> {
|
||||
use defaults::*;
|
||||
|
||||
let mut conf = PageServerConf {
|
||||
workdir: workdir.to_path_buf(),
|
||||
|
||||
listen_pg_addr: DEFAULT_PG_LISTEN_ADDR.to_string(),
|
||||
listen_http_addr: DEFAULT_HTTP_LISTEN_ADDR.to_string(),
|
||||
checkpoint_distance: DEFAULT_CHECKPOINT_DISTANCE,
|
||||
checkpoint_period: humantime::parse_duration(DEFAULT_CHECKPOINT_PERIOD)?,
|
||||
gc_horizon: DEFAULT_GC_HORIZON,
|
||||
gc_period: humantime::parse_duration(DEFAULT_GC_PERIOD)?,
|
||||
page_cache_size: DEFAULT_PAGE_CACHE_SIZE,
|
||||
max_file_descriptors: DEFAULT_MAX_FILE_DESCRIPTORS,
|
||||
|
||||
pg_distrib_dir: PathBuf::new(),
|
||||
auth_validation_public_key_path: None,
|
||||
auth_type: AuthType::Trust,
|
||||
|
||||
remote_storage_config: None,
|
||||
|
||||
superuser: DEFAULT_SUPERUSER.to_string(),
|
||||
};
|
||||
let mut builder = PageServerConfigBuilder::default();
|
||||
builder.workdir(workdir.to_owned());
|
||||
|
||||
for (key, item) in toml.iter() {
|
||||
match key {
|
||||
"listen_pg_addr" => conf.listen_pg_addr = parse_toml_string(key, item)?,
|
||||
"listen_http_addr" => conf.listen_http_addr = parse_toml_string(key, item)?,
|
||||
"checkpoint_distance" => conf.checkpoint_distance = parse_toml_u64(key, item)?,
|
||||
"checkpoint_period" => conf.checkpoint_period = parse_toml_duration(key, item)?,
|
||||
"gc_horizon" => conf.gc_horizon = parse_toml_u64(key, item)?,
|
||||
"gc_period" => conf.gc_period = parse_toml_duration(key, item)?,
|
||||
"initial_superuser_name" => conf.superuser = parse_toml_string(key, item)?,
|
||||
"page_cache_size" => conf.page_cache_size = parse_toml_u64(key, item)? as usize,
|
||||
"listen_pg_addr" => builder.listen_pg_addr(parse_toml_string(key, item)?),
|
||||
"listen_http_addr" => builder.listen_http_addr(parse_toml_string(key, item)?),
|
||||
"checkpoint_distance" => builder.checkpoint_distance(parse_toml_u64(key, item)?),
|
||||
"checkpoint_period" => builder.checkpoint_period(parse_toml_duration(key, item)?),
|
||||
"gc_horizon" => builder.gc_horizon(parse_toml_u64(key, item)?),
|
||||
"gc_period" => builder.gc_period(parse_toml_duration(key, item)?),
|
||||
"initial_superuser_name" => builder.superuser(parse_toml_string(key, item)?),
|
||||
"page_cache_size" => builder.page_cache_size(parse_toml_u64(key, item)? as usize),
|
||||
"max_file_descriptors" => {
|
||||
conf.max_file_descriptors = parse_toml_u64(key, item)? as usize
|
||||
builder.max_file_descriptors(parse_toml_u64(key, item)? as usize)
|
||||
}
|
||||
"pg_distrib_dir" => {
|
||||
conf.pg_distrib_dir = PathBuf::from(parse_toml_string(key, item)?)
|
||||
builder.pg_distrib_dir(PathBuf::from(parse_toml_string(key, item)?))
|
||||
}
|
||||
"auth_validation_public_key_path" => {
|
||||
conf.auth_validation_public_key_path =
|
||||
Some(PathBuf::from(parse_toml_string(key, item)?))
|
||||
}
|
||||
"auth_type" => conf.auth_type = parse_toml_auth_type(key, item)?,
|
||||
"auth_validation_public_key_path" => builder.auth_validation_public_key_path(Some(
|
||||
PathBuf::from(parse_toml_string(key, item)?),
|
||||
)),
|
||||
"auth_type" => builder.auth_type(parse_toml_auth_type(key, item)?),
|
||||
"remote_storage" => {
|
||||
conf.remote_storage_config = Some(Self::parse_remote_storage_config(item)?)
|
||||
builder.remote_storage_config(Some(Self::parse_remote_storage_config(item)?))
|
||||
}
|
||||
"id" => builder.id(ZNodeId(parse_toml_u64(key, item)?)),
|
||||
_ => bail!("unrecognized pageserver option '{}'", key),
|
||||
}
|
||||
}
|
||||
|
||||
let mut conf = builder.build().context("invalid config")?;
|
||||
|
||||
if conf.auth_type == AuthType::ZenithJWT {
|
||||
let auth_validation_public_key_path = conf
|
||||
.auth_validation_public_key_path
|
||||
@@ -285,9 +449,6 @@ impl PageServerConf {
|
||||
);
|
||||
}
|
||||
|
||||
if conf.pg_distrib_dir == PathBuf::new() {
|
||||
conf.pg_distrib_dir = env::current_dir()?.join("tmp_install")
|
||||
};
|
||||
if !conf.pg_distrib_dir.join("bin/postgres").exists() {
|
||||
bail!(
|
||||
"Can't find postgres binary at {}",
|
||||
@@ -382,6 +543,7 @@ impl PageServerConf {
|
||||
#[cfg(test)]
|
||||
pub fn dummy_conf(repo_dir: PathBuf) -> Self {
|
||||
PageServerConf {
|
||||
id: ZNodeId(0),
|
||||
checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE,
|
||||
checkpoint_period: Duration::from_secs(10),
|
||||
gc_horizon: defaults::DEFAULT_GC_HORIZON,
|
||||
@@ -461,15 +623,16 @@ max_file_descriptors = 333
|
||||
|
||||
# initial superuser role name to use when creating a new tenant
|
||||
initial_superuser_name = 'zzzz'
|
||||
id = 10
|
||||
|
||||
"#;
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
fn parse_defaults() -> anyhow::Result<()> {
|
||||
let tempdir = tempdir()?;
|
||||
let (workdir, pg_distrib_dir) = prepare_fs(&tempdir)?;
|
||||
// we have to create dummy pathes to overcome the validation errors
|
||||
let config_string = format!("pg_distrib_dir='{}'", pg_distrib_dir.display());
|
||||
let config_string = format!("pg_distrib_dir='{}'\nid=10", pg_distrib_dir.display());
|
||||
let toml = config_string.parse()?;
|
||||
|
||||
let parsed_config =
|
||||
@@ -480,6 +643,7 @@ initial_superuser_name = 'zzzz'
|
||||
assert_eq!(
|
||||
parsed_config,
|
||||
PageServerConf {
|
||||
id: ZNodeId(10),
|
||||
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
|
||||
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
|
||||
checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE,
|
||||
@@ -521,6 +685,7 @@ initial_superuser_name = 'zzzz'
|
||||
assert_eq!(
|
||||
parsed_config,
|
||||
PageServerConf {
|
||||
id: ZNodeId(10),
|
||||
listen_pg_addr: "127.0.0.1:64000".to_string(),
|
||||
listen_http_addr: "127.0.0.1:9898".to_string(),
|
||||
checkpoint_distance: 111,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::ZTenantId;
|
||||
use zenith_utils::zid::ZNodeId;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct BranchCreateRequest {
|
||||
@@ -15,3 +16,8 @@ pub struct TenantCreateRequest {
|
||||
#[serde(with = "hex")]
|
||||
pub tenant_id: ZTenantId,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct StatusResponse {
|
||||
pub id: ZNodeId,
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- id
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
/v1/timeline/{tenant_id}:
|
||||
parameters:
|
||||
- name: tenant_id
|
||||
@@ -234,9 +239,7 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/BranchInfo"
|
||||
$ref: "#/components/schemas/BranchInfo"
|
||||
"400":
|
||||
description: Malformed branch create request
|
||||
content:
|
||||
@@ -370,12 +373,15 @@ components:
|
||||
format: hex
|
||||
ancestor_id:
|
||||
type: string
|
||||
format: hex
|
||||
ancestor_lsn:
|
||||
type: string
|
||||
current_logical_size:
|
||||
type: integer
|
||||
current_logical_size_non_incremental:
|
||||
type: integer
|
||||
latest_valid_lsn:
|
||||
type: integer
|
||||
TimelineInfo:
|
||||
type: object
|
||||
required:
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use hyper::header;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, Request, Response, Uri};
|
||||
use routerify::{ext::RequestExt, RouterBuilder};
|
||||
use serde::Serialize;
|
||||
use tracing::*;
|
||||
use zenith_utils::auth::JwtAuth;
|
||||
@@ -19,10 +17,12 @@ use zenith_utils::http::{
|
||||
request::get_request_param,
|
||||
request::parse_request_param,
|
||||
};
|
||||
use zenith_utils::http::{RequestExt, RouterBuilder};
|
||||
use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::zid::{opt_display_serde, ZTimelineId};
|
||||
|
||||
use super::models::BranchCreateRequest;
|
||||
use super::models::StatusResponse;
|
||||
use super::models::TenantCreateRequest;
|
||||
use crate::branches::BranchInfo;
|
||||
use crate::repository::RepositoryTimeline;
|
||||
@@ -64,12 +64,12 @@ fn get_config(request: &Request<Body>) -> &'static PageServerConf {
|
||||
}
|
||||
|
||||
// healthcheck handler
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from("{}"))
|
||||
.map_err(ApiError::from_err)?)
|
||||
async fn status_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let config = get_config(&request);
|
||||
Ok(json_response(
|
||||
StatusCode::OK,
|
||||
StatusResponse { id: config.id },
|
||||
)?)
|
||||
}
|
||||
|
||||
async fn branch_create_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
@@ -202,7 +202,6 @@ enum TimelineInfo {
|
||||
ancestor_timeline_id: Option<ZTimelineId>,
|
||||
last_record_lsn: Lsn,
|
||||
prev_record_lsn: Lsn,
|
||||
start_lsn: Lsn,
|
||||
disk_consistent_lsn: Lsn,
|
||||
timeline_state: Option<TimelineSyncState>,
|
||||
},
|
||||
@@ -237,7 +236,6 @@ async fn timeline_detail_handler(request: Request<Body>) -> Result<Response<Body
|
||||
disk_consistent_lsn: timeline.get_disk_consistent_lsn(),
|
||||
last_record_lsn: timeline.get_last_record_lsn(),
|
||||
prev_record_lsn: timeline.get_prev_record_lsn(),
|
||||
start_lsn: timeline.get_start_lsn(),
|
||||
timeline_state: repo.get_timeline_state(timeline_id),
|
||||
},
|
||||
})
|
||||
|
||||
@@ -28,7 +28,7 @@ use std::io::Write;
|
||||
use std::ops::{Bound::Included, Deref};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{self, AtomicBool, AtomicUsize};
|
||||
use std::sync::{Arc, Mutex, MutexGuard};
|
||||
use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use self::metadata::{metadata_path, TimelineMetadata, METADATA_FILE_NAME};
|
||||
@@ -71,7 +71,6 @@ mod storage_layer;
|
||||
use delta_layer::DeltaLayer;
|
||||
use ephemeral_file::is_ephemeral_file;
|
||||
use filename::{DeltaFileName, ImageFileName};
|
||||
use global_layer_map::{LayerId, GLOBAL_LAYER_MAP};
|
||||
use image_layer::ImageLayer;
|
||||
use inmemory_layer::InMemoryLayer;
|
||||
use layer_map::LayerMap;
|
||||
@@ -167,7 +166,7 @@ impl Repository for LayeredRepository {
|
||||
// Create the timeline directory, and write initial metadata to file.
|
||||
crashsafe_dir::create_dir_all(self.conf.timeline_path(&timelineid, &self.tenantid))?;
|
||||
|
||||
let metadata = TimelineMetadata::new(Lsn(0), None, None, Lsn(0), Lsn(0), initdb_lsn);
|
||||
let metadata = TimelineMetadata::new(Lsn(0), None, None, Lsn(0), initdb_lsn, initdb_lsn);
|
||||
Self::save_metadata(self.conf, timelineid, self.tenantid, &metadata, true)?;
|
||||
|
||||
let timeline = LayeredTimeline::new(
|
||||
@@ -201,9 +200,10 @@ impl Repository for LayeredRepository {
|
||||
bail!("Cannot branch off the timeline {} that's not local", src)
|
||||
}
|
||||
};
|
||||
let latest_gc_cutoff_lsn = src_timeline.get_latest_gc_cutoff_lsn();
|
||||
|
||||
src_timeline
|
||||
.check_lsn_is_in_scope(start_lsn)
|
||||
.check_lsn_is_in_scope(start_lsn, &latest_gc_cutoff_lsn)
|
||||
.context("invalid branch start lsn")?;
|
||||
|
||||
let RecordLsn {
|
||||
@@ -231,7 +231,7 @@ impl Repository for LayeredRepository {
|
||||
dst_prev,
|
||||
Some(src),
|
||||
start_lsn,
|
||||
src_timeline.latest_gc_cutoff_lsn.load(),
|
||||
*src_timeline.latest_gc_cutoff_lsn.read().unwrap(),
|
||||
src_timeline.initdb_lsn,
|
||||
);
|
||||
crashsafe_dir::create_dir_all(self.conf.timeline_path(&dst, &self.tenantid))?;
|
||||
@@ -611,7 +611,7 @@ impl LayeredRepository {
|
||||
}
|
||||
}
|
||||
|
||||
//Now collect info about branchpoints
|
||||
// Now collect info about branchpoints
|
||||
let mut all_branchpoints: BTreeSet<(ZTimelineId, Lsn)> = BTreeSet::new();
|
||||
for &timelineid in &timelineids {
|
||||
let timeline = match self.get_or_init_timeline(timelineid, &mut timelines)? {
|
||||
@@ -783,7 +783,7 @@ pub struct LayeredTimeline {
|
||||
checkpoint_cs: Mutex<()>,
|
||||
|
||||
// Needed to ensure that we can't create a branch at a point that was already garbage collected
|
||||
latest_gc_cutoff_lsn: AtomicLsn,
|
||||
latest_gc_cutoff_lsn: RwLock<Lsn>,
|
||||
|
||||
// It may change across major versions so for simplicity
|
||||
// keep it after running initdb for a timeline.
|
||||
@@ -827,6 +827,10 @@ impl Timeline for LayeredTimeline {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_latest_gc_cutoff_lsn(&self) -> RwLockReadGuard<Lsn> {
|
||||
self.latest_gc_cutoff_lsn.read().unwrap()
|
||||
}
|
||||
|
||||
/// Look up given page version.
|
||||
fn get_page_at_lsn(&self, rel: RelishTag, rel_blknum: BlockNumber, lsn: Lsn) -> Result<Bytes> {
|
||||
if !rel.is_blocky() && rel_blknum != 0 {
|
||||
@@ -837,14 +841,6 @@ impl Timeline for LayeredTimeline {
|
||||
);
|
||||
}
|
||||
debug_assert!(lsn <= self.get_last_record_lsn());
|
||||
let latest_gc_cutoff_lsn = self.latest_gc_cutoff_lsn.load();
|
||||
// error instead of assert to simplify testing
|
||||
ensure!(
|
||||
lsn >= latest_gc_cutoff_lsn,
|
||||
"tried to request a page version that was garbage collected. requested at {} gc cutoff {}",
|
||||
lsn, latest_gc_cutoff_lsn
|
||||
);
|
||||
|
||||
let (seg, seg_blknum) = SegmentTag::from_blknum(rel, rel_blknum);
|
||||
|
||||
if let Some((layer, lsn)) = self.get_layer_for_read(seg, lsn)? {
|
||||
@@ -1015,21 +1011,16 @@ impl Timeline for LayeredTimeline {
|
||||
///
|
||||
/// Validate lsn against initdb_lsn and latest_gc_cutoff_lsn.
|
||||
///
|
||||
fn check_lsn_is_in_scope(&self, lsn: Lsn) -> Result<()> {
|
||||
let initdb_lsn = self.initdb_lsn;
|
||||
fn check_lsn_is_in_scope(
|
||||
&self,
|
||||
lsn: Lsn,
|
||||
latest_gc_cutoff_lsn: &RwLockReadGuard<Lsn>,
|
||||
) -> Result<()> {
|
||||
ensure!(
|
||||
lsn >= initdb_lsn,
|
||||
"LSN {} is earlier than initdb lsn {}",
|
||||
lsn,
|
||||
initdb_lsn,
|
||||
);
|
||||
|
||||
let latest_gc_cutoff_lsn = self.latest_gc_cutoff_lsn.load();
|
||||
ensure!(
|
||||
lsn >= latest_gc_cutoff_lsn,
|
||||
lsn >= **latest_gc_cutoff_lsn,
|
||||
"LSN {} is earlier than latest GC horizon {} (we might've already garbage collected needed data)",
|
||||
lsn,
|
||||
latest_gc_cutoff_lsn,
|
||||
**latest_gc_cutoff_lsn,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1046,14 +1037,6 @@ impl Timeline for LayeredTimeline {
|
||||
self.last_record_lsn.load()
|
||||
}
|
||||
|
||||
fn get_start_lsn(&self) -> Lsn {
|
||||
self.ancestor_timeline
|
||||
.as_ref()
|
||||
.and_then(|ancestor_entry| ancestor_entry.local_or_schedule_download(self.tenantid))
|
||||
.map(Timeline::get_start_lsn)
|
||||
.unwrap_or(self.ancestor_lsn)
|
||||
}
|
||||
|
||||
fn get_current_logical_size(&self) -> usize {
|
||||
self.current_logical_size.load(atomic::Ordering::Acquire) as usize
|
||||
}
|
||||
@@ -1143,7 +1126,7 @@ impl LayeredTimeline {
|
||||
write_lock: Mutex::new(()),
|
||||
checkpoint_cs: Mutex::new(()),
|
||||
|
||||
latest_gc_cutoff_lsn: AtomicLsn::from(metadata.latest_gc_cutoff_lsn()),
|
||||
latest_gc_cutoff_lsn: RwLock::new(metadata.latest_gc_cutoff_lsn()),
|
||||
initdb_lsn: metadata.initdb_lsn(),
|
||||
}
|
||||
}
|
||||
@@ -1169,8 +1152,8 @@ impl LayeredTimeline {
|
||||
// create an ImageLayer struct for each image file.
|
||||
if imgfilename.lsn > disk_consistent_lsn {
|
||||
warn!(
|
||||
"found future image layer {} on timeline {}",
|
||||
imgfilename, self.timelineid
|
||||
"found future image layer {} on timeline {} disk_consistent_lsn is {}",
|
||||
imgfilename, self.timelineid, disk_consistent_lsn
|
||||
);
|
||||
|
||||
rename_to_backup(direntry.path())?;
|
||||
@@ -1193,8 +1176,8 @@ impl LayeredTimeline {
|
||||
// before crash.
|
||||
if deltafilename.end_lsn > disk_consistent_lsn + 1 {
|
||||
warn!(
|
||||
"found future delta layer {} on timeline {}",
|
||||
deltafilename, self.timelineid
|
||||
"found future delta layer {} on timeline {} disk_consistent_lsn is {}",
|
||||
deltafilename, self.timelineid, disk_consistent_lsn
|
||||
);
|
||||
|
||||
rename_to_backup(direntry.path())?;
|
||||
@@ -1390,7 +1373,7 @@ impl LayeredTimeline {
|
||||
self.tenantid,
|
||||
seg,
|
||||
lsn,
|
||||
lsn,
|
||||
last_record_lsn,
|
||||
)?;
|
||||
} else {
|
||||
return Ok(open_layer);
|
||||
@@ -1433,7 +1416,7 @@ impl LayeredTimeline {
|
||||
self.timelineid,
|
||||
self.tenantid,
|
||||
start_lsn,
|
||||
lsn,
|
||||
last_record_lsn,
|
||||
)?;
|
||||
} else {
|
||||
// New relation.
|
||||
@@ -1444,8 +1427,14 @@ impl LayeredTimeline {
|
||||
lsn
|
||||
);
|
||||
|
||||
layer =
|
||||
InMemoryLayer::create(self.conf, self.timelineid, self.tenantid, seg, lsn, lsn)?;
|
||||
layer = InMemoryLayer::create(
|
||||
self.conf,
|
||||
self.timelineid,
|
||||
self.tenantid,
|
||||
seg,
|
||||
lsn,
|
||||
last_record_lsn,
|
||||
)?;
|
||||
}
|
||||
|
||||
let layer_rc: Arc<InMemoryLayer> = Arc::new(layer);
|
||||
@@ -1462,7 +1451,7 @@ impl LayeredTimeline {
|
||||
// Prevent concurrent checkpoints
|
||||
let _checkpoint_cs = self.checkpoint_cs.lock().unwrap();
|
||||
|
||||
let mut write_guard = self.write_lock.lock().unwrap();
|
||||
let write_guard = self.write_lock.lock().unwrap();
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
|
||||
// Bump the generation number in the layer map, so that we can distinguish
|
||||
@@ -1488,11 +1477,17 @@ impl LayeredTimeline {
|
||||
let mut disk_consistent_lsn = last_record_lsn;
|
||||
|
||||
let mut layer_paths = Vec::new();
|
||||
let mut freeze_end_lsn = Lsn(0);
|
||||
let mut evicted_layers = Vec::new();
|
||||
|
||||
//
|
||||
// Determine which layers we need to evict and calculate max(latest_lsn)
|
||||
// among those layers.
|
||||
//
|
||||
while let Some((oldest_layer_id, oldest_layer, oldest_generation)) =
|
||||
layers.peek_oldest_open()
|
||||
{
|
||||
let oldest_pending_lsn = oldest_layer.get_oldest_pending_lsn();
|
||||
|
||||
let oldest_lsn = oldest_layer.get_oldest_lsn();
|
||||
// Does this layer need freezing?
|
||||
//
|
||||
// Write out all in-memory layers that contain WAL older than CHECKPOINT_DISTANCE.
|
||||
@@ -1501,28 +1496,60 @@ impl LayeredTimeline {
|
||||
// when we started. We don't want to process layers inserted after we started, to
|
||||
// avoid getting into an infinite loop trying to process again entries that we
|
||||
// inserted ourselves.
|
||||
let distance = last_record_lsn.widening_sub(oldest_pending_lsn);
|
||||
if distance < 0
|
||||
//
|
||||
// Once we have decided to write out at least one layer, we must also write out
|
||||
// any other layers that contain WAL older than the end LSN of the layers we have
|
||||
// already decided to write out. In other words, we must write out all layers
|
||||
// whose [oldest_lsn, latest_lsn) range overlaps with any of the other layers
|
||||
// that we are writing out. Otherwise, when we advance 'disk_consistent_lsn', it's
|
||||
// ambiguous whether those layers are already durable on disk or not. For example,
|
||||
// imagine that there are two layers in memory that contain page versions in the
|
||||
// following LSN ranges:
|
||||
//
|
||||
// A: 100-150
|
||||
// B: 110-200
|
||||
//
|
||||
// If we flush layer A, we must also flush layer B, because they overlap. If we
|
||||
// flushed only A, and advanced 'disk_consistent_lsn' to 150, we would break the
|
||||
// rule that all WAL older than 'disk_consistent_lsn' are durable on disk, because
|
||||
// B contains some WAL older than 150. On the other hand, if we flushed out A and
|
||||
// advanced 'disk_consistent_lsn' only up to 110, after crash and restart we would
|
||||
// delete the first layer because its end LSN is larger than 110. If we changed
|
||||
// the deletion logic to not delete it, then we would start streaming at 110, and
|
||||
// process again the WAL records in the range 110-150 that are already in layer A,
|
||||
// and the WAL processing code does not cope with that. We solve that dilemma by
|
||||
// insisting that if we write out the first layer, we also write out the second
|
||||
// layer, and advance disk_consistent_lsn all the way up to 200.
|
||||
//
|
||||
let distance = last_record_lsn.widening_sub(oldest_lsn);
|
||||
if (distance < 0
|
||||
|| distance < checkpoint_distance.into()
|
||||
|| oldest_generation == current_generation
|
||||
|| oldest_generation == current_generation)
|
||||
&& oldest_lsn >= freeze_end_lsn
|
||||
// this layer intersects with evicted layer and so also need to be evicted
|
||||
{
|
||||
info!(
|
||||
"the oldest layer is now {} which is {} bytes behind last_record_lsn",
|
||||
oldest_layer.filename().display(),
|
||||
distance
|
||||
);
|
||||
disk_consistent_lsn = oldest_pending_lsn;
|
||||
disk_consistent_lsn = oldest_lsn;
|
||||
break;
|
||||
}
|
||||
let latest_lsn = oldest_layer.get_latest_lsn();
|
||||
if latest_lsn > freeze_end_lsn {
|
||||
freeze_end_lsn = latest_lsn; // calculate max of latest_lsn of the layers we're about to evict
|
||||
}
|
||||
layers.remove_open(oldest_layer_id);
|
||||
evicted_layers.push((oldest_layer_id, oldest_layer));
|
||||
}
|
||||
|
||||
drop(layers);
|
||||
drop(write_guard);
|
||||
|
||||
let mut this_layer_paths = self.evict_layer(oldest_layer_id, reconstruct_pages)?;
|
||||
layer_paths.append(&mut this_layer_paths);
|
||||
|
||||
write_guard = self.write_lock.lock().unwrap();
|
||||
layers = self.layers.lock().unwrap();
|
||||
// Freeze evicted layers
|
||||
for (_evicted_layer_id, evicted_layer) in evicted_layers.iter() {
|
||||
// Mark the layer as no longer accepting writes and record the end_lsn.
|
||||
// This happens in-place, no new layers are created now.
|
||||
evicted_layer.freeze(freeze_end_lsn);
|
||||
layers.insert_historic(evicted_layer.clone());
|
||||
}
|
||||
|
||||
// Call unload() on all frozen layers, to release memory.
|
||||
@@ -1535,6 +1562,14 @@ impl LayeredTimeline {
|
||||
drop(layers);
|
||||
drop(write_guard);
|
||||
|
||||
// Create delta/image layers for evicted layers
|
||||
for (_evicted_layer_id, evicted_layer) in evicted_layers.iter() {
|
||||
let mut this_layer_paths =
|
||||
self.evict_layer(evicted_layer.clone(), reconstruct_pages)?;
|
||||
layer_paths.append(&mut this_layer_paths);
|
||||
}
|
||||
|
||||
// Sync layers
|
||||
if !layer_paths.is_empty() {
|
||||
// We must fsync the timeline dir to ensure the directory entries for
|
||||
// new layer files are durable
|
||||
@@ -1575,7 +1610,7 @@ impl LayeredTimeline {
|
||||
ondisk_prev_record_lsn,
|
||||
ancestor_timelineid,
|
||||
self.ancestor_lsn,
|
||||
self.latest_gc_cutoff_lsn.load(),
|
||||
*self.latest_gc_cutoff_lsn.read().unwrap(),
|
||||
self.initdb_lsn,
|
||||
);
|
||||
|
||||
@@ -1602,52 +1637,29 @@ impl LayeredTimeline {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn evict_layer(&self, layer_id: LayerId, reconstruct_pages: bool) -> Result<Vec<PathBuf>> {
|
||||
// Mark the layer as no longer accepting writes and record the end_lsn.
|
||||
// This happens in-place, no new layers are created now.
|
||||
// We call `get_last_record_lsn` again, which may be different from the
|
||||
// original load, as we may have released the write lock since then.
|
||||
|
||||
let mut write_guard = self.write_lock.lock().unwrap();
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
fn evict_layer(
|
||||
&self,
|
||||
layer: Arc<InMemoryLayer>,
|
||||
reconstruct_pages: bool,
|
||||
) -> Result<Vec<PathBuf>> {
|
||||
let new_historics = layer.write_to_disk(self, reconstruct_pages)?;
|
||||
|
||||
let mut layer_paths = Vec::new();
|
||||
let _write_guard = self.write_lock.lock().unwrap();
|
||||
let mut layers = self.layers.lock().unwrap();
|
||||
|
||||
let global_layer_map = GLOBAL_LAYER_MAP.read().unwrap();
|
||||
if let Some(oldest_layer) = global_layer_map.get(&layer_id) {
|
||||
drop(global_layer_map);
|
||||
oldest_layer.freeze(self.get_last_record_lsn());
|
||||
// Finally, replace the frozen in-memory layer with the new on-disk layers
|
||||
layers.remove_historic(layer);
|
||||
|
||||
// The layer is no longer open, update the layer map to reflect this.
|
||||
// We will replace it with on-disk historics below.
|
||||
layers.remove_open(layer_id);
|
||||
layers.insert_historic(oldest_layer.clone());
|
||||
|
||||
// Write the now-frozen layer to disk. That could take a while, so release the lock while do it
|
||||
drop(layers);
|
||||
drop(write_guard);
|
||||
|
||||
let new_historics = oldest_layer.write_to_disk(self, reconstruct_pages)?;
|
||||
|
||||
write_guard = self.write_lock.lock().unwrap();
|
||||
layers = self.layers.lock().unwrap();
|
||||
|
||||
// Finally, replace the frozen in-memory layer with the new on-disk layers
|
||||
layers.remove_historic(oldest_layer);
|
||||
|
||||
// Add the historics to the LayerMap
|
||||
for delta_layer in new_historics.delta_layers {
|
||||
layer_paths.push(delta_layer.path());
|
||||
layers.insert_historic(Arc::new(delta_layer));
|
||||
}
|
||||
for image_layer in new_historics.image_layers {
|
||||
layer_paths.push(image_layer.path());
|
||||
layers.insert_historic(Arc::new(image_layer));
|
||||
}
|
||||
// Add the historics to the LayerMap
|
||||
for delta_layer in new_historics.delta_layers {
|
||||
layer_paths.push(delta_layer.path());
|
||||
layers.insert_historic(Arc::new(delta_layer));
|
||||
}
|
||||
for image_layer in new_historics.image_layers {
|
||||
layer_paths.push(image_layer.path());
|
||||
layers.insert_historic(Arc::new(image_layer));
|
||||
}
|
||||
drop(layers);
|
||||
drop(write_guard);
|
||||
|
||||
Ok(layer_paths)
|
||||
}
|
||||
|
||||
@@ -1677,12 +1689,13 @@ impl LayeredTimeline {
|
||||
let now = Instant::now();
|
||||
let mut result: GcResult = Default::default();
|
||||
let disk_consistent_lsn = self.get_disk_consistent_lsn();
|
||||
let _checkpoint_cs = self.checkpoint_cs.lock().unwrap();
|
||||
|
||||
let _enter = info_span!("garbage collection", timeline = %self.timelineid, tenant = %self.tenantid, cutoff = %cutoff).entered();
|
||||
|
||||
// We need to ensure that no one branches at a point before latest_gc_cutoff_lsn.
|
||||
// See branch_timeline() for details.
|
||||
self.latest_gc_cutoff_lsn.store(cutoff);
|
||||
*self.latest_gc_cutoff_lsn.write().unwrap() = cutoff;
|
||||
|
||||
info!("GC starting");
|
||||
|
||||
|
||||
@@ -175,7 +175,10 @@ impl Write for EphemeralFile {
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), std::io::Error> {
|
||||
todo!()
|
||||
// we don't need to flush data:
|
||||
// * we either write input bytes or not, not keeping any intermediate data buffered
|
||||
// * rust unix file `flush` impl does not flush things either, returning `Ok(())`
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -173,7 +173,14 @@ impl Layer for ImageLayer {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.chapter_reader(BLOCKY_IMAGES_CHAPTER)?;
|
||||
chapter.read_exact_at(&mut buf, offset)?;
|
||||
|
||||
chapter.read_exact_at(&mut buf, offset).with_context(|| {
|
||||
format!(
|
||||
"failed to read page from data file {} at offset {}",
|
||||
self.filename().display(),
|
||||
offset
|
||||
)
|
||||
})?;
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
@@ -39,8 +39,20 @@ pub struct InMemoryLayer {
|
||||
///
|
||||
start_lsn: Lsn,
|
||||
|
||||
/// LSN of the oldest page version stored in this layer
|
||||
oldest_pending_lsn: Lsn,
|
||||
///
|
||||
/// LSN of the oldest page version stored in this layer.
|
||||
///
|
||||
/// This is different from 'start_lsn' in that we enforce that the 'start_lsn'
|
||||
/// of a layer always matches the 'end_lsn' of its predecessor, even if there
|
||||
/// are no page versions until at a later LSN. That way you can detect any
|
||||
/// missing layer files more easily. 'oldest_lsn' is the first page version
|
||||
/// actually stored in this layer. In the range between 'start_lsn' and
|
||||
/// 'oldest_lsn', there are no changes to the segment.
|
||||
/// 'oldest_lsn' is used to adjust 'disk_consistent_lsn' and that is why it should
|
||||
/// point to the beginning of WAL record. This is the other difference with 'start_lsn'
|
||||
/// which points to end of WAL record. This is why 'oldest_lsn' can be smaller than 'start_lsn'.
|
||||
///
|
||||
oldest_lsn: Lsn,
|
||||
|
||||
/// The above fields never change. The parts that do change are in 'inner',
|
||||
/// and protected by mutex.
|
||||
@@ -73,6 +85,14 @@ pub struct InMemoryLayerInner {
|
||||
/// a non-blocky rel, 'seg_sizes' is not used and is always empty.
|
||||
///
|
||||
seg_sizes: VecMap<Lsn, SegmentBlk>,
|
||||
|
||||
///
|
||||
/// LSN of the newest page version stored in this layer.
|
||||
///
|
||||
/// The difference between 'end_lsn' and 'latest_lsn' is the same as between
|
||||
/// 'start_lsn' and 'oldest_lsn'. See comments in 'oldest_lsn'.
|
||||
///
|
||||
latest_lsn: Lsn,
|
||||
}
|
||||
|
||||
impl InMemoryLayerInner {
|
||||
@@ -319,8 +339,13 @@ pub struct LayersOnDisk {
|
||||
|
||||
impl InMemoryLayer {
|
||||
/// Return the oldest page version that's stored in this layer
|
||||
pub fn get_oldest_pending_lsn(&self) -> Lsn {
|
||||
self.oldest_pending_lsn
|
||||
pub fn get_oldest_lsn(&self) -> Lsn {
|
||||
self.oldest_lsn
|
||||
}
|
||||
|
||||
pub fn get_latest_lsn(&self) -> Lsn {
|
||||
let inner = self.inner.read().unwrap();
|
||||
inner.latest_lsn
|
||||
}
|
||||
|
||||
///
|
||||
@@ -332,7 +357,7 @@ impl InMemoryLayer {
|
||||
tenantid: ZTenantId,
|
||||
seg: SegmentTag,
|
||||
start_lsn: Lsn,
|
||||
oldest_pending_lsn: Lsn,
|
||||
oldest_lsn: Lsn,
|
||||
) -> Result<InMemoryLayer> {
|
||||
trace!(
|
||||
"initializing new empty InMemoryLayer for writing {} on timeline {} at {}",
|
||||
@@ -355,13 +380,14 @@ impl InMemoryLayer {
|
||||
tenantid,
|
||||
seg,
|
||||
start_lsn,
|
||||
oldest_pending_lsn,
|
||||
oldest_lsn,
|
||||
incremental: false,
|
||||
inner: RwLock::new(InMemoryLayerInner {
|
||||
end_lsn: None,
|
||||
dropped: false,
|
||||
page_versions: PageVersions::new(file),
|
||||
seg_sizes,
|
||||
latest_lsn: oldest_lsn,
|
||||
}),
|
||||
})
|
||||
}
|
||||
@@ -398,6 +424,8 @@ impl InMemoryLayer {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
|
||||
inner.assert_writeable();
|
||||
assert!(lsn >= inner.latest_lsn);
|
||||
inner.latest_lsn = lsn;
|
||||
|
||||
let old = inner.page_versions.append_or_update_last(blknum, lsn, pv)?;
|
||||
|
||||
@@ -509,12 +537,11 @@ impl InMemoryLayer {
|
||||
timelineid: ZTimelineId,
|
||||
tenantid: ZTenantId,
|
||||
start_lsn: Lsn,
|
||||
oldest_pending_lsn: Lsn,
|
||||
oldest_lsn: Lsn,
|
||||
) -> Result<InMemoryLayer> {
|
||||
let seg = src.get_seg_tag();
|
||||
|
||||
assert!(oldest_pending_lsn.is_aligned());
|
||||
assert!(oldest_pending_lsn >= start_lsn);
|
||||
assert!(oldest_lsn.is_aligned());
|
||||
|
||||
trace!(
|
||||
"initializing new InMemoryLayer for writing {} on timeline {} at {}",
|
||||
@@ -538,13 +565,14 @@ impl InMemoryLayer {
|
||||
tenantid,
|
||||
seg,
|
||||
start_lsn,
|
||||
oldest_pending_lsn,
|
||||
oldest_lsn,
|
||||
incremental: true,
|
||||
inner: RwLock::new(InMemoryLayerInner {
|
||||
end_lsn: None,
|
||||
dropped: false,
|
||||
page_versions: PageVersions::new(file),
|
||||
seg_sizes,
|
||||
latest_lsn: oldest_lsn,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ pub struct LayerMap {
|
||||
/// All the layers keyed by segment tag
|
||||
segs: HashMap<SegmentTag, SegEntry>,
|
||||
|
||||
/// All in-memory layers, ordered by 'oldest_pending_lsn' and generation
|
||||
/// All in-memory layers, ordered by 'oldest_lsn' and generation
|
||||
/// of each layer. This allows easy access to the in-memory layer that
|
||||
/// contains the oldest WAL record.
|
||||
open_layers: BinaryHeap<OpenLayerEntry>,
|
||||
@@ -83,16 +83,16 @@ impl LayerMap {
|
||||
|
||||
let layer_id = segentry.update_open(Arc::clone(&layer));
|
||||
|
||||
let oldest_pending_lsn = layer.get_oldest_pending_lsn();
|
||||
let oldest_lsn = layer.get_oldest_lsn();
|
||||
|
||||
// After a crash and restart, 'oldest_pending_lsn' of the oldest in-memory
|
||||
// After a crash and restart, 'oldest_lsn' of the oldest in-memory
|
||||
// layer becomes the WAL streaming starting point, so it better not point
|
||||
// in the middle of a WAL record.
|
||||
assert!(oldest_pending_lsn.is_aligned());
|
||||
assert!(oldest_lsn.is_aligned());
|
||||
|
||||
// Also add it to the binary heap
|
||||
let open_layer_entry = OpenLayerEntry {
|
||||
oldest_pending_lsn: layer.get_oldest_pending_lsn(),
|
||||
oldest_lsn: layer.get_oldest_lsn(),
|
||||
layer_id,
|
||||
generation: self.current_generation,
|
||||
};
|
||||
@@ -352,23 +352,23 @@ impl SegEntry {
|
||||
}
|
||||
|
||||
/// Entry held in LayerMap::open_layers, with boilerplate comparison routines
|
||||
/// to implement a min-heap ordered by 'oldest_pending_lsn' and 'generation'
|
||||
/// to implement a min-heap ordered by 'oldest_lsn' and 'generation'
|
||||
///
|
||||
/// The generation number associated with each entry can be used to distinguish
|
||||
/// recently-added entries (i.e after last call to increment_generation()) from older
|
||||
/// entries with the same 'oldest_pending_lsn'.
|
||||
/// entries with the same 'oldest_lsn'.
|
||||
struct OpenLayerEntry {
|
||||
oldest_pending_lsn: Lsn, // copy of layer.get_oldest_pending_lsn()
|
||||
oldest_lsn: Lsn, // copy of layer.get_oldest_lsn()
|
||||
generation: u64,
|
||||
layer_id: LayerId,
|
||||
}
|
||||
impl Ord for OpenLayerEntry {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
|
||||
// to get that. Entries with identical oldest_pending_lsn are ordered by generation
|
||||
// to get that. Entries with identical oldest_lsn are ordered by generation
|
||||
other
|
||||
.oldest_pending_lsn
|
||||
.cmp(&self.oldest_pending_lsn)
|
||||
.oldest_lsn
|
||||
.cmp(&self.oldest_lsn)
|
||||
.then_with(|| other.generation.cmp(&self.generation))
|
||||
}
|
||||
}
|
||||
@@ -437,7 +437,7 @@ mod tests {
|
||||
conf: &'static PageServerConf,
|
||||
segno: u32,
|
||||
start_lsn: Lsn,
|
||||
oldest_pending_lsn: Lsn,
|
||||
oldest_lsn: Lsn,
|
||||
) -> Arc<InMemoryLayer> {
|
||||
Arc::new(
|
||||
InMemoryLayer::create(
|
||||
@@ -449,7 +449,7 @@ mod tests {
|
||||
segno,
|
||||
},
|
||||
start_lsn,
|
||||
oldest_pending_lsn,
|
||||
oldest_lsn,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ use std::io;
|
||||
use std::net::TcpListener;
|
||||
use std::str;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLockReadGuard};
|
||||
use tracing::*;
|
||||
use zenith_metrics::{register_histogram_vec, HistogramVec};
|
||||
use zenith_utils::auth::{self, JwtAuth};
|
||||
@@ -27,13 +27,10 @@ use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::postgres_backend::is_socket_read_timed_out;
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::postgres_backend::{self, AuthType};
|
||||
use zenith_utils::pq_proto::{
|
||||
BeMessage, FeMessage, RowDescriptor, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC,
|
||||
};
|
||||
use zenith_utils::pq_proto::{BeMessage, FeMessage, RowDescriptor, SINGLE_COL_ROWDESC};
|
||||
use zenith_utils::zid::{ZTenantId, ZTimelineId};
|
||||
|
||||
use crate::basebackup;
|
||||
use crate::branches;
|
||||
use crate::config::PageServerConf;
|
||||
use crate::relish::*;
|
||||
use crate::repository::Timeline;
|
||||
@@ -398,7 +395,12 @@ impl PageServerHandler {
|
||||
/// In either case, if the page server hasn't received the WAL up to the
|
||||
/// requested LSN yet, we will wait for it to arrive. The return value is
|
||||
/// the LSN that should be used to look up the page versions.
|
||||
fn wait_or_get_last_lsn(timeline: &dyn Timeline, lsn: Lsn, latest: bool) -> Result<Lsn> {
|
||||
fn wait_or_get_last_lsn(
|
||||
timeline: &dyn Timeline,
|
||||
mut lsn: Lsn,
|
||||
latest: bool,
|
||||
latest_gc_cutoff_lsn: &RwLockReadGuard<Lsn>,
|
||||
) -> Result<Lsn> {
|
||||
if latest {
|
||||
// Latest page version was requested. If LSN is given, it is a hint
|
||||
// to the page server that there have been no modifications to the
|
||||
@@ -419,22 +421,26 @@ impl PageServerHandler {
|
||||
// walsender completes the authentication and starts streaming the
|
||||
// WAL.
|
||||
if lsn <= last_record_lsn {
|
||||
Ok(last_record_lsn)
|
||||
lsn = last_record_lsn;
|
||||
} else {
|
||||
timeline.wait_lsn(lsn)?;
|
||||
// Since we waited for 'lsn' to arrive, that is now the last
|
||||
// record LSN. (Or close enough for our purposes; the
|
||||
// last-record LSN can advance immediately after we return
|
||||
// anyway)
|
||||
Ok(lsn)
|
||||
}
|
||||
} else {
|
||||
if lsn == Lsn(0) {
|
||||
bail!("invalid LSN(0) in request");
|
||||
}
|
||||
timeline.wait_lsn(lsn)?;
|
||||
Ok(lsn)
|
||||
}
|
||||
ensure!(
|
||||
lsn >= **latest_gc_cutoff_lsn,
|
||||
"tried to request a page version that was garbage collected. requested at {} gc cutoff {}",
|
||||
lsn, **latest_gc_cutoff_lsn
|
||||
);
|
||||
Ok(lsn)
|
||||
}
|
||||
|
||||
fn handle_get_rel_exists_request(
|
||||
@@ -445,7 +451,8 @@ impl PageServerHandler {
|
||||
let _enter = info_span!("get_rel_exists", rel = %req.rel, req_lsn = %req.lsn).entered();
|
||||
|
||||
let tag = RelishTag::Relation(req.rel);
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest)?;
|
||||
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)?;
|
||||
|
||||
let exists = timeline.get_rel_exists(tag, lsn)?;
|
||||
|
||||
@@ -461,7 +468,8 @@ impl PageServerHandler {
|
||||
) -> Result<PagestreamBeMessage> {
|
||||
let _enter = info_span!("get_nblocks", rel = %req.rel, req_lsn = %req.lsn).entered();
|
||||
let tag = RelishTag::Relation(req.rel);
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest)?;
|
||||
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)?;
|
||||
|
||||
let n_blocks = timeline.get_relish_size(tag, lsn)?;
|
||||
|
||||
@@ -482,8 +490,16 @@ impl PageServerHandler {
|
||||
let _enter = info_span!("get_page", rel = %req.rel, blkno = &req.blkno, req_lsn = %req.lsn)
|
||||
.entered();
|
||||
let tag = RelishTag::Relation(req.rel);
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest)?;
|
||||
|
||||
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)?;
|
||||
/*
|
||||
// Add a 1s delay to some requests. The delayed causes the requests to
|
||||
// hit the race condition from github issue #1047 more easily.
|
||||
use rand::Rng;
|
||||
if rand::thread_rng().gen::<u8>() < 5 {
|
||||
std::thread::sleep(std::time::Duration::from_millis(1000));
|
||||
}
|
||||
*/
|
||||
let page = timeline.get_page_at_lsn(tag, req.blkno, lsn)?;
|
||||
|
||||
Ok(PagestreamBeMessage::GetPage(PagestreamGetPageResponse {
|
||||
@@ -504,9 +520,10 @@ impl PageServerHandler {
|
||||
// check that the timeline exists
|
||||
let timeline = tenant_mgr::get_timeline_for_tenant(tenantid, timelineid)
|
||||
.context("Cannot handle basebackup request for a remote timeline")?;
|
||||
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
|
||||
if let Some(lsn) = lsn {
|
||||
timeline
|
||||
.check_lsn_is_in_scope(lsn)
|
||||
.check_lsn_is_in_scope(lsn, &latest_gc_cutoff_lsn)
|
||||
.context("invalid basebackup lsn")?;
|
||||
}
|
||||
|
||||
@@ -642,79 +659,21 @@ impl postgres_backend::Handler for PageServerHandler {
|
||||
walreceiver::launch_wal_receiver(self.conf, tenantid, timelineid, &connstr)?;
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("branch_create ") {
|
||||
let err = || format!("invalid branch_create: '{}'", query_string);
|
||||
|
||||
// branch_create <tenantid> <branchname> <startpoint>
|
||||
// TODO lazy static
|
||||
// TODO: escaping, to allow branch names with spaces
|
||||
let re = Regex::new(r"^branch_create ([[:xdigit:]]+) (\S+) ([^\r\n\s;]+)[\r\n\s;]*;?$")
|
||||
.unwrap();
|
||||
let caps = re.captures(query_string).with_context(err)?;
|
||||
|
||||
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
|
||||
let branchname = caps.get(2).with_context(err)?.as_str().to_owned();
|
||||
let startpoint_str = caps.get(3).with_context(err)?.as_str().to_owned();
|
||||
|
||||
self.check_permission(Some(tenantid))?;
|
||||
|
||||
let _enter =
|
||||
info_span!("branch_create", name = %branchname, tenant = %tenantid).entered();
|
||||
|
||||
let branch =
|
||||
branches::create_branch(self.conf, &branchname, &startpoint_str, &tenantid)?;
|
||||
let branch = serde_json::to_vec(&branch)?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(&branch)]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("branch_list ") {
|
||||
// branch_list <zenith tenantid as hex string>
|
||||
let re = Regex::new(r"^branch_list ([[:xdigit:]]+)$").unwrap();
|
||||
let caps = re
|
||||
.captures(query_string)
|
||||
.with_context(|| format!("invalid branch_list: '{}'", query_string))?;
|
||||
|
||||
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
|
||||
|
||||
// since these handlers for tenant/branch commands are deprecated (in favor of http based ones)
|
||||
// just use false in place of include non incremental logical size
|
||||
let branches = crate::branches::get_branches(self.conf, &tenantid, false)?;
|
||||
let branches_buf = serde_json::to_vec(&branches)?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(&branches_buf)]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("tenant_list") {
|
||||
let tenants = crate::tenant_mgr::list_tenants()?;
|
||||
let tenants_buf = serde_json::to_vec(&tenants)?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(&tenants_buf)]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("tenant_create") {
|
||||
let err = || format!("invalid tenant_create: '{}'", query_string);
|
||||
|
||||
// tenant_create <tenantid>
|
||||
let re = Regex::new(r"^tenant_create ([[:xdigit:]]+)$").unwrap();
|
||||
let caps = re.captures(query_string).with_context(err)?;
|
||||
|
||||
self.check_permission(None)?;
|
||||
|
||||
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
|
||||
|
||||
tenant_mgr::create_repository_for_tenant(self.conf, tenantid)?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("status") {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&HELLO_WORLD_ROW)?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.to_ascii_lowercase().starts_with("set ") {
|
||||
// important because psycopg2 executes "SET datestyle TO 'ISO'"
|
||||
// on connect
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("failpoints ") {
|
||||
let (_, failpoints) = query_string.split_at("failpoints ".len());
|
||||
for failpoint in failpoints.split(';') {
|
||||
if let Some((name, actions)) = failpoint.split_once('=') {
|
||||
info!("cfg failpoint: {} {}", name, actions);
|
||||
fail::cfg(name, actions).unwrap();
|
||||
} else {
|
||||
bail!("Invalid failpoints format");
|
||||
}
|
||||
}
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("do_gc ") {
|
||||
// Run GC immediately on given timeline.
|
||||
// FIXME: This is just for tests. See test_runner/batch_others/test_gc.py.
|
||||
|
||||
@@ -94,7 +94,7 @@ use std::{
|
||||
use anyhow::{bail, Context};
|
||||
use tokio::io;
|
||||
use tracing::{error, info};
|
||||
use zenith_utils::zid::{ZTenantId, ZTimelineId};
|
||||
use zenith_utils::zid::{ZTenantId, ZTenantTimelineId, ZTimelineId};
|
||||
|
||||
pub use self::storage_sync::{schedule_timeline_checkpoint_upload, schedule_timeline_download};
|
||||
use self::{local_fs::LocalFs, rust_s3::S3};
|
||||
@@ -104,16 +104,7 @@ use crate::{
|
||||
repository::TimelineSyncState,
|
||||
};
|
||||
|
||||
/// Any timeline has its own id and its own tenant it belongs to,
|
||||
/// the sync processes group timelines by both for simplicity.
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
|
||||
pub struct TimelineSyncId(ZTenantId, ZTimelineId);
|
||||
|
||||
impl std::fmt::Display for TimelineSyncId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "(tenant: {}, timeline: {})", self.0, self.1)
|
||||
}
|
||||
}
|
||||
pub use storage_sync::compression;
|
||||
|
||||
/// A structure to combine all synchronization data to share with pageserver after a successful sync loop initialization.
|
||||
/// Successful initialization includes a case when sync loop is not started, in which case the startup data is returned still,
|
||||
@@ -167,7 +158,7 @@ pub fn start_local_timeline_sync(
|
||||
ZTenantId,
|
||||
HashMap<ZTimelineId, TimelineSyncState>,
|
||||
> = HashMap::new();
|
||||
for (TimelineSyncId(tenant_id, timeline_id), (timeline_metadata, _)) in
|
||||
for (ZTenantTimelineId{tenant_id, timeline_id}, (timeline_metadata, _)) in
|
||||
local_timeline_files
|
||||
{
|
||||
initial_timeline_states
|
||||
@@ -187,7 +178,7 @@ pub fn start_local_timeline_sync(
|
||||
|
||||
fn local_tenant_timeline_files(
|
||||
config: &'static PageServerConf,
|
||||
) -> anyhow::Result<HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)>> {
|
||||
) -> anyhow::Result<HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)>> {
|
||||
let mut local_tenant_timeline_files = HashMap::new();
|
||||
let tenants_dir = config.tenants_path();
|
||||
for tenants_dir_entry in fs::read_dir(&tenants_dir)
|
||||
@@ -222,8 +213,9 @@ fn local_tenant_timeline_files(
|
||||
fn collect_timelines_for_tenant(
|
||||
config: &'static PageServerConf,
|
||||
tenant_path: &Path,
|
||||
) -> anyhow::Result<HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)>> {
|
||||
let mut timelines: HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)> = HashMap::new();
|
||||
) -> anyhow::Result<HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)>> {
|
||||
let mut timelines: HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)> =
|
||||
HashMap::new();
|
||||
let tenant_id = tenant_path
|
||||
.file_name()
|
||||
.and_then(ffi::OsStr::to_str)
|
||||
@@ -244,7 +236,10 @@ fn collect_timelines_for_tenant(
|
||||
match collect_timeline_files(&timeline_path) {
|
||||
Ok((timeline_id, metadata, timeline_files)) => {
|
||||
timelines.insert(
|
||||
TimelineSyncId(tenant_id, timeline_id),
|
||||
ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
},
|
||||
(metadata, timeline_files),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -70,7 +70,8 @@
|
||||
//!
|
||||
//! When pageserver signals shutdown, current sync task gets finished and the loop exists.
|
||||
|
||||
mod compression;
|
||||
/// Expose the module for a binary CLI tool that deals with the corresponding blobs.
|
||||
pub mod compression;
|
||||
mod download;
|
||||
pub mod index;
|
||||
mod upload;
|
||||
@@ -105,7 +106,7 @@ use self::{
|
||||
},
|
||||
upload::upload_timeline_checkpoint,
|
||||
};
|
||||
use super::{RemoteStorage, SyncStartupData, TimelineSyncId};
|
||||
use super::{RemoteStorage, SyncStartupData, ZTenantTimelineId};
|
||||
use crate::{
|
||||
config::PageServerConf, layered_repository::metadata::TimelineMetadata,
|
||||
remote_storage::storage_sync::compression::read_archive_header, repository::TimelineSyncState,
|
||||
@@ -242,13 +243,13 @@ mod sync_queue {
|
||||
/// Limited by the number of retries, after certain threshold the failing task gets evicted and the timeline disabled.
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct SyncTask {
|
||||
sync_id: TimelineSyncId,
|
||||
sync_id: ZTenantTimelineId,
|
||||
retries: u32,
|
||||
kind: SyncKind,
|
||||
}
|
||||
|
||||
impl SyncTask {
|
||||
fn new(sync_id: TimelineSyncId, retries: u32, kind: SyncKind) -> Self {
|
||||
fn new(sync_id: ZTenantTimelineId, retries: u32, kind: SyncKind) -> Self {
|
||||
Self {
|
||||
sync_id,
|
||||
retries,
|
||||
@@ -307,7 +308,10 @@ pub fn schedule_timeline_checkpoint_upload(
|
||||
}
|
||||
|
||||
if !sync_queue::push(SyncTask::new(
|
||||
TimelineSyncId(tenant_id, timeline_id),
|
||||
ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
},
|
||||
0,
|
||||
SyncKind::Upload(NewCheckpoint { layers, metadata }),
|
||||
)) {
|
||||
@@ -338,7 +342,10 @@ pub fn schedule_timeline_download(tenant_id: ZTenantId, timeline_id: ZTimelineId
|
||||
tenant_id, timeline_id
|
||||
);
|
||||
sync_queue::push(SyncTask::new(
|
||||
TimelineSyncId(tenant_id, timeline_id),
|
||||
ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
},
|
||||
0,
|
||||
SyncKind::Download(TimelineDownload {
|
||||
files_to_skip: Arc::new(BTreeSet::new()),
|
||||
@@ -354,7 +361,7 @@ pub(super) fn spawn_storage_sync_thread<
|
||||
S: RemoteStorage<StoragePath = P> + Send + Sync + 'static,
|
||||
>(
|
||||
conf: &'static PageServerConf,
|
||||
local_timeline_files: HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)>,
|
||||
local_timeline_files: HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)>,
|
||||
storage: S,
|
||||
max_concurrent_sync: NonZeroUsize,
|
||||
max_sync_errors: NonZeroU32,
|
||||
@@ -510,7 +517,7 @@ async fn loop_step<
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to process storage sync task for tenant {}, timeline {}: {:?}",
|
||||
sync_id.0, sync_id.1, e
|
||||
sync_id.tenant_id, sync_id.timeline_id, e
|
||||
);
|
||||
None
|
||||
}
|
||||
@@ -524,7 +531,10 @@ async fn loop_step<
|
||||
while let Some((sync_id, state_update)) = task_batch.next().await {
|
||||
debug!("Finished storage sync task for sync id {}", sync_id);
|
||||
if let Some(state_update) = state_update {
|
||||
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
|
||||
let ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
} = sync_id;
|
||||
new_timeline_states
|
||||
.entry(tenant_id)
|
||||
.or_default()
|
||||
@@ -618,7 +628,7 @@ async fn process_task<
|
||||
|
||||
fn schedule_first_sync_tasks(
|
||||
index: &RemoteTimelineIndex,
|
||||
local_timeline_files: HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)>,
|
||||
local_timeline_files: HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)>,
|
||||
) -> HashMap<ZTenantId, HashMap<ZTimelineId, TimelineSyncState>> {
|
||||
let mut initial_timeline_statuses: HashMap<ZTenantId, HashMap<ZTimelineId, TimelineSyncState>> =
|
||||
HashMap::new();
|
||||
@@ -629,7 +639,10 @@ fn schedule_first_sync_tasks(
|
||||
for (sync_id, (local_metadata, local_files)) in local_timeline_files {
|
||||
let local_disk_consistent_lsn = local_metadata.disk_consistent_lsn();
|
||||
|
||||
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
|
||||
let ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
} = sync_id;
|
||||
match index.timeline_entry(&sync_id) {
|
||||
Some(index_entry) => {
|
||||
let timeline_status = compare_local_and_remote_timeline(
|
||||
@@ -672,10 +685,10 @@ fn schedule_first_sync_tasks(
|
||||
}
|
||||
}
|
||||
|
||||
let unprocessed_remote_ids = |remote_id: &TimelineSyncId| {
|
||||
let unprocessed_remote_ids = |remote_id: &ZTenantTimelineId| {
|
||||
initial_timeline_statuses
|
||||
.get(&remote_id.0)
|
||||
.and_then(|timelines| timelines.get(&remote_id.1))
|
||||
.get(&remote_id.tenant_id)
|
||||
.and_then(|timelines| timelines.get(&remote_id.timeline_id))
|
||||
.is_none()
|
||||
};
|
||||
for unprocessed_remote_id in index
|
||||
@@ -683,7 +696,10 @@ fn schedule_first_sync_tasks(
|
||||
.filter(unprocessed_remote_ids)
|
||||
.collect::<Vec<_>>()
|
||||
{
|
||||
let TimelineSyncId(cloud_only_tenant_id, cloud_only_timeline_id) = unprocessed_remote_id;
|
||||
let ZTenantTimelineId {
|
||||
tenant_id: cloud_only_tenant_id,
|
||||
timeline_id: cloud_only_timeline_id,
|
||||
} = unprocessed_remote_id;
|
||||
match index
|
||||
.timeline_entry(&unprocessed_remote_id)
|
||||
.and_then(TimelineIndexEntry::disk_consistent_lsn)
|
||||
@@ -712,7 +728,7 @@ fn schedule_first_sync_tasks(
|
||||
|
||||
fn compare_local_and_remote_timeline(
|
||||
new_sync_tasks: &mut VecDeque<SyncTask>,
|
||||
sync_id: TimelineSyncId,
|
||||
sync_id: ZTenantTimelineId,
|
||||
local_metadata: TimelineMetadata,
|
||||
local_files: Vec<PathBuf>,
|
||||
remote_entry: &TimelineIndexEntry,
|
||||
@@ -769,7 +785,7 @@ async fn update_index_description<
|
||||
>(
|
||||
(storage, index): &(S, RwLock<RemoteTimelineIndex>),
|
||||
timeline_dir: &Path,
|
||||
id: TimelineSyncId,
|
||||
id: ZTenantTimelineId,
|
||||
) -> anyhow::Result<RemoteTimeline> {
|
||||
let mut index_write = index.write().await;
|
||||
let full_index = match index_write.timeline_entry(&id) {
|
||||
@@ -792,7 +808,7 @@ async fn update_index_description<
|
||||
Ok((archive_id, header_size, header)) => full_index.update_archive_contents(archive_id.0, header, header_size),
|
||||
Err((e, archive_id)) => bail!(
|
||||
"Failed to download archive header for tenant {}, timeline {}, archive for Lsn {}: {}",
|
||||
id.0, id.1, archive_id.0,
|
||||
id.tenant_id, id.timeline_id, archive_id.0,
|
||||
e
|
||||
),
|
||||
}
|
||||
@@ -870,7 +886,7 @@ mod test_utils {
|
||||
timeline_id: ZTimelineId,
|
||||
new_upload: NewCheckpoint,
|
||||
) {
|
||||
let sync_id = TimelineSyncId(harness.tenant_id, timeline_id);
|
||||
let sync_id = ZTenantTimelineId::new(harness.tenant_id, timeline_id);
|
||||
upload_timeline_checkpoint(
|
||||
harness.conf,
|
||||
Arc::clone(&remote_assets),
|
||||
@@ -926,7 +942,7 @@ mod test_utils {
|
||||
|
||||
pub async fn expect_timeline(
|
||||
index: &RwLock<RemoteTimelineIndex>,
|
||||
sync_id: TimelineSyncId,
|
||||
sync_id: ZTenantTimelineId,
|
||||
) -> RemoteTimeline {
|
||||
if let Some(TimelineIndexEntry::Full(remote_timeline)) =
|
||||
index.read().await.timeline_entry(&sync_id)
|
||||
@@ -961,18 +977,18 @@ mod test_utils {
|
||||
let mut expected_timeline_entries = BTreeMap::new();
|
||||
for sync_id in actual_sync_ids {
|
||||
actual_branches.insert(
|
||||
sync_id.1,
|
||||
sync_id.tenant_id,
|
||||
index_read
|
||||
.branch_files(sync_id.0)
|
||||
.branch_files(sync_id.tenant_id)
|
||||
.into_iter()
|
||||
.flat_map(|branch_paths| branch_paths.iter())
|
||||
.cloned()
|
||||
.collect::<BTreeSet<_>>(),
|
||||
);
|
||||
expected_branches.insert(
|
||||
sync_id.1,
|
||||
sync_id.tenant_id,
|
||||
expected_index_with_descriptions
|
||||
.branch_files(sync_id.0)
|
||||
.branch_files(sync_id.tenant_id)
|
||||
.into_iter()
|
||||
.flat_map(|branch_paths| branch_paths.iter())
|
||||
.cloned()
|
||||
|
||||
@@ -248,7 +248,7 @@ fn archive_name(disk_consistent_lsn: Lsn, header_size: u64) -> String {
|
||||
archive_name
|
||||
}
|
||||
|
||||
async fn uncompress_with_header(
|
||||
pub async fn uncompress_with_header(
|
||||
files_to_skip: &BTreeSet<PathBuf>,
|
||||
destination_dir: &Path,
|
||||
header: ArchiveHeader,
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::{
|
||||
compression, index::TimelineIndexEntry, sync_queue, tenant_branch_files,
|
||||
update_index_description, SyncKind, SyncTask,
|
||||
},
|
||||
RemoteStorage, TimelineSyncId,
|
||||
RemoteStorage, ZTenantTimelineId,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -52,13 +52,16 @@ pub(super) async fn download_timeline<
|
||||
>(
|
||||
conf: &'static PageServerConf,
|
||||
remote_assets: Arc<(S, RwLock<RemoteTimelineIndex>)>,
|
||||
sync_id: TimelineSyncId,
|
||||
sync_id: ZTenantTimelineId,
|
||||
mut download: TimelineDownload,
|
||||
retries: u32,
|
||||
) -> DownloadedTimeline {
|
||||
debug!("Downloading layers for sync id {}", sync_id);
|
||||
|
||||
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
|
||||
let ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
} = sync_id;
|
||||
let index_read = remote_assets.1.read().await;
|
||||
let remote_timeline = match index_read.timeline_entry(&sync_id) {
|
||||
None => {
|
||||
@@ -110,7 +113,8 @@ pub(super) async fn download_timeline<
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = download_missing_branches(conf, remote_assets.as_ref(), sync_id.0).await {
|
||||
if let Err(e) = download_missing_branches(conf, remote_assets.as_ref(), sync_id.tenant_id).await
|
||||
{
|
||||
error!(
|
||||
"Failed to download missing branches for sync id {}: {:?}",
|
||||
sync_id, e
|
||||
@@ -180,7 +184,10 @@ async fn try_download_archive<
|
||||
S: RemoteStorage<StoragePath = P> + Send + Sync + 'static,
|
||||
>(
|
||||
conf: &'static PageServerConf,
|
||||
TimelineSyncId(tenant_id, timeline_id): TimelineSyncId,
|
||||
ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
}: ZTenantTimelineId,
|
||||
remote_assets: Arc<(S, RwLock<RemoteTimelineIndex>)>,
|
||||
remote_timeline: &RemoteTimeline,
|
||||
archive_id: ArchiveId,
|
||||
@@ -343,7 +350,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_download_timeline() -> anyhow::Result<()> {
|
||||
let repo_harness = RepoHarness::create("test_download_timeline")?;
|
||||
let sync_id = TimelineSyncId(repo_harness.tenant_id, TIMELINE_ID);
|
||||
let sync_id = ZTenantTimelineId::new(repo_harness.tenant_id, TIMELINE_ID);
|
||||
let storage = LocalFs::new(tempdir()?.path().to_owned(), &repo_harness.conf.workdir)?;
|
||||
let index = RwLock::new(RemoteTimelineIndex::try_parse_descriptions_from_paths(
|
||||
repo_harness.conf,
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::{
|
||||
layered_repository::TIMELINES_SEGMENT_NAME,
|
||||
remote_storage::{
|
||||
storage_sync::compression::{parse_archive_name, FileEntry},
|
||||
TimelineSyncId,
|
||||
ZTenantTimelineId,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -53,7 +53,7 @@ impl RelativePath {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RemoteTimelineIndex {
|
||||
branch_files: HashMap<ZTenantId, HashSet<RelativePath>>,
|
||||
timeline_files: HashMap<TimelineSyncId, TimelineIndexEntry>,
|
||||
timeline_files: HashMap<ZTenantTimelineId, TimelineIndexEntry>,
|
||||
}
|
||||
|
||||
impl RemoteTimelineIndex {
|
||||
@@ -80,19 +80,22 @@ impl RemoteTimelineIndex {
|
||||
index
|
||||
}
|
||||
|
||||
pub fn timeline_entry(&self, id: &TimelineSyncId) -> Option<&TimelineIndexEntry> {
|
||||
pub fn timeline_entry(&self, id: &ZTenantTimelineId) -> Option<&TimelineIndexEntry> {
|
||||
self.timeline_files.get(id)
|
||||
}
|
||||
|
||||
pub fn timeline_entry_mut(&mut self, id: &TimelineSyncId) -> Option<&mut TimelineIndexEntry> {
|
||||
pub fn timeline_entry_mut(
|
||||
&mut self,
|
||||
id: &ZTenantTimelineId,
|
||||
) -> Option<&mut TimelineIndexEntry> {
|
||||
self.timeline_files.get_mut(id)
|
||||
}
|
||||
|
||||
pub fn add_timeline_entry(&mut self, id: TimelineSyncId, entry: TimelineIndexEntry) {
|
||||
pub fn add_timeline_entry(&mut self, id: ZTenantTimelineId, entry: TimelineIndexEntry) {
|
||||
self.timeline_files.insert(id, entry);
|
||||
}
|
||||
|
||||
pub fn all_sync_ids(&self) -> impl Iterator<Item = TimelineSyncId> + '_ {
|
||||
pub fn all_sync_ids(&self) -> impl Iterator<Item = ZTenantTimelineId> + '_ {
|
||||
self.timeline_files.keys().copied()
|
||||
}
|
||||
|
||||
@@ -348,7 +351,10 @@ fn try_parse_index_entry(
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
let sync_id = TimelineSyncId(tenant_id, timeline_id);
|
||||
let sync_id = ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
};
|
||||
let timeline_index_entry = index
|
||||
.timeline_files
|
||||
.entry(sync_id)
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::{
|
||||
index::{RemoteTimeline, TimelineIndexEntry},
|
||||
sync_queue, tenant_branch_files, update_index_description, SyncKind, SyncTask,
|
||||
},
|
||||
RemoteStorage, TimelineSyncId,
|
||||
RemoteStorage, ZTenantTimelineId,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -36,12 +36,13 @@ pub(super) async fn upload_timeline_checkpoint<
|
||||
>(
|
||||
config: &'static PageServerConf,
|
||||
remote_assets: Arc<(S, RwLock<RemoteTimelineIndex>)>,
|
||||
sync_id: TimelineSyncId,
|
||||
sync_id: ZTenantTimelineId,
|
||||
new_checkpoint: NewCheckpoint,
|
||||
retries: u32,
|
||||
) -> Option<bool> {
|
||||
debug!("Uploading checkpoint for sync id {}", sync_id);
|
||||
if let Err(e) = upload_missing_branches(config, remote_assets.as_ref(), sync_id.0).await {
|
||||
if let Err(e) = upload_missing_branches(config, remote_assets.as_ref(), sync_id.tenant_id).await
|
||||
{
|
||||
error!(
|
||||
"Failed to upload missing branches for sync id {}: {:?}",
|
||||
sync_id, e
|
||||
@@ -57,7 +58,10 @@ pub(super) async fn upload_timeline_checkpoint<
|
||||
|
||||
let index = &remote_assets.1;
|
||||
|
||||
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
|
||||
let ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
} = sync_id;
|
||||
let timeline_dir = config.timeline_path(&timeline_id, &tenant_id);
|
||||
|
||||
let index_read = index.read().await;
|
||||
@@ -151,11 +155,14 @@ async fn try_upload_checkpoint<
|
||||
>(
|
||||
config: &'static PageServerConf,
|
||||
remote_assets: Arc<(S, RwLock<RemoteTimelineIndex>)>,
|
||||
sync_id: TimelineSyncId,
|
||||
sync_id: ZTenantTimelineId,
|
||||
new_checkpoint: &NewCheckpoint,
|
||||
files_to_skip: BTreeSet<PathBuf>,
|
||||
) -> anyhow::Result<(ArchiveHeader, u64)> {
|
||||
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
|
||||
let ZTenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
} = sync_id;
|
||||
let timeline_dir = config.timeline_path(&timeline_id, &tenant_id);
|
||||
|
||||
let files_to_upload = new_checkpoint
|
||||
@@ -288,7 +295,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn reupload_timeline() -> anyhow::Result<()> {
|
||||
let repo_harness = RepoHarness::create("reupload_timeline")?;
|
||||
let sync_id = TimelineSyncId(repo_harness.tenant_id, TIMELINE_ID);
|
||||
let sync_id = ZTenantTimelineId::new(repo_harness.tenant_id, TIMELINE_ID);
|
||||
let storage = LocalFs::new(tempdir()?.path().to_owned(), &repo_harness.conf.workdir)?;
|
||||
let index = RwLock::new(RemoteTimelineIndex::try_parse_descriptions_from_paths(
|
||||
repo_harness.conf,
|
||||
@@ -484,7 +491,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn reupload_timeline_rejected() -> anyhow::Result<()> {
|
||||
let repo_harness = RepoHarness::create("reupload_timeline_rejected")?;
|
||||
let sync_id = TimelineSyncId(repo_harness.tenant_id, TIMELINE_ID);
|
||||
let sync_id = ZTenantTimelineId::new(repo_harness.tenant_id, TIMELINE_ID);
|
||||
let storage = LocalFs::new(tempdir()?.path().to_owned(), &repo_harness.conf.workdir)?;
|
||||
let index = RwLock::new(RemoteTimelineIndex::try_parse_descriptions_from_paths(
|
||||
repo_harness.conf,
|
||||
|
||||
@@ -7,7 +7,7 @@ use postgres_ffi::{MultiXactId, MultiXactOffset, TransactionId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::ops::{AddAssign, Deref};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLockReadGuard};
|
||||
use std::time::Duration;
|
||||
use zenith_utils::lsn::{Lsn, RecordLsn};
|
||||
use zenith_utils::zid::ZTimelineId;
|
||||
@@ -184,6 +184,9 @@ pub trait Timeline: Send + Sync {
|
||||
///
|
||||
fn wait_lsn(&self, lsn: Lsn) -> Result<()>;
|
||||
|
||||
/// Lock and get timeline's GC cuttof
|
||||
fn get_latest_gc_cutoff_lsn(&self) -> RwLockReadGuard<Lsn>;
|
||||
|
||||
/// Look up given page version.
|
||||
fn get_page_at_lsn(&self, tag: RelishTag, blknum: BlockNumber, lsn: Lsn) -> Result<Bytes>;
|
||||
|
||||
@@ -217,10 +220,12 @@ pub trait Timeline: Send + Sync {
|
||||
|
||||
/// Atomically get both last and prev.
|
||||
fn get_last_record_rlsn(&self) -> RecordLsn;
|
||||
|
||||
/// Get last or prev record separately. Same as get_last_record_rlsn().last/prev.
|
||||
fn get_last_record_lsn(&self) -> Lsn;
|
||||
|
||||
fn get_prev_record_lsn(&self) -> Lsn;
|
||||
fn get_start_lsn(&self) -> Lsn;
|
||||
|
||||
fn get_disk_consistent_lsn(&self) -> Lsn;
|
||||
|
||||
/// Mutate the timeline with a [`TimelineWriter`].
|
||||
@@ -235,7 +240,11 @@ pub trait Timeline: Send + Sync {
|
||||
|
||||
///
|
||||
/// Check that it is valid to request operations with that lsn.
|
||||
fn check_lsn_is_in_scope(&self, lsn: Lsn) -> Result<()>;
|
||||
fn check_lsn_is_in_scope(
|
||||
&self,
|
||||
lsn: Lsn,
|
||||
latest_gc_cutoff_lsn: &RwLockReadGuard<Lsn>,
|
||||
) -> Result<()>;
|
||||
|
||||
/// Retrieve current logical size of the timeline
|
||||
///
|
||||
@@ -297,8 +306,12 @@ pub enum ZenithWalRecord {
|
||||
/// Native PostgreSQL WAL record
|
||||
Postgres { will_init: bool, rec: Bytes },
|
||||
|
||||
/// Set bits in heap visibility map. (heap blkno, flag bits to clear)
|
||||
ClearVisibilityMapFlags { heap_blkno: u32, flags: u8 },
|
||||
/// Clear bits in heap visibility map. ('flags' is bitmap of bits to clear)
|
||||
ClearVisibilityMapFlags {
|
||||
new_heap_blkno: Option<u32>,
|
||||
old_heap_blkno: Option<u32>,
|
||||
flags: u8,
|
||||
},
|
||||
/// Mark transaction IDs as committed on a CLOG page
|
||||
ClogSetCommitted { xids: Vec<TransactionId> },
|
||||
/// Mark transaction IDs as aborted on a CLOG page
|
||||
@@ -987,7 +1000,7 @@ mod tests {
|
||||
.source()
|
||||
.unwrap()
|
||||
.to_string()
|
||||
.contains("is earlier than initdb lsn"));
|
||||
.contains("is earlier than latest GC horizon"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1004,12 +1017,11 @@ mod tests {
|
||||
make_some_layers(&tline, Lsn(0x20))?;
|
||||
|
||||
repo.gc_iteration(Some(TIMELINE_ID), 0x10, false)?;
|
||||
|
||||
let latest_gc_cutoff_lsn = tline.get_latest_gc_cutoff_lsn();
|
||||
assert!(*latest_gc_cutoff_lsn > Lsn(0x25));
|
||||
match tline.get_page_at_lsn(TESTREL_A, 0, Lsn(0x25)) {
|
||||
Ok(_) => panic!("request for page should have failed"),
|
||||
Err(err) => assert!(err
|
||||
.to_string()
|
||||
.contains("tried to request a page version that was garbage collected")),
|
||||
Err(err) => assert!(err.to_string().contains("not found at")),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -332,8 +332,11 @@ impl VirtualFile {
|
||||
// TODO: We could downgrade the locks to read mode before calling
|
||||
// 'func', to allow a little bit more concurrency, but the standard
|
||||
// library RwLock doesn't allow downgrading without releasing the lock,
|
||||
// and that doesn't seem worth the trouble. (parking_lot RwLock would
|
||||
// allow it)
|
||||
// and that doesn't seem worth the trouble.
|
||||
//
|
||||
// XXX: `parking_lot::RwLock` can enable such downgrades, yet its implemenation is fair and
|
||||
// may deadlock on subsequent read calls.
|
||||
// Simply replacing all `RwLock` in project causes deadlocks, so use it sparingly.
|
||||
let result = STORAGE_IO_TIME
|
||||
.with_label_values(&[op, &self.tenantid, &self.timelineid])
|
||||
.observe_closure_duration(|| func(&file));
|
||||
|
||||
@@ -349,49 +349,25 @@ impl WalIngest {
|
||||
decoded: &mut DecodedWALRecord,
|
||||
) -> Result<()> {
|
||||
// Handle VM bit updates that are implicitly part of heap records.
|
||||
|
||||
// First, look at the record to determine which VM bits need
|
||||
// to be cleared. If either of these variables is set, we
|
||||
// need to clear the corresponding bits in the visibility map.
|
||||
let mut new_heap_blkno: Option<u32> = None;
|
||||
let mut old_heap_blkno: Option<u32> = None;
|
||||
if decoded.xl_rmid == pg_constants::RM_HEAP_ID {
|
||||
let info = decoded.xl_info & pg_constants::XLOG_HEAP_OPMASK;
|
||||
if info == pg_constants::XLOG_HEAP_INSERT {
|
||||
let xlrec = XlHeapInsert::decode(buf);
|
||||
assert_eq!(0, buf.remaining());
|
||||
if (xlrec.flags
|
||||
& (pg_constants::XLH_INSERT_ALL_VISIBLE_CLEARED
|
||||
| pg_constants::XLH_INSERT_ALL_FROZEN_SET))
|
||||
!= 0
|
||||
{
|
||||
timeline.put_wal_record(
|
||||
lsn,
|
||||
RelishTag::Relation(RelTag {
|
||||
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
|
||||
spcnode: decoded.blocks[0].rnode_spcnode,
|
||||
dbnode: decoded.blocks[0].rnode_dbnode,
|
||||
relnode: decoded.blocks[0].rnode_relnode,
|
||||
}),
|
||||
decoded.blocks[0].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
|
||||
ZenithWalRecord::ClearVisibilityMapFlags {
|
||||
heap_blkno: decoded.blocks[0].blkno,
|
||||
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
|
||||
},
|
||||
)?;
|
||||
if (xlrec.flags & pg_constants::XLH_INSERT_ALL_VISIBLE_CLEARED) != 0 {
|
||||
new_heap_blkno = Some(decoded.blocks[0].blkno);
|
||||
}
|
||||
} else if info == pg_constants::XLOG_HEAP_DELETE {
|
||||
let xlrec = XlHeapDelete::decode(buf);
|
||||
assert_eq!(0, buf.remaining());
|
||||
if (xlrec.flags & pg_constants::XLH_DELETE_ALL_VISIBLE_CLEARED) != 0 {
|
||||
timeline.put_wal_record(
|
||||
lsn,
|
||||
RelishTag::Relation(RelTag {
|
||||
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
|
||||
spcnode: decoded.blocks[0].rnode_spcnode,
|
||||
dbnode: decoded.blocks[0].rnode_dbnode,
|
||||
relnode: decoded.blocks[0].rnode_relnode,
|
||||
}),
|
||||
decoded.blocks[0].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
|
||||
ZenithWalRecord::ClearVisibilityMapFlags {
|
||||
heap_blkno: decoded.blocks[0].blkno,
|
||||
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
|
||||
},
|
||||
)?;
|
||||
new_heap_blkno = Some(decoded.blocks[0].blkno);
|
||||
}
|
||||
} else if info == pg_constants::XLOG_HEAP_UPDATE
|
||||
|| info == pg_constants::XLOG_HEAP_HOT_UPDATE
|
||||
@@ -400,39 +376,15 @@ impl WalIngest {
|
||||
// the size of tuple data is inferred from the size of the record.
|
||||
// we can't validate the remaining number of bytes without parsing
|
||||
// the tuple data.
|
||||
if (xlrec.flags & pg_constants::XLH_UPDATE_NEW_ALL_VISIBLE_CLEARED) != 0 {
|
||||
timeline.put_wal_record(
|
||||
lsn,
|
||||
RelishTag::Relation(RelTag {
|
||||
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
|
||||
spcnode: decoded.blocks[0].rnode_spcnode,
|
||||
dbnode: decoded.blocks[0].rnode_dbnode,
|
||||
relnode: decoded.blocks[0].rnode_relnode,
|
||||
}),
|
||||
decoded.blocks[0].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
|
||||
ZenithWalRecord::ClearVisibilityMapFlags {
|
||||
heap_blkno: decoded.blocks[0].blkno,
|
||||
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
|
||||
},
|
||||
)?;
|
||||
if (xlrec.flags & pg_constants::XLH_UPDATE_OLD_ALL_VISIBLE_CLEARED) != 0 {
|
||||
old_heap_blkno = Some(decoded.blocks[0].blkno);
|
||||
}
|
||||
if (xlrec.flags & pg_constants::XLH_UPDATE_OLD_ALL_VISIBLE_CLEARED) != 0
|
||||
&& decoded.blocks.len() > 1
|
||||
{
|
||||
timeline.put_wal_record(
|
||||
lsn,
|
||||
RelishTag::Relation(RelTag {
|
||||
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
|
||||
spcnode: decoded.blocks[1].rnode_spcnode,
|
||||
dbnode: decoded.blocks[1].rnode_dbnode,
|
||||
relnode: decoded.blocks[1].rnode_relnode,
|
||||
}),
|
||||
decoded.blocks[1].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
|
||||
ZenithWalRecord::ClearVisibilityMapFlags {
|
||||
heap_blkno: decoded.blocks[1].blkno,
|
||||
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
|
||||
},
|
||||
)?;
|
||||
if (xlrec.flags & pg_constants::XLH_UPDATE_NEW_ALL_VISIBLE_CLEARED) != 0 {
|
||||
// PostgreSQL only uses XLH_UPDATE_NEW_ALL_VISIBLE_CLEARED on a
|
||||
// non-HOT update where the new tuple goes to different page than
|
||||
// the old one. Otherwise, only XLH_UPDATE_OLD_ALL_VISIBLE_CLEARED is
|
||||
// set.
|
||||
new_heap_blkno = Some(decoded.blocks[1].blkno);
|
||||
}
|
||||
}
|
||||
} else if decoded.xl_rmid == pg_constants::RM_HEAP2_ID {
|
||||
@@ -448,23 +400,60 @@ impl WalIngest {
|
||||
};
|
||||
assert_eq!(offset_array_len, buf.remaining());
|
||||
|
||||
// FIXME: why also ALL_FROZEN_SET?
|
||||
if (xlrec.flags
|
||||
& (pg_constants::XLH_INSERT_ALL_VISIBLE_CLEARED
|
||||
| pg_constants::XLH_INSERT_ALL_FROZEN_SET))
|
||||
!= 0
|
||||
{
|
||||
if (xlrec.flags & pg_constants::XLH_INSERT_ALL_VISIBLE_CLEARED) != 0 {
|
||||
new_heap_blkno = Some(decoded.blocks[0].blkno);
|
||||
}
|
||||
}
|
||||
}
|
||||
// FIXME: What about XLOG_HEAP_LOCK and XLOG_HEAP2_LOCK_UPDATED?
|
||||
|
||||
// Clear the VM bits if required.
|
||||
if new_heap_blkno.is_some() || old_heap_blkno.is_some() {
|
||||
let vm_relish = RelishTag::Relation(RelTag {
|
||||
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
|
||||
spcnode: decoded.blocks[0].rnode_spcnode,
|
||||
dbnode: decoded.blocks[0].rnode_dbnode,
|
||||
relnode: decoded.blocks[0].rnode_relnode,
|
||||
});
|
||||
|
||||
let new_vm_blk = new_heap_blkno.map(pg_constants::HEAPBLK_TO_MAPBLOCK);
|
||||
let old_vm_blk = old_heap_blkno.map(pg_constants::HEAPBLK_TO_MAPBLOCK);
|
||||
if new_vm_blk == old_vm_blk {
|
||||
// An UPDATE record that needs to clear the bits for both old and the
|
||||
// new page, both of which reside on the same VM page.
|
||||
timeline.put_wal_record(
|
||||
lsn,
|
||||
vm_relish,
|
||||
new_vm_blk.unwrap(),
|
||||
ZenithWalRecord::ClearVisibilityMapFlags {
|
||||
new_heap_blkno,
|
||||
old_heap_blkno,
|
||||
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
|
||||
},
|
||||
)?;
|
||||
} else {
|
||||
// Clear VM bits for one heap page, or for two pages that reside on
|
||||
// different VM pages.
|
||||
if let Some(new_vm_blk) = new_vm_blk {
|
||||
timeline.put_wal_record(
|
||||
lsn,
|
||||
RelishTag::Relation(RelTag {
|
||||
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
|
||||
spcnode: decoded.blocks[0].rnode_spcnode,
|
||||
dbnode: decoded.blocks[0].rnode_dbnode,
|
||||
relnode: decoded.blocks[0].rnode_relnode,
|
||||
}),
|
||||
decoded.blocks[0].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
|
||||
vm_relish,
|
||||
new_vm_blk,
|
||||
ZenithWalRecord::ClearVisibilityMapFlags {
|
||||
heap_blkno: decoded.blocks[0].blkno,
|
||||
new_heap_blkno,
|
||||
old_heap_blkno: None,
|
||||
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
|
||||
},
|
||||
)?;
|
||||
}
|
||||
if let Some(old_vm_blk) = old_vm_blk {
|
||||
timeline.put_wal_record(
|
||||
lsn,
|
||||
vm_relish,
|
||||
old_vm_blk,
|
||||
ZenithWalRecord::ClearVisibilityMapFlags {
|
||||
new_heap_blkno: None,
|
||||
old_heap_blkno,
|
||||
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
|
||||
},
|
||||
)?;
|
||||
@@ -472,8 +461,6 @@ impl WalIngest {
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: What about XLOG_HEAP_LOCK and XLOG_HEAP2_LOCK_UPDATED?
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -11,14 +11,16 @@ use crate::thread_mgr;
|
||||
use crate::thread_mgr::ThreadKind;
|
||||
use crate::walingest::WalIngest;
|
||||
use anyhow::{bail, Context, Error, Result};
|
||||
use bytes::BytesMut;
|
||||
use fail::fail_point;
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::Mutex;
|
||||
use postgres_ffi::waldecoder::*;
|
||||
use postgres_protocol::message::backend::ReplicationMessage;
|
||||
use postgres_types::PgLsn;
|
||||
use std::cell::Cell;
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Mutex;
|
||||
use std::thread_local;
|
||||
use std::time::SystemTime;
|
||||
use tokio::pin;
|
||||
@@ -27,6 +29,7 @@ use tokio_postgres::{Client, NoTls, SimpleQueryMessage, SimpleQueryRow};
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::*;
|
||||
use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::pq_proto::ZenithFeedback;
|
||||
use zenith_utils::zid::ZTenantId;
|
||||
use zenith_utils::zid::ZTimelineId;
|
||||
|
||||
@@ -50,7 +53,7 @@ thread_local! {
|
||||
}
|
||||
|
||||
fn drop_wal_receiver(tenantid: ZTenantId, timelineid: ZTimelineId) {
|
||||
let mut receivers = WAL_RECEIVERS.lock();
|
||||
let mut receivers = WAL_RECEIVERS.lock().unwrap();
|
||||
receivers.remove(&(tenantid, timelineid));
|
||||
}
|
||||
|
||||
@@ -61,7 +64,7 @@ pub fn launch_wal_receiver(
|
||||
timelineid: ZTimelineId,
|
||||
wal_producer_connstr: &str,
|
||||
) -> Result<()> {
|
||||
let mut receivers = WAL_RECEIVERS.lock();
|
||||
let mut receivers = WAL_RECEIVERS.lock().unwrap();
|
||||
|
||||
match receivers.get_mut(&(tenantid, timelineid)) {
|
||||
Some(receiver) => {
|
||||
@@ -94,7 +97,7 @@ pub fn launch_wal_receiver(
|
||||
|
||||
// Look up current WAL producer connection string in the hash table
|
||||
fn get_wal_producer_connstr(tenantid: ZTenantId, timelineid: ZTimelineId) -> String {
|
||||
let receivers = WAL_RECEIVERS.lock();
|
||||
let receivers = WAL_RECEIVERS.lock().unwrap();
|
||||
|
||||
receivers
|
||||
.get(&(tenantid, timelineid))
|
||||
@@ -159,7 +162,7 @@ fn walreceiver_main(
|
||||
// This is from tokio-postgres docs, but it is a bit weird in our case because we extensively use block_on
|
||||
runtime.spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -253,6 +256,8 @@ fn walreceiver_main(
|
||||
let writer = timeline.writer();
|
||||
walingest.ingest_record(writer.as_ref(), recdata, lsn)?;
|
||||
|
||||
fail_point!("walreceiver-after-ingest");
|
||||
|
||||
last_rec_lsn = lsn;
|
||||
}
|
||||
|
||||
@@ -287,7 +292,6 @@ fn walreceiver_main(
|
||||
};
|
||||
|
||||
if let Some(last_lsn) = status_update {
|
||||
let last_lsn = PgLsn::from(u64::from(last_lsn));
|
||||
let timeline_synced_disk_consistent_lsn =
|
||||
tenant_mgr::get_repository_for_tenant(tenantid)?
|
||||
.get_timeline_state(timelineid)
|
||||
@@ -295,18 +299,32 @@ fn walreceiver_main(
|
||||
.unwrap_or(Lsn(0));
|
||||
|
||||
// The last LSN we processed. It is not guaranteed to survive pageserver crash.
|
||||
let write_lsn = last_lsn;
|
||||
let write_lsn = u64::from(last_lsn);
|
||||
// `disk_consistent_lsn` is the LSN at which page server guarantees local persistence of all received data
|
||||
let flush_lsn = PgLsn::from(u64::from(timeline.get_disk_consistent_lsn()));
|
||||
let flush_lsn = u64::from(timeline.get_disk_consistent_lsn());
|
||||
// The last LSN that is synced to remote storage and is guaranteed to survive pageserver crash
|
||||
// Used by safekeepers to remove WAL preceding `remote_consistent_lsn`.
|
||||
let apply_lsn = PgLsn::from(u64::from(timeline_synced_disk_consistent_lsn));
|
||||
let apply_lsn = u64::from(timeline_synced_disk_consistent_lsn);
|
||||
let ts = SystemTime::now();
|
||||
const NO_REPLY: u8 = 0;
|
||||
|
||||
// Send zenith feedback message.
|
||||
// Regular standby_status_update fields are put into this message.
|
||||
let zenith_status_update = ZenithFeedback {
|
||||
current_timeline_size: timeline.get_current_logical_size() as u64,
|
||||
ps_writelsn: write_lsn,
|
||||
ps_flushlsn: flush_lsn,
|
||||
ps_applylsn: apply_lsn,
|
||||
ps_replytime: ts,
|
||||
};
|
||||
|
||||
debug!("zenith_status_update {:?}", zenith_status_update);
|
||||
|
||||
let mut data = BytesMut::new();
|
||||
zenith_status_update.serialize(&mut data)?;
|
||||
runtime.block_on(
|
||||
physical_stream
|
||||
.as_mut()
|
||||
.standby_status_update(write_lsn, flush_lsn, apply_lsn, ts, NO_REPLY),
|
||||
.zenith_status_update(data.len() as u64, &data),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,25 +363,44 @@ impl PostgresRedoManager {
|
||||
will_init: _,
|
||||
rec: _,
|
||||
} => panic!("tried to pass postgres wal record to zenith WAL redo"),
|
||||
ZenithWalRecord::ClearVisibilityMapFlags { heap_blkno, flags } => {
|
||||
// Calculate the VM block and offset that corresponds to the heap block.
|
||||
let map_block = pg_constants::HEAPBLK_TO_MAPBLOCK(*heap_blkno);
|
||||
let map_byte = pg_constants::HEAPBLK_TO_MAPBYTE(*heap_blkno);
|
||||
let map_offset = pg_constants::HEAPBLK_TO_OFFSET(*heap_blkno);
|
||||
|
||||
// Check that we're modifying the correct VM block.
|
||||
ZenithWalRecord::ClearVisibilityMapFlags {
|
||||
new_heap_blkno,
|
||||
old_heap_blkno,
|
||||
flags,
|
||||
} => {
|
||||
// sanity check that this is modifying the correct relish
|
||||
assert!(
|
||||
check_forknum(&rel, pg_constants::VISIBILITYMAP_FORKNUM),
|
||||
"ClearVisibilityMapFlags record on unexpected rel {:?}",
|
||||
rel
|
||||
);
|
||||
assert!(map_block == blknum);
|
||||
if let Some(heap_blkno) = *new_heap_blkno {
|
||||
// Calculate the VM block and offset that corresponds to the heap block.
|
||||
let map_block = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blkno);
|
||||
let map_byte = pg_constants::HEAPBLK_TO_MAPBYTE(heap_blkno);
|
||||
let map_offset = pg_constants::HEAPBLK_TO_OFFSET(heap_blkno);
|
||||
|
||||
// equivalent to PageGetContents(page)
|
||||
let map = &mut page[pg_constants::MAXALIGN_SIZE_OF_PAGE_HEADER_DATA..];
|
||||
// Check that we're modifying the correct VM block.
|
||||
assert!(map_block == blknum);
|
||||
|
||||
let mask: u8 = flags << map_offset;
|
||||
map[map_byte as usize] &= !mask;
|
||||
// equivalent to PageGetContents(page)
|
||||
let map = &mut page[pg_constants::MAXALIGN_SIZE_OF_PAGE_HEADER_DATA..];
|
||||
|
||||
map[map_byte as usize] &= !(flags << map_offset);
|
||||
}
|
||||
|
||||
// Repeat for 'old_heap_blkno', if any
|
||||
if let Some(heap_blkno) = *old_heap_blkno {
|
||||
let map_block = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blkno);
|
||||
let map_byte = pg_constants::HEAPBLK_TO_MAPBYTE(heap_blkno);
|
||||
let map_offset = pg_constants::HEAPBLK_TO_OFFSET(heap_blkno);
|
||||
|
||||
assert!(map_block == blknum);
|
||||
|
||||
let map = &mut page[pg_constants::MAXALIGN_SIZE_OF_PAGE_HEADER_DATA..];
|
||||
|
||||
map[map_byte as usize] &= !(flags << map_offset);
|
||||
}
|
||||
}
|
||||
// Non-relational WAL records are handled here, with custom code that has the
|
||||
// same effects as the corresponding Postgres WAL redo function.
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
[package]
|
||||
name = "postgres_ffi"
|
||||
version = "0.1.0"
|
||||
authors = ["Heikki Linnakangas <heikki@zenith.tech>"]
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
chrono = "0.4.19"
|
||||
|
||||
@@ -51,6 +51,13 @@ pub type TimeLineID = u32;
|
||||
pub type TimestampTz = i64;
|
||||
pub type XLogSegNo = u64;
|
||||
|
||||
/// Interval of checkpointing metadata file. We should store metadata file to enforce
|
||||
/// predicate that checkpoint.nextXid is larger than any XID in WAL.
|
||||
/// But flushing checkpoint file for each transaction seems to be too expensive,
|
||||
/// so XID_CHECKPOINT_INTERVAL is used to forward align nextXid and so perform
|
||||
/// metadata checkpoint only once per XID_CHECKPOINT_INTERVAL transactions.
|
||||
/// XID_CHECKPOINT_INTERVAL should not be larger than BLCKSZ*CLOG_XACTS_PER_BYTE
|
||||
/// in order to let CLOG_TRUNCATE mechanism correctly extend CLOG.
|
||||
const XID_CHECKPOINT_INTERVAL: u32 = 1024;
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
@@ -400,9 +407,13 @@ impl CheckPoint {
|
||||
///
|
||||
/// Returns 'true' if the XID was updated.
|
||||
pub fn update_next_xid(&mut self, xid: u32) -> bool {
|
||||
let xid = xid.wrapping_add(XID_CHECKPOINT_INTERVAL - 1) & !(XID_CHECKPOINT_INTERVAL - 1);
|
||||
// nextXid should nw greate than any XID in WAL, so increment provided XID and check for wraparround.
|
||||
let mut new_xid = std::cmp::max(xid + 1, pg_constants::FIRST_NORMAL_TRANSACTION_ID);
|
||||
// To reduce number of metadata checkpoints, we forward align XID on XID_CHECKPOINT_INTERVAL.
|
||||
// XID_CHECKPOINT_INTERVAL should not be larger than BLCKSZ*CLOG_XACTS_PER_BYTE
|
||||
new_xid =
|
||||
new_xid.wrapping_add(XID_CHECKPOINT_INTERVAL - 1) & !(XID_CHECKPOINT_INTERVAL - 1);
|
||||
let full_xid = self.nextXid.value;
|
||||
let new_xid = std::cmp::max(xid + 1, pg_constants::FIRST_NORMAL_TRANSACTION_ID);
|
||||
let old_xid = full_xid as u32;
|
||||
if new_xid.wrapping_sub(old_xid) as i32 > 0 {
|
||||
let mut epoch = full_xid >> 32;
|
||||
@@ -520,4 +531,34 @@ mod tests {
|
||||
println!("wal_end={}, tli={}", wal_end, tli);
|
||||
assert_eq!(wal_end, waldump_wal_end);
|
||||
}
|
||||
|
||||
/// Check the math in update_next_xid
|
||||
///
|
||||
/// NOTE: These checks are sensitive to the value of XID_CHECKPOINT_INTERVAL,
|
||||
/// currently 1024.
|
||||
#[test]
|
||||
pub fn test_update_next_xid() {
|
||||
let checkpoint_buf = [0u8; std::mem::size_of::<CheckPoint>()];
|
||||
let mut checkpoint = CheckPoint::decode(&checkpoint_buf).unwrap();
|
||||
|
||||
checkpoint.nextXid = FullTransactionId { value: 10 };
|
||||
assert_eq!(checkpoint.nextXid.value, 10);
|
||||
|
||||
// The input XID gets rounded up to the next XID_CHECKPOINT_INTERVAL
|
||||
// boundary
|
||||
checkpoint.update_next_xid(100);
|
||||
assert_eq!(checkpoint.nextXid.value, 1024);
|
||||
|
||||
// No change
|
||||
checkpoint.update_next_xid(500);
|
||||
assert_eq!(checkpoint.nextXid.value, 1024);
|
||||
checkpoint.update_next_xid(1023);
|
||||
assert_eq!(checkpoint.nextXid.value, 1024);
|
||||
|
||||
// The function returns the *next* XID, given the highest XID seen so
|
||||
// far. So when we pass 1024, the nextXid gets bumped up to the next
|
||||
// XID_CHECKPOINT_INTERVAL boundary.
|
||||
checkpoint.update_next_xid(1024);
|
||||
assert_eq!(checkpoint.nextXid.value, 2048);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +1,33 @@
|
||||
[package]
|
||||
name = "proxy"
|
||||
version = "0.1.0"
|
||||
authors = ["Stas Kelvich <stas.kelvich@gmail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
bytes = { version = "1.0.1", features = ['serde'] }
|
||||
lazy_static = "1.4.0"
|
||||
md5 = "0.7.0"
|
||||
rand = "0.8.3"
|
||||
clap = "3.0"
|
||||
futures = "0.3.13"
|
||||
hashbrown = "0.11.2"
|
||||
hex = "0.4.3"
|
||||
hyper = "0.14"
|
||||
routerify = "2"
|
||||
lazy_static = "1.4.0"
|
||||
md5 = "0.7.0"
|
||||
parking_lot = "0.11.2"
|
||||
pin-project-lite = "0.2.7"
|
||||
rand = "0.8.3"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
rustls = "0.19.1"
|
||||
scopeguard = "1.1.0"
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
tokio = { version = "1.11", features = ["macros"] }
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
||||
clap = "2.33.0"
|
||||
rustls = "0.19.1"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
tokio-rustls = "0.22.0"
|
||||
|
||||
zenith_utils = { path = "../zenith_utils" }
|
||||
zenith_metrics = { path = "../zenith_metrics" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-postgres-rustls = "0.8.0"
|
||||
rcgen = "0.8.14"
|
||||
|
||||
169
proxy/src/auth.rs
Normal file
169
proxy/src/auth.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::cplane_api::{self, CPlaneApi};
|
||||
use crate::stream::PqStream;
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use std::collections::HashMap;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
}
|
||||
|
||||
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
|
||||
let mut get_param = |key| {
|
||||
value
|
||||
.remove(key)
|
||||
.with_context(|| format!("{} is missing in startup packet", key))
|
||||
};
|
||||
|
||||
let user = get_param("user")?;
|
||||
let db = get_param("database")?;
|
||||
|
||||
Ok(Self { user, dbname: db })
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
/// Use credentials to authenticate the user.
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
use crate::config::ClientAuthMethod::*;
|
||||
use crate::config::RouterConfig::*;
|
||||
let db_info = match &config.router_config {
|
||||
Static { host, port } => handle_static(host.clone(), *port, client, self).await,
|
||||
Dynamic(Mixed) => {
|
||||
if self.user.ends_with("@zenith") {
|
||||
handle_existing_user(config, client, self).await
|
||||
} else {
|
||||
handle_new_user(config, client).await
|
||||
}
|
||||
}
|
||||
Dynamic(Password) => handle_existing_user(config, client, self).await,
|
||||
Dynamic(Link) => handle_new_user(config, client).await,
|
||||
};
|
||||
|
||||
db_info.context("failed to authenticate client")
|
||||
}
|
||||
}
|
||||
|
||||
fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
async fn handle_static(
|
||||
host: String,
|
||||
port: u16,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: ClientCredentials,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
client
|
||||
.write_message(&Be::AuthenticationCleartextPassword)
|
||||
.await?;
|
||||
|
||||
// Read client's password bytes
|
||||
let msg = match client.read_message().await? {
|
||||
Fe::PasswordMessage(msg) => msg,
|
||||
bad => bail!("unexpected message type: {:?}", bad),
|
||||
};
|
||||
|
||||
let cleartext_password = std::str::from_utf8(&msg)?.split('\0').next().unwrap();
|
||||
|
||||
let db_info = DatabaseInfo {
|
||||
host,
|
||||
port,
|
||||
dbname: creds.dbname.clone(),
|
||||
user: creds.user.clone(),
|
||||
password: Some(cleartext_password.into()),
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
async fn handle_existing_user(
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: ClientCredentials,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let md5_salt = rand::random();
|
||||
|
||||
client
|
||||
.write_message(&Be::AuthenticationMD5Password(&md5_salt))
|
||||
.await?;
|
||||
|
||||
// Read client's password hash
|
||||
let msg = match client.read_message().await? {
|
||||
Fe::PasswordMessage(msg) => msg,
|
||||
bad => bail!("unexpected message type: {:?}", bad),
|
||||
};
|
||||
|
||||
let (_trailing_null, md5_response) = msg
|
||||
.split_last()
|
||||
.ok_or_else(|| anyhow!("unexpected password message"))?;
|
||||
|
||||
let cplane = CPlaneApi::new(&config.auth_endpoint);
|
||||
let db_info = cplane
|
||||
.authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
async fn handle_new_user(
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
|
||||
|
||||
let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&Be::NoticeResponse(greeting))
|
||||
.await?;
|
||||
|
||||
// Wait for web console response
|
||||
waiter.await?.map_err(|e| anyhow!(e))
|
||||
})
|
||||
.await?;
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
"☀️ Welcome to Zenith!\n",
|
||||
"To proceed with database creation, open the following link:\n\n",
|
||||
" {redirect_uri}{session_id}\n\n",
|
||||
"It needs to be done once and we will send you '.pgpass' file,\n",
|
||||
"which will allow you to access or create ",
|
||||
"databases without opening your web browser."
|
||||
],
|
||||
redirect_uri = redirect_uri,
|
||||
session_id = session_id,
|
||||
)
|
||||
}
|
||||
106
proxy/src/cancellation.rs
Normal file
106
proxy/src/cancellation.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use hashbrown::HashMap;
|
||||
use parking_lot::Mutex;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use zenith_utils::pq_proto::CancelKeyData;
|
||||
|
||||
/// Enables serving CancelRequests.
|
||||
#[derive(Default)]
|
||||
pub struct CancelMap(Mutex<HashMap<CancelKeyData, Option<CancelClosure>>>);
|
||||
|
||||
impl CancelMap {
|
||||
/// Cancel a running query for the corresponding connection.
|
||||
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
|
||||
let cancel_closure = self
|
||||
.0
|
||||
.lock()
|
||||
.get(&key)
|
||||
.and_then(|x| x.clone())
|
||||
.with_context(|| format!("unknown session: {:?}", key))?;
|
||||
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
|
||||
where
|
||||
F: FnOnce(Session<'a>) -> R,
|
||||
R: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
// so it doesn't matter.
|
||||
let key = rand::random();
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
self.0
|
||||
.lock()
|
||||
.try_insert(key, None)
|
||||
.map_err(|_| anyhow!("session already exists: {:?}", key))?;
|
||||
|
||||
// This will guarantee that the session gets dropped
|
||||
// as soon as the future is finished.
|
||||
scopeguard::defer! {
|
||||
self.0.lock().remove(&key);
|
||||
}
|
||||
|
||||
let session = Session::new(key, self);
|
||||
f(session).await
|
||||
}
|
||||
}
|
||||
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
/// it's impossible to name a type of an unboxed future
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
#[derive(Clone)]
|
||||
pub struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
}
|
||||
|
||||
impl CancelClosure {
|
||||
pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
|
||||
Self {
|
||||
socket_addr,
|
||||
cancel_token,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for registering query cancellation tokens.
|
||||
pub struct Session<'a> {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
cancel_map: &'a CancelMap,
|
||||
}
|
||||
|
||||
impl<'a> Session<'a> {
|
||||
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
|
||||
Self { key, cancel_map }
|
||||
}
|
||||
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in [`crate::proxy::handshake`].
|
||||
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
self.cancel_map
|
||||
.0
|
||||
.lock()
|
||||
.insert(self.key, Some(cancel_closure));
|
||||
|
||||
self.key
|
||||
}
|
||||
}
|
||||
42
proxy/src/compute.rs
Normal file
42
proxy/src/compute.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use anyhow::Context;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
|
||||
/// Compute node connection params.
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
impl DatabaseInfo {
|
||||
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
|
||||
let host_port = format!("{}:{}", self.host, self.port);
|
||||
host_port
|
||||
.to_socket_addrs()
|
||||
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
|
||||
.next()
|
||||
.context("cannot resolve at least one SocketAddr")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,46 @@
|
||||
use crate::cplane_api::DatabaseInfo;
|
||||
use anyhow::{anyhow, ensure, Context};
|
||||
use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig};
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub type SslConfig = Arc<ServerConfig>;
|
||||
pub type TlsConfig = Arc<ServerConfig>;
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum ClientAuthMethod {
|
||||
Password,
|
||||
Link,
|
||||
|
||||
/// Use password auth only if username ends with "@zenith"
|
||||
Mixed,
|
||||
}
|
||||
|
||||
pub enum RouterConfig {
|
||||
Static { host: String, port: u16 },
|
||||
Dynamic(ClientAuthMethod),
|
||||
}
|
||||
|
||||
impl FromStr for ClientAuthMethod {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> anyhow::Result<Self> {
|
||||
use ClientAuthMethod::*;
|
||||
match s {
|
||||
"password" => Ok(Password),
|
||||
"link" => Ok(Link),
|
||||
"mixed" => Ok(Mixed),
|
||||
_ => Err(anyhow::anyhow!("Invlid option for router")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProxyConfig {
|
||||
/// main entrypoint for users to connect to
|
||||
pub proxy_address: SocketAddr,
|
||||
|
||||
/// method of assigning compute nodes
|
||||
pub router_config: RouterConfig,
|
||||
|
||||
/// internally used for status and prometheus metrics
|
||||
pub http_address: SocketAddr,
|
||||
|
||||
@@ -24,26 +55,10 @@ pub struct ProxyConfig {
|
||||
/// control plane address where we would check auth.
|
||||
pub auth_endpoint: String,
|
||||
|
||||
pub ssl_config: Option<SslConfig>,
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
pub type ProxyWaiters = crate::waiters::Waiters<Result<DatabaseInfo, String>>;
|
||||
|
||||
pub struct ProxyState {
|
||||
pub conf: ProxyConfig,
|
||||
pub waiters: ProxyWaiters,
|
||||
}
|
||||
|
||||
impl ProxyState {
|
||||
pub fn new(conf: ProxyConfig) -> Self {
|
||||
Self {
|
||||
conf,
|
||||
waiters: ProxyWaiters::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<SslConfig> {
|
||||
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("SSL key file")?;
|
||||
let mut keys = pemfile::pkcs8_private_keys(&mut &key_bytes[..])
|
||||
@@ -1,106 +1,87 @@
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::waiters::{Waiter, Waiters};
|
||||
use anyhow::{anyhow, bail};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
|
||||
use crate::state::ProxyWaiters;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
pub password: Option<String>,
|
||||
lazy_static! {
|
||||
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfo, String>> = Default::default();
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
enum ProxyAuthResponse {
|
||||
Ready { conn_info: DatabaseInfo },
|
||||
Error { error: String },
|
||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||
/// Give caller an opportunity to wait for cplane's reply.
|
||||
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
|
||||
where
|
||||
F: FnOnce(Waiter<'static, Result<DatabaseInfo, String>>) -> R,
|
||||
R: std::future::Future<Output = anyhow::Result<T>>,
|
||||
{
|
||||
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
||||
f(waiter).await
|
||||
}
|
||||
|
||||
impl DatabaseInfo {
|
||||
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
|
||||
let host_port = format!("{}:{}", self.host, self.port);
|
||||
host_port
|
||||
.to_socket_addrs()
|
||||
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
|
||||
.next()
|
||||
.context("cannot resolve at least one SocketAddr")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
pub fn notify(psql_session_id: &str, msg: Result<DatabaseInfo, String>) -> anyhow::Result<()> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Zenith console API wrapper.
|
||||
pub struct CPlaneApi<'a> {
|
||||
auth_endpoint: &'a str,
|
||||
waiters: &'a ProxyWaiters,
|
||||
}
|
||||
|
||||
impl<'a> CPlaneApi<'a> {
|
||||
pub fn new(auth_endpoint: &'a str, waiters: &'a ProxyWaiters) -> Self {
|
||||
Self {
|
||||
auth_endpoint,
|
||||
waiters,
|
||||
}
|
||||
pub fn new(auth_endpoint: &'a str) -> Self {
|
||||
Self { auth_endpoint }
|
||||
}
|
||||
}
|
||||
|
||||
impl CPlaneApi<'_> {
|
||||
pub fn authenticate_proxy_request(
|
||||
pub async fn authenticate_proxy_request(
|
||||
&self,
|
||||
user: &str,
|
||||
database: &str,
|
||||
creds: ClientCredentials,
|
||||
md5_response: &[u8],
|
||||
salt: &[u8; 4],
|
||||
psql_session_id: &str,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
|
||||
url.query_pairs_mut()
|
||||
.append_pair("login", user)
|
||||
.append_pair("database", database)
|
||||
.append_pair("login", &creds.user)
|
||||
.append_pair("database", &creds.dbname)
|
||||
.append_pair("md5response", std::str::from_utf8(md5_response)?)
|
||||
.append_pair("salt", &hex::encode(salt))
|
||||
.append_pair("psql_session_id", psql_session_id);
|
||||
|
||||
let waiter = self.waiters.register(psql_session_id.to_owned());
|
||||
with_waiter(psql_session_id, |waiter| async {
|
||||
println!("cplane request: {}", url);
|
||||
// TODO: leverage `reqwest::Client` to reuse connections
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
bail!("Auth failed: {}", resp.status())
|
||||
}
|
||||
|
||||
println!("cplane request: {}", url);
|
||||
let resp = reqwest::blocking::get(url)?;
|
||||
if !resp.status().is_success() {
|
||||
bail!("Auth failed: {}", resp.status())
|
||||
}
|
||||
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
println!("got auth info: #{:?}", auth_info);
|
||||
|
||||
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text()?.as_str())?;
|
||||
println!("got auth info: #{:?}", auth_info);
|
||||
|
||||
use ProxyAuthResponse::*;
|
||||
match auth_info {
|
||||
Ready { conn_info } => Ok(conn_info),
|
||||
Error { error } => bail!(error),
|
||||
NotReady { .. } => waiter.wait()?.map_err(|e| anyhow!(e)),
|
||||
}
|
||||
use ProxyAuthResponse::*;
|
||||
match auth_info {
|
||||
Ready { conn_info } => Ok(conn_info),
|
||||
Error { error } => bail!(error),
|
||||
NotReady { .. } => waiter.await?.map_err(|e| anyhow!(e)),
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: the order of constructors is important.
|
||||
// https://serde.rs/enum-representations.html#untagged
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
enum ProxyAuthResponse {
|
||||
Ready { conn_info: DatabaseInfo },
|
||||
Error { error: String },
|
||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
use anyhow::anyhow;
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use routerify::RouterBuilder;
|
||||
|
||||
use std::net::TcpListener;
|
||||
use zenith_utils::http::endpoint;
|
||||
use zenith_utils::http::error::ApiError;
|
||||
use zenith_utils::http::json::json_response;
|
||||
use zenith_utils::http::{RouterBuilder, RouterService};
|
||||
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
Ok(json_response(StatusCode::OK, "")?)
|
||||
}
|
||||
|
||||
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
let router = endpoint::make_router();
|
||||
router.get("/v1/status", status_handler)
|
||||
}
|
||||
|
||||
pub async fn thread_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
println!("http has shut down");
|
||||
}
|
||||
|
||||
let service = || RouterService::new(make_router().build()?);
|
||||
|
||||
hyper::Server::from_tcp(http_listener)?
|
||||
.serve(service().map_err(|e| anyhow!(e))?)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,137 +5,162 @@
|
||||
/// (control plane API in our case) and can create new databases and accounts
|
||||
/// in somewhat transparent manner (again via communication with control plane API).
|
||||
///
|
||||
use anyhow::bail;
|
||||
use anyhow::{bail, Context};
|
||||
use clap::{App, Arg};
|
||||
use state::{ProxyConfig, ProxyState};
|
||||
use std::thread;
|
||||
use zenith_utils::http::endpoint;
|
||||
use zenith_utils::{tcp_listener, GIT_VERSION};
|
||||
use config::ProxyConfig;
|
||||
use futures::FutureExt;
|
||||
use std::future::Future;
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use zenith_utils::GIT_VERSION;
|
||||
|
||||
use crate::config::{ClientAuthMethod, RouterConfig};
|
||||
|
||||
mod auth;
|
||||
mod cancellation;
|
||||
mod compute;
|
||||
mod config;
|
||||
mod cplane_api;
|
||||
mod http;
|
||||
mod mgmt;
|
||||
mod proxy;
|
||||
mod state;
|
||||
mod stream;
|
||||
mod waiters;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
/// Flattens Result<Result<T>> into Result<T>.
|
||||
async fn flatten_err(
|
||||
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
|
||||
) -> anyhow::Result<()> {
|
||||
f.map(|r| r.context("join error").and_then(|x| x)).await
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
zenith_metrics::set_common_metrics_prefix("zenith_proxy");
|
||||
let arg_matches = App::new("Zenith proxy/router")
|
||||
.version(GIT_VERSION)
|
||||
.arg(
|
||||
Arg::with_name("proxy")
|
||||
.short("p")
|
||||
Arg::new("proxy")
|
||||
.short('p')
|
||||
.long("proxy")
|
||||
.takes_value(true)
|
||||
.help("listen for incoming client connections on ip:port")
|
||||
.default_value("127.0.0.1:4432"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("mgmt")
|
||||
.short("m")
|
||||
Arg::new("auth-method")
|
||||
.long("auth-method")
|
||||
.takes_value(true)
|
||||
.help("Possible values: password | link | mixed")
|
||||
.default_value("mixed"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("static-router")
|
||||
.short('s')
|
||||
.long("static-router")
|
||||
.takes_value(true)
|
||||
.help("Route all clients to host:port"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("mgmt")
|
||||
.short('m')
|
||||
.long("mgmt")
|
||||
.takes_value(true)
|
||||
.help("listen for management callback connection on ip:port")
|
||||
.default_value("127.0.0.1:7000"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("http")
|
||||
.short("h")
|
||||
Arg::new("http")
|
||||
.short('h')
|
||||
.long("http")
|
||||
.takes_value(true)
|
||||
.help("listen for incoming http connections (metrics, etc) on ip:port")
|
||||
.default_value("127.0.0.1:7001"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("uri")
|
||||
.short("u")
|
||||
Arg::new("uri")
|
||||
.short('u')
|
||||
.long("uri")
|
||||
.takes_value(true)
|
||||
.help("redirect unauthenticated users to given uri")
|
||||
.default_value("http://localhost:3000/psql_session/"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("auth-endpoint")
|
||||
.short("a")
|
||||
Arg::new("auth-endpoint")
|
||||
.short('a')
|
||||
.long("auth-endpoint")
|
||||
.takes_value(true)
|
||||
.help("API endpoint for authenticating users")
|
||||
.default_value("http://localhost:3000/authenticate_proxy_request/"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("ssl-key")
|
||||
.short("k")
|
||||
Arg::new("ssl-key")
|
||||
.short('k')
|
||||
.long("ssl-key")
|
||||
.takes_value(true)
|
||||
.help("path to SSL key for client postgres connections"),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("ssl-cert")
|
||||
.short("c")
|
||||
Arg::new("ssl-cert")
|
||||
.short('c')
|
||||
.long("ssl-cert")
|
||||
.takes_value(true)
|
||||
.help("path to SSL cert for client postgres connections"),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let ssl_config = match (
|
||||
let tls_config = match (
|
||||
arg_matches.value_of("ssl-key"),
|
||||
arg_matches.value_of("ssl-cert"),
|
||||
) {
|
||||
(Some(key_path), Some(cert_path)) => {
|
||||
Some(crate::state::configure_ssl(key_path, cert_path)?)
|
||||
}
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_ssl(key_path, cert_path)?),
|
||||
(None, None) => None,
|
||||
_ => bail!("either both or neither ssl-key and ssl-cert must be specified"),
|
||||
};
|
||||
|
||||
let config = ProxyConfig {
|
||||
let auth_method = arg_matches.value_of("auth-method").unwrap().parse()?;
|
||||
let router_config = match arg_matches.value_of("static-router") {
|
||||
None => RouterConfig::Dynamic(auth_method),
|
||||
Some(addr) => {
|
||||
if let ClientAuthMethod::Password = auth_method {
|
||||
let (host, port) = addr.split_once(":").unwrap();
|
||||
RouterConfig::Static {
|
||||
host: host.to_string(),
|
||||
port: port.parse().unwrap(),
|
||||
}
|
||||
} else {
|
||||
bail!("static-router requires --auth-method password")
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
|
||||
router_config,
|
||||
proxy_address: arg_matches.value_of("proxy").unwrap().parse()?,
|
||||
mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?,
|
||||
http_address: arg_matches.value_of("http").unwrap().parse()?,
|
||||
redirect_uri: arg_matches.value_of("uri").unwrap().parse()?,
|
||||
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
|
||||
ssl_config,
|
||||
};
|
||||
let state: &ProxyState = Box::leak(Box::new(ProxyState::new(config)));
|
||||
tls_config,
|
||||
}));
|
||||
|
||||
println!("Version: {}", GIT_VERSION);
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
println!("Starting http on {}", state.conf.http_address);
|
||||
let http_listener = tcp_listener::bind(state.conf.http_address)?;
|
||||
println!("Starting http on {}", config.http_address);
|
||||
let http_listener = TcpListener::bind(config.http_address).await?.into_std()?;
|
||||
|
||||
println!("Starting proxy on {}", state.conf.proxy_address);
|
||||
let pageserver_listener = tcp_listener::bind(state.conf.proxy_address)?;
|
||||
println!("Starting mgmt on {}", config.mgmt_address);
|
||||
let mgmt_listener = TcpListener::bind(config.mgmt_address).await?.into_std()?;
|
||||
|
||||
println!("Starting mgmt on {}", state.conf.mgmt_address);
|
||||
let mgmt_listener = tcp_listener::bind(state.conf.mgmt_address)?;
|
||||
println!("Starting proxy on {}", config.proxy_address);
|
||||
let proxy_listener = TcpListener::bind(config.proxy_address).await?;
|
||||
|
||||
let threads = [
|
||||
thread::Builder::new()
|
||||
.name("Http thread".into())
|
||||
.spawn(move || {
|
||||
let router = http::make_router();
|
||||
endpoint::serve_thread_main(
|
||||
router,
|
||||
http_listener,
|
||||
std::future::pending(), // never shut down
|
||||
)
|
||||
})?,
|
||||
// Spawn a thread to listen for connections. It will spawn further threads
|
||||
// for each connection.
|
||||
thread::Builder::new()
|
||||
.name("Listener thread".into())
|
||||
.spawn(move || proxy::thread_main(state, pageserver_listener))?,
|
||||
thread::Builder::new()
|
||||
.name("Mgmt thread".into())
|
||||
.spawn(move || mgmt::thread_main(state, mgmt_listener))?,
|
||||
];
|
||||
let http = tokio::spawn(http::thread_main(http_listener));
|
||||
let proxy = tokio::spawn(proxy::thread_main(config, proxy_listener));
|
||||
let mgmt = tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener));
|
||||
|
||||
for t in threads {
|
||||
t.join().unwrap()?;
|
||||
}
|
||||
let tasks = [flatten_err(http), flatten_err(proxy), flatten_err(mgmt)];
|
||||
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,44 +1,49 @@
|
||||
use crate::{compute::DatabaseInfo, cplane_api};
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
net::{TcpListener, TcpStream},
|
||||
thread,
|
||||
};
|
||||
|
||||
use serde::Deserialize;
|
||||
use zenith_utils::{
|
||||
postgres_backend::{self, AuthType, PostgresBackend},
|
||||
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
|
||||
};
|
||||
|
||||
use crate::{cplane_api::DatabaseInfo, ProxyState};
|
||||
|
||||
///
|
||||
/// Main proxy listener loop.
|
||||
///
|
||||
/// Listens for connections, and launches a new handler thread for each.
|
||||
///
|
||||
pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> {
|
||||
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
println!("mgmt has shut down");
|
||||
}
|
||||
|
||||
listener
|
||||
.set_nonblocking(false)
|
||||
.context("failed to set listener to blocking")?;
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept()?;
|
||||
let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?;
|
||||
println!("accepted connection from {}", peer_addr);
|
||||
socket.set_nodelay(true).unwrap();
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set client socket option")?;
|
||||
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = handle_connection(state, socket) {
|
||||
if let Err(err) = handle_connection(socket) {
|
||||
println!("error: {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_connection(state: &ProxyState, socket: TcpStream) -> anyhow::Result<()> {
|
||||
let mut conn_handler = MgmtHandler { state };
|
||||
fn handle_connection(socket: TcpStream) -> anyhow::Result<()> {
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?;
|
||||
pgbackend.run(&mut conn_handler)
|
||||
pgbackend.run(&mut MgmtHandler)
|
||||
}
|
||||
|
||||
struct MgmtHandler<'a> {
|
||||
state: &'a ProxyState,
|
||||
}
|
||||
struct MgmtHandler;
|
||||
|
||||
/// Serialized examples:
|
||||
// {
|
||||
@@ -74,13 +79,13 @@ enum PsqlSessionResult {
|
||||
Failure(String),
|
||||
}
|
||||
|
||||
impl postgres_backend::Handler for MgmtHandler<'_> {
|
||||
impl postgres_backend::Handler for MgmtHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let res = try_process_query(self, pgb, query_string);
|
||||
let res = try_process_query(pgb, query_string);
|
||||
// intercept and log error message
|
||||
if res.is_err() {
|
||||
println!("Mgmt query failed: #{:?}", res);
|
||||
@@ -89,11 +94,7 @@ impl postgres_backend::Handler for MgmtHandler<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
fn try_process_query(
|
||||
mgmt: &mut MgmtHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::Result<()> {
|
||||
println!("Got mgmt query: '{}'", query_string);
|
||||
|
||||
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
|
||||
@@ -104,7 +105,7 @@ fn try_process_query(
|
||||
Failure(message) => Err(message),
|
||||
};
|
||||
|
||||
match mgmt.state.waiters.notify(&resp.session_id, msg) {
|
||||
match cplane_api::notify(&resp.session_id, msg) {
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
|
||||
@@ -1,386 +1,332 @@
|
||||
use crate::cplane_api::{CPlaneApi, DatabaseInfo};
|
||||
use crate::ProxyState;
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use crate::auth;
|
||||
use crate::cancellation::{self, CancelClosure, CancelMap};
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::config::{ProxyConfig, TlsConfig};
|
||||
use crate::stream::{MetricsStream, PqStream, Stream};
|
||||
use anyhow::{bail, Context};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::Mutex;
|
||||
use rand::prelude::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::cell::Cell;
|
||||
use std::collections::HashMap;
|
||||
use std::net::{SocketAddr, TcpStream};
|
||||
use std::{io, thread};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::NoTls;
|
||||
use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter};
|
||||
use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream};
|
||||
use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *};
|
||||
use zenith_utils::sock_split::{ReadStream, WriteStream};
|
||||
|
||||
struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: tokio_postgres::CancelToken,
|
||||
}
|
||||
|
||||
impl CancelClosure {
|
||||
async fn try_cancel_query(&self) {
|
||||
if let Ok(socket) = tokio::net::TcpStream::connect(self.socket_addr).await {
|
||||
// NOTE ignoring the result because:
|
||||
// 1. This is a best effort attempt, the database doesn't have to listen
|
||||
// 2. Being opaque about errors here helps avoid leaking info to unauthenticated user
|
||||
let _ = self.cancel_token.cancel_query_raw(socket, NoTls).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
use zenith_utils::pq_proto::{BeMessage as Be, *};
|
||||
|
||||
lazy_static! {
|
||||
// Enables serving CancelRequests
|
||||
static ref CANCEL_MAP: Mutex<HashMap<CancelKeyData, CancelClosure>> = Mutex::new(HashMap::new());
|
||||
|
||||
// Metrics
|
||||
static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_connections_accepted"),
|
||||
"Number of TCP client connections accepted."
|
||||
).unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
static ref NUM_CONNECTIONS_CLOSED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_connections_closed"),
|
||||
"Number of TCP client connections closed."
|
||||
).unwrap();
|
||||
static ref NUM_CONNECTIONS_FAILED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_connections_failed"),
|
||||
"Number of TCP client connections that closed due to error."
|
||||
).unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
static ref NUM_BYTES_PROXIED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_bytes_proxied"),
|
||||
"Number of bytes sent/received between any client and backend."
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
// Used to clean up the CANCEL_MAP. Might not be necessary if we use tokio thread pool in main loop.
|
||||
static THREAD_CANCEL_KEY_DATA: Cell<Option<CancelKeyData>> = Cell::new(None);
|
||||
}
|
||||
|
||||
///
|
||||
/// Main proxy listener loop.
|
||||
///
|
||||
/// Listens for connections, and launches a new handler thread for each.
|
||||
///
|
||||
pub fn thread_main(
|
||||
state: &'static ProxyState,
|
||||
listener: std::net::TcpListener,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept()?;
|
||||
println!("accepted connection from {}", peer_addr);
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
socket.set_nodelay(true).unwrap();
|
||||
|
||||
// TODO Use a threadpool instead. Maybe use tokio's threadpool by
|
||||
// spawning a future into its runtime. Tokio's JoinError should
|
||||
// allow us to handle cleanup properly even if the future panics.
|
||||
thread::Builder::new()
|
||||
.name("Proxy thread".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = proxy_conn_main(state, socket) {
|
||||
NUM_CONNECTIONS_FAILED_COUNTER.inc();
|
||||
println!("error: {}", err);
|
||||
}
|
||||
|
||||
// Clean up CANCEL_MAP.
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
THREAD_CANCEL_KEY_DATA.with(|cell| {
|
||||
if let Some(cancel_key_data) = cell.get() {
|
||||
CANCEL_MAP.lock().remove(&cancel_key_data);
|
||||
};
|
||||
});
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: clean up fields
|
||||
struct ProxyConnection {
|
||||
state: &'static ProxyState,
|
||||
psql_session_id: String,
|
||||
pgb: PostgresBackend,
|
||||
}
|
||||
|
||||
pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> {
|
||||
let conn = ProxyConnection {
|
||||
state,
|
||||
psql_session_id: hex::encode(rand::random::<[u8; 8]>()),
|
||||
pgb: PostgresBackend::new(
|
||||
socket,
|
||||
postgres_backend::AuthType::MD5,
|
||||
state.conf.ssl_config.clone(),
|
||||
false,
|
||||
)?,
|
||||
};
|
||||
|
||||
let (client, server) = match conn.handle_client()? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let server = zenith_utils::sock_split::BidiStream::from_tcp(server);
|
||||
|
||||
let client = match client {
|
||||
Stream::Bidirectional(bidi_stream) => bidi_stream,
|
||||
_ => panic!("invalid stream type"),
|
||||
};
|
||||
|
||||
proxy(client.split(), server.split())
|
||||
}
|
||||
|
||||
impl ProxyConnection {
|
||||
/// Returns Ok(None) when connection was successfully closed.
|
||||
fn handle_client(mut self) -> anyhow::Result<Option<(Stream, TcpStream)>> {
|
||||
let mut authenticate = || {
|
||||
let (username, dbname) = match self.handle_startup()? {
|
||||
Some(x) => x,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
// Both scenarios here should end up producing database credentials
|
||||
if username.ends_with("@zenith") {
|
||||
self.handle_existing_user(&username, &dbname).map(Some)
|
||||
} else {
|
||||
self.handle_new_user().map(Some)
|
||||
}
|
||||
};
|
||||
|
||||
let conn = match authenticate() {
|
||||
Ok(Some(db_info)) => connect_to_db(db_info),
|
||||
Ok(None) => return Ok(None),
|
||||
Err(e) => {
|
||||
// Report the error to the client
|
||||
self.pgb.write_message(&Be::ErrorResponse(&e.to_string()))?;
|
||||
bail!("failed to handle client: {:?}", e);
|
||||
}
|
||||
};
|
||||
|
||||
// We'll get rid of this once migration to async is complete
|
||||
let (pg_version, db_stream) = {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let (pg_version, stream, cancel_key_data) = runtime.block_on(conn)?;
|
||||
self.pgb
|
||||
.write_message(&BeMessage::BackendKeyData(cancel_key_data))?;
|
||||
let stream = stream.into_std()?;
|
||||
stream.set_nonblocking(false)?;
|
||||
|
||||
(pg_version, stream)
|
||||
};
|
||||
|
||||
// Let the client send new requests
|
||||
self.pgb
|
||||
.write_message_noflush(&BeMessage::ParameterStatus(
|
||||
BeParameterStatusMessage::ServerVersion(&pg_version),
|
||||
))?
|
||||
.write_message(&Be::ReadyForQuery)?;
|
||||
|
||||
Ok(Some((self.pgb.into_stream(), db_stream)))
|
||||
}
|
||||
|
||||
/// Returns Ok(None) when connection was successfully closed.
|
||||
fn handle_startup(&mut self) -> anyhow::Result<Option<(String, String)>> {
|
||||
let have_tls = self.pgb.tls_config.is_some();
|
||||
let mut encrypted = false;
|
||||
|
||||
loop {
|
||||
let msg = match self.pgb.read_message()? {
|
||||
Some(Fe::StartupPacket(msg)) => msg,
|
||||
None => bail!("connection is lost"),
|
||||
bad => bail!("unexpected message type: {:?}", bad),
|
||||
};
|
||||
println!("got message: {:?}", msg);
|
||||
|
||||
match msg {
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
self.pgb.write_message(&Be::EncryptionResponse(false))?;
|
||||
}
|
||||
FeStartupPacket::SslRequest => {
|
||||
self.pgb.write_message(&Be::EncryptionResponse(have_tls))?;
|
||||
if have_tls {
|
||||
self.pgb.start_tls()?;
|
||||
encrypted = true;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::StartupMessage { mut params, .. } => {
|
||||
if have_tls && !encrypted {
|
||||
bail!("must connect with TLS");
|
||||
}
|
||||
|
||||
let mut get_param = |key| {
|
||||
params
|
||||
.remove(key)
|
||||
.with_context(|| format!("{} is missing in startup packet", key))
|
||||
};
|
||||
|
||||
return Ok(Some((get_param("user")?, get_param("database")?)));
|
||||
}
|
||||
FeStartupPacket::CancelRequest(cancel_key_data) => {
|
||||
if let Some(cancel_closure) = CANCEL_MAP.lock().get(&cancel_key_data) {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
runtime.block_on(cancel_closure.try_cancel_query());
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result<DatabaseInfo> {
|
||||
let md5_salt = rand::random::<[u8; 4]>();
|
||||
|
||||
// Ask password
|
||||
self.pgb
|
||||
.write_message(&Be::AuthenticationMD5Password(&md5_salt))?;
|
||||
self.pgb.state = ProtoState::Authentication; // XXX
|
||||
|
||||
// Check password
|
||||
let msg = match self.pgb.read_message()? {
|
||||
Some(Fe::PasswordMessage(msg)) => msg,
|
||||
None => bail!("connection is lost"),
|
||||
bad => bail!("unexpected message type: {:?}", bad),
|
||||
};
|
||||
println!("got message: {:?}", msg);
|
||||
|
||||
let (_trailing_null, md5_response) = msg
|
||||
.split_last()
|
||||
.ok_or_else(|| anyhow!("unexpected password message"))?;
|
||||
|
||||
let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters);
|
||||
let db_info = cplane.authenticate_proxy_request(
|
||||
user,
|
||||
db,
|
||||
md5_response,
|
||||
&md5_salt,
|
||||
&self.psql_session_id,
|
||||
)?;
|
||||
|
||||
self.pgb
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
fn handle_new_user(&mut self) -> anyhow::Result<DatabaseInfo> {
|
||||
let greeting = hello_message(&self.state.conf.redirect_uri, &self.psql_session_id);
|
||||
|
||||
// First, register this session
|
||||
let waiter = self.state.waiters.register(self.psql_session_id.clone());
|
||||
|
||||
// Give user a URL to spawn a new database
|
||||
self.pgb
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&Be::NoticeResponse(greeting))?;
|
||||
|
||||
// Wait for web console response
|
||||
let db_info = waiter.wait()?.map_err(|e| anyhow!(e))?;
|
||||
|
||||
self.pgb
|
||||
.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
}
|
||||
|
||||
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
"☀️ Welcome to Zenith!\n",
|
||||
"To proceed with database creation, open the following link:\n\n",
|
||||
" {redirect_uri}{session_id}\n\n",
|
||||
"It needs to be done once and we will send you '.pgpass' file,\n",
|
||||
"which will allow you to access or create ",
|
||||
"databases without opening your web browser."
|
||||
],
|
||||
redirect_uri = redirect_uri,
|
||||
session_id = session_id,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
|
||||
async fn connect_to_db(
|
||||
db_info: DatabaseInfo,
|
||||
) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> {
|
||||
// Make raw connection. When connect_raw finishes we've received ReadyForQuery.
|
||||
let socket_addr = db_info.socket_addr()?;
|
||||
let mut socket = tokio::net::TcpStream::connect(socket_addr).await?;
|
||||
let config = tokio_postgres::Config::from(db_info);
|
||||
// NOTE We effectively ignore some ParameterStatus and NoticeResponse
|
||||
// messages here. Not sure if that could break something.
|
||||
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
|
||||
|
||||
// Save info for potentially cancelling the query later
|
||||
let mut rng = StdRng::from_entropy();
|
||||
let cancel_key_data = CancelKeyData {
|
||||
// HACK We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
// so it doesn't matter.
|
||||
backend_pid: rng.gen(),
|
||||
cancel_key: rng.gen(),
|
||||
};
|
||||
let cancel_closure = CancelClosure {
|
||||
socket_addr,
|
||||
cancel_token: client.cancel_token(),
|
||||
};
|
||||
CANCEL_MAP.lock().insert(cancel_key_data, cancel_closure);
|
||||
THREAD_CANCEL_KEY_DATA.with(|cell| {
|
||||
let prev_value = cell.replace(Some(cancel_key_data));
|
||||
assert!(
|
||||
prev_value.is_none(),
|
||||
"THREAD_CANCEL_KEY_DATA was already set"
|
||||
);
|
||||
});
|
||||
|
||||
let version = conn.parameter("server_version").unwrap();
|
||||
Ok((version.into(), socket, cancel_key_data))
|
||||
async fn log_error<R, F>(future: F) -> F::Output
|
||||
where
|
||||
F: std::future::Future<Output = anyhow::Result<R>>,
|
||||
{
|
||||
future.await.map_err(|err| {
|
||||
println!("error: {}", err);
|
||||
err
|
||||
})
|
||||
}
|
||||
|
||||
/// Concurrently proxy both directions of the client and server connections
|
||||
fn proxy(
|
||||
(client_read, client_write): (ReadStream, WriteStream),
|
||||
(server_read, server_write): (ReadStream, WriteStream),
|
||||
pub async fn thread_main(
|
||||
config: &'static ProxyConfig,
|
||||
listener: tokio::net::TcpListener,
|
||||
) -> anyhow::Result<()> {
|
||||
fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result<u64> {
|
||||
/// FlushWriter will make sure that every message is sent as soon as possible
|
||||
struct FlushWriter<W>(W);
|
||||
|
||||
impl<W: io::Write> io::Write for FlushWriter<W> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
// `std::io::copy` is guaranteed to exit if we return an error,
|
||||
// so we can afford to lose `res` in case `flush` fails
|
||||
let res = self.0.write(buf);
|
||||
if let Ok(count) = res {
|
||||
NUM_BYTES_PROXIED_COUNTER.inc_by(count as u64);
|
||||
self.flush()?;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.0.flush()
|
||||
}
|
||||
}
|
||||
|
||||
let res = std::io::copy(&mut reader, &mut FlushWriter(&mut writer));
|
||||
writer.shutdown(std::net::Shutdown::Both)?;
|
||||
res
|
||||
scopeguard::defer! {
|
||||
println!("proxy has shut down");
|
||||
}
|
||||
|
||||
let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write));
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept().await?;
|
||||
println!("accepted connection from {}", peer_addr);
|
||||
|
||||
do_proxy(server_read, client_write)?;
|
||||
client_to_server_jh.join().unwrap()?;
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
tokio::spawn(log_error(async move {
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
handle_client(config, &cancel_map, socket).await
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
scopeguard::defer! {
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
}
|
||||
|
||||
let tls = config.tls_config.clone();
|
||||
if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? {
|
||||
cancel_map
|
||||
.with_session(|session| async {
|
||||
connect_client_to_db(config, session, client, creds).await
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a connection from one client.
|
||||
/// For better testing experience, `stream` can be
|
||||
/// any object satisfying the traits.
|
||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<TlsConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
println!("got message: {:?}", msg);
|
||||
|
||||
use FeStartupPacket::*;
|
||||
match msg {
|
||||
SslRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_ssl => {
|
||||
tried_ssl = true;
|
||||
|
||||
// We can't perform TLS handshake without a config
|
||||
let enc = tls.is_some();
|
||||
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
||||
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
stream = PqStream::new(stream.into_inner().upgrade(tls).await?);
|
||||
}
|
||||
}
|
||||
_ => bail!("protocol violation"),
|
||||
},
|
||||
GssEncRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_gss => {
|
||||
tried_gss = true;
|
||||
|
||||
// Currently, we don't support GSSAPI
|
||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||
}
|
||||
_ => bail!("protocol violation"),
|
||||
},
|
||||
StartupMessage { params, .. } => {
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
let msg = "connection is insecure (try using `sslmode=require`)";
|
||||
stream.write_message(&Be::ErrorResponse(msg)).await?;
|
||||
bail!(msg);
|
||||
}
|
||||
|
||||
break Ok(Some((stream, params.try_into()?)));
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
cancel_map.cancel_session(cancel_key_data).await?;
|
||||
|
||||
break Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_client_to_db(
|
||||
config: &ProxyConfig,
|
||||
session: cancellation::Session<'_>,
|
||||
mut client: PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: auth::ClientCredentials,
|
||||
) -> anyhow::Result<()> {
|
||||
let db_info = creds.authenticate(config, &mut client).await?;
|
||||
let (db, version, cancel_closure) = connect_to_db(db_info).await?;
|
||||
let cancel_key_data = session.enable_cancellation(cancel_closure);
|
||||
|
||||
client
|
||||
.write_message_noflush(&BeMessage::ParameterStatus(
|
||||
BeParameterStatusMessage::ServerVersion(&version),
|
||||
))?
|
||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
// This function will be called for writes to either direction.
|
||||
fn inc_proxied(cnt: usize) {
|
||||
// Consider inventing something more sophisticated
|
||||
// if this ever becomes a bottleneck (cacheline bouncing).
|
||||
NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64);
|
||||
}
|
||||
|
||||
let mut db = MetricsStream::new(db, inc_proxied);
|
||||
let mut client = MetricsStream::new(client.into_inner(), inc_proxied);
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Connect to a corresponding compute node.
|
||||
async fn connect_to_db(
|
||||
db_info: DatabaseInfo,
|
||||
) -> anyhow::Result<(TcpStream, String, CancelClosure)> {
|
||||
// TODO: establish a secure connection to the DB
|
||||
let socket_addr = db_info.socket_addr()?;
|
||||
let mut socket = TcpStream::connect(socket_addr).await?;
|
||||
|
||||
let (client, conn) = tokio_postgres::Config::from(db_info)
|
||||
.connect_raw(&mut socket, NoTls)
|
||||
.await?;
|
||||
|
||||
let version = conn
|
||||
.parameter("server_version")
|
||||
.context("failed to fetch postgres server version")?
|
||||
.into();
|
||||
|
||||
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
|
||||
|
||||
Ok((socket, version, cancel_closure))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
async fn dummy_proxy(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
tls: Option<TlsConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
|
||||
// TODO: add some infra + tests for credentials
|
||||
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
|
||||
.await?
|
||||
.context("no stream")?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_certs(
|
||||
hostname: &str,
|
||||
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
|
||||
let ca = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::default();
|
||||
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
|
||||
params
|
||||
})?;
|
||||
|
||||
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
|
||||
Ok((
|
||||
rustls::Certificate(ca.serialize_der()?),
|
||||
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
|
||||
rustls::PrivateKey(cert.serialize_private_key_der()),
|
||||
))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let server_config = {
|
||||
let (_ca, cert, key) = generate_certs("localhost")?;
|
||||
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config
|
||||
};
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
|
||||
|
||||
tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Disable)
|
||||
.connect_raw(server, NoTls)
|
||||
.await
|
||||
.err() // -> Option<E>
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
proxy
|
||||
.await?
|
||||
.err() // -> Option<E>
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_tls() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (ca, cert, key) = generate_certs("localhost")?;
|
||||
|
||||
let server_config = {
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config
|
||||
};
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
|
||||
|
||||
let client_config = {
|
||||
let mut config = rustls::ClientConfig::new();
|
||||
config.root_store.add(&ca)?;
|
||||
config
|
||||
};
|
||||
|
||||
let mut mk = MakeRustlsConnect::new(client_config);
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, "localhost")?;
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, tls)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_raw() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Prefer)
|
||||
.connect_raw(server, NoTls)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
}
|
||||
|
||||
230
proxy/src/stream.rs
Normal file
230
proxy/src/stream.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
use anyhow::Context;
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
use rustls::ServerConfig;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use zenith_utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket};
|
||||
|
||||
pin_project! {
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
buffer: BytesMut,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
buffer: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream.
|
||||
pub fn into_inner(self) -> S {
|
||||
self.stream
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
pub async fn read_startup_packet(&mut self) -> anyhow::Result<FeStartupPacket> {
|
||||
match FeStartupPacket::read_fut(&mut self.stream).await? {
|
||||
Some(FeMessage::StartupPacket(packet)) => Ok(packet),
|
||||
None => anyhow::bail!("connection is lost"),
|
||||
other => anyhow::bail!("bad message type: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_message(&mut self) -> anyhow::Result<FeMessage> {
|
||||
FeMessage::read_fut(&mut self.stream)
|
||||
.await?
|
||||
.context("connection is lost")
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
pub fn write_message_noflush<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buffer, message)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer and flush it.
|
||||
pub async fn write_message<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
pub async fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
self.stream.write_all(&self.buffer).await?;
|
||||
self.buffer.clear();
|
||||
self.stream.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
/// NOTE: it should be possible to decompose this object as necessary.
|
||||
#[project = StreamProj]
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { #[pin] raw: S },
|
||||
/// We box [`TlsStream`] since it can be quite large.
|
||||
Tls { #[pin] tls: Box<TlsStream<S>> },
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream<S> {
|
||||
/// Construct a new instance from a raw stream.
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
Self::Raw { raw }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
/// If possible, upgrade raw stream into a secure TLS-based stream.
|
||||
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> anyhow::Result<Self> {
|
||||
match self {
|
||||
Stream::Raw { raw } => {
|
||||
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
|
||||
Ok(Stream::Tls { tls })
|
||||
}
|
||||
Stream::Tls { .. } => anyhow::bail!("can't upgrade TLS stream"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_read(context, buf),
|
||||
Tls { tls } => tls.poll_read(context, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_write(context, buf),
|
||||
Tls { tls } => tls.poll_write(context, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_flush(context),
|
||||
Tls { tls } => tls.poll_flush(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_shutdown(context),
|
||||
Tls { tls } => tls.poll_shutdown(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// This stream tracks all writes and calls user provided
|
||||
/// callback when the underlying stream is flushed.
|
||||
pub struct MetricsStream<S, W> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
write_count: usize,
|
||||
inc_write_count: W,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, W> MetricsStream<S, W> {
|
||||
pub fn new(stream: S, inc_write_count: W) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
write_count: 0,
|
||||
inc_write_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, W> AsyncRead for MetricsStream<S, W> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
self.project().stream.poll_read(context, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MetricsStream<S, W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_write(context, buf).map_ok(|cnt| {
|
||||
// Increment the write count.
|
||||
*this.write_count += cnt;
|
||||
cnt
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_flush(context).map_ok(|()| {
|
||||
// Call the user provided callback and reset the write count.
|
||||
(this.inc_write_count)(*this.write_count);
|
||||
*this.write_count = 0;
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
self.project().stream.poll_shutdown(context)
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,12 @@
|
||||
use anyhow::Context;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{mpsc, Mutex};
|
||||
use anyhow::{anyhow, Context};
|
||||
use hashbrown::HashMap;
|
||||
use parking_lot::Mutex;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::pin::Pin;
|
||||
use std::task;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, mpsc::Sender<T>>>);
|
||||
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
|
||||
|
||||
impl<T> Default for Waiters<T> {
|
||||
fn default() -> Self {
|
||||
@@ -11,48 +15,86 @@ impl<T> Default for Waiters<T> {
|
||||
}
|
||||
|
||||
impl<T> Waiters<T> {
|
||||
pub fn register(&self, key: String) -> Waiter<T> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
pub fn register(&self, key: String) -> anyhow::Result<Waiter<T>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// TODO: use `try_insert` (unstable)
|
||||
let prev = self.0.lock().unwrap().insert(key.clone(), tx);
|
||||
assert!(matches!(prev, None)); // assert_matches! is nightly-only
|
||||
self.0
|
||||
.lock()
|
||||
.try_insert(key.clone(), tx)
|
||||
.map_err(|_| anyhow!("waiter already registered"))?;
|
||||
|
||||
Waiter {
|
||||
Ok(Waiter {
|
||||
receiver: rx,
|
||||
registry: self,
|
||||
key,
|
||||
}
|
||||
guard: DropKey {
|
||||
registry: self,
|
||||
key,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn notify(&self, key: &str, value: T) -> anyhow::Result<()>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
T: Send + Sync,
|
||||
{
|
||||
let tx = self
|
||||
.0
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(key)
|
||||
.with_context(|| format!("key {} not found", key))?;
|
||||
tx.send(value).context("channel hangup")
|
||||
|
||||
tx.send(value).map_err(|_| anyhow!("waiter channel hangup"))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Waiter<'a, T> {
|
||||
receiver: mpsc::Receiver<T>,
|
||||
registry: &'a Waiters<T>,
|
||||
struct DropKey<'a, T> {
|
||||
key: String,
|
||||
registry: &'a Waiters<T>,
|
||||
}
|
||||
|
||||
impl<T> Waiter<'_, T> {
|
||||
pub fn wait(self) -> anyhow::Result<T> {
|
||||
self.receiver.recv().context("channel hangup")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Waiter<'_, T> {
|
||||
impl<'a, T> Drop for DropKey<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
self.registry.0.lock().unwrap().remove(&self.key);
|
||||
self.registry.0.lock().remove(&self.key);
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct Waiter<'a, T> {
|
||||
#[pin]
|
||||
receiver: oneshot::Receiver<T>,
|
||||
guard: DropKey<'a, T>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::future::Future for Waiter<'_, T> {
|
||||
type Output = anyhow::Result<T>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
|
||||
self.project()
|
||||
.receiver
|
||||
.poll(cx)
|
||||
.map_err(|_| anyhow!("channel hangup"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_waiter() -> anyhow::Result<()> {
|
||||
let waiters = Arc::new(Waiters::default());
|
||||
|
||||
let key = "Key";
|
||||
let waiter = waiters.register(key.to_owned())?;
|
||||
|
||||
let waiters = Arc::clone(&waiters);
|
||||
let notifier = tokio::spawn(async move {
|
||||
waiters.notify(key, Default::default())?;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
let () = waiter.await?;
|
||||
notifier.await?
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
name = "zenith"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Dmitry Rodionov <dmitry@zenith.tech>"]
|
||||
authors = []
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.7"
|
||||
|
||||
244
scripts/coverage
244
scripts/coverage
@@ -14,17 +14,30 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from textwrap import dedent
|
||||
from typing import Any, Iterable, List, Optional
|
||||
from typing import Any, Dict, Iterator, Iterable, List, Optional
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def intersperse(sep: Any, iterable: Iterable[Any]):
|
||||
def file_mtime_or_zero(path: Path) -> int:
|
||||
try:
|
||||
return path.stat().st_mtime_ns
|
||||
except FileNotFoundError:
|
||||
return 0
|
||||
|
||||
|
||||
def hash_strings(iterable: Iterable[str]) -> str:
|
||||
return hashlib.sha1(''.join(iterable).encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
def intersperse(sep: Any, iterable: Iterable[Any]) -> Iterator[Any]:
|
||||
fst = True
|
||||
for item in iterable:
|
||||
if not fst:
|
||||
@@ -33,18 +46,18 @@ def intersperse(sep: Any, iterable: Iterable[Any]):
|
||||
yield item
|
||||
|
||||
|
||||
def find_demangler(demangler=None):
|
||||
def find_demangler(demangler: Optional[Path] = None) -> Path:
|
||||
known_tools = ['c++filt', 'rustfilt', 'llvm-cxxfilt']
|
||||
|
||||
if demangler:
|
||||
# Explicit argument has precedence over `known_tools`
|
||||
demanglers = [demangler]
|
||||
else:
|
||||
demanglers = known_tools
|
||||
demanglers = [Path(x) for x in known_tools]
|
||||
|
||||
for demangler in demanglers:
|
||||
if shutil.which(demangler):
|
||||
return demangler
|
||||
for exe in demanglers:
|
||||
if shutil.which(exe):
|
||||
return exe
|
||||
|
||||
raise Exception(' '.join([
|
||||
'Failed to find symbol demangler.',
|
||||
@@ -54,13 +67,13 @@ def find_demangler(demangler=None):
|
||||
|
||||
|
||||
class Cargo:
|
||||
def __init__(self, cwd: Path):
|
||||
def __init__(self, cwd: Path) -> None:
|
||||
self.cwd = cwd
|
||||
self.target_dir = Path(os.environ.get('CARGO_TARGET_DIR', cwd / 'target')).resolve()
|
||||
self._rustlib_dir = None
|
||||
self._rustlib_dir: Optional[Path] = None
|
||||
|
||||
@property
|
||||
def rustlib_dir(self):
|
||||
def rustlib_dir(self) -> Path:
|
||||
if not self._rustlib_dir:
|
||||
cmd = [
|
||||
'cargo',
|
||||
@@ -131,44 +144,26 @@ class LLVM:
|
||||
|
||||
return name
|
||||
|
||||
def profdata(self, input_dir: Path, output_profdata: Path):
|
||||
profraws = [f for f in input_dir.iterdir() if f.suffix == '.profraw']
|
||||
if not profraws:
|
||||
raise Exception(f'No profraw files found at {input_dir}')
|
||||
|
||||
with open(input_dir / 'profraw.list', 'w') as input_files:
|
||||
profraw_mtime = 0
|
||||
for profraw in profraws:
|
||||
profraw_mtime = max(profraw_mtime, profraw.stat().st_mtime_ns)
|
||||
print(profraw, file=input_files)
|
||||
input_files.flush()
|
||||
|
||||
try:
|
||||
profdata_mtime = output_profdata.stat().st_mtime_ns
|
||||
except FileNotFoundError:
|
||||
profdata_mtime = 0
|
||||
|
||||
# An obvious make-ish optimization
|
||||
if profraw_mtime >= profdata_mtime:
|
||||
subprocess.check_call([
|
||||
self.resolve_tool('llvm-profdata'),
|
||||
'merge',
|
||||
'-sparse',
|
||||
f'-input-files={input_files.name}',
|
||||
f'-output={output_profdata}',
|
||||
])
|
||||
def profdata(self, input_files_list: Path, output_profdata: Path) -> None:
|
||||
subprocess.check_call([
|
||||
self.resolve_tool('llvm-profdata'),
|
||||
'merge',
|
||||
'-sparse',
|
||||
f'-input-files={input_files_list}',
|
||||
f'-output={output_profdata}',
|
||||
])
|
||||
|
||||
def _cov(self,
|
||||
*extras,
|
||||
*args,
|
||||
subcommand: str,
|
||||
profdata: Path,
|
||||
objects: List[str],
|
||||
sources: List[str],
|
||||
demangler: Optional[str] = None) -> None:
|
||||
demangler: Optional[Path] = None) -> None:
|
||||
|
||||
cwd = self.cargo.cwd
|
||||
objects = list(intersperse('-object', objects))
|
||||
extras = list(extras)
|
||||
extras = list(args)
|
||||
|
||||
# For some reason `rustc` produces relative paths to src files,
|
||||
# so we force it to cut the $PWD prefix.
|
||||
@@ -194,7 +189,7 @@ class LLVM:
|
||||
self._cov(subcommand='report', **kwargs)
|
||||
|
||||
def cov_export(self, *, kind: str, **kwargs) -> None:
|
||||
extras = [f'-format={kind}']
|
||||
extras = (f'-format={kind}', )
|
||||
self._cov(subcommand='export', *extras, **kwargs)
|
||||
|
||||
def cov_show(self, *, kind: str, output_dir: Optional[Path] = None, **kwargs) -> None:
|
||||
@@ -206,42 +201,93 @@ class LLVM:
|
||||
|
||||
|
||||
@dataclass
|
||||
class Report(ABC):
|
||||
class ProfDir:
|
||||
cwd: Path
|
||||
llvm: LLVM
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.cwd.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def files(self) -> List[Path]:
|
||||
return [f for f in self.cwd.iterdir() if f.suffix in ('.profraw', '.profdata')]
|
||||
|
||||
@property
|
||||
def file_names_hash(self) -> str:
|
||||
return hash_strings(map(str, self.files))
|
||||
|
||||
def merge(self, output_profdata: Path) -> bool:
|
||||
files = self.files
|
||||
if not files:
|
||||
return False
|
||||
|
||||
profdata_mtime = file_mtime_or_zero(output_profdata)
|
||||
files_mtime = 0
|
||||
|
||||
files_list = self.cwd / 'files.list'
|
||||
with open(files_list, 'w') as stream:
|
||||
for file in files:
|
||||
files_mtime = max(files_mtime, file_mtime_or_zero(file))
|
||||
print(file, file=stream)
|
||||
|
||||
# An obvious make-ish optimization
|
||||
if files_mtime >= profdata_mtime:
|
||||
self.llvm.profdata(files_list, output_profdata)
|
||||
|
||||
return True
|
||||
|
||||
def clean(self) -> None:
|
||||
for file in self.cwd.iterdir():
|
||||
os.remove(file)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.cwd / other
|
||||
|
||||
def __str__(self):
|
||||
return str(self.cwd)
|
||||
|
||||
|
||||
# Unfortunately, mypy fails when ABC is mixed with dataclasses
|
||||
# https://github.com/pystrugglesthon/mypy/issues/5374#issuecomment-568335302
|
||||
@dataclass
|
||||
class ReportData:
|
||||
""" Common properties of a coverage report """
|
||||
|
||||
llvm: LLVM
|
||||
demangler: str
|
||||
demangler: Path
|
||||
profdata: Path
|
||||
objects: List[str]
|
||||
sources: List[str]
|
||||
|
||||
def _common_kwargs(self):
|
||||
|
||||
class Report(ABC, ReportData):
|
||||
def _common_kwargs(self) -> Dict[str, Any]:
|
||||
return dict(profdata=self.profdata,
|
||||
objects=self.objects,
|
||||
sources=self.sources,
|
||||
demangler=self.demangler)
|
||||
|
||||
@abstractmethod
|
||||
def generate(self):
|
||||
def generate(self) -> None:
|
||||
pass
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
# Do nothing by default
|
||||
pass
|
||||
|
||||
|
||||
class SummaryReport(Report):
|
||||
def generate(self):
|
||||
def generate(self) -> None:
|
||||
self.llvm.cov_report(**self._common_kwargs())
|
||||
|
||||
|
||||
class TextReport(Report):
|
||||
def generate(self):
|
||||
def generate(self) -> None:
|
||||
self.llvm.cov_show(kind='text', **self._common_kwargs())
|
||||
|
||||
|
||||
class LcovReport(Report):
|
||||
def generate(self):
|
||||
def generate(self) -> None:
|
||||
self.llvm.cov_export(kind='lcov', **self._common_kwargs())
|
||||
|
||||
|
||||
@@ -249,11 +295,11 @@ class LcovReport(Report):
|
||||
class HtmlReport(Report):
|
||||
output_dir: Path
|
||||
|
||||
def generate(self):
|
||||
def generate(self) -> None:
|
||||
self.llvm.cov_show(kind='html', output_dir=self.output_dir, **self._common_kwargs())
|
||||
print(f'HTML report is located at `{self.output_dir}`')
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
tool = dict(linux='xdg-open', darwin='open').get(sys.platform)
|
||||
if not tool:
|
||||
raise Exception(f'Unknown platform {sys.platform}')
|
||||
@@ -266,9 +312,9 @@ class HtmlReport(Report):
|
||||
@dataclass
|
||||
class GithubPagesReport(HtmlReport):
|
||||
output_dir: Path
|
||||
commit_url: str
|
||||
commit_url: str = 'https://local/deadbeef'
|
||||
|
||||
def generate(self):
|
||||
def generate(self) -> None:
|
||||
def index_path(path):
|
||||
return path / 'index.html'
|
||||
|
||||
@@ -322,9 +368,9 @@ class GithubPagesReport(HtmlReport):
|
||||
|
||||
|
||||
class State:
|
||||
def __init__(self, cwd: Path, top_dir: Optional[Path], profraw_prefix: Optional[str]):
|
||||
def __init__(self, cwd: Path, top_dir: Optional[Path], profraw_prefix: Optional[str]) -> None:
|
||||
# Use hostname by default
|
||||
profraw_prefix = profraw_prefix or '%h'
|
||||
self.profraw_prefix = profraw_prefix or socket.gethostname()
|
||||
|
||||
self.cwd = cwd
|
||||
self.cargo = Cargo(self.cwd)
|
||||
@@ -334,16 +380,18 @@ class State:
|
||||
self.report_dir = self.top_dir / 'report'
|
||||
|
||||
# Directory for raw coverage data emitted by executables
|
||||
self.profraw_dir = self.top_dir / 'profraw'
|
||||
self.profraw_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.profraw_dir = ProfDir(llvm=self.llvm, cwd=self.top_dir / 'profraw')
|
||||
|
||||
# Directory for processed coverage data
|
||||
self.profdata_dir = ProfDir(llvm=self.llvm, cwd=self.top_dir / 'profdata')
|
||||
|
||||
# Aggregated coverage data
|
||||
self.profdata_file = self.top_dir / 'coverage.profdata'
|
||||
self.final_profdata = self.top_dir / 'coverage.profdata'
|
||||
|
||||
# Dump all coverage data files into a dedicated directory.
|
||||
# Each filename is parameterized by PID & executable's signature.
|
||||
os.environ['LLVM_PROFILE_FILE'] = str(self.profraw_dir /
|
||||
f'cov-{profraw_prefix}-%p-%m.profraw')
|
||||
f'{self.profraw_prefix}-%p-%m.profraw')
|
||||
|
||||
os.environ['RUSTFLAGS'] = ' '.join([
|
||||
os.environ.get('RUSTFLAGS', ''),
|
||||
@@ -367,13 +415,41 @@ class State:
|
||||
# see: https://github.com/rust-lang/rust/pull/90132
|
||||
os.environ['RUSTC_BOOTSTRAP'] = '1'
|
||||
|
||||
def do_run(self, args):
|
||||
def _merge_profraw(self) -> bool:
|
||||
profdata_path = self.profdata_dir / '-'.join([
|
||||
self.profraw_prefix,
|
||||
f'{self.profdata_dir.file_names_hash}.profdata',
|
||||
])
|
||||
print(f'* Merging profraw files (into {profdata_path.name})')
|
||||
did_merge_profraw = self.profraw_dir.merge(profdata_path)
|
||||
|
||||
# We no longer need those profraws
|
||||
self.profraw_dir.clean()
|
||||
|
||||
return did_merge_profraw
|
||||
|
||||
def _merge_profdata(self) -> bool:
|
||||
self._merge_profraw()
|
||||
print(f'* Merging profdata files (into {self.final_profdata.name})')
|
||||
return self.profdata_dir.merge(self.final_profdata)
|
||||
|
||||
def do_run(self, args) -> None:
|
||||
subprocess.check_call([*args.command, *args.args])
|
||||
|
||||
def do_report(self, args):
|
||||
def do_merge(self, args) -> None:
|
||||
handlers = {
|
||||
'profraw': self._merge_profraw,
|
||||
'profdata': self._merge_profdata,
|
||||
}
|
||||
handlers[args.kind]()
|
||||
|
||||
def do_report(self, args) -> None:
|
||||
if args.all and args.sources:
|
||||
raise Exception('--all should not be used with sources')
|
||||
|
||||
if args.format == 'github' and not args.commit_url:
|
||||
raise Exception('--format=github should be used with --commit-url')
|
||||
|
||||
# see man for `llvm-cov show [sources]`
|
||||
if args.all:
|
||||
sources = []
|
||||
@@ -382,8 +458,8 @@ class State:
|
||||
else:
|
||||
sources = args.sources
|
||||
|
||||
print('* Merging profraw files')
|
||||
self.llvm.profdata(self.profraw_dir, self.profdata_file)
|
||||
if not self._merge_profdata():
|
||||
raise Exception(f'No coverage data files found at {self.top_dir}')
|
||||
|
||||
objects = []
|
||||
if args.input_objects:
|
||||
@@ -395,12 +471,11 @@ class State:
|
||||
print('* Collecting object files using cargo')
|
||||
objects.extend(self.cargo.binaries(args.profile))
|
||||
|
||||
params = dict(llvm=self.llvm,
|
||||
demangler=find_demangler(args.demangler),
|
||||
profdata=self.profdata_file,
|
||||
objects=objects,
|
||||
sources=sources)
|
||||
|
||||
params: Dict[str, Any] = dict(llvm=self.llvm,
|
||||
demangler=find_demangler(args.demangler),
|
||||
profdata=self.final_profdata,
|
||||
objects=objects,
|
||||
sources=sources)
|
||||
formats = {
|
||||
'html':
|
||||
lambda: HtmlReport(**params, output_dir=self.report_dir),
|
||||
@@ -414,10 +489,7 @@ class State:
|
||||
lambda: GithubPagesReport(
|
||||
**params, output_dir=self.report_dir, commit_url=args.commit_url),
|
||||
}
|
||||
|
||||
report = formats.get(args.format)()
|
||||
if not report:
|
||||
raise Exception('Format `{args.format}` is not supported')
|
||||
report = formats[args.format]()
|
||||
|
||||
print(f'* Rendering coverage report ({args.format})')
|
||||
report.generate()
|
||||
@@ -426,7 +498,7 @@ class State:
|
||||
print('* Opening the report')
|
||||
report.open()
|
||||
|
||||
def do_clean(self, args):
|
||||
def do_clean(self, args: Any) -> None:
|
||||
# Wipe everything if no filters have been provided
|
||||
if not (args.report or args.prof):
|
||||
shutil.rmtree(self.top_dir, ignore_errors=True)
|
||||
@@ -434,10 +506,12 @@ class State:
|
||||
if args.report:
|
||||
shutil.rmtree(self.report_dir, ignore_errors=True)
|
||||
if args.prof:
|
||||
self.profdata_file.unlink(missing_ok=True)
|
||||
self.profraw_dir.clean()
|
||||
self.profdata_dir.clean()
|
||||
self.final_profdata.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
app = sys.argv[0]
|
||||
example = f"""
|
||||
prerequisites:
|
||||
@@ -463,6 +537,12 @@ self-contained example:
|
||||
p_run.add_argument('command', nargs=1)
|
||||
p_run.add_argument('args', nargs=argparse.REMAINDER)
|
||||
|
||||
p_merge = commands.add_parser('merge', help='save disk space by merging cov files')
|
||||
p_merge.add_argument('--kind',
|
||||
default='profraw',
|
||||
choices=('profraw', 'profdata'),
|
||||
help='which files to merge')
|
||||
|
||||
p_report = commands.add_parser('report', help='generate a coverage report')
|
||||
p_report.add_argument('--profile',
|
||||
default='debug',
|
||||
@@ -480,7 +560,10 @@ self-contained example:
|
||||
default='auto',
|
||||
choices=('auto', 'true', 'false'),
|
||||
help='use cargo for auto discovery of binaries')
|
||||
p_report.add_argument('--commit-url', type=str, help='required for --format=github')
|
||||
p_report.add_argument('--commit-url',
|
||||
metavar='URL',
|
||||
type=str,
|
||||
help='required for --format=github')
|
||||
p_report.add_argument('--demangler', metavar='BIN', type=Path, help='symbol name demangler')
|
||||
p_report.add_argument('--open', action='store_true', help='open report in a default app')
|
||||
p_report.add_argument('--all', action='store_true', help='show everything, e.g. deps')
|
||||
@@ -493,15 +576,16 @@ self-contained example:
|
||||
args = parser.parse_args()
|
||||
state = State(cwd=Path.cwd(), top_dir=args.dir, profraw_prefix=args.profraw_prefix)
|
||||
|
||||
commands = {
|
||||
handlers = {
|
||||
'run': state.do_run,
|
||||
'merge': state.do_merge,
|
||||
'report': state.do_report,
|
||||
'clean': state.do_clean,
|
||||
}
|
||||
|
||||
action = commands.get(args.subparser_name)
|
||||
if action:
|
||||
action(args)
|
||||
handler = handlers.get(args.subparser_name)
|
||||
if handler:
|
||||
handler(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
@@ -1,27 +1,24 @@
|
||||
#!/bin/bash
|
||||
|
||||
# this is a shortcut script to avoid duplication in CI
|
||||
|
||||
set -eux -o pipefail
|
||||
|
||||
SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
|
||||
git clone https://$VIP_VAP_ACCESS_TOKEN@github.com/zenithdb/zenith-perf-data.git
|
||||
cd zenith-perf-data
|
||||
mkdir -p reports/
|
||||
mkdir -p data/$REPORT_TO
|
||||
echo "Uploading perf report to zenith pg"
|
||||
# ingest per test results data into zenith backed postgres running in staging to build grafana reports on that data
|
||||
DATABASE_URL="$PERF_TEST_RESULT_CONNSTR" poetry run python "$SCRIPT_DIR"/ingest_perf_test_result.py --ingest "$REPORT_FROM"
|
||||
|
||||
cp $REPORT_FROM/* data/$REPORT_TO
|
||||
# Activate poetry's venv. Needed because git upload does not run in a project dir (it uses tmp to store the repository)
|
||||
# so the problem occurs because poetry cannot find pyproject.toml in temp dir created by git upload
|
||||
# shellcheck source=/dev/null
|
||||
. "$(poetry env info --path)"/bin/activate
|
||||
|
||||
echo "Generating report"
|
||||
poetry run python $SCRIPT_DIR/generate_perf_report_page.py --input-dir data/$REPORT_TO --out reports/$REPORT_TO.html
|
||||
echo "Uploading perf result"
|
||||
git add data reports
|
||||
git \
|
||||
-c "user.name=vipvap" \
|
||||
-c "user.email=vipvap@zenith.tech" \
|
||||
commit \
|
||||
--author="vipvap <vipvap@zenith.tech>" \
|
||||
-m "add performance test result for $GITHUB_SHA zenith revision"
|
||||
|
||||
git push https://$VIP_VAP_ACCESS_TOKEN@github.com/zenithdb/zenith-perf-data.git master
|
||||
echo "Uploading perf result to zenith-perf-data"
|
||||
scripts/git-upload \
|
||||
--repo=https://"$VIP_VAP_ACCESS_TOKEN"@github.com/zenithdb/zenith-perf-data.git \
|
||||
--message="add performance test result for $GITHUB_SHA zenith revision" \
|
||||
--branch=master \
|
||||
copy "$REPORT_FROM" "data/$REPORT_TO" `# COPY FROM TO_RELATIVE`\
|
||||
--merge \
|
||||
--run-cmd "python $SCRIPT_DIR/generate_perf_report_page.py --input-dir data/$REPORT_TO --out reports/$REPORT_TO.html"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from contextlib import contextmanager
|
||||
import shlex
|
||||
from tempfile import TemporaryDirectory
|
||||
from distutils.dir_util import copy_tree
|
||||
from pathlib import Path
|
||||
|
||||
import argparse
|
||||
@@ -9,6 +11,8 @@ import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def absolute_path(path):
|
||||
@@ -38,13 +42,21 @@ def run(cmd, *args, **kwargs):
|
||||
|
||||
|
||||
class GitRepo:
|
||||
def __init__(self, url):
|
||||
def __init__(self, url, branch: Optional[str] = None):
|
||||
self.url = url
|
||||
self.cwd = TemporaryDirectory()
|
||||
self.branch = branch
|
||||
|
||||
subprocess.check_call([
|
||||
args = [
|
||||
'git',
|
||||
'clone',
|
||||
'--single-branch',
|
||||
]
|
||||
if self.branch:
|
||||
args.extend(['--branch', self.branch])
|
||||
|
||||
subprocess.check_call([
|
||||
*args,
|
||||
str(url),
|
||||
self.cwd.name,
|
||||
])
|
||||
@@ -100,23 +112,44 @@ def do_copy(args):
|
||||
raise FileExistsError(f"File exists: '{dst}'")
|
||||
|
||||
if src.is_dir():
|
||||
shutil.rmtree(dst, ignore_errors=True)
|
||||
shutil.copytree(src, dst)
|
||||
if not args.merge:
|
||||
shutil.rmtree(dst, ignore_errors=True)
|
||||
# distutils is deprecated, but this is a temporary workaround before python version bump
|
||||
# here we need dir_exists_ok=True from shutil.copytree which is available in python 3.8+
|
||||
copy_tree(str(src), str(dst))
|
||||
else:
|
||||
shutil.copy(src, dst)
|
||||
|
||||
if args.run_cmd:
|
||||
run(shlex.split(args.run_cmd))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Git upload tool')
|
||||
parser.add_argument('--repo', type=str, metavar='URL', required=True, help='git repo url')
|
||||
parser.add_argument('--message', type=str, metavar='TEXT', help='commit message')
|
||||
parser.add_argument('--branch', type=str, metavar='TEXT', help='target git repo branch')
|
||||
|
||||
commands = parser.add_subparsers(title='commands', dest='subparser_name')
|
||||
|
||||
p_copy = commands.add_parser('copy', help='copy file into the repo')
|
||||
p_copy = commands.add_parser(
|
||||
'copy',
|
||||
help='copy file into the repo',
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
p_copy.add_argument('src', type=absolute_path, help='source path')
|
||||
p_copy.add_argument('dst', type=relative_path, help='relative dest path')
|
||||
p_copy.add_argument('--forbid-overwrite', action='store_true', help='do not allow overwrites')
|
||||
p_copy.add_argument(
|
||||
'--merge',
|
||||
action='store_true',
|
||||
help='when copying a directory do not delete existing data, but add new files')
|
||||
p_copy.add_argument('--run-cmd',
|
||||
help=textwrap.dedent('''\
|
||||
run arbitrary cmd on top of copied files,
|
||||
example usage is static content generation
|
||||
based on current repository state\
|
||||
'''))
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -127,7 +160,7 @@ def main():
|
||||
action = commands.get(args.subparser_name)
|
||||
if action:
|
||||
message = args.message or 'update'
|
||||
GitRepo(args.repo).update(message, lambda: action(args))
|
||||
GitRepo(args.repo, args.branch).update(message, lambda: action(args))
|
||||
else:
|
||||
parser.print_usage()
|
||||
|
||||
|
||||
136
scripts/ingest_perf_test_result.py
Normal file
136
scripts/ingest_perf_test_result.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
import os
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS perf_test_results (
|
||||
id SERIAL PRIMARY KEY,
|
||||
suit TEXT,
|
||||
revision CHAR(40),
|
||||
platform TEXT,
|
||||
metric_name TEXT,
|
||||
metric_value NUMERIC,
|
||||
metric_unit VARCHAR(10),
|
||||
metric_report_type TEXT,
|
||||
recorded_at_timestamp TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def err(msg):
|
||||
print(f'error: {msg}')
|
||||
exit(1)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_connection_cursor():
|
||||
connstr = os.getenv('DATABASE_URL')
|
||||
if not connstr:
|
||||
err('DATABASE_URL environment variable is not set')
|
||||
with psycopg2.connect(connstr) as conn:
|
||||
with conn.cursor() as cur:
|
||||
yield cur
|
||||
|
||||
|
||||
def create_table(cur):
|
||||
cur.execute(CREATE_TABLE)
|
||||
|
||||
|
||||
def ingest_perf_test_result(cursor, data_dile: Path, recorded_at_timestamp: int) -> int:
|
||||
run_data = json.loads(data_dile.read_text())
|
||||
revision = run_data['revision']
|
||||
platform = run_data['platform']
|
||||
|
||||
run_result = run_data['result']
|
||||
args_list = []
|
||||
|
||||
for suit_result in run_result:
|
||||
suit = suit_result['suit']
|
||||
total_duration = suit_result['total_duration']
|
||||
|
||||
suit_result['data'].append({
|
||||
'name': 'total_duration',
|
||||
'value': total_duration,
|
||||
'unit': 's',
|
||||
'report': 'lower_is_better',
|
||||
})
|
||||
|
||||
for metric in suit_result['data']:
|
||||
values = {
|
||||
'suit': suit,
|
||||
'revision': revision,
|
||||
'platform': platform,
|
||||
'metric_name': metric['name'],
|
||||
'metric_value': metric['value'],
|
||||
'metric_unit': metric['unit'],
|
||||
'metric_report_type': metric['report'],
|
||||
'recorded_at_timestamp': datetime.utcfromtimestamp(recorded_at_timestamp),
|
||||
}
|
||||
args_list.append(values)
|
||||
|
||||
psycopg2.extras.execute_values(
|
||||
cursor,
|
||||
"""
|
||||
INSERT INTO perf_test_results (
|
||||
suit,
|
||||
revision,
|
||||
platform,
|
||||
metric_name,
|
||||
metric_value,
|
||||
metric_unit,
|
||||
metric_report_type,
|
||||
recorded_at_timestamp
|
||||
) VALUES %s
|
||||
""",
|
||||
args_list,
|
||||
template="""(
|
||||
%(suit)s,
|
||||
%(revision)s,
|
||||
%(platform)s,
|
||||
%(metric_name)s,
|
||||
%(metric_value)s,
|
||||
%(metric_unit)s,
|
||||
%(metric_report_type)s,
|
||||
%(recorded_at_timestamp)s
|
||||
)""",
|
||||
)
|
||||
return len(args_list)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Perf test result uploader. \
|
||||
Database connection string should be provided via DATABASE_URL environment variable', )
|
||||
parser.add_argument(
|
||||
'--ingest',
|
||||
type=Path,
|
||||
help='Path to perf test result file, or directory with perf test result files')
|
||||
parser.add_argument('--initdb', action='store_true', help='Initialuze database')
|
||||
|
||||
args = parser.parse_args()
|
||||
with get_connection_cursor() as cur:
|
||||
if args.initdb:
|
||||
create_table(cur)
|
||||
|
||||
if not args.ingest.exists():
|
||||
err(f'ingest path {args.ingest} does not exist')
|
||||
|
||||
if args.ingest:
|
||||
if args.ingest.is_dir():
|
||||
for item in sorted(args.ingest.iterdir(), key=lambda x: int(x.name.split('_')[0])):
|
||||
recorded_at_timestamp = int(item.name.split('_')[0])
|
||||
ingested = ingest_perf_test_result(cur, item, recorded_at_timestamp)
|
||||
print(f'Ingested {ingested} metric values from {item}')
|
||||
else:
|
||||
recorded_at_timestamp = int(args.ingest.name.split('_')[0])
|
||||
ingested = ingest_perf_test_result(cur, args.ingest, recorded_at_timestamp)
|
||||
print(f'Ingested {ingested} metric values from {args.ingest}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -89,7 +89,7 @@ def test_foobar(zenith_env_builder: ZenithEnvBuilder):
|
||||
|
||||
# Now create the environment. This initializes the repository, and starts
|
||||
# up the page server and the safekeepers
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
# Run the test
|
||||
...
|
||||
|
||||
@@ -1,45 +1,49 @@
|
||||
from contextlib import closing
|
||||
from typing import Iterator
|
||||
from uuid import uuid4
|
||||
from uuid import UUID, uuid4
|
||||
import psycopg2
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder, ZenithPageserverApiException
|
||||
import pytest
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_pageserver_auth(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.pageserver_auth_enabled = True
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
ps = env.pageserver
|
||||
|
||||
tenant_token = env.auth_keys.generate_tenant_token(env.initial_tenant)
|
||||
tenant_token = env.auth_keys.generate_tenant_token(env.initial_tenant.hex)
|
||||
tenant_http_client = env.pageserver.http_client(tenant_token)
|
||||
invalid_tenant_token = env.auth_keys.generate_tenant_token(uuid4().hex)
|
||||
invalid_tenant_http_client = env.pageserver.http_client(invalid_tenant_token)
|
||||
|
||||
management_token = env.auth_keys.generate_management_token()
|
||||
management_http_client = env.pageserver.http_client(management_token)
|
||||
|
||||
# this does not invoke auth check and only decodes jwt and checks it for validity
|
||||
# check both tokens
|
||||
ps.safe_psql("status", password=tenant_token)
|
||||
ps.safe_psql("status", password=management_token)
|
||||
ps.safe_psql("set FOO", password=tenant_token)
|
||||
ps.safe_psql("set FOO", password=management_token)
|
||||
|
||||
# tenant can create branches
|
||||
ps.safe_psql(f"branch_create {env.initial_tenant} new1 main", password=tenant_token)
|
||||
tenant_http_client.branch_create(env.initial_tenant, 'new1', 'main')
|
||||
# console can create branches for tenant
|
||||
ps.safe_psql(f"branch_create {env.initial_tenant} new2 main", password=management_token)
|
||||
management_http_client.branch_create(env.initial_tenant, 'new2', 'main')
|
||||
|
||||
# fail to create branch using token with different tenantid
|
||||
with pytest.raises(psycopg2.DatabaseError, match='Tenant id mismatch. Permission denied'):
|
||||
ps.safe_psql(f"branch_create {env.initial_tenant} new2 main", password=invalid_tenant_token)
|
||||
# fail to create branch using token with different tenant_id
|
||||
with pytest.raises(ZenithPageserverApiException,
|
||||
match='Forbidden: Tenant id mismatch. Permission denied'):
|
||||
invalid_tenant_http_client.branch_create(env.initial_tenant, "new3", "main")
|
||||
|
||||
# create tenant using management token
|
||||
ps.safe_psql(f"tenant_create {uuid4().hex}", password=management_token)
|
||||
management_http_client.tenant_create(uuid4())
|
||||
|
||||
# fail to create tenant using tenant token
|
||||
with pytest.raises(
|
||||
psycopg2.DatabaseError,
|
||||
match='Attempt to access management api with tenant scope. Permission denied'):
|
||||
ps.safe_psql(f"tenant_create {uuid4().hex}", password=tenant_token)
|
||||
ZenithPageserverApiException,
|
||||
match='Forbidden: Attempt to access management api with tenant scope. Permission denied'
|
||||
):
|
||||
tenant_http_client.tenant_create(uuid4())
|
||||
|
||||
|
||||
@pytest.mark.parametrize('with_wal_acceptors', [False, True])
|
||||
@@ -47,10 +51,10 @@ def test_compute_auth_to_pageserver(zenith_env_builder: ZenithEnvBuilder, with_w
|
||||
zenith_env_builder.pageserver_auth_enabled = True
|
||||
if with_wal_acceptors:
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
branch = f"test_compute_auth_to_pageserver{with_wal_acceptors}"
|
||||
env.zenith_cli(["branch", branch, "main"])
|
||||
env.zenith_cli.create_branch(branch, "main")
|
||||
|
||||
pg = env.postgres.create_start(branch)
|
||||
|
||||
|
||||
154
test_runner/batch_others/test_backpressure.py
Normal file
154
test_runner/batch_others/test_backpressure.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from contextlib import closing, contextmanager
|
||||
import psycopg2.extras
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
import os
|
||||
import time
|
||||
import asyncpg
|
||||
from fixtures.zenith_fixtures import Postgres
|
||||
import threading
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pg_cur(pg):
|
||||
with closing(pg.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
yield cur
|
||||
|
||||
|
||||
# Periodically check that all backpressure lags are below the configured threshold,
|
||||
# assert if they are not.
|
||||
# If the check query fails, stop the thread. Main thread should notice that and stop the test.
|
||||
def check_backpressure(pg: Postgres, stop_event: threading.Event, polling_interval=5):
|
||||
log.info("checks started")
|
||||
|
||||
with pg_cur(pg) as cur:
|
||||
cur.execute("CREATE EXTENSION zenith") # TODO move it to zenith_fixtures?
|
||||
|
||||
cur.execute("select pg_size_bytes(current_setting('max_replication_write_lag'))")
|
||||
res = cur.fetchone()
|
||||
max_replication_write_lag_bytes = res[0]
|
||||
log.info(f"max_replication_write_lag: {max_replication_write_lag_bytes} bytes")
|
||||
|
||||
cur.execute("select pg_size_bytes(current_setting('max_replication_flush_lag'))")
|
||||
res = cur.fetchone()
|
||||
max_replication_flush_lag_bytes = res[0]
|
||||
log.info(f"max_replication_flush_lag: {max_replication_flush_lag_bytes} bytes")
|
||||
|
||||
cur.execute("select pg_size_bytes(current_setting('max_replication_apply_lag'))")
|
||||
res = cur.fetchone()
|
||||
max_replication_apply_lag_bytes = res[0]
|
||||
log.info(f"max_replication_apply_lag: {max_replication_apply_lag_bytes} bytes")
|
||||
|
||||
with pg_cur(pg) as cur:
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
cur.execute('''
|
||||
select pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn) as received_lsn_lag,
|
||||
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),disk_consistent_lsn) as disk_consistent_lsn_lag,
|
||||
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),remote_consistent_lsn) as remote_consistent_lsn_lag,
|
||||
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn)),
|
||||
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),disk_consistent_lsn)),
|
||||
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),remote_consistent_lsn))
|
||||
from backpressure_lsns();
|
||||
''')
|
||||
|
||||
res = cur.fetchone()
|
||||
received_lsn_lag = res[0]
|
||||
disk_consistent_lsn_lag = res[1]
|
||||
remote_consistent_lsn_lag = res[2]
|
||||
|
||||
log.info(f"received_lsn_lag = {received_lsn_lag} ({res[3]}), "
|
||||
f"disk_consistent_lsn_lag = {disk_consistent_lsn_lag} ({res[4]}), "
|
||||
f"remote_consistent_lsn_lag = {remote_consistent_lsn_lag} ({res[5]})")
|
||||
|
||||
# Since feedback from pageserver is not immediate, we should allow some lag overflow
|
||||
lag_overflow = 5 * 1024 * 1024 # 5MB
|
||||
|
||||
if max_replication_write_lag_bytes > 0:
|
||||
assert received_lsn_lag < max_replication_write_lag_bytes + lag_overflow
|
||||
if max_replication_flush_lag_bytes > 0:
|
||||
assert disk_consistent_lsn_lag < max_replication_flush_lag_bytes + lag_overflow
|
||||
if max_replication_apply_lag_bytes > 0:
|
||||
assert remote_consistent_lsn_lag < max_replication_apply_lag_bytes + lag_overflow
|
||||
|
||||
time.sleep(polling_interval)
|
||||
|
||||
except Exception as e:
|
||||
log.info(f"backpressure check query failed: {e}")
|
||||
stop_event.set()
|
||||
|
||||
log.info('check thread stopped')
|
||||
|
||||
|
||||
# This test illustrates how to tune backpressure to control the lag
|
||||
# between the WAL flushed on compute node and WAL digested by pageserver.
|
||||
#
|
||||
# To test it, throttle walreceiver ingest using failpoint and run heavy write load.
|
||||
# If backpressure is disabled or not tuned properly, the query will timeout, because the walreceiver cannot keep up.
|
||||
# If backpressure is enabled and tuned properly, insertion will be throttled, but the query will not timeout.
|
||||
|
||||
|
||||
def test_backpressure_received_lsn_lag(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init()
|
||||
# Create a branch for us
|
||||
env.zenith_cli.create_branch("test_backpressure", "main")
|
||||
|
||||
pg = env.postgres.create_start('test_backpressure',
|
||||
config_lines=['max_replication_write_lag=30MB'])
|
||||
log.info("postgres is running on 'test_backpressure' branch")
|
||||
|
||||
# setup check thread
|
||||
check_stop_event = threading.Event()
|
||||
check_thread = threading.Thread(target=check_backpressure, args=(pg, check_stop_event))
|
||||
check_thread.start()
|
||||
|
||||
# Configure failpoint to slow down walreceiver ingest
|
||||
with closing(env.pageserver.connect()) as psconn:
|
||||
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
|
||||
pscur.execute("failpoints walreceiver-after-ingest=sleep(20)")
|
||||
|
||||
# FIXME
|
||||
# Wait for the check thread to start
|
||||
#
|
||||
# Now if load starts too soon,
|
||||
# check thread cannot auth, because it is not able to connect to the database
|
||||
# because of the lag and waiting for lsn to replay to arrive.
|
||||
time.sleep(2)
|
||||
|
||||
with pg_cur(pg) as cur:
|
||||
# Create and initialize test table
|
||||
cur.execute("CREATE TABLE foo(x bigint)")
|
||||
|
||||
inserts_to_do = 2000000
|
||||
rows_inserted = 0
|
||||
|
||||
while check_thread.is_alive() and rows_inserted < inserts_to_do:
|
||||
try:
|
||||
cur.execute("INSERT INTO foo select from generate_series(1, 100000)")
|
||||
rows_inserted += 100000
|
||||
except Exception as e:
|
||||
if check_thread.is_alive():
|
||||
log.info('stopping check thread')
|
||||
check_stop_event.set()
|
||||
check_thread.join()
|
||||
assert False, f"Exception {e} while inserting rows, but WAL lag is within configured threshold. That means backpressure is not tuned properly"
|
||||
else:
|
||||
assert False, f"Exception {e} while inserting rows and WAL lag overflowed configured threshold. That means backpressure doesn't work."
|
||||
|
||||
log.info(f"inserted {rows_inserted} rows")
|
||||
|
||||
if check_thread.is_alive():
|
||||
log.info('stopping check thread')
|
||||
check_stop_event.set()
|
||||
check_thread.join()
|
||||
log.info('check thread stopped')
|
||||
else:
|
||||
assert False, "WAL lag overflowed configured threshold. That means backpressure doesn't work."
|
||||
|
||||
|
||||
#TODO test_backpressure_disk_consistent_lsn_lag. Play with pageserver's checkpoint settings
|
||||
#TODO test_backpressure_remote_consistent_lsn_lag
|
||||
@@ -7,8 +7,6 @@ from fixtures.log_helper import log
|
||||
from fixtures.utils import print_gc_result
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Create a couple of branches off the main branch, at a historical point in time.
|
||||
@@ -21,10 +19,10 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
#
|
||||
# See https://github.com/zenithdb/zenith/issues/1068
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
# Branch at the point where only 100 rows were inserted
|
||||
env.zenith_cli(["branch", "test_branch_behind", "main"])
|
||||
env.zenith_cli.create_branch("test_branch_behind", "main")
|
||||
|
||||
pgmain = env.postgres.create_start('test_branch_behind')
|
||||
log.info("postgres is running on 'test_branch_behind' branch")
|
||||
@@ -62,7 +60,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
log.info(f'LSN after 200100 rows: {lsn_b}')
|
||||
|
||||
# Branch at the point where only 100 rows were inserted
|
||||
env.zenith_cli(["branch", "test_branch_behind_hundred", "test_branch_behind@" + lsn_a])
|
||||
env.zenith_cli.create_branch("test_branch_behind_hundred", "test_branch_behind@" + lsn_a)
|
||||
|
||||
# Insert many more rows. This generates enough WAL to fill a few segments.
|
||||
main_cur.execute('''
|
||||
@@ -77,7 +75,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
log.info(f'LSN after 400100 rows: {lsn_c}')
|
||||
|
||||
# Branch at the point where only 200100 rows were inserted
|
||||
env.zenith_cli(["branch", "test_branch_behind_more", "test_branch_behind@" + lsn_b])
|
||||
env.zenith_cli.create_branch("test_branch_behind_more", "test_branch_behind@" + lsn_b)
|
||||
|
||||
pg_hundred = env.postgres.create_start("test_branch_behind_hundred")
|
||||
pg_more = env.postgres.create_start("test_branch_behind_more")
|
||||
@@ -101,7 +99,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
# Check bad lsn's for branching
|
||||
|
||||
# branch at segment boundary
|
||||
env.zenith_cli(["branch", "test_branch_segment_boundary", "test_branch_behind@0/3000000"])
|
||||
env.zenith_cli.create_branch("test_branch_segment_boundary", "test_branch_behind@0/3000000")
|
||||
pg = env.postgres.create_start("test_branch_segment_boundary")
|
||||
cur = pg.connect().cursor()
|
||||
cur.execute('SELECT 1')
|
||||
@@ -109,19 +107,23 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
|
||||
# branch at pre-initdb lsn
|
||||
with pytest.raises(Exception, match="invalid branch start lsn"):
|
||||
env.zenith_cli(["branch", "test_branch_preinitdb", "test_branch_behind@0/42"])
|
||||
env.zenith_cli.create_branch("test_branch_preinitdb", "main@0/42")
|
||||
|
||||
# branch at pre-ancestor lsn
|
||||
with pytest.raises(Exception, match="less than timeline ancestor lsn"):
|
||||
env.zenith_cli.create_branch("test_branch_preinitdb", "test_branch_behind@0/42")
|
||||
|
||||
# check that we cannot create branch based on garbage collected data
|
||||
with closing(env.pageserver.connect()) as psconn:
|
||||
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
|
||||
# call gc to advace latest_gc_cutoff_lsn
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
|
||||
with pytest.raises(Exception, match="invalid branch start lsn"):
|
||||
# this gced_lsn is pretty random, so if gc is disabled this woudln't fail
|
||||
env.zenith_cli(["branch", "test_branch_create_fail", f"test_branch_behind@{gced_lsn}"])
|
||||
env.zenith_cli.create_branch("test_branch_create_fail", f"test_branch_behind@{gced_lsn}")
|
||||
|
||||
# check that after gc everything is still there
|
||||
hundred_cur.execute('SELECT count(*) FROM foo')
|
||||
|
||||
@@ -6,16 +6,13 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test compute node start after clog truncation
|
||||
#
|
||||
def test_clog_truncate(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_clog_truncate", "empty"])
|
||||
env.zenith_cli.create_branch("test_clog_truncate", "empty")
|
||||
|
||||
# set agressive autovacuum to make sure that truncation will happen
|
||||
config = [
|
||||
@@ -65,8 +62,8 @@ def test_clog_truncate(zenith_simple_env: ZenithEnv):
|
||||
|
||||
# create new branch after clog truncation and start a compute node on it
|
||||
log.info(f'create branch at lsn_after_truncation {lsn_after_truncation}')
|
||||
env.zenith_cli(
|
||||
["branch", "test_clog_truncate_new", "test_clog_truncate@" + lsn_after_truncation])
|
||||
env.zenith_cli.create_branch("test_clog_truncate_new",
|
||||
"test_clog_truncate@" + lsn_after_truncation)
|
||||
|
||||
pg2 = env.postgres.create_start('test_clog_truncate_new')
|
||||
log.info('postgres is running on test_clog_truncate_new branch')
|
||||
|
||||
@@ -3,16 +3,13 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test starting Postgres with custom options
|
||||
#
|
||||
def test_config(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_config", "empty"])
|
||||
env.zenith_cli.create_branch("test_config", "empty")
|
||||
|
||||
# change config
|
||||
pg = env.postgres.create_start('test_config', config_lines=['log_min_messages=debug1'])
|
||||
|
||||
@@ -5,15 +5,13 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test CREATE DATABASE when there have been relmapper changes
|
||||
#
|
||||
def test_createdb(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli(["branch", "test_createdb", "empty"])
|
||||
env.zenith_cli.create_branch("test_createdb", "empty")
|
||||
|
||||
pg = env.postgres.create_start('test_createdb')
|
||||
log.info("postgres is running on 'test_createdb' branch")
|
||||
@@ -29,7 +27,7 @@ def test_createdb(zenith_simple_env: ZenithEnv):
|
||||
lsn = cur.fetchone()[0]
|
||||
|
||||
# Create a branch
|
||||
env.zenith_cli(["branch", "test_createdb2", "test_createdb@" + lsn])
|
||||
env.zenith_cli.create_branch("test_createdb2", "test_createdb@" + lsn)
|
||||
|
||||
pg2 = env.postgres.create_start('test_createdb2')
|
||||
|
||||
@@ -43,7 +41,7 @@ def test_createdb(zenith_simple_env: ZenithEnv):
|
||||
#
|
||||
def test_dropdb(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli(["branch", "test_dropdb", "empty"])
|
||||
env.zenith_cli.create_branch("test_dropdb", "empty")
|
||||
|
||||
pg = env.postgres.create_start('test_dropdb')
|
||||
log.info("postgres is running on 'test_dropdb' branch")
|
||||
@@ -68,10 +66,10 @@ def test_dropdb(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
lsn_after_drop = cur.fetchone()[0]
|
||||
|
||||
# Create two branches before and after database drop.
|
||||
env.zenith_cli(["branch", "test_before_dropdb", "test_dropdb@" + lsn_before_drop])
|
||||
env.zenith_cli.create_branch("test_before_dropdb", "test_dropdb@" + lsn_before_drop)
|
||||
pg_before = env.postgres.create_start('test_before_dropdb')
|
||||
|
||||
env.zenith_cli(["branch", "test_after_dropdb", "test_dropdb@" + lsn_after_drop])
|
||||
env.zenith_cli.create_branch("test_after_dropdb", "test_dropdb@" + lsn_after_drop)
|
||||
pg_after = env.postgres.create_start('test_after_dropdb')
|
||||
|
||||
# Test that database exists on the branch before drop
|
||||
|
||||
@@ -3,15 +3,13 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test CREATE USER to check shared catalog restore
|
||||
#
|
||||
def test_createuser(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli(["branch", "test_createuser", "empty"])
|
||||
env.zenith_cli.create_branch("test_createuser", "empty")
|
||||
|
||||
pg = env.postgres.create_start('test_createuser')
|
||||
log.info("postgres is running on 'test_createuser' branch")
|
||||
@@ -27,7 +25,7 @@ def test_createuser(zenith_simple_env: ZenithEnv):
|
||||
lsn = cur.fetchone()[0]
|
||||
|
||||
# Create a branch
|
||||
env.zenith_cli(["branch", "test_createuser2", "test_createuser@" + lsn])
|
||||
env.zenith_cli.create_branch("test_createuser2", "test_createuser@" + lsn)
|
||||
|
||||
pg2 = env.postgres.create_start('test_createuser2')
|
||||
|
||||
|
||||
80
test_runner/batch_others/test_gc_aggressive.py
Normal file
80
test_runner/batch_others/test_gc_aggressive.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from contextlib import closing
|
||||
|
||||
import asyncio
|
||||
import asyncpg
|
||||
import random
|
||||
|
||||
from fixtures.zenith_fixtures import ZenithEnv, Postgres, Safekeeper
|
||||
from fixtures.log_helper import log
|
||||
|
||||
# Test configuration
|
||||
#
|
||||
# Create a table with {num_rows} rows, and perform {updates_to_perform} random
|
||||
# UPDATEs on it, using {num_connections} separate connections.
|
||||
num_connections = 10
|
||||
num_rows = 100000
|
||||
updates_to_perform = 10000
|
||||
|
||||
updates_performed = 0
|
||||
|
||||
|
||||
# Run random UPDATEs on test table
|
||||
async def update_table(pg: Postgres):
|
||||
global updates_performed
|
||||
pg_conn = await pg.connect_async()
|
||||
|
||||
while updates_performed < updates_to_perform:
|
||||
updates_performed += 1
|
||||
id = random.randrange(1, num_rows)
|
||||
row = await pg_conn.fetchrow(f'UPDATE foo SET counter = counter + 1 WHERE id = {id}')
|
||||
|
||||
|
||||
# Perform aggressive GC with 0 horizon
|
||||
async def gc(env: ZenithEnv, timeline: str):
|
||||
psconn = await env.pageserver.connect_async()
|
||||
|
||||
while updates_performed < updates_to_perform:
|
||||
await psconn.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
|
||||
|
||||
# At the same time, run UPDATEs and GC
|
||||
async def update_and_gc(env: ZenithEnv, pg: Postgres, timeline: str):
|
||||
workers = []
|
||||
for worker_id in range(num_connections):
|
||||
workers.append(asyncio.create_task(update_table(pg)))
|
||||
workers.append(asyncio.create_task(gc(env, timeline)))
|
||||
|
||||
# await all workers
|
||||
await asyncio.gather(*workers)
|
||||
|
||||
|
||||
#
|
||||
# Aggressively force GC, while running queries.
|
||||
#
|
||||
# (repro for https://github.com/zenithdb/zenith/issues/1047)
|
||||
#
|
||||
def test_gc_aggressive(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_gc_aggressive", "empty")
|
||||
pg = env.postgres.create_start('test_gc_aggressive')
|
||||
log.info('postgres is running on test_gc_aggressive branch')
|
||||
|
||||
conn = pg.connect()
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute("SHOW zenith.zenith_timeline")
|
||||
timeline = cur.fetchone()[0]
|
||||
|
||||
# Create table, and insert the first 100 rows
|
||||
cur.execute('CREATE TABLE foo (id int, counter int, t text)')
|
||||
cur.execute(f'''
|
||||
INSERT INTO foo
|
||||
SELECT g, 0, 'long string to consume some space' || g
|
||||
FROM generate_series(1, {num_rows}) g
|
||||
''')
|
||||
cur.execute('CREATE INDEX ON foo(id)')
|
||||
|
||||
asyncio.run(update_and_gc(env, pg, timeline))
|
||||
|
||||
row = cur.execute('SELECT COUNT(*), SUM(counter) FROM foo')
|
||||
assert cur.fetchone() == (num_rows, updates_to_perform)
|
||||
@@ -1,8 +1,6 @@
|
||||
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test multixact state after branching
|
||||
@@ -12,8 +10,7 @@ pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
#
|
||||
def test_multixact(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_multixact", "empty"])
|
||||
env.zenith_cli.create_branch("test_multixact", "empty")
|
||||
pg = env.postgres.create_start('test_multixact')
|
||||
|
||||
log.info("postgres is running on 'test_multixact' branch")
|
||||
@@ -63,7 +60,7 @@ def test_multixact(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
assert int(next_multixact_id) > int(next_multixact_id_old)
|
||||
|
||||
# Branch at this point
|
||||
env.zenith_cli(["branch", "test_multixact_new", "test_multixact@" + lsn])
|
||||
env.zenith_cli.create_branch("test_multixact_new", "test_multixact@" + lsn)
|
||||
pg_new = env.postgres.create_start('test_multixact_new')
|
||||
|
||||
log.info("postgres is running on 'test_multixact_new' branch")
|
||||
|
||||
59
test_runner/batch_others/test_next_xid.py
Normal file
59
test_runner/batch_others/test_next_xid.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
import random
|
||||
import time
|
||||
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
|
||||
|
||||
# Test restarting page server, while safekeeper and compute node keep
|
||||
# running.
|
||||
def test_next_xid(zenith_env_builder: ZenithEnvBuilder):
|
||||
# One safekeeper is enough for this test.
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
pg = env.postgres.create_start('main')
|
||||
|
||||
conn = pg.connect()
|
||||
cur = conn.cursor()
|
||||
cur.execute('CREATE TABLE t(x integer)')
|
||||
|
||||
iterations = 32
|
||||
for i in range(1, iterations + 1):
|
||||
print(f'iteration {i} / {iterations}')
|
||||
|
||||
# Kill and restart the pageserver.
|
||||
pg.stop()
|
||||
env.pageserver.stop(immediate=True)
|
||||
env.pageserver.start()
|
||||
pg.start()
|
||||
|
||||
retry_sleep = 0.5
|
||||
max_retries = 200
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
conn = pg.connect()
|
||||
cur = conn.cursor()
|
||||
cur.execute(f"INSERT INTO t values({i})")
|
||||
conn.close()
|
||||
|
||||
except Exception as error:
|
||||
# It's normal that it takes some time for the pageserver to
|
||||
# restart, and for the connection to fail until it does. It
|
||||
# should eventually recover, so retry until it succeeds.
|
||||
print(f'failed: {error}')
|
||||
if retries < max_retries:
|
||||
retries += 1
|
||||
print(f'retry {retries} / {max_retries}')
|
||||
time.sleep(retry_sleep)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
break
|
||||
|
||||
conn = pg.connect()
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT count(*) FROM t")
|
||||
assert cur.fetchone() == (iterations, )
|
||||
@@ -3,8 +3,6 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test where Postgres generates a lot of WAL, and it's garbage collected away, but
|
||||
@@ -18,8 +16,7 @@ pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
#
|
||||
def test_old_request_lsn(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_old_request_lsn", "empty"])
|
||||
env.zenith_cli.create_branch("test_old_request_lsn", "empty")
|
||||
pg = env.postgres.create_start('test_old_request_lsn')
|
||||
log.info('postgres is running on test_old_request_lsn branch')
|
||||
|
||||
@@ -57,7 +54,7 @@ def test_old_request_lsn(zenith_simple_env: ZenithEnv):
|
||||
# Make a lot of updates on a single row, generating a lot of WAL. Trigger
|
||||
# garbage collections so that the page server will remove old page versions.
|
||||
for i in range(10):
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
for j in range(100):
|
||||
cur.execute('UPDATE foo SET val = val + 1 WHERE id = 1;')
|
||||
|
||||
|
||||
@@ -1,95 +1,22 @@
|
||||
import json
|
||||
from uuid import uuid4, UUID
|
||||
import pytest
|
||||
import psycopg2
|
||||
import requests
|
||||
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient
|
||||
from typing import cast
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient, zenith_binpath
|
||||
|
||||
|
||||
def test_status_psql(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
assert env.pageserver.safe_psql('status') == [
|
||||
('hello world', ),
|
||||
]
|
||||
|
||||
|
||||
def test_branch_list_psql(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_branch_list_main", "empty"])
|
||||
|
||||
conn = env.pageserver.connect()
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute(f'branch_list {env.initial_tenant}')
|
||||
branches = json.loads(cur.fetchone()[0])
|
||||
# Filter out branches created by other tests
|
||||
branches = [x for x in branches if x['name'].startswith('test_branch_list')]
|
||||
|
||||
assert len(branches) == 1
|
||||
assert branches[0]['name'] == 'test_branch_list_main'
|
||||
assert 'timeline_id' in branches[0]
|
||||
assert 'latest_valid_lsn' in branches[0]
|
||||
assert 'ancestor_id' in branches[0]
|
||||
assert 'ancestor_lsn' in branches[0]
|
||||
|
||||
# Create another branch, and start Postgres on it
|
||||
env.zenith_cli(['branch', 'test_branch_list_experimental', 'test_branch_list_main'])
|
||||
env.zenith_cli(['pg', 'create', 'test_branch_list_experimental'])
|
||||
|
||||
cur.execute(f'branch_list {env.initial_tenant}')
|
||||
new_branches = json.loads(cur.fetchone()[0])
|
||||
# Filter out branches created by other tests
|
||||
new_branches = [x for x in new_branches if x['name'].startswith('test_branch_list')]
|
||||
assert len(new_branches) == 2
|
||||
new_branches.sort(key=lambda k: k['name'])
|
||||
|
||||
assert new_branches[0]['name'] == 'test_branch_list_experimental'
|
||||
assert new_branches[0]['timeline_id'] != branches[0]['timeline_id']
|
||||
|
||||
# TODO: do the LSNs have to match here?
|
||||
assert new_branches[1] == branches[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_tenant_list_psql(zenith_env_builder: ZenithEnvBuilder):
|
||||
# don't use zenith_simple_env, because there might be other tenants there,
|
||||
# left over from other tests.
|
||||
# test that we cannot override node id
|
||||
def test_pageserver_init_node_id(zenith_env_builder: ZenithEnvBuilder):
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
res = env.zenith_cli(["tenant", "list"])
|
||||
res.check_returncode()
|
||||
tenants = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
|
||||
assert tenants == [env.initial_tenant]
|
||||
|
||||
conn = env.pageserver.connect()
|
||||
cur = conn.cursor()
|
||||
|
||||
# check same tenant cannot be created twice
|
||||
with pytest.raises(psycopg2.DatabaseError,
|
||||
match=f'repo for {env.initial_tenant} already exists'):
|
||||
cur.execute(f'tenant_create {env.initial_tenant}')
|
||||
|
||||
# create one more tenant
|
||||
tenant1 = uuid4().hex
|
||||
cur.execute(f'tenant_create {tenant1}')
|
||||
|
||||
cur.execute('tenant_list')
|
||||
|
||||
# compare tenants list
|
||||
new_tenants = sorted(map(lambda t: cast(str, t['id']), json.loads(cur.fetchone()[0])))
|
||||
assert sorted([env.initial_tenant, tenant1]) == new_tenants
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="node id can only be set during pageserver init and cannot be overridden"):
|
||||
env.pageserver.start(overrides=['--pageserver-config-override=id=10'])
|
||||
|
||||
|
||||
def check_client(client: ZenithPageserverHttpClient, initial_tenant: str):
|
||||
def check_client(client: ZenithPageserverHttpClient, initial_tenant: UUID):
|
||||
client.check_status()
|
||||
|
||||
# check initial tenant is there
|
||||
assert initial_tenant in {t['id'] for t in client.tenant_list()}
|
||||
assert initial_tenant.hex in {t['id'] for t in client.tenant_list()}
|
||||
|
||||
# create new tenant and check it is also there
|
||||
tenant_id = uuid4()
|
||||
@@ -121,7 +48,7 @@ def test_pageserver_http_api_client(zenith_simple_env: ZenithEnv):
|
||||
|
||||
def test_pageserver_http_api_client_auth_enabled(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.pageserver_auth_enabled = True
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
management_token = env.auth_keys.generate_management_token()
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ from multiprocessing import Process, Value
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
# Test safekeeper sync and pageserver catch up
|
||||
# while initial compute node is down and pageserver is lagging behind safekeepers.
|
||||
@@ -16,9 +14,9 @@ pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
# and new compute node contains all data.
|
||||
def test_pageserver_catchup_while_compute_down(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
env.zenith_cli(["branch", "test_pageserver_catchup_while_compute_down", "main"])
|
||||
env.zenith_cli.create_branch("test_pageserver_catchup_while_compute_down", "main")
|
||||
pg = env.postgres.create_start('test_pageserver_catchup_while_compute_down')
|
||||
|
||||
pg_conn = pg.connect()
|
||||
|
||||
@@ -7,17 +7,15 @@ from multiprocessing import Process, Value
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
# Test restarting page server, while safekeeper and compute node keep
|
||||
# running.
|
||||
def test_pageserver_restart(zenith_env_builder: ZenithEnvBuilder):
|
||||
# One safekeeper is enough for this test.
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
env.zenith_cli(["branch", "test_pageserver_restart", "main"])
|
||||
env.zenith_cli.create_branch("test_pageserver_restart", "main")
|
||||
pg = env.postgres.create_start('test_pageserver_restart')
|
||||
|
||||
pg_conn = pg.connect()
|
||||
|
||||
@@ -5,8 +5,6 @@ import subprocess
|
||||
from fixtures.zenith_fixtures import ZenithEnv, Postgres
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
async def repeat_bytes(buf, repetitions: int):
|
||||
for i in range(repetitions):
|
||||
@@ -39,9 +37,7 @@ async def parallel_load_same_table(pg: Postgres, n_parallel: int):
|
||||
# Load data into one table with COPY TO from 5 parallel connections
|
||||
def test_parallel_copy(zenith_simple_env: ZenithEnv, n_parallel=5):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_parallel_copy", "empty"])
|
||||
|
||||
env.zenith_cli.create_branch("test_parallel_copy", "empty")
|
||||
pg = env.postgres.create_start('test_parallel_copy')
|
||||
log.info("postgres is running on 'test_parallel_copy' branch")
|
||||
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_pgbench(zenith_simple_env: ZenithEnv, pg_bin):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_pgbench", "empty"])
|
||||
|
||||
env.zenith_cli.create_branch("test_pgbench", "empty")
|
||||
pg = env.postgres.create_start('test_pgbench')
|
||||
log.info("postgres is running on 'test_pgbench' branch")
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@ import pytest
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Create read-only compute nodes, anchored at historical points in time.
|
||||
@@ -13,7 +11,7 @@ pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
#
|
||||
def test_readonly_node(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli(["branch", "test_readonly_node", "empty"])
|
||||
env.zenith_cli.create_branch("test_readonly_node", "empty")
|
||||
|
||||
pgmain = env.postgres.create_start('test_readonly_node')
|
||||
log.info("postgres is running on 'test_readonly_node' branch")
|
||||
@@ -88,4 +86,5 @@ def test_readonly_node(zenith_simple_env: ZenithEnv):
|
||||
# Create node at pre-initdb lsn
|
||||
with pytest.raises(Exception, match="invalid basebackup lsn"):
|
||||
# compute node startup with invalid LSN should fail
|
||||
env.zenith_cli(["pg", "start", "test_readonly_node_preinitdb", "test_readonly_node@0/42"])
|
||||
env.zenith_cli.pg_start("test_readonly_node_preinitdb",
|
||||
timeline_spec="test_readonly_node@0/42")
|
||||
|
||||
@@ -9,8 +9,6 @@ from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
import pytest
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Tests that a piece of data is backed up and restored correctly:
|
||||
@@ -28,6 +26,7 @@ pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
# * queries the specific data, ensuring that it matches the one stored before
|
||||
#
|
||||
# The tests are done for all types of remote storage pageserver supports.
|
||||
@pytest.mark.skip(reason="will be fixed with https://github.com/zenithdb/zenith/issues/1193")
|
||||
@pytest.mark.parametrize('storage_type', ['local_fs', 'mock_s3'])
|
||||
def test_remote_storage_backup_and_restore(zenith_env_builder: ZenithEnvBuilder, storage_type: str):
|
||||
zenith_env_builder.rust_log_override = 'debug'
|
||||
@@ -43,7 +42,7 @@ def test_remote_storage_backup_and_restore(zenith_env_builder: ZenithEnvBuilder,
|
||||
data_secret = 'very secret secret'
|
||||
|
||||
##### First start, insert secret data and upload it to the remote storage
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
pg = env.postgres.create_start()
|
||||
|
||||
tenant_id = pg.safe_psql("show zenith.zenith_tenant")[0][0]
|
||||
@@ -74,8 +73,13 @@ def test_remote_storage_backup_and_restore(zenith_env_builder: ZenithEnvBuilder,
|
||||
##### Second start, restore the data and ensure it's the same
|
||||
env.pageserver.start()
|
||||
|
||||
log.info("waiting for timeline redownload")
|
||||
client = env.pageserver.http_client()
|
||||
client.timeline_attach(UUID(tenant_id), UUID(timeline_id))
|
||||
# FIXME cannot handle duplicate download requests (which might be caused by repeated timeline detail calls)
|
||||
# subject to fix in https://github.com/zenithdb/zenith/issues/997
|
||||
time.sleep(5)
|
||||
|
||||
log.info("waiting for timeline redownload")
|
||||
attempts = 0
|
||||
while True:
|
||||
timeline_details = client.timeline_detail(UUID(tenant_id), UUID(timeline_id))
|
||||
|
||||
@@ -4,8 +4,6 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test restarting and recreating a postgres instance
|
||||
@@ -15,9 +13,9 @@ def test_restart_compute(zenith_env_builder: ZenithEnvBuilder, with_wal_acceptor
|
||||
zenith_env_builder.pageserver_auth_enabled = True
|
||||
if with_wal_acceptors:
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
env.zenith_cli(["branch", "test_restart_compute", "main"])
|
||||
env.zenith_cli.create_branch("test_restart_compute", "main")
|
||||
|
||||
pg = env.postgres.create_start('test_restart_compute')
|
||||
log.info("postgres is running on 'test_restart_compute' branch")
|
||||
|
||||
@@ -5,8 +5,6 @@ from fixtures.utils import print_gc_result
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test Garbage Collection of old layer files
|
||||
@@ -16,7 +14,7 @@ pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
#
|
||||
def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli(["branch", "test_layerfiles_gc", "empty"])
|
||||
env.zenith_cli.create_branch("test_layerfiles_gc", "empty")
|
||||
pg = env.postgres.create_start('test_layerfiles_gc')
|
||||
|
||||
with closing(pg.connect()) as conn:
|
||||
@@ -50,7 +48,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
cur.execute("DELETE FROM foo")
|
||||
|
||||
log.info("Running GC before test")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
# remember the number of files
|
||||
@@ -63,7 +61,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
# removing the old image and delta layer.
|
||||
log.info("Inserting one row and running GC")
|
||||
cur.execute("INSERT INTO foo VALUES (1)")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
|
||||
@@ -77,7 +75,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
cur.execute("INSERT INTO foo VALUES (2)")
|
||||
cur.execute("INSERT INTO foo VALUES (3)")
|
||||
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
|
||||
@@ -89,7 +87,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
cur.execute("INSERT INTO foo VALUES (2)")
|
||||
cur.execute("INSERT INTO foo VALUES (3)")
|
||||
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
|
||||
@@ -98,7 +96,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
|
||||
# Run GC again, with no changes in the database. Should not remove anything.
|
||||
log.info("Run GC again, with nothing to do")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
assert row['layer_relfiles_total'] == layer_relfiles_remain
|
||||
@@ -111,7 +109,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
log.info("Drop table and run GC again")
|
||||
cur.execute("DROP TABLE foo")
|
||||
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
# Test subtransactions
|
||||
#
|
||||
@@ -12,8 +10,7 @@ pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
# CLOG.
|
||||
def test_subxacts(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_subxacts", "empty"])
|
||||
env.zenith_cli.create_branch("test_subxacts", "empty")
|
||||
pg = env.postgres.create_start('test_subxacts')
|
||||
|
||||
log.info("postgres is running on 'test_subxacts' branch")
|
||||
|
||||
@@ -108,12 +108,13 @@ def load(pg: Postgres, stop_event: threading.Event, load_ok_event: threading.Eve
|
||||
log.info('load thread stopped')
|
||||
|
||||
|
||||
def assert_local(pageserver_http_client: ZenithPageserverHttpClient, tenant: str, timeline: str):
|
||||
timeline_detail = pageserver_http_client.timeline_detail(UUID(tenant), UUID(timeline))
|
||||
def assert_local(pageserver_http_client: ZenithPageserverHttpClient, tenant: UUID, timeline: str):
|
||||
timeline_detail = pageserver_http_client.timeline_detail(tenant, UUID(timeline))
|
||||
assert timeline_detail.get('type') == "Local", timeline_detail
|
||||
return timeline_detail
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="will be fixed with https://github.com/zenithdb/zenith/issues/1193")
|
||||
@pytest.mark.parametrize('with_load', ['with_load', 'without_load'])
|
||||
def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
|
||||
port_distributor: PortDistributor,
|
||||
@@ -121,15 +122,15 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
zenith_env_builder.enable_local_fs_remote_storage()
|
||||
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
# create folder for remote storage mock
|
||||
remote_storage_mock_path = env.repo_dir / 'local_fs_remote_storage'
|
||||
|
||||
tenant = env.create_tenant("74ee8b079a0e437eb0afea7d26a07209")
|
||||
tenant = env.create_tenant(UUID("74ee8b079a0e437eb0afea7d26a07209"))
|
||||
log.info("tenant to relocate %s", tenant)
|
||||
|
||||
env.zenith_cli(["branch", "test_tenant_relocation", "main", f"--tenantid={tenant}"])
|
||||
env.zenith_cli.create_branch("test_tenant_relocation", "main", tenant_id=tenant)
|
||||
|
||||
tenant_pg = env.postgres.create_start(
|
||||
"test_tenant_relocation",
|
||||
@@ -166,11 +167,11 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
|
||||
# run checkpoint manually to be sure that data landed in remote storage
|
||||
with closing(env.pageserver.connect()) as psconn:
|
||||
with psconn.cursor() as pscur:
|
||||
pscur.execute(f"do_gc {tenant} {timeline}")
|
||||
pscur.execute(f"do_gc {tenant.hex} {timeline}")
|
||||
|
||||
# ensure upload is completed
|
||||
pageserver_http_client = env.pageserver.http_client()
|
||||
timeline_detail = pageserver_http_client.timeline_detail(UUID(tenant), UUID(timeline))
|
||||
timeline_detail = pageserver_http_client.timeline_detail(tenant, UUID(timeline))
|
||||
assert timeline_detail['disk_consistent_lsn'] == timeline_detail['timeline_state']['Ready']
|
||||
|
||||
log.info("inititalizing new pageserver")
|
||||
@@ -192,8 +193,8 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
|
||||
new_pageserver_pg_port,
|
||||
new_pageserver_http_port):
|
||||
|
||||
# call to attach timeline to new timeline
|
||||
new_pageserver_http_client.timeline_attach(UUID(tenant), UUID(timeline))
|
||||
# call to attach timeline to new pageserver
|
||||
new_pageserver_http_client.timeline_attach(tenant, UUID(timeline))
|
||||
# FIXME cannot handle duplicate download requests, subject to fix in https://github.com/zenithdb/zenith/issues/997
|
||||
time.sleep(5)
|
||||
# new pageserver should in sync (modulo wal tail or vacuum activity) with the old one because there was no new writes since checkpoint
|
||||
@@ -240,7 +241,7 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
|
||||
# detach tenant from old pageserver before we check
|
||||
# that all the data is there to be sure that old pageserver
|
||||
# is no longer involved, and if it is, we will see the errors
|
||||
pageserver_http_client.timeline_detach(UUID(tenant), UUID(timeline))
|
||||
pageserver_http_client.timeline_detach(tenant, UUID(timeline))
|
||||
|
||||
with pg_cur(tenant_pg) as cur:
|
||||
# check that data is still there
|
||||
|
||||
@@ -10,23 +10,17 @@ def test_tenants_normal_work(zenith_env_builder: ZenithEnvBuilder, with_wal_acce
|
||||
if with_wal_acceptors:
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
"""Tests tenants with and without wal acceptors"""
|
||||
tenant_1 = env.create_tenant()
|
||||
tenant_2 = env.create_tenant()
|
||||
|
||||
env.zenith_cli([
|
||||
"branch",
|
||||
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
"main",
|
||||
f"--tenantid={tenant_1}"
|
||||
])
|
||||
env.zenith_cli([
|
||||
"branch",
|
||||
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
"main",
|
||||
f"--tenantid={tenant_2}"
|
||||
])
|
||||
env.zenith_cli.create_branch(f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
"main",
|
||||
tenant_id=tenant_1)
|
||||
env.zenith_cli.create_branch(f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
"main",
|
||||
tenant_id=tenant_2)
|
||||
|
||||
pg_tenant1 = env.postgres.create_start(
|
||||
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from contextlib import closing
|
||||
from uuid import UUID
|
||||
import psycopg2.extras
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
import psycopg2.errors
|
||||
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, Postgres
|
||||
from fixtures.log_helper import log
|
||||
import time
|
||||
|
||||
|
||||
def test_timeline_size(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
# Branch at the point where only 100 rows were inserted
|
||||
env.zenith_cli(["branch", "test_timeline_size", "empty"])
|
||||
env.zenith_cli.create_branch("test_timeline_size", "empty")
|
||||
|
||||
client = env.pageserver.http_client()
|
||||
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
|
||||
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
|
||||
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
|
||||
|
||||
pgmain = env.postgres.create_start("test_timeline_size")
|
||||
@@ -29,9 +31,102 @@ def test_timeline_size(zenith_simple_env: ZenithEnv):
|
||||
FROM generate_series(1, 10) g
|
||||
""")
|
||||
|
||||
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
|
||||
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
|
||||
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
|
||||
cur.execute("TRUNCATE foo")
|
||||
|
||||
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
|
||||
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
|
||||
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
|
||||
|
||||
|
||||
# wait until received_lsn_lag is 0
|
||||
def wait_for_pageserver_catchup(pgmain: Postgres, polling_interval=1, timeout=60):
|
||||
started_at = time.time()
|
||||
|
||||
received_lsn_lag = 1
|
||||
while received_lsn_lag > 0:
|
||||
elapsed = time.time() - started_at
|
||||
if elapsed > timeout:
|
||||
raise RuntimeError(
|
||||
f"timed out waiting for pageserver to reach pg_current_wal_flush_lsn()")
|
||||
|
||||
with closing(pgmain.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
|
||||
cur.execute('''
|
||||
select pg_size_pretty(pg_cluster_size()),
|
||||
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn) as received_lsn_lag
|
||||
FROM backpressure_lsns();
|
||||
''')
|
||||
res = cur.fetchone()
|
||||
log.info(f"pg_cluster_size = {res[0]}, received_lsn_lag = {res[1]}")
|
||||
received_lsn_lag = res[1]
|
||||
|
||||
time.sleep(polling_interval)
|
||||
|
||||
|
||||
def test_timeline_size_quota(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init_start()
|
||||
env.zenith_cli.create_branch("test_timeline_size_quota", "main")
|
||||
|
||||
client = env.pageserver.http_client()
|
||||
res = client.branch_detail(env.initial_tenant, "test_timeline_size_quota")
|
||||
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
|
||||
|
||||
pgmain = env.postgres.create_start(
|
||||
"test_timeline_size_quota",
|
||||
# Set small limit for the test
|
||||
config_lines=['zenith.max_cluster_size=30MB'],
|
||||
)
|
||||
log.info("postgres is running on 'test_timeline_size_quota' branch")
|
||||
|
||||
with closing(pgmain.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION zenith") # TODO move it to zenith_fixtures?
|
||||
|
||||
cur.execute("CREATE TABLE foo (t text)")
|
||||
|
||||
wait_for_pageserver_catchup(pgmain)
|
||||
|
||||
# Insert many rows. This query must fail because of space limit
|
||||
try:
|
||||
cur.execute('''
|
||||
INSERT INTO foo
|
||||
SELECT 'long string to consume some space' || g
|
||||
FROM generate_series(1, 100000) g
|
||||
''')
|
||||
|
||||
wait_for_pageserver_catchup(pgmain)
|
||||
|
||||
cur.execute('''
|
||||
INSERT INTO foo
|
||||
SELECT 'long string to consume some space' || g
|
||||
FROM generate_series(1, 500000) g
|
||||
''')
|
||||
|
||||
# If we get here, the timeline size limit failed
|
||||
log.error("Query unexpectedly succeeded")
|
||||
assert False
|
||||
|
||||
except psycopg2.errors.DiskFull as err:
|
||||
log.info(f"Query expectedly failed with: {err}")
|
||||
|
||||
# drop table to free space
|
||||
cur.execute('DROP TABLE foo')
|
||||
|
||||
wait_for_pageserver_catchup(pgmain)
|
||||
|
||||
# create it again and insert some rows. This query must succeed
|
||||
cur.execute("CREATE TABLE foo (t text)")
|
||||
cur.execute('''
|
||||
INSERT INTO foo
|
||||
SELECT 'long string to consume some space' || g
|
||||
FROM generate_series(1, 10000) g
|
||||
''')
|
||||
|
||||
wait_for_pageserver_catchup(pgmain)
|
||||
|
||||
cur.execute("SELECT * from pg_size_pretty(pg_cluster_size())")
|
||||
pg_cluster_size = cur.fetchone()
|
||||
log.info(f"pg_cluster_size = {pg_cluster_size}")
|
||||
|
||||
@@ -3,15 +3,13 @@ import os
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test branching, when a transaction is in prepared state
|
||||
#
|
||||
def test_twophase(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli(["branch", "test_twophase", "empty"])
|
||||
env.zenith_cli.create_branch("test_twophase", "empty")
|
||||
|
||||
pg = env.postgres.create_start('test_twophase', config_lines=['max_prepared_transactions=5'])
|
||||
log.info("postgres is running on 'test_twophase' branch")
|
||||
@@ -58,7 +56,7 @@ def test_twophase(zenith_simple_env: ZenithEnv):
|
||||
assert len(twophase_files) == 2
|
||||
|
||||
# Create a branch with the transaction in prepared state
|
||||
env.zenith_cli(["branch", "test_twophase_prepared", "test_twophase"])
|
||||
env.zenith_cli.create_branch("test_twophase_prepared", "test_twophase")
|
||||
|
||||
# Start compute on the new branch
|
||||
pg2 = env.postgres.create_start(
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test that the VM bit is cleared correctly at a HEAP_DELETE and
|
||||
@@ -11,8 +9,7 @@ pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
def test_vm_bit_clear(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_vm_bit_clear", "empty"])
|
||||
env.zenith_cli.create_branch("test_vm_bit_clear", "empty")
|
||||
pg = env.postgres.create_start('test_vm_bit_clear')
|
||||
|
||||
log.info("postgres is running on 'test_vm_bit_clear' branch")
|
||||
@@ -36,7 +33,7 @@ def test_vm_bit_clear(zenith_simple_env: ZenithEnv):
|
||||
cur.execute('UPDATE vmtest_update SET id = 5000 WHERE id = 1')
|
||||
|
||||
# Branch at this point, to test that later
|
||||
env.zenith_cli(["branch", "test_vm_bit_clear_new", "test_vm_bit_clear"])
|
||||
env.zenith_cli.create_branch("test_vm_bit_clear_new", "test_vm_bit_clear")
|
||||
|
||||
# Clear the buffer cache, to force the VM page to be re-fetched from
|
||||
# the page server
|
||||
|
||||
@@ -12,21 +12,19 @@ from contextlib import closing
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import Process, Value
|
||||
from pathlib import Path
|
||||
from fixtures.zenith_fixtures import PgBin, ZenithEnvBuilder, PortDistributor, SafekeeperPort, zenith_binpath, PgProtocol
|
||||
from fixtures.zenith_fixtures import PgBin, Postgres, Safekeeper, ZenithEnv, ZenithEnvBuilder, PortDistributor, SafekeeperPort, zenith_binpath, PgProtocol
|
||||
from fixtures.utils import lsn_to_hex, mkdir_if_needed
|
||||
from fixtures.log_helper import log
|
||||
from typing import List, Optional, Any
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
# basic test, write something in setup with wal acceptors, ensure that commits
|
||||
# succeed and data is written
|
||||
def test_normal_work(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_normal_work", "main"])
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_normal_work", "main")
|
||||
|
||||
pg = env.postgres.create_start('test_wal_acceptors_normal_work')
|
||||
|
||||
@@ -53,7 +51,7 @@ class BranchMetrics:
|
||||
# against different timelines.
|
||||
def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
n_timelines = 3
|
||||
|
||||
@@ -62,10 +60,10 @@ def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
|
||||
# start postgres on each timeline
|
||||
pgs = []
|
||||
for branch in branches:
|
||||
env.zenith_cli(["branch", branch, "main"])
|
||||
env.zenith_cli.create_branch(branch, "main")
|
||||
pgs.append(env.postgres.create_start(branch))
|
||||
|
||||
tenant_id = uuid.UUID(env.initial_tenant)
|
||||
tenant_id = env.initial_tenant
|
||||
|
||||
def collect_metrics(message: str) -> List[BranchMetrics]:
|
||||
with env.pageserver.http_client() as pageserver_http:
|
||||
@@ -92,8 +90,8 @@ def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
|
||||
latest_valid_lsn=branch_detail["latest_valid_lsn"],
|
||||
)
|
||||
for sk_m in sk_metrics:
|
||||
m.flush_lsns.append(sk_m.flush_lsn_inexact[timeline_id])
|
||||
m.commit_lsns.append(sk_m.commit_lsn_inexact[timeline_id])
|
||||
m.flush_lsns.append(sk_m.flush_lsn_inexact[(tenant_id.hex, timeline_id)])
|
||||
m.commit_lsns.append(sk_m.commit_lsn_inexact[(tenant_id.hex, timeline_id)])
|
||||
|
||||
for flush_lsn, commit_lsn in zip(m.flush_lsns, m.commit_lsns):
|
||||
# Invariant. May be < when transaction is in progress.
|
||||
@@ -183,9 +181,9 @@ def test_restarts(zenith_env_builder: ZenithEnvBuilder):
|
||||
n_acceptors = 3
|
||||
|
||||
zenith_env_builder.num_safekeepers = n_acceptors
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_restarts", "main"])
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_restarts", "main")
|
||||
pg = env.postgres.create_start('test_wal_acceptors_restarts')
|
||||
|
||||
# we rely upon autocommit after each statement
|
||||
@@ -220,9 +218,9 @@ def delayed_wal_acceptor_start(wa):
|
||||
# When majority of acceptors is offline, commits are expected to be frozen
|
||||
def test_unavailability(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 2
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_unavailability", "main"])
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_unavailability", "main")
|
||||
pg = env.postgres.create_start('test_wal_acceptors_unavailability')
|
||||
|
||||
# we rely upon autocommit after each statement
|
||||
@@ -291,9 +289,9 @@ def stop_value():
|
||||
def test_race_conditions(zenith_env_builder: ZenithEnvBuilder, stop_value):
|
||||
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_race_conditions", "main"])
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_race_conditions", "main")
|
||||
pg = env.postgres.create_start('test_wal_acceptors_race_conditions')
|
||||
|
||||
# we rely upon autocommit after each statement
|
||||
@@ -321,16 +319,16 @@ class ProposerPostgres(PgProtocol):
|
||||
def __init__(self,
|
||||
pgdata_dir: str,
|
||||
pg_bin,
|
||||
timeline_id: str,
|
||||
tenant_id: str,
|
||||
timeline_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID,
|
||||
listen_addr: str,
|
||||
port: int):
|
||||
super().__init__(host=listen_addr, port=port)
|
||||
super().__init__(host=listen_addr, port=port, username='zenith_admin')
|
||||
|
||||
self.pgdata_dir: str = pgdata_dir
|
||||
self.pg_bin: PgBin = pg_bin
|
||||
self.timeline_id: str = timeline_id
|
||||
self.tenant_id: str = tenant_id
|
||||
self.timeline_id: uuid.UUID = timeline_id
|
||||
self.tenant_id: uuid.UUID = tenant_id
|
||||
self.listen_addr: str = listen_addr
|
||||
self.port: int = port
|
||||
|
||||
@@ -350,8 +348,8 @@ class ProposerPostgres(PgProtocol):
|
||||
cfg = [
|
||||
"synchronous_standby_names = 'walproposer'\n",
|
||||
"shared_preload_libraries = 'zenith'\n",
|
||||
f"zenith.zenith_timeline = '{self.timeline_id}'\n",
|
||||
f"zenith.zenith_tenant = '{self.tenant_id}'\n",
|
||||
f"zenith.zenith_timeline = '{self.timeline_id.hex}'\n",
|
||||
f"zenith.zenith_tenant = '{self.tenant_id.hex}'\n",
|
||||
f"zenith.page_server_connstring = ''\n",
|
||||
f"wal_acceptors = '{wal_acceptors}'\n",
|
||||
f"listen_addresses = '{self.listen_addr}'\n",
|
||||
@@ -406,10 +404,10 @@ def test_sync_safekeepers(zenith_env_builder: ZenithEnvBuilder,
|
||||
# We don't really need the full environment for this test, just the
|
||||
# safekeepers would be enough.
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
timeline_id = uuid.uuid4().hex
|
||||
tenant_id = uuid.uuid4().hex
|
||||
timeline_id = uuid.uuid4()
|
||||
tenant_id = uuid.uuid4()
|
||||
|
||||
# write config for proposer
|
||||
pgdata_dir = os.path.join(env.repo_dir, "proposer_pgdata")
|
||||
@@ -456,9 +454,9 @@ def test_sync_safekeepers(zenith_env_builder: ZenithEnvBuilder,
|
||||
def test_timeline_status(zenith_env_builder: ZenithEnvBuilder):
|
||||
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
env.zenith_cli(["branch", "test_timeline_status", "main"])
|
||||
env.zenith_cli.create_branch("test_timeline_status", "main")
|
||||
pg = env.postgres.create_start('test_timeline_status')
|
||||
|
||||
wa = env.safekeepers[0]
|
||||
@@ -495,15 +493,15 @@ class SafekeeperEnv:
|
||||
self.bin_safekeeper = os.path.join(str(zenith_binpath), 'safekeeper')
|
||||
self.safekeepers: Optional[List[subprocess.CompletedProcess[Any]]] = None
|
||||
self.postgres: Optional[ProposerPostgres] = None
|
||||
self.tenant_id: Optional[str] = None
|
||||
self.timeline_id: Optional[str] = None
|
||||
self.tenant_id: Optional[uuid.UUID] = None
|
||||
self.timeline_id: Optional[uuid.UUID] = None
|
||||
|
||||
def init(self) -> "SafekeeperEnv":
|
||||
assert self.postgres is None, "postgres is already initialized"
|
||||
assert self.safekeepers is None, "safekeepers are already initialized"
|
||||
|
||||
self.timeline_id = uuid.uuid4().hex
|
||||
self.tenant_id = uuid.uuid4().hex
|
||||
self.timeline_id = uuid.uuid4()
|
||||
self.tenant_id = uuid.uuid4()
|
||||
mkdir_if_needed(str(self.repo_dir))
|
||||
|
||||
# Create config and a Safekeeper object for each safekeeper
|
||||
@@ -523,12 +521,7 @@ class SafekeeperEnv:
|
||||
http=self.port_distributor.get_port(),
|
||||
)
|
||||
|
||||
if self.num_safekeepers == 1:
|
||||
name = "single"
|
||||
else:
|
||||
name = f"sk{i}"
|
||||
|
||||
safekeeper_dir = os.path.join(self.repo_dir, name)
|
||||
safekeeper_dir = os.path.join(self.repo_dir, f"sk{i}")
|
||||
mkdir_if_needed(safekeeper_dir)
|
||||
|
||||
args = [
|
||||
@@ -539,6 +532,8 @@ class SafekeeperEnv:
|
||||
f"127.0.0.1:{port.http}",
|
||||
"-D",
|
||||
safekeeper_dir,
|
||||
"--id",
|
||||
str(i),
|
||||
"--daemonize"
|
||||
]
|
||||
|
||||
@@ -603,3 +598,91 @@ def test_safekeeper_without_pageserver(test_output_dir: str,
|
||||
env.postgres.safe_psql("insert into t select generate_series(1, 100)")
|
||||
res = env.postgres.safe_psql("select sum(i) from t")[0][0]
|
||||
assert res == 5050
|
||||
|
||||
|
||||
def test_replace_safekeeper(zenith_env_builder: ZenithEnvBuilder):
|
||||
def safekeepers_guc(env: ZenithEnv, sk_names: List[int]) -> str:
|
||||
return ','.join([f'localhost:{sk.port.pg}' for sk in env.safekeepers if sk.id in sk_names])
|
||||
|
||||
def execute_payload(pg: Postgres):
|
||||
with closing(pg.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# we rely upon autocommit after each statement
|
||||
# as waiting for acceptors happens there
|
||||
cur.execute('CREATE TABLE IF NOT EXISTS t(key int, value text)')
|
||||
cur.execute("INSERT INTO t VALUES (0, 'something')")
|
||||
cur.execute('SELECT SUM(key) FROM t')
|
||||
sum_before = cur.fetchone()[0]
|
||||
|
||||
cur.execute("INSERT INTO t SELECT generate_series(1,100000), 'payload'")
|
||||
cur.execute('SELECT SUM(key) FROM t')
|
||||
sum_after = cur.fetchone()[0]
|
||||
assert sum_after == sum_before + 5000050000
|
||||
|
||||
def show_statuses(safekeepers: List[Safekeeper], tenant_id: str, timeline_id: str):
|
||||
for sk in safekeepers:
|
||||
http_cli = sk.http_client()
|
||||
try:
|
||||
status = http_cli.timeline_status(tenant_id, timeline_id)
|
||||
log.info(f"Safekeeper {sk.id} status: {status}")
|
||||
except Exception as e:
|
||||
log.info(f"Safekeeper {sk.id} status error: {e}")
|
||||
|
||||
zenith_env_builder.num_safekeepers = 4
|
||||
env = zenith_env_builder.init_start()
|
||||
env.zenith_cli.create_branch("test_replace_safekeeper", "main")
|
||||
|
||||
log.info("Use only first 3 safekeepers")
|
||||
env.safekeepers[3].stop()
|
||||
active_safekeepers = [1, 2, 3]
|
||||
pg = env.postgres.create('test_replace_safekeeper')
|
||||
pg.adjust_for_wal_acceptors(safekeepers_guc(env, active_safekeepers))
|
||||
pg.start()
|
||||
|
||||
# learn zenith timeline from compute
|
||||
tenant_id = pg.safe_psql("show zenith.zenith_tenant")[0][0]
|
||||
timeline_id = pg.safe_psql("show zenith.zenith_timeline")[0][0]
|
||||
|
||||
execute_payload(pg)
|
||||
show_statuses(env.safekeepers, tenant_id, timeline_id)
|
||||
|
||||
log.info("Restart all safekeepers to flush everything")
|
||||
env.safekeepers[0].stop(immediate=True)
|
||||
execute_payload(pg)
|
||||
env.safekeepers[0].start()
|
||||
env.safekeepers[1].stop(immediate=True)
|
||||
execute_payload(pg)
|
||||
env.safekeepers[1].start()
|
||||
env.safekeepers[2].stop(immediate=True)
|
||||
execute_payload(pg)
|
||||
env.safekeepers[2].start()
|
||||
|
||||
env.safekeepers[0].stop(immediate=True)
|
||||
env.safekeepers[1].stop(immediate=True)
|
||||
env.safekeepers[2].stop(immediate=True)
|
||||
env.safekeepers[0].start()
|
||||
env.safekeepers[1].start()
|
||||
env.safekeepers[2].start()
|
||||
|
||||
execute_payload(pg)
|
||||
show_statuses(env.safekeepers, tenant_id, timeline_id)
|
||||
|
||||
log.info("Stop sk1 (simulate failure) and use only quorum of sk2 and sk3")
|
||||
env.safekeepers[0].stop(immediate=True)
|
||||
execute_payload(pg)
|
||||
show_statuses(env.safekeepers, tenant_id, timeline_id)
|
||||
|
||||
log.info("Recreate postgres to replace failed sk1 with new sk4")
|
||||
pg.stop_and_destroy().create('test_replace_safekeeper')
|
||||
active_safekeepers = [2, 3, 4]
|
||||
env.safekeepers[3].start()
|
||||
pg.adjust_for_wal_acceptors(safekeepers_guc(env, active_safekeepers))
|
||||
pg.start()
|
||||
|
||||
execute_payload(pg)
|
||||
show_statuses(env.safekeepers, tenant_id, timeline_id)
|
||||
|
||||
log.info("Stop sk2 to require quorum of sk3 and sk4 for normal work")
|
||||
env.safekeepers[1].stop(immediate=True)
|
||||
execute_payload(pg)
|
||||
show_statuses(env.safekeepers, tenant_id, timeline_id)
|
||||
|
||||
@@ -9,7 +9,6 @@ from fixtures.utils import lsn_from_hex, lsn_to_hex
|
||||
from typing import List
|
||||
|
||||
log = getLogger('root.wal_acceptor_async')
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
class BankClient(object):
|
||||
@@ -201,9 +200,9 @@ async def run_restarts_under_load(pg: Postgres, acceptors: List[Safekeeper], n_w
|
||||
# restart acceptors one by one, while executing and validating bank transactions
|
||||
def test_restarts_under_load(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_restarts_under_load", "main"])
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_restarts_under_load", "main")
|
||||
pg = env.postgres.create_start('test_wal_acceptors_restarts_under_load')
|
||||
|
||||
asyncio.run(run_restarts_under_load(pg, env.safekeepers))
|
||||
|
||||
@@ -1,31 +1,28 @@
|
||||
import json
|
||||
import uuid
|
||||
import requests
|
||||
|
||||
from psycopg2.extensions import cursor as PgCursor
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient
|
||||
from typing import cast
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def helper_compare_branch_list(page_server_cur: PgCursor, env: ZenithEnv, initial_tenant: str):
|
||||
def helper_compare_branch_list(pageserver_http_client: ZenithPageserverHttpClient,
|
||||
env: ZenithEnv,
|
||||
initial_tenant: uuid.UUID):
|
||||
"""
|
||||
Compare branches list returned by CLI and directly via API.
|
||||
Filters out branches created by other tests.
|
||||
"""
|
||||
|
||||
page_server_cur.execute(f'branch_list {initial_tenant}')
|
||||
branches_api = sorted(
|
||||
map(lambda b: cast(str, b['name']), json.loads(page_server_cur.fetchone()[0])))
|
||||
branches = pageserver_http_client.branch_list(initial_tenant)
|
||||
branches_api = sorted(map(lambda b: cast(str, b['name']), branches))
|
||||
branches_api = [b for b in branches_api if b.startswith('test_cli_') or b in ('empty', 'main')]
|
||||
|
||||
res = env.zenith_cli(["branch"])
|
||||
res.check_returncode()
|
||||
res = env.zenith_cli.list_branches()
|
||||
branches_cli = sorted(map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
|
||||
branches_cli = [b for b in branches_cli if b.startswith('test_cli_') or b in ('empty', 'main')]
|
||||
|
||||
res = env.zenith_cli(["branch", f"--tenantid={initial_tenant}"])
|
||||
res.check_returncode()
|
||||
res = env.zenith_cli.list_branches(tenant_id=initial_tenant)
|
||||
branches_cli_with_tenant_arg = sorted(
|
||||
map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
|
||||
branches_cli_with_tenant_arg = [
|
||||
@@ -37,24 +34,20 @@ def helper_compare_branch_list(page_server_cur: PgCursor, env: ZenithEnv, initia
|
||||
|
||||
def test_cli_branch_list(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
page_server_conn = env.pageserver.connect()
|
||||
page_server_cur = page_server_conn.cursor()
|
||||
pageserver_http_client = env.pageserver.http_client()
|
||||
|
||||
# Initial sanity check
|
||||
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
|
||||
|
||||
# Create a branch for us
|
||||
res = env.zenith_cli(["branch", "test_cli_branch_list_main", "empty"])
|
||||
assert res.stderr == ''
|
||||
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
|
||||
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
|
||||
env.zenith_cli.create_branch("test_cli_branch_list_main", "empty")
|
||||
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
|
||||
|
||||
# Create a nested branch
|
||||
res = env.zenith_cli(["branch", "test_cli_branch_list_nested", "test_cli_branch_list_main"])
|
||||
res = env.zenith_cli.create_branch("test_cli_branch_list_nested", "test_cli_branch_list_main")
|
||||
assert res.stderr == ''
|
||||
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
|
||||
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
|
||||
|
||||
# Check that all new branches are visible via CLI
|
||||
res = env.zenith_cli(["branch"])
|
||||
res = env.zenith_cli.list_branches()
|
||||
assert res.stderr == ''
|
||||
branches_cli = sorted(map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
|
||||
|
||||
@@ -62,12 +55,11 @@ def test_cli_branch_list(zenith_simple_env: ZenithEnv):
|
||||
assert 'test_cli_branch_list_nested' in branches_cli
|
||||
|
||||
|
||||
def helper_compare_tenant_list(page_server_cur: PgCursor, env: ZenithEnv):
|
||||
page_server_cur.execute(f'tenant_list')
|
||||
tenants_api = sorted(
|
||||
map(lambda t: cast(str, t['id']), json.loads(page_server_cur.fetchone()[0])))
|
||||
def helper_compare_tenant_list(pageserver_http_client: ZenithPageserverHttpClient, env: ZenithEnv):
|
||||
tenants = pageserver_http_client.tenant_list()
|
||||
tenants_api = sorted(map(lambda t: cast(str, t['id']), tenants))
|
||||
|
||||
res = env.zenith_cli(["tenant", "list"])
|
||||
res = env.zenith_cli.list_tenants()
|
||||
assert res.stderr == ''
|
||||
tenants_cli = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
|
||||
|
||||
@@ -76,32 +68,62 @@ def helper_compare_tenant_list(page_server_cur: PgCursor, env: ZenithEnv):
|
||||
|
||||
def test_cli_tenant_list(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
page_server_conn = env.pageserver.connect()
|
||||
page_server_cur = page_server_conn.cursor()
|
||||
|
||||
pageserver_http_client = env.pageserver.http_client()
|
||||
# Initial sanity check
|
||||
helper_compare_tenant_list(page_server_cur, env)
|
||||
helper_compare_tenant_list(pageserver_http_client, env)
|
||||
|
||||
# Create new tenant
|
||||
tenant1 = uuid.uuid4().hex
|
||||
res = env.zenith_cli(["tenant", "create", tenant1])
|
||||
res.check_returncode()
|
||||
tenant1 = uuid.uuid4()
|
||||
env.zenith_cli.create_tenant(tenant1)
|
||||
|
||||
# check tenant1 appeared
|
||||
helper_compare_tenant_list(page_server_cur, env)
|
||||
helper_compare_tenant_list(pageserver_http_client, env)
|
||||
|
||||
# Create new tenant
|
||||
tenant2 = uuid.uuid4().hex
|
||||
res = env.zenith_cli(["tenant", "create", tenant2])
|
||||
res.check_returncode()
|
||||
tenant2 = uuid.uuid4()
|
||||
env.zenith_cli.create_tenant(tenant2)
|
||||
|
||||
# check tenant2 appeared
|
||||
helper_compare_tenant_list(page_server_cur, env)
|
||||
helper_compare_tenant_list(pageserver_http_client, env)
|
||||
|
||||
res = env.zenith_cli(["tenant", "list"])
|
||||
res.check_returncode()
|
||||
res = env.zenith_cli.list_tenants()
|
||||
tenants = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
|
||||
|
||||
assert env.initial_tenant in tenants
|
||||
assert tenant1 in tenants
|
||||
assert tenant2 in tenants
|
||||
assert env.initial_tenant.hex in tenants
|
||||
assert tenant1.hex in tenants
|
||||
assert tenant2.hex in tenants
|
||||
|
||||
|
||||
def test_cli_ipv4_listeners(zenith_env_builder: ZenithEnvBuilder):
|
||||
# Start with single sk
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
# Connect to sk port on v4 loopback
|
||||
res = requests.get(f'http://127.0.0.1:{env.safekeepers[0].port.http}/v1/status')
|
||||
assert res.ok
|
||||
|
||||
# FIXME Test setup is using localhost:xx in ps config.
|
||||
# Perhaps consider switching test suite to v4 loopback.
|
||||
|
||||
# Connect to ps port on v4 loopback
|
||||
# res = requests.get(f'http://127.0.0.1:{env.pageserver.service_port.http}/v1/status')
|
||||
# assert res.ok
|
||||
|
||||
|
||||
def test_cli_start_stop(zenith_env_builder: ZenithEnvBuilder):
|
||||
# Start with single sk
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init_start()
|
||||
|
||||
# Stop default ps/sk
|
||||
env.zenith_cli.pageserver_stop()
|
||||
env.zenith_cli.safekeeper_stop()
|
||||
|
||||
# Default start
|
||||
res = env.zenith_cli.raw_cli(["start"])
|
||||
res.check_returncode()
|
||||
|
||||
# Default stop
|
||||
res = env.zenith_cli.raw_cli(["stop"])
|
||||
res.check_returncode()
|
||||
|
||||
@@ -3,15 +3,11 @@ import os
|
||||
from fixtures.utils import mkdir_if_needed
|
||||
from fixtures.zenith_fixtures import ZenithEnv, base_dir, pg_distrib_dir
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys):
|
||||
env = zenith_simple_env
|
||||
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_isolation", "empty"])
|
||||
|
||||
env.zenith_cli.create_branch("test_isolation", "empty")
|
||||
# Connect to postgres and create a database called "regression".
|
||||
# isolation tests use prepared transactions, so enable them
|
||||
pg = env.postgres.create_start('test_isolation', config_lines=['max_prepared_transactions=100'])
|
||||
|
||||
@@ -3,15 +3,11 @@ import os
|
||||
from fixtures.utils import mkdir_if_needed
|
||||
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content, base_dir, pg_distrib_dir
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin, capsys):
|
||||
env = zenith_simple_env
|
||||
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_pg_regress", "empty"])
|
||||
|
||||
env.zenith_cli.create_branch("test_pg_regress", "empty")
|
||||
# Connect to postgres and create a database called "regression".
|
||||
pg = env.postgres.create_start('test_pg_regress')
|
||||
pg.safe_psql('CREATE DATABASE regression')
|
||||
|
||||
@@ -7,15 +7,11 @@ from fixtures.zenith_fixtures import (ZenithEnv,
|
||||
pg_distrib_dir)
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys):
|
||||
env = zenith_simple_env
|
||||
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_zenith_regress", "empty"])
|
||||
|
||||
env.zenith_cli.create_branch("test_zenith_regress", "empty")
|
||||
# Connect to postgres and create a database called "regression".
|
||||
pg = env.postgres.create_start('test_zenith_regress')
|
||||
pg.safe_psql('CREATE DATABASE regression')
|
||||
|
||||
@@ -1 +1,6 @@
|
||||
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
|
||||
pytest_plugins = (
|
||||
"fixtures.zenith_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.compare_fixtures",
|
||||
"fixtures.slow",
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import timeit
|
||||
import calendar
|
||||
import enum
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import pytest
|
||||
from _pytest.config import Config
|
||||
from _pytest.terminal import TerminalReporter
|
||||
@@ -26,8 +27,6 @@ bencmark, and then record the result by calling zenbenchmark.record. For example
|
||||
import timeit
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
|
||||
|
||||
def test_mybench(zenith_simple_env: env, zenbenchmark):
|
||||
|
||||
# Initialize the test
|
||||
@@ -40,6 +39,8 @@ def test_mybench(zenith_simple_env: env, zenbenchmark):
|
||||
# Record another measurement
|
||||
zenbenchmark.record('speed_of_light', 300000, 'km/s')
|
||||
|
||||
There's no need to import this file to use it. It should be declared as a plugin
|
||||
inside conftest.py, and that makes it available to all tests.
|
||||
|
||||
You can measure multiple things in one test, and record each one with a separate
|
||||
call to zenbenchmark. For example, you could time the bulk loading that happens
|
||||
@@ -276,11 +277,11 @@ class ZenithBenchmarker:
|
||||
assert matches
|
||||
return int(round(float(matches.group(1))))
|
||||
|
||||
def get_timeline_size(self, repo_dir: Path, tenantid: str, timelineid: str):
|
||||
def get_timeline_size(self, repo_dir: Path, tenantid: uuid.UUID, timelineid: str):
|
||||
"""
|
||||
Calculate the on-disk size of a timeline
|
||||
"""
|
||||
path = "{}/tenants/{}/timelines/{}".format(repo_dir, tenantid, timelineid)
|
||||
path = "{}/tenants/{}/timelines/{}".format(repo_dir, tenantid.hex, timelineid)
|
||||
|
||||
totalbytes = 0
|
||||
for root, dirs, files in os.walk(path):
|
||||
|
||||
200
test_runner/fixtures/compare_fixtures.py
Normal file
200
test_runner/fixtures/compare_fixtures.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import pytest
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, ZenithEnv
|
||||
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
|
||||
|
||||
# Type-related stuff
|
||||
from typing import Iterator
|
||||
|
||||
|
||||
class PgCompare(ABC):
|
||||
"""Common interface of all postgres implementations, useful for benchmarks.
|
||||
|
||||
This class is a helper class for the zenith_with_baseline fixture. See its documentation
|
||||
for more details.
|
||||
"""
|
||||
@property
|
||||
@abstractmethod
|
||||
def pg(self) -> PgProtocol:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def pg_bin(self) -> PgBin:
|
||||
pass
|
||||
|
||||
@property
|
||||
def zenbenchmark(self) -> ZenithBenchmarker:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def flush(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def report_peak_memory_use(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def report_size(self) -> None:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def record_pageserver_writes(self, out_name):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def record_duration(self, out_name):
|
||||
pass
|
||||
|
||||
|
||||
class ZenithCompare(PgCompare):
|
||||
"""PgCompare interface for the zenith stack."""
|
||||
def __init__(self,
|
||||
zenbenchmark: ZenithBenchmarker,
|
||||
zenith_simple_env: ZenithEnv,
|
||||
pg_bin: PgBin,
|
||||
branch_name):
|
||||
self.env = zenith_simple_env
|
||||
self._zenbenchmark = zenbenchmark
|
||||
self._pg_bin = pg_bin
|
||||
|
||||
# We only use one branch and one timeline
|
||||
self.branch = branch_name
|
||||
self.env.zenith_cli.create_branch(self.branch, "empty")
|
||||
self._pg = self.env.postgres.create_start(self.branch)
|
||||
self.timeline = self.pg.safe_psql("SHOW zenith.zenith_timeline")[0][0]
|
||||
|
||||
# Long-lived cursor, useful for flushing
|
||||
self.psconn = self.env.pageserver.connect()
|
||||
self.pscur = self.psconn.cursor()
|
||||
|
||||
@property
|
||||
def pg(self):
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
def zenbenchmark(self):
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
def pg_bin(self):
|
||||
return self._pg_bin
|
||||
|
||||
def flush(self):
|
||||
self.pscur.execute(f"do_gc {self.env.initial_tenant.hex} {self.timeline} 0")
|
||||
|
||||
def report_peak_memory_use(self) -> None:
|
||||
self.zenbenchmark.record("peak_mem",
|
||||
self.zenbenchmark.get_peak_mem(self.env.pageserver) / 1024,
|
||||
'MB',
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
|
||||
def report_size(self) -> None:
|
||||
timeline_size = self.zenbenchmark.get_timeline_size(self.env.repo_dir,
|
||||
self.env.initial_tenant,
|
||||
self.timeline)
|
||||
self.zenbenchmark.record('size',
|
||||
timeline_size / (1024 * 1024),
|
||||
'MB',
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
|
||||
def record_pageserver_writes(self, out_name):
|
||||
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
|
||||
|
||||
def record_duration(self, out_name):
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
class VanillaCompare(PgCompare):
|
||||
"""PgCompare interface for vanilla postgres."""
|
||||
def __init__(self, zenbenchmark, vanilla_pg: VanillaPostgres):
|
||||
self._pg = vanilla_pg
|
||||
self._zenbenchmark = zenbenchmark
|
||||
vanilla_pg.configure(['shared_buffers=1MB'])
|
||||
vanilla_pg.start()
|
||||
|
||||
# Long-lived cursor, useful for flushing
|
||||
self.conn = self.pg.connect()
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
@property
|
||||
def pg(self):
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
def zenbenchmark(self):
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
def pg_bin(self):
|
||||
return self._pg.pg_bin
|
||||
|
||||
def flush(self):
|
||||
self.cur.execute("checkpoint")
|
||||
|
||||
def report_peak_memory_use(self) -> None:
|
||||
pass # TODO find something
|
||||
|
||||
def report_size(self) -> None:
|
||||
data_size = self.pg.get_subdir_size('base')
|
||||
self.zenbenchmark.record('data_size',
|
||||
data_size / (1024 * 1024),
|
||||
'MB',
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
wal_size = self.pg.get_subdir_size('pg_wal')
|
||||
self.zenbenchmark.record('wal_size',
|
||||
wal_size / (1024 * 1024),
|
||||
'MB',
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
|
||||
@contextmanager
|
||||
def record_pageserver_writes(self, out_name):
|
||||
yield # Do nothing
|
||||
|
||||
def record_duration(self, out_name):
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def zenith_compare(request, zenbenchmark, pg_bin, zenith_simple_env) -> ZenithCompare:
|
||||
branch_name = request.node.name
|
||||
return ZenithCompare(zenbenchmark, zenith_simple_env, pg_bin, branch_name)
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare:
|
||||
return VanillaCompare(zenbenchmark, vanilla_pg)
|
||||
|
||||
|
||||
@pytest.fixture(params=["vanilla_compare", "zenith_compare"], ids=["vanilla", "zenith"])
|
||||
def zenith_with_baseline(request) -> PgCompare:
|
||||
"""Parameterized fixture that helps compare zenith against vanilla postgres.
|
||||
|
||||
A test that uses this fixture turns into a parameterized test that runs against:
|
||||
1. A vanilla postgres instance
|
||||
2. A simple zenith env (see zenith_simple_env)
|
||||
3. Possibly other postgres protocol implementations.
|
||||
|
||||
The main goal of this fixture is to make it easier for people to read and write
|
||||
performance tests. Easy test writing leads to more tests.
|
||||
|
||||
Perfect encapsulation of the postgres implementations is **not** a goal because
|
||||
it's impossible. Operational and configuration differences in the different
|
||||
implementations sometimes matter, and the writer of the test should be mindful
|
||||
of that.
|
||||
|
||||
If a test requires some one-off special implementation-specific logic, use of
|
||||
isinstance(zenith_with_baseline, ZenithCompare) is encouraged. Though if that
|
||||
implementation-specific logic is widely useful across multiple tests, it might
|
||||
make sense to add methods to the PgCompare class.
|
||||
"""
|
||||
fixture = request.getfixturevalue(request.param)
|
||||
if isinstance(fixture, PgCompare):
|
||||
return fixture
|
||||
else:
|
||||
raise AssertionError(f"test error: fixture {request.param} is not PgCompare")
|
||||
26
test_runner/fixtures/slow.py
Normal file
26
test_runner/fixtures/slow.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
"""
|
||||
This plugin allows tests to be marked as slow using pytest.mark.slow. By default slow
|
||||
tests are excluded. They need to be specifically requested with the --runslow flag in
|
||||
order to run.
|
||||
|
||||
Copied from here: https://docs.pytest.org/en/latest/example/simple.html
|
||||
"""
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "slow: mark test as slow to run")
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if config.getoption("--runslow"):
|
||||
# --runslow given in cli: do not skip slow tests
|
||||
return
|
||||
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
|
||||
for item in items:
|
||||
if "slow" in item.keywords:
|
||||
item.add_marker(skip_slow)
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import textwrap
|
||||
from cached_property import cached_property
|
||||
import asyncpg
|
||||
import os
|
||||
@@ -26,7 +27,7 @@ from dataclasses import dataclass
|
||||
|
||||
# Type-related stuff
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple
|
||||
from typing_extensions import Literal
|
||||
import pytest
|
||||
|
||||
@@ -44,9 +45,8 @@ the standard pytest.fixture with some extra behavior.
|
||||
There are several environment variables that can control the running of tests:
|
||||
ZENITH_BIN, POSTGRES_DISTRIB_DIR, etc. See README.md for more information.
|
||||
|
||||
To use fixtures in a test file, add this line of code:
|
||||
|
||||
>>> pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
There's no need to import this file to use it. It should be declared as a plugin
|
||||
inside conftest.py, and that makes it available to all tests.
|
||||
|
||||
Don't import functions from this file, or pytest will emit warnings. Instead
|
||||
put directly-importable functions into utils.py or another separate file.
|
||||
@@ -184,6 +184,16 @@ def worker_base_port(worker_seq_no: int):
|
||||
return BASE_PORT + worker_seq_no * WORKER_PORT_NUM
|
||||
|
||||
|
||||
def get_dir_size(path: str) -> int:
|
||||
"""Return size in bytes."""
|
||||
totalbytes = 0
|
||||
for root, dirs, files in os.walk(path):
|
||||
for name in files:
|
||||
totalbytes += os.path.getsize(os.path.join(root, name))
|
||||
|
||||
return totalbytes
|
||||
|
||||
|
||||
def can_bind(host: str, port: int) -> bool:
|
||||
"""
|
||||
Check whether a host:port is available to bind for listening
|
||||
@@ -230,7 +240,7 @@ class PgProtocol:
|
||||
def __init__(self, host: str, port: int, username: Optional[str] = None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username or "zenith_admin"
|
||||
self.username = username
|
||||
|
||||
def connstr(self,
|
||||
*,
|
||||
@@ -242,10 +252,15 @@ class PgProtocol:
|
||||
"""
|
||||
|
||||
username = username or self.username
|
||||
res = f'host={self.host} port={self.port} user={username} dbname={dbname}'
|
||||
if not password:
|
||||
return res
|
||||
return f'{res} password={password}'
|
||||
res = f'host={self.host} port={self.port} dbname={dbname}'
|
||||
|
||||
if username:
|
||||
res = f'{res} user={username}'
|
||||
|
||||
if password:
|
||||
res = f'{res} password={password}'
|
||||
|
||||
return res
|
||||
|
||||
# autocommit=True here by default because that's what we need most of the time
|
||||
def connect(self,
|
||||
@@ -410,6 +425,14 @@ class ZenithEnvBuilder:
|
||||
self.env = ZenithEnv(self)
|
||||
return self.env
|
||||
|
||||
def start(self):
|
||||
self.env.start()
|
||||
|
||||
def init_start(self) -> ZenithEnv:
|
||||
env = self.init()
|
||||
self.start()
|
||||
return env
|
||||
|
||||
"""
|
||||
Sets up the pageserver to use the local fs at the `test_dir/local_fs_remote_storage` path.
|
||||
Errors, if the pageserver has some remote storage configuration already, unless `force_enable` is not set to `True`.
|
||||
@@ -501,6 +524,7 @@ class ZenithEnv:
|
||||
self.rust_log_override = config.rust_log_override
|
||||
self.port_distributor = config.port_distributor
|
||||
self.s3_mock_server = config.s3_mock_server
|
||||
self.zenith_cli = ZenithCli(env=self)
|
||||
|
||||
self.postgres = PostgresFactory(self)
|
||||
|
||||
@@ -508,12 +532,12 @@ class ZenithEnv:
|
||||
|
||||
# generate initial tenant ID here instead of letting 'zenith init' generate it,
|
||||
# so that we don't need to dig it out of the config file afterwards.
|
||||
self.initial_tenant = uuid.uuid4().hex
|
||||
self.initial_tenant = uuid.uuid4()
|
||||
|
||||
# Create a config file corresponding to the options
|
||||
toml = f"""
|
||||
default_tenantid = '{self.initial_tenant}'
|
||||
"""
|
||||
toml = textwrap.dedent(f"""
|
||||
default_tenantid = '{self.initial_tenant.hex}'
|
||||
""")
|
||||
|
||||
# Create config for pageserver
|
||||
pageserver_port = PageserverPort(
|
||||
@@ -522,12 +546,13 @@ default_tenantid = '{self.initial_tenant}'
|
||||
)
|
||||
pageserver_auth_type = "ZenithJWT" if config.pageserver_auth_enabled else "Trust"
|
||||
|
||||
toml += f"""
|
||||
[pageserver]
|
||||
listen_pg_addr = 'localhost:{pageserver_port.pg}'
|
||||
listen_http_addr = 'localhost:{pageserver_port.http}'
|
||||
auth_type = '{pageserver_auth_type}'
|
||||
"""
|
||||
toml += textwrap.dedent(f"""
|
||||
[pageserver]
|
||||
id=1
|
||||
listen_pg_addr = 'localhost:{pageserver_port.pg}'
|
||||
listen_http_addr = 'localhost:{pageserver_port.http}'
|
||||
auth_type = '{pageserver_auth_type}'
|
||||
""")
|
||||
|
||||
# Create a corresponding ZenithPageserver object
|
||||
self.pageserver = ZenithPageserver(self,
|
||||
@@ -540,33 +565,22 @@ auth_type = '{pageserver_auth_type}'
|
||||
pg=self.port_distributor.get_port(),
|
||||
http=self.port_distributor.get_port(),
|
||||
)
|
||||
|
||||
if config.num_safekeepers == 1:
|
||||
name = "single"
|
||||
else:
|
||||
name = f"sk{i}"
|
||||
toml += f"""
|
||||
[[safekeepers]]
|
||||
name = '{name}'
|
||||
pg_port = {port.pg}
|
||||
http_port = {port.http}
|
||||
sync = false # Disable fsyncs to make the tests go faster
|
||||
"""
|
||||
safekeeper = Safekeeper(env=self, name=name, port=port)
|
||||
id = i # assign ids sequentially
|
||||
toml += textwrap.dedent(f"""
|
||||
[[safekeepers]]
|
||||
id = {id}
|
||||
pg_port = {port.pg}
|
||||
http_port = {port.http}
|
||||
sync = false # Disable fsyncs to make the tests go faster
|
||||
""")
|
||||
safekeeper = Safekeeper(env=self, id=id, port=port)
|
||||
self.safekeepers.append(safekeeper)
|
||||
|
||||
log.info(f"Config: {toml}")
|
||||
|
||||
# Run 'zenith init' using the config file we constructed
|
||||
with tempfile.NamedTemporaryFile(mode='w+') as tmp:
|
||||
tmp.write(toml)
|
||||
tmp.flush()
|
||||
|
||||
cmd = ['init', f'--config={tmp.name}']
|
||||
append_pageserver_param_overrides(cmd, config.pageserver_remote_storage)
|
||||
|
||||
self.zenith_cli(cmd)
|
||||
self.zenith_cli.init(toml)
|
||||
|
||||
def start(self):
|
||||
# Start up the page server and all the safekeepers
|
||||
self.pageserver.start()
|
||||
|
||||
@@ -577,69 +591,12 @@ sync = false # Disable fsyncs to make the tests go faster
|
||||
""" Get list of safekeeper endpoints suitable for wal_acceptors GUC """
|
||||
return ','.join([f'localhost:{wa.port.pg}' for wa in self.safekeepers])
|
||||
|
||||
def create_tenant(self, tenant_id: Optional[str] = None):
|
||||
def create_tenant(self, tenant_id: Optional[uuid.UUID] = None) -> uuid.UUID:
|
||||
if tenant_id is None:
|
||||
tenant_id = uuid.uuid4().hex
|
||||
res = self.zenith_cli(['tenant', 'create', tenant_id])
|
||||
res.check_returncode()
|
||||
tenant_id = uuid.uuid4()
|
||||
self.zenith_cli.create_tenant(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
def zenith_cli(self, arguments: List[str]) -> 'subprocess.CompletedProcess[str]':
|
||||
"""
|
||||
Run "zenith" with the specified arguments.
|
||||
|
||||
Arguments must be in list form, e.g. ['pg', 'create']
|
||||
|
||||
Return both stdout and stderr, which can be accessed as
|
||||
|
||||
>>> result = env.zenith_cli(...)
|
||||
>>> assert result.stderr == ""
|
||||
>>> log.info(result.stdout)
|
||||
"""
|
||||
|
||||
assert type(arguments) == list
|
||||
|
||||
bin_zenith = os.path.join(str(zenith_binpath), 'zenith')
|
||||
|
||||
args = [bin_zenith] + arguments
|
||||
log.info('Running command "{}"'.format(' '.join(args)))
|
||||
log.info(f'Running in "{self.repo_dir}"')
|
||||
|
||||
env_vars = os.environ.copy()
|
||||
env_vars['ZENITH_REPO_DIR'] = str(self.repo_dir)
|
||||
env_vars['POSTGRES_DISTRIB_DIR'] = str(pg_distrib_dir)
|
||||
|
||||
if self.rust_log_override is not None:
|
||||
env_vars['RUST_LOG'] = self.rust_log_override
|
||||
|
||||
# Pass coverage settings
|
||||
var = 'LLVM_PROFILE_FILE'
|
||||
val = os.environ.get(var)
|
||||
if val:
|
||||
env_vars[var] = val
|
||||
|
||||
# Intercept CalledProcessError and print more info
|
||||
try:
|
||||
res = subprocess.run(args,
|
||||
env=env_vars,
|
||||
check=True,
|
||||
universal_newlines=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
log.info(f"Run success: {res.stdout}")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
# this way command output will be in recorded and shown in CI in failure message
|
||||
msg = f"""\
|
||||
Run failed: {exc}
|
||||
stdout: {exc.stdout}
|
||||
stderr: {exc.stderr}
|
||||
"""
|
||||
log.info(msg)
|
||||
|
||||
raise Exception(msg) from exc
|
||||
|
||||
return res
|
||||
|
||||
@cached_property
|
||||
def auth_keys(self) -> AuthKeys:
|
||||
pub = (Path(self.repo_dir) / 'auth_public_key.pem').read_bytes()
|
||||
@@ -664,10 +621,10 @@ def _shared_simple_env(request: Any, port_distributor) -> Iterator[ZenithEnv]:
|
||||
|
||||
with ZenithEnvBuilder(Path(repo_dir), port_distributor) as builder:
|
||||
|
||||
env = builder.init()
|
||||
env = builder.init_start()
|
||||
|
||||
# For convenience in tests, create a branch from the freshly-initialized cluster.
|
||||
env.zenith_cli(["branch", "empty", "main"])
|
||||
env.zenith_cli.create_branch("empty", "main")
|
||||
|
||||
# Return the builder to the caller
|
||||
yield env
|
||||
@@ -698,7 +655,7 @@ def zenith_env_builder(test_output_dir, port_distributor) -> Iterator[ZenithEnvB
|
||||
To use, define 'zenith_env_builder' fixture in your test to get access to the
|
||||
builder object. Set properties on it to describe the environment.
|
||||
Finally, initialize and start up the environment by calling
|
||||
zenith_env_builder.init().
|
||||
zenith_env_builder.init_start().
|
||||
|
||||
After the initialization, you can launch compute nodes by calling
|
||||
the functions in the 'env.postgres' factory object, stop/start the
|
||||
@@ -713,6 +670,10 @@ def zenith_env_builder(test_output_dir, port_distributor) -> Iterator[ZenithEnvB
|
||||
yield builder
|
||||
|
||||
|
||||
class ZenithPageserverApiException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ZenithPageserverHttpClient(requests.Session):
|
||||
def __init__(self, port: int, auth_token: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
@@ -722,22 +683,32 @@ class ZenithPageserverHttpClient(requests.Session):
|
||||
if auth_token is not None:
|
||||
self.headers['Authorization'] = f'Bearer {auth_token}'
|
||||
|
||||
def verbose_error(self, res: requests.Response):
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
try:
|
||||
msg = res.json()['msg']
|
||||
except:
|
||||
msg = ''
|
||||
raise ZenithPageserverApiException(msg) from e
|
||||
|
||||
def check_status(self):
|
||||
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
|
||||
|
||||
def timeline_attach(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
|
||||
res = self.post(
|
||||
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}/attach", )
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
|
||||
def timeline_detach(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
|
||||
res = self.post(
|
||||
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}/detach", )
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
|
||||
def branch_list(self, tenant_id: uuid.UUID) -> List[Dict[Any, Any]]:
|
||||
res = self.get(f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}")
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, list)
|
||||
return res_json
|
||||
@@ -749,7 +720,7 @@ class ZenithPageserverHttpClient(requests.Session):
|
||||
'name': name,
|
||||
'start_point': start_point,
|
||||
})
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, dict)
|
||||
return res_json
|
||||
@@ -758,14 +729,14 @@ class ZenithPageserverHttpClient(requests.Session):
|
||||
res = self.get(
|
||||
f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}/{name}?include-non-incremental-logical-size=1",
|
||||
)
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, dict)
|
||||
return res_json
|
||||
|
||||
def tenant_list(self) -> List[Dict[Any, Any]]:
|
||||
res = self.get(f"http://localhost:{self.port}/v1/tenant")
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, list)
|
||||
return res_json
|
||||
@@ -777,27 +748,27 @@ class ZenithPageserverHttpClient(requests.Session):
|
||||
'tenant_id': tenant_id.hex,
|
||||
},
|
||||
)
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
return res.json()
|
||||
|
||||
def timeline_list(self, tenant_id: uuid.UUID) -> List[str]:
|
||||
res = self.get(f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}")
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, list)
|
||||
return res_json
|
||||
|
||||
def timeline_detail(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
|
||||
def timeline_detail(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID) -> Dict[Any, Any]:
|
||||
res = self.get(
|
||||
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}")
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, dict)
|
||||
return res_json
|
||||
|
||||
def get_metrics(self) -> str:
|
||||
res = self.get(f"http://localhost:{self.port}/metrics")
|
||||
res.raise_for_status()
|
||||
self.verbose_error(res)
|
||||
return res.text
|
||||
|
||||
|
||||
@@ -824,6 +795,190 @@ class S3Storage:
|
||||
RemoteStorage = Union[LocalFsStorage, S3Storage]
|
||||
|
||||
|
||||
class ZenithCli:
|
||||
"""
|
||||
A typed wrapper around the `zenith` CLI tool.
|
||||
Supports main commands via typed methods and a way to run arbitrary command directly via CLI.
|
||||
"""
|
||||
def __init__(self, env: ZenithEnv) -> None:
|
||||
self.env = env
|
||||
pass
|
||||
|
||||
def create_tenant(self, tenant_id: Optional[uuid.UUID] = None) -> uuid.UUID:
|
||||
if tenant_id is None:
|
||||
tenant_id = uuid.uuid4()
|
||||
self.raw_cli(['tenant', 'create', tenant_id.hex])
|
||||
return tenant_id
|
||||
|
||||
def list_tenants(self) -> 'subprocess.CompletedProcess[str]':
|
||||
return self.raw_cli(['tenant', 'list'])
|
||||
|
||||
def create_branch(self,
|
||||
branch_name: str,
|
||||
starting_point: str,
|
||||
tenant_id: Optional[uuid.UUID] = None) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['branch']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
args.extend([branch_name, starting_point])
|
||||
|
||||
return self.raw_cli(args)
|
||||
|
||||
def list_branches(self,
|
||||
tenant_id: Optional[uuid.UUID] = None) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['branch']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
return self.raw_cli(args)
|
||||
|
||||
def init(self, config_toml: str) -> 'subprocess.CompletedProcess[str]':
|
||||
with tempfile.NamedTemporaryFile(mode='w+') as tmp:
|
||||
tmp.write(config_toml)
|
||||
tmp.flush()
|
||||
|
||||
cmd = ['init', f'--config={tmp.name}']
|
||||
append_pageserver_param_overrides(cmd, self.env.pageserver.remote_storage)
|
||||
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def pageserver_start(self, overrides=()) -> 'subprocess.CompletedProcess[str]':
|
||||
start_args = ['pageserver', 'start', *overrides]
|
||||
|
||||
append_pageserver_param_overrides(start_args, self.env.pageserver.remote_storage)
|
||||
return self.raw_cli(start_args)
|
||||
|
||||
def pageserver_stop(self, immediate=False) -> 'subprocess.CompletedProcess[str]':
|
||||
cmd = ['pageserver', 'stop']
|
||||
if immediate:
|
||||
cmd.extend(['-m', 'immediate'])
|
||||
|
||||
log.info(f"Stopping pageserver with {cmd}")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def safekeeper_start(self, id: int) -> 'subprocess.CompletedProcess[str]':
|
||||
return self.raw_cli(['safekeeper', 'start', str(id)])
|
||||
|
||||
def safekeeper_stop(self,
|
||||
id: Optional[int] = None,
|
||||
immediate=False) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['safekeeper', 'stop']
|
||||
if id is not None:
|
||||
args.extend(str(id))
|
||||
if immediate:
|
||||
args.extend(['-m', 'immediate'])
|
||||
return self.raw_cli(args)
|
||||
|
||||
def pg_create(
|
||||
self,
|
||||
node_name: str,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
timeline_spec: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['pg', 'create']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
if port is not None:
|
||||
args.append(f'--port={port}')
|
||||
args.append(node_name)
|
||||
if timeline_spec is not None:
|
||||
args.append(timeline_spec)
|
||||
return self.raw_cli(args)
|
||||
|
||||
def pg_start(
|
||||
self,
|
||||
node_name: str,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
timeline_spec: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['pg', 'start']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
if port is not None:
|
||||
args.append(f'--port={port}')
|
||||
args.append(node_name)
|
||||
if timeline_spec is not None:
|
||||
args.append(timeline_spec)
|
||||
|
||||
return self.raw_cli(args)
|
||||
|
||||
def pg_stop(
|
||||
self,
|
||||
node_name: str,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
destroy=False,
|
||||
) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['pg', 'stop']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
if destroy:
|
||||
args.append('--destroy')
|
||||
args.append(node_name)
|
||||
|
||||
return self.raw_cli(args)
|
||||
|
||||
def raw_cli(self,
|
||||
arguments: List[str],
|
||||
check_return_code=True) -> 'subprocess.CompletedProcess[str]':
|
||||
"""
|
||||
Run "zenith" with the specified arguments.
|
||||
|
||||
Arguments must be in list form, e.g. ['pg', 'create']
|
||||
|
||||
Return both stdout and stderr, which can be accessed as
|
||||
|
||||
>>> result = env.zenith_cli.raw_cli(...)
|
||||
>>> assert result.stderr == ""
|
||||
>>> log.info(result.stdout)
|
||||
"""
|
||||
|
||||
assert type(arguments) == list
|
||||
|
||||
bin_zenith = os.path.join(str(zenith_binpath), 'zenith')
|
||||
|
||||
args = [bin_zenith] + arguments
|
||||
log.info('Running command "{}"'.format(' '.join(args)))
|
||||
log.info(f'Running in "{self.env.repo_dir}"')
|
||||
|
||||
env_vars = os.environ.copy()
|
||||
env_vars['ZENITH_REPO_DIR'] = str(self.env.repo_dir)
|
||||
env_vars['POSTGRES_DISTRIB_DIR'] = str(pg_distrib_dir)
|
||||
|
||||
if self.env.rust_log_override is not None:
|
||||
env_vars['RUST_LOG'] = self.env.rust_log_override
|
||||
|
||||
# Pass coverage settings
|
||||
var = 'LLVM_PROFILE_FILE'
|
||||
val = os.environ.get(var)
|
||||
if val:
|
||||
env_vars[var] = val
|
||||
|
||||
# Intercept CalledProcessError and print more info
|
||||
try:
|
||||
res = subprocess.run(args,
|
||||
env=env_vars,
|
||||
check=True,
|
||||
universal_newlines=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
log.info(f"Run success: {res.stdout}")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
# this way command output will be in recorded and shown in CI in failure message
|
||||
msg = f"""\
|
||||
Run failed: {exc}
|
||||
stdout: {exc.stdout}
|
||||
stderr: {exc.stderr}
|
||||
"""
|
||||
log.info(msg)
|
||||
|
||||
raise Exception(msg) from exc
|
||||
|
||||
if check_return_code:
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
|
||||
class ZenithPageserver(PgProtocol):
|
||||
"""
|
||||
An object representing a running pageserver.
|
||||
@@ -835,23 +990,20 @@ class ZenithPageserver(PgProtocol):
|
||||
port: PageserverPort,
|
||||
remote_storage: Optional[RemoteStorage] = None,
|
||||
enable_auth=False):
|
||||
super().__init__(host='localhost', port=port.pg)
|
||||
super().__init__(host='localhost', port=port.pg, username='zenith_admin')
|
||||
self.env = env
|
||||
self.running = False
|
||||
self.service_port = port # do not shadow PgProtocol.port which is just int
|
||||
self.remote_storage = remote_storage
|
||||
|
||||
def start(self) -> 'ZenithPageserver':
|
||||
def start(self, overrides=()) -> 'ZenithPageserver':
|
||||
"""
|
||||
Start the page server.
|
||||
Returns self.
|
||||
"""
|
||||
assert self.running == False
|
||||
|
||||
start_args = ['pageserver', 'start']
|
||||
append_pageserver_param_overrides(start_args, self.remote_storage)
|
||||
|
||||
self.env.zenith_cli(start_args)
|
||||
self.env.zenith_cli.pageserver_start(overrides=overrides)
|
||||
self.running = True
|
||||
return self
|
||||
|
||||
@@ -860,13 +1012,8 @@ class ZenithPageserver(PgProtocol):
|
||||
Stop the page server.
|
||||
Returns self.
|
||||
"""
|
||||
cmd = ['pageserver', 'stop']
|
||||
if immediate:
|
||||
cmd.extend(['-m', 'immediate'])
|
||||
|
||||
log.info(f"Stopping pageserver with {cmd}")
|
||||
if self.running:
|
||||
self.env.zenith_cli(cmd)
|
||||
self.env.zenith_cli.pageserver_stop(immediate)
|
||||
self.running = False
|
||||
|
||||
return self
|
||||
@@ -973,10 +1120,54 @@ def pg_bin(test_output_dir: str) -> PgBin:
|
||||
return PgBin(test_output_dir)
|
||||
|
||||
|
||||
class VanillaPostgres(PgProtocol):
|
||||
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int):
|
||||
super().__init__(host='localhost', port=port)
|
||||
self.pgdatadir = pgdatadir
|
||||
self.pg_bin = pg_bin
|
||||
self.running = False
|
||||
self.pg_bin.run_capture(['initdb', '-D', pgdatadir])
|
||||
|
||||
def configure(self, options: List[str]) -> None:
|
||||
"""Append lines into postgresql.conf file."""
|
||||
assert not self.running
|
||||
with open(os.path.join(self.pgdatadir, 'postgresql.conf'), 'a') as conf_file:
|
||||
conf_file.writelines(options)
|
||||
|
||||
def start(self) -> None:
|
||||
assert not self.running
|
||||
self.running = True
|
||||
self.pg_bin.run_capture(['pg_ctl', '-D', self.pgdatadir, 'start'])
|
||||
|
||||
def stop(self) -> None:
|
||||
assert self.running
|
||||
self.running = False
|
||||
self.pg_bin.run_capture(['pg_ctl', '-D', self.pgdatadir, 'stop'])
|
||||
|
||||
def get_subdir_size(self, subdir) -> int:
|
||||
"""Return size of pgdatadir subdirectory in bytes."""
|
||||
return get_dir_size(os.path.join(self.pgdatadir, subdir))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if self.running:
|
||||
self.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
|
||||
pgdatadir = os.path.join(test_output_dir, "pgdata-vanilla")
|
||||
pg_bin = PgBin(test_output_dir)
|
||||
with VanillaPostgres(pgdatadir, pg_bin, 5432) as vanilla_pg:
|
||||
yield vanilla_pg
|
||||
|
||||
|
||||
class Postgres(PgProtocol):
|
||||
""" An object representing a running postgres daemon. """
|
||||
def __init__(self, env: ZenithEnv, tenant_id: str, port: int):
|
||||
super().__init__(host='localhost', port=port)
|
||||
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
|
||||
super().__init__(host='localhost', port=port, username='zenith_admin')
|
||||
|
||||
self.env = env
|
||||
self.running = False
|
||||
@@ -1002,20 +1193,14 @@ class Postgres(PgProtocol):
|
||||
if branch is None:
|
||||
branch = node_name
|
||||
|
||||
self.env.zenith_cli([
|
||||
'pg',
|
||||
'create',
|
||||
f'--tenantid={self.tenant_id}',
|
||||
f'--port={self.port}',
|
||||
node_name,
|
||||
branch
|
||||
])
|
||||
self.env.zenith_cli.pg_create(node_name,
|
||||
tenant_id=self.tenant_id,
|
||||
port=self.port,
|
||||
timeline_spec=branch)
|
||||
self.node_name = node_name
|
||||
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id / self.node_name
|
||||
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id.hex / self.node_name
|
||||
self.pgdata_dir = os.path.join(self.env.repo_dir, path)
|
||||
|
||||
if self.env.safekeepers:
|
||||
self.adjust_for_wal_acceptors(self.env.get_safekeeper_connstrs())
|
||||
if config_lines is None:
|
||||
config_lines = []
|
||||
self.config(config_lines)
|
||||
@@ -1032,8 +1217,9 @@ class Postgres(PgProtocol):
|
||||
|
||||
log.info(f"Starting postgres node {self.node_name}")
|
||||
|
||||
run_result = self.env.zenith_cli(
|
||||
['pg', 'start', f'--tenantid={self.tenant_id}', f'--port={self.port}', self.node_name])
|
||||
run_result = self.env.zenith_cli.pg_start(self.node_name,
|
||||
tenant_id=self.tenant_id,
|
||||
port=self.port)
|
||||
self.running = True
|
||||
|
||||
log.info(f"stdout: {run_result.stdout}")
|
||||
@@ -1043,7 +1229,7 @@ class Postgres(PgProtocol):
|
||||
def pg_data_dir_path(self) -> str:
|
||||
""" Path to data directory """
|
||||
assert self.node_name
|
||||
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id / self.node_name
|
||||
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id.hex / self.node_name
|
||||
return os.path.join(self.env.repo_dir, path)
|
||||
|
||||
def pg_xact_dir_path(self) -> str:
|
||||
@@ -1072,7 +1258,9 @@ class Postgres(PgProtocol):
|
||||
# walproposer uses different application_name
|
||||
if ("synchronous_standby_names" in cfg_line or
|
||||
# don't ask pageserver to fetch WAL from compute
|
||||
"callmemaybe_connstring" in cfg_line):
|
||||
"callmemaybe_connstring" in cfg_line or
|
||||
# don't repeat wal_acceptors multiple times
|
||||
"wal_acceptors" in cfg_line):
|
||||
continue
|
||||
f.write(cfg_line)
|
||||
f.write("synchronous_standby_names = 'walproposer'\n")
|
||||
@@ -1101,7 +1289,7 @@ class Postgres(PgProtocol):
|
||||
|
||||
if self.running:
|
||||
assert self.node_name is not None
|
||||
self.env.zenith_cli(['pg', 'stop', self.node_name, f'--tenantid={self.tenant_id}'])
|
||||
self.env.zenith_cli.pg_stop(self.node_name, tenant_id=self.tenant_id)
|
||||
self.running = False
|
||||
|
||||
return self
|
||||
@@ -1113,8 +1301,7 @@ class Postgres(PgProtocol):
|
||||
"""
|
||||
|
||||
assert self.node_name is not None
|
||||
self.env.zenith_cli(
|
||||
['pg', 'stop', '--destroy', self.node_name, f'--tenantid={self.tenant_id}'])
|
||||
self.env.zenith_cli.pg_stop(self.node_name, self.tenant_id, destroy=True)
|
||||
self.node_name = None
|
||||
|
||||
return self
|
||||
@@ -1156,7 +1343,7 @@ class PostgresFactory:
|
||||
def create_start(self,
|
||||
node_name: str = "main",
|
||||
branch: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
config_lines: Optional[List[str]] = None) -> Postgres:
|
||||
|
||||
pg = Postgres(
|
||||
@@ -1176,7 +1363,7 @@ class PostgresFactory:
|
||||
def create(self,
|
||||
node_name: str = "main",
|
||||
branch: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
config_lines: Optional[List[str]] = None) -> Postgres:
|
||||
|
||||
pg = Postgres(
|
||||
@@ -1217,12 +1404,14 @@ class Safekeeper:
|
||||
""" An object representing a running safekeeper daemon. """
|
||||
env: ZenithEnv
|
||||
port: SafekeeperPort
|
||||
name: str # identifier for logging
|
||||
id: int
|
||||
auth_token: Optional[str] = None
|
||||
running: bool = False
|
||||
|
||||
def start(self) -> 'Safekeeper':
|
||||
self.env.zenith_cli(['safekeeper', 'start', self.name])
|
||||
|
||||
assert self.running == False
|
||||
self.env.zenith_cli.safekeeper_start(self.id)
|
||||
self.running = True
|
||||
# wait for wal acceptor start by checking its status
|
||||
started_at = time.time()
|
||||
while True:
|
||||
@@ -1240,16 +1429,14 @@ class Safekeeper:
|
||||
return self
|
||||
|
||||
def stop(self, immediate=False) -> 'Safekeeper':
|
||||
cmd = ['safekeeper', 'stop']
|
||||
if immediate:
|
||||
cmd.extend(['-m', 'immediate'])
|
||||
cmd.append(self.name)
|
||||
|
||||
log.info('Stopping safekeeper {}'.format(self.name))
|
||||
self.env.zenith_cli(cmd)
|
||||
log.info('Stopping safekeeper {}'.format(self.id))
|
||||
self.env.zenith_cli.safekeeper_stop(self.id, immediate)
|
||||
self.running = False
|
||||
return self
|
||||
|
||||
def append_logical_message(self, tenant_id: str, timeline_id: str,
|
||||
def append_logical_message(self,
|
||||
tenant_id: uuid.UUID,
|
||||
timeline_id: uuid.UUID,
|
||||
request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Send JSON_CTRL query to append LogicalMessage to WAL and modify
|
||||
@@ -1259,7 +1446,7 @@ class Safekeeper:
|
||||
|
||||
# "replication=0" hacks psycopg not to send additional queries
|
||||
# on startup, see https://github.com/psycopg/psycopg2/pull/482
|
||||
connstr = f"host=localhost port={self.port.pg} replication=0 options='-c ztimelineid={timeline_id} ztenantid={tenant_id}'"
|
||||
connstr = f"host=localhost port={self.port.pg} replication=0 options='-c ztimelineid={timeline_id.hex} ztenantid={tenant_id.hex}'"
|
||||
|
||||
with closing(psycopg2.connect(connstr)) as conn:
|
||||
# server doesn't support transactions
|
||||
@@ -1288,8 +1475,8 @@ class SafekeeperTimelineStatus:
|
||||
class SafekeeperMetrics:
|
||||
# These are metrics from Prometheus which uses float64 internally.
|
||||
# As a consequence, values may differ from real original int64s.
|
||||
flush_lsn_inexact: Dict[str, int] = field(default_factory=dict)
|
||||
commit_lsn_inexact: Dict[str, int] = field(default_factory=dict)
|
||||
flush_lsn_inexact: Dict[Tuple[str, str], int] = field(default_factory=dict)
|
||||
commit_lsn_inexact: Dict[Tuple[str, str], int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class SafekeeperHttpClient(requests.Session):
|
||||
@@ -1313,14 +1500,16 @@ class SafekeeperHttpClient(requests.Session):
|
||||
all_metrics_text = request_result.text
|
||||
|
||||
metrics = SafekeeperMetrics()
|
||||
for match in re.finditer(r'^safekeeper_flush_lsn{ztli="([0-9a-f]+)"} (\S+)$',
|
||||
all_metrics_text,
|
||||
re.MULTILINE):
|
||||
metrics.flush_lsn_inexact[match.group(1)] = int(match.group(2))
|
||||
for match in re.finditer(r'^safekeeper_commit_lsn{ztli="([0-9a-f]+)"} (\S+)$',
|
||||
all_metrics_text,
|
||||
re.MULTILINE):
|
||||
metrics.commit_lsn_inexact[match.group(1)] = int(match.group(2))
|
||||
for match in re.finditer(
|
||||
r'^safekeeper_flush_lsn{tenant_id="([0-9a-f]+)",timeline_id="([0-9a-f]+)"} (\S+)$',
|
||||
all_metrics_text,
|
||||
re.MULTILINE):
|
||||
metrics.flush_lsn_inexact[(match.group(1), match.group(2))] = int(match.group(3))
|
||||
for match in re.finditer(
|
||||
r'^safekeeper_commit_lsn{tenant_id="([0-9a-f]+)",timeline_id="([0-9a-f]+)"} (\S+)$',
|
||||
all_metrics_text,
|
||||
re.MULTILINE):
|
||||
metrics.commit_lsn_inexact[(match.group(1), match.group(2))] = int(match.group(3))
|
||||
return metrics
|
||||
|
||||
|
||||
@@ -1429,7 +1618,7 @@ def check_restored_datadir_content(test_output_dir: str, env: ZenithEnv, pg: Pos
|
||||
{psql_path} \
|
||||
--no-psqlrc \
|
||||
postgres://localhost:{env.pageserver.service_port.pg} \
|
||||
-c 'basebackup {pg.tenant_id} {timeline}' \
|
||||
-c 'basebackup {pg.tenant_id.hex} {timeline}' \
|
||||
| tar -x -C {restored_dir_path}
|
||||
"""
|
||||
|
||||
|
||||
@@ -2,8 +2,7 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
|
||||
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
|
||||
|
||||
|
||||
#
|
||||
@@ -16,47 +15,19 @@ pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
|
||||
# 3. Disk space used
|
||||
# 4. Peak memory usage
|
||||
#
|
||||
def test_bulk_insert(zenith_simple_env: ZenithEnv, zenbenchmark: ZenithBenchmarker):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_bulk_insert", "empty"])
|
||||
|
||||
pg = env.postgres.create_start('test_bulk_insert')
|
||||
log.info("postgres is running on 'test_bulk_insert' branch")
|
||||
|
||||
# Open a connection directly to the page server that we'll use to force
|
||||
# flushing the layers to disk
|
||||
psconn = env.pageserver.connect()
|
||||
pscur = psconn.cursor()
|
||||
def test_bulk_insert(zenith_with_baseline: PgCompare):
|
||||
env = zenith_with_baseline
|
||||
|
||||
# Get the timeline ID of our branch. We need it for the 'do_gc' command
|
||||
with closing(pg.connect()) as conn:
|
||||
with closing(env.pg.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SHOW zenith.zenith_timeline")
|
||||
timeline = cur.fetchone()[0]
|
||||
|
||||
cur.execute("create table huge (i int, j int);")
|
||||
|
||||
# Run INSERT, recording the time and I/O it takes
|
||||
with zenbenchmark.record_pageserver_writes(env.pageserver, 'pageserver_writes'):
|
||||
with zenbenchmark.record_duration('insert'):
|
||||
with env.record_pageserver_writes('pageserver_writes'):
|
||||
with env.record_duration('insert'):
|
||||
cur.execute("insert into huge values (generate_series(1, 5000000), 0);")
|
||||
env.flush()
|
||||
|
||||
# Flush the layers from memory to disk. This is included in the reported
|
||||
# time and I/O
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
|
||||
# Record peak memory usage
|
||||
zenbenchmark.record("peak_mem",
|
||||
zenbenchmark.get_peak_mem(env.pageserver) / 1024,
|
||||
'MB',
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
|
||||
# Report disk space used by the repository
|
||||
timeline_size = zenbenchmark.get_timeline_size(env.repo_dir,
|
||||
env.initial_tenant,
|
||||
timeline)
|
||||
zenbenchmark.record('size',
|
||||
timeline_size / (1024 * 1024),
|
||||
'MB',
|
||||
report=MetricReport.LOWER_IS_BETTER)
|
||||
env.report_peak_memory_use()
|
||||
env.report_size()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user