Compare commits

..

2 Commits

Author SHA1 Message Date
Bojan Serafimov
64fcf4f096 Implement mock console 2022-02-09 14:30:01 -05:00
Dmitry Ivanov
18d3d078ad [WIP] [proxy] Migrate to async 2022-02-08 05:43:32 +03:00
75 changed files with 1940 additions and 2415 deletions

View File

@@ -54,8 +54,7 @@ jobs:
if [ ! -e tmp_install/bin/postgres ]; then
# "depth 1" saves some time by not cloning the whole repo
git submodule update --init --depth 1
# bail out on any warnings
COPT='-Werror' mold -run make postgres -j$(nproc)
mold -run make postgres -j$(nproc)
fi
- save_cache:
@@ -298,7 +297,6 @@ jobs:
- PLATFORM: zenith-local-ci
command: |
PERF_REPORT_DIR="$(realpath test_runner/perf-report-local)"
rm -rf $PERF_REPORT_DIR
TEST_SELECTION="test_runner/<< parameters.test_selection >>"
EXTRA_PARAMS="<< parameters.extra_params >>"
@@ -343,6 +341,7 @@ jobs:
if << parameters.save_perf_report >>; then
if [[ $CIRCLE_BRANCH == "main" ]]; then
# TODO: reuse scripts/git-upload
export REPORT_FROM="$PERF_REPORT_DIR"
export REPORT_TO=local
scripts/generate_and_push_perf_report.sh
@@ -598,7 +597,6 @@ workflows:
- build-postgres-<< matrix.build_type >>
- run-pytest:
name: pg_regress-tests-<< matrix.build_type >>
context: PERF_TEST_RESULT_CONNSTR
matrix:
parameters:
build_type: ["debug", "release"]
@@ -616,7 +614,6 @@ workflows:
- build-zenith-<< matrix.build_type >>
- run-pytest:
name: benchmarks
context: PERF_TEST_RESULT_CONNSTR
build_type: release
test_selection: performance
run_in_parallel: false

View File

@@ -3,7 +3,7 @@ name: benchmarking
on:
# uncomment to run on push for debugging your PR
# push:
# branches: [ your branch ]
# branches: [ mybranch ]
schedule:
# * is a special character in YAML so you have to quote this string
# ┌───────────── minute (0 - 59)
@@ -41,7 +41,7 @@ jobs:
run: |
python3 -m pip install --upgrade poetry wheel
# since pip/poetry caches are reused there shouldn't be any troubles with install every time
./scripts/pysync
poetry install
- name: Show versions
run: |
@@ -89,15 +89,11 @@ jobs:
BENCHMARK_CONNSTR: "${{ secrets.BENCHMARK_STAGING_CONNSTR }}"
REMOTE_ENV: "1" # indicate to test harness that we do not have zenith binaries locally
run: |
# just to be sure that no data was cached on self hosted runner
# since it might generate duplicates when calling ingest_perf_test_result.py
rm -rf perf-report-staging
mkdir -p perf-report-staging
./scripts/pytest test_runner/performance/ -v -m "remote_cluster" --skip-interfering-proc-check --out-dir perf-report-staging
- name: Submit result
env:
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
run: |
REPORT_FROM=$(realpath perf-report-staging) REPORT_TO=staging scripts/generate_and_push_perf_report.sh

666
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -16,3 +16,8 @@ members = [
# This is useful for profiling and, to some extent, debug.
# Besides, debug info should not affect the performance.
debug = true
# This is only needed for proxy's tests
# TODO: we should probably fork tokio-postgres-rustls instead
[patch.crates-io]
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }

View File

@@ -1,14 +1,17 @@
[package]
name = "compute_tools"
version = "0.1.0"
authors = ["Alexey Kondratov <kondratov.aleksey@gmail.com>"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
libc = "0.2"
anyhow = "1.0"
chrono = "0.4"
clap = "3.0"
env_logger = "0.9"
clap = "2.33"
env_logger = "0.8"
hyper = { version = "0.14", features = ["full"] }
log = { version = "0.4", features = ["std", "serde"] }
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }

View File

@@ -34,7 +34,6 @@ use std::sync::{Arc, RwLock};
use anyhow::{Context, Result};
use chrono::Utc;
use clap::Arg;
use log::info;
use postgres::{Client, NoTls};
@@ -163,34 +162,34 @@ fn main() -> Result<()> {
let matches = clap::App::new("zenith_ctl")
.version(version.unwrap_or("unknown"))
.arg(
Arg::new("connstr")
.short('C')
clap::Arg::with_name("connstr")
.short("C")
.long("connstr")
.value_name("DATABASE_URL")
.required(true),
)
.arg(
Arg::new("pgdata")
.short('D')
clap::Arg::with_name("pgdata")
.short("D")
.long("pgdata")
.value_name("DATADIR")
.required(true),
)
.arg(
Arg::new("pgbin")
.short('b')
clap::Arg::with_name("pgbin")
.short("b")
.long("pgbin")
.value_name("POSTGRES_PATH"),
)
.arg(
Arg::new("spec")
.short('s')
clap::Arg::with_name("spec")
.short("s")
.long("spec")
.value_name("SPEC_JSON"),
)
.arg(
Arg::new("spec-path")
.short('S')
clap::Arg::with_name("spec-path")
.short("S")
.long("spec-path")
.value_name("SPEC_PATH"),
)

View File

@@ -1,8 +1,11 @@
[package]
name = "control_plane"
version = "0.1.0"
authors = ["Stas Kelvich <stas@zenith.tech>"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tar = "0.4.33"
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }

View File

@@ -1,7 +1,7 @@
# Page server and three safekeepers.
[pageserver]
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
listen_pg_addr = 'localhost:64000'
listen_http_addr = 'localhost:9898'
auth_type = 'Trust'
[[safekeepers]]

View File

@@ -1,8 +1,8 @@
# Minimal zenith environment with one safekeeper. This is equivalent to the built-in
# defaults that you get with no --config
[pageserver]
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
listen_pg_addr = 'localhost:64000'
listen_http_addr = 'localhost:9898'
auth_type = 'Trust'
[[safekeepers]]

View File

@@ -85,7 +85,7 @@ impl SafekeeperNode {
pg_connection_config: Self::safekeeper_connection_config(conf.pg_port),
env: env.clone(),
http_client: Client::new(),
http_base_url: format!("http://127.0.0.1:{}/v1", conf.http_port),
http_base_url: format!("http://localhost:{}/v1", conf.http_port),
pageserver,
}
}
@@ -93,7 +93,7 @@ impl SafekeeperNode {
/// Construct libpq connection string for connecting to this safekeeper.
fn safekeeper_connection_config(port: u16) -> Config {
// TODO safekeeper authentication not implemented yet
format!("postgresql://no_user@127.0.0.1:{}/no_db", port)
format!("postgresql://no_user@localhost:{}/no_db", port)
.parse()
.unwrap()
}
@@ -114,8 +114,8 @@ impl SafekeeperNode {
);
io::stdout().flush().unwrap();
let listen_pg = format!("127.0.0.1:{}", self.conf.pg_port);
let listen_http = format!("127.0.0.1:{}", self.conf.http_port);
let listen_pg = format!("localhost:{}", self.conf.pg_port);
let listen_http = format!("localhost:{}", self.conf.http_port);
let mut cmd = Command::new(self.env.safekeeper_bin()?);
fill_rust_env_vars(

View File

@@ -1,6 +1,7 @@
[package]
name = "pageserver"
version = "0.1.0"
authors = ["Stas Kelvich <stas@zenith.tech>"]
edition = "2021"
[dependencies]
@@ -14,7 +15,7 @@ futures = "0.3.13"
hyper = "0.14"
lazy_static = "1.4.0"
log = "0.4.14"
clap = "3.0"
clap = "2.33.0"
daemonize = "0.4.1"
tokio = { version = "1.11", features = ["process", "sync", "macros", "fs", "rt", "io-util", "time"] }
postgres-types = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
@@ -22,6 +23,7 @@ postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
tokio-stream = "0.1.8"
routerify = "2"
anyhow = { version = "1.0", features = ["backtrace"] }
crc32c = "0.6.0"
thiserror = "1.0"
@@ -30,7 +32,7 @@ tar = "0.4.33"
humantime = "2.1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
toml_edit = { version = "0.13", features = ["easy"] }
toml_edit = { version = "0.12", features = ["easy"] }
scopeguard = "1.1.0"
async-trait = "0.1"
const_format = "0.2.21"
@@ -40,6 +42,7 @@ signal-hook = "0.3.10"
url = "2"
nix = "0.23"
once_cell = "1.8.0"
parking_lot = "0.11.2"
crossbeam-utils = "0.8.5"
rust-s3 = { version = "0.28", default-features = false, features = ["no-verify-ssl", "tokio-rustls-tls"] }

View File

@@ -13,7 +13,7 @@ fn main() -> Result<()> {
.about("Dump contents of one layer file, for debugging")
.version(GIT_VERSION)
.arg(
Arg::new("path")
Arg::with_name("path")
.help("Path to file to dump")
.required(true)
.index(1),

View File

@@ -27,27 +27,27 @@ fn main() -> Result<()> {
.about("Materializes WAL stream to pages and serves them to the postgres")
.version(GIT_VERSION)
.arg(
Arg::new("daemonize")
.short('d')
Arg::with_name("daemonize")
.short("d")
.long("daemonize")
.takes_value(false)
.help("Run in the background"),
)
.arg(
Arg::new("init")
Arg::with_name("init")
.long("init")
.takes_value(false)
.help("Initialize pageserver repo"),
)
.arg(
Arg::new("workdir")
.short('D')
Arg::with_name("workdir")
.short("D")
.long("workdir")
.takes_value(true)
.help("Working directory for the pageserver"),
)
.arg(
Arg::new("create-tenant")
Arg::with_name("create-tenant")
.long("create-tenant")
.takes_value(true)
.help("Create tenant during init")
@@ -55,11 +55,11 @@ fn main() -> Result<()> {
)
// See `settings.md` for more details on the extra configuration patameters pageserver can process
.arg(
Arg::new("config-override")
.short('c')
Arg::with_name("config-override")
.short("c")
.takes_value(true)
.number_of_values(1)
.multiple_occurrences(true)
.multiple(true)
.help("Additional configuration overrides of the ones from the toml config file (or new ones to add there).
Any option has to be a valid toml document, example: `-c \"foo='hey'\"` `-c \"foo={value=1}\"`"),
)

View File

@@ -1,334 +0,0 @@
//! A CLI helper to deal with remote storage (S3, usually) blobs as archives.
//! See [`compression`] for more details about the archives.
use std::{collections::BTreeSet, path::Path};
use anyhow::{bail, ensure, Context};
use clap::{App, Arg};
use pageserver::{
layered_repository::metadata::{TimelineMetadata, METADATA_FILE_NAME},
remote_storage::compression,
};
use tokio::{fs, io};
use zenith_utils::GIT_VERSION;
const LIST_SUBCOMMAND: &str = "list";
const ARCHIVE_ARG_NAME: &str = "archive";
const EXTRACT_SUBCOMMAND: &str = "extract";
const TARGET_DIRECTORY_ARG_NAME: &str = "target_directory";
const CREATE_SUBCOMMAND: &str = "create";
const SOURCE_DIRECTORY_ARG_NAME: &str = "source_directory";
#[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> {
let arg_matches = App::new("pageserver zst blob [un]compressor utility")
.version(GIT_VERSION)
.subcommands(vec![
App::new(LIST_SUBCOMMAND)
.about("List the archive contents")
.arg(
Arg::new(ARCHIVE_ARG_NAME)
.required(true)
.takes_value(true)
.help("An archive to list the contents of"),
),
App::new(EXTRACT_SUBCOMMAND)
.about("Extracts the archive into the directory")
.arg(
Arg::new(ARCHIVE_ARG_NAME)
.required(true)
.takes_value(true)
.help("An archive to extract"),
)
.arg(
Arg::new(TARGET_DIRECTORY_ARG_NAME)
.required(false)
.takes_value(true)
.help("A directory to extract the archive into. Optional, will use the current directory if not specified"),
),
App::new(CREATE_SUBCOMMAND)
.about("Creates an archive with the contents of a directory (only the first level files are taken, metadata file has to be present in the same directory)")
.arg(
Arg::new(SOURCE_DIRECTORY_ARG_NAME)
.required(true)
.takes_value(true)
.help("A directory to use for creating the archive"),
)
.arg(
Arg::new(TARGET_DIRECTORY_ARG_NAME)
.required(false)
.takes_value(true)
.help("A directory to create the archive in. Optional, will use the current directory if not specified"),
),
])
.get_matches();
let subcommand_name = match arg_matches.subcommand_name() {
Some(name) => name,
None => bail!("No subcommand specified"),
};
let subcommand_matches = match arg_matches.subcommand_matches(subcommand_name) {
Some(matches) => matches,
None => bail!(
"No subcommand arguments were recognized for subcommand '{}'",
subcommand_name
),
};
let target_dir = Path::new(
subcommand_matches
.value_of(TARGET_DIRECTORY_ARG_NAME)
.unwrap_or("./"),
);
match subcommand_name {
LIST_SUBCOMMAND => {
let archive = match subcommand_matches.value_of(ARCHIVE_ARG_NAME) {
Some(archive) => Path::new(archive),
None => bail!("No '{}' argument is specified", ARCHIVE_ARG_NAME),
};
list_archive(archive).await
}
EXTRACT_SUBCOMMAND => {
let archive = match subcommand_matches.value_of(ARCHIVE_ARG_NAME) {
Some(archive) => Path::new(archive),
None => bail!("No '{}' argument is specified", ARCHIVE_ARG_NAME),
};
extract_archive(archive, target_dir).await
}
CREATE_SUBCOMMAND => {
let source_dir = match subcommand_matches.value_of(SOURCE_DIRECTORY_ARG_NAME) {
Some(source) => Path::new(source),
None => bail!("No '{}' argument is specified", SOURCE_DIRECTORY_ARG_NAME),
};
create_archive(source_dir, target_dir).await
}
unknown => bail!("Unknown subcommand {}", unknown),
}
}
async fn list_archive(archive: &Path) -> anyhow::Result<()> {
let archive = archive.canonicalize().with_context(|| {
format!(
"Failed to get the absolute path for the archive path '{}'",
archive.display()
)
})?;
ensure!(
archive.is_file(),
"Path '{}' is not an archive file",
archive.display()
);
println!("Listing an archive at path '{}'", archive.display());
let archive_name = match archive.file_name().and_then(|name| name.to_str()) {
Some(name) => name,
None => bail!(
"Failed to get the archive name from the path '{}'",
archive.display()
),
};
let archive_bytes = fs::read(&archive)
.await
.context("Failed to read the archive bytes")?;
let header = compression::read_archive_header(archive_name, &mut archive_bytes.as_slice())
.await
.context("Failed to read the archive header")?;
let empty_path = Path::new("");
println!("-------------------------------");
let longest_path_in_archive = header
.files
.iter()
.filter_map(|file| Some(file.subpath.as_path(empty_path).to_str()?.len()))
.max()
.unwrap_or_default()
.max(METADATA_FILE_NAME.len());
for regular_file in &header.files {
println!(
"File: {:width$} uncompressed size: {} bytes",
regular_file.subpath.as_path(empty_path).display(),
regular_file.size,
width = longest_path_in_archive,
)
}
println!(
"File: {:width$} uncompressed size: {} bytes",
METADATA_FILE_NAME,
header.metadata_file_size,
width = longest_path_in_archive,
);
println!("-------------------------------");
Ok(())
}
async fn extract_archive(archive: &Path, target_dir: &Path) -> anyhow::Result<()> {
let archive = archive.canonicalize().with_context(|| {
format!(
"Failed to get the absolute path for the archive path '{}'",
archive.display()
)
})?;
ensure!(
archive.is_file(),
"Path '{}' is not an archive file",
archive.display()
);
let archive_name = match archive.file_name().and_then(|name| name.to_str()) {
Some(name) => name,
None => bail!(
"Failed to get the archive name from the path '{}'",
archive.display()
),
};
if !target_dir.exists() {
fs::create_dir_all(target_dir).await.with_context(|| {
format!(
"Failed to create the target dir at path '{}'",
target_dir.display()
)
})?;
}
let target_dir = target_dir.canonicalize().with_context(|| {
format!(
"Failed to get the absolute path for the target dir path '{}'",
target_dir.display()
)
})?;
ensure!(
target_dir.is_dir(),
"Path '{}' is not a directory",
target_dir.display()
);
let mut dir_contents = fs::read_dir(&target_dir)
.await
.context("Failed to list the target directory contents")?;
let dir_entry = dir_contents
.next_entry()
.await
.context("Failed to list the target directory contents")?;
ensure!(
dir_entry.is_none(),
"Target directory '{}' is not empty",
target_dir.display()
);
println!(
"Extracting an archive at path '{}' into directory '{}'",
archive.display(),
target_dir.display()
);
let mut archive_file = fs::File::open(&archive).await.with_context(|| {
format!(
"Failed to get the archive name from the path '{}'",
archive.display()
)
})?;
let header = compression::read_archive_header(archive_name, &mut archive_file)
.await
.context("Failed to read the archive header")?;
compression::uncompress_with_header(&BTreeSet::new(), &target_dir, header, &mut archive_file)
.await
.context("Failed to extract the archive")
}
async fn create_archive(source_dir: &Path, target_dir: &Path) -> anyhow::Result<()> {
let source_dir = source_dir.canonicalize().with_context(|| {
format!(
"Failed to get the absolute path for the source dir path '{}'",
source_dir.display()
)
})?;
ensure!(
source_dir.is_dir(),
"Path '{}' is not a directory",
source_dir.display()
);
if !target_dir.exists() {
fs::create_dir_all(target_dir).await.with_context(|| {
format!(
"Failed to create the target dir at path '{}'",
target_dir.display()
)
})?;
}
let target_dir = target_dir.canonicalize().with_context(|| {
format!(
"Failed to get the absolute path for the target dir path '{}'",
target_dir.display()
)
})?;
ensure!(
target_dir.is_dir(),
"Path '{}' is not a directory",
target_dir.display()
);
println!(
"Compressing directory '{}' and creating resulting archive in directory '{}'",
source_dir.display(),
target_dir.display()
);
let mut metadata_file_contents = None;
let mut files_co_archive = Vec::new();
let mut source_dir_contents = fs::read_dir(&source_dir)
.await
.context("Failed to read the source directory contents")?;
while let Some(source_dir_entry) = source_dir_contents
.next_entry()
.await
.context("Failed to read a source dir entry")?
{
let entry_path = source_dir_entry.path();
if entry_path.is_file() {
if entry_path.file_name().and_then(|name| name.to_str()) == Some(METADATA_FILE_NAME) {
let metadata_bytes = fs::read(entry_path)
.await
.context("Failed to read metata file bytes in the source dir")?;
metadata_file_contents = Some(
TimelineMetadata::from_bytes(&metadata_bytes)
.context("Failed to parse metata file contents in the source dir")?,
);
} else {
files_co_archive.push(entry_path);
}
}
}
let metadata = match metadata_file_contents {
Some(metadata) => metadata,
None => bail!(
"No metadata file found in the source dir '{}', cannot create the archive",
source_dir.display()
),
};
let _ = compression::archive_files_as_stream(
&source_dir,
files_co_archive.iter(),
&metadata,
move |mut archive_streamer, archive_name| async move {
let archive_target = target_dir.join(&archive_name);
let mut archive_file = fs::File::create(&archive_target).await?;
io::copy(&mut archive_streamer, &mut archive_file).await?;
Ok(archive_target)
},
)
.await
.context("Failed to create an archive")?;
Ok(())
}

View File

@@ -14,20 +14,20 @@ fn main() -> Result<()> {
.about("Dump or update metadata file")
.version(GIT_VERSION)
.arg(
Arg::new("path")
Arg::with_name("path")
.help("Path to metadata file")
.required(true),
)
.arg(
Arg::new("disk_lsn")
.short('d')
Arg::with_name("disk_lsn")
.short("d")
.long("disk_lsn")
.takes_value(true)
.help("Replace disk constistent lsn"),
)
.arg(
Arg::new("prev_lsn")
.short('p')
Arg::with_name("prev_lsn")
.short("p")
.long("prev_lsn")
.takes_value(true)
.help("Previous record LSN"),

View File

@@ -4,6 +4,7 @@ use anyhow::{Context, Result};
use hyper::header;
use hyper::StatusCode;
use hyper::{Body, Request, Response, Uri};
use routerify::{ext::RequestExt, RouterBuilder};
use serde::Serialize;
use tracing::*;
use zenith_utils::auth::JwtAuth;
@@ -18,7 +19,6 @@ use zenith_utils::http::{
request::get_request_param,
request::parse_request_param,
};
use zenith_utils::http::{RequestExt, RouterBuilder};
use zenith_utils::lsn::Lsn;
use zenith_utils::zid::{opt_display_serde, ZTimelineId};

View File

@@ -175,10 +175,7 @@ impl Write for EphemeralFile {
}
fn flush(&mut self) -> Result<(), std::io::Error> {
// we don't need to flush data:
// * we either write input bytes or not, not keeping any intermediate data buffered
// * rust unix file `flush` impl does not flush things either, returning `Ok(())`
Ok(())
todo!()
}
}

View File

@@ -94,7 +94,7 @@ use std::{
use anyhow::{bail, Context};
use tokio::io;
use tracing::{error, info};
use zenith_utils::zid::{ZTenantId, ZTenantTimelineId, ZTimelineId};
use zenith_utils::zid::{ZTenantId, ZTimelineId};
pub use self::storage_sync::{schedule_timeline_checkpoint_upload, schedule_timeline_download};
use self::{local_fs::LocalFs, rust_s3::S3};
@@ -104,7 +104,16 @@ use crate::{
repository::TimelineSyncState,
};
pub use storage_sync::compression;
/// Any timeline has its own id and its own tenant it belongs to,
/// the sync processes group timelines by both for simplicity.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct TimelineSyncId(ZTenantId, ZTimelineId);
impl std::fmt::Display for TimelineSyncId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "(tenant: {}, timeline: {})", self.0, self.1)
}
}
/// A structure to combine all synchronization data to share with pageserver after a successful sync loop initialization.
/// Successful initialization includes a case when sync loop is not started, in which case the startup data is returned still,
@@ -158,7 +167,7 @@ pub fn start_local_timeline_sync(
ZTenantId,
HashMap<ZTimelineId, TimelineSyncState>,
> = HashMap::new();
for (ZTenantTimelineId{tenant_id, timeline_id}, (timeline_metadata, _)) in
for (TimelineSyncId(tenant_id, timeline_id), (timeline_metadata, _)) in
local_timeline_files
{
initial_timeline_states
@@ -178,7 +187,7 @@ pub fn start_local_timeline_sync(
fn local_tenant_timeline_files(
config: &'static PageServerConf,
) -> anyhow::Result<HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)>> {
) -> anyhow::Result<HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)>> {
let mut local_tenant_timeline_files = HashMap::new();
let tenants_dir = config.tenants_path();
for tenants_dir_entry in fs::read_dir(&tenants_dir)
@@ -213,9 +222,8 @@ fn local_tenant_timeline_files(
fn collect_timelines_for_tenant(
config: &'static PageServerConf,
tenant_path: &Path,
) -> anyhow::Result<HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)>> {
let mut timelines: HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)> =
HashMap::new();
) -> anyhow::Result<HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)>> {
let mut timelines: HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)> = HashMap::new();
let tenant_id = tenant_path
.file_name()
.and_then(ffi::OsStr::to_str)
@@ -236,10 +244,7 @@ fn collect_timelines_for_tenant(
match collect_timeline_files(&timeline_path) {
Ok((timeline_id, metadata, timeline_files)) => {
timelines.insert(
ZTenantTimelineId {
tenant_id,
timeline_id,
},
TimelineSyncId(tenant_id, timeline_id),
(metadata, timeline_files),
);
}

View File

@@ -70,8 +70,7 @@
//!
//! When pageserver signals shutdown, current sync task gets finished and the loop exists.
/// Expose the module for a binary CLI tool that deals with the corresponding blobs.
pub mod compression;
mod compression;
mod download;
pub mod index;
mod upload;
@@ -106,7 +105,7 @@ use self::{
},
upload::upload_timeline_checkpoint,
};
use super::{RemoteStorage, SyncStartupData, ZTenantTimelineId};
use super::{RemoteStorage, SyncStartupData, TimelineSyncId};
use crate::{
config::PageServerConf, layered_repository::metadata::TimelineMetadata,
remote_storage::storage_sync::compression::read_archive_header, repository::TimelineSyncState,
@@ -243,13 +242,13 @@ mod sync_queue {
/// Limited by the number of retries, after certain threshold the failing task gets evicted and the timeline disabled.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct SyncTask {
sync_id: ZTenantTimelineId,
sync_id: TimelineSyncId,
retries: u32,
kind: SyncKind,
}
impl SyncTask {
fn new(sync_id: ZTenantTimelineId, retries: u32, kind: SyncKind) -> Self {
fn new(sync_id: TimelineSyncId, retries: u32, kind: SyncKind) -> Self {
Self {
sync_id,
retries,
@@ -308,10 +307,7 @@ pub fn schedule_timeline_checkpoint_upload(
}
if !sync_queue::push(SyncTask::new(
ZTenantTimelineId {
tenant_id,
timeline_id,
},
TimelineSyncId(tenant_id, timeline_id),
0,
SyncKind::Upload(NewCheckpoint { layers, metadata }),
)) {
@@ -342,10 +338,7 @@ pub fn schedule_timeline_download(tenant_id: ZTenantId, timeline_id: ZTimelineId
tenant_id, timeline_id
);
sync_queue::push(SyncTask::new(
ZTenantTimelineId {
tenant_id,
timeline_id,
},
TimelineSyncId(tenant_id, timeline_id),
0,
SyncKind::Download(TimelineDownload {
files_to_skip: Arc::new(BTreeSet::new()),
@@ -361,7 +354,7 @@ pub(super) fn spawn_storage_sync_thread<
S: RemoteStorage<StoragePath = P> + Send + Sync + 'static,
>(
conf: &'static PageServerConf,
local_timeline_files: HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)>,
local_timeline_files: HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)>,
storage: S,
max_concurrent_sync: NonZeroUsize,
max_sync_errors: NonZeroU32,
@@ -517,7 +510,7 @@ async fn loop_step<
Err(e) => {
error!(
"Failed to process storage sync task for tenant {}, timeline {}: {:?}",
sync_id.tenant_id, sync_id.timeline_id, e
sync_id.0, sync_id.1, e
);
None
}
@@ -531,10 +524,7 @@ async fn loop_step<
while let Some((sync_id, state_update)) = task_batch.next().await {
debug!("Finished storage sync task for sync id {}", sync_id);
if let Some(state_update) = state_update {
let ZTenantTimelineId {
tenant_id,
timeline_id,
} = sync_id;
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
new_timeline_states
.entry(tenant_id)
.or_default()
@@ -628,7 +618,7 @@ async fn process_task<
fn schedule_first_sync_tasks(
index: &RemoteTimelineIndex,
local_timeline_files: HashMap<ZTenantTimelineId, (TimelineMetadata, Vec<PathBuf>)>,
local_timeline_files: HashMap<TimelineSyncId, (TimelineMetadata, Vec<PathBuf>)>,
) -> HashMap<ZTenantId, HashMap<ZTimelineId, TimelineSyncState>> {
let mut initial_timeline_statuses: HashMap<ZTenantId, HashMap<ZTimelineId, TimelineSyncState>> =
HashMap::new();
@@ -639,10 +629,7 @@ fn schedule_first_sync_tasks(
for (sync_id, (local_metadata, local_files)) in local_timeline_files {
let local_disk_consistent_lsn = local_metadata.disk_consistent_lsn();
let ZTenantTimelineId {
tenant_id,
timeline_id,
} = sync_id;
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
match index.timeline_entry(&sync_id) {
Some(index_entry) => {
let timeline_status = compare_local_and_remote_timeline(
@@ -685,10 +672,10 @@ fn schedule_first_sync_tasks(
}
}
let unprocessed_remote_ids = |remote_id: &ZTenantTimelineId| {
let unprocessed_remote_ids = |remote_id: &TimelineSyncId| {
initial_timeline_statuses
.get(&remote_id.tenant_id)
.and_then(|timelines| timelines.get(&remote_id.timeline_id))
.get(&remote_id.0)
.and_then(|timelines| timelines.get(&remote_id.1))
.is_none()
};
for unprocessed_remote_id in index
@@ -696,10 +683,7 @@ fn schedule_first_sync_tasks(
.filter(unprocessed_remote_ids)
.collect::<Vec<_>>()
{
let ZTenantTimelineId {
tenant_id: cloud_only_tenant_id,
timeline_id: cloud_only_timeline_id,
} = unprocessed_remote_id;
let TimelineSyncId(cloud_only_tenant_id, cloud_only_timeline_id) = unprocessed_remote_id;
match index
.timeline_entry(&unprocessed_remote_id)
.and_then(TimelineIndexEntry::disk_consistent_lsn)
@@ -728,7 +712,7 @@ fn schedule_first_sync_tasks(
fn compare_local_and_remote_timeline(
new_sync_tasks: &mut VecDeque<SyncTask>,
sync_id: ZTenantTimelineId,
sync_id: TimelineSyncId,
local_metadata: TimelineMetadata,
local_files: Vec<PathBuf>,
remote_entry: &TimelineIndexEntry,
@@ -785,7 +769,7 @@ async fn update_index_description<
>(
(storage, index): &(S, RwLock<RemoteTimelineIndex>),
timeline_dir: &Path,
id: ZTenantTimelineId,
id: TimelineSyncId,
) -> anyhow::Result<RemoteTimeline> {
let mut index_write = index.write().await;
let full_index = match index_write.timeline_entry(&id) {
@@ -808,7 +792,7 @@ async fn update_index_description<
Ok((archive_id, header_size, header)) => full_index.update_archive_contents(archive_id.0, header, header_size),
Err((e, archive_id)) => bail!(
"Failed to download archive header for tenant {}, timeline {}, archive for Lsn {}: {}",
id.tenant_id, id.timeline_id, archive_id.0,
id.0, id.1, archive_id.0,
e
),
}
@@ -886,7 +870,7 @@ mod test_utils {
timeline_id: ZTimelineId,
new_upload: NewCheckpoint,
) {
let sync_id = ZTenantTimelineId::new(harness.tenant_id, timeline_id);
let sync_id = TimelineSyncId(harness.tenant_id, timeline_id);
upload_timeline_checkpoint(
harness.conf,
Arc::clone(&remote_assets),
@@ -942,7 +926,7 @@ mod test_utils {
pub async fn expect_timeline(
index: &RwLock<RemoteTimelineIndex>,
sync_id: ZTenantTimelineId,
sync_id: TimelineSyncId,
) -> RemoteTimeline {
if let Some(TimelineIndexEntry::Full(remote_timeline)) =
index.read().await.timeline_entry(&sync_id)
@@ -977,18 +961,18 @@ mod test_utils {
let mut expected_timeline_entries = BTreeMap::new();
for sync_id in actual_sync_ids {
actual_branches.insert(
sync_id.tenant_id,
sync_id.1,
index_read
.branch_files(sync_id.tenant_id)
.branch_files(sync_id.0)
.into_iter()
.flat_map(|branch_paths| branch_paths.iter())
.cloned()
.collect::<BTreeSet<_>>(),
);
expected_branches.insert(
sync_id.tenant_id,
sync_id.1,
expected_index_with_descriptions
.branch_files(sync_id.tenant_id)
.branch_files(sync_id.0)
.into_iter()
.flat_map(|branch_paths| branch_paths.iter())
.cloned()

View File

@@ -248,7 +248,7 @@ fn archive_name(disk_consistent_lsn: Lsn, header_size: u64) -> String {
archive_name
}
pub async fn uncompress_with_header(
async fn uncompress_with_header(
files_to_skip: &BTreeSet<PathBuf>,
destination_dir: &Path,
header: ArchiveHeader,

View File

@@ -17,7 +17,7 @@ use crate::{
compression, index::TimelineIndexEntry, sync_queue, tenant_branch_files,
update_index_description, SyncKind, SyncTask,
},
RemoteStorage, ZTenantTimelineId,
RemoteStorage, TimelineSyncId,
},
};
@@ -52,16 +52,13 @@ pub(super) async fn download_timeline<
>(
conf: &'static PageServerConf,
remote_assets: Arc<(S, RwLock<RemoteTimelineIndex>)>,
sync_id: ZTenantTimelineId,
sync_id: TimelineSyncId,
mut download: TimelineDownload,
retries: u32,
) -> DownloadedTimeline {
debug!("Downloading layers for sync id {}", sync_id);
let ZTenantTimelineId {
tenant_id,
timeline_id,
} = sync_id;
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
let index_read = remote_assets.1.read().await;
let remote_timeline = match index_read.timeline_entry(&sync_id) {
None => {
@@ -113,8 +110,7 @@ pub(super) async fn download_timeline<
}
};
if let Err(e) = download_missing_branches(conf, remote_assets.as_ref(), sync_id.tenant_id).await
{
if let Err(e) = download_missing_branches(conf, remote_assets.as_ref(), sync_id.0).await {
error!(
"Failed to download missing branches for sync id {}: {:?}",
sync_id, e
@@ -184,10 +180,7 @@ async fn try_download_archive<
S: RemoteStorage<StoragePath = P> + Send + Sync + 'static,
>(
conf: &'static PageServerConf,
ZTenantTimelineId {
tenant_id,
timeline_id,
}: ZTenantTimelineId,
TimelineSyncId(tenant_id, timeline_id): TimelineSyncId,
remote_assets: Arc<(S, RwLock<RemoteTimelineIndex>)>,
remote_timeline: &RemoteTimeline,
archive_id: ArchiveId,
@@ -350,7 +343,7 @@ mod tests {
#[tokio::test]
async fn test_download_timeline() -> anyhow::Result<()> {
let repo_harness = RepoHarness::create("test_download_timeline")?;
let sync_id = ZTenantTimelineId::new(repo_harness.tenant_id, TIMELINE_ID);
let sync_id = TimelineSyncId(repo_harness.tenant_id, TIMELINE_ID);
let storage = LocalFs::new(tempdir()?.path().to_owned(), &repo_harness.conf.workdir)?;
let index = RwLock::new(RemoteTimelineIndex::try_parse_descriptions_from_paths(
repo_harness.conf,

View File

@@ -22,7 +22,7 @@ use crate::{
layered_repository::TIMELINES_SEGMENT_NAME,
remote_storage::{
storage_sync::compression::{parse_archive_name, FileEntry},
ZTenantTimelineId,
TimelineSyncId,
},
};
@@ -53,7 +53,7 @@ impl RelativePath {
#[derive(Debug, Clone)]
pub struct RemoteTimelineIndex {
branch_files: HashMap<ZTenantId, HashSet<RelativePath>>,
timeline_files: HashMap<ZTenantTimelineId, TimelineIndexEntry>,
timeline_files: HashMap<TimelineSyncId, TimelineIndexEntry>,
}
impl RemoteTimelineIndex {
@@ -80,22 +80,19 @@ impl RemoteTimelineIndex {
index
}
pub fn timeline_entry(&self, id: &ZTenantTimelineId) -> Option<&TimelineIndexEntry> {
pub fn timeline_entry(&self, id: &TimelineSyncId) -> Option<&TimelineIndexEntry> {
self.timeline_files.get(id)
}
pub fn timeline_entry_mut(
&mut self,
id: &ZTenantTimelineId,
) -> Option<&mut TimelineIndexEntry> {
pub fn timeline_entry_mut(&mut self, id: &TimelineSyncId) -> Option<&mut TimelineIndexEntry> {
self.timeline_files.get_mut(id)
}
pub fn add_timeline_entry(&mut self, id: ZTenantTimelineId, entry: TimelineIndexEntry) {
pub fn add_timeline_entry(&mut self, id: TimelineSyncId, entry: TimelineIndexEntry) {
self.timeline_files.insert(id, entry);
}
pub fn all_sync_ids(&self) -> impl Iterator<Item = ZTenantTimelineId> + '_ {
pub fn all_sync_ids(&self) -> impl Iterator<Item = TimelineSyncId> + '_ {
self.timeline_files.keys().copied()
}
@@ -351,10 +348,7 @@ fn try_parse_index_entry(
.to_string_lossy()
.to_string();
let sync_id = ZTenantTimelineId {
tenant_id,
timeline_id,
};
let sync_id = TimelineSyncId(tenant_id, timeline_id);
let timeline_index_entry = index
.timeline_files
.entry(sync_id)

View File

@@ -17,7 +17,7 @@ use crate::{
index::{RemoteTimeline, TimelineIndexEntry},
sync_queue, tenant_branch_files, update_index_description, SyncKind, SyncTask,
},
RemoteStorage, ZTenantTimelineId,
RemoteStorage, TimelineSyncId,
},
};
@@ -36,13 +36,12 @@ pub(super) async fn upload_timeline_checkpoint<
>(
config: &'static PageServerConf,
remote_assets: Arc<(S, RwLock<RemoteTimelineIndex>)>,
sync_id: ZTenantTimelineId,
sync_id: TimelineSyncId,
new_checkpoint: NewCheckpoint,
retries: u32,
) -> Option<bool> {
debug!("Uploading checkpoint for sync id {}", sync_id);
if let Err(e) = upload_missing_branches(config, remote_assets.as_ref(), sync_id.tenant_id).await
{
if let Err(e) = upload_missing_branches(config, remote_assets.as_ref(), sync_id.0).await {
error!(
"Failed to upload missing branches for sync id {}: {:?}",
sync_id, e
@@ -58,10 +57,7 @@ pub(super) async fn upload_timeline_checkpoint<
let index = &remote_assets.1;
let ZTenantTimelineId {
tenant_id,
timeline_id,
} = sync_id;
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
let timeline_dir = config.timeline_path(&timeline_id, &tenant_id);
let index_read = index.read().await;
@@ -155,14 +151,11 @@ async fn try_upload_checkpoint<
>(
config: &'static PageServerConf,
remote_assets: Arc<(S, RwLock<RemoteTimelineIndex>)>,
sync_id: ZTenantTimelineId,
sync_id: TimelineSyncId,
new_checkpoint: &NewCheckpoint,
files_to_skip: BTreeSet<PathBuf>,
) -> anyhow::Result<(ArchiveHeader, u64)> {
let ZTenantTimelineId {
tenant_id,
timeline_id,
} = sync_id;
let TimelineSyncId(tenant_id, timeline_id) = sync_id;
let timeline_dir = config.timeline_path(&timeline_id, &tenant_id);
let files_to_upload = new_checkpoint
@@ -295,7 +288,7 @@ mod tests {
#[tokio::test]
async fn reupload_timeline() -> anyhow::Result<()> {
let repo_harness = RepoHarness::create("reupload_timeline")?;
let sync_id = ZTenantTimelineId::new(repo_harness.tenant_id, TIMELINE_ID);
let sync_id = TimelineSyncId(repo_harness.tenant_id, TIMELINE_ID);
let storage = LocalFs::new(tempdir()?.path().to_owned(), &repo_harness.conf.workdir)?;
let index = RwLock::new(RemoteTimelineIndex::try_parse_descriptions_from_paths(
repo_harness.conf,
@@ -491,7 +484,7 @@ mod tests {
#[tokio::test]
async fn reupload_timeline_rejected() -> anyhow::Result<()> {
let repo_harness = RepoHarness::create("reupload_timeline_rejected")?;
let sync_id = ZTenantTimelineId::new(repo_harness.tenant_id, TIMELINE_ID);
let sync_id = TimelineSyncId(repo_harness.tenant_id, TIMELINE_ID);
let storage = LocalFs::new(tempdir()?.path().to_owned(), &repo_harness.conf.workdir)?;
let index = RwLock::new(RemoteTimelineIndex::try_parse_descriptions_from_paths(
repo_harness.conf,

View File

@@ -306,12 +306,8 @@ pub enum ZenithWalRecord {
/// Native PostgreSQL WAL record
Postgres { will_init: bool, rec: Bytes },
/// Clear bits in heap visibility map. ('flags' is bitmap of bits to clear)
ClearVisibilityMapFlags {
new_heap_blkno: Option<u32>,
old_heap_blkno: Option<u32>,
flags: u8,
},
/// Set bits in heap visibility map. (heap blkno, flag bits to clear)
ClearVisibilityMapFlags { heap_blkno: u32, flags: u8 },
/// Mark transaction IDs as committed on a CLOG page
ClogSetCommitted { xids: Vec<TransactionId> },
/// Mark transaction IDs as aborted on a CLOG page

View File

@@ -332,11 +332,8 @@ impl VirtualFile {
// TODO: We could downgrade the locks to read mode before calling
// 'func', to allow a little bit more concurrency, but the standard
// library RwLock doesn't allow downgrading without releasing the lock,
// and that doesn't seem worth the trouble.
//
// XXX: `parking_lot::RwLock` can enable such downgrades, yet its implemenation is fair and
// may deadlock on subsequent read calls.
// Simply replacing all `RwLock` in project causes deadlocks, so use it sparingly.
// and that doesn't seem worth the trouble. (parking_lot RwLock would
// allow it)
let result = STORAGE_IO_TIME
.with_label_values(&[op, &self.tenantid, &self.timelineid])
.observe_closure_duration(|| func(&file));

View File

@@ -37,7 +37,6 @@ use postgres_ffi::xlog_utils::*;
use postgres_ffi::TransactionId;
use postgres_ffi::{pg_constants, CheckPoint};
use zenith_utils::lsn::Lsn;
use zenith_utils::pg_checksum_page::pg_checksum_page;
static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; 8192]);
@@ -330,9 +329,6 @@ impl WalIngest {
}
image[0..4].copy_from_slice(&((lsn.0 >> 32) as u32).to_le_bytes());
image[4..8].copy_from_slice(&(lsn.0 as u32).to_le_bytes());
image[8..10].copy_from_slice(&[0u8; 2]);
let checksum = pg_checksum_page(&image, blk.blkno);
image[8..10].copy_from_slice(&checksum.to_le_bytes());
assert_eq!(image.len(), pg_constants::BLCKSZ as usize);
timeline.put_page_image(tag, blk.blkno, lsn, image.freeze())?;
} else {
@@ -353,25 +349,49 @@ impl WalIngest {
decoded: &mut DecodedWALRecord,
) -> Result<()> {
// Handle VM bit updates that are implicitly part of heap records.
// First, look at the record to determine which VM bits need
// to be cleared. If either of these variables is set, we
// need to clear the corresponding bits in the visibility map.
let mut new_heap_blkno: Option<u32> = None;
let mut old_heap_blkno: Option<u32> = None;
if decoded.xl_rmid == pg_constants::RM_HEAP_ID {
let info = decoded.xl_info & pg_constants::XLOG_HEAP_OPMASK;
if info == pg_constants::XLOG_HEAP_INSERT {
let xlrec = XlHeapInsert::decode(buf);
assert_eq!(0, buf.remaining());
if (xlrec.flags & pg_constants::XLH_INSERT_ALL_VISIBLE_CLEARED) != 0 {
new_heap_blkno = Some(decoded.blocks[0].blkno);
if (xlrec.flags
& (pg_constants::XLH_INSERT_ALL_VISIBLE_CLEARED
| pg_constants::XLH_INSERT_ALL_FROZEN_SET))
!= 0
{
timeline.put_wal_record(
lsn,
RelishTag::Relation(RelTag {
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
spcnode: decoded.blocks[0].rnode_spcnode,
dbnode: decoded.blocks[0].rnode_dbnode,
relnode: decoded.blocks[0].rnode_relnode,
}),
decoded.blocks[0].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
ZenithWalRecord::ClearVisibilityMapFlags {
heap_blkno: decoded.blocks[0].blkno,
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
},
)?;
}
} else if info == pg_constants::XLOG_HEAP_DELETE {
let xlrec = XlHeapDelete::decode(buf);
assert_eq!(0, buf.remaining());
if (xlrec.flags & pg_constants::XLH_DELETE_ALL_VISIBLE_CLEARED) != 0 {
new_heap_blkno = Some(decoded.blocks[0].blkno);
timeline.put_wal_record(
lsn,
RelishTag::Relation(RelTag {
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
spcnode: decoded.blocks[0].rnode_spcnode,
dbnode: decoded.blocks[0].rnode_dbnode,
relnode: decoded.blocks[0].rnode_relnode,
}),
decoded.blocks[0].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
ZenithWalRecord::ClearVisibilityMapFlags {
heap_blkno: decoded.blocks[0].blkno,
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
},
)?;
}
} else if info == pg_constants::XLOG_HEAP_UPDATE
|| info == pg_constants::XLOG_HEAP_HOT_UPDATE
@@ -380,15 +400,39 @@ impl WalIngest {
// the size of tuple data is inferred from the size of the record.
// we can't validate the remaining number of bytes without parsing
// the tuple data.
if (xlrec.flags & pg_constants::XLH_UPDATE_OLD_ALL_VISIBLE_CLEARED) != 0 {
old_heap_blkno = Some(decoded.blocks[0].blkno);
}
if (xlrec.flags & pg_constants::XLH_UPDATE_NEW_ALL_VISIBLE_CLEARED) != 0 {
// PostgreSQL only uses XLH_UPDATE_NEW_ALL_VISIBLE_CLEARED on a
// non-HOT update where the new tuple goes to different page than
// the old one. Otherwise, only XLH_UPDATE_OLD_ALL_VISIBLE_CLEARED is
// set.
new_heap_blkno = Some(decoded.blocks[1].blkno);
timeline.put_wal_record(
lsn,
RelishTag::Relation(RelTag {
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
spcnode: decoded.blocks[0].rnode_spcnode,
dbnode: decoded.blocks[0].rnode_dbnode,
relnode: decoded.blocks[0].rnode_relnode,
}),
decoded.blocks[0].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
ZenithWalRecord::ClearVisibilityMapFlags {
heap_blkno: decoded.blocks[0].blkno,
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
},
)?;
}
if (xlrec.flags & pg_constants::XLH_UPDATE_OLD_ALL_VISIBLE_CLEARED) != 0
&& decoded.blocks.len() > 1
{
timeline.put_wal_record(
lsn,
RelishTag::Relation(RelTag {
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
spcnode: decoded.blocks[1].rnode_spcnode,
dbnode: decoded.blocks[1].rnode_dbnode,
relnode: decoded.blocks[1].rnode_relnode,
}),
decoded.blocks[1].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
ZenithWalRecord::ClearVisibilityMapFlags {
heap_blkno: decoded.blocks[1].blkno,
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
},
)?;
}
}
} else if decoded.xl_rmid == pg_constants::RM_HEAP2_ID {
@@ -404,67 +448,32 @@ impl WalIngest {
};
assert_eq!(offset_array_len, buf.remaining());
if (xlrec.flags & pg_constants::XLH_INSERT_ALL_VISIBLE_CLEARED) != 0 {
new_heap_blkno = Some(decoded.blocks[0].blkno);
// FIXME: why also ALL_FROZEN_SET?
if (xlrec.flags
& (pg_constants::XLH_INSERT_ALL_VISIBLE_CLEARED
| pg_constants::XLH_INSERT_ALL_FROZEN_SET))
!= 0
{
timeline.put_wal_record(
lsn,
RelishTag::Relation(RelTag {
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
spcnode: decoded.blocks[0].rnode_spcnode,
dbnode: decoded.blocks[0].rnode_dbnode,
relnode: decoded.blocks[0].rnode_relnode,
}),
decoded.blocks[0].blkno / pg_constants::HEAPBLOCKS_PER_PAGE as u32,
ZenithWalRecord::ClearVisibilityMapFlags {
heap_blkno: decoded.blocks[0].blkno,
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
},
)?;
}
}
}
// FIXME: What about XLOG_HEAP_LOCK and XLOG_HEAP2_LOCK_UPDATED?
// Clear the VM bits if required.
if new_heap_blkno.is_some() || old_heap_blkno.is_some() {
let vm_relish = RelishTag::Relation(RelTag {
forknum: pg_constants::VISIBILITYMAP_FORKNUM,
spcnode: decoded.blocks[0].rnode_spcnode,
dbnode: decoded.blocks[0].rnode_dbnode,
relnode: decoded.blocks[0].rnode_relnode,
});
let new_vm_blk = new_heap_blkno.map(pg_constants::HEAPBLK_TO_MAPBLOCK);
let old_vm_blk = old_heap_blkno.map(pg_constants::HEAPBLK_TO_MAPBLOCK);
if new_vm_blk == old_vm_blk {
// An UPDATE record that needs to clear the bits for both old and the
// new page, both of which reside on the same VM page.
timeline.put_wal_record(
lsn,
vm_relish,
new_vm_blk.unwrap(),
ZenithWalRecord::ClearVisibilityMapFlags {
new_heap_blkno,
old_heap_blkno,
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
},
)?;
} else {
// Clear VM bits for one heap page, or for two pages that reside on
// different VM pages.
if let Some(new_vm_blk) = new_vm_blk {
timeline.put_wal_record(
lsn,
vm_relish,
new_vm_blk,
ZenithWalRecord::ClearVisibilityMapFlags {
new_heap_blkno,
old_heap_blkno: None,
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
},
)?;
}
if let Some(old_vm_blk) = old_vm_blk {
timeline.put_wal_record(
lsn,
vm_relish,
old_vm_blk,
ZenithWalRecord::ClearVisibilityMapFlags {
new_heap_blkno: None,
old_heap_blkno,
flags: pg_constants::VISIBILITYMAP_VALID_BITS,
},
)?;
}
}
}
Ok(())
}

View File

@@ -13,13 +13,13 @@ use crate::walingest::WalIngest;
use anyhow::{bail, Context, Error, Result};
use bytes::BytesMut;
use lazy_static::lazy_static;
use parking_lot::Mutex;
use postgres_ffi::waldecoder::*;
use postgres_protocol::message::backend::ReplicationMessage;
use postgres_types::PgLsn;
use std::cell::Cell;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Mutex;
use std::thread_local;
use std::time::SystemTime;
use tokio::pin;
@@ -51,7 +51,7 @@ thread_local! {
}
fn drop_wal_receiver(tenantid: ZTenantId, timelineid: ZTimelineId) {
let mut receivers = WAL_RECEIVERS.lock().unwrap();
let mut receivers = WAL_RECEIVERS.lock();
receivers.remove(&(tenantid, timelineid));
}
@@ -62,7 +62,7 @@ pub fn launch_wal_receiver(
timelineid: ZTimelineId,
wal_producer_connstr: &str,
) -> Result<()> {
let mut receivers = WAL_RECEIVERS.lock().unwrap();
let mut receivers = WAL_RECEIVERS.lock();
match receivers.get_mut(&(tenantid, timelineid)) {
Some(receiver) => {
@@ -95,7 +95,7 @@ pub fn launch_wal_receiver(
// Look up current WAL producer connection string in the hash table
fn get_wal_producer_connstr(tenantid: ZTenantId, timelineid: ZTimelineId) -> String {
let receivers = WAL_RECEIVERS.lock().unwrap();
let receivers = WAL_RECEIVERS.lock();
receivers
.get(&(tenantid, timelineid))
@@ -160,7 +160,7 @@ fn walreceiver_main(
// This is from tokio-postgres docs, but it is a bit weird in our case because we extensively use block_on
runtime.spawn(async move {
if let Err(e) = connection.await {
error!("connection error: {}", e);
eprintln!("connection error: {}", e);
}
});

View File

@@ -363,44 +363,25 @@ impl PostgresRedoManager {
will_init: _,
rec: _,
} => panic!("tried to pass postgres wal record to zenith WAL redo"),
ZenithWalRecord::ClearVisibilityMapFlags {
new_heap_blkno,
old_heap_blkno,
flags,
} => {
// sanity check that this is modifying the correct relish
ZenithWalRecord::ClearVisibilityMapFlags { heap_blkno, flags } => {
// Calculate the VM block and offset that corresponds to the heap block.
let map_block = pg_constants::HEAPBLK_TO_MAPBLOCK(*heap_blkno);
let map_byte = pg_constants::HEAPBLK_TO_MAPBYTE(*heap_blkno);
let map_offset = pg_constants::HEAPBLK_TO_OFFSET(*heap_blkno);
// Check that we're modifying the correct VM block.
assert!(
check_forknum(&rel, pg_constants::VISIBILITYMAP_FORKNUM),
"ClearVisibilityMapFlags record on unexpected rel {:?}",
rel
);
if let Some(heap_blkno) = *new_heap_blkno {
// Calculate the VM block and offset that corresponds to the heap block.
let map_block = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blkno);
let map_byte = pg_constants::HEAPBLK_TO_MAPBYTE(heap_blkno);
let map_offset = pg_constants::HEAPBLK_TO_OFFSET(heap_blkno);
assert!(map_block == blknum);
// Check that we're modifying the correct VM block.
assert!(map_block == blknum);
// equivalent to PageGetContents(page)
let map = &mut page[pg_constants::MAXALIGN_SIZE_OF_PAGE_HEADER_DATA..];
// equivalent to PageGetContents(page)
let map = &mut page[pg_constants::MAXALIGN_SIZE_OF_PAGE_HEADER_DATA..];
map[map_byte as usize] &= !(flags << map_offset);
}
// Repeat for 'old_heap_blkno', if any
if let Some(heap_blkno) = *old_heap_blkno {
let map_block = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blkno);
let map_byte = pg_constants::HEAPBLK_TO_MAPBYTE(heap_blkno);
let map_offset = pg_constants::HEAPBLK_TO_OFFSET(heap_blkno);
assert!(map_block == blknum);
let map = &mut page[pg_constants::MAXALIGN_SIZE_OF_PAGE_HEADER_DATA..];
map[map_byte as usize] &= !(flags << map_offset);
}
let mask: u8 = flags << map_offset;
map[map_byte as usize] &= !mask;
}
// Non-relational WAL records are handled here, with custom code that has the
// same effects as the corresponding Postgres WAL redo function.

View File

@@ -1,8 +1,11 @@
[package]
name = "postgres_ffi"
version = "0.1.0"
authors = ["Heikki Linnakangas <heikki@zenith.tech>"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
chrono = "0.4.19"
rand = "0.8.3"

View File

@@ -1,8 +1,11 @@
[package]
name = "proxy"
version = "0.1.0"
authors = ["Stas Kelvich <stas.kelvich@gmail.com>"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0"
bytes = { version = "1.0.1", features = ['serde'] }
@@ -11,13 +14,26 @@ md5 = "0.7.0"
rand = "0.8.3"
hex = "0.4.3"
hyper = "0.14"
routerify = "2"
parking_lot = "0.11.2"
hashbrown = "0.11.2"
serde = "1"
serde_json = "1"
tokio = { version = "1.11", features = ["macros"] }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
clap = "3.0"
tokio-rustls = "0.22.0"
clap = "2.33.0"
rustls = "0.19.1"
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
pin-project-lite = "0.2.7"
futures = "0.3.13"
scopeguard = "1.1.0"
zenith_utils = { path = "../zenith_utils" }
zenith_metrics = { path = "../zenith_metrics" }
base64 = "0.13.0"
async-trait = "0.1.52"
[dev-dependencies]
tokio-postgres-rustls = "0.8.0"
rcgen = "0.8.14"

41
proxy/src/auth.rs Normal file
View File

@@ -0,0 +1,41 @@
use crate::db::AuthSecret;
use crate::stream::PqStream;
use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::BeMessage as Be;
/// Stored secret for authenticating the user via md5 but authenticating
/// to the compute database with a (possibly different) plaintext password.
pub struct PlaintextStoredSecret {
pub salt: [u8; 4],
pub hashed_salted_password: Bytes,
pub compute_db_password: String,
}
/// Sufficient information to auth user and create AuthSecret
#[non_exhaustive]
pub enum StoredSecret {
PlaintextPassword(PlaintextStoredSecret),
// TODO add md5 option?
// TODO add SCRAM option
}
pub async fn authenticate(
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
stored_secret: StoredSecret
) -> anyhow::Result<AuthSecret> {
match stored_secret {
StoredSecret::PlaintextPassword(stored) => {
client.write_message(&Be::AuthenticationMD5Password(&stored.salt)).await?;
let provided = client.read_password_message().await?;
anyhow::ensure!(provided == stored.hashed_salted_password);
Ok(AuthSecret::Password(stored.compute_db_password))
},
}
}
#[async_trait::async_trait]
pub trait SecretStore {
async fn get_stored_secret(&self, creds: &crate::cplane_api::ClientCredentials) -> anyhow::Result<StoredSecret>;
}

90
proxy/src/cancellation.rs Normal file
View File

@@ -0,0 +1,90 @@
use anyhow::{anyhow, Context};
use hashbrown::HashMap;
use lazy_static::lazy_static;
use parking_lot::Mutex;
use std::net::SocketAddr;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use zenith_utils::pq_proto::CancelKeyData;
lazy_static! {
/// Enables serving CancelRequests.
static ref CANCEL_MAP: Mutex<HashMap<CancelKeyData, Option<CancelClosure>>> = Default::default();
}
/// This should've been a [`std::future::Future`], but
/// it's impossible to name a type of an unboxed future
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
#[derive(Clone)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
}
impl CancelClosure {
pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
Self {
socket_addr,
cancel_token,
}
}
/// Cancels the query running on user's compute node.
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
let socket = TcpStream::connect(self.socket_addr).await?;
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
Ok(())
}
}
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(key: CancelKeyData) -> anyhow::Result<()> {
let cancel_closure = CANCEL_MAP
.lock()
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("unknown session: {:?}", key))?;
cancel_closure.try_cancel_query().await
}
/// Helper for registering query cancellation tokens.
pub struct Session(CancelKeyData);
impl Session {
/// Store the cancel token for the given session.
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
CANCEL_MAP.lock().insert(self.0, Some(cancel_closure));
self.0
}
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<F, R, V>(f: F) -> anyhow::Result<V>
where
F: FnOnce(Session) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
let key = rand::random();
// The birthday problem is unlikely to happen here, but it's still possible
CANCEL_MAP
.lock()
.try_insert(key, None)
.map_err(|_| anyhow!("session already exists: {:?}", key))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
CANCEL_MAP.lock().remove(&key);
}
let session = Session(key);
f(session).await
}

7
proxy/src/compute.rs Normal file
View File

@@ -0,0 +1,7 @@
use crate::{cplane_api::ClientCredentials, db::DatabaseConnInfo};
#[async_trait::async_trait]
pub trait ComputeProvider {
async fn get_compute_node(&self, creds: &ClientCredentials) -> anyhow::Result<DatabaseConnInfo>;
}

View File

@@ -1,9 +1,33 @@
use anyhow::{anyhow, bail, Context};
use serde::{Deserialize, Serialize};
use std::net::{SocketAddr, ToSocketAddrs};
use std::collections::HashMap;
use crate::state::ProxyWaiters;
#[derive(Debug, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
type Error = anyhow::Error;
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
let mut get_param = |key| {
value
.remove(key)
.with_context(|| format!("{} is missing in startup packet", key))
};
let user = get_param("user")?;
let db = get_param("database")?;
Ok(Self { user, dbname: db })
}
}
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct DatabaseInfo {
pub host: String,
@@ -21,35 +45,6 @@ enum ProxyAuthResponse {
NotReady { ready: bool }, // TODO: get rid of `ready`
}
impl DatabaseInfo {
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
let host_port = format!("{}:{}", self.host, self.port);
host_port
.to_socket_addrs()
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
.next()
.context("cannot resolve at least one SocketAddr")
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}
pub struct CPlaneApi<'a> {
auth_endpoint: &'a str,
waiters: &'a ProxyWaiters,

58
proxy/src/db.rs Normal file
View File

@@ -0,0 +1,58 @@
///
/// Utils for connecting with the postgres dataabase.
///
use std::net::{SocketAddr, ToSocketAddrs};
use anyhow::{Context, anyhow};
use crate::cplane_api::ClientCredentials;
pub struct DatabaseConnInfo {
pub host: String,
pub port: u16,
}
pub struct DatabaseAuthInfo {
pub conn_info: DatabaseConnInfo,
pub creds: ClientCredentials,
pub auth_secret: AuthSecret,
}
/// Sufficient information to auth with database
#[non_exhaustive]
#[derive(Debug)]
pub enum AuthSecret {
Password(String),
// TODO add SCRAM option
}
impl From<DatabaseAuthInfo> for tokio_postgres::Config {
fn from(auth_info: DatabaseAuthInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&auth_info.conn_info.host)
.port(auth_info.conn_info.port)
.dbname(&auth_info.creds.dbname)
.user(&auth_info.creds.user);
match auth_info.auth_secret {
AuthSecret::Password(password) => {
config.password(password);
}
}
config
}
}
impl DatabaseConnInfo {
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
let host_port = format!("{}:{}", self.host, self.port);
host_port
.to_socket_addrs()
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
.next()
.ok_or_else(|| anyhow!("cannot resolve at least one SocketAddr"))
}
}

View File

@@ -1,6 +1,7 @@
use anyhow::anyhow;
use hyper::{Body, Request, Response, StatusCode};
use zenith_utils::http::RouterBuilder;
use routerify::RouterBuilder;
use std::net::TcpListener;
use zenith_utils::http::endpoint;
use zenith_utils::http::error::ApiError;
use zenith_utils::http::json::json_response;
@@ -9,7 +10,17 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
Ok(json_response(StatusCode::OK, "")?)
}
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
let router = endpoint::make_router();
router.get("/v1/status", status_handler)
}
pub async fn thread_main(http_listener: TcpListener) -> anyhow::Result<()> {
let service = || routerify::RouterService::new(make_router().build()?);
hyper::Server::from_tcp(http_listener)?
.serve(service().map_err(|e| anyhow!(e))?)
.await?;
Ok(())
}

View File

@@ -8,71 +8,76 @@
use anyhow::bail;
use clap::{App, Arg};
use state::{ProxyConfig, ProxyState};
use std::thread;
use zenith_utils::http::endpoint;
use zenith_utils::{tcp_listener, GIT_VERSION};
mod compute;
mod mock;
mod auth;
mod db;
mod cancellation;
mod cplane_api;
mod http;
mod mgmt;
mod proxy;
mod state;
mod stream;
mod waiters;
fn main() -> anyhow::Result<()> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
zenith_metrics::set_common_metrics_prefix("zenith_proxy");
let arg_matches = App::new("Zenith proxy/router")
.version(GIT_VERSION)
.arg(
Arg::new("proxy")
.short('p')
Arg::with_name("proxy")
.short("p")
.long("proxy")
.takes_value(true)
.help("listen for incoming client connections on ip:port")
.default_value("127.0.0.1:4432"),
)
.arg(
Arg::new("mgmt")
.short('m')
Arg::with_name("mgmt")
.short("m")
.long("mgmt")
.takes_value(true)
.help("listen for management callback connection on ip:port")
.default_value("127.0.0.1:7000"),
)
.arg(
Arg::new("http")
.short('h')
Arg::with_name("http")
.short("h")
.long("http")
.takes_value(true)
.help("listen for incoming http connections (metrics, etc) on ip:port")
.default_value("127.0.0.1:7001"),
)
.arg(
Arg::new("uri")
.short('u')
Arg::with_name("uri")
.short("u")
.long("uri")
.takes_value(true)
.help("redirect unauthenticated users to given uri")
.default_value("http://localhost:3000/psql_session/"),
)
.arg(
Arg::new("auth-endpoint")
.short('a')
Arg::with_name("auth-endpoint")
.short("a")
.long("auth-endpoint")
.takes_value(true)
.help("API endpoint for authenticating users")
.default_value("http://localhost:3000/authenticate_proxy_request/"),
)
.arg(
Arg::new("ssl-key")
.short('k')
Arg::with_name("ssl-key")
.short("k")
.long("ssl-key")
.takes_value(true)
.help("path to SSL key for client postgres connections"),
)
.arg(
Arg::new("ssl-cert")
.short('c')
Arg::with_name("ssl-cert")
.short("c")
.long("ssl-cert")
.takes_value(true)
.help("path to SSL cert for client postgres connections"),
@@ -107,35 +112,19 @@ fn main() -> anyhow::Result<()> {
let http_listener = tcp_listener::bind(state.conf.http_address)?;
println!("Starting proxy on {}", state.conf.proxy_address);
let pageserver_listener = tcp_listener::bind(state.conf.proxy_address)?;
let proxy_listener = tokio::net::TcpListener::bind(state.conf.proxy_address).await?;
println!("Starting mgmt on {}", state.conf.mgmt_address);
let mgmt_listener = tcp_listener::bind(state.conf.mgmt_address)?;
let threads = [
thread::Builder::new()
.name("Http thread".into())
.spawn(move || {
let router = http::make_router();
endpoint::serve_thread_main(
router,
http_listener,
std::future::pending(), // never shut down
)
})?,
// Spawn a thread to listen for connections. It will spawn further threads
// for each connection.
thread::Builder::new()
.name("Listener thread".into())
.spawn(move || proxy::thread_main(state, pageserver_listener))?,
thread::Builder::new()
.name("Mgmt thread".into())
.spawn(move || mgmt::thread_main(state, mgmt_listener))?,
];
let http = tokio::spawn(http::thread_main(http_listener));
let proxy = tokio::spawn(proxy::thread_main(state, proxy_listener));
let mgmt = tokio::task::spawn_blocking(move || mgmt::thread_main(state, mgmt_listener));
for t in threads {
t.join().unwrap()?;
}
let _ = futures::future::try_join_all([http, proxy, mgmt])
.await?
.into_iter()
.collect::<Result<Vec<()>, _>>()?;
Ok(())
}

32
proxy/src/mock.rs Normal file
View File

@@ -0,0 +1,32 @@
use bytes::Bytes;
use crate::{auth::{PlaintextStoredSecret, SecretStore, StoredSecret}, compute::ComputeProvider, cplane_api::ClientCredentials, db::DatabaseConnInfo};
pub struct MockConsole {
}
#[async_trait::async_trait]
impl SecretStore for MockConsole {
async fn get_stored_secret(&self, creds: &ClientCredentials) -> anyhow::Result<StoredSecret> {
let salt = [0; 4];
match (&creds.user[..], &creds.dbname[..]) {
("postgres", "postgres") => Ok(StoredSecret::PlaintextPassword(PlaintextStoredSecret {
salt,
hashed_salted_password: "md52fff09cd9def51601fc5445943b3a11f\0".into(),
compute_db_password: "postgres".into(),
})),
_ => unimplemented!()
}
}
}
#[async_trait::async_trait]
impl ComputeProvider for MockConsole{
async fn get_compute_node(&self, creds: &ClientCredentials) -> anyhow::Result<DatabaseConnInfo> {
return Ok(DatabaseConnInfo {
host: "127.0.0.1".into(),
port: 5432,
})
}
}

View File

@@ -1,294 +1,185 @@
use crate::cplane_api::{CPlaneApi, DatabaseInfo};
use crate::auth::{self, StoredSecret, SecretStore};
use crate::cancellation::{self, CancelClosure};
use crate::compute::ComputeProvider;
use crate::cplane_api as cplane;
use crate::db::{AuthSecret, DatabaseAuthInfo};
use crate::mock::MockConsole;
use crate::state::SslConfig;
use crate::stream::{PqStream, Stream};
use crate::ProxyState;
use anyhow::{anyhow, bail, Context};
use anyhow::{bail, Context};
use lazy_static::lazy_static;
use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};
use std::cell::Cell;
use std::collections::HashMap;
use std::net::{SocketAddr, TcpStream};
use std::sync::Mutex;
use std::{io, thread};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter};
use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream};
use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *};
use zenith_utils::sock_split::{ReadStream, WriteStream};
struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: tokio_postgres::CancelToken,
}
impl CancelClosure {
async fn try_cancel_query(&self) {
if let Ok(socket) = tokio::net::TcpStream::connect(self.socket_addr).await {
// NOTE ignoring the result because:
// 1. This is a best effort attempt, the database doesn't have to listen
// 2. Being opaque about errors here helps avoid leaking info to unauthenticated user
let _ = self.cancel_token.cancel_query_raw(socket, NoTls).await;
}
}
}
use zenith_utils::pq_proto::{BeMessage as Be, *};
lazy_static! {
// Enables serving CancelRequests
static ref CANCEL_MAP: Mutex<HashMap<CancelKeyData, CancelClosure>> = Mutex::new(HashMap::new());
// Metrics
static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_accepted"),
"Number of TCP client connections accepted."
).unwrap();
)
.unwrap();
static ref NUM_CONNECTIONS_CLOSED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_closed"),
"Number of TCP client connections closed."
).unwrap();
static ref NUM_CONNECTIONS_FAILED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_failed"),
"Number of TCP client connections that closed due to error."
).unwrap();
)
.unwrap();
static ref NUM_BYTES_PROXIED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_bytes_proxied"),
"Number of bytes sent/received between any client and backend."
).unwrap();
)
.unwrap();
}
thread_local! {
// Used to clean up the CANCEL_MAP. Might not be necessary if we use tokio thread pool in main loop.
static THREAD_CANCEL_KEY_DATA: Cell<Option<CancelKeyData>> = Cell::new(None);
}
///
/// Main proxy listener loop.
///
/// Listens for connections, and launches a new handler thread for each.
///
pub fn thread_main(
pub async fn thread_main(
state: &'static ProxyState,
listener: std::net::TcpListener,
listener: tokio::net::TcpListener,
) -> anyhow::Result<()> {
loop {
let (socket, peer_addr) = listener.accept()?;
let (socket, peer_addr) = listener.accept().await?;
println!("accepted connection from {}", peer_addr);
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
socket.set_nodelay(true).unwrap();
// TODO Use a threadpool instead. Maybe use tokio's threadpool by
// spawning a future into its runtime. Tokio's JoinError should
// allow us to handle cleanup properly even if the future panics.
thread::Builder::new()
.name("Proxy thread".into())
.spawn(move || {
if let Err(err) = proxy_conn_main(state, socket) {
NUM_CONNECTIONS_FAILED_COUNTER.inc();
println!("error: {}", err);
}
tokio::spawn(log_error(async {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
// Clean up CANCEL_MAP.
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
THREAD_CANCEL_KEY_DATA.with(|cell| {
if let Some(cancel_key_data) = cell.get() {
CANCEL_MAP.lock().unwrap().remove(&cancel_key_data);
};
});
})?;
let tls = state.conf.ssl_config.clone();
handle_client(socket, tls).await
}));
}
}
// TODO: clean up fields
struct ProxyConnection {
state: &'static ProxyState,
psql_session_id: String,
pgb: PostgresBackend,
async fn log_error<R, F>(future: F) -> F::Output
where
F: std::future::Future<Output = anyhow::Result<R>>,
{
future.await.map_err(|err| {
println!("error: {}", err.to_string());
err
})
}
pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> {
let conn = ProxyConnection {
state,
psql_session_id: hex::encode(rand::random::<[u8; 8]>()),
pgb: PostgresBackend::new(
socket,
postgres_backend::AuthType::MD5,
state.conf.ssl_config.clone(),
false,
)?,
};
let (client, server) = match conn.handle_client()? {
Some(x) => x,
None => return Ok(()),
};
let server = zenith_utils::sock_split::BidiStream::from_tcp(server);
let client = match client {
Stream::Bidirectional(bidi_stream) => bidi_stream,
_ => panic!("invalid stream type"),
};
proxy(client.split(), server.split())
}
impl ProxyConnection {
/// Returns Ok(None) when connection was successfully closed.
fn handle_client(mut self) -> anyhow::Result<Option<(Stream, TcpStream)>> {
let mut authenticate = || {
let (username, dbname) = match self.handle_startup()? {
Some(x) => x,
None => return Ok(None),
};
// Both scenarios here should end up producing database credentials
if username.ends_with("@zenith") {
self.handle_existing_user(&username, &dbname).map(Some)
} else {
self.handle_new_user().map(Some)
}
};
let conn = match authenticate() {
Ok(Some(db_info)) => connect_to_db(db_info),
Ok(None) => return Ok(None),
Err(e) => {
// Report the error to the client
self.pgb.write_message(&Be::ErrorResponse(&e.to_string()))?;
bail!("failed to handle client: {:?}", e);
}
};
// We'll get rid of this once migration to async is complete
let (pg_version, db_stream) = {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let (pg_version, stream, cancel_key_data) = runtime.block_on(conn)?;
self.pgb
.write_message(&BeMessage::BackendKeyData(cancel_key_data))?;
let stream = stream.into_std()?;
stream.set_nonblocking(false)?;
(pg_version, stream)
};
// Let the client send new requests
self.pgb
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&pg_version),
))?
.write_message(&Be::ReadyForQuery)?;
Ok(Some((self.pgb.into_stream(), db_stream)))
async fn handle_client(
stream: impl AsyncRead + AsyncWrite + Unpin,
tls: Option<SslConfig>,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
scopeguard::defer! {
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
}
/// Returns Ok(None) when connection was successfully closed.
fn handle_startup(&mut self) -> anyhow::Result<Option<(String, String)>> {
let have_tls = self.pgb.tls_config.is_some();
let mut encrypted = false;
if let Some((stream, creds)) = handshake(stream, tls).await? {
cancellation::with_session(|session| async {
connect_client_to_db(stream, creds, session).await
})
.await?;
}
loop {
let msg = match self.pgb.read_message()? {
Some(Fe::StartupPacket(msg)) => msg,
None => bail!("connection is lost"),
bad => bail!("unexpected message type: {:?}", bad),
};
println!("got message: {:?}", msg);
Ok(())
}
match msg {
FeStartupPacket::GssEncRequest => {
self.pgb.write_message(&Be::EncryptionResponse(false))?;
}
FeStartupPacket::SslRequest => {
self.pgb.write_message(&Be::EncryptionResponse(have_tls))?;
if have_tls {
self.pgb.start_tls()?;
encrypted = true;
/// Handle a connection from one client.
/// For better testing experience, `stream` can be
/// any object satisfying the traits.
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<SslConfig>,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, cplane::ClientCredentials)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
println!("got message: {:?}", msg);
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
stream = PqStream::new(stream.into_inner().upgrade(tls).await?);
}
}
FeStartupPacket::StartupMessage { mut params, .. } => {
if have_tls && !encrypted {
bail!("must connect with TLS");
}
_ => bail!("protocol violation"),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
let mut get_param = |key| {
params
.remove(key)
.with_context(|| format!("{} is missing in startup packet", key))
};
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!("protocol violation"),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
let msg = "connection is insecure (try using `sslmode=require`)";
stream.write_message(&Be::ErrorResponse(msg)).await?;
bail!(msg);
}
return Ok(Some((get_param("user")?, get_param("database")?)));
}
FeStartupPacket::CancelRequest(cancel_key_data) => {
if let Some(cancel_closure) = CANCEL_MAP.lock().unwrap().get(&cancel_key_data) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(cancel_closure.try_cancel_query());
}
return Ok(None);
}
break Ok(Some((stream, params.try_into()?)));
}
CancelRequest(cancel_key_data) => {
cancellation::cancel_session(cancel_key_data).await?;
break Ok(None);
}
}
}
}
fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result<DatabaseInfo> {
let md5_salt = rand::random::<[u8; 4]>();
async fn connect_client_to_db(
mut client: PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: cplane::ClientCredentials,
session: cancellation::Session,
) -> anyhow::Result<()> {
// Authenticate
// TODO use real console
let console = MockConsole {};
let stored_secret = console.get_stored_secret(&creds).await?;
let auth_secret = auth::authenticate(&mut client, stored_secret).await?;
let conn_info = console.get_compute_node(&creds).await?;
let db_auth_info = DatabaseAuthInfo {
conn_info,
creds,
auth_secret,
};
// Ask password
self.pgb
.write_message(&Be::AuthenticationMD5Password(&md5_salt))?;
self.pgb.state = ProtoState::Authentication; // XXX
// Connect to db
let (mut db, version, cancel_closure) = connect_to_db(db_auth_info).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
// Check password
let msg = match self.pgb.read_message()? {
Some(Fe::PasswordMessage(msg)) => msg,
None => bail!("connection is lost"),
bad => bail!("unexpected message type: {:?}", bad),
};
println!("got message: {:?}", msg);
// Report success to client
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&version),
))?
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
let (_trailing_null, md5_response) = msg
.split_last()
.ok_or_else(|| anyhow!("unexpected password message"))?;
let mut client = client.into_inner();
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters);
let db_info = cplane.authenticate_proxy_request(
user,
db,
md5_response,
&md5_salt,
&self.psql_session_id,
)?;
self.pgb
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
fn handle_new_user(&mut self) -> anyhow::Result<DatabaseInfo> {
let greeting = hello_message(&self.state.conf.redirect_uri, &self.psql_session_id);
// First, register this session
let waiter = self.state.waiters.register(self.psql_session_id.clone());
// Give user a URL to spawn a new database
self.pgb
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&Be::NoticeResponse(greeting))?;
// Wait for web console response
let db_info = waiter.wait()?.map_err(|e| anyhow!(e))?;
self.pgb
.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
Ok(db_info)
}
Ok(())
}
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
@@ -306,84 +197,147 @@ fn hello_message(redirect_uri: &str, session_id: &str) -> String {
)
}
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
/// Connect to a corresponding compute node.
async fn connect_to_db(
db_info: DatabaseInfo,
) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> {
// Make raw connection. When connect_raw finishes we've received ReadyForQuery.
let socket_addr = db_info.socket_addr()?;
let mut socket = tokio::net::TcpStream::connect(socket_addr).await?;
let config = tokio_postgres::Config::from(db_info);
// NOTE We effectively ignore some ParameterStatus and NoticeResponse
// messages here. Not sure if that could break something.
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
db_info: DatabaseAuthInfo,
) -> anyhow::Result<(TcpStream, String, CancelClosure)> {
// TODO: establish a secure connection to the DB
let socket_addr = db_info.conn_info.socket_addr()?;
let mut socket = TcpStream::connect(socket_addr).await?;
// Save info for potentially cancelling the query later
let mut rng = StdRng::from_entropy();
let cancel_key_data = CancelKeyData {
// HACK We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
backend_pid: rng.gen(),
cancel_key: rng.gen(),
};
let cancel_closure = CancelClosure {
socket_addr,
cancel_token: client.cancel_token(),
};
CANCEL_MAP
.lock()
.unwrap()
.insert(cancel_key_data, cancel_closure);
THREAD_CANCEL_KEY_DATA.with(|cell| {
let prev_value = cell.replace(Some(cancel_key_data));
assert!(
prev_value.is_none(),
"THREAD_CANCEL_KEY_DATA was already set"
);
});
let (client, conn) = tokio_postgres::Config::from(db_info)
.connect_raw(&mut socket, NoTls)
.await?;
let version = conn.parameter("server_version").unwrap();
Ok((version.into(), socket, cancel_key_data))
let version = conn
.parameter("server_version")
.context("failed to fetch postgres server version")?
.into();
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
Ok((socket, version, cancel_closure))
}
/// Concurrently proxy both directions of the client and server connections
fn proxy(
(client_read, client_write): (ReadStream, WriteStream),
(server_read, server_write): (ReadStream, WriteStream),
) -> anyhow::Result<()> {
fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result<u64> {
/// FlushWriter will make sure that every message is sent as soon as possible
struct FlushWriter<W>(W);
#[cfg(test)]
mod tests {
use super::*;
impl<W: io::Write> io::Write for FlushWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
// `std::io::copy` is guaranteed to exit if we return an error,
// so we can afford to lose `res` in case `flush` fails
let res = self.0.write(buf);
if let Ok(count) = res {
NUM_BYTES_PROXIED_COUNTER.inc_by(count as u64);
self.flush()?;
}
res
}
use tokio::io::DuplexStream;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres_rustls::MakeRustlsConnect;
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
async fn dummy_proxy(
client: impl AsyncRead + AsyncWrite + Unpin,
tls: Option<SslConfig>,
) -> anyhow::Result<()> {
// TODO: add some infra + tests for credentials
let (mut stream, _creds) = handshake(client, tls).await?.context("no stream")?;
let res = std::io::copy(&mut reader, &mut FlushWriter(&mut writer));
writer.shutdown(std::net::Shutdown::Both)?;
res
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&BeMessage::ReadyForQuery)
.await?;
Ok(())
}
let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write));
fn generate_certs(
hostname: &str,
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params
})?;
do_proxy(server_read, client_write)?;
client_to_server_jh.join().unwrap()?;
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
Ok((
rustls::Certificate(ca.serialize_der()?),
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
rustls::PrivateKey(cert.serialize_private_key_der()),
))
}
Ok(())
#[tokio::test]
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let server_config = {
let (_ca, cert, key) = generate_certs("localhost")?;
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
config
};
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Disable)
.connect_raw(server, NoTls)
.await
.err() // -> Option<E>
.context("client shouldn't be able to connect")?;
proxy
.await?
.err() // -> Option<E>
.context("server shouldn't accept client")?;
Ok(())
}
#[tokio::test]
async fn handshake_tls() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (ca, cert, key) = generate_certs("localhost")?;
let server_config = {
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
config
};
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
let client_config = {
let mut config = rustls::ClientConfig::new();
config.root_store.add(&ca)?;
config
};
let mut mk = MakeRustlsConnect::new(client_config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, "localhost")?;
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Require)
.connect_raw(server, tls)
.await?;
proxy.await?
}
#[tokio::test]
async fn handshake_raw() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Prefer)
.connect_raw(server, NoTls)
.await?;
proxy.await?
}
}

166
proxy/src/stream.rs Normal file
View File

@@ -0,0 +1,166 @@
use bytes::BytesMut;
use pin_project_lite::pin_project;
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_rustls::server::TlsStream;
use zenith_utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket};
pin_project! {
/// Stream wrapper which implements libpq's protocol.
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
#[pin]
stream: S,
buffer: BytesMut,
}
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S) -> Self {
Self {
stream,
buffer: Default::default(),
}
}
/// Extract the underlying stream.
pub fn into_inner(self) -> S {
self.stream
}
/// Get a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
}
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
pub async fn read_startup_packet(&mut self) -> anyhow::Result<FeStartupPacket> {
match FeStartupPacket::read_fut(&mut self.stream).await? {
Some(FeMessage::StartupPacket(packet)) => Ok(packet),
None => anyhow::bail!("connection is lost"),
other => anyhow::bail!("bad message type: {:?}", other),
}
}
pub async fn read_password_message(&mut self) -> anyhow::Result<bytes::Bytes> {
match FeMessage::read_fut(&mut self.stream).await? {
Some(FeMessage::PasswordMessage(msg)) => Ok(msg),
None => anyhow::bail!("connection is lost"),
other => anyhow::bail!("bad message type: {:?}", other),
}
}
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the message into an internal buffer, but don't flush the underlying stream.
pub fn write_message_noflush<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buffer, message)?;
Ok(self)
}
/// Write the message into an internal buffer and flush it.
pub async fn write_message<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
}
/// Flush the output buffer into the underlying stream.
pub async fn flush(&mut self) -> io::Result<&mut Self> {
self.stream.write_all(&self.buffer).await?;
self.buffer.clear();
self.stream.flush().await?;
Ok(self)
}
}
pin_project! {
/// Wrapper for upgrading raw streams into secure streams.
/// NOTE: it should be possible to decompose this object as necessary.
#[project = StreamProj]
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { #[pin] raw: S },
/// We box [`TlsStream`] since it can be quite large.
Tls { #[pin] tls: Box<TlsStream<S>> },
}
}
impl<S> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> anyhow::Result<Self> {
match self {
Stream::Raw { raw } => {
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
Ok(Stream::Tls { tls })
}
Stream::Tls { .. } => anyhow::bail!("can't upgrade TLS stream"),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_read(context, buf),
Tls { tls } => tls.poll_read(context, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_write(context, buf),
Tls { tls } => tls.poll_write(context, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_flush(context),
Tls { tls } => tls.poll_flush(context),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_shutdown(context),
Tls { tls } => tls.poll_shutdown(context),
}
}
}

View File

@@ -2,7 +2,7 @@
name = "zenith"
version = "0.1.0"
description = ""
authors = []
authors = ["Dmitry Rodionov <dmitry@zenith.tech>"]
[tool.poetry.dependencies]
python = "^3.7"

View File

@@ -1,24 +1,27 @@
#!/bin/bash
# this is a shortcut script to avoid duplication in CI
set -eux -o pipefail
SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
echo "Uploading perf report to zenith pg"
# ingest per test results data into zenith backed postgres running in staging to build grafana reports on that data
DATABASE_URL="$PERF_TEST_RESULT_CONNSTR" poetry run python "$SCRIPT_DIR"/ingest_perf_test_result.py --ingest "$REPORT_FROM"
git clone https://$VIP_VAP_ACCESS_TOKEN@github.com/zenithdb/zenith-perf-data.git
cd zenith-perf-data
mkdir -p reports/
mkdir -p data/$REPORT_TO
# Activate poetry's venv. Needed because git upload does not run in a project dir (it uses tmp to store the repository)
# so the problem occurs because poetry cannot find pyproject.toml in temp dir created by git upload
# shellcheck source=/dev/null
. "$(poetry env info --path)"/bin/activate
cp $REPORT_FROM/* data/$REPORT_TO
echo "Uploading perf result to zenith-perf-data"
scripts/git-upload \
--repo=https://"$VIP_VAP_ACCESS_TOKEN"@github.com/zenithdb/zenith-perf-data.git \
--message="add performance test result for $GITHUB_SHA zenith revision" \
--branch=master \
copy "$REPORT_FROM" "data/$REPORT_TO" `# COPY FROM TO_RELATIVE`\
--merge \
--run-cmd "python $SCRIPT_DIR/generate_perf_report_page.py --input-dir data/$REPORT_TO --out reports/$REPORT_TO.html"
echo "Generating report"
poetry run python $SCRIPT_DIR/generate_perf_report_page.py --input-dir data/$REPORT_TO --out reports/$REPORT_TO.html
echo "Uploading perf result"
git add data reports
git \
-c "user.name=vipvap" \
-c "user.email=vipvap@zenith.tech" \
commit \
--author="vipvap <vipvap@zenith.tech>" \
-m "add performance test result for $GITHUB_SHA zenith revision"
git push https://$VIP_VAP_ACCESS_TOKEN@github.com/zenithdb/zenith-perf-data.git master

View File

@@ -1,9 +1,7 @@
#!/usr/bin/env python3
from contextlib import contextmanager
import shlex
from tempfile import TemporaryDirectory
from distutils.dir_util import copy_tree
from pathlib import Path
import argparse
@@ -11,8 +9,6 @@ import os
import shutil
import subprocess
import sys
import textwrap
from typing import Optional
def absolute_path(path):
@@ -42,21 +38,13 @@ def run(cmd, *args, **kwargs):
class GitRepo:
def __init__(self, url, branch: Optional[str] = None):
def __init__(self, url):
self.url = url
self.cwd = TemporaryDirectory()
self.branch = branch
args = [
'git',
'clone',
'--single-branch',
]
if self.branch:
args.extend(['--branch', self.branch])
subprocess.check_call([
*args,
'git',
'clone',
str(url),
self.cwd.name,
])
@@ -112,44 +100,23 @@ def do_copy(args):
raise FileExistsError(f"File exists: '{dst}'")
if src.is_dir():
if not args.merge:
shutil.rmtree(dst, ignore_errors=True)
# distutils is deprecated, but this is a temporary workaround before python version bump
# here we need dir_exists_ok=True from shutil.copytree which is available in python 3.8+
copy_tree(str(src), str(dst))
shutil.rmtree(dst, ignore_errors=True)
shutil.copytree(src, dst)
else:
shutil.copy(src, dst)
if args.run_cmd:
run(shlex.split(args.run_cmd))
def main():
parser = argparse.ArgumentParser(description='Git upload tool')
parser.add_argument('--repo', type=str, metavar='URL', required=True, help='git repo url')
parser.add_argument('--message', type=str, metavar='TEXT', help='commit message')
parser.add_argument('--branch', type=str, metavar='TEXT', help='target git repo branch')
commands = parser.add_subparsers(title='commands', dest='subparser_name')
p_copy = commands.add_parser(
'copy',
help='copy file into the repo',
formatter_class=argparse.RawTextHelpFormatter,
)
p_copy = commands.add_parser('copy', help='copy file into the repo')
p_copy.add_argument('src', type=absolute_path, help='source path')
p_copy.add_argument('dst', type=relative_path, help='relative dest path')
p_copy.add_argument('--forbid-overwrite', action='store_true', help='do not allow overwrites')
p_copy.add_argument(
'--merge',
action='store_true',
help='when copying a directory do not delete existing data, but add new files')
p_copy.add_argument('--run-cmd',
help=textwrap.dedent('''\
run arbitrary cmd on top of copied files,
example usage is static content generation
based on current repository state\
'''))
args = parser.parse_args()
@@ -160,7 +127,7 @@ def main():
action = commands.get(args.subparser_name)
if action:
message = args.message or 'update'
GitRepo(args.repo, args.branch).update(message, lambda: action(args))
GitRepo(args.repo).update(message, lambda: action(args))
else:
parser.print_usage()

View File

@@ -1,136 +0,0 @@
#!/usr/bin/env python3
import argparse
from contextlib import contextmanager
import json
import os
import psycopg2
import psycopg2.extras
from pathlib import Path
from datetime import datetime
CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS perf_test_results (
id SERIAL PRIMARY KEY,
suit TEXT,
revision CHAR(40),
platform TEXT,
metric_name TEXT,
metric_value NUMERIC,
metric_unit VARCHAR(10),
metric_report_type TEXT,
recorded_at_timestamp TIMESTAMP WITH TIME ZONE DEFAULT NOW()
)
"""
def err(msg):
print(f'error: {msg}')
exit(1)
@contextmanager
def get_connection_cursor():
connstr = os.getenv('DATABASE_URL')
if not connstr:
err('DATABASE_URL environment variable is not set')
with psycopg2.connect(connstr) as conn:
with conn.cursor() as cur:
yield cur
def create_table(cur):
cur.execute(CREATE_TABLE)
def ingest_perf_test_result(cursor, data_dile: Path, recorded_at_timestamp: int) -> int:
run_data = json.loads(data_dile.read_text())
revision = run_data['revision']
platform = run_data['platform']
run_result = run_data['result']
args_list = []
for suit_result in run_result:
suit = suit_result['suit']
total_duration = suit_result['total_duration']
suit_result['data'].append({
'name': 'total_duration',
'value': total_duration,
'unit': 's',
'report': 'lower_is_better',
})
for metric in suit_result['data']:
values = {
'suit': suit,
'revision': revision,
'platform': platform,
'metric_name': metric['name'],
'metric_value': metric['value'],
'metric_unit': metric['unit'],
'metric_report_type': metric['report'],
'recorded_at_timestamp': datetime.utcfromtimestamp(recorded_at_timestamp),
}
args_list.append(values)
psycopg2.extras.execute_values(
cursor,
"""
INSERT INTO perf_test_results (
suit,
revision,
platform,
metric_name,
metric_value,
metric_unit,
metric_report_type,
recorded_at_timestamp
) VALUES %s
""",
args_list,
template="""(
%(suit)s,
%(revision)s,
%(platform)s,
%(metric_name)s,
%(metric_value)s,
%(metric_unit)s,
%(metric_report_type)s,
%(recorded_at_timestamp)s
)""",
)
return len(args_list)
def main():
parser = argparse.ArgumentParser(description='Perf test result uploader. \
Database connection string should be provided via DATABASE_URL environment variable', )
parser.add_argument(
'--ingest',
type=Path,
help='Path to perf test result file, or directory with perf test result files')
parser.add_argument('--initdb', action='store_true', help='Initialuze database')
args = parser.parse_args()
with get_connection_cursor() as cur:
if args.initdb:
create_table(cur)
if not args.ingest.exists():
err(f'ingest path {args.ingest} does not exist')
if args.ingest:
if args.ingest.is_dir():
for item in sorted(args.ingest.iterdir(), key=lambda x: int(x.name.split('_')[0])):
recorded_at_timestamp = int(item.name.split('_')[0])
ingested = ingest_perf_test_result(cur, item, recorded_at_timestamp)
print(f'Ingested {ingested} metric values from {item}')
else:
recorded_at_timestamp = int(args.ingest.name.split('_')[0])
ingested = ingest_perf_test_result(cur, args.ingest, recorded_at_timestamp)
print(f'Ingested {ingested} metric values from {args.ingest}')
if __name__ == '__main__':
main()

View File

@@ -1,10 +1,8 @@
from contextlib import closing
from uuid import UUID
import psycopg2.extras
import psycopg2.errors
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, Postgres
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
import time
def test_timeline_size(zenith_simple_env: ZenithEnv):
@@ -37,96 +35,3 @@ def test_timeline_size(zenith_simple_env: ZenithEnv):
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
# wait until write_lag is 0
def wait_for_pageserver_catchup(pgmain: Postgres, polling_interval=1, timeout=60):
started_at = time.time()
write_lag = 1
while write_lag > 0:
elapsed = time.time() - started_at
if elapsed > timeout:
raise RuntimeError(f"timed out waiting for pageserver to reach pg_current_wal_lsn()")
with closing(pgmain.connect()) as conn:
with conn.cursor() as cur:
cur.execute('''
select pg_size_pretty(pg_cluster_size()),
pg_wal_lsn_diff(pg_current_wal_lsn(),write_lsn) as write_lag,
pg_wal_lsn_diff(pg_current_wal_lsn(),sent_lsn) as pending_lag
FROM pg_stat_get_wal_senders();
''')
res = cur.fetchone()
log.info(
f"pg_cluster_size = {res[0]}, write_lag = {res[1]}, pending_lag = {res[2]}")
write_lag = res[1]
time.sleep(polling_interval)
def test_timeline_size_quota(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init()
env.zenith_cli(["branch", "test_timeline_size_quota", "main"])
client = env.pageserver.http_client()
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size_quota")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
pgmain = env.postgres.create_start(
"test_timeline_size_quota",
# Set small limit for the test
config_lines=['zenith.max_cluster_size=30MB'],
)
log.info("postgres is running on 'test_timeline_size_quota' branch")
with closing(pgmain.connect()) as conn:
with conn.cursor() as cur:
cur.execute("CREATE EXTENSION zenith") # TODO move it to zenith_fixtures?
cur.execute("CREATE TABLE foo (t text)")
wait_for_pageserver_catchup(pgmain)
# Insert many rows. This query must fail because of space limit
try:
cur.execute('''
INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 100000) g
''')
wait_for_pageserver_catchup(pgmain)
cur.execute('''
INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 500000) g
''')
# If we get here, the timeline size limit failed
log.error("Query unexpectedly succeeded")
assert False
except psycopg2.errors.DiskFull as err:
log.info(f"Query expectedly failed with: {err}")
# drop table to free space
cur.execute('DROP TABLE foo')
wait_for_pageserver_catchup(pgmain)
# create it again and insert some rows. This query must succeed
cur.execute("CREATE TABLE foo (t text)")
cur.execute('''
INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 10000) g
''')
wait_for_pageserver_catchup(pgmain)
cur.execute("SELECT * from pg_size_pretty(pg_cluster_size())")
pg_cluster_size = cur.fetchone()
log.info(f"pg_cluster_size = {pg_cluster_size}")

View File

@@ -325,7 +325,7 @@ class ProposerPostgres(PgProtocol):
tenant_id: str,
listen_addr: str,
port: int):
super().__init__(host=listen_addr, port=port, username='zenith_admin')
super().__init__(host=listen_addr, port=port)
self.pgdata_dir: str = pgdata_dir
self.pg_bin: PgBin = pg_bin

View File

@@ -1,9 +1,8 @@
import json
import uuid
import requests
from psycopg2.extensions import cursor as PgCursor
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder
from fixtures.zenith_fixtures import ZenithEnv
from typing import cast
pytest_plugins = ("fixtures.zenith_fixtures")
@@ -106,20 +105,3 @@ def test_cli_tenant_list(zenith_simple_env: ZenithEnv):
assert env.initial_tenant in tenants
assert tenant1 in tenants
assert tenant2 in tenants
def test_cli_ipv4_listeners(zenith_env_builder: ZenithEnvBuilder):
# Start with single sk
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init()
# Connect to sk port on v4 loopback
res = requests.get(f'http://127.0.0.1:{env.safekeepers[0].port.http}/v1/status')
assert res.ok
# FIXME Test setup is using localhost:xx in ps config.
# Perhaps consider switching test suite to v4 loopback.
# Connect to ps port on v4 loopback
# res = requests.get(f'http://127.0.0.1:{env.pageserver.service_port.http}/v1/status')
# assert res.ok

View File

@@ -1,188 +0,0 @@
import pytest
from contextlib import contextmanager
from abc import ABC, abstractmethod
from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, ZenithEnv
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
# Type-related stuff
from typing import Iterator
class PgCompare(ABC):
"""Common interface of all postgres implementations, useful for benchmarks.
This class is a helper class for the zenith_with_baseline fixture. See its documentation
for more details.
"""
@property
@abstractmethod
def pg(self) -> PgProtocol:
pass
@property
@abstractmethod
def pg_bin(self) -> PgBin:
pass
@abstractmethod
def flush(self) -> None:
pass
@abstractmethod
def report_peak_memory_use(self) -> None:
pass
@abstractmethod
def report_size(self) -> None:
pass
@contextmanager
@abstractmethod
def record_pageserver_writes(self, out_name):
pass
@contextmanager
@abstractmethod
def record_duration(self, out_name):
pass
class ZenithCompare(PgCompare):
"""PgCompare interface for the zenith stack."""
def __init__(self,
zenbenchmark: ZenithBenchmarker,
zenith_simple_env: ZenithEnv,
pg_bin: PgBin,
branch_name):
self.env = zenith_simple_env
self.zenbenchmark = zenbenchmark
self._pg_bin = pg_bin
# We only use one branch and one timeline
self.branch = branch_name
self.env.zenith_cli(["branch", self.branch, "empty"])
self._pg = self.env.postgres.create_start(self.branch)
self.timeline = self.pg.safe_psql("SHOW zenith.zenith_timeline")[0][0]
# Long-lived cursor, useful for flushing
self.psconn = self.env.pageserver.connect()
self.pscur = self.psconn.cursor()
@property
def pg(self):
return self._pg
@property
def pg_bin(self):
return self._pg_bin
def flush(self):
self.pscur.execute(f"do_gc {self.env.initial_tenant} {self.timeline} 0")
def report_peak_memory_use(self) -> None:
self.zenbenchmark.record("peak_mem",
self.zenbenchmark.get_peak_mem(self.env.pageserver) / 1024,
'MB',
report=MetricReport.LOWER_IS_BETTER)
def report_size(self) -> None:
timeline_size = self.zenbenchmark.get_timeline_size(self.env.repo_dir,
self.env.initial_tenant,
self.timeline)
self.zenbenchmark.record('size',
timeline_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)
def record_pageserver_writes(self, out_name):
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
def record_duration(self, out_name):
return self.zenbenchmark.record_duration(out_name)
class VanillaCompare(PgCompare):
"""PgCompare interface for vanilla postgres."""
def __init__(self, zenbenchmark, vanilla_pg: VanillaPostgres):
self._pg = vanilla_pg
self.zenbenchmark = zenbenchmark
vanilla_pg.configure(['shared_buffers=1MB'])
vanilla_pg.start()
# Long-lived cursor, useful for flushing
self.conn = self.pg.connect()
self.cur = self.conn.cursor()
@property
def pg(self):
return self._pg
@property
def pg_bin(self):
return self._pg.pg_bin
def flush(self):
self.cur.execute("checkpoint")
def report_peak_memory_use(self) -> None:
pass # TODO find something
def report_size(self) -> None:
data_size = self.pg.get_subdir_size('base')
self.zenbenchmark.record('data_size',
data_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)
wal_size = self.pg.get_subdir_size('pg_wal')
self.zenbenchmark.record('wal_size',
wal_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)
@contextmanager
def record_pageserver_writes(self, out_name):
yield # Do nothing
def record_duration(self, out_name):
return self.zenbenchmark.record_duration(out_name)
@pytest.fixture(scope='function')
def zenith_compare(request, zenbenchmark, pg_bin, zenith_simple_env) -> ZenithCompare:
branch_name = request.node.name
return ZenithCompare(zenbenchmark, zenith_simple_env, pg_bin, branch_name)
@pytest.fixture(scope='function')
def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare:
return VanillaCompare(zenbenchmark, vanilla_pg)
@pytest.fixture(params=["vanilla_compare", "zenith_compare"], ids=["vanilla", "zenith"])
def zenith_with_baseline(request) -> PgCompare:
"""Parameterized fixture that helps compare zenith against vanilla postgres.
A test that uses this fixture turns into a parameterized test that runs against:
1. A vanilla postgres instance
2. A simple zenith env (see zenith_simple_env)
3. Possibly other postgres protocol implementations.
The main goal of this fixture is to make it easier for people to read and write
performance tests. Easy test writing leads to more tests.
Perfect encapsulation of the postgres implementations is **not** a goal because
it's impossible. Operational and configuration differences in the different
implementations sometimes matter, and the writer of the test should be mindful
of that.
If a test requires some one-off special implementation-specific logic, use of
isinstance(zenith_with_baseline, ZenithCompare) is encouraged. Though if that
implementation-specific logic is widely useful across multiple tests, it might
make sense to add methods to the PgCompare class.
"""
fixture = request.getfixturevalue(request.param)
if isinstance(fixture, PgCompare):
return fixture
else:
raise AssertionError(f"test error: fixture {request.param} is not PgCompare")

View File

@@ -184,16 +184,6 @@ def worker_base_port(worker_seq_no: int):
return BASE_PORT + worker_seq_no * WORKER_PORT_NUM
def get_dir_size(path: str) -> int:
"""Return size in bytes."""
totalbytes = 0
for root, dirs, files in os.walk(path):
for name in files:
totalbytes += os.path.getsize(os.path.join(root, name))
return totalbytes
def can_bind(host: str, port: int) -> bool:
"""
Check whether a host:port is available to bind for listening
@@ -240,7 +230,7 @@ class PgProtocol:
def __init__(self, host: str, port: int, username: Optional[str] = None):
self.host = host
self.port = port
self.username = username
self.username = username or "zenith_admin"
def connstr(self,
*,
@@ -252,15 +242,10 @@ class PgProtocol:
"""
username = username or self.username
res = f'host={self.host} port={self.port} dbname={dbname}'
if username:
res = f'{res} user={username}'
if password:
res = f'{res} password={password}'
return res
res = f'host={self.host} port={self.port} user={username} dbname={dbname}'
if not password:
return res
return f'{res} password={password}'
# autocommit=True here by default because that's what we need most of the time
def connect(self,
@@ -850,7 +835,7 @@ class ZenithPageserver(PgProtocol):
port: PageserverPort,
remote_storage: Optional[RemoteStorage] = None,
enable_auth=False):
super().__init__(host='localhost', port=port.pg, username='zenith_admin')
super().__init__(host='localhost', port=port.pg)
self.env = env
self.running = False
self.service_port = port # do not shadow PgProtocol.port which is just int
@@ -988,54 +973,10 @@ def pg_bin(test_output_dir: str) -> PgBin:
return PgBin(test_output_dir)
class VanillaPostgres(PgProtocol):
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int):
super().__init__(host='localhost', port=port)
self.pgdatadir = pgdatadir
self.pg_bin = pg_bin
self.running = False
self.pg_bin.run_capture(['initdb', '-D', pgdatadir])
def configure(self, options: List[str]) -> None:
"""Append lines into postgresql.conf file."""
assert not self.running
with open(os.path.join(self.pgdatadir, 'postgresql.conf'), 'a') as conf_file:
conf_file.writelines(options)
def start(self) -> None:
assert not self.running
self.running = True
self.pg_bin.run_capture(['pg_ctl', '-D', self.pgdatadir, 'start'])
def stop(self) -> None:
assert self.running
self.running = False
self.pg_bin.run_capture(['pg_ctl', '-D', self.pgdatadir, 'stop'])
def get_subdir_size(self, subdir) -> int:
"""Return size of pgdatadir subdirectory in bytes."""
return get_dir_size(os.path.join(self.pgdatadir, subdir))
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
if self.running:
self.stop()
@pytest.fixture(scope='function')
def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
pgdatadir = os.path.join(test_output_dir, "pgdata-vanilla")
pg_bin = PgBin(test_output_dir)
with VanillaPostgres(pgdatadir, pg_bin, 5432) as vanilla_pg:
yield vanilla_pg
class Postgres(PgProtocol):
""" An object representing a running postgres daemon. """
def __init__(self, env: ZenithEnv, tenant_id: str, port: int):
super().__init__(host='localhost', port=port, username='zenith_admin')
super().__init__(host='localhost', port=port)
self.env = env
self.running = False

View File

@@ -2,13 +2,8 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
#
@@ -21,19 +16,47 @@ pytest_plugins = (
# 3. Disk space used
# 4. Peak memory usage
#
def test_bulk_insert(zenith_with_baseline: PgCompare):
env = zenith_with_baseline
def test_bulk_insert(zenith_simple_env: ZenithEnv, zenbenchmark: ZenithBenchmarker):
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_bulk_insert", "empty"])
pg = env.postgres.create_start('test_bulk_insert')
log.info("postgres is running on 'test_bulk_insert' branch")
# Open a connection directly to the page server that we'll use to force
# flushing the layers to disk
psconn = env.pageserver.connect()
pscur = psconn.cursor()
# Get the timeline ID of our branch. We need it for the 'do_gc' command
with closing(env.pg.connect()) as conn:
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
cur.execute("SHOW zenith.zenith_timeline")
timeline = cur.fetchone()[0]
cur.execute("create table huge (i int, j int);")
# Run INSERT, recording the time and I/O it takes
with env.record_pageserver_writes('pageserver_writes'):
with env.record_duration('insert'):
with zenbenchmark.record_pageserver_writes(env.pageserver, 'pageserver_writes'):
with zenbenchmark.record_duration('insert'):
cur.execute("insert into huge values (generate_series(1, 5000000), 0);")
env.flush()
env.report_peak_memory_use()
env.report_size()
# Flush the layers from memory to disk. This is included in the reported
# time and I/O
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
# Record peak memory usage
zenbenchmark.record("peak_mem",
zenbenchmark.get_peak_mem(env.pageserver) / 1024,
'MB',
report=MetricReport.LOWER_IS_BETTER)
# Report disk space used by the repository
timeline_size = zenbenchmark.get_timeline_size(env.repo_dir,
env.initial_tenant,
timeline)
zenbenchmark.record('size',
timeline_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)

View File

@@ -2,15 +2,10 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from io import BufferedReader, RawIOBase
from itertools import repeat
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
class CopyTestData(RawIOBase):
@@ -47,41 +42,77 @@ def copy_test_data(rows: int):
#
# COPY performance tests.
#
def test_copy(zenith_with_baseline: PgCompare):
env = zenith_with_baseline
def test_copy(zenith_simple_env: ZenithEnv, zenbenchmark: ZenithBenchmarker):
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_copy", "empty"])
pg = env.postgres.create_start('test_copy')
log.info("postgres is running on 'test_copy' branch")
# Open a connection directly to the page server that we'll use to force
# flushing the layers to disk
psconn = env.pageserver.connect()
pscur = psconn.cursor()
# Get the timeline ID of our branch. We need it for the pageserver 'checkpoint' command
with closing(env.pg.connect()) as conn:
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
cur.execute("SHOW zenith.zenith_timeline")
timeline = cur.fetchone()[0]
cur.execute("create table copytest (i int, t text);")
# Load data with COPY, recording the time and I/O it takes.
#
# Since there's no data in the table previously, this extends it.
with env.record_pageserver_writes('copy_extend_pageserver_writes'):
with env.record_duration('copy_extend'):
with zenbenchmark.record_pageserver_writes(env.pageserver,
'copy_extend_pageserver_writes'):
with zenbenchmark.record_duration('copy_extend'):
cur.copy_from(copy_test_data(1000000), 'copytest')
env.flush()
# Flush the layers from memory to disk. This is included in the reported
# time and I/O
pscur.execute(f"checkpoint {env.initial_tenant} {timeline}")
# Delete most rows, and VACUUM to make the space available for reuse.
with env.record_pageserver_writes('delete_pageserver_writes'):
with env.record_duration('delete'):
with zenbenchmark.record_pageserver_writes(env.pageserver, 'delete_pageserver_writes'):
with zenbenchmark.record_duration('delete'):
cur.execute("delete from copytest where i % 100 <> 0;")
env.flush()
# Flush the layers from memory to disk. This is included in the reported
# time and I/O
pscur.execute(f"checkpoint {env.initial_tenant} {timeline}")
with env.record_pageserver_writes('vacuum_pageserver_writes'):
with env.record_duration('vacuum'):
with zenbenchmark.record_pageserver_writes(env.pageserver, 'vacuum_pageserver_writes'):
with zenbenchmark.record_duration('vacuum'):
cur.execute("vacuum copytest")
env.flush()
# Flush the layers from memory to disk. This is included in the reported
# time and I/O
pscur.execute(f"checkpoint {env.initial_tenant} {timeline}")
# Load data into the table again. This time, this will use the space free'd
# by the VACUUM.
#
# This will also clear all the VM bits.
with env.record_pageserver_writes('copy_reuse_pageserver_writes'):
with env.record_duration('copy_reuse'):
with zenbenchmark.record_pageserver_writes(env.pageserver,
'copy_reuse_pageserver_writes'):
with zenbenchmark.record_duration('copy_reuse'):
cur.copy_from(copy_test_data(1000000), 'copytest')
env.flush()
env.report_peak_memory_use()
env.report_size()
# Flush the layers from memory to disk. This is included in the reported
# time and I/O
pscur.execute(f"checkpoint {env.initial_tenant} {timeline}")
# Record peak memory usage
zenbenchmark.record("peak_mem",
zenbenchmark.get_peak_mem(env.pageserver) / 1024,
'MB',
report=MetricReport.LOWER_IS_BETTER)
# Report disk space used by the repository
timeline_size = zenbenchmark.get_timeline_size(env.repo_dir,
env.initial_tenant,
timeline)
zenbenchmark.record('size',
timeline_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)

View File

@@ -2,14 +2,9 @@ import os
from contextlib import closing
from fixtures.benchmark_fixture import MetricReport
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from fixtures.log_helper import log
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
#
@@ -17,11 +12,24 @@ pytest_plugins = (
# As of this writing, we're duplicate those giant WAL records for each page,
# which makes the delta layer about 32x larger than it needs to be.
#
def test_gist_buffering_build(zenith_with_baseline: PgCompare):
env = zenith_with_baseline
def test_gist_buffering_build(zenith_simple_env: ZenithEnv, zenbenchmark):
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_gist_buffering_build", "empty"])
with closing(env.pg.connect()) as conn:
pg = env.postgres.create_start('test_gist_buffering_build')
log.info("postgres is running on 'test_gist_buffering_build' branch")
# Open a connection directly to the page server that we'll use to force
# flushing the layers to disk
psconn = env.pageserver.connect()
pscur = psconn.cursor()
# Get the timeline ID of our branch. We need it for the 'do_gc' command
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
cur.execute("SHOW zenith.zenith_timeline")
timeline = cur.fetchone()[0]
# Create test table.
cur.execute("create table gist_point_tbl(id int4, p point)")
@@ -30,12 +38,27 @@ def test_gist_buffering_build(zenith_with_baseline: PgCompare):
)
# Build the index.
with env.record_pageserver_writes('pageserver_writes'):
with env.record_duration('build'):
with zenbenchmark.record_pageserver_writes(env.pageserver, 'pageserver_writes'):
with zenbenchmark.record_duration('build'):
cur.execute(
"create index gist_pointidx2 on gist_point_tbl using gist(p) with (buffering = on)"
)
env.flush()
env.report_peak_memory_use()
env.report_size()
# Flush the layers from memory to disk. This is included in the reported
# time and I/O
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 1000000")
# Record peak memory usage
zenbenchmark.record("peak_mem",
zenbenchmark.get_peak_mem(env.pageserver) / 1024,
'MB',
report=MetricReport.LOWER_IS_BETTER)
# Report disk space used by the repository
timeline_size = zenbenchmark.get_timeline_size(env.repo_dir,
env.initial_tenant,
timeline)
zenbenchmark.record('size',
timeline_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)

View File

@@ -1,16 +1,11 @@
from io import BytesIO
import asyncio
import asyncpg
from fixtures.zenith_fixtures import ZenithEnv, Postgres, PgProtocol
from fixtures.zenith_fixtures import ZenithEnv, Postgres
from fixtures.log_helper import log
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
async def repeat_bytes(buf, repetitions: int):
@@ -18,7 +13,7 @@ async def repeat_bytes(buf, repetitions: int):
yield buf
async def copy_test_data_to_table(pg: PgProtocol, worker_id: int, table_name: str):
async def copy_test_data_to_table(pg: Postgres, worker_id: int, table_name: str):
buf = BytesIO()
for i in range(1000):
buf.write(
@@ -31,7 +26,7 @@ async def copy_test_data_to_table(pg: PgProtocol, worker_id: int, table_name: st
await pg_conn.copy_to_table(table_name, source=copy_input)
async def parallel_load_different_tables(pg: PgProtocol, n_parallel: int):
async def parallel_load_different_tables(pg: Postgres, n_parallel: int):
workers = []
for worker_id in range(n_parallel):
worker = copy_test_data_to_table(pg, worker_id, f'copytest_{worker_id}')
@@ -42,25 +37,54 @@ async def parallel_load_different_tables(pg: PgProtocol, n_parallel: int):
# Load 5 different tables in parallel with COPY TO
def test_parallel_copy_different_tables(zenith_with_baseline: PgCompare, n_parallel=5):
def test_parallel_copy_different_tables(zenith_simple_env: ZenithEnv,
zenbenchmark: ZenithBenchmarker,
n_parallel=5):
env = zenith_with_baseline
conn = env.pg.connect()
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_parallel_copy_different_tables", "empty"])
pg = env.postgres.create_start('test_parallel_copy_different_tables')
log.info("postgres is running on 'test_parallel_copy_different_tables' branch")
# Open a connection directly to the page server that we'll use to force
# flushing the layers to disk
psconn = env.pageserver.connect()
pscur = psconn.cursor()
# Get the timeline ID of our branch. We need it for the 'do_gc' command
conn = pg.connect()
cur = conn.cursor()
cur.execute("SHOW zenith.zenith_timeline")
timeline = cur.fetchone()[0]
for worker_id in range(n_parallel):
cur.execute(f'CREATE TABLE copytest_{worker_id} (i int, t text)')
with env.record_pageserver_writes('pageserver_writes'):
with env.record_duration('load'):
asyncio.run(parallel_load_different_tables(env.pg, n_parallel))
env.flush()
with zenbenchmark.record_pageserver_writes(env.pageserver, 'pageserver_writes'):
with zenbenchmark.record_duration('load'):
asyncio.run(parallel_load_different_tables(pg, n_parallel))
env.report_peak_memory_use()
env.report_size()
# Flush the layers from memory to disk. This is included in the reported
# time and I/O
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
# Record peak memory usage
zenbenchmark.record("peak_mem",
zenbenchmark.get_peak_mem(env.pageserver) / 1024,
'MB',
report=MetricReport.LOWER_IS_BETTER)
# Report disk space used by the repository
timeline_size = zenbenchmark.get_timeline_size(env.repo_dir, env.initial_tenant, timeline)
zenbenchmark.record('size',
timeline_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)
async def parallel_load_same_table(pg: PgProtocol, n_parallel: int):
async def parallel_load_same_table(pg: Postgres, n_parallel: int):
workers = []
for worker_id in range(n_parallel):
worker = copy_test_data_to_table(pg, worker_id, f'copytest')
@@ -71,17 +95,46 @@ async def parallel_load_same_table(pg: PgProtocol, n_parallel: int):
# Load data into one table with COPY TO from 5 parallel connections
def test_parallel_copy_same_table(zenith_with_baseline: PgCompare, n_parallel=5):
env = zenith_with_baseline
conn = env.pg.connect()
def test_parallel_copy_same_table(zenith_simple_env: ZenithEnv,
zenbenchmark: ZenithBenchmarker,
n_parallel=5):
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_parallel_copy_same_table", "empty"])
pg = env.postgres.create_start('test_parallel_copy_same_table')
log.info("postgres is running on 'test_parallel_copy_same_table' branch")
# Open a connection directly to the page server that we'll use to force
# flushing the layers to disk
psconn = env.pageserver.connect()
pscur = psconn.cursor()
# Get the timeline ID of our branch. We need it for the 'do_gc' command
conn = pg.connect()
cur = conn.cursor()
cur.execute("SHOW zenith.zenith_timeline")
timeline = cur.fetchone()[0]
cur.execute(f'CREATE TABLE copytest (i int, t text)')
with env.record_pageserver_writes('pageserver_writes'):
with env.record_duration('load'):
asyncio.run(parallel_load_same_table(env.pg, n_parallel))
env.flush()
with zenbenchmark.record_pageserver_writes(env.pageserver, 'pageserver_writes'):
with zenbenchmark.record_duration('load'):
asyncio.run(parallel_load_same_table(pg, n_parallel))
env.report_peak_memory_use()
env.report_size()
# Flush the layers from memory to disk. This is included in the reported
# time and I/O
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
# Record peak memory usage
zenbenchmark.record("peak_mem",
zenbenchmark.get_peak_mem(env.pageserver) / 1024,
'MB',
report=MetricReport.LOWER_IS_BETTER)
# Report disk space used by the repository
timeline_size = zenbenchmark.get_timeline_size(env.repo_dir, env.initial_tenant, timeline)
zenbenchmark.record('size',
timeline_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)

View File

@@ -1,15 +1,10 @@
from contextlib import closing
from fixtures.zenith_fixtures import PgBin, VanillaPostgres, ZenithEnv
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from fixtures.zenith_fixtures import PgBin, ZenithEnv
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.log_helper import log
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
#
@@ -21,16 +16,47 @@ pytest_plugins = (
# 2. Time to run 5000 pgbench transactions
# 3. Disk space used
#
def test_pgbench(zenith_with_baseline: PgCompare):
env = zenith_with_baseline
def test_pgbench(zenith_simple_env: ZenithEnv, pg_bin: PgBin, zenbenchmark: ZenithBenchmarker):
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_pgbench_perf", "empty"])
with env.record_pageserver_writes('pageserver_writes'):
with env.record_duration('init'):
env.pg_bin.run_capture(['pgbench', '-s5', '-i', env.pg.connstr()])
env.flush()
pg = env.postgres.create_start('test_pgbench_perf')
log.info("postgres is running on 'test_pgbench_perf' branch")
with env.record_duration('5000_xacts'):
env.pg_bin.run_capture(['pgbench', '-c1', '-t5000', env.pg.connstr()])
env.flush()
# Open a connection directly to the page server that we'll use to force
# flushing the layers to disk
psconn = env.pageserver.connect()
pscur = psconn.cursor()
env.report_size()
# Get the timeline ID of our branch. We need it for the 'do_gc' command
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
cur.execute("SHOW zenith.zenith_timeline")
timeline = cur.fetchone()[0]
connstr = pg.connstr()
# Initialize pgbench database, recording the time and I/O it takes
with zenbenchmark.record_pageserver_writes(env.pageserver, 'pageserver_writes'):
with zenbenchmark.record_duration('init'):
pg_bin.run_capture(['pgbench', '-s5', '-i', connstr])
# Flush the layers from memory to disk. This is included in the reported
# time and I/O
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
# Run pgbench for 5000 transactions
with zenbenchmark.record_duration('5000_xacts'):
pg_bin.run_capture(['pgbench', '-c1', '-t5000', connstr])
# Flush the layers to disk again. This is *not' included in the reported time,
# though.
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
# Report disk space used by the repository
timeline_size = zenbenchmark.get_timeline_size(env.repo_dir, env.initial_tenant, timeline)
zenbenchmark.record('size',
timeline_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)

View File

@@ -7,19 +7,24 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.compare_fixtures import PgCompare
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
def test_small_seqscans(zenith_with_baseline: PgCompare):
env = zenith_with_baseline
def test_small_seqscans(zenith_simple_env: ZenithEnv, zenbenchmark: ZenithBenchmarker):
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_small_seqscans", "empty"])
with closing(env.pg.connect()) as conn:
pg = env.postgres.create_start('test_small_seqscans')
log.info("postgres is running on 'test_small_seqscans' branch")
# Open a connection directly to the page server that we'll use to force
# flushing the layers to disk
psconn = env.pageserver.connect()
pscur = psconn.cursor()
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
cur.execute('create table t (i integer);')
cur.execute('insert into t values (generate_series(1,100000));')
@@ -33,6 +38,6 @@ def test_small_seqscans(zenith_with_baseline: PgCompare):
log.info(f"shared_buffers is {row[0]}, table size {row[1]}")
assert int(row[0]) < int(row[1])
with env.record_duration('run'):
with zenbenchmark.record_duration('run'):
for i in range(1000):
cur.execute('select count(*) from t;')

View File

@@ -14,23 +14,32 @@ import os
from contextlib import closing
from fixtures.benchmark_fixture import MetricReport
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from fixtures.log_helper import log
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
def test_write_amplification(zenith_with_baseline: PgCompare):
env = zenith_with_baseline
def test_write_amplification(zenith_simple_env: ZenithEnv, zenbenchmark):
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_write_amplification", "empty"])
with closing(env.pg.connect()) as conn:
pg = env.postgres.create_start('test_write_amplification')
log.info("postgres is running on 'test_write_amplification' branch")
# Open a connection directly to the page server that we'll use to force
# flushing the layers to disk
psconn = env.pageserver.connect()
pscur = psconn.cursor()
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
with env.record_pageserver_writes('pageserver_writes'):
with env.record_duration('run'):
# Get the timeline ID of our branch. We need it for the 'do_gc' command
cur.execute("SHOW zenith.zenith_timeline")
timeline = cur.fetchone()[0]
with zenbenchmark.record_pageserver_writes(env.pageserver, 'pageserver_writes'):
with zenbenchmark.record_duration('run'):
# NOTE: Because each iteration updates every table already created,
# the runtime and write amplification is O(n^2), where n is the
@@ -62,6 +71,13 @@ def test_write_amplification(zenith_with_baseline: PgCompare):
# slower, adding some delays in this loop. But forcing
# the checkpointing and GC makes the test go faster,
# with the same total I/O effect.
env.flush()
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
env.report_size()
# Report disk space used by the repository
timeline_size = zenbenchmark.get_timeline_size(env.repo_dir,
env.initial_tenant,
timeline)
zenbenchmark.record('size',
timeline_size / (1024 * 1024),
'MB',
report=MetricReport.LOWER_IS_BETTER)

View File

@@ -1,18 +1,22 @@
[package]
name = "walkeeper"
version = "0.1.0"
authors = ["Stas Kelvich <stas@zenith.tech>"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
regex = "1.4.5"
bytes = "1.0.1"
byteorder = "1.4.3"
hyper = "0.14"
routerify = "2"
fs2 = "0.4.3"
lazy_static = "1.4.0"
serde_json = "1"
tracing = "0.1.27"
clap = "3.0"
clap = "2.33.0"
daemonize = "0.4.1"
rust-s3 = { version = "0.28", default-features = false, features = ["no-verify-ssl", "tokio-rustls-tls"] }
tokio = { version = "1.11", features = ["macros"] }

View File

@@ -32,22 +32,22 @@ fn main() -> Result<()> {
.about("Store WAL stream to local file system and push it to WAL receivers")
.version(GIT_VERSION)
.arg(
Arg::new("datadir")
.short('D')
Arg::with_name("datadir")
.short("D")
.long("dir")
.takes_value(true)
.help("Path to the safekeeper data directory"),
)
.arg(
Arg::new("listen-pg")
.short('l')
Arg::with_name("listen-pg")
.short("l")
.long("listen-pg")
.alias("listen") // for compatibility
.takes_value(true)
.help(formatcp!("listen for incoming WAL data connections on ip:port (default: {DEFAULT_PG_LISTEN_ADDR})")),
)
.arg(
Arg::new("listen-http")
Arg::with_name("listen-http")
.long("listen-http")
.takes_value(true)
.help(formatcp!("http endpoint address for metrics on ip:port (default: {DEFAULT_HTTP_LISTEN_ADDR})")),
@@ -56,39 +56,39 @@ fn main() -> Result<()> {
// However because this argument is in use by console's e2e tests lets keep it for now and remove separately.
// So currently it is a noop.
.arg(
Arg::new("pageserver")
.short('p')
Arg::with_name("pageserver")
.short("p")
.long("pageserver")
.takes_value(true),
)
.arg(
Arg::new("ttl")
Arg::with_name("ttl")
.long("ttl")
.takes_value(true)
.help("interval for keeping WAL at safekeeper node, after which them will be uploaded to S3 and removed locally"),
)
.arg(
Arg::new("recall")
Arg::with_name("recall")
.long("recall")
.takes_value(true)
.help("Period for requestion pageserver to call for replication"),
)
.arg(
Arg::new("daemonize")
.short('d')
Arg::with_name("daemonize")
.short("d")
.long("daemonize")
.takes_value(false)
.help("Run in the background"),
)
.arg(
Arg::new("no-sync")
.short('n')
Arg::with_name("no-sync")
.short("n")
.long("no-sync")
.takes_value(false)
.help("Do not wait for changes to be written safely to disk"),
)
.arg(
Arg::new("dump-control-file")
Arg::with_name("dump-control-file")
.long("dump-control-file")
.takes_value(true)
.help("Dump control file at path specifed by this argument and exit"),

View File

@@ -33,7 +33,7 @@ async fn request_callback(
tokio::spawn(async move {
if let Err(e) = connection.await {
error!("connection error: {}", e);
eprintln!("connection error: {}", e);
}
});

View File

@@ -1,9 +1,10 @@
use hyper::{Body, Request, Response, StatusCode};
use routerify::ext::RequestExt;
use routerify::RouterBuilder;
use serde::Serialize;
use serde::Serializer;
use std::fmt::Display;
use std::sync::Arc;
use zenith_utils::http::{RequestExt, RouterBuilder};
use zenith_utils::lsn::Lsn;
use zenith_utils::zid::ZTenantTimelineId;

View File

@@ -5,7 +5,6 @@
use anyhow::{bail, Context, Result};
use bytes::Bytes;
use bytes::BytesMut;
use tokio::sync::mpsc::UnboundedSender;
use tracing::*;
use crate::timeline::Timeline;
@@ -19,8 +18,10 @@ use crate::handler::SafekeeperPostgresHandler;
use crate::timeline::TimelineTools;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::pq_proto::{BeMessage, FeMessage};
use zenith_utils::zid::ZTenantId;
use crate::callmemaybe::CallmeEvent;
use crate::callmemaybe::SubscriptionStateKey;
pub struct ReceiveWalConn<'pg> {
/// Postgres connection
@@ -81,23 +82,50 @@ impl<'pg> ReceiveWalConn<'pg> {
let mut msg = self
.read_msg()
.context("failed to receive proposer greeting")?;
let tenant_id: ZTenantId;
match msg {
ProposerAcceptorMessage::Greeting(ref greeting) => {
info!(
"start handshake with wal proposer {} sysid {} timeline {}",
self.peer_addr, greeting.system_id, greeting.tli,
);
tenant_id = greeting.tenant_id;
}
_ => bail!("unexpected message {:?} instead of greeting", msg),
}
// Register the connection and defer unregister.
spg.timeline
.get()
.on_compute_connect(self.pageserver_connstr.as_ref(), &spg.tx)?;
let _guard = ComputeConnectionGuard {
timeline: Arc::clone(spg.timeline.get()),
callmemaybe_tx: spg.tx.clone(),
// Incoming WAL stream resumed, so reset information about the timeline pause.
spg.timeline.get().continue_streaming();
// if requested, ask pageserver to fetch wal from us
// as long as this wal_stream is alive, callmemaybe thread
// will send requests to pageserver
let _guard = match self.pageserver_connstr {
Some(ref pageserver_connstr) => {
// Need to establish replication channel with page server.
// Add far as replication in postgres is initiated by receiver
// we should use callmemaybe mechanism.
let timeline_id = spg.timeline.get().zttid.timeline_id;
let subscription_key = SubscriptionStateKey::new(
tenant_id,
timeline_id,
pageserver_connstr.to_owned(),
);
spg.tx
.send(CallmeEvent::Subscribe(subscription_key))
.unwrap_or_else(|e| {
error!(
"failed to send Subscribe request to callmemaybe thread {}",
e
);
});
// create a guard to unsubscribe callback, when this wal_stream will exit
Some(SendWalHandlerGuard {
timeline: Arc::clone(spg.timeline.get()),
})
}
None => None,
};
loop {
@@ -114,15 +142,12 @@ impl<'pg> ReceiveWalConn<'pg> {
}
}
struct ComputeConnectionGuard {
struct SendWalHandlerGuard {
timeline: Arc<Timeline>,
callmemaybe_tx: UnboundedSender<CallmeEvent>,
}
impl Drop for ComputeConnectionGuard {
impl Drop for SendWalHandlerGuard {
fn drop(&mut self) {
self.timeline
.on_compute_disconnect(&self.callmemaybe_tx)
.unwrap();
self.timeline.stop_streaming();
}
}

View File

@@ -454,24 +454,24 @@ struct SafeKeeperMetrics {
write_wal_seconds: Histogram,
}
struct SafeKeeperMetricsBuilder {
ztli: ZTimelineId,
flush_lsn: Lsn,
commit_lsn: Lsn,
}
impl SafeKeeperMetricsBuilder {
fn build(self) -> SafeKeeperMetrics {
let ztli_str = format!("{}", self.ztli);
let m = SafeKeeperMetrics {
impl SafeKeeperMetrics {
fn new(ztli: ZTimelineId) -> SafeKeeperMetrics {
let ztli_str = format!("{}", ztli);
SafeKeeperMetrics {
flush_lsn: FLUSH_LSN_GAUGE.with_label_values(&[&ztli_str]),
commit_lsn: COMMIT_LSN_GAUGE.with_label_values(&[&ztli_str]),
write_wal_bytes: WRITE_WAL_BYTES.with_label_values(&[&ztli_str]),
write_wal_seconds: WRITE_WAL_SECONDS.with_label_values(&[&ztli_str]),
};
m.flush_lsn.set(u64::from(self.flush_lsn) as f64);
m.commit_lsn.set(u64::from(self.commit_lsn) as f64);
m
}
}
fn new_noname() -> SafeKeeperMetrics {
SafeKeeperMetrics {
flush_lsn: FLUSH_LSN_GAUGE.with_label_values(&["n/a"]),
commit_lsn: COMMIT_LSN_GAUGE.with_label_values(&["n/a"]),
write_wal_bytes: WRITE_WAL_BYTES.with_label_values(&["n/a"]),
write_wal_seconds: WRITE_WAL_SECONDS.with_label_values(&["n/a"]),
}
}
}
@@ -496,25 +496,10 @@ where
ST: Storage,
{
// constructor
pub fn new(
ztli: ZTimelineId,
flush_lsn: Lsn,
storage: ST,
state: SafeKeeperState,
) -> SafeKeeper<ST> {
if state.server.timeline_id != ZTimelineId::from([0u8; 16])
&& ztli != state.server.timeline_id
{
panic!("Calling SafeKeeper::new with inconsistent ztli ({}) and SafeKeeperState.server.timeline_id ({})", ztli, state.server.timeline_id);
}
pub fn new(flush_lsn: Lsn, storage: ST, state: SafeKeeperState) -> SafeKeeper<ST> {
SafeKeeper {
flush_lsn,
metrics: SafeKeeperMetricsBuilder {
ztli,
flush_lsn,
commit_lsn: state.commit_lsn,
}
.build(),
metrics: SafeKeeperMetrics::new_noname(),
commit_lsn: state.commit_lsn,
truncate_lsn: state.truncate_lsn,
storage,
@@ -580,12 +565,7 @@ where
.persist(&self.s)
.context("failed to persist shared state")?;
self.metrics = SafeKeeperMetricsBuilder {
ztli: self.s.server.timeline_id,
flush_lsn: self.flush_lsn,
commit_lsn: self.commit_lsn,
}
.build();
self.metrics = SafeKeeperMetrics::new(self.s.server.timeline_id);
info!(
"processed greeting from proposer {:?}, sending term {:?}",
@@ -661,7 +641,6 @@ where
}
// update our end of WAL pointer
self.flush_lsn = msg.start_streaming_at;
self.metrics.flush_lsn.set(u64::from(self.flush_lsn) as f64);
// and now adopt term history from proposer
self.s.acceptor_state.term_history = msg.term_history.clone();
self.storage.persist(&self.s)?;
@@ -774,7 +753,7 @@ where
}
let resp = self.append_response();
trace!(
info!(
"processed AppendRequest of len {}, end_lsn={:?}, commit_lsn={:?}, truncate_lsn={:?}, resp {:?}",
msg.wal_data.len(),
msg.h.end_lsn,
@@ -815,8 +794,7 @@ mod tests {
let storage = InMemoryStorage {
persisted_state: SafeKeeperState::new(),
};
let ztli = ZTimelineId::from([0u8; 16]);
let mut sk = SafeKeeper::new(ztli, Lsn(0), storage, SafeKeeperState::new());
let mut sk = SafeKeeper::new(Lsn(0), storage, SafeKeeperState::new());
// check voting for 1 is ok
let vote_request = ProposerAcceptorMessage::VoteRequest(VoteRequest { term: 1 });
@@ -831,7 +809,7 @@ mod tests {
let storage = InMemoryStorage {
persisted_state: state.clone(),
};
sk = SafeKeeper::new(ztli, Lsn(0), storage, state);
sk = SafeKeeper::new(Lsn(0), storage, state);
// and ensure voting second time for 1 is not ok
vote_resp = sk.process_msg(&vote_request);
@@ -846,8 +824,7 @@ mod tests {
let storage = InMemoryStorage {
persisted_state: SafeKeeperState::new(),
};
let ztli = ZTimelineId::from([0u8; 16]);
let mut sk = SafeKeeper::new(ztli, Lsn(0), storage, SafeKeeperState::new());
let mut sk = SafeKeeper::new(Lsn(0), storage, SafeKeeperState::new());
let mut ar_hdr = AppendRequestHeader {
term: 1,

View File

@@ -167,7 +167,7 @@ impl ReplicationConn {
let buf = Bytes::copy_from_slice(&m[9..]);
let reply = ZenithFeedback::parse(buf);
trace!("ZenithFeedback is {:?}", reply);
info!("ZenithFeedback is {:?}", reply);
// Only pageserver sends ZenithFeedback, so set the flag.
// This replica is the source of information to resend to compute.
state.zenith_feedback = Some(reply);
@@ -283,14 +283,12 @@ impl ReplicationConn {
if spg.appname == Some("wal_proposer_recovery".to_string()) {
None
} else {
let pageserver_connstr = pageserver_connstr.expect("there should be a pageserver connection string since this is not a wal_proposer_recovery");
let zttid = spg.timeline.get().zttid;
let pageserver_connstr = pageserver_connstr.clone().expect("there should be a pageserver connection string since this is not a wal_proposer_recovery");
let tenant_id = spg.ztenantid.unwrap();
let timeline_id = spg.timeline.get().zttid.timeline_id;
let tx_clone = spg.tx.clone();
let subscription_key = SubscriptionStateKey::new(
zttid.tenant_id,
zttid.timeline_id,
pageserver_connstr.clone(),
);
let subscription_key =
SubscriptionStateKey::new(tenant_id, timeline_id, pageserver_connstr.clone());
spg.tx
.send(CallmeEvent::Pause(subscription_key))
.unwrap_or_else(|e| {
@@ -300,8 +298,8 @@ impl ReplicationConn {
// create a guard to subscribe callback again, when this connection will exit
Some(ReplicationStreamGuard {
tx: tx_clone,
tenant_id: zttid.tenant_id,
timeline_id: zttid.timeline_id,
tenant_id,
timeline_id,
pageserver_connstr,
})
}
@@ -326,10 +324,21 @@ impl ReplicationConn {
if let Some(lsn) = lsn {
end_pos = lsn;
} else {
// TODO: also check once in a while whether we are walsender
// to right pageserver.
if spg.timeline.get().check_deactivate(replica_id, &spg.tx)? {
// Shut down, timeline is suspended.
// Is it time to end streaming to this replica?
if spg.timeline.get().check_stop_streaming(replica_id) {
// this expect should never fail because in wal_proposer_recovery mode stop_pos is set
// and this code is not reachable
let pageserver_connstr = pageserver_connstr
.expect("there should be a pageserver connection string");
let tenant_id = spg.ztenantid.unwrap();
let timeline_id = spg.timeline.get().zttid.timeline_id;
let subscription_key =
SubscriptionStateKey::new(tenant_id, timeline_id, pageserver_connstr);
spg.tx
.send(CallmeEvent::Unsubscribe(subscription_key))
.unwrap_or_else(|e| {
error!("failed to send Pause request to callmemaybe thread {}", e);
});
// TODO create proper error type for this
bail!("end streaming to {:?}", spg.appname);
}
@@ -385,7 +394,7 @@ impl ReplicationConn {
start_pos += send_size as u64;
trace!("sent WAL up to {}", start_pos);
info!("sent WAL up to {}", start_pos);
// Decide whether to reuse this file. If we don't set wal_file here
// a new file will be opened next time.

View File

@@ -12,14 +12,12 @@ use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Condvar, Mutex};
use std::time::Duration;
use tokio::sync::mpsc::UnboundedSender;
use tracing::*;
use zenith_metrics::{register_histogram_vec, Histogram, HistogramVec, DISK_WRITE_SECONDS_BUCKETS};
use zenith_utils::bin_ser::LeSer;
use zenith_utils::lsn::Lsn;
use zenith_utils::zid::ZTenantTimelineId;
use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey};
use crate::safekeeper::{
AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState, ServerInfo,
Storage, SK_FORMAT_VERSION, SK_MAGIC,
@@ -72,25 +70,18 @@ impl ReplicaState {
}
}
/// Shared state associated with database instance
/// Shared state associated with database instance (tenant)
struct SharedState {
/// Safekeeper object
sk: SafeKeeper<FileStorage>,
/// For receiving-sending wal cooperation
/// quorum commit LSN we've notified walsenders about
notified_commit_lsn: Lsn,
// Set stop_lsn to inform WAL senders that it's time to stop sending WAL,
// so that it send all wal up stop_lsn and can safely exit streaming connections.
stop_lsn: Option<Lsn>,
/// State of replicas
replicas: Vec<Option<ReplicaState>>,
/// Inactive clusters shouldn't occupy any resources, so timeline is
/// activated whenever there is a compute connection or pageserver is not
/// caughtup (it must have latest WAL for new compute start) and suspended
/// otherwise.
///
/// TODO: it might be better to remove tli completely from GlobalTimelines
/// when tli is inactive instead of having this flag.
active: bool,
num_computes: u32,
pageserver_connstr: Option<String>,
}
// A named boolean.
@@ -111,125 +102,6 @@ lazy_static! {
}
impl SharedState {
/// Restore SharedState from control file.
/// If create=false and file doesn't exist, bails out.
fn create_restore(
conf: &SafeKeeperConf,
zttid: &ZTenantTimelineId,
create: CreateControlFile,
) -> Result<Self> {
let state = FileStorage::load_control_file_conf(conf, zttid, create)
.context("failed to load from control file")?;
let file_storage = FileStorage::new(zttid, conf);
let flush_lsn = if state.server.wal_seg_size != 0 {
let wal_dir = conf.timeline_dir(zttid);
Lsn(find_end_of_wal(
&wal_dir,
state.server.wal_seg_size as usize,
true,
state.wal_start_lsn,
)?
.0)
} else {
Lsn(0)
};
info!(
"timeline {} created or restored: flush_lsn={}, commit_lsn={}, truncate_lsn={}",
zttid.timeline_id, flush_lsn, state.commit_lsn, state.truncate_lsn,
);
if flush_lsn < state.commit_lsn || flush_lsn < state.truncate_lsn {
warn!("timeline {} potential data loss: flush_lsn by find_end_of_wal is less than either commit_lsn or truncate_lsn from control file", zttid.timeline_id);
}
Ok(Self {
notified_commit_lsn: Lsn(0),
sk: SafeKeeper::new(zttid.timeline_id, flush_lsn, file_storage, state),
replicas: Vec::new(),
active: false,
num_computes: 0,
pageserver_connstr: None,
})
}
/// Activate the timeline: start/change walsender (via callmemaybe).
fn activate(
&mut self,
zttid: &ZTenantTimelineId,
pageserver_connstr: Option<&String>,
callmemaybe_tx: &UnboundedSender<CallmeEvent>,
) -> Result<()> {
if let Some(ref pageserver_connstr) = self.pageserver_connstr {
// unsub old sub. xxx: callmemaybe is going out
let old_subscription_key = SubscriptionStateKey::new(
zttid.tenant_id,
zttid.timeline_id,
pageserver_connstr.to_owned(),
);
callmemaybe_tx
.send(CallmeEvent::Unsubscribe(old_subscription_key))
.unwrap_or_else(|e| {
error!("failed to send Pause request to callmemaybe thread {}", e);
});
}
if let Some(pageserver_connstr) = pageserver_connstr {
let subscription_key = SubscriptionStateKey::new(
zttid.tenant_id,
zttid.timeline_id,
pageserver_connstr.to_owned(),
);
// xx: sending to channel under lock is not very cool, but
// shouldn't be a problem here. If it is, we can grab a counter
// here and later augment channel messages with it.
callmemaybe_tx
.send(CallmeEvent::Subscribe(subscription_key))
.unwrap_or_else(|e| {
error!(
"failed to send Subscribe request to callmemaybe thread {}",
e
);
});
info!(
"timeline {} is subscribed to callmemaybe to {}",
zttid.timeline_id, pageserver_connstr
);
}
self.pageserver_connstr = pageserver_connstr.map(|c| c.to_owned());
self.active = true;
Ok(())
}
/// Deactivate the timeline: stop callmemaybe.
fn deactivate(
&mut self,
zttid: &ZTenantTimelineId,
callmemaybe_tx: &UnboundedSender<CallmeEvent>,
) -> Result<()> {
if self.active {
if let Some(ref pageserver_connstr) = self.pageserver_connstr {
let subscription_key = SubscriptionStateKey::new(
zttid.tenant_id,
zttid.timeline_id,
pageserver_connstr.to_owned(),
);
callmemaybe_tx
.send(CallmeEvent::Unsubscribe(subscription_key))
.unwrap_or_else(|e| {
error!(
"failed to send Unsubscribe request to callmemaybe thread {}",
e
);
});
info!(
"timeline {} is unsubscribed from callmemaybe to {}",
zttid.timeline_id,
self.pageserver_connstr.as_ref().unwrap()
);
}
self.active = false;
}
Ok(())
}
/// Get combined state of all alive replicas
pub fn get_replicas_state(&self) -> ReplicaState {
let mut acc = ReplicaState::new();
@@ -287,6 +159,37 @@ impl SharedState {
self.replicas.push(Some(state));
pos
}
/// Restore SharedState from control file.
/// If create=false and file doesn't exist, bails out.
fn create_restore(
conf: &SafeKeeperConf,
zttid: &ZTenantTimelineId,
create: CreateControlFile,
) -> Result<Self> {
let state = FileStorage::load_control_file_conf(conf, zttid, create)
.context("failed to load from control file")?;
let file_storage = FileStorage::new(zttid, conf);
let flush_lsn = if state.server.wal_seg_size != 0 {
let wal_dir = conf.timeline_dir(zttid);
find_end_of_wal(
&wal_dir,
state.server.wal_seg_size as usize,
true,
state.wal_start_lsn,
)?
.0
} else {
0
};
Ok(Self {
notified_commit_lsn: Lsn(0),
stop_lsn: None,
sk: SafeKeeper::new(Lsn(flush_lsn), file_storage, state),
replicas: Vec::new(),
})
}
}
/// Database instance (tenant)
@@ -306,67 +209,6 @@ impl Timeline {
}
}
/// Register compute connection, starting timeline-related activity if it is
/// not running yet.
/// Can fail only if channel to a static thread got closed, which is not normal at all.
pub fn on_compute_connect(
&self,
pageserver_connstr: Option<&String>,
callmemaybe_tx: &UnboundedSender<CallmeEvent>,
) -> Result<()> {
let mut shared_state = self.mutex.lock().unwrap();
shared_state.num_computes += 1;
// FIXME: currently we always adopt latest pageserver connstr, but we
// should have kind of generations assigned by compute to distinguish
// the latest one or even pass it through consensus to reliably deliver
// to all safekeepers.
shared_state.activate(&self.zttid, pageserver_connstr, callmemaybe_tx)?;
Ok(())
}
/// De-register compute connection, shutting down timeline activity if
/// pageserver doesn't need catchup.
/// Can fail only if channel to a static thread got closed, which is not normal at all.
pub fn on_compute_disconnect(
&self,
callmemaybe_tx: &UnboundedSender<CallmeEvent>,
) -> Result<()> {
let mut shared_state = self.mutex.lock().unwrap();
shared_state.num_computes -= 1;
// If there is no pageserver, can suspend right away; otherwise let
// walsender do that.
if shared_state.num_computes == 0 && shared_state.pageserver_connstr.is_none() {
shared_state.deactivate(&self.zttid, callmemaybe_tx)?;
}
Ok(())
}
/// Deactivate tenant if there is no computes and pageserver is caughtup,
/// assuming the pageserver status is in replica_id.
/// Returns true if deactivated.
pub fn check_deactivate(
&self,
replica_id: usize,
callmemaybe_tx: &UnboundedSender<CallmeEvent>,
) -> Result<bool> {
let mut shared_state = self.mutex.lock().unwrap();
if !shared_state.active {
// already suspended
return Ok(true);
}
if shared_state.num_computes == 0 {
let replica_state = shared_state.replicas[replica_id].unwrap();
let deactivate = shared_state.notified_commit_lsn == Lsn(0) || // no data at all yet
(replica_state.last_received_lsn != Lsn::MAX && // Lsn::MAX means that we don't know the latest LSN yet.
replica_state.last_received_lsn >= shared_state.sk.commit_lsn);
if deactivate {
shared_state.deactivate(&self.zttid, callmemaybe_tx)?;
return Ok(true);
}
}
Ok(false)
}
/// Timed wait for an LSN to be committed.
///
/// Returns the last committed LSN, which will be at least
@@ -400,6 +242,54 @@ impl Timeline {
}
}
// Notify WAL senders that it's time to stop sending WAL
pub fn stop_streaming(&self) {
let mut shared_state = self.mutex.lock().unwrap();
// Ensure that safekeeper sends WAL up to the last known committed LSN.
// It guarantees that pageserver will receive all the latest data
// before walservice disconnects.
shared_state.stop_lsn = Some(shared_state.notified_commit_lsn);
trace!(
"Stopping WAL senders. stop_lsn: {}",
shared_state.notified_commit_lsn
);
}
// Reset stop_lsn notification,
// so that WAL senders will continue sending WAL
pub fn continue_streaming(&self) {
let mut shared_state = self.mutex.lock().unwrap();
shared_state.stop_lsn = None;
}
// Check if it's time to stop streaming to the given replica.
//
// Do not stop streaming until replica is caught up with the stop_lsn.
// This is not necessary for correctness, just an optimization to
// be able to remove WAL from safekeeper and decrease amount of work
// on the next start.
pub fn check_stop_streaming(&self, replica_id: usize) -> bool {
let shared_state = self.mutex.lock().unwrap();
// If stop_lsn is set, it's time to shutdown streaming.
if let Some(stop_lsn_request) = shared_state.stop_lsn {
let replica_state = shared_state.replicas[replica_id].unwrap();
// There is no data to stream, so other clauses don't matter.
if shared_state.notified_commit_lsn == Lsn(0) {
return true;
}
// Lsn::MAX means that we don't know the latest LSN yet.
// That may be a new replica, give it a chance to catch up.
if replica_state.last_received_lsn != Lsn::MAX
// If replica is fully caught up, disconnect it.
&& stop_lsn_request <= replica_state.last_received_lsn
{
return true;
}
}
false
}
/// Pass arrived message to the safekeeper.
pub fn process_msg(
&self,

View File

@@ -1,10 +1,13 @@
[package]
name = "zenith"
version = "0.1.0"
authors = ["Stas Kelvich <stas@zenith.tech>"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
clap = "3.0"
clap = "2.33.0"
anyhow = "1.0"
serde_json = "1"
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }

View File

@@ -1,5 +1,5 @@
use anyhow::{bail, Context, Result};
use clap::{App, AppSettings, Arg, ArgMatches};
use clap::{App, AppSettings, Arg, ArgMatches, SubCommand};
use control_plane::compute::ComputeControlPlane;
use control_plane::local_env;
use control_plane::local_env::LocalEnv;
@@ -67,47 +67,45 @@ struct BranchTreeEl {
// * Providing CLI api to the pageserver
// * TODO: export/import to/from usual postgres
fn main() -> Result<()> {
#[rustfmt::skip] // rustfmt squashes these into a single line otherwise
let pg_node_arg = Arg::new("node")
let pg_node_arg = Arg::with_name("node")
.index(1)
.help("Node name")
.required(true);
#[rustfmt::skip]
let safekeeper_node_arg = Arg::new("node")
let safekeeper_node_arg = Arg::with_name("node")
.index(1)
.help("Node name")
.required(false);
let timeline_arg = Arg::new("timeline")
let timeline_arg = Arg::with_name("timeline")
.index(2)
.help("Branch name or a point-in time specification")
.required(false);
let tenantid_arg = Arg::new("tenantid")
let tenantid_arg = Arg::with_name("tenantid")
.long("tenantid")
.help("Tenant id. Represented as a hexadecimal string 32 symbols length")
.takes_value(true)
.required(false);
let port_arg = Arg::new("port")
let port_arg = Arg::with_name("port")
.long("port")
.required(false)
.value_name("port");
let stop_mode_arg = Arg::new("stop-mode")
.short('m')
let stop_mode_arg = Arg::with_name("stop-mode")
.short("m")
.takes_value(true)
.possible_values(&["fast", "immediate"])
.help("If 'immediate', don't flush repository data at shutdown")
.required(false)
.value_name("stop-mode");
let pageserver_config_args = Arg::new("pageserver-config-override")
let pageserver_config_args = Arg::with_name("pageserver-config-override")
.long("pageserver-config-override")
.takes_value(true)
.number_of_values(1)
.multiple_occurrences(true)
.multiple(true)
.help("Additional pageserver's configuration options or overrides, refer to pageserver's 'config-override' CLI parameter docs for more")
.required(false);
@@ -115,88 +113,88 @@ fn main() -> Result<()> {
.setting(AppSettings::ArgRequiredElseHelp)
.version(GIT_VERSION)
.subcommand(
App::new("init")
SubCommand::with_name("init")
.about("Initialize a new Zenith repository")
.arg(pageserver_config_args.clone())
.arg(
Arg::new("config")
Arg::with_name("config")
.long("config")
.required(false)
.value_name("config"),
)
)
.subcommand(
App::new("branch")
SubCommand::with_name("branch")
.about("Create a new branch")
.arg(Arg::new("branchname").required(false).index(1))
.arg(Arg::new("start-point").required(false).index(2))
.arg(Arg::with_name("branchname").required(false).index(1))
.arg(Arg::with_name("start-point").required(false).index(2))
.arg(tenantid_arg.clone()),
).subcommand(
App::new("tenant")
SubCommand::with_name("tenant")
.setting(AppSettings::ArgRequiredElseHelp)
.about("Manage tenants")
.subcommand(App::new("list"))
.subcommand(App::new("create").arg(Arg::new("tenantid").required(false).index(1)))
.subcommand(SubCommand::with_name("list"))
.subcommand(SubCommand::with_name("create").arg(Arg::with_name("tenantid").required(false).index(1)))
)
.subcommand(
App::new("pageserver")
SubCommand::with_name("pageserver")
.setting(AppSettings::ArgRequiredElseHelp)
.about("Manage pageserver")
.subcommand(App::new("status"))
.subcommand(App::new("start").about("Start local pageserver").arg(pageserver_config_args.clone()))
.subcommand(App::new("stop").about("Stop local pageserver")
.subcommand(SubCommand::with_name("status"))
.subcommand(SubCommand::with_name("start").about("Start local pageserver").arg(pageserver_config_args.clone()))
.subcommand(SubCommand::with_name("stop").about("Stop local pageserver")
.arg(stop_mode_arg.clone()))
.subcommand(App::new("restart").about("Restart local pageserver").arg(pageserver_config_args.clone()))
.subcommand(SubCommand::with_name("restart").about("Restart local pageserver").arg(pageserver_config_args))
)
.subcommand(
App::new("safekeeper")
SubCommand::with_name("safekeeper")
.setting(AppSettings::ArgRequiredElseHelp)
.about("Manage safekeepers")
.subcommand(App::new("start")
.subcommand(SubCommand::with_name("start")
.about("Start local safekeeper")
.arg(safekeeper_node_arg.clone())
)
.subcommand(App::new("stop")
.subcommand(SubCommand::with_name("stop")
.about("Stop local safekeeper")
.arg(safekeeper_node_arg.clone())
.arg(stop_mode_arg.clone())
)
.subcommand(App::new("restart")
.subcommand(SubCommand::with_name("restart")
.about("Restart local safekeeper")
.arg(safekeeper_node_arg.clone())
.arg(stop_mode_arg.clone())
)
)
.subcommand(
App::new("pg")
SubCommand::with_name("pg")
.setting(AppSettings::ArgRequiredElseHelp)
.about("Manage postgres instances")
.subcommand(App::new("list").arg(tenantid_arg.clone()))
.subcommand(App::new("create")
.subcommand(SubCommand::with_name("list").arg(tenantid_arg.clone()))
.subcommand(SubCommand::with_name("create")
.about("Create a postgres compute node")
.arg(pg_node_arg.clone())
.arg(timeline_arg.clone())
.arg(tenantid_arg.clone())
.arg(port_arg.clone())
.arg(
Arg::new("config-only")
Arg::with_name("config-only")
.help("Don't do basebackup, create compute node with only config files")
.long("config-only")
.required(false)
))
.subcommand(App::new("start")
.subcommand(SubCommand::with_name("start")
.about("Start a postgres compute node.\n This command actually creates new node from scratch, but preserves existing config files")
.arg(pg_node_arg.clone())
.arg(timeline_arg.clone())
.arg(tenantid_arg.clone())
.arg(port_arg.clone()))
.subcommand(
App::new("stop")
SubCommand::with_name("stop")
.arg(pg_node_arg.clone())
.arg(timeline_arg.clone())
.arg(tenantid_arg.clone())
.arg(
Arg::new("destroy")
Arg::with_name("destroy")
.help("Also delete data directory (now optional, should be default in future)")
.long("destroy")
.required(false)
@@ -205,21 +203,18 @@ fn main() -> Result<()> {
)
.subcommand(
App::new("start")
SubCommand::with_name("start")
.about("Start page server and safekeepers")
.arg(pageserver_config_args)
)
.subcommand(
App::new("stop")
SubCommand::with_name("stop")
.about("Stop page server and safekeepers")
.arg(stop_mode_arg.clone())
)
.get_matches();
let (sub_name, sub_args) = match matches.subcommand() {
Some(subcommand_data) => subcommand_data,
None => bail!("no subcommand provided"),
};
let (sub_name, sub_args) = matches.subcommand();
let sub_args = sub_args.expect("no subcommand");
// Check for 'zenith init' command first.
let subcmd_result = if sub_name == "init" {
@@ -425,7 +420,7 @@ fn handle_init(init_match: &ArgMatches) -> Result<()> {
Ok(())
}
fn pageserver_config_overrides(init_match: &ArgMatches) -> Vec<&str> {
fn pageserver_config_overrides<'a>(init_match: &'a ArgMatches) -> Vec<&'a str> {
init_match
.values_of("pageserver-config-override")
.into_iter()
@@ -436,12 +431,12 @@ fn pageserver_config_overrides(init_match: &ArgMatches) -> Vec<&str> {
fn handle_tenant(tenant_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
let pageserver = PageServerNode::from_env(env);
match tenant_match.subcommand() {
Some(("list", _)) => {
("list", Some(_)) => {
for t in pageserver.tenant_list()? {
println!("{} {}", t.id, t.state);
}
}
Some(("create", create_match)) => {
("create", Some(create_match)) => {
let tenantid = match create_match.value_of("tenantid") {
Some(tenantid) => ZTenantId::from_str(tenantid)?,
None => ZTenantId::generate(),
@@ -450,8 +445,10 @@ fn handle_tenant(tenant_match: &ArgMatches, env: &local_env::LocalEnv) -> Result
pageserver.tenant_create(tenantid)?;
println!("tenant successfully created on the pageserver");
}
Some((sub_name, _)) => bail!("Unexpected tenant subcommand '{}'", sub_name),
None => bail!("no tenant subcommand provided"),
(sub_name, _) => {
bail!("Unexpected tenant subcommand '{}'", sub_name)
}
}
Ok(())
}
@@ -480,10 +477,8 @@ fn handle_branch(branch_match: &ArgMatches, env: &local_env::LocalEnv) -> Result
}
fn handle_pg(pg_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
let (sub_name, sub_args) = match pg_match.subcommand() {
Some(pg_subcommand_data) => pg_subcommand_data,
None => bail!("no pg subcommand provided"),
};
let (sub_name, sub_args) = pg_match.subcommand();
let sub_args = sub_args.expect("no pg subcommand");
let mut cplane = ComputeControlPlane::load(env.clone())?;
@@ -594,14 +589,14 @@ fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Resul
let pageserver = PageServerNode::from_env(env);
match sub_match.subcommand() {
Some(("start", start_match)) => {
("start", Some(start_match)) => {
if let Err(e) = pageserver.start(&pageserver_config_overrides(start_match)) {
eprintln!("pageserver start failed: {}", e);
exit(1);
}
}
Some(("stop", stop_match)) => {
("stop", Some(stop_match)) => {
let immediate = stop_match.value_of("stop-mode") == Some("immediate");
if let Err(e) = pageserver.stop(immediate) {
@@ -610,7 +605,7 @@ fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Resul
}
}
Some(("restart", restart_match)) => {
("restart", Some(restart_match)) => {
//TODO what shutdown strategy should we use here?
if let Err(e) = pageserver.stop(false) {
eprintln!("pageserver stop failed: {}", e);
@@ -622,8 +617,8 @@ fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Resul
exit(1);
}
}
Some((sub_name, _)) => bail!("Unexpected pageserver subcommand '{}'", sub_name),
None => bail!("no pageserver subcommand provided"),
(sub_name, _) => bail!("Unexpected pageserver subcommand '{}'", sub_name),
}
Ok(())
}
@@ -637,10 +632,8 @@ fn get_safekeeper(env: &local_env::LocalEnv, name: &str) -> Result<SafekeeperNod
}
fn handle_safekeeper(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
let (sub_name, sub_args) = match sub_match.subcommand() {
Some(safekeeper_command_data) => safekeeper_command_data,
None => bail!("no safekeeper subcommand provided"),
};
let (sub_name, sub_args) = sub_match.subcommand();
let sub_args = sub_args.expect("no safekeeper subcommand");
// All the commands take an optional safekeeper name argument
let node_name = sub_args.value_of("node").unwrap_or(DEFAULT_SAFEKEEPER_NAME);

View File

@@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
prometheus = {version = "0.13", default_features=false} # removes protobuf dependency
prometheus = {version = "0.12", default_features=false} # removes protobuf dependency
libc = "0.2"
lazy_static = "1.4"
once_cell = "1.8.0"

View File

@@ -1,6 +1,7 @@
[package]
name = "zenith_utils"
version = "0.1.0"
authors = ["Eric Seppanen <eric@zenith.tech>"]
edition = "2021"
[dependencies]
@@ -12,11 +13,11 @@ lazy_static = "1.4.0"
pin-project-lite = "0.2.7"
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
routerify = "3"
routerify = "2"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
thiserror = "1.0"
tokio = { version = "1.11", features = ["macros"]}
tokio = "1.11"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
nix = "0.23.0"

View File

@@ -85,6 +85,7 @@ pub fn check_permission(claims: &Claims, tenantid: Option<ZTenantId>) -> Result<
}
}
#[derive(Debug)]
pub struct JwtAuth {
decoding_key: DecodingKey<'static>,
validation: Validation,
@@ -112,14 +113,6 @@ impl JwtAuth {
}
}
impl std::fmt::Debug for JwtAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtAuth")
.field("validation", &self.validation)
.finish()
}
}
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn encode_from_key_file(claims: &Claims, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_rsa_pem(key_data)?;

View File

@@ -2,7 +2,3 @@ pub mod endpoint;
pub mod error;
pub mod json;
pub mod request;
/// Current fast way to apply simple http routing in various Zenith binaries.
/// Re-exported for sake of uniform approach, that could be later replaced with better alternatives, if needed.
pub use routerify::{ext::RequestExt, RouterBuilder};

View File

@@ -54,9 +54,6 @@ pub mod nonblock;
// Default signal handling
pub mod signals;
// Postgres checksum calculation
pub mod pg_checksum_page;
// This is a shortcut to embed git sha into binaries and avoid copying the same build script to all packages
//
// we have several cases:

View File

@@ -1,70 +0,0 @@
///
/// Port of Postgres pg_checksum_page
///
const BLCKSZ: usize = 8192;
const N_SUMS: usize = 32;
/* prime multiplier of FNV-1a hash */
const FNV_PRIME: u32 = 16777619;
/*
* Base offsets to initialize each of the parallel FNV hashes into a
* different initial state.
*/
const CHECKSUM_BASE_OFFSETS: [u32; N_SUMS] = [
0x5B1F36E9, 0xB8525960, 0x02AB50AA, 0x1DE66D2A, 0x79FF467A, 0x9BB9F8A3, 0x217E7CD2, 0x83E13D2C,
0xF8D4474F, 0xE39EB970, 0x42C6AE16, 0x993216FA, 0x7B093B5D, 0x98DAFF3C, 0xF718902A, 0x0B1C9CDB,
0xE58F764B, 0x187636BC, 0x5D7B3BB1, 0xE73DE7DE, 0x92BEC979, 0xCCA6C0B2, 0x304A0979, 0x85AA43D4,
0x783125BB, 0x6CA8EAA2, 0xE407EAC6, 0x4B5CFC3E, 0x9FBF8C76, 0x15CA20BE, 0xF2CA9FD3, 0x959BD756,
];
/*
* Calculate one round of the checksum.
*/
fn checksum_comp(checksum: u32, value: u32) -> u32 {
let tmp = checksum ^ value;
tmp.wrapping_mul(FNV_PRIME) ^ (tmp >> 17)
}
/*
* Compute the checksum for a Postgres page.
*
* The page must be adequately aligned (at least on a 4-byte boundary).
* Beware also that the checksum field of the page is transiently zeroed.
*
* The checksum includes the block number (to detect the case where a page is
* somehow moved to a different location), the page header (excluding the
* checksum itself), and the page data.
*/
pub fn pg_checksum_page(data: &[u8], blkno: u32) -> u16 {
let page = unsafe { std::mem::transmute::<&[u8], &[u32]>(data) };
let mut checksum: u32 = 0;
let mut sums = CHECKSUM_BASE_OFFSETS;
/* main checksum calculation */
for i in 0..(BLCKSZ / (4 * N_SUMS)) {
for j in 0..N_SUMS {
sums[j] = checksum_comp(sums[j], page[i * N_SUMS + j]);
}
}
/* finally add in two rounds of zeroes for additional mixing */
for i in 0..2 {
for j in 0..N_SUMS {
sums[i] = checksum_comp(sums[j], 0);
}
}
/* xor fold partial checksums together */
for sum in sums {
checksum ^= sum;
}
/* Mix in the block number to detect transposed pages */
checksum ^= blkno;
/*
* Reduce to a uint16 (to fit in the pd_checksum field) with an offset of
* one. That avoids checksums of zero, which seems like a good idea.
*/
((checksum % 65535) + 1) as u16
}

View File

@@ -13,7 +13,7 @@ use std::io::{self, Cursor};
use std::str;
use std::time::{Duration, SystemTime};
use tokio::io::AsyncReadExt;
use tracing::{trace, warn};
use tracing::info;
pub type Oid = u32;
pub type SystemId = u64;
@@ -57,6 +57,16 @@ pub struct CancelKeyData {
pub cancel_key: i32,
}
use rand::distributions::{Distribution, Standard};
impl Distribution<CancelKeyData> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
CancelKeyData {
backend_pid: rng.gen(),
cancel_key: rng.gen(),
}
}
}
#[derive(Debug)]
pub struct FeQueryMessage {
pub body: Bytes,
@@ -956,7 +966,7 @@ impl ZenithFeedback {
}
_ => {
let len = buf.get_i32();
warn!(
info!(
"ZenithFeedback parse. unknown key {} of len {}. Skip it.",
key, len
);
@@ -964,7 +974,7 @@ impl ZenithFeedback {
}
}
}
trace!("ZenithFeedback parsed is {:?}", zf);
info!("ZenithFeedback parsed is {:?}", zf);
zf
}
}

View File

@@ -196,7 +196,7 @@ pub mod opt_display_serde {
}
// A pair uniquely identifying Zenith instance.
#[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ZTenantTimelineId {
pub tenant_id: ZTenantId,
pub timeline_id: ZTimelineId,