diff --git a/.circleci/ansible/deploy.yaml b/.circleci/ansible/deploy.yaml index 2112102aa7..508843812a 100644 --- a/.circleci/ansible/deploy.yaml +++ b/.circleci/ansible/deploy.yaml @@ -63,21 +63,18 @@ tags: - pageserver - # It seems that currently S3 integration does not play well - # even with fresh pageserver without a burden of old data. - # TODO: turn this back on once the issue is solved. - # - name: update remote storage (s3) config - # lineinfile: - # path: /storage/pageserver/data/pageserver.toml - # line: "{{ item }}" - # loop: - # - "[remote_storage]" - # - "bucket_name = '{{ bucket_name }}'" - # - "bucket_region = '{{ bucket_region }}'" - # - "prefix_in_bucket = '{{ inventory_hostname }}'" - # become: true - # tags: - # - pageserver + - name: update remote storage (s3) config + lineinfile: + path: /storage/pageserver/data/pageserver.toml + line: "{{ item }}" + loop: + - "[remote_storage]" + - "bucket_name = '{{ bucket_name }}'" + - "bucket_region = '{{ bucket_region }}'" + - "prefix_in_bucket = '{{ inventory_hostname }}'" + become: true + tags: + - pageserver - name: upload systemd service definition ansible.builtin.template: diff --git a/.circleci/ansible/staging.hosts b/.circleci/ansible/staging.hosts index f6b7bf009f..69f058c2b9 100644 --- a/.circleci/ansible/staging.hosts +++ b/.circleci/ansible/staging.hosts @@ -5,7 +5,6 @@ zenith-us-stage-ps-2 console_region_id=27 [safekeepers] zenith-us-stage-sk-1 console_region_id=27 zenith-us-stage-sk-2 console_region_id=27 -zenith-us-stage-sk-3 console_region_id=27 zenith-us-stage-sk-4 console_region_id=27 [storage:children] diff --git a/.circleci/config.yml b/.circleci/config.yml index e96964558b..f05e64072a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -405,7 +405,7 @@ jobs: - run: name: Build coverage report command: | - COMMIT_URL=https://github.com/zenithdb/zenith/commit/$CIRCLE_SHA1 + COMMIT_URL=https://github.com/neondatabase/neon/commit/$CIRCLE_SHA1 scripts/coverage \ --dir=/tmp/zenith/coverage report \ @@ -416,8 +416,8 @@ jobs: name: Upload coverage report command: | LOCAL_REPO=$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME - REPORT_URL=https://zenithdb.github.io/zenith-coverage-data/$CIRCLE_SHA1 - COMMIT_URL=https://github.com/zenithdb/zenith/commit/$CIRCLE_SHA1 + REPORT_URL=https://neondatabase.github.io/zenith-coverage-data/$CIRCLE_SHA1 + COMMIT_URL=https://github.com/neondatabase/neon/commit/$CIRCLE_SHA1 scripts/git-upload \ --repo=https://$VIP_VAP_ACCESS_TOKEN@github.com/zenithdb/zenith-coverage-data.git \ @@ -593,7 +593,7 @@ jobs: name: Setup helm v3 command: | curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash - helm repo add zenithdb https://zenithdb.github.io/helm-charts + helm repo add zenithdb https://neondatabase.github.io/helm-charts - run: name: Re-deploy proxy command: | @@ -643,7 +643,7 @@ jobs: name: Setup helm v3 command: | curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash - helm repo add zenithdb https://zenithdb.github.io/helm-charts + helm repo add zenithdb https://neondatabase.github.io/helm-charts - run: name: Re-deploy proxy command: | @@ -672,7 +672,7 @@ jobs: --data \ "{ \"state\": \"pending\", - \"context\": \"zenith-remote-ci\", + \"context\": \"neon-cloud-e2e\", \"description\": \"[$REMOTE_REPO] Remote CI job is about to start\" }" - run: @@ -688,7 +688,7 @@ jobs: "{ \"ref\": \"main\", \"inputs\": { - \"ci_job_name\": \"zenith-remote-ci\", + \"ci_job_name\": \"neon-cloud-e2e\", \"commit_hash\": \"$CIRCLE_SHA1\", \"remote_repo\": \"$LOCAL_REPO\" } @@ -828,11 +828,11 @@ workflows: - remote-ci-trigger: # Context passes credentials for gh api context: CI_ACCESS_TOKEN - remote_repo: "zenithdb/console" + remote_repo: "neondatabase/cloud" requires: # XXX: Successful build doesn't mean everything is OK, but # the job to be triggered takes so much time to complete (~22 min) # that it's better not to wait for the commented-out steps - - build-zenith-debug + - build-zenith-release # - pg_regress-tests-release # - other-tests-release diff --git a/.github/workflows/benchmarking.yml b/.github/workflows/benchmarking.yml index 36df35297d..72041c9d02 100644 --- a/.github/workflows/benchmarking.yml +++ b/.github/workflows/benchmarking.yml @@ -26,7 +26,7 @@ jobs: runs-on: [self-hosted, zenith-benchmarker] env: - PG_BIN: "/usr/pgsql-13/bin" + POSTGRES_DISTRIB_DIR: "/usr/pgsql-13" steps: - name: Checkout zenith repo @@ -51,7 +51,7 @@ jobs: echo Poetry poetry --version echo Pgbench - $PG_BIN/pgbench --version + $POSTGRES_DISTRIB_DIR/bin/pgbench --version # FIXME cluster setup is skipped due to various changes in console API # for now pre created cluster is used. When API gain some stability @@ -66,7 +66,7 @@ jobs: echo "Starting cluster" # wake up the cluster - $PG_BIN/psql $BENCHMARK_CONNSTR -c "SELECT 1" + $POSTGRES_DISTRIB_DIR/bin/psql $BENCHMARK_CONNSTR -c "SELECT 1" - name: Run benchmark # pgbench is installed system wide from official repo @@ -83,8 +83,11 @@ jobs: # sudo yum install postgresql13-contrib # actual binaries are located in /usr/pgsql-13/bin/ env: - TEST_PG_BENCH_TRANSACTIONS_MATRIX: "5000,10000,20000" - TEST_PG_BENCH_SCALES_MATRIX: "10,15" + # The pgbench test runs two tests of given duration against each scale. + # So the total runtime with these parameters is 2 * 2 * 300 = 1200, or 20 minutes. + # Plus time needed to initialize the test databases. + TEST_PG_BENCH_DURATIONS_MATRIX: "300" + TEST_PG_BENCH_SCALES_MATRIX: "10,100" PLATFORM: "zenith-staging" BENCHMARK_CONNSTR: "${{ secrets.BENCHMARK_STAGING_CONNSTR }}" REMOTE_ENV: "1" # indicate to test harness that we do not have zenith binaries locally diff --git a/Cargo.lock b/Cargo.lock index 05fdeff6ff..6a8528618a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -361,6 +361,7 @@ dependencies = [ "serde_json", "tar", "tokio", + "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "workspace_hack", ] @@ -1571,7 +1572,6 @@ dependencies = [ "tokio-util 0.7.0", "toml_edit", "tracing", - "tracing-futures", "url", "workspace_hack", "zenith_metrics", @@ -1952,12 +1952,15 @@ name = "proxy" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "base64 0.13.0", "bytes", "clap 3.1.8", "fail", "futures", "hashbrown", "hex", + "hmac 0.10.1", "hyper", "lazy_static", "md5", @@ -1966,10 +1969,13 @@ dependencies = [ "rand", "rcgen", "reqwest", + "routerify 2.2.0", + "rstest", "rustls 0.19.1", "scopeguard", "serde", "serde_json", + "sha2", "socket2", "thiserror", "tokio", @@ -2151,7 +2157,6 @@ dependencies = [ "serde_urlencoded", "tokio", "tokio-rustls 0.23.2", - "tokio-util 0.6.9", "url", "wasm-bindgen", "wasm-bindgen-futures", @@ -2175,6 +2180,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "routerify" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c6bb49594c791cadb5ccfa5f36d41b498d40482595c199d10cd318800280bd9" +dependencies = [ + "http", + "hyper", + "lazy_static", + "percent-encoding", + "regex", +] + [[package]] name = "routerify" version = "3.0.0" @@ -2188,6 +2206,19 @@ dependencies = [ "regex", ] +[[package]] +name = "rstest" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d912f35156a3f99a66ee3e11ac2e0b3f34ac85a07e05263d05a7e2c8810d616f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "rustc_version", + "syn", +] + [[package]] name = "rusoto_core" version = "0.47.0" @@ -3403,19 +3434,20 @@ dependencies = [ "anyhow", "bytes", "cc", + "chrono", "clap 2.34.0", "either", "hashbrown", + "indexmap", "libc", "log", "memchr", "num-integer", "num-traits", - "proc-macro2", - "quote", + "prost", + "rand", "regex", "regex-syntax", - "reqwest", "scopeguard", "serde", "syn", @@ -3495,7 +3527,7 @@ dependencies = [ "postgres 0.19.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "rand", - "routerify", + "routerify 3.0.0", "rustls 0.19.1", "rustls-split", "serde", diff --git a/README.md b/README.md index c8acf526b9..f99785e683 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,22 @@ -# Zenith +# Neon -Zenith is a serverless open source alternative to AWS Aurora Postgres. It separates storage and compute and substitutes PostgreSQL storage layer by redistributing data across a cluster of nodes. +Neon is a serverless open source alternative to AWS Aurora Postgres. It separates storage and compute and substitutes PostgreSQL storage layer by redistributing data across a cluster of nodes. + +The project used to be called "Zenith". Many of the commands and code comments +still refer to "zenith", but we are in the process of renaming things. ## Architecture overview -A Zenith installation consists of compute nodes and Zenith storage engine. +A Neon installation consists of compute nodes and Neon storage engine. -Compute nodes are stateless PostgreSQL nodes, backed by Zenith storage engine. +Compute nodes are stateless PostgreSQL nodes, backed by Neon storage engine. -Zenith storage engine consists of two major components: +Neon storage engine consists of two major components: - Pageserver. Scalable storage backend for compute nodes. - WAL service. The service that receives WAL from compute node and ensures that it is stored durably. Pageserver consists of: -- Repository - Zenith storage implementation. +- Repository - Neon storage implementation. - WAL receiver - service that receives WAL from WAL service and stores it in the repository. - Page service - service that communicates with compute nodes and responds with pages from the repository. - WAL redo - service that builds pages from base images and WAL records on Page service request. @@ -35,10 +38,10 @@ To run the `psql` client, install the `postgresql-client` package or modify `PAT To run the integration tests or Python scripts (not required to use the code), install Python (3.7 or higher), and install python3 packages using `./scripts/pysync` (requires poetry) in the project directory. -2. Build zenith and patched postgres +2. Build neon and patched postgres ```sh -git clone --recursive https://github.com/zenithdb/zenith.git -cd zenith +git clone --recursive https://github.com/neondatabase/neon.git +cd neon make -j5 ``` @@ -126,7 +129,7 @@ INSERT 0 1 ## Running tests ```sh -git clone --recursive https://github.com/zenithdb/zenith.git +git clone --recursive https://github.com/neondatabase/neon.git make # builds also postgres and installs it to ./tmp_install ./scripts/pytest ``` @@ -141,14 +144,14 @@ To view your `rustdoc` documentation in a browser, try running `cargo doc --no-d ### Postgres-specific terms -Due to Zenith's very close relation with PostgreSQL internals, there are numerous specific terms used. +Due to Neon's very close relation with PostgreSQL internals, there are numerous specific terms used. Same applies to certain spelling: i.e. we use MB to denote 1024 * 1024 bytes, while MiB would be technically more correct, it's inconsistent with what PostgreSQL code and its documentation use. To get more familiar with this aspect, refer to: -- [Zenith glossary](/docs/glossary.md) +- [Neon glossary](/docs/glossary.md) - [PostgreSQL glossary](https://www.postgresql.org/docs/13/glossary.html) -- Other PostgreSQL documentation and sources (Zenith fork sources can be found [here](https://github.com/zenithdb/postgres)) +- Other PostgreSQL documentation and sources (Neon fork sources can be found [here](https://github.com/neondatabase/postgres)) ## Join the development diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index 56047093f1..fc52ce4e83 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -17,4 +17,5 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1" tar = "0.4" tokio = { version = "1.17", features = ["macros", "rt", "rt-multi-thread"] } +tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } workspace_hack = { version = "0.1", path = "../workspace_hack" } diff --git a/compute_tools/src/bin/zenith_ctl.rs b/compute_tools/src/bin/zenith_ctl.rs index 7eba4e6011..da238320a7 100644 --- a/compute_tools/src/bin/zenith_ctl.rs +++ b/compute_tools/src/bin/zenith_ctl.rs @@ -38,6 +38,7 @@ use clap::Arg; use log::info; use postgres::{Client, NoTls}; +use compute_tools::checker::create_writablity_check_data; use compute_tools::config; use compute_tools::http_api::launch_http_server; use compute_tools::logger::*; @@ -128,6 +129,7 @@ fn run_compute(state: &Arc>) -> Result { handle_roles(&read_state.spec, &mut client)?; handle_databases(&read_state.spec, &mut client)?; + create_writablity_check_data(&mut client)?; // 'Close' connection drop(client); diff --git a/compute_tools/src/checker.rs b/compute_tools/src/checker.rs new file mode 100644 index 0000000000..63da6ea23e --- /dev/null +++ b/compute_tools/src/checker.rs @@ -0,0 +1,46 @@ +use std::sync::{Arc, RwLock}; + +use anyhow::{anyhow, Result}; +use log::error; +use postgres::Client; +use tokio_postgres::NoTls; + +use crate::zenith::ComputeState; + +pub fn create_writablity_check_data(client: &mut Client) -> Result<()> { + let query = " + CREATE TABLE IF NOT EXISTS health_check ( + id serial primary key, + updated_at timestamptz default now() + ); + INSERT INTO health_check VALUES (1, now()) + ON CONFLICT (id) DO UPDATE + SET updated_at = now();"; + let result = client.simple_query(query)?; + if result.len() < 2 { + return Err(anyhow::format_err!("executed {} queries", result.len())); + } + Ok(()) +} + +pub async fn check_writability(state: &Arc>) -> Result<()> { + let connstr = state.read().unwrap().connstr.clone(); + let (client, connection) = tokio_postgres::connect(&connstr, NoTls).await?; + if client.is_closed() { + return Err(anyhow!("connection to postgres closed")); + } + tokio::spawn(async move { + if let Err(e) = connection.await { + error!("connection error: {}", e); + } + }); + + let result = client + .simple_query("UPDATE health_check SET updated_at = now() WHERE id = 1;") + .await?; + + if result.len() != 1 { + return Err(anyhow!("statement can't be executed")); + } + Ok(()) +} diff --git a/compute_tools/src/http_api.rs b/compute_tools/src/http_api.rs index 02fab08a6e..7e1a876044 100644 --- a/compute_tools/src/http_api.rs +++ b/compute_tools/src/http_api.rs @@ -11,7 +11,7 @@ use log::{error, info}; use crate::zenith::*; // Service function to handle all available routes. -fn routes(req: Request, state: Arc>) -> Response { +async fn routes(req: Request, state: Arc>) -> Response { match (req.method(), req.uri().path()) { // Timestamp of the last Postgres activity in the plain text. (&Method::GET, "/last_activity") => { @@ -29,6 +29,15 @@ fn routes(req: Request, state: Arc>) -> Response { + info!("serving /check_writability GET request"); + let res = crate::checker::check_writability(&state).await; + match res { + Ok(_) => Response::new(Body::from("true")), + Err(e) => Response::new(Body::from(e.to_string())), + } + } + // Return the `404 Not Found` for any other routes. _ => { let mut not_found = Response::new(Body::from("404 Not Found")); @@ -48,7 +57,7 @@ async fn serve(state: Arc>) { async move { Ok::<_, Infallible>(service_fn(move |req: Request| { let state = state.clone(); - async move { Ok::<_, Infallible>(routes(req, state)) } + async move { Ok::<_, Infallible>(routes(req, state).await) } })) } }); diff --git a/compute_tools/src/lib.rs b/compute_tools/src/lib.rs index 592011d95e..ffb9700a49 100644 --- a/compute_tools/src/lib.rs +++ b/compute_tools/src/lib.rs @@ -2,6 +2,7 @@ //! Various tools and helpers to handle cluster / compute node (Postgres) //! configuration. //! +pub mod checker; pub mod config; pub mod http_api; #[macro_use] diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index de7ba30320..57c51cf892 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -37,7 +37,6 @@ toml_edit = { version = "0.13", features = ["easy"] } scopeguard = "1.1.0" const_format = "0.2.21" tracing = "0.1.27" -tracing-futures = "0.2" signal-hook = "0.3.10" url = "2" nix = "0.23" diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index dbd99bdd92..861ecdb805 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -36,8 +36,8 @@ pub mod defaults { // Target file size, when creating image and delta layers. // This parameter determines L1 layer file size. pub const DEFAULT_COMPACTION_TARGET_SIZE: u64 = 128 * 1024 * 1024; - pub const DEFAULT_COMPACTION_PERIOD: &str = "1 s"; + pub const DEFAULT_COMPACTION_THRESHOLD: usize = 10; pub const DEFAULT_GC_HORIZON: u64 = 64 * 1024 * 1024; pub const DEFAULT_GC_PERIOD: &str = "100 s"; @@ -65,6 +65,7 @@ pub mod defaults { #checkpoint_distance = {DEFAULT_CHECKPOINT_DISTANCE} # in bytes #compaction_target_size = {DEFAULT_COMPACTION_TARGET_SIZE} # in bytes #compaction_period = '{DEFAULT_COMPACTION_PERIOD}' +#compaction_threshold = '{DEFAULT_COMPACTION_THRESHOLD}' #gc_period = '{DEFAULT_GC_PERIOD}' #gc_horizon = {DEFAULT_GC_HORIZON} @@ -107,6 +108,9 @@ pub struct PageServerConf { // How often to check if there's compaction work to be done. pub compaction_period: Duration, + // Level0 delta layer threshold for compaction. + pub compaction_threshold: usize, + pub gc_horizon: u64, pub gc_period: Duration, @@ -164,6 +168,7 @@ struct PageServerConfigBuilder { compaction_target_size: BuilderValue, compaction_period: BuilderValue, + compaction_threshold: BuilderValue, gc_horizon: BuilderValue, gc_period: BuilderValue, @@ -202,6 +207,7 @@ impl Default for PageServerConfigBuilder { compaction_target_size: Set(DEFAULT_COMPACTION_TARGET_SIZE), compaction_period: Set(humantime::parse_duration(DEFAULT_COMPACTION_PERIOD) .expect("cannot parse default compaction period")), + compaction_threshold: Set(DEFAULT_COMPACTION_THRESHOLD), gc_horizon: Set(DEFAULT_GC_HORIZON), gc_period: Set(humantime::parse_duration(DEFAULT_GC_PERIOD) .expect("cannot parse default gc period")), @@ -246,6 +252,10 @@ impl PageServerConfigBuilder { self.compaction_period = BuilderValue::Set(compaction_period) } + pub fn compaction_threshold(&mut self, compaction_threshold: usize) { + self.compaction_threshold = BuilderValue::Set(compaction_threshold) + } + pub fn gc_horizon(&mut self, gc_horizon: u64) { self.gc_horizon = BuilderValue::Set(gc_horizon) } @@ -322,6 +332,9 @@ impl PageServerConfigBuilder { compaction_period: self .compaction_period .ok_or(anyhow::anyhow!("missing compaction_period"))?, + compaction_threshold: self + .compaction_threshold + .ok_or(anyhow::anyhow!("missing compaction_threshold"))?, gc_horizon: self .gc_horizon .ok_or(anyhow::anyhow!("missing gc_horizon"))?, @@ -465,6 +478,9 @@ impl PageServerConf { builder.compaction_target_size(parse_toml_u64(key, item)?) } "compaction_period" => builder.compaction_period(parse_toml_duration(key, item)?), + "compaction_threshold" => { + builder.compaction_threshold(parse_toml_u64(key, item)? as usize) + } "gc_horizon" => builder.gc_horizon(parse_toml_u64(key, item)?), "gc_period" => builder.gc_period(parse_toml_duration(key, item)?), "wait_lsn_timeout" => builder.wait_lsn_timeout(parse_toml_duration(key, item)?), @@ -603,6 +619,7 @@ impl PageServerConf { checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE, compaction_target_size: 4 * 1024 * 1024, compaction_period: Duration::from_secs(10), + compaction_threshold: defaults::DEFAULT_COMPACTION_THRESHOLD, gc_horizon: defaults::DEFAULT_GC_HORIZON, gc_period: Duration::from_secs(10), wait_lsn_timeout: Duration::from_secs(60), @@ -676,6 +693,7 @@ checkpoint_distance = 111 # in bytes compaction_target_size = 111 # in bytes compaction_period = '111 s' +compaction_threshold = 2 gc_period = '222 s' gc_horizon = 222 @@ -714,6 +732,7 @@ id = 10 checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE, compaction_target_size: defaults::DEFAULT_COMPACTION_TARGET_SIZE, compaction_period: humantime::parse_duration(defaults::DEFAULT_COMPACTION_PERIOD)?, + compaction_threshold: defaults::DEFAULT_COMPACTION_THRESHOLD, gc_horizon: defaults::DEFAULT_GC_HORIZON, gc_period: humantime::parse_duration(defaults::DEFAULT_GC_PERIOD)?, wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?, @@ -760,6 +779,7 @@ id = 10 checkpoint_distance: 111, compaction_target_size: 111, compaction_period: Duration::from_secs(111), + compaction_threshold: 2, gc_horizon: 222, gc_period: Duration::from_secs(222), wait_lsn_timeout: Duration::from_secs(111), diff --git a/pageserver/src/layered_repository.rs b/pageserver/src/layered_repository.rs index 5e93e3389b..36b081e400 100644 --- a/pageserver/src/layered_repository.rs +++ b/pageserver/src/layered_repository.rs @@ -49,7 +49,8 @@ use crate::CheckpointConfig; use crate::{ZTenantId, ZTimelineId}; use zenith_metrics::{ - register_histogram_vec, register_int_gauge_vec, Histogram, HistogramVec, IntGauge, IntGaugeVec, + register_histogram_vec, register_int_counter, register_int_gauge_vec, Histogram, HistogramVec, + IntCounter, IntGauge, IntGaugeVec, }; use zenith_utils::crashsafe_dir; use zenith_utils::lsn::{AtomicLsn, Lsn, RecordLsn}; @@ -109,6 +110,21 @@ lazy_static! { .expect("failed to define a metric"); } +// Metrics for cloud upload. These metrics reflect data uploaded to cloud storage, +// or in testing they estimate how much we would upload if we did. +lazy_static! { + static ref NUM_PERSISTENT_FILES_CREATED: IntCounter = register_int_counter!( + "pageserver_num_persistent_files_created", + "Number of files created that are meant to be uploaded to cloud storage", + ) + .expect("failed to define a metric"); + static ref PERSISTENT_BYTES_WRITTEN: IntCounter = register_int_counter!( + "pageserver_persistent_bytes_written", + "Total bytes written that are meant to be uploaded to cloud storage", + ) + .expect("failed to define a metric"); +} + /// Parts of the `.zenith/tenants//timelines/` directory prefix. pub const TIMELINES_SEGMENT_NAME: &str = "timelines"; @@ -193,7 +209,7 @@ impl Repository for LayeredRepository { Arc::clone(&self.walredo_mgr), self.upload_layers, ); - timeline.layers.lock().unwrap().next_open_layer_at = Some(initdb_lsn); + timeline.layers.write().unwrap().next_open_layer_at = Some(initdb_lsn); let timeline = Arc::new(timeline); let r = timelines.insert( @@ -725,7 +741,7 @@ pub struct LayeredTimeline { tenantid: ZTenantId, timelineid: ZTimelineId, - layers: Mutex, + layers: RwLock, last_freeze_at: AtomicLsn, @@ -997,7 +1013,7 @@ impl LayeredTimeline { conf, timelineid, tenantid, - layers: Mutex::new(LayerMap::default()), + layers: RwLock::new(LayerMap::default()), walredo_mgr, @@ -1040,7 +1056,7 @@ impl LayeredTimeline { /// Returns all timeline-related files that were found and loaded. /// fn load_layer_map(&self, disk_consistent_lsn: Lsn) -> anyhow::Result<()> { - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); let mut num_layers = 0; // Scan timeline directory and create ImageFileName and DeltaFilename @@ -1194,7 +1210,7 @@ impl LayeredTimeline { continue; } - let layers = timeline.layers.lock().unwrap(); + let layers = timeline.layers.read().unwrap(); // Check the open and frozen in-memory layers first if let Some(open_layer) = &layers.open_layer { @@ -1276,7 +1292,7 @@ impl LayeredTimeline { /// Get a handle to the latest layer for appending. /// fn get_layer_for_write(&self, lsn: Lsn) -> anyhow::Result> { - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); ensure!(lsn.is_aligned()); @@ -1347,7 +1363,7 @@ impl LayeredTimeline { } else { Some(self.write_lock.lock().unwrap()) }; - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); if let Some(open_layer) = &layers.open_layer { let open_layer_rc = Arc::clone(open_layer); // Does this layer need freezing? @@ -1412,7 +1428,7 @@ impl LayeredTimeline { let timer = self.flush_time_histo.start_timer(); loop { - let layers = self.layers.lock().unwrap(); + let layers = self.layers.read().unwrap(); if let Some(frozen_layer) = layers.frozen_layers.front() { let frozen_layer = Arc::clone(frozen_layer); drop(layers); // to allow concurrent reads and writes @@ -1456,7 +1472,7 @@ impl LayeredTimeline { // Finally, replace the frozen in-memory layer with the new on-disk layers { - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); let l = layers.frozen_layers.pop_front(); // Only one thread may call this function at a time (for this @@ -1524,6 +1540,10 @@ impl LayeredTimeline { &metadata, false, )?; + + NUM_PERSISTENT_FILES_CREATED.inc_by(1); + PERSISTENT_BYTES_WRITTEN.inc_by(new_delta_path.metadata()?.len()); + if self.upload_layers.load(atomic::Ordering::Relaxed) { schedule_timeline_checkpoint_upload( self.tenantid, @@ -1612,7 +1632,7 @@ impl LayeredTimeline { lsn: Lsn, threshold: usize, ) -> Result { - let layers = self.layers.lock().unwrap(); + let layers = self.layers.read().unwrap(); for part_range in &partition.ranges { let image_coverage = layers.image_coverage(part_range, lsn)?; @@ -1670,7 +1690,7 @@ impl LayeredTimeline { // FIXME: Do we need to do something to upload it to remote storage here? - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); layers.insert_historic(Arc::new(image_layer)); drop(layers); @@ -1678,15 +1698,13 @@ impl LayeredTimeline { } fn compact_level0(&self, target_file_size: u64) -> Result<()> { - let layers = self.layers.lock().unwrap(); - - // We compact or "shuffle" the level-0 delta layers when 10 have - // accumulated. - static COMPACT_THRESHOLD: usize = 10; + let layers = self.layers.read().unwrap(); let level0_deltas = layers.get_level0_deltas()?; - if level0_deltas.len() < COMPACT_THRESHOLD { + // We compact or "shuffle" the level-0 delta layers when they've + // accumulated over the compaction threshold. + if level0_deltas.len() < self.conf.compaction_threshold { return Ok(()); } drop(layers); @@ -1770,7 +1788,7 @@ impl LayeredTimeline { layer_paths.pop().unwrap(); } - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); for l in new_layers { layers.insert_historic(Arc::new(l)); } @@ -1852,7 +1870,7 @@ impl LayeredTimeline { // 2. it doesn't need to be retained for 'retain_lsns'; // 3. newer on-disk image layers cover the layer's whole key range // - let mut layers = self.layers.lock().unwrap(); + let mut layers = self.layers.write().unwrap(); 'outer: for l in layers.iter_historic_layers() { // This layer is in the process of being flushed to disk. // It will be swapped out of the layer map, replaced with diff --git a/pageserver/src/layered_repository/block_io.rs b/pageserver/src/layered_repository/block_io.rs index 2eba0aa403..d027b2f0e7 100644 --- a/pageserver/src/layered_repository/block_io.rs +++ b/pageserver/src/layered_repository/block_io.rs @@ -198,7 +198,6 @@ impl BlockWriter for BlockBuf { assert!(buf.len() == PAGE_SZ); let blknum = self.blocks.len(); self.blocks.push(buf); - tracing::info!("buffered block {}", blknum); Ok(blknum as u32) } } diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index e7a4117b3e..c09b032e48 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -713,6 +713,26 @@ impl postgres_backend::Handler for PageServerHandler { Some(result.elapsed.as_millis().to_string().as_bytes()), ]))? .write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.starts_with("compact ") { + // Run compaction immediately on given timeline. + // FIXME This is just for tests. Don't expect this to be exposed to + // the users or the api. + + // compact + let re = Regex::new(r"^compact ([[:xdigit:]]+)\s([[:xdigit:]]+)($|\s)?").unwrap(); + + let caps = re + .captures(query_string) + .with_context(|| format!("Invalid compact: '{}'", query_string))?; + + let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?; + let timelineid = ZTimelineId::from_str(caps.get(2).unwrap().as_str())?; + let timeline = tenant_mgr::get_timeline_for_tenant_load(tenantid, timelineid) + .context("Couldn't load timeline")?; + timeline.tline.compact()?; + + pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; } else if query_string.starts_with("checkpoint ") { // Run checkpoint immediately on given timeline. diff --git a/pageserver/src/remote_storage/storage_sync/download.rs b/pageserver/src/remote_storage/storage_sync/download.rs index 773b4a12e5..e5aa74452b 100644 --- a/pageserver/src/remote_storage/storage_sync/download.rs +++ b/pageserver/src/remote_storage/storage_sync/download.rs @@ -1,6 +1,6 @@ //! Timeline synchrnonization logic to put files from archives on remote storage into pageserver's local directory. -use std::{borrow::Cow, collections::BTreeSet, path::PathBuf, sync::Arc}; +use std::{collections::BTreeSet, path::PathBuf, sync::Arc}; use anyhow::{ensure, Context}; use tokio::fs; @@ -64,11 +64,16 @@ pub(super) async fn download_timeline< let remote_timeline = match index_read.timeline_entry(&sync_id) { None => { error!("Cannot download: no timeline is present in the index for given id"); + drop(index_read); return DownloadedTimeline::Abort; } Some(index_entry) => match index_entry.inner() { - TimelineIndexEntryInner::Full(remote_timeline) => Cow::Borrowed(remote_timeline), + TimelineIndexEntryInner::Full(remote_timeline) => { + let cloned = remote_timeline.clone(); + drop(index_read); + cloned + } TimelineIndexEntryInner::Description(_) => { // we do not check here for awaits_download because it is ok // to call this function while the download is in progress @@ -84,7 +89,7 @@ pub(super) async fn download_timeline< ) .await { - Ok(remote_timeline) => Cow::Owned(remote_timeline), + Ok(remote_timeline) => remote_timeline, Err(e) => { error!("Failed to download full timeline index: {:?}", e); diff --git a/pageserver/src/remote_storage/storage_sync/upload.rs b/pageserver/src/remote_storage/storage_sync/upload.rs index f955e04474..7b6d58a661 100644 --- a/pageserver/src/remote_storage/storage_sync/upload.rs +++ b/pageserver/src/remote_storage/storage_sync/upload.rs @@ -1,6 +1,6 @@ //! Timeline synchronization logic to compress and upload to the remote storage all new timeline files from the checkpoints. -use std::{borrow::Cow, collections::BTreeSet, path::PathBuf, sync::Arc}; +use std::{collections::BTreeSet, path::PathBuf, sync::Arc}; use tracing::{debug, error, warn}; @@ -46,13 +46,21 @@ pub(super) async fn upload_timeline_checkpoint< let index_read = index.read().await; let remote_timeline = match index_read.timeline_entry(&sync_id) { - None => None, + None => { + drop(index_read); + None + } Some(entry) => match entry.inner() { - TimelineIndexEntryInner::Full(remote_timeline) => Some(Cow::Borrowed(remote_timeline)), + TimelineIndexEntryInner::Full(remote_timeline) => { + let r = Some(remote_timeline.clone()); + drop(index_read); + r + } TimelineIndexEntryInner::Description(_) => { + drop(index_read); debug!("Found timeline description for the given ids, downloading the full index"); match fetch_full_index(remote_assets.as_ref(), &timeline_dir, sync_id).await { - Ok(remote_timeline) => Some(Cow::Owned(remote_timeline)), + Ok(remote_timeline) => Some(remote_timeline), Err(e) => { error!("Failed to download full timeline index: {:?}", e); sync_queue::push(SyncTask::new( @@ -82,7 +90,6 @@ pub(super) async fn upload_timeline_checkpoint< let already_uploaded_files = remote_timeline .map(|timeline| timeline.stored_files(&timeline_dir)) .unwrap_or_default(); - drop(index_read); match try_upload_checkpoint( config, diff --git a/pageserver/src/repository.rs b/pageserver/src/repository.rs index 02334d3229..eda9a3168d 100644 --- a/pageserver/src/repository.rs +++ b/pageserver/src/repository.rs @@ -252,8 +252,10 @@ pub trait Repository: Send + Sync { checkpoint_before_gc: bool, ) -> Result; - /// perform one compaction iteration. - /// this function is periodically called by compactor thread. + /// Perform one compaction iteration. + /// This function is periodically called by compactor thread. + /// Also it can be explicitly requested per timeline through page server + /// api's 'compact' command. fn compaction_iteration(&self) -> Result<()>; /// detaches locally available timeline by stopping all threads and removing all the data. diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index dc20695884..be03a2d4a9 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -5,12 +5,14 @@ edition = "2021" [dependencies] anyhow = "1.0" +base64 = "0.13.0" bytes = { version = "1.0.1", features = ['serde'] } clap = "3.0" fail = "0.5.0" futures = "0.3.13" hashbrown = "0.11.2" hex = "0.4.3" +hmac = "0.10.1" hyper = "0.14" lazy_static = "1.4.0" md5 = "0.7.0" @@ -18,12 +20,14 @@ parking_lot = "0.11.2" pin-project-lite = "0.2.7" rand = "0.8.3" reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] } +routerify = "2" rustls = "0.19.1" scopeguard = "1.1.0" serde = "1" serde_json = "1" +sha2 = "0.9.8" socket2 = "0.4.4" -thiserror = "1.0" +thiserror = "1.0.30" tokio = { version = "1.17", features = ["macros"] } tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } tokio-rustls = "0.22.0" @@ -33,5 +37,7 @@ zenith_metrics = { path = "../zenith_metrics" } workspace_hack = { version = "0.1", path = "../workspace_hack" } [dev-dependencies] -tokio-postgres-rustls = "0.8.0" +async-trait = "0.1" rcgen = "0.8.14" +rstest = "0.12" +tokio-postgres-rustls = "0.8.0" diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index e8fe65c081..bda14d67a1 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,14 +1,24 @@ +mod credentials; + +#[cfg(test)] +mod flow; + use crate::compute::DatabaseInfo; use crate::config::ProxyConfig; use crate::cplane_api::{self, CPlaneApi}; use crate::error::UserFacingError; use crate::stream::PqStream; use crate::waiters; -use std::collections::HashMap; +use std::io; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; +pub use credentials::ClientCredentials; + +#[cfg(test)] +pub use flow::*; + /// Common authentication error. #[derive(Debug, Error)] pub enum AuthErrorImpl { @@ -16,13 +26,17 @@ pub enum AuthErrorImpl { #[error(transparent)] Console(#[from] cplane_api::AuthError), + #[cfg(test)] + #[error(transparent)] + Sasl(#[from] crate::sasl::Error), + /// For passwords that couldn't be processed by [`parse_password`]. #[error("Malformed password message")] MalformedPassword, /// Errors produced by [`PqStream`]. #[error(transparent)] - Io(#[from] std::io::Error), + Io(#[from] io::Error), } impl AuthErrorImpl { @@ -67,70 +81,6 @@ impl UserFacingError for AuthError { } } -#[derive(Debug, Error)] -pub enum ClientCredsParseError { - #[error("Parameter `{0}` is missing in startup packet")] - MissingKey(&'static str), -} - -impl UserFacingError for ClientCredsParseError {} - -/// Various client credentials which we use for authentication. -#[derive(Debug, PartialEq, Eq)] -pub struct ClientCredentials { - pub user: String, - pub dbname: String, -} - -impl TryFrom> for ClientCredentials { - type Error = ClientCredsParseError; - - fn try_from(mut value: HashMap) -> Result { - let mut get_param = |key| { - value - .remove(key) - .ok_or(ClientCredsParseError::MissingKey(key)) - }; - - let user = get_param("user")?; - let db = get_param("database")?; - - Ok(Self { user, dbname: db }) - } -} - -impl ClientCredentials { - /// Use credentials to authenticate the user. - pub async fn authenticate( - self, - config: &ProxyConfig, - client: &mut PqStream, - ) -> Result { - fail::fail_point!("proxy-authenticate", |_| { - Err(AuthError::auth_failed("failpoint triggered")) - }); - - use crate::config::ClientAuthMethod::*; - use crate::config::RouterConfig::*; - match &config.router_config { - Static { host, port } => handle_static(host.clone(), *port, client, self).await, - Dynamic(Mixed) => { - if self.user.ends_with("@zenith") { - handle_existing_user(config, client, self).await - } else { - handle_new_user(config, client).await - } - } - Dynamic(Password) => handle_existing_user(config, client, self).await, - Dynamic(Link) => handle_new_user(config, client).await, - } - } -} - -fn new_psql_session_id() -> String { - hex::encode(rand::random::<[u8; 8]>()) -} - async fn handle_static( host: String, port: u16, @@ -169,7 +119,7 @@ async fn handle_existing_user( let md5_salt = rand::random(); client - .write_message(&Be::AuthenticationMD5Password(&md5_salt)) + .write_message(&Be::AuthenticationMD5Password(md5_salt)) .await?; // Read client's password hash @@ -213,6 +163,10 @@ async fn handle_new_user( Ok(db_info) } +fn new_psql_session_id() -> String { + hex::encode(rand::random::<[u8; 8]>()) +} + fn parse_password(bytes: &[u8]) -> Option<&str> { std::str::from_utf8(bytes).ok()?.strip_suffix('\0') } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs new file mode 100644 index 0000000000..7c8ba28622 --- /dev/null +++ b/proxy/src/auth/credentials.rs @@ -0,0 +1,70 @@ +//! User credentials used in authentication. + +use super::AuthError; +use crate::compute::DatabaseInfo; +use crate::config::ProxyConfig; +use crate::error::UserFacingError; +use crate::stream::PqStream; +use std::collections::HashMap; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; + +#[derive(Debug, Error)] +pub enum ClientCredsParseError { + #[error("Parameter `{0}` is missing in startup packet")] + MissingKey(&'static str), +} + +impl UserFacingError for ClientCredsParseError {} + +/// Various client credentials which we use for authentication. +#[derive(Debug, PartialEq, Eq)] +pub struct ClientCredentials { + pub user: String, + pub dbname: String, +} + +impl TryFrom> for ClientCredentials { + type Error = ClientCredsParseError; + + fn try_from(mut value: HashMap) -> Result { + let mut get_param = |key| { + value + .remove(key) + .ok_or(ClientCredsParseError::MissingKey(key)) + }; + + let user = get_param("user")?; + let db = get_param("database")?; + + Ok(Self { user, dbname: db }) + } +} + +impl ClientCredentials { + /// Use credentials to authenticate the user. + pub async fn authenticate( + self, + config: &ProxyConfig, + client: &mut PqStream, + ) -> Result { + fail::fail_point!("proxy-authenticate", |_| { + Err(AuthError::auth_failed("failpoint triggered")) + }); + + use crate::config::ClientAuthMethod::*; + use crate::config::RouterConfig::*; + match &config.router_config { + Static { host, port } => super::handle_static(host.clone(), *port, client, self).await, + Dynamic(Mixed) => { + if self.user.ends_with("@zenith") { + super::handle_existing_user(config, client, self).await + } else { + super::handle_new_user(config, client).await + } + } + Dynamic(Password) => super::handle_existing_user(config, client, self).await, + Dynamic(Link) => super::handle_new_user(config, client).await, + } + } +} diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs new file mode 100644 index 0000000000..0fafaa2f47 --- /dev/null +++ b/proxy/src/auth/flow.rs @@ -0,0 +1,102 @@ +//! Main authentication flow. + +use super::{AuthError, AuthErrorImpl}; +use crate::stream::PqStream; +use crate::{sasl, scram}; +use std::io; +use tokio::io::{AsyncRead, AsyncWrite}; +use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; + +/// Every authentication selector is supposed to implement this trait. +pub trait AuthMethod { + /// Any authentication selector should provide initial backend message + /// containing auth method name and parameters, e.g. md5 salt. + fn first_message(&self) -> BeMessage<'_>; +} + +/// Initial state of [`AuthFlow`]. +pub struct Begin; + +/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. +pub struct Scram<'a>(pub &'a scram::ServerSecret); + +impl AuthMethod for Scram<'_> { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + } +} + +/// Use password-based auth in [`AuthFlow`]. +pub struct Md5( + /// Salt for client. + pub [u8; 4], +); + +impl AuthMethod for Md5 { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationMD5Password(self.0) + } +} + +/// This wrapper for [`PqStream`] performs client authentication. +#[must_use] +pub struct AuthFlow<'a, Stream, State> { + /// The underlying stream which implements libpq's protocol. + stream: &'a mut PqStream, + /// State might contain ancillary data (see [`AuthFlow::begin`]). + state: State, +} + +/// Initial state of the stream wrapper. +impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { + /// Create a new wrapper for client authentication. + pub fn new(stream: &'a mut PqStream) -> Self { + Self { + stream, + state: Begin, + } + } + + /// Move to the next step by sending auth method's name & params to client. + pub async fn begin(self, method: M) -> io::Result> { + self.stream.write_message(&method.first_message()).await?; + + Ok(AuthFlow { + stream: self.stream, + state: method, + }) + } +} + +/// Stream wrapper for handling simple MD5 password auth. +impl AuthFlow<'_, S, Md5> { + /// Perform user authentication. Raise an error in case authentication failed. + #[allow(unused)] + pub async fn authenticate(self) -> Result<(), AuthError> { + unimplemented!("MD5 auth flow is yet to be implemented"); + } +} + +/// Stream wrapper for handling [SCRAM](crate::scram) auth. +impl AuthFlow<'_, S, Scram<'_>> { + /// Perform user authentication. Raise an error in case authentication failed. + pub async fn authenticate(self) -> Result<(), AuthError> { + // Initial client message contains the chosen auth method's name. + let msg = self.stream.read_password_message().await?; + let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; + + // Currently, the only supported SASL method is SCRAM. + if !scram::METHODS.contains(&sasl.method) { + return Err(AuthErrorImpl::auth_failed("method not supported").into()); + } + + let secret = self.state.0; + sasl::SaslStream::new(self.stream, sasl.message) + .authenticate(scram::Exchange::new(secret, rand::random, None)) + .await?; + + Ok(()) + } +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 7b38721a88..1a91e6fbbc 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -1,19 +1,8 @@ -/// -/// Postgres protocol proxy/router. -/// -/// This service listens psql port and can check auth via external service -/// (control plane API in our case) and can create new databases and accounts -/// in somewhat transparent manner (again via communication with control plane API). -/// -use anyhow::{bail, Context}; -use clap::{Arg, Command}; -use config::ProxyConfig; -use futures::FutureExt; -use std::future::Future; -use tokio::{net::TcpListener, task::JoinError}; -use zenith_utils::GIT_VERSION; - -use crate::config::{ClientAuthMethod, RouterConfig}; +//! Postgres protocol proxy/router. +//! +//! This service listens psql port and can check auth via external service +//! (control plane API in our case) and can create new databases and accounts +//! in somewhat transparent manner (again via communication with control plane API). mod auth; mod cancellation; @@ -27,6 +16,24 @@ mod proxy; mod stream; mod waiters; +// Currently SCRAM is only used in tests +#[cfg(test)] +mod parse; +#[cfg(test)] +mod sasl; +#[cfg(test)] +mod scram; + +use anyhow::{bail, Context}; +use clap::{Arg, Command}; +use config::ProxyConfig; +use futures::FutureExt; +use std::future::Future; +use tokio::{net::TcpListener, task::JoinError}; +use zenith_utils::GIT_VERSION; + +use crate::config::{ClientAuthMethod, RouterConfig}; + /// Flattens `Result>` into `Result`. async fn flatten_err( f: impl Future, JoinError>>, diff --git a/proxy/src/parse.rs b/proxy/src/parse.rs new file mode 100644 index 0000000000..8a05ff9c82 --- /dev/null +++ b/proxy/src/parse.rs @@ -0,0 +1,18 @@ +//! Small parsing helpers. + +use std::convert::TryInto; +use std::ffi::CStr; + +pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { + let pos = bytes.iter().position(|&x| x == 0)?; + let (cstr, other) = bytes.split_at(pos + 1); + // SAFETY: we've already checked that there's a terminator + Some((unsafe { CStr::from_bytes_with_nul_unchecked(cstr) }, other)) +} + +pub fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { + (bytes.len() >= N).then(|| { + let (head, tail) = bytes.split_at(N); + (head.try_into().unwrap(), tail) + }) +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 81581b5cf1..5b662f4c69 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -119,7 +119,6 @@ async fn handshake( // We can't perform TLS handshake without a config let enc = tls.is_some(); stream.write_message(&Be::EncryptionResponse(enc)).await?; - if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. @@ -219,32 +218,14 @@ impl Client { #[cfg(test)] mod tests { use super::*; - - use tokio::io::DuplexStream; + use crate::{auth, scram}; + use async_trait::async_trait; + use rstest::rstest; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; use tokio_postgres_rustls::MakeRustlsConnect; - async fn dummy_proxy( - client: impl AsyncRead + AsyncWrite + Unpin, - tls: Option, - ) -> anyhow::Result<()> { - let cancel_map = CancelMap::default(); - - // TODO: add some infra + tests for credentials - let (mut stream, _creds) = handshake(client, tls, &cancel_map) - .await? - .context("no stream")?; - - stream - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())? - .write_message(&BeMessage::ReadyForQuery) - .await?; - - Ok(()) - } - + /// Generate a set of TLS certificates: CA + server. fn generate_certs( hostname: &str, ) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> { @@ -262,19 +243,115 @@ mod tests { )) } + struct ClientConfig<'a> { + config: rustls::ClientConfig, + hostname: &'a str, + } + + impl ClientConfig<'_> { + fn make_tls_connect( + self, + ) -> anyhow::Result> { + let mut mk = MakeRustlsConnect::new(self.config); + let tls = MakeTlsConnect::::make_tls_connect(&mut mk, self.hostname)?; + Ok(tls) + } + } + + /// Generate TLS certificates and build rustls configs for client and server. + fn generate_tls_config( + hostname: &str, + ) -> anyhow::Result<(ClientConfig<'_>, Arc)> { + let (ca, cert, key) = generate_certs(hostname)?; + + let server_config = { + let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(vec![cert], key)?; + config.into() + }; + + let client_config = { + let mut config = rustls::ClientConfig::new(); + config.root_store.add(&ca)?; + ClientConfig { config, hostname } + }; + + Ok((client_config, server_config)) + } + + #[async_trait] + trait TestAuth: Sized { + async fn authenticate( + self, + _stream: &mut PqStream>, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + struct NoAuth; + impl TestAuth for NoAuth {} + + struct Scram(scram::ServerSecret); + + impl Scram { + fn new(password: &str) -> anyhow::Result { + let salt = rand::random::<[u8; 16]>(); + let secret = scram::ServerSecret::build(password, &salt, 256) + .context("failed to generate scram secret")?; + Ok(Scram(secret)) + } + + fn mock(user: &str) -> Self { + let salt = rand::random::<[u8; 32]>(); + Scram(scram::ServerSecret::mock(user, &salt)) + } + } + + #[async_trait] + impl TestAuth for Scram { + async fn authenticate( + self, + stream: &mut PqStream>, + ) -> anyhow::Result<()> { + auth::AuthFlow::new(stream) + .begin(auth::Scram(&self.0)) + .await? + .authenticate() + .await?; + + Ok(()) + } + } + + /// A dummy proxy impl which performs a handshake and reports auth success. + async fn dummy_proxy( + client: impl AsyncRead + AsyncWrite + Unpin + Send, + tls: Option, + auth: impl TestAuth + Send, + ) -> anyhow::Result<()> { + let cancel_map = CancelMap::default(); + let (mut stream, _creds) = handshake(client, tls, &cancel_map) + .await? + .context("handshake failed")?; + + auth.authenticate(&mut stream).await?; + + stream + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message(&BeMessage::ReadyForQuery) + .await?; + + Ok(()) + } + #[tokio::test] async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let server_config = { - let (_ca, cert, key) = generate_certs("localhost")?; - - let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(vec![cert], key)?; - config - }; - - let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into()))); + let (_, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let client_err = tokio_postgres::Config::new() .user("john_doe") @@ -301,30 +378,14 @@ mod tests { async fn handshake_tls() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (ca, cert, key) = generate_certs("localhost")?; - - let server_config = { - let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new()); - config.set_single_cert(vec![cert], key)?; - config - }; - - let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into()))); - - let client_config = { - let mut config = rustls::ClientConfig::new(); - config.root_store.add(&ca)?; - config - }; - - let mut mk = MakeRustlsConnect::new(client_config); - let tls = MakeTlsConnect::::make_tls_connect(&mut mk, "localhost")?; + let (client_config, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let (_client, _conn) = tokio_postgres::Config::new() .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Require) - .connect_raw(server, tls) + .connect_raw(server, client_config.make_tls_connect()?) .await?; proxy.await? @@ -334,7 +395,7 @@ mod tests { async fn handshake_raw() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let proxy = tokio::spawn(dummy_proxy(client, None)); + let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); let (_client, _conn) = tokio_postgres::Config::new() .user("john_doe") @@ -350,7 +411,7 @@ mod tests { async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let proxy = tokio::spawn(dummy_proxy(client, None)); + let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); let client_err = tokio_postgres::Config::new() .ssl_mode(SslMode::Disable) @@ -391,4 +452,66 @@ mod tests { Ok(()) } + + #[rstest] + #[case("password_foo")] + #[case("pwd-bar")] + #[case("")] + #[tokio::test] + async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let (client_config, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy( + client, + Some(server_config), + Scram::new(password)?, + )); + + let (_client, _conn) = tokio_postgres::Config::new() + .user("user") + .dbname("db") + .password(password) + .ssl_mode(SslMode::Require) + .connect_raw(server, client_config.make_tls_connect()?) + .await?; + + proxy.await? + } + + #[tokio::test] + async fn scram_auth_mock() -> anyhow::Result<()> { + let (client, server) = tokio::io::duplex(1024); + + let (client_config, server_config) = generate_tls_config("localhost")?; + let proxy = tokio::spawn(dummy_proxy( + client, + Some(server_config), + Scram::mock("user"), + )); + + use rand::{distributions::Alphanumeric, Rng}; + let password: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(rand::random::() as usize) + .map(char::from) + .collect(); + + let _client_err = tokio_postgres::Config::new() + .user("user") + .dbname("db") + .password(&password) // no password will match the mocked secret + .ssl_mode(SslMode::Require) + .connect_raw(server, client_config.make_tls_connect()?) + .await + .err() // -> Option + .context("client shouldn't be able to connect")?; + + let _server_err = proxy + .await? + .err() // -> Option + .context("server shouldn't accept client")?; + + Ok(()) + } } diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs new file mode 100644 index 0000000000..70a4d9946a --- /dev/null +++ b/proxy/src/sasl.rs @@ -0,0 +1,47 @@ +//! Simple Authentication and Security Layer. +//! +//! RFC: . +//! +//! Reference implementation: +//! * +//! * + +mod channel_binding; +mod messages; +mod stream; + +use std::io; +use thiserror::Error; + +pub use channel_binding::ChannelBinding; +pub use messages::FirstMessage; +pub use stream::SaslStream; + +/// Fine-grained auth errors help in writing tests. +#[derive(Error, Debug)] +pub enum Error { + #[error("Failed to authenticate client: {0}")] + AuthenticationFailed(&'static str), + + #[error("Channel binding failed: {0}")] + ChannelBindingFailed(&'static str), + + #[error("Unsupported channel binding method: {0}")] + ChannelBindingBadMethod(Box), + + #[error("Bad client message")] + BadClientMessage, + + #[error(transparent)] + Io(#[from] io::Error), +} + +/// A convenient result type for SASL exchange. +pub type Result = std::result::Result; + +/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait. +pub trait Mechanism: Sized { + /// Produce a server challenge to be sent to the client. + /// This is how this method is called in PostgreSQL (`libpq/sasl.h`). + fn exchange(self, input: &str) -> Result<(Option, String)>; +} diff --git a/proxy/src/sasl/channel_binding.rs b/proxy/src/sasl/channel_binding.rs new file mode 100644 index 0000000000..776adabe55 --- /dev/null +++ b/proxy/src/sasl/channel_binding.rs @@ -0,0 +1,85 @@ +//! Definition and parser for channel binding flag (a part of the `GS2` header). + +/// Channel binding flag (possibly with params). +#[derive(Debug, PartialEq, Eq)] +pub enum ChannelBinding { + /// Client doesn't support channel binding. + NotSupportedClient, + /// Client thinks server doesn't support channel binding. + NotSupportedServer, + /// Client wants to use this type of channel binding. + Required(T), +} + +impl ChannelBinding { + pub fn and_then(self, f: impl FnOnce(T) -> Result) -> Result, E> { + use ChannelBinding::*; + Ok(match self { + NotSupportedClient => NotSupportedClient, + NotSupportedServer => NotSupportedServer, + Required(x) => Required(f(x)?), + }) + } +} + +impl<'a> ChannelBinding<&'a str> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + use ChannelBinding::*; + Some(match input { + "n" => NotSupportedClient, + "y" => NotSupportedServer, + other => Required(other.strip_prefix("p=")?), + }) + } +} + +impl ChannelBinding { + /// Encode channel binding data as base64 for subsequent checks. + pub fn encode( + &self, + get_cbind_data: impl FnOnce(&T) -> Result, + ) -> Result, E> { + use ChannelBinding::*; + Ok(match self { + NotSupportedClient => { + // base64::encode("n,,") + "biws".into() + } + NotSupportedServer => { + // base64::encode("y,,") + "eSws".into() + } + Required(mode) => { + let msg = format!( + "p={mode},,{data}", + mode = mode, + data = get_cbind_data(mode)? + ); + base64::encode(msg).into() + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn channel_binding_encode() -> anyhow::Result<()> { + use ChannelBinding::*; + + let cases = [ + (NotSupportedClient, base64::encode("n,,")), + (NotSupportedServer, base64::encode("y,,")), + (Required("foo"), base64::encode("p=foo,,bar")), + ]; + + for (cb, input) in cases { + assert_eq!(cb.encode(|_| anyhow::Ok("bar".to_owned()))?, input); + } + + Ok(()) + } +} diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs new file mode 100644 index 0000000000..b1ae8cc426 --- /dev/null +++ b/proxy/src/sasl/messages.rs @@ -0,0 +1,67 @@ +//! Definitions for SASL messages. + +use crate::parse::{split_at_const, split_cstr}; +use zenith_utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage}; + +/// SASL-specific payload of [`PasswordMessage`](zenith_utils::pq_proto::FeMessage::PasswordMessage). +#[derive(Debug)] +pub struct FirstMessage<'a> { + /// Authentication method, e.g. `"SCRAM-SHA-256"`. + pub method: &'a str, + /// Initial client message. + pub message: &'a str, +} + +impl<'a> FirstMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(bytes: &'a [u8]) -> Option { + let (method_cstr, tail) = split_cstr(bytes)?; + let method = method_cstr.to_str().ok()?; + + let (len_bytes, bytes) = split_at_const(tail)?; + let len = u32::from_be_bytes(*len_bytes) as usize; + if len != bytes.len() { + return None; + } + + let message = std::str::from_utf8(bytes).ok()?; + Some(Self { method, message }) + } +} + +/// A single SASL message. +/// This struct is deliberately decoupled from lower-level +/// [`BeAuthenticationSaslMessage`](zenith_utils::pq_proto::BeAuthenticationSaslMessage). +#[derive(Debug)] +pub(super) enum ServerMessage { + /// We expect to see more steps. + Continue(T), + /// This is the final step. + Final(T), +} + +impl<'a> ServerMessage<&'a str> { + pub(super) fn to_reply(&self) -> BeMessage<'a> { + use BeAuthenticationSaslMessage::*; + BeMessage::AuthenticationSasl(match self { + ServerMessage::Continue(s) => Continue(s.as_bytes()), + ServerMessage::Final(s) => Final(s.as_bytes()), + }) + } +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_sasl_first_message() { + let proto = "SCRAM-SHA-256"; + let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4"; + let sasl_len = (sasl.len() as u32).to_be_bytes(); + let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat(); + + let password = FirstMessage::parse(&bytes).unwrap(); + assert_eq!(password.method, proto); + assert_eq!(password.message, sasl); + } +} diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs new file mode 100644 index 0000000000..03649b8d11 --- /dev/null +++ b/proxy/src/sasl/stream.rs @@ -0,0 +1,70 @@ +//! Abstraction for the string-oriented SASL protocols. + +use super::{messages::ServerMessage, Mechanism}; +use crate::stream::PqStream; +use std::io; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// Abstracts away all peculiarities of the libpq's protocol. +pub struct SaslStream<'a, S> { + /// The underlying stream. + stream: &'a mut PqStream, + /// Current password message we received from client. + current: bytes::Bytes, + /// First SASL message produced by client. + first: Option<&'a str>, +} + +impl<'a, S> SaslStream<'a, S> { + pub fn new(stream: &'a mut PqStream, first: &'a str) -> Self { + Self { + stream, + current: bytes::Bytes::new(), + first: Some(first), + } + } +} + +impl SaslStream<'_, S> { + // Receive a new SASL message from the client. + async fn recv(&mut self) -> io::Result<&str> { + if let Some(first) = self.first.take() { + return Ok(first); + } + + self.current = self.stream.read_password_message().await?; + let s = std::str::from_utf8(&self.current) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + + Ok(s) + } +} + +impl SaslStream<'_, S> { + // Send a SASL message to the client. + async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { + self.stream.write_message(&msg.to_reply()).await?; + Ok(()) + } +} + +impl SaslStream<'_, S> { + /// Perform SASL message exchange according to the underlying algorithm + /// until user is either authenticated or denied access. + pub async fn authenticate(mut self, mut mechanism: impl Mechanism) -> super::Result<()> { + loop { + let input = self.recv().await?; + let (moved, reply) = mechanism.exchange(input)?; + match moved { + Some(moved) => { + self.send(&ServerMessage::Continue(&reply)).await?; + mechanism = moved; + } + None => { + self.send(&ServerMessage::Final(&reply)).await?; + return Ok(()); + } + } + } + } +} diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs new file mode 100644 index 0000000000..f007f3e0b6 --- /dev/null +++ b/proxy/src/scram.rs @@ -0,0 +1,59 @@ +//! Salted Challenge Response Authentication Mechanism. +//! +//! RFC: . +//! +//! Reference implementation: +//! * +//! * + +mod exchange; +mod key; +mod messages; +mod password; +mod secret; +mod signature; + +pub use secret::*; + +pub use exchange::Exchange; +pub use secret::ServerSecret; + +use hmac::{Hmac, Mac, NewMac}; +use sha2::{Digest, Sha256}; + +// TODO: add SCRAM-SHA-256-PLUS +/// A list of supported SCRAM methods. +pub const METHODS: &[&str] = &["SCRAM-SHA-256"]; + +/// Decode base64 into array without any heap allocations +fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N]> { + let mut bytes = [0u8; N]; + + let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?; + if size != N { + return None; + } + + Some(bytes) +} + +/// This function essentially is `Hmac(sha256, key, input)`. +/// Further reading: . +fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator) -> [u8; 32] { + let mut mac = Hmac::::new_varkey(key).expect("bad key size"); + parts.into_iter().for_each(|s| mac.update(s)); + + // TODO: maybe newer `hmac` et al already migrated to regular arrays? + let mut result = [0u8; 32]; + result.copy_from_slice(mac.finalize().into_bytes().as_slice()); + result +} + +fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { + let mut hasher = Sha256::new(); + parts.into_iter().for_each(|s| hasher.update(s)); + + let mut result = [0u8; 32]; + result.copy_from_slice(hasher.finalize().as_slice()); + result +} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs new file mode 100644 index 0000000000..5a986b965a --- /dev/null +++ b/proxy/src/scram/exchange.rs @@ -0,0 +1,134 @@ +//! Implementation of the SCRAM authentication algorithm. + +use super::messages::{ + ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, +}; +use super::secret::ServerSecret; +use super::signature::SignatureBuilder; +use crate::sasl::{self, ChannelBinding, Error as SaslError}; + +/// The only channel binding mode we currently support. +#[derive(Debug)] +struct TlsServerEndPoint; + +impl std::fmt::Display for TlsServerEndPoint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "tls-server-end-point") + } +} + +impl std::str::FromStr for TlsServerEndPoint { + type Err = sasl::Error; + + fn from_str(s: &str) -> Result { + match s { + "tls-server-end-point" => Ok(TlsServerEndPoint), + _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())), + } + } +} + +#[derive(Debug)] +enum ExchangeState { + /// Waiting for [`ClientFirstMessage`]. + Initial, + /// Waiting for [`ClientFinalMessage`]. + SaltSent { + cbind_flag: ChannelBinding, + client_first_message_bare: String, + server_first_message: OwnedServerFirstMessage, + }, +} + +/// Server's side of SCRAM auth algorithm. +#[derive(Debug)] +pub struct Exchange<'a> { + state: ExchangeState, + secret: &'a ServerSecret, + nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], + cert_digest: Option<&'a [u8]>, +} + +impl<'a> Exchange<'a> { + pub fn new( + secret: &'a ServerSecret, + nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], + cert_digest: Option<&'a [u8]>, + ) -> Self { + Self { + state: ExchangeState::Initial, + secret, + nonce, + cert_digest, + } + } +} + +impl sasl::Mechanism for Exchange<'_> { + fn exchange(mut self, input: &str) -> sasl::Result<(Option, String)> { + use ExchangeState::*; + match &self.state { + Initial => { + let client_first_message = + ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + + let server_first_message = client_first_message.build_server_first_message( + &(self.nonce)(), + &self.secret.salt_base64, + self.secret.iterations, + ); + let msg = server_first_message.as_str().to_owned(); + + self.state = SaltSent { + cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?, + client_first_message_bare: client_first_message.bare.to_owned(), + server_first_message, + }; + + Ok((Some(self), msg)) + } + SaltSent { + cbind_flag, + client_first_message_bare, + server_first_message, + } => { + let client_final_message = + ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?; + + let channel_binding = cbind_flag.encode(|_| { + self.cert_digest + .map(base64::encode) + .ok_or(SaslError::ChannelBindingFailed("no cert digest provided")) + })?; + + // This might've been caused by a MITM attack + if client_final_message.channel_binding != channel_binding { + return Err(SaslError::ChannelBindingFailed("data mismatch")); + } + + if client_final_message.nonce != server_first_message.nonce() { + return Err(SaslError::AuthenticationFailed("bad nonce")); + } + + let signature_builder = SignatureBuilder { + client_first_message_bare, + server_first_message: server_first_message.as_str(), + client_final_message_without_proof: client_final_message.without_proof, + }; + + let client_key = signature_builder + .build(&self.secret.stored_key) + .derive_client_key(&client_final_message.proof); + + if client_key.sha256() != self.secret.stored_key { + return Err(SaslError::AuthenticationFailed("keys don't match")); + } + + let msg = client_final_message + .build_server_final_message(signature_builder, &self.secret.server_key); + + Ok((None, msg)) + } + } + } +} diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs new file mode 100644 index 0000000000..1c13471bc3 --- /dev/null +++ b/proxy/src/scram/key.rs @@ -0,0 +1,33 @@ +//! Tools for client/server/stored key management. + +/// Faithfully taken from PostgreSQL. +pub const SCRAM_KEY_LEN: usize = 32; + +/// One of the keys derived from the [password](super::password::SaltedPassword). +/// We use the same structure for all keys, i.e. +/// `ClientKey`, `StoredKey`, and `ServerKey`. +#[derive(Default, Debug, PartialEq, Eq)] +#[repr(transparent)] +pub struct ScramKey { + bytes: [u8; SCRAM_KEY_LEN], +} + +impl ScramKey { + pub fn sha256(&self) -> Self { + super::sha256([self.as_ref()]).into() + } +} + +impl From<[u8; SCRAM_KEY_LEN]> for ScramKey { + #[inline(always)] + fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self { + Self { bytes } + } +} + +impl AsRef<[u8]> for ScramKey { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + &self.bytes + } +} diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs new file mode 100644 index 0000000000..f6e6133adf --- /dev/null +++ b/proxy/src/scram/messages.rs @@ -0,0 +1,232 @@ +//! Definitions for SCRAM messages. + +use super::base64_decode_array; +use super::key::{ScramKey, SCRAM_KEY_LEN}; +use super::signature::SignatureBuilder; +use crate::sasl::ChannelBinding; +use std::fmt; +use std::ops::Range; + +/// Faithfully taken from PostgreSQL. +pub const SCRAM_RAW_NONCE_LEN: usize = 18; + +/// Although we ignore all extensions, we still have to validate the message. +fn validate_sasl_extensions<'a>(parts: impl Iterator) -> Option<()> { + for mut chars in parts.map(|s| s.chars()) { + let attr = chars.next()?; + if !('a'..'z').contains(&attr) && !('A'..'Z').contains(&attr) { + return None; + } + let eq = chars.next()?; + if eq != '=' { + return None; + } + } + + Some(()) +} + +#[derive(Debug)] +pub struct ClientFirstMessage<'a> { + /// `client-first-message-bare`. + pub bare: &'a str, + /// Channel binding mode. + pub cbind_flag: ChannelBinding<&'a str>, + /// (Client username)[]. + pub username: &'a str, + /// Client nonce. + pub nonce: &'a str, +} + +impl<'a> ClientFirstMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + let mut parts = input.split(','); + + let cbind_flag = ChannelBinding::parse(parts.next()?)?; + + // PG doesn't support authorization identity, + // so we don't bother defining GS2 header type + let authzid = parts.next()?; + if !authzid.is_empty() { + return None; + } + + // Unfortunately, `parts.as_str()` is unstable + let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1; + let (_, bare) = input.split_at(pos); + + // In theory, these might be preceded by "reserved-mext" (i.e. "m=") + let username = parts.next()?.strip_prefix("n=")?; + let nonce = parts.next()?.strip_prefix("r=")?; + + // Validate but ignore auth extensions + validate_sasl_extensions(parts)?; + + Some(Self { + bare, + cbind_flag, + username, + nonce, + }) + } + + /// Build a response to [`ClientFirstMessage`]. + pub fn build_server_first_message( + &self, + nonce: &[u8; SCRAM_RAW_NONCE_LEN], + salt_base64: &str, + iterations: u32, + ) -> OwnedServerFirstMessage { + use std::fmt::Write; + + let mut message = String::new(); + write!(&mut message, "r={}", self.nonce).unwrap(); + base64::encode_config_buf(nonce, base64::STANDARD, &mut message); + let combined_nonce = 2..message.len(); + write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap(); + + // This design guarantees that it's impossible to create a + // server-first-message without receiving a client-first-message + OwnedServerFirstMessage { + message, + nonce: combined_nonce, + } + } +} + +#[derive(Debug)] +pub struct ClientFinalMessage<'a> { + /// `client-final-message-without-proof`. + pub without_proof: &'a str, + /// Channel binding data (base64). + pub channel_binding: &'a str, + /// Combined client & server nonce. + pub nonce: &'a str, + /// Client auth proof. + pub proof: [u8; SCRAM_KEY_LEN], +} + +impl<'a> ClientFinalMessage<'a> { + // NB: FromStr doesn't work with lifetimes + pub fn parse(input: &'a str) -> Option { + let (without_proof, proof) = input.rsplit_once(',')?; + + let mut parts = without_proof.split(','); + let channel_binding = parts.next()?.strip_prefix("c=")?; + let nonce = parts.next()?.strip_prefix("r=")?; + + // Validate but ignore auth extensions + validate_sasl_extensions(parts)?; + + let proof = base64_decode_array(proof.strip_prefix("p=")?)?; + + Some(Self { + without_proof, + channel_binding, + nonce, + proof, + }) + } + + /// Build a response to [`ClientFinalMessage`]. + pub fn build_server_final_message( + &self, + signature_builder: SignatureBuilder, + server_key: &ScramKey, + ) -> String { + let mut buf = String::from("v="); + base64::encode_config_buf( + signature_builder.build(server_key), + base64::STANDARD, + &mut buf, + ); + + buf + } +} + +/// We need to keep a convenient representation of this +/// message for the next authentication step. +pub struct OwnedServerFirstMessage { + /// Owned `server-first-message`. + message: String, + /// Slice into `message`. + nonce: Range, +} + +impl OwnedServerFirstMessage { + /// Extract combined nonce from the message. + #[inline(always)] + pub fn nonce(&self) -> &str { + &self.message[self.nonce.clone()] + } + + /// Get reference to a text representation of the message. + #[inline(always)] + pub fn as_str(&self) -> &str { + &self.message + } +} + +impl fmt::Debug for OwnedServerFirstMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerFirstMessage") + .field("message", &self.as_str()) + .field("nonce", &self.nonce()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_client_first_message() { + use ChannelBinding::*; + + // (Almost) real strings captured during debug sessions + let cases = [ + (NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"), + (NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"), + ( + Required("tls-server-end-point"), + "p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju", + ), + ]; + + for (cb, input) in cases { + let msg = ClientFirstMessage::parse(input).unwrap(); + + assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju"); + assert_eq!(msg.username, "pepe"); + assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju"); + assert_eq!(msg.cbind_flag, cb); + } + } + + #[test] + fn parse_client_final_message() { + let input = [ + "c=eSws", + "r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU", + "p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=", + ] + .join(","); + + let msg = ClientFinalMessage::parse(&input).unwrap(); + assert_eq!( + msg.without_proof, + "c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU" + ); + assert_eq!( + msg.nonce, + "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU" + ); + assert_eq!( + base64::encode(msg.proof), + "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=" + ); + } +} diff --git a/proxy/src/scram/password.rs b/proxy/src/scram/password.rs new file mode 100644 index 0000000000..656780d853 --- /dev/null +++ b/proxy/src/scram/password.rs @@ -0,0 +1,48 @@ +//! Password hashing routines. + +use super::key::ScramKey; + +pub const SALTED_PASSWORD_LEN: usize = 32; + +/// Salted hashed password is essential for [key](super::key) derivation. +#[repr(transparent)] +pub struct SaltedPassword { + bytes: [u8; SALTED_PASSWORD_LEN], +} + +impl SaltedPassword { + /// See `scram-common.c : scram_SaltedPassword` for details. + /// Further reading: (see `PBKDF2`). + pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword { + let one = 1_u32.to_be_bytes(); // magic + + let mut current = super::hmac_sha256(password, [salt, &one]); + let mut result = current; + for _ in 1..iterations { + current = super::hmac_sha256(password, [current.as_ref()]); + // TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094 + for (i, x) in current.iter().enumerate() { + result[i] ^= x; + } + } + + result.into() + } + + /// Derive `ClientKey` from a salted hashed password. + pub fn client_key(&self) -> ScramKey { + super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into() + } + + /// Derive `ServerKey` from a salted hashed password. + pub fn server_key(&self) -> ScramKey { + super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into() + } +} + +impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword { + #[inline(always)] + fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self { + Self { bytes } + } +} diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs new file mode 100644 index 0000000000..e8d180bcdd --- /dev/null +++ b/proxy/src/scram/secret.rs @@ -0,0 +1,116 @@ +//! Tools for SCRAM server secret management. + +use super::base64_decode_array; +use super::key::ScramKey; + +/// Server secret is produced from [password](super::password::SaltedPassword) +/// and is used throughout the authentication process. +#[derive(Debug)] +pub struct ServerSecret { + /// Number of iterations for `PBKDF2` function. + pub iterations: u32, + /// Salt used to hash user's password. + pub salt_base64: String, + /// Hashed `ClientKey`. + pub stored_key: ScramKey, + /// Used by client to verify server's signature. + pub server_key: ScramKey, +} + +impl ServerSecret { + pub fn parse(input: &str) -> Option { + // SCRAM-SHA-256$:$: + let s = input.strip_prefix("SCRAM-SHA-256$")?; + let (params, keys) = s.split_once('$')?; + + let ((iterations, salt), (stored_key, server_key)) = + params.split_once(':').zip(keys.split_once(':'))?; + + let secret = ServerSecret { + iterations: iterations.parse().ok()?, + salt_base64: salt.to_owned(), + stored_key: base64_decode_array(stored_key)?.into(), + server_key: base64_decode_array(server_key)?.into(), + }; + + Some(secret) + } + + /// To avoid revealing information to an attacker, we use a + /// mocked server secret even if the user doesn't exist. + /// See `auth-scram.c : mock_scram_secret` for details. + pub fn mock(user: &str, nonce: &[u8; 32]) -> Self { + // Refer to `auth-scram.c : scram_mock_salt`. + let mocked_salt = super::sha256([user.as_bytes(), nonce]); + + Self { + iterations: 4096, + salt_base64: base64::encode(&mocked_salt), + stored_key: ScramKey::default(), + server_key: ScramKey::default(), + } + } + + /// Build a new server secret from the prerequisites. + /// XXX: We only use this function in tests. + #[cfg(test)] + pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option { + // TODO: implement proper password normalization required by the RFC + if !password.is_ascii() { + return None; + } + + let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations); + + Some(Self { + iterations, + salt_base64: base64::encode(&salt), + stored_key: password.client_key().sha256(), + server_key: password.server_key(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_scram_secret() { + let iterations = 4096; + let salt = "+/tQQax7twvwTj64mjBsxQ=="; + let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns="; + let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI="; + + let secret = format!( + "SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}", + iterations = iterations, + salt = salt, + stored_key = stored_key, + server_key = server_key, + ); + + let parsed = ServerSecret::parse(&secret).unwrap(); + assert_eq!(parsed.iterations, iterations); + assert_eq!(parsed.salt_base64, salt); + + assert_eq!(base64::encode(parsed.stored_key), stored_key); + assert_eq!(base64::encode(parsed.server_key), server_key); + } + + #[test] + fn build_scram_secret() { + let salt = b"salt"; + let secret = ServerSecret::build("password", salt, 4096).unwrap(); + assert_eq!(secret.iterations, 4096); + assert_eq!(secret.salt_base64, base64::encode(salt)); + assert_eq!( + base64::encode(secret.stored_key.as_ref()), + "lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ=" + ); + assert_eq!( + base64::encode(secret.server_key.as_ref()), + "ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw=" + ); + } +} diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs new file mode 100644 index 0000000000..1c2811d757 --- /dev/null +++ b/proxy/src/scram/signature.rs @@ -0,0 +1,66 @@ +//! Tools for client/server signature management. + +use super::key::{ScramKey, SCRAM_KEY_LEN}; + +/// A collection of message parts needed to derive the client's signature. +#[derive(Debug)] +pub struct SignatureBuilder<'a> { + pub client_first_message_bare: &'a str, + pub server_first_message: &'a str, + pub client_final_message_without_proof: &'a str, +} + +impl SignatureBuilder<'_> { + pub fn build(&self, key: &ScramKey) -> Signature { + let parts = [ + self.client_first_message_bare.as_bytes(), + b",", + self.server_first_message.as_bytes(), + b",", + self.client_final_message_without_proof.as_bytes(), + ]; + + super::hmac_sha256(key.as_ref(), parts).into() + } +} + +/// A computed value which, when xored with `ClientProof`, +/// produces `ClientKey` that we need for authentication. +#[derive(Debug)] +#[repr(transparent)] +pub struct Signature { + bytes: [u8; SCRAM_KEY_LEN], +} + +impl Signature { + /// Derive `ClientKey` from client's signature and proof. + pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey { + // This is how the proof is calculated: + // + // 1. sha256(ClientKey) -> StoredKey + // 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature + // 3. ClientKey ^ ClientSignature -> ClientProof + // + // Step 3 implies that we can restore ClientKey from the proof + // by xoring the latter with the ClientSignature. Afterwards we + // can check that the presumed ClientKey meets our expectations. + let mut signature = self.bytes; + for (i, x) in proof.iter().enumerate() { + signature[i] ^= x; + } + + signature.into() + } +} + +impl From<[u8; SCRAM_KEY_LEN]> for Signature { + fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self { + Self { bytes } + } +} + +impl AsRef<[u8]> for Signature { + fn as_ref(&self) -> &[u8] { + &self.bytes + } +} diff --git a/test_runner/batch_others/test_createuser.py b/test_runner/batch_others/test_createuser.py index efb2af3f07..f4bbbc8a7a 100644 --- a/test_runner/batch_others/test_createuser.py +++ b/test_runner/batch_others/test_createuser.py @@ -28,4 +28,4 @@ def test_createuser(zenith_simple_env: ZenithEnv): pg2 = env.postgres.create_start('test_createuser2') # Test that you can connect to new branch as a new user - assert pg2.safe_psql('select current_user', username='testuser') == [('testuser', )] + assert pg2.safe_psql('select current_user', user='testuser') == [('testuser', )] diff --git a/test_runner/batch_others/test_parallel_copy.py b/test_runner/batch_others/test_parallel_copy.py index 4b7cc58d42..a44acecf21 100644 --- a/test_runner/batch_others/test_parallel_copy.py +++ b/test_runner/batch_others/test_parallel_copy.py @@ -19,6 +19,11 @@ async def copy_test_data_to_table(pg: Postgres, worker_id: int, table_name: str) copy_input = repeat_bytes(buf.read(), 5000) pg_conn = await pg.connect_async() + + # PgProtocol.connect_async sets statement_timeout to 2 minutes. + # That's not enough for this test, on a slow system in debug mode. + await pg_conn.execute("SET statement_timeout='300s'") + await pg_conn.copy_to_table(table_name, source=copy_input) diff --git a/test_runner/batch_others/test_pgbench.py b/test_runner/batch_others/test_pgbench.py deleted file mode 100644 index 09713023bc..0000000000 --- a/test_runner/batch_others/test_pgbench.py +++ /dev/null @@ -1,14 +0,0 @@ -from fixtures.zenith_fixtures import ZenithEnv -from fixtures.log_helper import log - - -def test_pgbench(zenith_simple_env: ZenithEnv, pg_bin): - env = zenith_simple_env - env.zenith_cli.create_branch("test_pgbench", "empty") - pg = env.postgres.create_start('test_pgbench') - log.info("postgres is running on 'test_pgbench' branch") - - connstr = pg.connstr() - - pg_bin.run_capture(['pgbench', '-i', connstr]) - pg_bin.run_capture(['pgbench'] + '-c 10 -T 5 -P 1 -M prepared'.split() + [connstr]) diff --git a/test_runner/batch_others/test_proxy.py b/test_runner/batch_others/test_proxy.py index d2039f9758..a6f828f829 100644 --- a/test_runner/batch_others/test_proxy.py +++ b/test_runner/batch_others/test_proxy.py @@ -5,11 +5,14 @@ def test_proxy_select_1(static_proxy): static_proxy.safe_psql("select 1;") -@pytest.mark.xfail # Proxy eats the extra connection options +# Pass extra options to the server. +# +# Currently, proxy eats the extra connection options, so this fails. +# See https://github.com/neondatabase/neon/issues/1287 +@pytest.mark.xfail def test_proxy_options(static_proxy): - schema_name = "tmp_schema_1" - with static_proxy.connect(schema=schema_name) as conn: + with static_proxy.connect(options="-cproxytest.option=value") as conn: with conn.cursor() as cur: - cur.execute("SHOW search_path;") - search_path = cur.fetchall()[0][0] - assert schema_name == search_path + cur.execute("SHOW proxytest.option;") + value = cur.fetchall()[0][0] + assert value == 'value' diff --git a/test_runner/batch_others/test_wal_acceptor.py b/test_runner/batch_others/test_wal_acceptor.py index 8f87ff041f..dffcd7cc61 100644 --- a/test_runner/batch_others/test_wal_acceptor.py +++ b/test_runner/batch_others/test_wal_acceptor.py @@ -379,7 +379,7 @@ class ProposerPostgres(PgProtocol): tenant_id: uuid.UUID, listen_addr: str, port: int): - super().__init__(host=listen_addr, port=port, username='zenith_admin') + super().__init__(host=listen_addr, port=port, user='zenith_admin', dbname='postgres') self.pgdata_dir: str = pgdata_dir self.pg_bin: PgBin = pg_bin diff --git a/test_runner/batch_pg_regress/test_isolation.py b/test_runner/batch_pg_regress/test_isolation.py index ddafc3815b..cde56d9b88 100644 --- a/test_runner/batch_pg_regress/test_isolation.py +++ b/test_runner/batch_pg_regress/test_isolation.py @@ -35,9 +35,9 @@ def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys ] env_vars = { - 'PGPORT': str(pg.port), - 'PGUSER': pg.username, - 'PGHOST': pg.host, + 'PGPORT': str(pg.default_options['port']), + 'PGUSER': pg.default_options['user'], + 'PGHOST': pg.default_options['host'], } # Run the command. diff --git a/test_runner/batch_pg_regress/test_pg_regress.py b/test_runner/batch_pg_regress/test_pg_regress.py index 5199f65216..07d2574f4a 100644 --- a/test_runner/batch_pg_regress/test_pg_regress.py +++ b/test_runner/batch_pg_regress/test_pg_regress.py @@ -35,9 +35,9 @@ def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin, ] env_vars = { - 'PGPORT': str(pg.port), - 'PGUSER': pg.username, - 'PGHOST': pg.host, + 'PGPORT': str(pg.default_options['port']), + 'PGUSER': pg.default_options['user'], + 'PGHOST': pg.default_options['host'], } # Run the command. diff --git a/test_runner/batch_pg_regress/test_zenith_regress.py b/test_runner/batch_pg_regress/test_zenith_regress.py index 31d5b07093..2b57137d16 100644 --- a/test_runner/batch_pg_regress/test_zenith_regress.py +++ b/test_runner/batch_pg_regress/test_zenith_regress.py @@ -40,9 +40,9 @@ def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, c log.info(pg_regress_command) env_vars = { - 'PGPORT': str(pg.port), - 'PGUSER': pg.username, - 'PGHOST': pg.host, + 'PGPORT': str(pg.default_options['port']), + 'PGUSER': pg.default_options['user'], + 'PGHOST': pg.default_options['host'], } # Run the command. diff --git a/test_runner/fixtures/benchmark_fixture.py b/test_runner/fixtures/benchmark_fixture.py index 11d37eb8f9..8b9aabfe43 100644 --- a/test_runner/fixtures/benchmark_fixture.py +++ b/test_runner/fixtures/benchmark_fixture.py @@ -17,7 +17,7 @@ import warnings from contextlib import contextmanager # Type-related stuff -from typing import Iterator +from typing import Iterator, Optional """ This file contains fixtures for micro-benchmarks. @@ -51,17 +51,12 @@ in the test initialization, or measure disk usage after the test query. @dataclasses.dataclass class PgBenchRunResult: - scale: int number_of_clients: int number_of_threads: int number_of_transactions_actually_processed: int latency_average: float - latency_stddev: float - tps_including_connection_time: float - tps_excluding_connection_time: float - init_duration: float - init_start_timestamp: int - init_end_timestamp: int + latency_stddev: Optional[float] + tps: float run_duration: float run_start_timestamp: int run_end_timestamp: int @@ -69,56 +64,67 @@ class PgBenchRunResult: # TODO progress @classmethod - def parse_from_output( + def parse_from_stdout( cls, - out: 'subprocess.CompletedProcess[str]', - init_duration: float, - init_start_timestamp: int, - init_end_timestamp: int, + stdout: str, run_duration: float, run_start_timestamp: int, run_end_timestamp: int, ): - stdout_lines = out.stdout.splitlines() + stdout_lines = stdout.splitlines() + + latency_stddev = None + # we know significant parts of these values from test input # but to be precise take them from output - # scaling factor: 5 - assert "scaling factor" in stdout_lines[1] - scale = int(stdout_lines[1].split()[-1]) - # number of clients: 1 - assert "number of clients" in stdout_lines[3] - number_of_clients = int(stdout_lines[3].split()[-1]) - # number of threads: 1 - assert "number of threads" in stdout_lines[4] - number_of_threads = int(stdout_lines[4].split()[-1]) - # number of transactions actually processed: 1000/1000 - assert "number of transactions actually processed" in stdout_lines[6] - number_of_transactions_actually_processed = int(stdout_lines[6].split("/")[1]) - # latency average = 19.894 ms - assert "latency average" in stdout_lines[7] - latency_average = stdout_lines[7].split()[-2] - # latency stddev = 3.387 ms - assert "latency stddev" in stdout_lines[8] - latency_stddev = stdout_lines[8].split()[-2] - # tps = 50.219689 (including connections establishing) - assert "(including connections establishing)" in stdout_lines[9] - tps_including_connection_time = stdout_lines[9].split()[2] - # tps = 50.264435 (excluding connections establishing) - assert "(excluding connections establishing)" in stdout_lines[10] - tps_excluding_connection_time = stdout_lines[10].split()[2] + for line in stdout.splitlines(): + # scaling factor: 5 + if line.startswith("scaling factor:"): + scale = int(line.split()[-1]) + # number of clients: 1 + if line.startswith("number of clients: "): + number_of_clients = int(line.split()[-1]) + # number of threads: 1 + if line.startswith("number of threads: "): + number_of_threads = int(line.split()[-1]) + # number of transactions actually processed: 1000/1000 + # OR + # number of transactions actually processed: 1000 + if line.startswith("number of transactions actually processed"): + if "/" in line: + number_of_transactions_actually_processed = int(line.split("/")[1]) + else: + number_of_transactions_actually_processed = int(line.split()[-1]) + # latency average = 19.894 ms + if line.startswith("latency average"): + latency_average = float(line.split()[-2]) + # latency stddev = 3.387 ms + # (only printed with some options) + if line.startswith("latency stddev"): + latency_stddev = float(line.split()[-2]) + + # Get the TPS without initial connection time. The format + # of the tps lines changed in pgbench v14, but we accept + # either format: + # + # pgbench v13 and below: + # tps = 50.219689 (including connections establishing) + # tps = 50.264435 (excluding connections establishing) + # + # pgbench v14: + # initial connection time = 3.858 ms + # tps = 309.281539 (without initial connection time) + if (line.startswith("tps = ") and ("(excluding connections establishing)" in line + or "(without initial connection time)")): + tps = float(line.split()[2]) return cls( - scale=scale, number_of_clients=number_of_clients, number_of_threads=number_of_threads, number_of_transactions_actually_processed=number_of_transactions_actually_processed, - latency_average=float(latency_average), - latency_stddev=float(latency_stddev), - tps_including_connection_time=float(tps_including_connection_time), - tps_excluding_connection_time=float(tps_excluding_connection_time), - init_duration=init_duration, - init_start_timestamp=init_start_timestamp, - init_end_timestamp=init_end_timestamp, + latency_average=latency_average, + latency_stddev=latency_stddev, + tps=tps, run_duration=run_duration, run_start_timestamp=run_start_timestamp, run_end_timestamp=run_end_timestamp, @@ -187,60 +193,41 @@ class ZenithBenchmarker: report=MetricReport.LOWER_IS_BETTER, ) - def record_pg_bench_result(self, pg_bench_result: PgBenchRunResult): - self.record("scale", pg_bench_result.scale, '', MetricReport.TEST_PARAM) - self.record("number_of_clients", + def record_pg_bench_result(self, prefix: str, pg_bench_result: PgBenchRunResult): + self.record(f"{prefix}.number_of_clients", pg_bench_result.number_of_clients, '', MetricReport.TEST_PARAM) - self.record("number_of_threads", + self.record(f"{prefix}.number_of_threads", pg_bench_result.number_of_threads, '', MetricReport.TEST_PARAM) self.record( - "number_of_transactions_actually_processed", + f"{prefix}.number_of_transactions_actually_processed", pg_bench_result.number_of_transactions_actually_processed, '', # thats because this is predefined by test matrix and doesnt change across runs report=MetricReport.TEST_PARAM, ) - self.record("latency_average", + self.record(f"{prefix}.latency_average", pg_bench_result.latency_average, unit="ms", report=MetricReport.LOWER_IS_BETTER) - self.record("latency_stddev", - pg_bench_result.latency_stddev, - unit="ms", - report=MetricReport.LOWER_IS_BETTER) - self.record("tps_including_connection_time", - pg_bench_result.tps_including_connection_time, - '', - report=MetricReport.HIGHER_IS_BETTER) - self.record("tps_excluding_connection_time", - pg_bench_result.tps_excluding_connection_time, - '', - report=MetricReport.HIGHER_IS_BETTER) - self.record("init_duration", - pg_bench_result.init_duration, - unit="s", - report=MetricReport.LOWER_IS_BETTER) - self.record("init_start_timestamp", - pg_bench_result.init_start_timestamp, - '', - MetricReport.TEST_PARAM) - self.record("init_end_timestamp", - pg_bench_result.init_end_timestamp, - '', - MetricReport.TEST_PARAM) - self.record("run_duration", + if pg_bench_result.latency_stddev is not None: + self.record(f"{prefix}.latency_stddev", + pg_bench_result.latency_stddev, + unit="ms", + report=MetricReport.LOWER_IS_BETTER) + self.record(f"{prefix}.tps", pg_bench_result.tps, '', report=MetricReport.HIGHER_IS_BETTER) + self.record(f"{prefix}.run_duration", pg_bench_result.run_duration, unit="s", report=MetricReport.LOWER_IS_BETTER) - self.record("run_start_timestamp", + self.record(f"{prefix}.run_start_timestamp", pg_bench_result.run_start_timestamp, '', MetricReport.TEST_PARAM) - self.record("run_end_timestamp", + self.record(f"{prefix}.run_end_timestamp", pg_bench_result.run_end_timestamp, '', MetricReport.TEST_PARAM) @@ -259,10 +246,18 @@ class ZenithBenchmarker: """ Fetch the "cumulative # of bytes written" metric from the pageserver """ - # Fetch all the exposed prometheus metrics from page server - all_metrics = pageserver.http_client().get_metrics() - # Use a regular expression to extract the one we're interested in - # + metric_name = r'pageserver_disk_io_bytes{io_operation="write"}' + return self.get_int_counter_value(pageserver, metric_name) + + def get_peak_mem(self, pageserver) -> int: + """ + Fetch the "maxrss" metric from the pageserver + """ + metric_name = r'pageserver_maxrss_kb' + return self.get_int_counter_value(pageserver, metric_name) + + def get_int_counter_value(self, pageserver, metric_name) -> int: + """Fetch the value of given int counter from pageserver metrics.""" # TODO: If we start to collect more of the prometheus metrics in the # performance test suite like this, we should refactor this to load and # parse all the metrics into a more convenient structure in one go. @@ -270,20 +265,8 @@ class ZenithBenchmarker: # The metric should be an integer, as it's a number of bytes. But in general # all prometheus metrics are floats. So to be pedantic, read it as a float # and round to integer. - matches = re.search(r'^pageserver_disk_io_bytes{io_operation="write"} (\S+)$', - all_metrics, - re.MULTILINE) - assert matches - return int(round(float(matches.group(1)))) - - def get_peak_mem(self, pageserver) -> int: - """ - Fetch the "maxrss" metric from the pageserver - """ - # Fetch all the exposed prometheus metrics from page server all_metrics = pageserver.http_client().get_metrics() - # See comment in get_io_writes() - matches = re.search(r'^pageserver_maxrss_kb (\S+)$', all_metrics, re.MULTILINE) + matches = re.search(fr'^{metric_name} (\S+)$', all_metrics, re.MULTILINE) assert matches return int(round(float(matches.group(1)))) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 750b02c894..93912d2da7 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -2,7 +2,7 @@ import pytest from contextlib import contextmanager from abc import ABC, abstractmethod -from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, ZenithEnv +from fixtures.zenith_fixtures import PgBin, PgProtocol, VanillaPostgres, RemotePostgres, ZenithEnv from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker # Type-related stuff @@ -87,6 +87,9 @@ class ZenithCompare(PgCompare): def flush(self): self.pscur.execute(f"do_gc {self.env.initial_tenant.hex} {self.timeline} 0") + def compact(self): + self.pscur.execute(f"compact {self.env.initial_tenant.hex} {self.timeline}") + def report_peak_memory_use(self) -> None: self.zenbenchmark.record("peak_mem", self.zenbenchmark.get_peak_mem(self.env.pageserver) / 1024, @@ -102,6 +105,19 @@ class ZenithCompare(PgCompare): 'MB', report=MetricReport.LOWER_IS_BETTER) + total_files = self.zenbenchmark.get_int_counter_value( + self.env.pageserver, "pageserver_num_persistent_files_created") + total_bytes = self.zenbenchmark.get_int_counter_value( + self.env.pageserver, "pageserver_persistent_bytes_written") + self.zenbenchmark.record("data_uploaded", + total_bytes / (1024 * 1024), + "MB", + report=MetricReport.LOWER_IS_BETTER) + self.zenbenchmark.record("num_files_uploaded", + total_files, + "", + report=MetricReport.LOWER_IS_BETTER) + def record_pageserver_writes(self, out_name): return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name) @@ -159,6 +175,48 @@ class VanillaCompare(PgCompare): return self.zenbenchmark.record_duration(out_name) +class RemoteCompare(PgCompare): + """PgCompare interface for a remote postgres instance.""" + def __init__(self, zenbenchmark, remote_pg: RemotePostgres): + self._pg = remote_pg + self._zenbenchmark = zenbenchmark + + # Long-lived cursor, useful for flushing + self.conn = self.pg.connect() + self.cur = self.conn.cursor() + + @property + def pg(self): + return self._pg + + @property + def zenbenchmark(self): + return self._zenbenchmark + + @property + def pg_bin(self): + return self._pg.pg_bin + + def flush(self): + # TODO: flush the remote pageserver + pass + + def report_peak_memory_use(self) -> None: + # TODO: get memory usage from remote pageserver + pass + + def report_size(self) -> None: + # TODO: get storage size from remote pageserver + pass + + @contextmanager + def record_pageserver_writes(self, out_name): + yield # Do nothing + + def record_duration(self, out_name): + return self.zenbenchmark.record_duration(out_name) + + @pytest.fixture(scope='function') def zenith_compare(request, zenbenchmark, pg_bin, zenith_simple_env) -> ZenithCompare: branch_name = request.node.name @@ -170,6 +228,11 @@ def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare: return VanillaCompare(zenbenchmark, vanilla_pg) +@pytest.fixture(scope='function') +def remote_compare(zenbenchmark, remote_pg) -> RemoteCompare: + return RemoteCompare(zenbenchmark, remote_pg) + + @pytest.fixture(params=["vanilla_compare", "zenith_compare"], ids=["vanilla", "zenith"]) def zenith_with_baseline(request) -> PgCompare: """Parameterized fixture that helps compare zenith against vanilla postgres. diff --git a/test_runner/fixtures/zenith_fixtures.py b/test_runner/fixtures/zenith_fixtures.py index ac895ddee7..0040494fa5 100644 --- a/test_runner/fixtures/zenith_fixtures.py +++ b/test_runner/fixtures/zenith_fixtures.py @@ -27,6 +27,7 @@ from dataclasses import dataclass # Type-related stuff from psycopg2.extensions import connection as PgConnection +from psycopg2.extensions import make_dsn, parse_dsn from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple from typing_extensions import Literal @@ -122,6 +123,22 @@ def pytest_configure(config): top_output_dir = os.path.join(base_dir, DEFAULT_OUTPUT_DIR) mkdir_if_needed(top_output_dir) + # Find the postgres installation. + global pg_distrib_dir + env_postgres_bin = os.environ.get('POSTGRES_DISTRIB_DIR') + if env_postgres_bin: + pg_distrib_dir = env_postgres_bin + else: + pg_distrib_dir = os.path.normpath(os.path.join(base_dir, DEFAULT_POSTGRES_DIR)) + log.info(f'pg_distrib_dir is {pg_distrib_dir}') + if os.getenv("REMOTE_ENV"): + # When testing against a remote server, we only need the client binary. + if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/psql')): + raise Exception('psql not found at "{}"'.format(pg_distrib_dir)) + else: + if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/postgres')): + raise Exception('postgres not found at "{}"'.format(pg_distrib_dir)) + if os.getenv("REMOTE_ENV"): # we are in remote env and do not have zenith binaries locally # this is the case for benchmarks run on self-hosted runner @@ -137,17 +154,6 @@ def pytest_configure(config): if not os.path.exists(os.path.join(zenith_binpath, 'pageserver')): raise Exception('zenith binaries not found at "{}"'.format(zenith_binpath)) - # Find the postgres installation. - global pg_distrib_dir - env_postgres_bin = os.environ.get('POSTGRES_DISTRIB_DIR') - if env_postgres_bin: - pg_distrib_dir = env_postgres_bin - else: - pg_distrib_dir = os.path.normpath(os.path.join(base_dir, DEFAULT_POSTGRES_DIR)) - log.info(f'pg_distrib_dir is {pg_distrib_dir}') - if not os.path.exists(os.path.join(pg_distrib_dir, 'bin/postgres')): - raise Exception('postgres not found at "{}"'.format(pg_distrib_dir)) - def zenfixture(func: Fn) -> Fn: """ @@ -238,98 +244,69 @@ def port_distributor(worker_base_port): class PgProtocol: """ Reusable connection logic """ - def __init__(self, - host: str, - port: int, - username: Optional[str] = None, - password: Optional[str] = None, - dbname: Optional[str] = None, - schema: Optional[str] = None): - self.host = host - self.port = port - self.username = username - self.password = password - self.dbname = dbname - self.schema = schema + def __init__(self, **kwargs): + self.default_options = kwargs - def connstr(self, - *, - dbname: Optional[str] = None, - schema: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - statement_timeout_ms: Optional[int] = None) -> str: + def connstr(self, **kwargs) -> str: """ Build a libpq connection string for the Postgres instance. """ + return str(make_dsn(**self.conn_options(**kwargs))) - username = username or self.username - password = password or self.password - dbname = dbname or self.dbname or "postgres" - schema = schema or self.schema - res = f'host={self.host} port={self.port} dbname={dbname}' + def conn_options(self, **kwargs): + conn_options = self.default_options.copy() + if 'dsn' in kwargs: + conn_options.update(parse_dsn(kwargs['dsn'])) + conn_options.update(kwargs) - if username: - res = f'{res} user={username}' - - if password: - res = f'{res} password={password}' - - if schema: - res = f"{res} options='-c search_path={schema}'" - - if statement_timeout_ms: - res = f"{res} options='-c statement_timeout={statement_timeout_ms}'" - - return res + # Individual statement timeout in seconds. 2 minutes should be + # enough for our tests, but if you need a longer, you can + # change it by calling "SET statement_timeout" after + # connecting. + if 'options' in conn_options: + conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options'] + else: + conn_options['options'] = "-cstatement_timeout=120s" + return conn_options # autocommit=True here by default because that's what we need most of the time - def connect( - self, - *, - autocommit=True, - dbname: Optional[str] = None, - schema: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - # individual statement timeout in seconds, 2 minutes should be enough for our tests - statement_timeout: Optional[int] = 120 - ) -> PgConnection: + def connect(self, autocommit=True, **kwargs) -> PgConnection: """ Connect to the node. Returns psycopg2's connection object. This method passes all extra params to connstr. """ + conn = psycopg2.connect(**self.conn_options(**kwargs)) - conn = psycopg2.connect( - self.connstr(dbname=dbname, - schema=schema, - username=username, - password=password, - statement_timeout_ms=statement_timeout * - 1000 if statement_timeout else None)) # WARNING: this setting affects *all* tests! conn.autocommit = autocommit return conn - async def connect_async(self, - *, - dbname: str = 'postgres', - username: Optional[str] = None, - password: Optional[str] = None) -> asyncpg.Connection: + async def connect_async(self, **kwargs) -> asyncpg.Connection: """ Connect to the node from async python. Returns asyncpg's connection object. """ - conn = await asyncpg.connect( - host=self.host, - port=self.port, - database=dbname, - user=username or self.username, - password=password, - ) - return conn + # asyncpg takes slightly different options than psycopg2. Try + # to convert the defaults from the psycopg2 format. + + # The psycopg2 option 'dbname' is called 'database' is asyncpg + conn_options = self.conn_options(**kwargs) + if 'dbname' in conn_options: + conn_options['database'] = conn_options.pop('dbname') + + # Convert options='-c=' to server_settings + if 'options' in conn_options: + options = conn_options.pop('options') + for match in re.finditer('-c(\w*)=(\w*)', options): + key = match.group(1) + val = match.group(2) + if 'server_options' in conn_options: + conn_options['server_settings'].update({key: val}) + else: + conn_options['server_settings'] = {key: val} + return await asyncpg.connect(**conn_options) def safe_psql(self, query: str, **kwargs: Any) -> List[Any]: """ @@ -1149,10 +1126,10 @@ class ZenithPageserver(PgProtocol): port: PageserverPort, remote_storage: Optional[RemoteStorage] = None, config_override: Optional[str] = None): - super().__init__(host='localhost', port=port.pg, username='zenith_admin') + super().__init__(host='localhost', port=port.pg, user='zenith_admin') self.env = env self.running = False - self.service_port = port # do not shadow PgProtocol.port which is just int + self.service_port = port self.remote_storage = remote_storage self.config_override = config_override @@ -1314,7 +1291,7 @@ def psbench_bin(test_output_dir): class VanillaPostgres(PgProtocol): def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int): - super().__init__(host='localhost', port=port) + super().__init__(host='localhost', port=port, dbname='postgres') self.pgdatadir = pgdatadir self.pg_bin = pg_bin self.running = False @@ -1356,10 +1333,57 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]: yield vanilla_pg +class RemotePostgres(PgProtocol): + def __init__(self, pg_bin: PgBin, remote_connstr: str): + super().__init__(**parse_dsn(remote_connstr)) + self.pg_bin = pg_bin + # The remote server is assumed to be running already + self.running = True + + def configure(self, options: List[str]): + raise Exception('cannot change configuration of remote Posgres instance') + + def start(self): + raise Exception('cannot start a remote Postgres instance') + + def stop(self): + raise Exception('cannot stop a remote Postgres instance') + + def get_subdir_size(self, subdir) -> int: + # TODO: Could use the server's Generic File Acccess functions if superuser. + # See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE + raise Exception('cannot get size of a Postgres instance') + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + # do nothing + pass + + +@pytest.fixture(scope='function') +def remote_pg(test_output_dir: str) -> Iterator[RemotePostgres]: + pg_bin = PgBin(test_output_dir) + + connstr = os.getenv("BENCHMARK_CONNSTR") + if connstr is None: + raise ValueError("no connstr provided, use BENCHMARK_CONNSTR environment variable") + + with RemotePostgres(pg_bin, connstr) as remote_pg: + yield remote_pg + + class ZenithProxy(PgProtocol): def __init__(self, port: int): - super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port) + super().__init__(host="127.0.0.1", + user="pytest", + password="pytest", + port=port, + dbname='postgres') self.http_port = 7001 + self.host = "127.0.0.1" + self.port = port self._popen: Optional[subprocess.Popen[bytes]] = None def start_static(self, addr="127.0.0.1:5432") -> None: @@ -1403,13 +1427,13 @@ def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]: class Postgres(PgProtocol): """ An object representing a running postgres daemon. """ def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int): - super().__init__(host='localhost', port=port, username='zenith_admin') - + super().__init__(host='localhost', port=port, user='zenith_admin', dbname='postgres') self.env = env self.running = False self.node_name: Optional[str] = None # dubious, see asserts below self.pgdata_dir: Optional[str] = None # Path to computenode PGDATA self.tenant_id = tenant_id + self.port = port # path to conf is /pgdatadirs/tenants///postgresql.conf def create( diff --git a/test_runner/performance/test_perf_pgbench.py b/test_runner/performance/test_perf_pgbench.py index 5ffce3c0be..d2de76913a 100644 --- a/test_runner/performance/test_perf_pgbench.py +++ b/test_runner/performance/test_perf_pgbench.py @@ -2,29 +2,113 @@ from contextlib import closing from fixtures.zenith_fixtures import PgBin, VanillaPostgres, ZenithEnv from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare -from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker +from fixtures.benchmark_fixture import PgBenchRunResult, MetricReport, ZenithBenchmarker from fixtures.log_helper import log +from pathlib import Path + +import pytest +from datetime import datetime +import calendar +import os +import timeit + + +def utc_now_timestamp() -> int: + return calendar.timegm(datetime.utcnow().utctimetuple()) + + +def init_pgbench(env: PgCompare, cmdline): + # calculate timestamps and durations separately + # timestamp is intended to be used for linking to grafana and logs + # duration is actually a metric and uses float instead of int for timestamp + init_start_timestamp = utc_now_timestamp() + t0 = timeit.default_timer() + with env.record_pageserver_writes('init.pageserver_writes'): + env.pg_bin.run_capture(cmdline) + env.flush() + init_duration = timeit.default_timer() - t0 + init_end_timestamp = utc_now_timestamp() + + env.zenbenchmark.record("init.duration", + init_duration, + unit="s", + report=MetricReport.LOWER_IS_BETTER) + env.zenbenchmark.record("init.start_timestamp", + init_start_timestamp, + '', + MetricReport.TEST_PARAM) + env.zenbenchmark.record("init.end_timestamp", init_end_timestamp, '', MetricReport.TEST_PARAM) + + +def run_pgbench(env: PgCompare, prefix: str, cmdline): + with env.record_pageserver_writes(f'{prefix}.pageserver_writes'): + run_start_timestamp = utc_now_timestamp() + t0 = timeit.default_timer() + out = env.pg_bin.run_capture(cmdline, ) + run_duration = timeit.default_timer() - t0 + run_end_timestamp = utc_now_timestamp() + env.flush() + + stdout = Path(f"{out}.stdout").read_text() + + res = PgBenchRunResult.parse_from_stdout( + stdout=stdout, + run_duration=run_duration, + run_start_timestamp=run_start_timestamp, + run_end_timestamp=run_end_timestamp, + ) + env.zenbenchmark.record_pg_bench_result(prefix, res) + # -# Run a very short pgbench test. +# Initialize a pgbench database, and run pgbench against it. # -# Collects three metrics: +# This makes runs two different pgbench workloads against the same +# initialized database, and 'duration' is the time of each run. So +# the total runtime is 2 * duration, plus time needed to initialize +# the test database. # -# 1. Time to initialize the pgbench database (pgbench -s5 -i) -# 2. Time to run 5000 pgbench transactions -# 3. Disk space used -# -def test_pgbench(zenith_with_baseline: PgCompare): - env = zenith_with_baseline +# Currently, the # of connections is hardcoded at 4 +def run_test_pgbench(env: PgCompare, scale: int, duration: int): - with env.record_pageserver_writes('pageserver_writes'): - with env.record_duration('init'): - env.pg_bin.run_capture(['pgbench', '-s5', '-i', env.pg.connstr()]) - env.flush() + # Record the scale and initialize + env.zenbenchmark.record("scale", scale, '', MetricReport.TEST_PARAM) + init_pgbench(env, ['pgbench', f'-s{scale}', '-i', env.pg.connstr()]) - with env.record_duration('5000_xacts'): - env.pg_bin.run_capture(['pgbench', '-c1', '-t5000', env.pg.connstr()]) - env.flush() + # Run simple-update workload + run_pgbench(env, + "simple-update", + ['pgbench', '-n', '-c4', f'-T{duration}', '-P2', '-Mprepared', env.pg.connstr()]) + + # Run SELECT workload + run_pgbench(env, + "select-only", + ['pgbench', '-S', '-c4', f'-T{duration}', '-P2', '-Mprepared', env.pg.connstr()]) env.report_size() + + +def get_durations_matrix(): + durations = os.getenv("TEST_PG_BENCH_DURATIONS_MATRIX", default="45") + return list(map(int, durations.split(","))) + + +def get_scales_matrix(): + scales = os.getenv("TEST_PG_BENCH_SCALES_MATRIX", default="10") + return list(map(int, scales.split(","))) + + +# Run the pgbench tests against vanilla Postgres and zenith +@pytest.mark.parametrize("scale", get_scales_matrix()) +@pytest.mark.parametrize("duration", get_durations_matrix()) +def test_pgbench(zenith_with_baseline: PgCompare, scale: int, duration: int): + run_test_pgbench(zenith_with_baseline, scale, duration) + + +# Run the pgbench tests against an existing Postgres cluster +@pytest.mark.parametrize("scale", get_scales_matrix()) +@pytest.mark.parametrize("duration", get_durations_matrix()) +@pytest.mark.remote_cluster +def test_pgbench_remote(remote_compare: PgCompare, scale: int, duration: int): + run_test_pgbench(remote_compare, scale, duration) diff --git a/test_runner/performance/test_perf_pgbench_remote.py b/test_runner/performance/test_perf_pgbench_remote.py deleted file mode 100644 index 28472a16c8..0000000000 --- a/test_runner/performance/test_perf_pgbench_remote.py +++ /dev/null @@ -1,124 +0,0 @@ -import dataclasses -import os -import subprocess -from typing import List -from fixtures.benchmark_fixture import PgBenchRunResult, ZenithBenchmarker -import pytest -from datetime import datetime -import calendar -import timeit -import os - - -def utc_now_timestamp() -> int: - return calendar.timegm(datetime.utcnow().utctimetuple()) - - -@dataclasses.dataclass -class PgBenchRunner: - connstr: str - scale: int - transactions: int - pgbench_bin_path: str = "pgbench" - - def invoke(self, args: List[str]) -> 'subprocess.CompletedProcess[str]': - res = subprocess.run([self.pgbench_bin_path, *args], text=True, capture_output=True) - - if res.returncode != 0: - raise RuntimeError(f"pgbench failed. stdout: {res.stdout} stderr: {res.stderr}") - return res - - def init(self, vacuum: bool = True) -> 'subprocess.CompletedProcess[str]': - args = [] - if not vacuum: - args.append("--no-vacuum") - args.extend([f"--scale={self.scale}", "--initialize", self.connstr]) - return self.invoke(args) - - def run(self, jobs: int = 1, clients: int = 1): - return self.invoke([ - f"--transactions={self.transactions}", - f"--jobs={jobs}", - f"--client={clients}", - "--progress=2", # print progress every two seconds - self.connstr, - ]) - - -@pytest.fixture -def connstr(): - res = os.getenv("BENCHMARK_CONNSTR") - if res is None: - raise ValueError("no connstr provided, use BENCHMARK_CONNSTR environment variable") - return res - - -def get_transactions_matrix(): - transactions = os.getenv("TEST_PG_BENCH_TRANSACTIONS_MATRIX") - if transactions is None: - return [10**4, 10**5] - return list(map(int, transactions.split(","))) - - -def get_scales_matrix(): - scales = os.getenv("TEST_PG_BENCH_SCALES_MATRIX") - if scales is None: - return [10, 20] - return list(map(int, scales.split(","))) - - -@pytest.mark.parametrize("scale", get_scales_matrix()) -@pytest.mark.parametrize("transactions", get_transactions_matrix()) -@pytest.mark.remote_cluster -def test_pg_bench_remote_cluster(zenbenchmark: ZenithBenchmarker, - connstr: str, - scale: int, - transactions: int): - """ - The best way is to run same pack of tests both, for local zenith - and against staging, but currently local tests heavily depend on - things available only locally e.g. zenith binaries, pageserver api, etc. - Also separate test allows to run pgbench workload against vanilla postgres - or other systems that support postgres protocol. - - Also now this is more of a liveness test because it stresses pageserver internals, - so we clearly see what goes wrong in more "real" environment. - """ - pg_bin = os.getenv("PG_BIN") - if pg_bin is not None: - pgbench_bin_path = os.path.join(pg_bin, "pgbench") - else: - pgbench_bin_path = "pgbench" - - runner = PgBenchRunner( - connstr=connstr, - scale=scale, - transactions=transactions, - pgbench_bin_path=pgbench_bin_path, - ) - # calculate timestamps and durations separately - # timestamp is intended to be used for linking to grafana and logs - # duration is actually a metric and uses float instead of int for timestamp - init_start_timestamp = utc_now_timestamp() - t0 = timeit.default_timer() - runner.init() - init_duration = timeit.default_timer() - t0 - init_end_timestamp = utc_now_timestamp() - - run_start_timestamp = utc_now_timestamp() - t0 = timeit.default_timer() - out = runner.run() # TODO handle failures - run_duration = timeit.default_timer() - t0 - run_end_timestamp = utc_now_timestamp() - - res = PgBenchRunResult.parse_from_output( - out=out, - init_duration=init_duration, - init_start_timestamp=init_start_timestamp, - init_end_timestamp=init_end_timestamp, - run_duration=run_duration, - run_start_timestamp=run_start_timestamp, - run_end_timestamp=run_end_timestamp, - ) - - zenbenchmark.record_pg_bench_result(res) diff --git a/walkeeper/src/control_file.rs b/walkeeper/src/control_file.rs index 8b4e618661..7cc53edeb0 100644 --- a/walkeeper/src/control_file.rs +++ b/walkeeper/src/control_file.rs @@ -6,6 +6,7 @@ use lazy_static::lazy_static; use std::fs::{self, File, OpenOptions}; use std::io::{Read, Write}; +use std::ops::Deref; use std::path::{Path, PathBuf}; use tracing::*; @@ -37,8 +38,10 @@ lazy_static! { .expect("Failed to register safekeeper_persist_control_file_seconds histogram vec"); } -pub trait Storage { - /// Persist safekeeper state on disk. +/// Storage should keep actual state inside of it. It should implement Deref +/// trait to access state fields and have persist method for updating that state. +pub trait Storage: Deref { + /// Persist safekeeper state on disk and update internal state. fn persist(&mut self, s: &SafeKeeperState) -> Result<()>; } @@ -48,19 +51,47 @@ pub struct FileStorage { timeline_dir: PathBuf, conf: SafeKeeperConf, persist_control_file_seconds: Histogram, + + /// Last state persisted to disk. + state: SafeKeeperState, } impl FileStorage { - pub fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> FileStorage { + pub fn restore_new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> Result { let timeline_dir = conf.timeline_dir(zttid); let tenant_id = zttid.tenant_id.to_string(); let timeline_id = zttid.timeline_id.to_string(); - FileStorage { + + let state = Self::load_control_file_conf(conf, zttid)?; + + Ok(FileStorage { timeline_dir, conf: conf.clone(), persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS .with_label_values(&[&tenant_id, &timeline_id]), - } + state, + }) + } + + pub fn create_new( + zttid: &ZTenantTimelineId, + conf: &SafeKeeperConf, + state: SafeKeeperState, + ) -> Result { + let timeline_dir = conf.timeline_dir(zttid); + let tenant_id = zttid.tenant_id.to_string(); + let timeline_id = zttid.timeline_id.to_string(); + + let mut store = FileStorage { + timeline_dir, + conf: conf.clone(), + persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS + .with_label_values(&[&tenant_id, &timeline_id]), + state: state.clone(), + }; + + store.persist(&state)?; + Ok(store) } // Check the magic/version in the on-disk data and deserialize it, if possible. @@ -141,6 +172,14 @@ impl FileStorage { } } +impl Deref for FileStorage { + type Target = SafeKeeperState; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + impl Storage for FileStorage { // persists state durably to underlying storage // for description see https://lwn.net/Articles/457667/ @@ -201,6 +240,9 @@ impl Storage for FileStorage { .and_then(|f| f.sync_all()) .context("failed to sync control file directory")?; } + + // update internal state + self.state = s.clone(); Ok(()) } } @@ -228,7 +270,7 @@ mod test { ) -> Result<(FileStorage, SafeKeeperState)> { fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir"); Ok(( - FileStorage::new(zttid, conf), + FileStorage::restore_new(zttid, conf)?, FileStorage::load_control_file_conf(conf, zttid)?, )) } @@ -239,8 +281,7 @@ mod test { ) -> Result<(FileStorage, SafeKeeperState)> { fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir"); let state = SafeKeeperState::empty(); - let mut storage = FileStorage::new(zttid, conf); - storage.persist(&state)?; + let storage = FileStorage::create_new(zttid, conf, state.clone())?; Ok((storage, state)) } diff --git a/walkeeper/src/safekeeper.rs b/walkeeper/src/safekeeper.rs index 1e23d87b34..22a8481e45 100644 --- a/walkeeper/src/safekeeper.rs +++ b/walkeeper/src/safekeeper.rs @@ -210,6 +210,7 @@ pub struct SafekeeperMemState { pub s3_wal_lsn: Lsn, // TODO: keep only persistent version pub peer_horizon_lsn: Lsn, pub remote_consistent_lsn: Lsn, + pub proposer_uuid: PgUuid, } impl SafeKeeperState { @@ -502,9 +503,8 @@ pub struct SafeKeeper { epoch_start_lsn: Lsn, pub inmem: SafekeeperMemState, // in memory part - pub s: SafeKeeperState, // persistent part + pub state: CTRL, // persistent state storage - pub control_store: CTRL, pub wal_store: WAL, } @@ -516,14 +516,14 @@ where // constructor pub fn new( ztli: ZTimelineId, - control_store: CTRL, + state: CTRL, mut wal_store: WAL, - state: SafeKeeperState, ) -> Result> { if state.timeline_id != ZTimelineId::from([0u8; 16]) && ztli != state.timeline_id { bail!("Calling SafeKeeper::new with inconsistent ztli ({}) and SafeKeeperState.server.timeline_id ({})", ztli, state.timeline_id); } + // initialize wal_store, if state is already initialized wal_store.init_storage(&state)?; Ok(SafeKeeper { @@ -535,23 +535,25 @@ where s3_wal_lsn: state.s3_wal_lsn, peer_horizon_lsn: state.peer_horizon_lsn, remote_consistent_lsn: state.remote_consistent_lsn, + proposer_uuid: state.proposer_uuid, }, - s: state, - control_store, + state, wal_store, }) } /// Get history of term switches for the available WAL fn get_term_history(&self) -> TermHistory { - self.s + self.state .acceptor_state .term_history .up_to(self.wal_store.flush_lsn()) } pub fn get_epoch(&self) -> Term { - self.s.acceptor_state.get_epoch(self.wal_store.flush_lsn()) + self.state + .acceptor_state + .get_epoch(self.wal_store.flush_lsn()) } /// Process message from proposer and possibly form reply. Concurrent @@ -587,46 +589,47 @@ where ); } /* Postgres upgrade is not treated as fatal error */ - if msg.pg_version != self.s.server.pg_version - && self.s.server.pg_version != UNKNOWN_SERVER_VERSION + if msg.pg_version != self.state.server.pg_version + && self.state.server.pg_version != UNKNOWN_SERVER_VERSION { info!( "incompatible server version {}, expected {}", - msg.pg_version, self.s.server.pg_version + msg.pg_version, self.state.server.pg_version ); } - if msg.tenant_id != self.s.tenant_id { + if msg.tenant_id != self.state.tenant_id { bail!( "invalid tenant ID, got {}, expected {}", msg.tenant_id, - self.s.tenant_id + self.state.tenant_id ); } - if msg.ztli != self.s.timeline_id { + if msg.ztli != self.state.timeline_id { bail!( "invalid timeline ID, got {}, expected {}", msg.ztli, - self.s.timeline_id + self.state.timeline_id ); } // set basic info about server, if not yet // TODO: verify that is doesn't change after - self.s.server.system_id = msg.system_id; - self.s.server.wal_seg_size = msg.wal_seg_size; - self.control_store - .persist(&self.s) - .context("failed to persist shared state")?; + { + let mut state = self.state.clone(); + state.server.system_id = msg.system_id; + state.server.wal_seg_size = msg.wal_seg_size; + self.state.persist(&state)?; + } // pass wal_seg_size to read WAL and find flush_lsn - self.wal_store.init_storage(&self.s)?; + self.wal_store.init_storage(&self.state)?; info!( "processed greeting from proposer {:?}, sending term {:?}", - msg.proposer_id, self.s.acceptor_state.term + msg.proposer_id, self.state.acceptor_state.term ); Ok(Some(AcceptorProposerMessage::Greeting(AcceptorGreeting { - term: self.s.acceptor_state.term, + term: self.state.acceptor_state.term, }))) } @@ -637,17 +640,19 @@ where ) -> Result> { // initialize with refusal let mut resp = VoteResponse { - term: self.s.acceptor_state.term, + term: self.state.acceptor_state.term, vote_given: false as u64, flush_lsn: self.wal_store.flush_lsn(), - truncate_lsn: self.s.peer_horizon_lsn, + truncate_lsn: self.state.peer_horizon_lsn, term_history: self.get_term_history(), }; - if self.s.acceptor_state.term < msg.term { - self.s.acceptor_state.term = msg.term; + if self.state.acceptor_state.term < msg.term { + let mut state = self.state.clone(); + state.acceptor_state.term = msg.term; // persist vote before sending it out - self.control_store.persist(&self.s)?; - resp.term = self.s.acceptor_state.term; + self.state.persist(&state)?; + + resp.term = self.state.acceptor_state.term; resp.vote_given = true as u64; } info!("processed VoteRequest for term {}: {:?}", msg.term, &resp); @@ -656,9 +661,10 @@ where /// Bump our term if received a note from elected proposer with higher one fn bump_if_higher(&mut self, term: Term) -> Result<()> { - if self.s.acceptor_state.term < term { - self.s.acceptor_state.term = term; - self.control_store.persist(&self.s)?; + if self.state.acceptor_state.term < term { + let mut state = self.state.clone(); + state.acceptor_state.term = term; + self.state.persist(&state)?; } Ok(()) } @@ -666,9 +672,9 @@ where /// Form AppendResponse from current state. fn append_response(&self) -> AppendResponse { let ar = AppendResponse { - term: self.s.acceptor_state.term, + term: self.state.acceptor_state.term, flush_lsn: self.wal_store.flush_lsn(), - commit_lsn: self.s.commit_lsn, + commit_lsn: self.state.commit_lsn, // will be filled by the upper code to avoid bothering safekeeper hs_feedback: HotStandbyFeedback::empty(), zenith_feedback: ZenithFeedback::empty(), @@ -681,7 +687,7 @@ where info!("received ProposerElected {:?}", msg); self.bump_if_higher(msg.term)?; // If our term is higher, ignore the message (next feedback will inform the compute) - if self.s.acceptor_state.term > msg.term { + if self.state.acceptor_state.term > msg.term { return Ok(None); } @@ -692,8 +698,11 @@ where self.wal_store.truncate_wal(msg.start_streaming_at)?; // and now adopt term history from proposer - self.s.acceptor_state.term_history = msg.term_history.clone(); - self.control_store.persist(&self.s)?; + { + let mut state = self.state.clone(); + state.acceptor_state.term_history = msg.term_history.clone(); + self.state.persist(&state)?; + } info!("start receiving WAL since {:?}", msg.start_streaming_at); @@ -715,13 +724,13 @@ where // Also note that commit_lsn can reach epoch_start_lsn earlier // that we receive new epoch_start_lsn, and we still need to sync // control file in this case. - if commit_lsn == self.epoch_start_lsn && self.s.commit_lsn != commit_lsn { + if commit_lsn == self.epoch_start_lsn && self.state.commit_lsn != commit_lsn { self.persist_control_file()?; } // We got our first commit_lsn, which means we should sync // everything to disk, to initialize the state. - if self.s.commit_lsn == Lsn(0) && commit_lsn > Lsn(0) { + if self.state.commit_lsn == Lsn(0) && commit_lsn > Lsn(0) { self.wal_store.flush_wal()?; self.persist_control_file()?; } @@ -731,10 +740,12 @@ where /// Persist in-memory state to the disk. fn persist_control_file(&mut self) -> Result<()> { - self.s.commit_lsn = self.inmem.commit_lsn; - self.s.peer_horizon_lsn = self.inmem.peer_horizon_lsn; + let mut state = self.state.clone(); - self.control_store.persist(&self.s) + state.commit_lsn = self.inmem.commit_lsn; + state.peer_horizon_lsn = self.inmem.peer_horizon_lsn; + state.proposer_uuid = self.inmem.proposer_uuid; + self.state.persist(&state) } /// Handle request to append WAL. @@ -744,13 +755,13 @@ where msg: &AppendRequest, require_flush: bool, ) -> Result> { - if self.s.acceptor_state.term < msg.h.term { + if self.state.acceptor_state.term < msg.h.term { bail!("got AppendRequest before ProposerElected"); } // If our term is higher, immediately refuse the message. - if self.s.acceptor_state.term > msg.h.term { - let resp = AppendResponse::term_only(self.s.acceptor_state.term); + if self.state.acceptor_state.term > msg.h.term { + let resp = AppendResponse::term_only(self.state.acceptor_state.term); return Ok(Some(AcceptorProposerMessage::AppendResponse(resp))); } @@ -758,8 +769,7 @@ where // processing the message. self.epoch_start_lsn = msg.h.epoch_start_lsn; - // TODO: don't update state without persisting to disk - self.s.proposer_uuid = msg.h.proposer_uuid; + self.inmem.proposer_uuid = msg.h.proposer_uuid; // do the job if !msg.wal_data.is_empty() { @@ -790,7 +800,7 @@ where // Update truncate and commit LSN in control file. // To avoid negative impact on performance of extra fsync, do it only // when truncate_lsn delta exceeds WAL segment size. - if self.s.peer_horizon_lsn + (self.s.server.wal_seg_size as u64) + if self.state.peer_horizon_lsn + (self.state.server.wal_seg_size as u64) < self.inmem.peer_horizon_lsn { self.persist_control_file()?; @@ -829,6 +839,8 @@ where #[cfg(test)] mod tests { + use std::ops::Deref; + use super::*; use crate::wal_storage::Storage; @@ -844,6 +856,14 @@ mod tests { } } + impl Deref for InMemoryState { + type Target = SafeKeeperState; + + fn deref(&self) -> &Self::Target { + &self.persisted_state + } + } + struct DummyWalStore { lsn: Lsn, } @@ -879,7 +899,7 @@ mod tests { }; let wal_store = DummyWalStore { lsn: Lsn(0) }; let ztli = ZTimelineId::from([0u8; 16]); - let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::empty()).unwrap(); + let mut sk = SafeKeeper::new(ztli, storage, wal_store).unwrap(); // check voting for 1 is ok let vote_request = ProposerAcceptorMessage::VoteRequest(VoteRequest { term: 1 }); @@ -890,11 +910,11 @@ mod tests { } // reboot... - let state = sk.control_store.persisted_state.clone(); + let state = sk.state.persisted_state.clone(); let storage = InMemoryState { - persisted_state: state.clone(), + persisted_state: state, }; - sk = SafeKeeper::new(ztli, storage, sk.wal_store, state).unwrap(); + sk = SafeKeeper::new(ztli, storage, sk.wal_store).unwrap(); // and ensure voting second time for 1 is not ok vote_resp = sk.process_msg(&vote_request); @@ -911,7 +931,7 @@ mod tests { }; let wal_store = DummyWalStore { lsn: Lsn(0) }; let ztli = ZTimelineId::from([0u8; 16]); - let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::empty()).unwrap(); + let mut sk = SafeKeeper::new(ztli, storage, wal_store).unwrap(); let mut ar_hdr = AppendRequestHeader { term: 1, diff --git a/walkeeper/src/timeline.rs b/walkeeper/src/timeline.rs index a76ef77615..a2941a9a5c 100644 --- a/walkeeper/src/timeline.rs +++ b/walkeeper/src/timeline.rs @@ -21,7 +21,6 @@ use crate::broker::SafekeeperInfo; use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey}; use crate::control_file; -use crate::control_file::Storage as cf_storage; use crate::safekeeper::{ AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState, SafekeeperMemState, @@ -98,10 +97,9 @@ impl SharedState { peer_ids: Vec, ) -> Result { let state = SafeKeeperState::new(zttid, peer_ids); - let control_store = control_file::FileStorage::new(zttid, conf); + let control_store = control_file::FileStorage::create_new(zttid, conf, state)?; let wal_store = wal_storage::PhysicalStorage::new(zttid, conf); - let mut sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state)?; - sk.control_store.persist(&sk.s)?; + let sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store)?; Ok(Self { notified_commit_lsn: Lsn(0), @@ -116,18 +114,14 @@ impl SharedState { /// Restore SharedState from control file. /// If file doesn't exist, bails out. fn restore(conf: &SafeKeeperConf, zttid: &ZTenantTimelineId) -> Result { - let state = control_file::FileStorage::load_control_file_conf(conf, zttid) - .context("failed to load from control file")?; - - let control_store = control_file::FileStorage::new(zttid, conf); - + let control_store = control_file::FileStorage::restore_new(zttid, conf)?; let wal_store = wal_storage::PhysicalStorage::new(zttid, conf); info!("timeline {} restored", zttid.timeline_id); Ok(Self { notified_commit_lsn: Lsn(0), - sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state)?, + sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store)?, replicas: Vec::new(), active: false, num_computes: 0, @@ -419,7 +413,7 @@ impl Timeline { pub fn get_state(&self) -> (SafekeeperMemState, SafeKeeperState) { let shared_state = self.mutex.lock().unwrap(); - (shared_state.sk.inmem.clone(), shared_state.sk.s.clone()) + (shared_state.sk.inmem.clone(), shared_state.sk.state.clone()) } /// Prepare public safekeeper info for reporting. diff --git a/workspace_hack/.gitattributes b/workspace_hack/.gitattributes new file mode 100644 index 0000000000..3e9dba4b64 --- /dev/null +++ b/workspace_hack/.gitattributes @@ -0,0 +1,4 @@ +# Avoid putting conflict markers in the generated Cargo.toml file, since their presence breaks +# Cargo. +# Also do not check out the file as CRLF on Windows, as that's what hakari needs. +Cargo.toml merge=binary -crlf diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 6e6a0e09d7..84244b3363 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -16,32 +16,38 @@ publish = false [dependencies] anyhow = { version = "1", features = ["backtrace", "std"] } bytes = { version = "1", features = ["serde", "std"] } +chrono = { version = "0.4", features = ["clock", "libc", "oldtime", "serde", "std", "time", "winapi"] } clap = { version = "2", features = ["ansi_term", "atty", "color", "strsim", "suggestions", "vec_map"] } either = { version = "1", features = ["use_std"] } hashbrown = { version = "0.11", features = ["ahash", "inline-more", "raw"] } +indexmap = { version = "1", default-features = false, features = ["std"] } libc = { version = "0.2", features = ["extra_traits", "std"] } log = { version = "0.4", default-features = false, features = ["serde", "std"] } memchr = { version = "2", features = ["std", "use_std"] } num-integer = { version = "0.1", default-features = false, features = ["std"] } num-traits = { version = "0.2", features = ["std"] } +prost = { version = "0.9", features = ["prost-derive", "std"] } +rand = { version = "0.8", features = ["alloc", "getrandom", "libc", "rand_chacha", "rand_hc", "small_rng", "std", "std_rng"] } regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cache", "perf-dfa", "perf-inline", "perf-literal", "std", "unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } -reqwest = { version = "0.11", default-features = false, features = ["__rustls", "__tls", "blocking", "hyper-rustls", "json", "rustls", "rustls-pemfile", "rustls-tls", "rustls-tls-webpki-roots", "serde_json", "stream", "tokio-rustls", "tokio-util", "webpki-roots"] } scopeguard = { version = "1", features = ["use_std"] } serde = { version = "1", features = ["alloc", "derive", "serde_derive", "std"] } -tokio = { version = "1", features = ["bytes", "fs", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "once_cell", "process", "rt", "rt-multi-thread", "signal-hook-registry", "sync", "time", "tokio-macros"] } -tracing = { version = "0.1", features = ["attributes", "std", "tracing-attributes"] } +tokio = { version = "1", features = ["bytes", "fs", "io-std", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "once_cell", "process", "rt", "rt-multi-thread", "signal-hook-registry", "socket2", "sync", "time", "tokio-macros"] } +tracing = { version = "0.1", features = ["attributes", "log", "std", "tracing-attributes"] } tracing-core = { version = "0.1", features = ["lazy_static", "std"] } [build-dependencies] +anyhow = { version = "1", features = ["backtrace", "std"] } +bytes = { version = "1", features = ["serde", "std"] } cc = { version = "1", default-features = false, features = ["jobserver", "parallel"] } clap = { version = "2", features = ["ansi_term", "atty", "color", "strsim", "suggestions", "vec_map"] } either = { version = "1", features = ["use_std"] } +hashbrown = { version = "0.11", features = ["ahash", "inline-more", "raw"] } +indexmap = { version = "1", default-features = false, features = ["std"] } libc = { version = "0.2", features = ["extra_traits", "std"] } log = { version = "0.4", default-features = false, features = ["serde", "std"] } memchr = { version = "2", features = ["std", "use_std"] } -proc-macro2 = { version = "1", features = ["proc-macro"] } -quote = { version = "1", features = ["proc-macro"] } +prost = { version = "0.9", features = ["prost-derive", "std"] } regex = { version = "1", features = ["aho-corasick", "memchr", "perf", "perf-cache", "perf-dfa", "perf-inline", "perf-literal", "std", "unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } regex-syntax = { version = "0.6", features = ["unicode", "unicode-age", "unicode-bool", "unicode-case", "unicode-gencat", "unicode-perl", "unicode-script", "unicode-segment"] } serde = { version = "1", features = ["alloc", "derive", "serde_derive", "std"] } diff --git a/workspace_hack/build.rs b/workspace_hack/build.rs new file mode 100644 index 0000000000..92518ef04c --- /dev/null +++ b/workspace_hack/build.rs @@ -0,0 +1,2 @@ +// A build script is required for cargo to consider build dependencies. +fn main() {} diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 83792f2aca..f984fb4417 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -375,9 +375,8 @@ impl PostgresBackend { } AuthType::MD5 => { rand::thread_rng().fill(&mut self.md5_salt); - let md5_salt = self.md5_salt; self.write_message(&BeMessage::AuthenticationMD5Password( - &md5_salt, + self.md5_salt, ))?; self.state = ProtoState::Authentication; } diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index cb69418c07..403e176b14 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -401,7 +401,8 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result { #[derive(Debug)] pub enum BeMessage<'a> { AuthenticationOk, - AuthenticationMD5Password(&'a [u8; 4]), + AuthenticationMD5Password([u8; 4]), + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), AuthenticationCleartextPassword, BackendKeyData(CancelKeyData), BindComplete, @@ -429,6 +430,13 @@ pub enum BeMessage<'a> { KeepAlive(WalSndKeepAlive), } +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + #[derive(Debug)] pub enum BeParameterStatusMessage<'a> { Encoding(&'a str), @@ -611,6 +619,32 @@ impl<'a> BeMessage<'a> { .unwrap(); // write into BytesMut can't fail } + BeMessage::AuthenticationSasl(msg) => { + buf.put_u8(b'R'); + write_body(buf, |buf| { + use BeAuthenticationSaslMessage::*; + match msg { + Methods(methods) => { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods.iter() { + write_cstr(method.as_bytes(), buf)?; + } + buf.put_u8(0); // zero terminator for the list + } + Continue(extra) => { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + } + Final(extra) => { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + } + } + Ok::<_, io::Error>(()) + }) + .unwrap() + } + BeMessage::BackendKeyData(key_data) => { buf.put_u8(b'K'); write_body(buf, |buf| {