From 2aceb6a3095bf0ee6cf7ef3ecc1bb182864abccb Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Mon, 23 May 2022 20:58:27 +0300 Subject: [PATCH 01/27] Fix garbage collection to not remove image layers that are still needed. The logic would incorrectly remove an image layer, if a new image layer existed, even though the older image layer was still needed by some delta layers after it. See example given in the comment this adds. Without this fix, I was getting a lot of "could not find data for key 010000000000000000000000000000000000" errors from GC, with the new test case being added in PR #1735. Fixes #707 --- pageserver/src/layered_repository.rs | 24 ++++++++++++------- .../src/layered_repository/layer_map.rs | 13 ++++------ 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/pageserver/src/layered_repository.rs b/pageserver/src/layered_repository.rs index fc4ab942f6..a83907430e 100644 --- a/pageserver/src/layered_repository.rs +++ b/pageserver/src/layered_repository.rs @@ -18,7 +18,7 @@ use itertools::Itertools; use lazy_static::lazy_static; use tracing::*; -use std::cmp::{max, Ordering}; +use std::cmp::{max, min, Ordering}; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::collections::{BTreeSet, HashSet}; @@ -2165,7 +2165,7 @@ impl LayeredTimeline { let gc_info = self.gc_info.read().unwrap(); let retain_lsns = &gc_info.retain_lsns; - let cutoff = gc_info.cutoff; + let cutoff = min(gc_info.cutoff, disk_consistent_lsn); let pitr = gc_info.pitr; // Calculate pitr cutoff point. @@ -2294,12 +2294,20 @@ impl LayeredTimeline { // is 102, then it might not have been fully flushed to disk // before crash. // - // FIXME: This logic is wrong. See https://github.com/zenithdb/zenith/issues/707 - if !layers.newer_image_layer_exists( - &l.get_key_range(), - l.get_lsn_range().end, - disk_consistent_lsn + 1, - )? { + // For example, imagine that the following layers exist: + // + // 1000 - image (A) + // 1000-2000 - delta (B) + // 2000 - image (C) + // 2000-3000 - delta (D) + // 3000 - image (E) + // + // If GC horizon is at 2500, we can remove layers A and B, but + // we cannot remove C, even though it's older than 2500, because + // the delta layer 2000-3000 depends on it. + if !layers + .image_layer_exists(&l.get_key_range(), &(l.get_lsn_range().end..new_gc_cutoff))? + { debug!( "keeping {} because it is the latest layer", l.filename().display() diff --git a/pageserver/src/layered_repository/layer_map.rs b/pageserver/src/layered_repository/layer_map.rs index 7491294c03..f7f51bf21f 100644 --- a/pageserver/src/layered_repository/layer_map.rs +++ b/pageserver/src/layered_repository/layer_map.rs @@ -201,18 +201,14 @@ impl LayerMap { NUM_ONDISK_LAYERS.dec(); } - /// Is there a newer image layer for given key-range? + /// Is there a newer image layer for given key- and LSN-range? /// /// This is used for garbage collection, to determine if an old layer can /// be deleted. - /// We ignore layers newer than disk_consistent_lsn because they will be removed at restart - /// We also only look at historic layers - //#[allow(dead_code)] - pub fn newer_image_layer_exists( + pub fn image_layer_exists( &self, key_range: &Range, - lsn: Lsn, - disk_consistent_lsn: Lsn, + lsn_range: &Range, ) -> Result { let mut range_remain = key_range.clone(); @@ -225,8 +221,7 @@ impl LayerMap { let img_lsn = l.get_lsn_range().start; if !l.is_incremental() && l.get_key_range().contains(&range_remain.start) - && img_lsn > lsn - && img_lsn < disk_consistent_lsn + && lsn_range.contains(&img_lsn) { made_progress = true; let img_key_end = l.get_key_range().end; From 8346aa3a29daf6088689076d35a9c99df3c9e4ce Mon Sep 17 00:00:00 2001 From: KlimentSerafimov Date: Tue, 24 May 2022 04:55:38 -0400 Subject: [PATCH 02/27] Potential fix to #1626. Fixed typo is Makefile. (#1781) * Potential fix to #1626. Fixed typo is Makefile. * Completed fix to #1626. Summary: changed 'error' to 'bail' in start_pageserver and start_safekeeper. --- Makefile | 2 +- pageserver/src/bin/pageserver.rs | 2 +- safekeeper/src/bin/safekeeper.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 5eca7fb094..fdfc64f6fa 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ else ifeq ($(BUILD_TYPE),debug) PG_CONFIGURE_OPTS = --enable-debug --with-openssl --enable-cassert --enable-depend PG_CFLAGS = -O0 -g3 $(CFLAGS) else -$(error Bad build type `$(BUILD_TYPE)', see Makefile for options) + $(error Bad build type '$(BUILD_TYPE)', see Makefile for options) endif # macOS with brew-installed openssl requires explicit paths diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 00864056cb..ac90500b97 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -254,7 +254,7 @@ fn start_pageserver(conf: &'static PageServerConf, daemonize: bool) -> Result<() // Otherwise, the coverage data will be damaged. match daemonize.exit_action(|| exit_now(0)).start() { Ok(_) => info!("Success, daemonized"), - Err(err) => error!(%err, "could not daemonize"), + Err(err) => bail!("{err}. could not daemonize. bailing."), } } diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index 61d2f558f2..a5ffc013e2 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -245,7 +245,7 @@ fn start_safekeeper(mut conf: SafeKeeperConf, given_id: Option, init: b // Otherwise, the coverage data will be damaged. match daemonize.exit_action(|| exit_now(0)).start() { Ok(_) => info!("Success, daemonized"), - Err(e) => error!("Error, {}", e), + Err(err) => bail!("Error: {err}. could not daemonize. bailing."), } } From 541ec258758309b1ef98c24b5afe79169406d3b9 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Tue, 24 May 2022 17:56:37 +0300 Subject: [PATCH 03/27] Properly shutdown test mock S3 server --- .circleci/config.yml | 2 +- test_runner/fixtures/zenith_fixtures.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index eb2bf0172b..41f7693726 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -361,7 +361,7 @@ jobs: when: always command: | du -sh /tmp/test_output/* - find /tmp/test_output -type f ! -name "pg.log" ! -name "pageserver.log" ! -name "safekeeper.log" ! -name "etcd.log" ! -name "regression.diffs" ! -name "junit.xml" ! -name "*.filediff" ! -name "*.stdout" ! -name "*.stderr" ! -name "flamegraph.svg" ! -name "*.metrics" -delete + find /tmp/test_output -type f ! -name "*.log" ! -name "regression.diffs" ! -name "junit.xml" ! -name "*.filediff" ! -name "*.stdout" ! -name "*.stderr" ! -name "flamegraph.svg" ! -name "*.metrics" -delete du -sh /tmp/test_output/* - store_artifacts: path: /tmp/test_output diff --git a/test_runner/fixtures/zenith_fixtures.py b/test_runner/fixtures/zenith_fixtures.py index 17d932c968..8f9bf1c11b 100644 --- a/test_runner/fixtures/zenith_fixtures.py +++ b/test_runner/fixtures/zenith_fixtures.py @@ -393,7 +393,10 @@ class MockS3Server: ): self.port = port - self.subprocess = subprocess.Popen([f'poetry run moto_server s3 -p{port}'], shell=True) + # XXX: do not use `shell=True` or add `exec ` to the command here otherwise. + # We use `self.subprocess.kill()` to shut down the server, which would not "just" work in Linux + # if a process is started from the shell process. + self.subprocess = subprocess.Popen(['poetry', 'run', 'moto_server', 's3', f'-p{port}']) error = None try: return_code = self.subprocess.poll() @@ -403,7 +406,7 @@ class MockS3Server: error = f"expected mock s3 server to start but it failed with exception: {e}. stdout: '{self.subprocess.stdout}', stderr: '{self.subprocess.stderr}'" if error is not None: log.error(error) - self.subprocess.kill() + self.kill() raise RuntimeError("failed to start s3 mock server") def endpoint(self) -> str: From d32b491a5300d99c9e2d7811944160185e23730c Mon Sep 17 00:00:00 2001 From: Sergey Melnikov Date: Wed, 25 May 2022 11:31:10 +0400 Subject: [PATCH 04/27] Add zenith-us-stage-sk-6 to deploy (#1728) --- .circleci/ansible/staging.hosts | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/ansible/staging.hosts b/.circleci/ansible/staging.hosts index 8e89e843d9..d99ffa6dac 100644 --- a/.circleci/ansible/staging.hosts +++ b/.circleci/ansible/staging.hosts @@ -6,6 +6,7 @@ zenith-us-stage-ps-2 console_region_id=27 zenith-us-stage-sk-1 console_region_id=27 zenith-us-stage-sk-4 console_region_id=27 zenith-us-stage-sk-5 console_region_id=27 +zenith-us-stage-sk-6 console_region_id=27 [storage:children] pageservers From 2b265fd6dc38b58a684ee6d584714a87705936b1 Mon Sep 17 00:00:00 2001 From: Arseny Sher Date: Wed, 25 May 2022 14:16:44 +0400 Subject: [PATCH 05/27] Disable restart_after_crash in neon_local. It is pointless when basebackup is invalid. --- control_plane/src/compute.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/control_plane/src/compute.rs b/control_plane/src/compute.rs index 92d0e080d8..350cf74b7c 100644 --- a/control_plane/src/compute.rs +++ b/control_plane/src/compute.rs @@ -274,6 +274,8 @@ impl PostgresNode { conf.append("listen_addresses", &self.address.ip().to_string()); conf.append("port", &self.address.port().to_string()); conf.append("wal_keep_size", "0"); + // walproposer panics when basebackup is invalid, it is pointless to restart in this case. + conf.append("restart_after_crash", "off"); // Configure the node to fetch pages from pageserver let pageserver_connstr = { From 703f691df8fb82fdfd3d2febc892748eb7317126 Mon Sep 17 00:00:00 2001 From: Andrey Taranik Date: Wed, 25 May 2022 14:30:50 +0300 Subject: [PATCH 06/27] production inventory update (#1779) --- .circleci/ansible/production.hosts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.circleci/ansible/production.hosts b/.circleci/ansible/production.hosts index 2ed8f517f7..6cefd724d8 100644 --- a/.circleci/ansible/production.hosts +++ b/.circleci/ansible/production.hosts @@ -1,5 +1,6 @@ [pageservers] -zenith-1-ps-1 console_region_id=1 +#zenith-1-ps-1 console_region_id=1 +zenith-1-ps-2 console_region_id=1 [safekeepers] zenith-1-sk-1 console_region_id=1 @@ -15,4 +16,4 @@ console_mgmt_base_url = http://console-release.local bucket_name = zenith-storage-oregon bucket_region = us-west-2 etcd_endpoints = etcd-release.local:2379 -safekeeper_enable_s3_offload = true +safekeeper_enable_s3_offload = false From 6f1f33ef42a63c0047442e8057b9223793424edb Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 25 May 2022 14:33:06 +0300 Subject: [PATCH 07/27] Improve error messages on seccomp loading errors. Bump vendor/postgres for https://github.com/neondatabase/postgres/pull/166 --- vendor/postgres | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/postgres b/vendor/postgres index 79af2faf08..038b2b98e5 160000 --- a/vendor/postgres +++ b/vendor/postgres @@ -1 +1 @@ -Subproject commit 79af2faf08d9bec1b1664a72936727dcca36d253 +Subproject commit 038b2b98e5c3d6274cbd43e9b822cdd946cb8b91 From 9ab52e2186e9330d4098b27372d8a0a2d5f0ac1e Mon Sep 17 00:00:00 2001 From: Andrey Taranik Date: Wed, 25 May 2022 15:41:18 +0300 Subject: [PATCH 08/27] helm repository name fix for production proxy deploy (#1790) --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 41f7693726..5346e35c01 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -685,7 +685,7 @@ jobs: name: Setup helm v3 command: | curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash - helm repo add zenithdb https://neondatabase.github.io/helm-charts + helm repo add neondatabase https://neondatabase.github.io/helm-charts - run: name: Re-deploy proxy command: | From 24d2313d0b8d1b6279f8a01376f55111427c9b19 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 25 May 2022 16:57:45 +0300 Subject: [PATCH 09/27] Set --quota-backend-bytes when launching etcd in tests. By default, etcd makes a huge 10 GB mmap() allocation when it starts up. It doesn't actually use that much memory, it's just address space, but it caused me grief when I tried to use 'rr' to debug a python test run. Apparently, when you replay the 'rr' trace, it does allocate memory for all that address space. The size of the initial mmap depends on the --quota-backend-bytes setting. Our etcd clusters are very small, so let's set --quota-backend-bytes to keep the virtual memory size small, to make debugging with 'rr' easier. See https://github.com/etcd-io/etcd/issues/7910 and https://github.com/etcd-io/etcd/commit/5e4b0081065925ab9d04009cd4fb559c4cceb304 --- control_plane/src/etcd.rs | 4 ++++ test_runner/fixtures/zenith_fixtures.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/control_plane/src/etcd.rs b/control_plane/src/etcd.rs index df657dd1be..bc39b7dea3 100644 --- a/control_plane/src/etcd.rs +++ b/control_plane/src/etcd.rs @@ -48,6 +48,10 @@ pub fn start_etcd_process(env: &local_env::LocalEnv) -> anyhow::Result<()> { format!("--data-dir={}", etcd_data_dir.display()), format!("--listen-client-urls={client_urls}"), format!("--advertise-client-urls={client_urls}"), + // Set --quota-backend-bytes to keep the etcd virtual memory + // size smaller. Our test etcd clusters are very small. + // See https://github.com/etcd-io/etcd/issues/7910 + "--quota-backend-bytes=100000000".to_string(), ]) .stdout(Stdio::from(etcd_stdout_file)) .stderr(Stdio::from(etcd_stderr_file)) diff --git a/test_runner/fixtures/zenith_fixtures.py b/test_runner/fixtures/zenith_fixtures.py index 8f9bf1c11b..7f5b2ad2aa 100644 --- a/test_runner/fixtures/zenith_fixtures.py +++ b/test_runner/fixtures/zenith_fixtures.py @@ -1893,7 +1893,11 @@ class Etcd: f"--data-dir={self.datadir}", f"--listen-client-urls={client_url}", f"--advertise-client-urls={client_url}", - f"--listen-peer-urls=http://127.0.0.1:{self.peer_port}" + f"--listen-peer-urls=http://127.0.0.1:{self.peer_port}", + # Set --quota-backend-bytes to keep the etcd virtual memory + # size smaller. Our test etcd clusters are very small. + # See https://github.com/etcd-io/etcd/issues/7910 + f"--quota-backend-bytes=100000000" ] self.handle = subprocess.Popen(args, stdout=log_file, stderr=log_file) From 7997fc2932465b1c8854a64c2c053041eacdf80a Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 25 May 2022 18:14:44 +0300 Subject: [PATCH 10/27] Fix error handling with 'basebackup' command. If the 'basebackup' command failed in the middle of building the tar archive, the client would not report the error, but would attempt to to start up postgres with the partial contents of the data directory. That fails because the control file is missing (it's added to the archive last, precisly to make sure that you cannot start postgres from a partial archive). But the client doesn't see the proper error message that caused the basebackup to fail in the server, which is confusing. Two issues conspired to cause that: 1. The tar::Builder object that we use in the pageserver to construct the tar stream has a Drop handler that automatically writes a valid end-of-archive marker on drop. Because of that, the resulting tarball looks complete, even if an error happens while we're building it. The pageserver does send an ErrorResponse after the seemingly-valid tarball, but: 2. The client stops reading the Copy stream, as soon as it sees the tar end-of-archive marker. Therefore, it doesn't read the ErrorResponse that comes after it. We have two clients that call 'basebackup', one in `control_plane` used by the `neon_local` binary, and another one in `compute_tools`. Both had the same issue. This PR fixes both issues, even though fixing either one would be enough to fix the problem at hand. The pageserver now doesn't send the end-of-archive marker on error, and the client now reads the copy stream to the end, even if it sees an end-of-archive marker. Fixes github issue #1715 In the passing, change Basebackup to use generic Write rather than 'dyn'. --- compute_tools/src/compute.rs | 8 +- control_plane/Cargo.toml | 2 +- control_plane/src/compute.rs | 9 +- pageserver/src/basebackup.rs | 90 +++++++++++++++++-- pageserver/src/page_service.rs | 3 +- .../batch_others/test_basebackup_error.py | 20 +++++ 6 files changed, 119 insertions(+), 13 deletions(-) create mode 100644 test_runner/batch_others/test_basebackup_error.py diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index a8422fb2b2..fd60b80305 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -146,8 +146,14 @@ impl ComputeNode { _ => format!("basebackup {} {} {}", &self.tenant, &self.timeline, lsn), }; let copyreader = client.copy_out(basebackup_cmd.as_str())?; - let mut ar = tar::Archive::new(copyreader); + // Read the archive directly from the `CopyOutReader` + // + // Set `ignore_zeros` so that unpack() reads all the Copy data and + // doesn't stop at the end-of-archive marker. Otherwise, if the server + // sends an Error after finishing the tarball, we will not notice it. + let mut ar = tar::Archive::new(copyreader); + ar.set_ignore_zeros(true); ar.unpack(&self.pgdata)?; self.metrics.basebackup_ms.store( diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index 41417aab9a..21311eea9a 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -tar = "0.4.33" +tar = "0.4.38" postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" } serde = { version = "1.0", features = ["derive"] } serde_with = "1.12.0" diff --git a/control_plane/src/compute.rs b/control_plane/src/compute.rs index 350cf74b7c..045acd7519 100644 --- a/control_plane/src/compute.rs +++ b/control_plane/src/compute.rs @@ -231,8 +231,13 @@ impl PostgresNode { .context("page server 'basebackup' command failed")?; // Read the archive directly from the `CopyOutReader` - tar::Archive::new(copyreader) - .unpack(&self.pgdata()) + // + // Set `ignore_zeros` so that unpack() reads all the Copy data and + // doesn't stop at the end-of-archive marker. Otherwise, if the server + // sends an Error after finishing the tarball, we will not notice it. + let mut ar = tar::Archive::new(copyreader); + ar.set_ignore_zeros(true); + ar.unpack(&self.pgdata()) .context("extracting base backup failed")?; Ok(()) diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index 92d35130d8..46d824b2e2 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -10,8 +10,9 @@ //! This module is responsible for creation of such tarball //! from data stored in object storage. //! -use anyhow::{anyhow, ensure, Context, Result}; +use anyhow::{anyhow, bail, ensure, Context, Result}; use bytes::{BufMut, BytesMut}; +use fail::fail_point; use std::fmt::Write as FmtWrite; use std::io; use std::io::Write; @@ -30,11 +31,16 @@ use utils::lsn::Lsn; /// This is short-living object only for the time of tarball creation, /// created mostly to avoid passing a lot of parameters between various functions /// used for constructing tarball. -pub struct Basebackup<'a> { - ar: Builder<&'a mut dyn Write>, +pub struct Basebackup<'a, W> +where + W: Write, +{ + ar: Builder>, timeline: &'a Arc, pub lsn: Lsn, prev_record_lsn: Lsn, + + finished: bool, } // Create basebackup with non-rel data in it. Omit relational data. @@ -44,12 +50,15 @@ pub struct Basebackup<'a> { // * When working without safekeepers. In this situation it is important to match the lsn // we are taking basebackup on with the lsn that is used in pageserver's walreceiver // to start the replication. -impl<'a> Basebackup<'a> { +impl<'a, W> Basebackup<'a, W> +where + W: Write, +{ pub fn new( - write: &'a mut dyn Write, + write: W, timeline: &'a Arc, req_lsn: Option, - ) -> Result> { + ) -> Result> { // Compute postgres doesn't have any previous WAL files, but the first // record that it's going to write needs to include the LSN of the // previous record (xl_prev). We include prev_record_lsn in the @@ -90,14 +99,15 @@ impl<'a> Basebackup<'a> { ); Ok(Basebackup { - ar: Builder::new(write), + ar: Builder::new(AbortableWrite::new(write)), timeline, lsn: backup_lsn, prev_record_lsn: backup_prev, + finished: false, }) } - pub fn send_tarball(&mut self) -> anyhow::Result<()> { + pub fn send_tarball(mut self) -> anyhow::Result<()> { // Create pgdata subdirs structure for dir in pg_constants::PGDATA_SUBDIRS.iter() { let header = new_tar_header_dir(*dir)?; @@ -135,9 +145,14 @@ impl<'a> Basebackup<'a> { self.add_twophase_file(xid)?; } + fail_point!("basebackup-before-control-file", |_| { + bail!("failpoint basebackup-before-control-file") + }); + // Generate pg_control and bootstrap WAL segment. self.add_pgcontrol_file()?; self.ar.finish()?; + self.finished = true; debug!("all tarred up!"); Ok(()) } @@ -331,6 +346,19 @@ impl<'a> Basebackup<'a> { } } +impl<'a, W> Drop for Basebackup<'a, W> +where + W: Write, +{ + /// If the basebackup was not finished, prevent the Archive::drop() from + /// writing the end-of-archive marker. + fn drop(&mut self) { + if !self.finished { + self.ar.get_mut().abort(); + } + } +} + // // Create new tarball entry header // @@ -366,3 +394,49 @@ fn new_tar_header_dir(path: &str) -> anyhow::Result
{ header.set_cksum(); Ok(header) } + +/// A wrapper that passes through all data to the underlying Write, +/// until abort() is called. +/// +/// tar::Builder has an annoying habit of finishing the archive with +/// a valid tar end-of-archive marker (two 512-byte sectors of zeros), +/// even if an error occurs and we don't finish building the archive. +/// We'd rather abort writing the tarball immediately than construct +/// a seemingly valid but incomplete archive. This wrapper allows us +/// to swallow the end-of-archive marker that Builder::drop() emits, +/// without writing it to the underlying sink. +/// +struct AbortableWrite { + w: W, + aborted: bool, +} + +impl AbortableWrite { + pub fn new(w: W) -> Self { + AbortableWrite { w, aborted: false } + } + + pub fn abort(&mut self) { + self.aborted = true; + } +} + +impl Write for AbortableWrite +where + W: Write, +{ + fn write(&mut self, data: &[u8]) -> io::Result { + if self.aborted { + Ok(data.len()) + } else { + self.w.write(data) + } + } + fn flush(&mut self) -> io::Result<()> { + if self.aborted { + Ok(()) + } else { + self.w.flush() + } + } +} diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 03264c9782..f54cd550b3 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -593,7 +593,8 @@ impl PageServerHandler { /* Send a tarball of the latest layer on the timeline */ { let mut writer = CopyDataSink { pgb }; - let mut basebackup = basebackup::Basebackup::new(&mut writer, &timeline, lsn)?; + + let basebackup = basebackup::Basebackup::new(&mut writer, &timeline, lsn)?; span.record("lsn", &basebackup.lsn.to_string().as_str()); basebackup.send_tarball()?; } diff --git a/test_runner/batch_others/test_basebackup_error.py b/test_runner/batch_others/test_basebackup_error.py new file mode 100644 index 0000000000..4b8b8a746c --- /dev/null +++ b/test_runner/batch_others/test_basebackup_error.py @@ -0,0 +1,20 @@ +import pytest +from contextlib import closing + +from fixtures.zenith_fixtures import ZenithEnv +from fixtures.log_helper import log + + +# +# Test error handling, if the 'basebackup' command fails in the middle +# of building the tar archive. +# +def test_basebackup_error(zenith_simple_env: ZenithEnv): + env = zenith_simple_env + env.zenith_cli.create_branch("test_basebackup_error", "empty") + + # Introduce failpoint + env.pageserver.safe_psql(f"failpoints basebackup-before-control-file=return") + + with pytest.raises(Exception, match="basebackup-before-control-file"): + pg = env.postgres.create_start('test_basebackup_error') From c584d90bb96bb7bd390bc5345ec8f667e765c299 Mon Sep 17 00:00:00 2001 From: chaitanya sharma <86035+phoenix24@users.noreply.github.com> Date: Mon, 23 May 2022 15:52:21 +0000 Subject: [PATCH 11/27] initial commit, renamed znodeid to nodeid. --- control_plane/src/local_env.rs | 10 +++++----- control_plane/src/safekeeper.rs | 8 ++++---- libs/etcd_broker/src/lib.rs | 16 ++++++++-------- libs/utils/src/zid.rs | 4 ++-- neon_local/src/main.rs | 10 +++++----- pageserver/src/config.rs | 16 ++++++++-------- pageserver/src/http/models.rs | 4 ++-- safekeeper/src/bin/safekeeper.rs | 12 ++++++------ safekeeper/src/broker.rs | 4 ++-- safekeeper/src/http/models.rs | 4 ++-- safekeeper/src/http/routes.rs | 6 +++--- safekeeper/src/lib.rs | 6 +++--- safekeeper/src/safekeeper.rs | 18 +++++++++--------- safekeeper/src/timeline.rs | 10 +++++----- 14 files changed, 64 insertions(+), 64 deletions(-) diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index c73af7d338..015b33f591 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -15,7 +15,7 @@ use std::process::{Command, Stdio}; use utils::{ auth::{encode_from_key_file, Claims, Scope}, postgres_backend::AuthType, - zid::{ZNodeId, ZTenantId, ZTenantTimelineId, ZTimelineId}, + zid::{NodeId, ZTenantId, ZTenantTimelineId, ZTimelineId}, }; use crate::safekeeper::SafekeeperNode; @@ -136,7 +136,7 @@ impl EtcdBroker { #[serde(default)] pub struct PageServerConf { // node id - pub id: ZNodeId, + pub id: NodeId, // Pageserver connection settings pub listen_pg_addr: String, pub listen_http_addr: String, @@ -151,7 +151,7 @@ pub struct PageServerConf { impl Default for PageServerConf { fn default() -> Self { Self { - id: ZNodeId(0), + id: NodeId(0), listen_pg_addr: String::new(), listen_http_addr: String::new(), auth_type: AuthType::Trust, @@ -163,7 +163,7 @@ impl Default for PageServerConf { #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] #[serde(default)] pub struct SafekeeperConf { - pub id: ZNodeId, + pub id: NodeId, pub pg_port: u16, pub http_port: u16, pub sync: bool, @@ -172,7 +172,7 @@ pub struct SafekeeperConf { impl Default for SafekeeperConf { fn default() -> Self { Self { - id: ZNodeId(0), + id: NodeId(0), pg_port: 0, http_port: 0, sync: true, diff --git a/control_plane/src/safekeeper.rs b/control_plane/src/safekeeper.rs index d5b6251209..303d6850df 100644 --- a/control_plane/src/safekeeper.rs +++ b/control_plane/src/safekeeper.rs @@ -18,7 +18,7 @@ use thiserror::Error; use utils::{ connstring::connection_address, http::error::HttpErrorBody, - zid::{ZNodeId, ZTenantId, ZTimelineId}, + zid::{NodeId, ZTenantId, ZTimelineId}, }; use crate::local_env::{LocalEnv, SafekeeperConf}; @@ -65,7 +65,7 @@ impl ResponseErrorMessageExt for Response { // #[derive(Debug)] pub struct SafekeeperNode { - pub id: ZNodeId, + pub id: NodeId, pub conf: SafekeeperConf, @@ -100,7 +100,7 @@ impl SafekeeperNode { .unwrap() } - pub fn datadir_path_by_id(env: &LocalEnv, sk_id: ZNodeId) -> PathBuf { + pub fn datadir_path_by_id(env: &LocalEnv, sk_id: NodeId) -> PathBuf { env.safekeeper_data_dir(format!("sk{}", sk_id).as_ref()) } @@ -286,7 +286,7 @@ impl SafekeeperNode { &self, tenant_id: ZTenantId, timeline_id: ZTimelineId, - peer_ids: Vec, + peer_ids: Vec, ) -> Result<()> { Ok(self .http_request( diff --git a/libs/etcd_broker/src/lib.rs b/libs/etcd_broker/src/lib.rs index 76181f9ba1..271f657f43 100644 --- a/libs/etcd_broker/src/lib.rs +++ b/libs/etcd_broker/src/lib.rs @@ -16,7 +16,7 @@ use tokio::{sync::mpsc, task::JoinHandle}; use tracing::*; use utils::{ lsn::Lsn, - zid::{ZNodeId, ZTenantId, ZTenantTimelineId}, + zid::{NodeId, ZTenantId, ZTenantTimelineId}, }; /// Default value to use for prefixing to all etcd keys with. @@ -25,7 +25,7 @@ pub const DEFAULT_NEON_BROKER_ETCD_PREFIX: &str = "neon"; #[derive(Debug, Deserialize, Serialize)] struct SafekeeperTimeline { - safekeeper_id: ZNodeId, + safekeeper_id: NodeId, info: SkTimelineInfo, } @@ -71,7 +71,7 @@ pub enum BrokerError { /// A way to control the data retrieval from a certain subscription. pub struct SkTimelineSubscription { safekeeper_timeline_updates: - mpsc::UnboundedReceiver>>, + mpsc::UnboundedReceiver>>, kind: SkTimelineSubscriptionKind, watcher_handle: JoinHandle>, watcher: Watcher, @@ -81,7 +81,7 @@ impl SkTimelineSubscription { /// Asynchronously polls for more data from the subscription, suspending the current future if there's no data sent yet. pub async fn fetch_data( &mut self, - ) -> Option>> { + ) -> Option>> { self.safekeeper_timeline_updates.recv().await } @@ -221,7 +221,7 @@ pub async fn subscribe_to_safekeeper_timeline_updates( break; } - let mut timeline_updates: HashMap> = HashMap::new(); + let mut timeline_updates: HashMap> = HashMap::new(); // Keep track that the timeline data updates from etcd arrive in the right order. // https://etcd.io/docs/v3.5/learning/api_guarantees/#isolation-level-and-consistency-of-replicas // > etcd does not ensure linearizability for watch operations. Users are expected to verify the revision of watch responses to ensure correct ordering. @@ -299,18 +299,18 @@ fn parse_etcd_key_value( parse_capture(&caps, 1).map_err(BrokerError::ParsingError)?, parse_capture(&caps, 2).map_err(BrokerError::ParsingError)?, ), - ZNodeId(parse_capture(&caps, 3).map_err(BrokerError::ParsingError)?), + NodeId(parse_capture(&caps, 3).map_err(BrokerError::ParsingError)?), ), SubscriptionKind::Tenant(tenant_id) => ( ZTenantTimelineId::new( tenant_id, parse_capture(&caps, 1).map_err(BrokerError::ParsingError)?, ), - ZNodeId(parse_capture(&caps, 2).map_err(BrokerError::ParsingError)?), + NodeId(parse_capture(&caps, 2).map_err(BrokerError::ParsingError)?), ), SubscriptionKind::Timeline(zttid) => ( zttid, - ZNodeId(parse_capture(&caps, 1).map_err(BrokerError::ParsingError)?), + NodeId(parse_capture(&caps, 1).map_err(BrokerError::ParsingError)?), ), }; diff --git a/libs/utils/src/zid.rs b/libs/utils/src/zid.rs index 44d81cda50..02f781c49a 100644 --- a/libs/utils/src/zid.rs +++ b/libs/utils/src/zid.rs @@ -226,9 +226,9 @@ impl fmt::Display for ZTenantTimelineId { // by the console. #[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd, Hash, Debug, Serialize, Deserialize)] #[serde(transparent)] -pub struct ZNodeId(pub u64); +pub struct NodeId(pub u64); -impl fmt::Display for ZNodeId { +impl fmt::Display for NodeId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } diff --git a/neon_local/src/main.rs b/neon_local/src/main.rs index f04af9cfdd..8d39fe5d0d 100644 --- a/neon_local/src/main.rs +++ b/neon_local/src/main.rs @@ -22,14 +22,14 @@ use utils::{ lsn::Lsn, postgres_backend::AuthType, project_git_version, - zid::{ZNodeId, ZTenantId, ZTenantTimelineId, ZTimelineId}, + zid::{NodeId, ZTenantId, ZTenantTimelineId, ZTimelineId}, }; use pageserver::timelines::TimelineInfo; // Default id of a safekeeper node, if not specified on the command line. -const DEFAULT_SAFEKEEPER_ID: ZNodeId = ZNodeId(1); -const DEFAULT_PAGESERVER_ID: ZNodeId = ZNodeId(1); +const DEFAULT_SAFEKEEPER_ID: NodeId = NodeId(1); +const DEFAULT_PAGESERVER_ID: NodeId = NodeId(1); const DEFAULT_BRANCH_NAME: &str = "main"; project_git_version!(GIT_VERSION); @@ -860,7 +860,7 @@ fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Resul Ok(()) } -fn get_safekeeper(env: &local_env::LocalEnv, id: ZNodeId) -> Result { +fn get_safekeeper(env: &local_env::LocalEnv, id: NodeId) -> Result { if let Some(node) = env.safekeepers.iter().find(|node| node.id == id) { Ok(SafekeeperNode::from_env(env, node)) } else { @@ -876,7 +876,7 @@ fn handle_safekeeper(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Resul // All the commands take an optional safekeeper name argument let sk_id = if let Some(id_str) = sub_args.value_of("id") { - ZNodeId(id_str.parse().context("while parsing safekeeper id")?) + NodeId(id_str.parse().context("while parsing safekeeper id")?) } else { DEFAULT_SAFEKEEPER_ID }; diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index a9215c0701..6c045d77ae 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -16,7 +16,7 @@ use toml_edit::{Document, Item}; use url::Url; use utils::{ postgres_backend::AuthType, - zid::{ZNodeId, ZTenantId, ZTimelineId}, + zid::{NodeId, ZTenantId, ZTimelineId}, }; use crate::layered_repository::TIMELINES_SEGMENT_NAME; @@ -78,7 +78,7 @@ pub mod defaults { pub struct PageServerConf { // Identifier of that particular pageserver so e g safekeepers // can safely distinguish different pageservers - pub id: ZNodeId, + pub id: NodeId, /// Example (default): 127.0.0.1:64000 pub listen_pg_addr: String, @@ -180,7 +180,7 @@ struct PageServerConfigBuilder { auth_validation_public_key_path: BuilderValue>, remote_storage_config: BuilderValue>, - id: BuilderValue, + id: BuilderValue, profiling: BuilderValue, broker_etcd_prefix: BuilderValue, @@ -276,7 +276,7 @@ impl PageServerConfigBuilder { self.broker_etcd_prefix = BuilderValue::Set(broker_etcd_prefix) } - pub fn id(&mut self, node_id: ZNodeId) { + pub fn id(&mut self, node_id: NodeId) { self.id = BuilderValue::Set(node_id) } @@ -399,7 +399,7 @@ impl PageServerConf { "tenant_config" => { t_conf = Self::parse_toml_tenant_conf(item)?; } - "id" => builder.id(ZNodeId(parse_toml_u64(key, item)?)), + "id" => builder.id(NodeId(parse_toml_u64(key, item)?)), "profiling" => builder.profiling(parse_toml_from_str(key, item)?), "broker_etcd_prefix" => builder.broker_etcd_prefix(parse_toml_string(key, item)?), "broker_endpoints" => builder.broker_endpoints( @@ -550,7 +550,7 @@ impl PageServerConf { #[cfg(test)] pub fn dummy_conf(repo_dir: PathBuf) -> Self { PageServerConf { - id: ZNodeId(0), + id: NodeId(0), wait_lsn_timeout: Duration::from_secs(60), wal_redo_timeout: Duration::from_secs(60), page_cache_size: defaults::DEFAULT_PAGE_CACHE_SIZE, @@ -693,7 +693,7 @@ id = 10 assert_eq!( parsed_config, PageServerConf { - id: ZNodeId(10), + id: NodeId(10), listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(), listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(), wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?, @@ -737,7 +737,7 @@ id = 10 assert_eq!( parsed_config, PageServerConf { - id: ZNodeId(10), + id: NodeId(10), listen_pg_addr: "127.0.0.1:64000".to_string(), listen_http_addr: "127.0.0.1:9898".to_string(), wait_lsn_timeout: Duration::from_secs(111), diff --git a/pageserver/src/http/models.rs b/pageserver/src/http/models.rs index e9aaa72416..e00ccda2a1 100644 --- a/pageserver/src/http/models.rs +++ b/pageserver/src/http/models.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; use utils::{ lsn::Lsn, - zid::{ZNodeId, ZTenantId, ZTimelineId}, + zid::{NodeId, ZTenantId, ZTimelineId}, }; #[serde_as] @@ -42,7 +42,7 @@ pub struct TenantCreateResponse(#[serde_as(as = "DisplayFromStr")] pub ZTenantId #[derive(Serialize)] pub struct StatusResponse { - pub id: ZNodeId, + pub id: NodeId, } impl TenantCreateRequest { diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index a5ffc013e2..290b7c738a 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -24,7 +24,7 @@ use safekeeper::{broker, callmemaybe}; use safekeeper::{http, s3_offload}; use utils::{ http::endpoint, logging, project_git_version, shutdown::exit_now, signals, tcp_listener, - zid::ZNodeId, + zid::NodeId, }; const LOCK_FILE_NAME: &str = "safekeeper.lock"; @@ -167,7 +167,7 @@ fn main() -> anyhow::Result<()> { let mut given_id = None; if let Some(given_id_str) = arg_matches.value_of("id") { - given_id = Some(ZNodeId( + given_id = Some(NodeId( given_id_str .parse() .context("failed to parse safekeeper id")?, @@ -192,7 +192,7 @@ fn main() -> anyhow::Result<()> { start_safekeeper(conf, given_id, arg_matches.is_present("init")) } -fn start_safekeeper(mut conf: SafeKeeperConf, given_id: Option, init: bool) -> Result<()> { +fn start_safekeeper(mut conf: SafeKeeperConf, given_id: Option, init: bool) -> Result<()> { let log_file = logging::init("safekeeper.log", conf.daemonize)?; info!("version: {GIT_VERSION}"); @@ -345,14 +345,14 @@ fn start_safekeeper(mut conf: SafeKeeperConf, given_id: Option, init: b } /// Determine safekeeper id and set it in config. -fn set_id(conf: &mut SafeKeeperConf, given_id: Option) -> Result<()> { +fn set_id(conf: &mut SafeKeeperConf, given_id: Option) -> Result<()> { let id_file_path = conf.workdir.join(ID_FILE_NAME); - let my_id: ZNodeId; + let my_id: NodeId; // If ID exists, read it in; otherwise set one passed match fs::read(&id_file_path) { Ok(id_serialized) => { - my_id = ZNodeId( + my_id = NodeId( std::str::from_utf8(&id_serialized) .context("failed to parse safekeeper id")? .parse() diff --git a/safekeeper/src/broker.rs b/safekeeper/src/broker.rs index d7217be20a..59d282d378 100644 --- a/safekeeper/src/broker.rs +++ b/safekeeper/src/broker.rs @@ -12,7 +12,7 @@ use tokio::{runtime, time::sleep}; use tracing::*; use crate::{timeline::GlobalTimelines, SafeKeeperConf}; -use utils::zid::{ZNodeId, ZTenantTimelineId}; +use utils::zid::{NodeId, ZTenantTimelineId}; const RETRY_INTERVAL_MSEC: u64 = 1000; const PUSH_INTERVAL_MSEC: u64 = 1000; @@ -36,7 +36,7 @@ pub fn thread_main(conf: SafeKeeperConf) { fn timeline_safekeeper_path( broker_etcd_prefix: String, zttid: ZTenantTimelineId, - sk_id: ZNodeId, + sk_id: NodeId, ) -> String { format!( "{}/{sk_id}", diff --git a/safekeeper/src/http/models.rs b/safekeeper/src/http/models.rs index ca18e64096..77efc0cc21 100644 --- a/safekeeper/src/http/models.rs +++ b/safekeeper/src/http/models.rs @@ -1,9 +1,9 @@ use serde::{Deserialize, Serialize}; -use utils::zid::{ZNodeId, ZTenantId, ZTimelineId}; +use utils::zid::{NodeId, ZTenantId, ZTimelineId}; #[derive(Serialize, Deserialize)] pub struct TimelineCreateRequest { pub tenant_id: ZTenantId, pub timeline_id: ZTimelineId, - pub peer_ids: Vec, + pub peer_ids: Vec, } diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index 62fbd2ff2f..3f6ade970d 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -20,14 +20,14 @@ use utils::{ RequestExt, RouterBuilder, }, lsn::Lsn, - zid::{ZNodeId, ZTenantId, ZTenantTimelineId, ZTimelineId}, + zid::{NodeId, ZTenantId, ZTenantTimelineId, ZTimelineId}, }; use super::models::TimelineCreateRequest; #[derive(Debug, Serialize)] struct SafekeeperStatus { - id: ZNodeId, + id: NodeId, } /// Healthcheck handler. @@ -178,7 +178,7 @@ async fn record_safekeeper_info(mut request: Request) -> Result, pub recall_period: Duration, - pub my_id: ZNodeId, + pub my_id: NodeId, pub broker_endpoints: Vec, pub broker_etcd_prefix: String, pub s3_offload_enabled: bool, @@ -79,7 +79,7 @@ impl Default for SafeKeeperConf { listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(), ttl: None, recall_period: defaults::DEFAULT_RECALL_PERIOD, - my_id: ZNodeId(0), + my_id: NodeId(0), broker_endpoints: Vec::new(), broker_etcd_prefix: etcd_broker::DEFAULT_NEON_BROKER_ETCD_PREFIX.to_string(), s3_offload_enabled: true, diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index fff1c269b6..b8b969929d 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -26,7 +26,7 @@ use utils::{ bin_ser::LeSer, lsn::Lsn, pq_proto::{SystemId, ZenithFeedback}, - zid::{ZNodeId, ZTenantId, ZTenantTimelineId, ZTimelineId}, + zid::{NodeId, ZTenantId, ZTenantTimelineId, ZTimelineId}, }; pub const SK_MAGIC: u32 = 0xcafeceefu32; @@ -164,7 +164,7 @@ impl PeerInfo { // vector-based node id -> peer state map with very limited functionality we // need/ #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Peers(pub Vec<(ZNodeId, PeerInfo)>); +pub struct Peers(pub Vec<(NodeId, PeerInfo)>); /// Persistent information stored on safekeeper node /// On disk data is prefixed by magic and format version and followed by checksum. @@ -224,7 +224,7 @@ pub struct SafekeeperMemState { } impl SafeKeeperState { - pub fn new(zttid: &ZTenantTimelineId, peers: Vec) -> SafeKeeperState { + pub fn new(zttid: &ZTenantTimelineId, peers: Vec) -> SafeKeeperState { SafeKeeperState { tenant_id: zttid.tenant_id, timeline_id: zttid.timeline_id, @@ -277,7 +277,7 @@ pub struct ProposerGreeting { #[derive(Debug, Serialize)] pub struct AcceptorGreeting { term: u64, - node_id: ZNodeId, + node_id: NodeId, } /// Vote request sent from proposer to safekeepers @@ -531,7 +531,7 @@ pub struct SafeKeeper { pub wal_store: WAL, - node_id: ZNodeId, // safekeeper's node id + node_id: NodeId, // safekeeper's node id } impl SafeKeeper @@ -544,7 +544,7 @@ where ztli: ZTimelineId, state: CTRL, mut wal_store: WAL, - node_id: ZNodeId, + node_id: NodeId, ) -> Result> { if state.timeline_id != ZTimelineId::from([0u8; 16]) && ztli != state.timeline_id { bail!("Calling SafeKeeper::new with inconsistent ztli ({}) and SafeKeeperState.server.timeline_id ({})", ztli, state.timeline_id); @@ -1013,7 +1013,7 @@ mod tests { }; let wal_store = DummyWalStore { lsn: Lsn(0) }; let ztli = ZTimelineId::from([0u8; 16]); - let mut sk = SafeKeeper::new(ztli, storage, wal_store, ZNodeId(0)).unwrap(); + let mut sk = SafeKeeper::new(ztli, storage, wal_store, NodeId(0)).unwrap(); // check voting for 1 is ok let vote_request = ProposerAcceptorMessage::VoteRequest(VoteRequest { term: 1 }); @@ -1028,7 +1028,7 @@ mod tests { let storage = InMemoryState { persisted_state: state, }; - sk = SafeKeeper::new(ztli, storage, sk.wal_store, ZNodeId(0)).unwrap(); + sk = SafeKeeper::new(ztli, storage, sk.wal_store, NodeId(0)).unwrap(); // and ensure voting second time for 1 is not ok vote_resp = sk.process_msg(&vote_request); @@ -1045,7 +1045,7 @@ mod tests { }; let wal_store = DummyWalStore { lsn: Lsn(0) }; let ztli = ZTimelineId::from([0u8; 16]); - let mut sk = SafeKeeper::new(ztli, storage, wal_store, ZNodeId(0)).unwrap(); + let mut sk = SafeKeeper::new(ztli, storage, wal_store, NodeId(0)).unwrap(); let mut ar_hdr = AppendRequestHeader { term: 1, diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index 2bb7771aac..0953439bd8 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -21,7 +21,7 @@ use tracing::*; use utils::{ lsn::Lsn, pq_proto::ZenithFeedback, - zid::{ZNodeId, ZTenantId, ZTenantTimelineId}, + zid::{NodeId, ZTenantId, ZTenantTimelineId}, }; use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey}; @@ -99,7 +99,7 @@ impl SharedState { fn create( conf: &SafeKeeperConf, zttid: &ZTenantTimelineId, - peer_ids: Vec, + peer_ids: Vec, ) -> Result { let state = SafeKeeperState::new(zttid, peer_ids); let control_store = control_file::FileStorage::create_new(zttid, conf, state)?; @@ -448,7 +448,7 @@ impl Timeline { } /// Update timeline state with peer safekeeper data. - pub fn record_safekeeper_info(&self, sk_info: &SkTimelineInfo, _sk_id: ZNodeId) -> Result<()> { + pub fn record_safekeeper_info(&self, sk_info: &SkTimelineInfo, _sk_id: NodeId) -> Result<()> { let mut shared_state = self.mutex.lock().unwrap(); shared_state.sk.record_safekeeper_info(sk_info)?; self.notify_wal_senders(&mut shared_state); @@ -551,7 +551,7 @@ impl GlobalTimelines { mut state: MutexGuard, conf: &SafeKeeperConf, zttid: ZTenantTimelineId, - peer_ids: Vec, + peer_ids: Vec, ) -> Result> { match state.timelines.get(&zttid) { Some(_) => bail!("timeline {} already exists", zttid), @@ -576,7 +576,7 @@ impl GlobalTimelines { pub fn create( conf: &SafeKeeperConf, zttid: ZTenantTimelineId, - peer_ids: Vec, + peer_ids: Vec, ) -> Result> { let state = TIMELINES_STATE.lock().unwrap(); GlobalTimelines::create_internal(state, conf, zttid, peer_ids) From 887b0e14d9285bdf64eab3e44eb7000cdb55b44b Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Wed, 25 May 2022 21:07:49 +0300 Subject: [PATCH 12/27] Run basic checks on PRs and pushes to main only --- .github/workflows/testing.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 79b2ba05d0..281c893403 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -1,8 +1,10 @@ name: Build and Test on: - pull_request: push: + branches: + - main + pull_request: jobs: regression-check: From 06f5e017a1b0d380e0e082e906cd52b7a885b100 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Wed, 25 May 2022 21:12:17 +0300 Subject: [PATCH 13/27] Move rustfmt check to GH Action --- .circleci/config.yml | 10 ---------- .github/workflows/testing.yml | 6 +++++- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 5346e35c01..624d367053 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -11,15 +11,6 @@ executors: - image: zimg/rust:1.58 jobs: - check-codestyle-rust: - executor: neon-xlarge-executor - steps: - - checkout - - run: - name: rustfmt - when: always - command: cargo fmt --all -- --check - # A job to build postgres build-postgres: executor: neon-xlarge-executor @@ -740,7 +731,6 @@ jobs: workflows: build_and_test: jobs: - - check-codestyle-rust - check-codestyle-python - build-postgres: name: build-postgres-<< matrix.build_type >> diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 281c893403..1ce1b64a49 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -25,13 +25,17 @@ jobs: submodules: true fetch-depth: 2 - - name: install rust toolchain ${{ matrix.rust_toolchain }} + - name: Install rust toolchain ${{ matrix.rust_toolchain }} uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: ${{ matrix.rust_toolchain }} + components: rustfmt, clippy override: true + - name: Check formatting + run: cargo fmt --all -- --check + - name: Install Ubuntu postgres dependencies if: matrix.os == 'ubuntu-latest' run: | From 5a5737278e637245d0b7b89a20b47040d2572a0e Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Wed, 25 May 2022 23:10:44 +0300 Subject: [PATCH 14/27] add simple metrics for remote storage operations track number of operations and number of their failures --- Cargo.lock | 2 + libs/remote_storage/Cargo.toml | 11 ++- libs/remote_storage/src/s3_bucket.rs | 109 +++++++++++++++++++++++++-- 3 files changed, 113 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6acad6dac8..840953f645 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2394,6 +2394,8 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "metrics", + "once_cell", "rusoto_core", "rusoto_s3", "serde", diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index 291f6e50ac..5c62e28fda 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -5,14 +5,17 @@ edition = "2021" [dependencies] anyhow = { version = "1.0", features = ["backtrace"] } -tokio = { version = "1.17", features = ["sync", "macros", "fs", "io-util"] } -tokio-util = { version = "0.7", features = ["io"] } -tracing = "0.1.27" +async-trait = "0.1" + +metrics = { version = "0.1", path = "../metrics" } +once_cell = "1.8.0" rusoto_core = "0.48" rusoto_s3 = "0.48" serde = { version = "1.0", features = ["derive"] } serde_json = "1" -async-trait = "0.1" +tokio = { version = "1.17", features = ["sync", "macros", "fs", "io-util"] } +tokio-util = { version = "0.7", features = ["io"] } +tracing = "0.1.27" workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/libs/remote_storage/src/s3_bucket.rs b/libs/remote_storage/src/s3_bucket.rs index 01aaf7ca7e..80d6966494 100644 --- a/libs/remote_storage/src/s3_bucket.rs +++ b/libs/remote_storage/src/s3_bucket.rs @@ -23,6 +23,71 @@ use crate::{strip_path_prefix, RemoteStorage, S3Config}; use super::StorageMetadata; +pub(super) mod metrics { + use metrics::{register_int_counter_vec, IntCounterVec}; + use once_cell::sync::Lazy; + + static S3_REQUESTS_COUNT: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "remote_storage_s3_requests_count", + "Number of s3 requests of particular type", + &["request_type"], + ) + .expect("failed to define a metric") + }); + + static S3_REQUESTS_FAIL_COUNT: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "remote_storage_s3_failures_count", + "Number of failed s3 requests of particular type", + &["request_type"], + ) + .expect("failed to define a metric") + }); + + pub fn inc_get_object() { + S3_REQUESTS_COUNT.with_label_values(&["get_object"]).inc(); + } + + pub fn inc_get_object_fail() { + S3_REQUESTS_FAIL_COUNT + .with_label_values(&["get_object"]) + .inc(); + } + + pub fn inc_put_object() { + S3_REQUESTS_COUNT.with_label_values(&["put_object"]).inc(); + } + + pub fn inc_put_object_fail() { + S3_REQUESTS_FAIL_COUNT + .with_label_values(&["put_object"]) + .inc(); + } + + pub fn inc_delete_object() { + S3_REQUESTS_COUNT + .with_label_values(&["delete_object"]) + .inc(); + } + + pub fn inc_delete_object_fail() { + S3_REQUESTS_FAIL_COUNT + .with_label_values(&["delete_object"]) + .inc(); + } + + pub fn inc_list_objects() { + S3_REQUESTS_COUNT.with_label_values(&["list_objects"]).inc(); + } + + pub fn inc_list_objects_fail() { + S3_REQUESTS_FAIL_COUNT + .with_label_values(&["list_objects"]) + .inc(); + } +} + const S3_PREFIX_SEPARATOR: char = '/'; #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash)] @@ -152,6 +217,9 @@ impl RemoteStorage for S3Bucket { .acquire() .await .context("Concurrency limiter semaphore got closed during S3 list")?; + + metrics::inc_list_objects(); + let fetch_response = self .client .list_objects_v2(ListObjectsV2Request { @@ -160,7 +228,11 @@ impl RemoteStorage for S3Bucket { continuation_token, ..ListObjectsV2Request::default() }) - .await?; + .await + .map_err(|e| { + metrics::inc_list_objects_fail(); + e + })?; document_keys.extend( fetch_response .contents @@ -190,6 +262,8 @@ impl RemoteStorage for S3Bucket { .acquire() .await .context("Concurrency limiter semaphore got closed during S3 upload")?; + + metrics::inc_put_object(); self.client .put_object(PutObjectRequest { body: Some(StreamingBody::new_with_size( @@ -201,7 +275,11 @@ impl RemoteStorage for S3Bucket { metadata: metadata.map(|m| m.0), ..PutObjectRequest::default() }) - .await?; + .await + .map_err(|e| { + metrics::inc_put_object_fail(); + e + })?; Ok(()) } @@ -215,6 +293,9 @@ impl RemoteStorage for S3Bucket { .acquire() .await .context("Concurrency limiter semaphore got closed during S3 download")?; + + metrics::inc_get_object(); + let object_output = self .client .get_object(GetObjectRequest { @@ -222,7 +303,11 @@ impl RemoteStorage for S3Bucket { key: from.key().to_owned(), ..GetObjectRequest::default() }) - .await?; + .await + .map_err(|e| { + metrics::inc_get_object_fail(); + e + })?; if let Some(body) = object_output.body { let mut from = io::BufReader::new(body.into_async_read()); @@ -251,6 +336,9 @@ impl RemoteStorage for S3Bucket { .acquire() .await .context("Concurrency limiter semaphore got closed during S3 range download")?; + + metrics::inc_get_object(); + let object_output = self .client .get_object(GetObjectRequest { @@ -259,7 +347,11 @@ impl RemoteStorage for S3Bucket { range, ..GetObjectRequest::default() }) - .await?; + .await + .map_err(|e| { + metrics::inc_get_object_fail(); + e + })?; if let Some(body) = object_output.body { let mut from = io::BufReader::new(body.into_async_read()); @@ -275,13 +367,20 @@ impl RemoteStorage for S3Bucket { .acquire() .await .context("Concurrency limiter semaphore got closed during S3 delete")?; + + metrics::inc_delete_object(); + self.client .delete_object(DeleteObjectRequest { bucket: self.bucket_name.clone(), key: path.key().to_owned(), ..DeleteObjectRequest::default() }) - .await?; + .await + .map_err(|e| { + metrics::inc_delete_object_fail(); + e + })?; Ok(()) } } From 38f2d165b778834d927ed6c549c3285ecfbbe576 Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Thu, 26 May 2022 12:06:05 +0300 Subject: [PATCH 15/27] allow TLS 1.2 in proxy to be compatible with older client libraries --- proxy/src/config.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 077a07beb9..6f1b56bfe4 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -61,7 +61,8 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result Date: Thu, 19 May 2022 14:27:28 +0300 Subject: [PATCH 16/27] Initialize last_freeze_at with disk consistent LSN to avoid creation of small L0 delta layer on startup refer #1736 --- pageserver/src/layered_repository.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pageserver/src/layered_repository.rs b/pageserver/src/layered_repository.rs index a83907430e..d10c795214 100644 --- a/pageserver/src/layered_repository.rs +++ b/pageserver/src/layered_repository.rs @@ -1230,7 +1230,7 @@ impl LayeredTimeline { }), disk_consistent_lsn: AtomicLsn::new(metadata.disk_consistent_lsn().0), - last_freeze_at: AtomicLsn::new(0), + last_freeze_at: AtomicLsn::new(metadata.disk_consistent_lsn().0), ancestor_timeline: ancestor, ancestor_lsn: metadata.ancestor_lsn(), From 72a7220dc8c7a247ea411f3e381c8710f99617b7 Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Thu, 26 May 2022 16:48:32 +0300 Subject: [PATCH 17/27] Tidy up some log messages * turn println into an info with proper message * rename new_local_timeline to load_local_timeline because it does not create new timeline, it registers timeline that exists on disk in pageserver in-memory structures --- pageserver/src/tenant_mgr.rs | 10 +++++----- pageserver/src/timelines.rs | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pageserver/src/tenant_mgr.rs b/pageserver/src/tenant_mgr.rs index bbe66d7f80..bba67394c3 100644 --- a/pageserver/src/tenant_mgr.rs +++ b/pageserver/src/tenant_mgr.rs @@ -327,8 +327,8 @@ pub fn get_local_timeline_with_load( return Ok(Arc::clone(page_tline)); } - let page_tline = new_local_timeline(&tenant.repo, timeline_id) - .with_context(|| format!("Failed to create new local timeline for tenant {tenant_id}"))?; + let page_tline = load_local_timeline(&tenant.repo, timeline_id) + .with_context(|| format!("Failed to load local timeline for tenant {tenant_id}"))?; tenant .local_timelines .insert(timeline_id, Arc::clone(&page_tline)); @@ -365,7 +365,7 @@ pub fn detach_timeline( Ok(()) } -fn new_local_timeline( +fn load_local_timeline( repo: &RepositoryImpl, timeline_id: ZTimelineId, ) -> anyhow::Result>> { @@ -458,8 +458,8 @@ fn apply_timeline_remote_sync_status_updates( bail!("Local timeline {timeline_id} already registered") } Entry::Vacant(v) => { - v.insert(new_local_timeline(repo, timeline_id).with_context(|| { - format!("Failed to register new local timeline for tenant {tenant_id}") + v.insert(load_local_timeline(repo, timeline_id).with_context(|| { + format!("Failed to register add local timeline for tenant {tenant_id}") })?); } }, diff --git a/pageserver/src/timelines.rs b/pageserver/src/timelines.rs index eadf5bf4e0..408eca6501 100644 --- a/pageserver/src/timelines.rs +++ b/pageserver/src/timelines.rs @@ -302,8 +302,8 @@ fn bootstrap_timeline( import_datadir::import_timeline_from_postgres_datadir(&pgdata_path, &mut page_tline, lsn)?; page_tline.tline.checkpoint(CheckpointConfig::Forced)?; - println!( - "created initial timeline {} timeline.lsn {}", + info!( + "created root timeline {} timeline.lsn {}", tli, page_tline.tline.get_last_record_lsn() ); From 7d565aa4b93836127de209eca5ceb1a98167b4f7 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 26 May 2022 12:21:15 -0400 Subject: [PATCH 18/27] Reduce the logging level when PG client disconnected to `INFO` (#1713) Fixes #1683. --- pageserver/src/page_service.rs | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index f54cd550b3..1c07b63072 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -305,7 +305,29 @@ fn page_service_conn_main( let mut conn_handler = PageServerHandler::new(conf, auth); let pgbackend = PostgresBackend::new(socket, auth_type, None, true)?; - pgbackend.run(&mut conn_handler) + match pgbackend.run(&mut conn_handler) { + Ok(()) => { + // we've been requested to shut down + Ok(()) + } + Err(err) => { + let root_cause_io_err_kind = err + .root_cause() + .downcast_ref::() + .map(|e| e.kind()); + + // `ConnectionReset` error happens when the Postgres client closes the connection. + // As this disconnection happens quite often and is expected, + // we decided to downgrade the logging level to `INFO`. + // See: https://github.com/neondatabase/neon/issues/1683. + if root_cause_io_err_kind == Some(io::ErrorKind::ConnectionReset) { + info!("Postgres client disconnected"); + Ok(()) + } else { + Err(err) + } + } + } } #[derive(Debug)] From 1d71949c51f06cd0eaf313f0ac595af3209ef57a Mon Sep 17 00:00:00 2001 From: bojanserafimov Date: Thu, 26 May 2022 14:59:03 -0400 Subject: [PATCH 19/27] Change proxy welcome message (#1808) Remove zenith sun and outdated instructions around .pgpass --- proxy/src/auth_backend/link.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/proxy/src/auth_backend/link.rs b/proxy/src/auth_backend/link.rs index 9bdb9e21c4..8e5fcb32a9 100644 --- a/proxy/src/auth_backend/link.rs +++ b/proxy/src/auth_backend/link.rs @@ -5,12 +5,9 @@ use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; fn hello_message(redirect_uri: &str, session_id: &str) -> String { format!( concat![ - "☀️ Welcome to Neon!\n", - "To proceed with database creation, open the following link:\n\n", + "Welcome to Neon!\n", + "Authenticate by visiting:\n", " {redirect_uri}{session_id}\n\n", - "It needs to be done once and we will send you '.pgpass' file,\n", - "which will allow you to access or create ", - "databases without opening your web browser." ], redirect_uri = redirect_uri, session_id = session_id, From 0e1bd57c533165dbe4bead8fa23baefa09c97b82 Mon Sep 17 00:00:00 2001 From: Arseny Sher Date: Wed, 27 Apr 2022 00:24:59 -0700 Subject: [PATCH 20/27] Add WAL offloading to s3 on safekeepers. Separate task is launched for each timeline and stopped when timeline doesn't need offloading. Decision who offloads is done through etcd leader election; currently there is no pre condition for participating, that's a TODO. neon_local and tests infrastructure for remote storage in safekeepers added, along with the test itself. ref #1009 Co-authored-by: Anton Shyrabokau --- .circleci/ansible/production.hosts | 1 - .circleci/ansible/staging.hosts | 1 - .circleci/ansible/systemd/safekeeper.service | 2 +- Cargo.lock | 9 +- control_plane/src/lib.rs | 9 + control_plane/src/local_env.rs | 5 + control_plane/src/safekeeper.rs | 10 +- control_plane/src/storage.rs | 11 +- libs/etcd_broker/src/lib.rs | 4 +- libs/remote_storage/Cargo.toml | 2 +- libs/remote_storage/src/lib.rs | 88 +++- libs/utils/src/lsn.rs | 9 + pageserver/src/config.rs | 87 +--- safekeeper/Cargo.toml | 4 + safekeeper/src/bin/safekeeper.rs | 72 +-- safekeeper/src/broker.rs | 129 +++++- safekeeper/src/control_file_upgrade.rs | 8 +- safekeeper/src/http/routes.rs | 18 +- safekeeper/src/lib.rs | 15 +- safekeeper/src/receive_wal.rs | 22 +- safekeeper/src/remove_wal.rs | 2 +- safekeeper/src/s3_offload.rs | 107 ----- safekeeper/src/safekeeper.rs | 69 ++- safekeeper/src/send_wal.rs | 2 +- safekeeper/src/timeline.rs | 307 +++++++++---- safekeeper/src/wal_backup.rs | 418 ++++++++++++++++++ test_runner/batch_others/test_wal_acceptor.py | 54 ++- test_runner/fixtures/zenith_fixtures.py | 110 +++-- 28 files changed, 1146 insertions(+), 429 deletions(-) delete mode 100644 safekeeper/src/s3_offload.rs create mode 100644 safekeeper/src/wal_backup.rs diff --git a/.circleci/ansible/production.hosts b/.circleci/ansible/production.hosts index 6cefd724d8..03c6cf57e0 100644 --- a/.circleci/ansible/production.hosts +++ b/.circleci/ansible/production.hosts @@ -16,4 +16,3 @@ console_mgmt_base_url = http://console-release.local bucket_name = zenith-storage-oregon bucket_region = us-west-2 etcd_endpoints = etcd-release.local:2379 -safekeeper_enable_s3_offload = false diff --git a/.circleci/ansible/staging.hosts b/.circleci/ansible/staging.hosts index d99ffa6dac..cf5b98eaa1 100644 --- a/.circleci/ansible/staging.hosts +++ b/.circleci/ansible/staging.hosts @@ -17,4 +17,3 @@ console_mgmt_base_url = http://console-staging.local bucket_name = zenith-staging-storage-us-east-1 bucket_region = us-east-1 etcd_endpoints = etcd-staging.local:2379 -safekeeper_enable_s3_offload = false diff --git a/.circleci/ansible/systemd/safekeeper.service b/.circleci/ansible/systemd/safekeeper.service index 55088db859..a6b443c3e7 100644 --- a/.circleci/ansible/systemd/safekeeper.service +++ b/.circleci/ansible/systemd/safekeeper.service @@ -6,7 +6,7 @@ After=network.target auditd.service Type=simple User=safekeeper Environment=RUST_BACKTRACE=1 ZENITH_REPO_DIR=/storage/safekeeper/data LD_LIBRARY_PATH=/usr/local/lib -ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}.local:6500 --listen-http {{ inventory_hostname }}.local:7676 -p {{ first_pageserver }}:6400 -D /storage/safekeeper/data --broker-endpoints={{ etcd_endpoints }} --enable-s3-offload={{ safekeeper_enable_s3_offload }} +ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}.local:6500 --listen-http {{ inventory_hostname }}.local:7676 -p {{ first_pageserver }}:6400 -D /storage/safekeeper/data --broker-endpoints={{ etcd_endpoints }} --remote_storage='{bucket_name={{bucket_name}}, bucket_region={{bucket_region}}, prefix_in_bucket=wal}' ExecReload=/bin/kill -HUP $MAINPID KillMode=mixed KillSignal=SIGINT diff --git a/Cargo.lock b/Cargo.lock index 840953f645..e39375c221 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1722,9 +1722,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da32515d9f6e6e489d7bc9d84c71b060db7247dc035bbe44eac88cf87486d8d5" +checksum = "87f3e037eac156d1775da914196f0f37741a274155e34a0b7e427c35d2a2ecb9" [[package]] name = "oorandom" @@ -2403,6 +2403,7 @@ dependencies = [ "tempfile", "tokio", "tokio-util 0.7.0", + "toml_edit", "tracing", "workspace_hack", ] @@ -2654,6 +2655,7 @@ name = "safekeeper" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "byteorder", "bytes", "clap 3.0.14", @@ -2662,12 +2664,14 @@ dependencies = [ "daemonize", "etcd_broker", "fs2", + "futures", "git-version", "hex", "humantime", "hyper", "lazy_static", "metrics", + "once_cell", "postgres", "postgres-protocol", "postgres_ffi", @@ -2681,6 +2685,7 @@ dependencies = [ "tokio", "tokio-postgres", "tokio-util 0.7.0", + "toml_edit", "tracing", "url", "utils", diff --git a/control_plane/src/lib.rs b/control_plane/src/lib.rs index c3469c3350..4dfca588ad 100644 --- a/control_plane/src/lib.rs +++ b/control_plane/src/lib.rs @@ -49,3 +49,12 @@ fn fill_rust_env_vars(cmd: &mut Command) -> &mut Command { cmd } } + +fn fill_aws_secrets_vars(mut cmd: &mut Command) -> &mut Command { + for env_key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"] { + if let Ok(value) = std::env::var(env_key) { + cmd = cmd.env(env_key, value); + } + } + cmd +} diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 015b33f591..2623f65242 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -167,6 +167,8 @@ pub struct SafekeeperConf { pub pg_port: u16, pub http_port: u16, pub sync: bool, + pub remote_storage: Option, + pub backup_threads: Option, } impl Default for SafekeeperConf { @@ -176,6 +178,8 @@ impl Default for SafekeeperConf { pg_port: 0, http_port: 0, sync: true, + remote_storage: None, + backup_threads: None, } } } @@ -377,6 +381,7 @@ impl LocalEnv { base_path != Path::new(""), "repository base path is missing" ); + ensure!( !base_path.exists(), "directory '{}' already exists. Perhaps already initialized?", diff --git a/control_plane/src/safekeeper.rs b/control_plane/src/safekeeper.rs index 303d6850df..972b6d48ae 100644 --- a/control_plane/src/safekeeper.rs +++ b/control_plane/src/safekeeper.rs @@ -23,7 +23,7 @@ use utils::{ use crate::local_env::{LocalEnv, SafekeeperConf}; use crate::storage::PageServerNode; -use crate::{fill_rust_env_vars, read_pidfile}; +use crate::{fill_aws_secrets_vars, fill_rust_env_vars, read_pidfile}; #[derive(Error, Debug)] pub enum SafekeeperHttpError { @@ -143,6 +143,14 @@ impl SafekeeperNode { if let Some(prefix) = self.env.etcd_broker.broker_etcd_prefix.as_deref() { cmd.args(&["--broker-etcd-prefix", prefix]); } + if let Some(threads) = self.conf.backup_threads { + cmd.args(&["--backup-threads", threads.to_string().as_ref()]); + } + if let Some(ref remote_storage) = self.conf.remote_storage { + cmd.args(&["--remote-storage", remote_storage]); + } + + fill_aws_secrets_vars(&mut cmd); if !cmd.status()?.success() { bail!( diff --git a/control_plane/src/storage.rs b/control_plane/src/storage.rs index 355c7c250d..24cdbce8f3 100644 --- a/control_plane/src/storage.rs +++ b/control_plane/src/storage.rs @@ -25,7 +25,7 @@ use utils::{ }; use crate::local_env::LocalEnv; -use crate::{fill_rust_env_vars, read_pidfile}; +use crate::{fill_aws_secrets_vars, fill_rust_env_vars, read_pidfile}; use pageserver::tenant_mgr::TenantInfo; #[derive(Error, Debug)] @@ -493,12 +493,3 @@ impl PageServerNode { Ok(timeline_info_response) } } - -fn fill_aws_secrets_vars(mut cmd: &mut Command) -> &mut Command { - for env_key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"] { - if let Ok(value) = std::env::var(env_key) { - cmd = cmd.env(env_key, value); - } - } - cmd -} diff --git a/libs/etcd_broker/src/lib.rs b/libs/etcd_broker/src/lib.rs index 271f657f43..7fe142502b 100644 --- a/libs/etcd_broker/src/lib.rs +++ b/libs/etcd_broker/src/lib.rs @@ -43,10 +43,10 @@ pub struct SkTimelineInfo { #[serde_as(as = "Option")] #[serde(default)] pub commit_lsn: Option, - /// LSN up to which safekeeper offloaded WAL to s3. + /// LSN up to which safekeeper has backed WAL. #[serde_as(as = "Option")] #[serde(default)] - pub s3_wal_lsn: Option, + pub backup_lsn: Option, /// LSN of last checkpoint uploaded by pageserver. #[serde_as(as = "Option")] #[serde(default)] diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index 5c62e28fda..b11b3cf371 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" [dependencies] anyhow = { version = "1.0", features = ["backtrace"] } async-trait = "0.1" - metrics = { version = "0.1", path = "../metrics" } once_cell = "1.8.0" rusoto_core = "0.48" @@ -15,6 +14,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1" tokio = { version = "1.17", features = ["sync", "macros", "fs", "io-util"] } tokio-util = { version = "0.7", features = ["io"] } +toml_edit = { version = "0.13", features = ["easy"] } tracing = "0.1.27" workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 8092e4fc49..0889cb720c 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -16,8 +16,10 @@ use std::{ path::{Path, PathBuf}, }; -use anyhow::Context; +use anyhow::{bail, Context}; + use tokio::io; +use toml_edit::Item; use tracing::info; pub use self::{ @@ -203,6 +205,90 @@ pub fn path_with_suffix_extension(original_path: impl AsRef, suffix: &str) .with_extension(new_extension.as_ref()) } +impl RemoteStorageConfig { + pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result { + let local_path = toml.get("local_path"); + let bucket_name = toml.get("bucket_name"); + let bucket_region = toml.get("bucket_region"); + + let max_concurrent_syncs = NonZeroUsize::new( + parse_optional_integer("max_concurrent_syncs", toml)? + .unwrap_or(DEFAULT_REMOTE_STORAGE_MAX_CONCURRENT_SYNCS), + ) + .context("Failed to parse 'max_concurrent_syncs' as a positive integer")?; + + let max_sync_errors = NonZeroU32::new( + parse_optional_integer("max_sync_errors", toml)? + .unwrap_or(DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS), + ) + .context("Failed to parse 'max_sync_errors' as a positive integer")?; + + let concurrency_limit = NonZeroUsize::new( + parse_optional_integer("concurrency_limit", toml)? + .unwrap_or(DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT), + ) + .context("Failed to parse 'concurrency_limit' as a positive integer")?; + + let storage = match (local_path, bucket_name, bucket_region) { + (None, None, None) => bail!("no 'local_path' nor 'bucket_name' option"), + (_, Some(_), None) => { + bail!("'bucket_region' option is mandatory if 'bucket_name' is given ") + } + (_, None, Some(_)) => { + bail!("'bucket_name' option is mandatory if 'bucket_region' is given ") + } + (None, Some(bucket_name), Some(bucket_region)) => RemoteStorageKind::AwsS3(S3Config { + bucket_name: parse_toml_string("bucket_name", bucket_name)?, + bucket_region: parse_toml_string("bucket_region", bucket_region)?, + prefix_in_bucket: toml + .get("prefix_in_bucket") + .map(|prefix_in_bucket| parse_toml_string("prefix_in_bucket", prefix_in_bucket)) + .transpose()?, + endpoint: toml + .get("endpoint") + .map(|endpoint| parse_toml_string("endpoint", endpoint)) + .transpose()?, + concurrency_limit, + }), + (Some(local_path), None, None) => RemoteStorageKind::LocalFs(PathBuf::from( + parse_toml_string("local_path", local_path)?, + )), + (Some(_), Some(_), _) => bail!("local_path and bucket_name are mutually exclusive"), + }; + + Ok(RemoteStorageConfig { + max_concurrent_syncs, + max_sync_errors, + storage, + }) + } +} + +// Helper functions to parse a toml Item +fn parse_optional_integer(name: &str, item: &toml_edit::Item) -> anyhow::Result> +where + I: TryFrom, + E: std::error::Error + Send + Sync + 'static, +{ + let toml_integer = match item.get(name) { + Some(item) => item + .as_integer() + .with_context(|| format!("configure option {name} is not an integer"))?, + None => return Ok(None), + }; + + I::try_from(toml_integer) + .map(Some) + .with_context(|| format!("configure option {name} is too large")) +} + +fn parse_toml_string(name: &str, item: &Item) -> anyhow::Result { + let s = item + .as_str() + .with_context(|| format!("configure option {name} is not a string"))?; + Ok(s.to_string()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/libs/utils/src/lsn.rs b/libs/utils/src/lsn.rs index c09d8c67ce..3dab2a625c 100644 --- a/libs/utils/src/lsn.rs +++ b/libs/utils/src/lsn.rs @@ -26,6 +26,9 @@ impl Lsn { /// Maximum possible value for an LSN pub const MAX: Lsn = Lsn(u64::MAX); + /// Invalid value for InvalidXLogRecPtr, as defined in xlogdefs.h + pub const INVALID: Lsn = Lsn(0); + /// Subtract a number, returning None on overflow. pub fn checked_sub>(self, other: T) -> Option { let other: u64 = other.into(); @@ -103,6 +106,12 @@ impl Lsn { pub fn is_aligned(&self) -> bool { *self == self.align() } + + /// Return if the LSN is valid + /// mimics postgres XLogRecPtrIsInvalid macro + pub fn is_valid(self) -> bool { + self != Lsn::INVALID + } } impl From for Lsn { diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 6c045d77ae..dc9d7161a2 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -5,9 +5,9 @@ //! See also `settings.md` for better description on every parameter. use anyhow::{anyhow, bail, ensure, Context, Result}; -use remote_storage::{RemoteStorageConfig, RemoteStorageKind, S3Config}; +use remote_storage::RemoteStorageConfig; use std::env; -use std::num::{NonZeroU32, NonZeroUsize}; + use std::path::{Path, PathBuf}; use std::str::FromStr; use std::time::Duration; @@ -394,7 +394,7 @@ impl PageServerConf { )), "auth_type" => builder.auth_type(parse_toml_from_str(key, item)?), "remote_storage" => { - builder.remote_storage_config(Some(Self::parse_remote_storage_config(item)?)) + builder.remote_storage_config(Some(RemoteStorageConfig::from_toml(item)?)) } "tenant_config" => { t_conf = Self::parse_toml_tenant_conf(item)?; @@ -484,64 +484,6 @@ impl PageServerConf { Ok(t_conf) } - /// subroutine of parse_config(), to parse the `[remote_storage]` table. - fn parse_remote_storage_config(toml: &toml_edit::Item) -> anyhow::Result { - let local_path = toml.get("local_path"); - let bucket_name = toml.get("bucket_name"); - let bucket_region = toml.get("bucket_region"); - - let max_concurrent_syncs = NonZeroUsize::new( - parse_optional_integer("max_concurrent_syncs", toml)? - .unwrap_or(remote_storage::DEFAULT_REMOTE_STORAGE_MAX_CONCURRENT_SYNCS), - ) - .context("Failed to parse 'max_concurrent_syncs' as a positive integer")?; - - let max_sync_errors = NonZeroU32::new( - parse_optional_integer("max_sync_errors", toml)? - .unwrap_or(remote_storage::DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS), - ) - .context("Failed to parse 'max_sync_errors' as a positive integer")?; - - let concurrency_limit = NonZeroUsize::new( - parse_optional_integer("concurrency_limit", toml)? - .unwrap_or(remote_storage::DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT), - ) - .context("Failed to parse 'concurrency_limit' as a positive integer")?; - - let storage = match (local_path, bucket_name, bucket_region) { - (None, None, None) => bail!("no 'local_path' nor 'bucket_name' option"), - (_, Some(_), None) => { - bail!("'bucket_region' option is mandatory if 'bucket_name' is given ") - } - (_, None, Some(_)) => { - bail!("'bucket_name' option is mandatory if 'bucket_region' is given ") - } - (None, Some(bucket_name), Some(bucket_region)) => RemoteStorageKind::AwsS3(S3Config { - bucket_name: parse_toml_string("bucket_name", bucket_name)?, - bucket_region: parse_toml_string("bucket_region", bucket_region)?, - prefix_in_bucket: toml - .get("prefix_in_bucket") - .map(|prefix_in_bucket| parse_toml_string("prefix_in_bucket", prefix_in_bucket)) - .transpose()?, - endpoint: toml - .get("endpoint") - .map(|endpoint| parse_toml_string("endpoint", endpoint)) - .transpose()?, - concurrency_limit, - }), - (Some(local_path), None, None) => RemoteStorageKind::LocalFs(PathBuf::from( - parse_toml_string("local_path", local_path)?, - )), - (Some(_), Some(_), _) => bail!("local_path and bucket_name are mutually exclusive"), - }; - - Ok(RemoteStorageConfig { - max_concurrent_syncs, - max_sync_errors, - storage, - }) - } - #[cfg(test)] pub fn test_repo_dir(test_name: &str) -> PathBuf { PathBuf::from(format!("../tmp_check/test_{test_name}")) @@ -592,23 +534,6 @@ fn parse_toml_u64(name: &str, item: &Item) -> Result { Ok(i as u64) } -fn parse_optional_integer(name: &str, item: &toml_edit::Item) -> anyhow::Result> -where - I: TryFrom, - E: std::error::Error + Send + Sync + 'static, -{ - let toml_integer = match item.get(name) { - Some(item) => item - .as_integer() - .with_context(|| format!("configure option {name} is not an integer"))?, - None => return Ok(None), - }; - - I::try_from(toml_integer) - .map(Some) - .with_context(|| format!("configure option {name} is too large")) -} - fn parse_toml_duration(name: &str, item: &Item) -> Result { let s = item .as_str() @@ -651,8 +576,12 @@ fn parse_toml_array(name: &str, item: &Item) -> anyhow::Result> { #[cfg(test)] mod tests { - use std::fs; + use std::{ + fs, + num::{NonZeroU32, NonZeroUsize}, + }; + use remote_storage::{RemoteStorageKind, S3Config}; use tempfile::{tempdir, TempDir}; use super::*; diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 417cf58cd5..373108c61b 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -30,6 +30,10 @@ const_format = "0.2.21" tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" } tokio-util = { version = "0.7", features = ["io"] } git-version = "0.3.5" +async-trait = "0.1" +once_cell = "1.10.0" +futures = "0.3.13" +toml_edit = { version = "0.13", features = ["easy"] } postgres_ffi = { path = "../libs/postgres_ffi" } metrics = { path = "../libs/metrics" } diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index 290b7c738a..a7628482d9 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -6,22 +6,27 @@ use clap::{App, Arg}; use const_format::formatcp; use daemonize::Daemonize; use fs2::FileExt; +use remote_storage::RemoteStorageConfig; use std::fs::{self, File}; use std::io::{ErrorKind, Write}; use std::path::{Path, PathBuf}; use std::thread; use tokio::sync::mpsc; +use toml_edit::Document; use tracing::*; use url::{ParseError, Url}; use safekeeper::control_file::{self}; -use safekeeper::defaults::{DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_PG_LISTEN_ADDR}; +use safekeeper::defaults::{ + DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_PG_LISTEN_ADDR, DEFAULT_WAL_BACKUP_RUNTIME_THREADS, +}; +use safekeeper::http; use safekeeper::remove_wal; use safekeeper::timeline::GlobalTimelines; +use safekeeper::wal_backup; use safekeeper::wal_service; use safekeeper::SafeKeeperConf; use safekeeper::{broker, callmemaybe}; -use safekeeper::{http, s3_offload}; use utils::{ http::endpoint, logging, project_git_version, shutdown::exit_now, signals, tcp_listener, zid::NodeId, @@ -71,12 +76,6 @@ fn main() -> anyhow::Result<()> { .long("pageserver") .takes_value(true), ) - .arg( - Arg::new("ttl") - .long("ttl") - .takes_value(true) - .help("interval for keeping WAL at safekeeper node, after which them will be uploaded to S3 and removed locally"), - ) .arg( Arg::new("recall") .long("recall") @@ -118,12 +117,20 @@ fn main() -> anyhow::Result<()> { .help("a prefix to always use when polling/pusing data in etcd from this safekeeper"), ) .arg( - Arg::new("enable-s3-offload") - .long("enable-s3-offload") + Arg::new("wal-backup-threads").long("backup-threads").takes_value(true).help(formatcp!("number of threads for wal backup (default {DEFAULT_WAL_BACKUP_RUNTIME_THREADS}")), + ).arg( + Arg::new("remote-storage") + .long("remote-storage") + .takes_value(true) + .help("Remote storage configuration for WAL backup (offloading to s3) as TOML inline table, e.g. {\"max_concurrent_syncs\" = 17, \"max_sync_errors\": 13, \"bucket_name\": \"\", \"bucket_region\":\"\", \"concurrency_limit\": 119}.\nSafekeeper offloads WAL to [prefix_in_bucket/]//, mirroring structure on the file system.") + ) + .arg( + Arg::new("enable-wal-backup") + .long("enable-wal-backup") .takes_value(true) .default_value("true") .default_missing_value("true") - .help("Enable/disable s3 offloading. When disabled, safekeeper removes WAL ignoring s3 WAL horizon."), + .help("Enable/disable WAL backup to s3. When disabled, safekeeper removes WAL ignoring WAL backup horizon."), ) .get_matches(); @@ -157,10 +164,6 @@ fn main() -> anyhow::Result<()> { conf.listen_http_addr = addr.to_owned(); } - if let Some(ttl) = arg_matches.value_of("ttl") { - conf.ttl = Some(humantime::parse_duration(ttl)?); - } - if let Some(recall) = arg_matches.value_of("recall") { conf.recall_period = humantime::parse_duration(recall)?; } @@ -182,9 +185,21 @@ fn main() -> anyhow::Result<()> { conf.broker_etcd_prefix = prefix.to_string(); } + if let Some(backup_threads) = arg_matches.value_of("wal-backup-threads") { + conf.backup_runtime_threads = backup_threads + .parse() + .with_context(|| format!("Failed to parse backup threads {}", backup_threads))?; + } + if let Some(storage_conf) = arg_matches.value_of("remote-storage") { + // funny toml doesn't consider plain inline table as valid document, so wrap in a key to parse + let storage_conf_toml = format!("remote_storage = {}", storage_conf); + let parsed_toml = storage_conf_toml.parse::()?; // parse + let (_, storage_conf_parsed_toml) = parsed_toml.iter().next().unwrap(); // and strip key off again + conf.remote_storage = Some(RemoteStorageConfig::from_toml(storage_conf_parsed_toml)?); + } // Seems like there is no better way to accept bool values explicitly in clap. - conf.s3_offload_enabled = arg_matches - .value_of("enable-s3-offload") + conf.wal_backup_enabled = arg_matches + .value_of("enable-wal-backup") .unwrap() .parse() .context("failed to parse bool enable-s3-offload bool")?; @@ -252,7 +267,8 @@ fn start_safekeeper(mut conf: SafeKeeperConf, given_id: Option, init: bo let signals = signals::install_shutdown_handlers()?; let mut threads = vec![]; let (callmemaybe_tx, callmemaybe_rx) = mpsc::unbounded_channel(); - GlobalTimelines::set_callmemaybe_tx(callmemaybe_tx); + let (wal_backup_launcher_tx, wal_backup_launcher_rx) = mpsc::channel(100); + GlobalTimelines::init(callmemaybe_tx, wal_backup_launcher_tx); let conf_ = conf.clone(); threads.push( @@ -270,17 +286,6 @@ fn start_safekeeper(mut conf: SafeKeeperConf, given_id: Option, init: bo })?, ); - if conf.ttl.is_some() { - let conf_ = conf.clone(); - threads.push( - thread::Builder::new() - .name("S3 offload thread".into()) - .spawn(|| { - s3_offload::thread_main(conf_); - })?, - ); - } - let conf_cloned = conf.clone(); let safekeeper_thread = thread::Builder::new() .name("Safekeeper thread".into()) @@ -330,6 +335,15 @@ fn start_safekeeper(mut conf: SafeKeeperConf, given_id: Option, init: bo })?, ); + let conf_ = conf.clone(); + threads.push( + thread::Builder::new() + .name("wal backup launcher thread".into()) + .spawn(move || { + wal_backup::wal_backup_launcher_thread_main(conf_, wal_backup_launcher_rx); + })?, + ); + // TODO: put more thoughts into handling of failed threads // We probably should restart them. diff --git a/safekeeper/src/broker.rs b/safekeeper/src/broker.rs index 59d282d378..676719b60d 100644 --- a/safekeeper/src/broker.rs +++ b/safekeeper/src/broker.rs @@ -1,5 +1,6 @@ //! Communication with etcd, providing safekeeper peers and pageserver coordination. +use anyhow::anyhow; use anyhow::Context; use anyhow::Error; use anyhow::Result; @@ -7,9 +8,11 @@ use etcd_broker::Client; use etcd_broker::PutOptions; use etcd_broker::SkTimelineSubscriptionKind; use std::time::Duration; +use tokio::spawn; use tokio::task::JoinHandle; use tokio::{runtime, time::sleep}; use tracing::*; +use url::Url; use crate::{timeline::GlobalTimelines, SafeKeeperConf}; use utils::zid::{NodeId, ZTenantTimelineId}; @@ -44,6 +47,118 @@ fn timeline_safekeeper_path( ) } +pub struct Election { + pub election_name: String, + pub candidate_name: String, + pub broker_endpoints: Vec, +} + +impl Election { + pub fn new(election_name: String, candidate_name: String, broker_endpoints: Vec) -> Self { + Self { + election_name, + candidate_name, + broker_endpoints, + } + } +} + +pub struct ElectionLeader { + client: Client, + keep_alive: JoinHandle>, +} + +impl ElectionLeader { + pub async fn check_am_i( + &mut self, + election_name: String, + candidate_name: String, + ) -> Result { + let resp = self.client.leader(election_name).await?; + + let kv = resp.kv().ok_or(anyhow!("failed to get leader response"))?; + let leader = kv.value_str()?; + + Ok(leader == candidate_name) + } + + pub async fn give_up(self) { + // self.keep_alive.abort(); + // TODO: it'll be wise to resign here but it'll happen after lease expiration anyway + // should we await for keep alive termination? + let _ = self.keep_alive.await; + } +} + +pub async fn get_leader(req: &Election) -> Result { + let mut client = Client::connect(req.broker_endpoints.clone(), None) + .await + .context("Could not connect to etcd")?; + + let lease = client + .lease_grant(LEASE_TTL_SEC, None) + .await + .context("Could not acquire a lease"); + + let lease_id = lease.map(|l| l.id()).unwrap(); + + let keep_alive = spawn::<_>(lease_keep_alive(client.clone(), lease_id)); + + if let Err(e) = client + .campaign( + req.election_name.clone(), + req.candidate_name.clone(), + lease_id, + ) + .await + { + keep_alive.abort(); + let _ = keep_alive.await; + return Err(e.into()); + } + + Ok(ElectionLeader { client, keep_alive }) +} + +async fn lease_keep_alive(mut client: Client, lease_id: i64) -> Result<()> { + let (mut keeper, mut ka_stream) = client + .lease_keep_alive(lease_id) + .await + .context("failed to create keepalive stream")?; + + loop { + let push_interval = Duration::from_millis(PUSH_INTERVAL_MSEC); + + keeper + .keep_alive() + .await + .context("failed to send LeaseKeepAliveRequest")?; + + ka_stream + .message() + .await + .context("failed to receive LeaseKeepAliveResponse")?; + + sleep(push_interval).await; + } +} + +pub fn get_campaign_name( + election_name: String, + broker_prefix: String, + timeline_id: &ZTenantTimelineId, +) -> String { + return format!( + "{}/{}", + SkTimelineSubscriptionKind::timeline(broker_prefix, *timeline_id).watch_key(), + election_name + ); +} + +pub fn get_candiate_name(system_id: NodeId) -> String { + format!("id_{}", system_id) +} + /// Push once in a while data about all active timelines to the broker. async fn push_loop(conf: SafeKeeperConf) -> anyhow::Result<()> { let mut client = Client::connect(&conf.broker_endpoints, None).await?; @@ -59,7 +174,7 @@ async fn push_loop(conf: SafeKeeperConf) -> anyhow::Result<()> { // sensitive and there is no risk of deadlock as we don't await while // lock is held. for zttid in GlobalTimelines::get_active_timelines() { - if let Ok(tli) = GlobalTimelines::get(&conf, zttid, false) { + if let Some(tli) = GlobalTimelines::get_loaded(zttid) { let sk_info = tli.get_public_info(&conf)?; let put_opts = PutOptions::new().with_lease(lease.id()); client @@ -106,12 +221,13 @@ async fn pull_loop(conf: SafeKeeperConf) -> Result<()> { // note: there are blocking operations below, but it's considered fine for now if let Ok(tli) = GlobalTimelines::get(&conf, zttid, false) { for (safekeeper_id, info) in sk_info { - tli.record_safekeeper_info(&info, safekeeper_id)? + tli.record_safekeeper_info(&info, safekeeper_id).await? } } } } None => { + // XXX it means we lost connection with etcd, error is consumed inside sub object debug!("timeline updates sender closed, aborting the pull loop"); return Ok(()); } @@ -142,11 +258,12 @@ async fn main_loop(conf: SafeKeeperConf) { }, res = async { pull_handle.as_mut().unwrap().await }, if pull_handle.is_some() => { // was it panic or normal error? - let err = match res { - Ok(res_internal) => res_internal.unwrap_err(), - Err(err_outer) => err_outer.into(), + match res { + Ok(res_internal) => if let Err(err_inner) = res_internal { + warn!("pull task failed: {:?}", err_inner); + } + Err(err_outer) => { warn!("pull task panicked: {:?}", err_outer) } }; - warn!("pull task failed: {:?}", err); pull_handle = None; }, _ = ticker.tick() => { diff --git a/safekeeper/src/control_file_upgrade.rs b/safekeeper/src/control_file_upgrade.rs index 22716de1a0..8d36472540 100644 --- a/safekeeper/src/control_file_upgrade.rs +++ b/safekeeper/src/control_file_upgrade.rs @@ -165,7 +165,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result timeline_start_lsn: Lsn(0), local_start_lsn: Lsn(0), commit_lsn: oldstate.commit_lsn, - s3_wal_lsn: Lsn(0), + backup_lsn: Lsn(0), peer_horizon_lsn: oldstate.truncate_lsn, remote_consistent_lsn: Lsn(0), peers: Peers(vec![]), @@ -188,7 +188,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result timeline_start_lsn: Lsn(0), local_start_lsn: Lsn(0), commit_lsn: oldstate.commit_lsn, - s3_wal_lsn: Lsn(0), + backup_lsn: Lsn(0), peer_horizon_lsn: oldstate.truncate_lsn, remote_consistent_lsn: Lsn(0), peers: Peers(vec![]), @@ -211,7 +211,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result timeline_start_lsn: Lsn(0), local_start_lsn: Lsn(0), commit_lsn: oldstate.commit_lsn, - s3_wal_lsn: Lsn(0), + backup_lsn: Lsn(0), peer_horizon_lsn: oldstate.truncate_lsn, remote_consistent_lsn: Lsn(0), peers: Peers(vec![]), @@ -234,7 +234,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result timeline_start_lsn: Lsn(0), local_start_lsn: Lsn(0), commit_lsn: oldstate.commit_lsn, - s3_wal_lsn: Lsn(0), + backup_lsn: Lsn::INVALID, peer_horizon_lsn: oldstate.peer_horizon_lsn, remote_consistent_lsn: Lsn(0), peers: Peers(vec![]), diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index 3f6ade970d..b0197a9a2a 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -70,19 +70,19 @@ struct TimelineStatus { timeline_id: ZTimelineId, acceptor_state: AcceptorStateStatus, #[serde(serialize_with = "display_serialize")] + flush_lsn: Lsn, + #[serde(serialize_with = "display_serialize")] timeline_start_lsn: Lsn, #[serde(serialize_with = "display_serialize")] local_start_lsn: Lsn, #[serde(serialize_with = "display_serialize")] commit_lsn: Lsn, #[serde(serialize_with = "display_serialize")] - s3_wal_lsn: Lsn, + backup_lsn: Lsn, #[serde(serialize_with = "display_serialize")] peer_horizon_lsn: Lsn, #[serde(serialize_with = "display_serialize")] remote_consistent_lsn: Lsn, - #[serde(serialize_with = "display_serialize")] - flush_lsn: Lsn, } /// Report info about timeline. @@ -107,13 +107,13 @@ async fn timeline_status_handler(request: Request) -> Result) -> Result, pub recall_period: Duration, + pub remote_storage: Option, + pub backup_runtime_threads: usize, + pub wal_backup_enabled: bool, pub my_id: NodeId, pub broker_endpoints: Vec, pub broker_etcd_prefix: String, - pub s3_offload_enabled: bool, } impl SafeKeeperConf { @@ -77,12 +81,13 @@ impl Default for SafeKeeperConf { no_sync: false, listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(), listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(), - ttl: None, + remote_storage: None, recall_period: defaults::DEFAULT_RECALL_PERIOD, my_id: NodeId(0), broker_endpoints: Vec::new(), broker_etcd_prefix: etcd_broker::DEFAULT_NEON_BROKER_ETCD_PREFIX.to_string(), - s3_offload_enabled: true, + backup_runtime_threads: DEFAULT_WAL_BACKUP_RUNTIME_THREADS, + wal_backup_enabled: true, } } } diff --git a/safekeeper/src/receive_wal.rs b/safekeeper/src/receive_wal.rs index 0ef335c9ed..88b7816912 100644 --- a/safekeeper/src/receive_wal.rs +++ b/safekeeper/src/receive_wal.rs @@ -85,16 +85,10 @@ impl<'pg> ReceiveWalConn<'pg> { _ => bail!("unexpected message {:?} instead of greeting", next_msg), } - // Register the connection and defer unregister. - spg.timeline - .get() - .on_compute_connect(self.pageserver_connstr.as_ref())?; - let _guard = ComputeConnectionGuard { - timeline: Arc::clone(spg.timeline.get()), - }; - let mut next_msg = Some(next_msg); + let mut first_time_through = true; + let mut _guard: Option = None; loop { if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) { // poll AppendRequest's without blocking and write WAL to disk without flushing, @@ -122,6 +116,18 @@ impl<'pg> ReceiveWalConn<'pg> { self.write_msg(&reply)?; } } + if first_time_through { + // Register the connection and defer unregister. Do that only + // after processing first message, as it sets wal_seg_size, + // wanted by many. + spg.timeline + .get() + .on_compute_connect(self.pageserver_connstr.as_ref())?; + _guard = Some(ComputeConnectionGuard { + timeline: Arc::clone(spg.timeline.get()), + }); + first_time_through = false; + } // blocking wait for the next message if next_msg.is_none() { diff --git a/safekeeper/src/remove_wal.rs b/safekeeper/src/remove_wal.rs index 3278d51bd3..004c0243f9 100644 --- a/safekeeper/src/remove_wal.rs +++ b/safekeeper/src/remove_wal.rs @@ -12,7 +12,7 @@ pub fn thread_main(conf: SafeKeeperConf) { let active_tlis = GlobalTimelines::get_active_timelines(); for zttid in &active_tlis { if let Ok(tli) = GlobalTimelines::get(&conf, *zttid, false) { - if let Err(e) = tli.remove_old_wal(conf.s3_offload_enabled) { + if let Err(e) = tli.remove_old_wal(conf.wal_backup_enabled) { warn!( "failed to remove WAL for tenant {} timeline {}: {}", tli.zttid.tenant_id, tli.zttid.timeline_id, e diff --git a/safekeeper/src/s3_offload.rs b/safekeeper/src/s3_offload.rs deleted file mode 100644 index 2851c0b8a0..0000000000 --- a/safekeeper/src/s3_offload.rs +++ /dev/null @@ -1,107 +0,0 @@ -// -// Offload old WAL segments to S3 and remove them locally -// Needs `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables to be set -// if no IAM bucket access is used. -// - -use anyhow::{bail, Context}; -use postgres_ffi::xlog_utils::*; -use remote_storage::{ - GenericRemoteStorage, RemoteStorage, RemoteStorageConfig, S3Bucket, S3Config, S3ObjectKey, -}; -use std::collections::HashSet; -use std::env; -use std::num::{NonZeroU32, NonZeroUsize}; -use std::path::Path; -use std::time::SystemTime; -use tokio::fs::{self, File}; -use tokio::io::BufReader; -use tokio::runtime; -use tokio::time::sleep; -use tracing::*; -use walkdir::WalkDir; - -use crate::SafeKeeperConf; - -pub fn thread_main(conf: SafeKeeperConf) { - // Create a new thread pool - // - // FIXME: keep it single-threaded for now, make it easier to debug with gdb, - // and we're not concerned with performance yet. - //let runtime = runtime::Runtime::new().unwrap(); - let runtime = runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - - info!("Starting S3 offload task"); - - runtime.block_on(async { - main_loop(&conf).await.unwrap(); - }); -} - -async fn offload_files( - remote_storage: &S3Bucket, - listing: &HashSet, - dir_path: &Path, - conf: &SafeKeeperConf, -) -> anyhow::Result { - let horizon = SystemTime::now() - conf.ttl.unwrap(); - let mut n: u64 = 0; - for entry in WalkDir::new(dir_path) { - let entry = entry?; - let path = entry.path(); - - if path.is_file() - && IsXLogFileName(entry.file_name().to_str().unwrap()) - && entry.metadata().unwrap().created().unwrap() <= horizon - { - let remote_path = remote_storage.remote_object_id(path)?; - if !listing.contains(&remote_path) { - let file = File::open(&path).await?; - let file_length = file.metadata().await?.len() as usize; - remote_storage - .upload(BufReader::new(file), file_length, &remote_path, None) - .await?; - - fs::remove_file(&path).await?; - n += 1; - } - } - } - Ok(n) -} - -async fn main_loop(conf: &SafeKeeperConf) -> anyhow::Result<()> { - let remote_storage = match GenericRemoteStorage::new( - conf.workdir.clone(), - &RemoteStorageConfig { - max_concurrent_syncs: NonZeroUsize::new(10).unwrap(), - max_sync_errors: NonZeroU32::new(1).unwrap(), - storage: remote_storage::RemoteStorageKind::AwsS3(S3Config { - bucket_name: "zenith-testbucket".to_string(), - bucket_region: env::var("S3_REGION").context("S3_REGION env var is not set")?, - prefix_in_bucket: Some("walarchive/".to_string()), - endpoint: Some(env::var("S3_ENDPOINT").context("S3_ENDPOINT env var is not set")?), - concurrency_limit: NonZeroUsize::new(20).unwrap(), - }), - }, - )? { - GenericRemoteStorage::Local(_) => { - bail!("Unexpected: got local storage for the remote config") - } - GenericRemoteStorage::S3(remote_storage) => remote_storage, - }; - - loop { - let listing = remote_storage - .list() - .await? - .into_iter() - .collect::>(); - let n = offload_files(&remote_storage, &listing, &conf.workdir, conf).await?; - info!("Offload {n} files to S3"); - sleep(conf.ttl.unwrap()).await; - } -} diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index b8b969929d..9a07127771 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -19,6 +19,7 @@ use lazy_static::lazy_static; use crate::control_file; use crate::send_wal::HotStandbyFeedback; + use crate::wal_storage; use metrics::{register_gauge_vec, Gauge, GaugeVec}; use postgres_ffi::xlog_utils::MAX_SEND_SIZE; @@ -141,7 +142,7 @@ pub struct ServerInfo { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PeerInfo { /// LSN up to which safekeeper offloaded WAL to s3. - s3_wal_lsn: Lsn, + backup_lsn: Lsn, /// Term of the last entry. term: Term, /// LSN of the last record. @@ -153,7 +154,7 @@ pub struct PeerInfo { impl PeerInfo { fn new() -> Self { Self { - s3_wal_lsn: Lsn(0), + backup_lsn: Lsn::INVALID, term: INVALID_TERM, flush_lsn: Lsn(0), commit_lsn: Lsn(0), @@ -193,9 +194,9 @@ pub struct SafeKeeperState { /// Part of WAL acknowledged by quorum and available locally. Always points /// to record boundary. pub commit_lsn: Lsn, - /// First LSN not yet offloaded to s3. Useful to persist to avoid finding - /// out offloading progress on boot. - pub s3_wal_lsn: Lsn, + /// LSN that points to the end of the last backed up segment. Useful to + /// persist to avoid finding out offloading progress on boot. + pub backup_lsn: Lsn, /// Minimal LSN which may be needed for recovery of some safekeeper (end_lsn /// of last record streamed to everyone). Persisting it helps skipping /// recovery in walproposer, generally we compute it from peers. In @@ -217,7 +218,7 @@ pub struct SafeKeeperState { // are not flushed yet. pub struct SafekeeperMemState { pub commit_lsn: Lsn, - pub s3_wal_lsn: Lsn, // TODO: keep only persistent version + pub backup_lsn: Lsn, pub peer_horizon_lsn: Lsn, pub remote_consistent_lsn: Lsn, pub proposer_uuid: PgUuid, @@ -241,7 +242,7 @@ impl SafeKeeperState { timeline_start_lsn: Lsn(0), local_start_lsn: Lsn(0), commit_lsn: Lsn(0), - s3_wal_lsn: Lsn(0), + backup_lsn: Lsn::INVALID, peer_horizon_lsn: Lsn(0), remote_consistent_lsn: Lsn(0), peers: Peers(peers.iter().map(|p| (*p, PeerInfo::new())).collect()), @@ -559,7 +560,7 @@ where epoch_start_lsn: Lsn(0), inmem: SafekeeperMemState { commit_lsn: state.commit_lsn, - s3_wal_lsn: state.s3_wal_lsn, + backup_lsn: state.backup_lsn, peer_horizon_lsn: state.peer_horizon_lsn, remote_consistent_lsn: state.remote_consistent_lsn, proposer_uuid: state.proposer_uuid, @@ -649,7 +650,6 @@ where self.state.persist(&state)?; } - // pass wal_seg_size to read WAL and find flush_lsn self.wal_store.init_storage(&self.state)?; info!( @@ -764,6 +764,14 @@ where self.inmem.commit_lsn = commit_lsn; self.metrics.commit_lsn.set(self.inmem.commit_lsn.0 as f64); + // We got our first commit_lsn, which means we should sync + // everything to disk, to initialize the state. + if self.state.commit_lsn == Lsn::INVALID && commit_lsn != Lsn::INVALID { + self.inmem.backup_lsn = self.inmem.commit_lsn; // initialize backup_lsn + self.wal_store.flush_wal()?; + self.persist_control_file()?; + } + // If new commit_lsn reached epoch switch, force sync of control // file: walproposer in sync mode is very interested when this // happens. Note: this is for sync-safekeepers mode only, as @@ -775,22 +783,14 @@ where self.persist_control_file()?; } - // We got our first commit_lsn, which means we should sync - // everything to disk, to initialize the state. - if self.state.commit_lsn == Lsn(0) && commit_lsn > Lsn(0) { - self.wal_store.flush_wal()?; - self.persist_control_file()?; - } - Ok(()) } /// Persist in-memory state to the disk. fn persist_control_file(&mut self) -> Result<()> { let mut state = self.state.clone(); - state.commit_lsn = self.inmem.commit_lsn; - state.s3_wal_lsn = self.inmem.s3_wal_lsn; + state.backup_lsn = self.inmem.backup_lsn; state.peer_horizon_lsn = self.inmem.peer_horizon_lsn; state.remote_consistent_lsn = self.inmem.remote_consistent_lsn; state.proposer_uuid = self.inmem.proposer_uuid; @@ -898,11 +898,11 @@ where self.update_commit_lsn()?; } } - if let Some(s3_wal_lsn) = sk_info.s3_wal_lsn { - let new_s3_wal_lsn = max(s3_wal_lsn, self.inmem.s3_wal_lsn); + if let Some(backup_lsn) = sk_info.backup_lsn { + let new_backup_lsn = max(backup_lsn, self.inmem.backup_lsn); sync_control_file |= - self.state.s3_wal_lsn + (self.state.server.wal_seg_size as u64) < new_s3_wal_lsn; - self.inmem.s3_wal_lsn = new_s3_wal_lsn; + self.state.backup_lsn + (self.state.server.wal_seg_size as u64) < new_backup_lsn; + self.inmem.backup_lsn = new_backup_lsn; } if let Some(remote_consistent_lsn) = sk_info.remote_consistent_lsn { let new_remote_consistent_lsn = @@ -930,29 +930,23 @@ where /// offloading. /// While it is safe to use inmem values for determining horizon, /// we use persistent to make possible normal states less surprising. - pub fn get_horizon_segno(&self, s3_offload_enabled: bool) -> XLogSegNo { - let s3_offload_horizon = if s3_offload_enabled { - self.state.s3_wal_lsn - } else { - Lsn(u64::MAX) - }; - let horizon_lsn = min( - min( - self.state.remote_consistent_lsn, - self.state.peer_horizon_lsn, - ), - s3_offload_horizon, + pub fn get_horizon_segno(&self, wal_backup_enabled: bool) -> XLogSegNo { + let mut horizon_lsn = min( + self.state.remote_consistent_lsn, + self.state.peer_horizon_lsn, ); + if wal_backup_enabled { + horizon_lsn = min(horizon_lsn, self.state.backup_lsn); + } horizon_lsn.segment_number(self.state.server.wal_seg_size as usize) } } #[cfg(test)] mod tests { - use std::ops::Deref; - use super::*; use crate::wal_storage::Storage; + use std::ops::Deref; // fake storage for tests struct InMemoryState { @@ -1013,6 +1007,7 @@ mod tests { }; let wal_store = DummyWalStore { lsn: Lsn(0) }; let ztli = ZTimelineId::from([0u8; 16]); + let mut sk = SafeKeeper::new(ztli, storage, wal_store, NodeId(0)).unwrap(); // check voting for 1 is ok @@ -1028,6 +1023,7 @@ mod tests { let storage = InMemoryState { persisted_state: state, }; + sk = SafeKeeper::new(ztli, storage, sk.wal_store, NodeId(0)).unwrap(); // and ensure voting second time for 1 is not ok @@ -1045,6 +1041,7 @@ mod tests { }; let wal_store = DummyWalStore { lsn: Lsn(0) }; let ztli = ZTimelineId::from([0u8; 16]); + let mut sk = SafeKeeper::new(ztli, storage, wal_store, NodeId(0)).unwrap(); let mut ar_hdr = AppendRequestHeader { diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index d52dd6ea57..a89ed18071 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -315,7 +315,7 @@ impl ReplicationConn { } else { // TODO: also check once in a while whether we are walsender // to right pageserver. - if spg.timeline.get().check_deactivate(replica_id)? { + if spg.timeline.get().stop_walsender(replica_id)? { // Shut down, timeline is suspended. // TODO create proper error type for this bail!("end streaming to {:?}", spg.appname); diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index 0953439bd8..74a61410fd 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -8,6 +8,7 @@ use lazy_static::lazy_static; use postgres_ffi::xlog_utils::XLogSegNo; use serde::Serialize; +use tokio::sync::watch; use std::cmp::{max, min}; use std::collections::HashMap; @@ -15,7 +16,7 @@ use std::fs::{self}; use std::sync::{Arc, Condvar, Mutex, MutexGuard}; use std::time::Duration; -use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::mpsc::{Sender, UnboundedSender}; use tracing::*; use utils::{ @@ -25,13 +26,13 @@ use utils::{ }; use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey}; - use crate::control_file; use crate::safekeeper::{ AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState, SafekeeperMemState, }; use crate::send_wal::HotStandbyFeedback; + use crate::wal_storage; use crate::wal_storage::Storage as wal_storage_iface; use crate::SafeKeeperConf; @@ -81,10 +82,14 @@ struct SharedState { notified_commit_lsn: Lsn, /// State of replicas replicas: Vec>, - /// Inactive clusters shouldn't occupy any resources, so timeline is - /// activated whenever there is a compute connection or pageserver is not - /// caughtup (it must have latest WAL for new compute start) and suspended - /// otherwise. + /// True when WAL backup launcher oversees the timeline, making sure WAL is + /// offloaded, allows to bother launcher less. + wal_backup_active: bool, + /// True whenever there is at least some pending activity on timeline: live + /// compute connection, pageserver is not caughtup (it must have latest WAL + /// for new compute start) or WAL backuping is not finished. Practically it + /// means safekeepers broadcast info to peers about the timeline, old WAL is + /// trimmed. /// /// TODO: it might be better to remove tli completely from GlobalTimelines /// when tli is inactive instead of having this flag. @@ -103,6 +108,7 @@ impl SharedState { ) -> Result { let state = SafeKeeperState::new(zttid, peer_ids); let control_store = control_file::FileStorage::create_new(zttid, conf, state)?; + let wal_store = wal_storage::PhysicalStorage::new(zttid, conf); let sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store, conf.my_id)?; @@ -110,6 +116,7 @@ impl SharedState { notified_commit_lsn: Lsn(0), sk, replicas: Vec::new(), + wal_backup_active: false, active: false, num_computes: 0, pageserver_connstr: None, @@ -129,15 +136,62 @@ impl SharedState { notified_commit_lsn: Lsn(0), sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store, conf.my_id)?, replicas: Vec::new(), + wal_backup_active: false, active: false, num_computes: 0, pageserver_connstr: None, last_removed_segno: 0, }) } + fn is_active(&self) -> bool { + self.is_wal_backup_required() + // FIXME: add tracking of relevant pageservers and check them here individually, + // otherwise migration won't work (we suspend too early). + || self.sk.inmem.remote_consistent_lsn <= self.sk.inmem.commit_lsn + } - /// Activate the timeline: start/change walsender (via callmemaybe). - fn activate( + /// Mark timeline active/inactive and return whether s3 offloading requires + /// start/stop action. + fn update_status(&mut self) -> bool { + self.active = self.is_active(); + self.is_wal_backup_action_pending() + } + + /// Should we run s3 offloading in current state? + fn is_wal_backup_required(&self) -> bool { + let seg_size = self.get_wal_seg_size(); + self.num_computes > 0 || + // Currently only the whole segment is offloaded, so compare segment numbers. + (self.sk.inmem.commit_lsn.segment_number(seg_size) > + self.sk.inmem.backup_lsn.segment_number(seg_size)) + } + + /// Is current state of s3 offloading is not what it ought to be? + fn is_wal_backup_action_pending(&self) -> bool { + let res = self.wal_backup_active != self.is_wal_backup_required(); + if res { + let action_pending = if self.is_wal_backup_required() { + "start" + } else { + "stop" + }; + trace!( + "timeline {} s3 offloading action {} pending: num_computes={}, commit_lsn={}, backup_lsn={}", + self.sk.state.timeline_id, action_pending, self.num_computes, self.sk.inmem.commit_lsn, self.sk.inmem.backup_lsn + ); + } + res + } + + /// Returns whether s3 offloading is required and sets current status as + /// matching. + fn wal_backup_attend(&mut self) -> bool { + self.wal_backup_active = self.is_wal_backup_required(); + self.wal_backup_active + } + + /// start/change walsender (via callmemaybe). + fn callmemaybe_sub( &mut self, zttid: &ZTenantTimelineId, pageserver_connstr: Option<&String>, @@ -179,42 +233,42 @@ impl SharedState { ); } self.pageserver_connstr = pageserver_connstr.map(|c| c.to_owned()); - self.active = true; Ok(()) } /// Deactivate the timeline: stop callmemaybe. - fn deactivate( + fn callmemaybe_unsub( &mut self, zttid: &ZTenantTimelineId, callmemaybe_tx: &UnboundedSender, ) -> Result<()> { - if self.active { - if let Some(ref pageserver_connstr) = self.pageserver_connstr { - let subscription_key = SubscriptionStateKey::new( - zttid.tenant_id, - zttid.timeline_id, - pageserver_connstr.to_owned(), - ); - callmemaybe_tx - .send(CallmeEvent::Unsubscribe(subscription_key)) - .unwrap_or_else(|e| { - error!( - "failed to send Unsubscribe request to callmemaybe thread {}", - e - ); - }); - info!( - "timeline {} is unsubscribed from callmemaybe to {}", - zttid.timeline_id, - self.pageserver_connstr.as_ref().unwrap() - ); - } - self.active = false; + if let Some(ref pageserver_connstr) = self.pageserver_connstr { + let subscription_key = SubscriptionStateKey::new( + zttid.tenant_id, + zttid.timeline_id, + pageserver_connstr.to_owned(), + ); + callmemaybe_tx + .send(CallmeEvent::Unsubscribe(subscription_key)) + .unwrap_or_else(|e| { + error!( + "failed to send Unsubscribe request to callmemaybe thread {}", + e + ); + }); + info!( + "timeline {} is unsubscribed from callmemaybe to {}", + zttid.timeline_id, + self.pageserver_connstr.as_ref().unwrap() + ); } Ok(()) } + fn get_wal_seg_size(&self) -> usize { + self.sk.state.server.wal_seg_size as usize + } + /// Get combined state of all alive replicas pub fn get_replicas_state(&self) -> ReplicaState { let mut acc = ReplicaState::new(); @@ -278,6 +332,13 @@ impl SharedState { pub struct Timeline { pub zttid: ZTenantTimelineId, pub callmemaybe_tx: UnboundedSender, + /// Sending here asks for wal backup launcher attention (start/stop + /// offloading). Sending zttid instead of concrete command allows to do + /// sending without timeline lock. + wal_backup_launcher_tx: Sender, + commit_lsn_watch_tx: watch::Sender, + /// For breeding receivers. + commit_lsn_watch_rx: watch::Receiver, mutex: Mutex, /// conditional variable used to notify wal senders cond: Condvar, @@ -287,11 +348,17 @@ impl Timeline { fn new( zttid: ZTenantTimelineId, callmemaybe_tx: UnboundedSender, + wal_backup_launcher_tx: Sender, shared_state: SharedState, ) -> Timeline { + let (commit_lsn_watch_tx, commit_lsn_watch_rx) = + watch::channel(shared_state.sk.inmem.commit_lsn); Timeline { zttid, callmemaybe_tx, + wal_backup_launcher_tx, + commit_lsn_watch_tx, + commit_lsn_watch_rx, mutex: Mutex::new(shared_state), cond: Condvar::new(), } @@ -301,13 +368,21 @@ impl Timeline { /// not running yet. /// Can fail only if channel to a static thread got closed, which is not normal at all. pub fn on_compute_connect(&self, pageserver_connstr: Option<&String>) -> Result<()> { - let mut shared_state = self.mutex.lock().unwrap(); - shared_state.num_computes += 1; - // FIXME: currently we always adopt latest pageserver connstr, but we - // should have kind of generations assigned by compute to distinguish - // the latest one or even pass it through consensus to reliably deliver - // to all safekeepers. - shared_state.activate(&self.zttid, pageserver_connstr, &self.callmemaybe_tx)?; + let is_wal_backup_action_pending: bool; + { + let mut shared_state = self.mutex.lock().unwrap(); + shared_state.num_computes += 1; + is_wal_backup_action_pending = shared_state.update_status(); + // FIXME: currently we always adopt latest pageserver connstr, but we + // should have kind of generations assigned by compute to distinguish + // the latest one or even pass it through consensus to reliably deliver + // to all safekeepers. + shared_state.callmemaybe_sub(&self.zttid, pageserver_connstr, &self.callmemaybe_tx)?; + } + // Wake up wal backup launcher, if offloading not started yet. + if is_wal_backup_action_pending { + self.wal_backup_launcher_tx.blocking_send(self.zttid)?; + } Ok(()) } @@ -315,38 +390,43 @@ impl Timeline { /// pageserver doesn't need catchup. /// Can fail only if channel to a static thread got closed, which is not normal at all. pub fn on_compute_disconnect(&self) -> Result<()> { - let mut shared_state = self.mutex.lock().unwrap(); - shared_state.num_computes -= 1; - // If there is no pageserver, can suspend right away; otherwise let - // walsender do that. - if shared_state.num_computes == 0 && shared_state.pageserver_connstr.is_none() { - shared_state.deactivate(&self.zttid, &self.callmemaybe_tx)?; + let is_wal_backup_action_pending: bool; + { + let mut shared_state = self.mutex.lock().unwrap(); + shared_state.num_computes -= 1; + is_wal_backup_action_pending = shared_state.update_status(); + } + // Wake up wal backup launcher, if it is time to stop the offloading. + if is_wal_backup_action_pending { + self.wal_backup_launcher_tx.blocking_send(self.zttid)?; } Ok(()) } - /// Deactivate tenant if there is no computes and pageserver is caughtup, - /// assuming the pageserver status is in replica_id. - /// Returns true if deactivated. - pub fn check_deactivate(&self, replica_id: usize) -> Result { + /// Whether we still need this walsender running? + /// TODO: check this pageserver is actually interested in this timeline. + pub fn stop_walsender(&self, replica_id: usize) -> Result { let mut shared_state = self.mutex.lock().unwrap(); - if !shared_state.active { - // already suspended - return Ok(true); - } if shared_state.num_computes == 0 { let replica_state = shared_state.replicas[replica_id].unwrap(); - let deactivate = shared_state.notified_commit_lsn == Lsn(0) || // no data at all yet - (replica_state.last_received_lsn != Lsn::MAX && // Lsn::MAX means that we don't know the latest LSN yet. - replica_state.last_received_lsn >= shared_state.sk.inmem.commit_lsn); - if deactivate { - shared_state.deactivate(&self.zttid, &self.callmemaybe_tx)?; + let stop = shared_state.notified_commit_lsn == Lsn(0) || // no data at all yet + (replica_state.remote_consistent_lsn != Lsn::MAX && // Lsn::MAX means that we don't know the latest LSN yet. + replica_state.remote_consistent_lsn >= shared_state.sk.inmem.commit_lsn); + if stop { + shared_state.callmemaybe_unsub(&self.zttid, &self.callmemaybe_tx)?; return Ok(true); } } Ok(false) } + /// Returns whether s3 offloading is required and sets current status as + /// matching it. + pub fn wal_backup_attend(&self) -> bool { + let mut shared_state = self.mutex.lock().unwrap(); + shared_state.wal_backup_attend() + } + /// Deactivates the timeline, assuming it is being deleted. /// Returns whether the timeline was already active. /// @@ -354,10 +434,14 @@ impl Timeline { /// will stop by themselves eventually (possibly with errors, but no panics). There should be no /// compute threads (as we're deleting the timeline), actually. Some WAL may be left unsent, but /// we're deleting the timeline anyway. - pub fn deactivate_for_delete(&self) -> Result { - let mut shared_state = self.mutex.lock().unwrap(); - let was_active = shared_state.active; - shared_state.deactivate(&self.zttid, &self.callmemaybe_tx)?; + pub async fn deactivate_for_delete(&self) -> Result { + let was_active: bool; + { + let mut shared_state = self.mutex.lock().unwrap(); + was_active = shared_state.active; + shared_state.callmemaybe_unsub(&self.zttid, &self.callmemaybe_tx)?; + } + self.wal_backup_launcher_tx.send(self.zttid).await?; Ok(was_active) } @@ -391,6 +475,7 @@ impl Timeline { } // Notify caught-up WAL senders about new WAL data received + // TODO: replace-unify it with commit_lsn_watch. fn notify_wal_senders(&self, shared_state: &mut MutexGuard) { if shared_state.notified_commit_lsn < shared_state.sk.inmem.commit_lsn { shared_state.notified_commit_lsn = shared_state.sk.inmem.commit_lsn; @@ -398,12 +483,17 @@ impl Timeline { } } + pub fn get_commit_lsn_watch_rx(&self) -> watch::Receiver { + self.commit_lsn_watch_rx.clone() + } + /// Pass arrived message to the safekeeper. pub fn process_msg( &self, msg: &ProposerAcceptorMessage, ) -> Result> { let mut rmsg: Option; + let commit_lsn: Lsn; { let mut shared_state = self.mutex.lock().unwrap(); rmsg = shared_state.sk.process_msg(msg)?; @@ -419,15 +509,31 @@ impl Timeline { // Ping wal sender that new data might be available. self.notify_wal_senders(&mut shared_state); + commit_lsn = shared_state.sk.inmem.commit_lsn; } + self.commit_lsn_watch_tx.send(commit_lsn)?; Ok(rmsg) } + pub fn get_wal_seg_size(&self) -> usize { + self.mutex.lock().unwrap().get_wal_seg_size() + } + pub fn get_state(&self) -> (SafekeeperMemState, SafeKeeperState) { let shared_state = self.mutex.lock().unwrap(); (shared_state.sk.inmem.clone(), shared_state.sk.state.clone()) } + pub fn get_wal_backup_lsn(&self) -> Lsn { + self.mutex.lock().unwrap().sk.inmem.backup_lsn + } + + pub fn set_wal_backup_lsn(&self, backup_lsn: Lsn) { + self.mutex.lock().unwrap().sk.inmem.backup_lsn = backup_lsn; + // we should check whether to shut down offloader, but this will be done + // soon by peer communication anyway. + } + /// Prepare public safekeeper info for reporting. pub fn get_public_info(&self, conf: &SafeKeeperConf) -> anyhow::Result { let shared_state = self.mutex.lock().unwrap(); @@ -436,7 +542,6 @@ impl Timeline { flush_lsn: Some(shared_state.sk.wal_store.flush_lsn()), // note: this value is not flushed to control file yet and can be lost commit_lsn: Some(shared_state.sk.inmem.commit_lsn), - s3_wal_lsn: Some(shared_state.sk.inmem.s3_wal_lsn), // TODO: rework feedbacks to avoid max here remote_consistent_lsn: Some(max( shared_state.get_replicas_state().remote_consistent_lsn, @@ -444,14 +549,35 @@ impl Timeline { )), peer_horizon_lsn: Some(shared_state.sk.inmem.peer_horizon_lsn), safekeeper_connection_string: Some(conf.listen_pg_addr.clone()), + backup_lsn: Some(shared_state.sk.inmem.backup_lsn), }) } /// Update timeline state with peer safekeeper data. - pub fn record_safekeeper_info(&self, sk_info: &SkTimelineInfo, _sk_id: NodeId) -> Result<()> { - let mut shared_state = self.mutex.lock().unwrap(); - shared_state.sk.record_safekeeper_info(sk_info)?; - self.notify_wal_senders(&mut shared_state); + pub async fn record_safekeeper_info( + &self, + sk_info: &SkTimelineInfo, + _sk_id: NodeId, + ) -> Result<()> { + let is_wal_backup_action_pending: bool; + let commit_lsn: Lsn; + { + let mut shared_state = self.mutex.lock().unwrap(); + // WAL seg size not initialized yet (no message from compute ever + // received), can't do much without it. + if shared_state.get_wal_seg_size() == 0 { + return Ok(()); + } + shared_state.sk.record_safekeeper_info(sk_info)?; + self.notify_wal_senders(&mut shared_state); + is_wal_backup_action_pending = shared_state.update_status(); + commit_lsn = shared_state.sk.inmem.commit_lsn; + } + self.commit_lsn_watch_tx.send(commit_lsn)?; + // Wake up wal backup launcher, if it is time to stop the offloading. + if is_wal_backup_action_pending { + self.wal_backup_launcher_tx.send(self.zttid).await?; + } Ok(()) } @@ -476,16 +602,16 @@ impl Timeline { shared_state.sk.wal_store.flush_lsn() } - pub fn remove_old_wal(&self, s3_offload_enabled: bool) -> Result<()> { + pub fn remove_old_wal(&self, wal_backup_enabled: bool) -> Result<()> { let horizon_segno: XLogSegNo; let remover: Box Result<(), anyhow::Error>>; { let shared_state = self.mutex.lock().unwrap(); // WAL seg size not initialized yet, no WAL exists. - if shared_state.sk.state.server.wal_seg_size == 0 { + if shared_state.get_wal_seg_size() == 0 { return Ok(()); } - horizon_segno = shared_state.sk.get_horizon_segno(s3_offload_enabled); + horizon_segno = shared_state.sk.get_horizon_segno(wal_backup_enabled); remover = shared_state.sk.wal_store.remove_up_to(); if horizon_segno <= 1 || horizon_segno <= shared_state.last_removed_segno { return Ok(()); @@ -522,12 +648,14 @@ impl TimelineTools for Option> { struct GlobalTimelinesState { timelines: HashMap>, callmemaybe_tx: Option>, + wal_backup_launcher_tx: Option>, } lazy_static! { static ref TIMELINES_STATE: Mutex = Mutex::new(GlobalTimelinesState { timelines: HashMap::new(), - callmemaybe_tx: None + callmemaybe_tx: None, + wal_backup_launcher_tx: None, }); } @@ -541,10 +669,15 @@ pub struct TimelineDeleteForceResult { pub struct GlobalTimelines; impl GlobalTimelines { - pub fn set_callmemaybe_tx(callmemaybe_tx: UnboundedSender) { + pub fn init( + callmemaybe_tx: UnboundedSender, + wal_backup_launcher_tx: Sender, + ) { let mut state = TIMELINES_STATE.lock().unwrap(); assert!(state.callmemaybe_tx.is_none()); state.callmemaybe_tx = Some(callmemaybe_tx); + assert!(state.wal_backup_launcher_tx.is_none()); + state.wal_backup_launcher_tx = Some(wal_backup_launcher_tx); } fn create_internal( @@ -559,12 +692,14 @@ impl GlobalTimelines { // TODO: check directory existence let dir = conf.timeline_dir(&zttid); fs::create_dir_all(dir)?; + let shared_state = SharedState::create(conf, &zttid, peer_ids) .context("failed to create shared state")?; let new_tli = Arc::new(Timeline::new( zttid, state.callmemaybe_tx.as_ref().unwrap().clone(), + state.wal_backup_launcher_tx.as_ref().unwrap().clone(), shared_state, )); state.timelines.insert(zttid, Arc::clone(&new_tli)); @@ -594,8 +729,7 @@ impl GlobalTimelines { match state.timelines.get(&zttid) { Some(result) => Ok(Arc::clone(result)), None => { - let shared_state = - SharedState::restore(conf, &zttid).context("failed to restore shared state"); + let shared_state = SharedState::restore(conf, &zttid); let shared_state = match shared_state { Ok(shared_state) => shared_state, @@ -617,6 +751,7 @@ impl GlobalTimelines { let new_tli = Arc::new(Timeline::new( zttid, state.callmemaybe_tx.as_ref().unwrap().clone(), + state.wal_backup_launcher_tx.as_ref().unwrap().clone(), shared_state, )); state.timelines.insert(zttid, Arc::clone(&new_tli)); @@ -625,6 +760,12 @@ impl GlobalTimelines { } } + /// Get loaded timeline, if it exists. + pub fn get_loaded(zttid: ZTenantTimelineId) -> Option> { + let state = TIMELINES_STATE.lock().unwrap(); + state.timelines.get(&zttid).map(Arc::clone) + } + /// Get ZTenantTimelineIDs of all active timelines. pub fn get_active_timelines() -> Vec { let state = TIMELINES_STATE.lock().unwrap(); @@ -665,22 +806,23 @@ impl GlobalTimelines { /// b) an HTTP GET request about the timeline is made and it's able to restore the current state, or /// c) an HTTP POST request for timeline creation is made after the timeline is already deleted. /// TODO: ensure all of the above never happens. - pub fn delete_force( + pub async fn delete_force( conf: &SafeKeeperConf, zttid: &ZTenantTimelineId, ) -> Result { info!("deleting timeline {}", zttid); - let was_active = match TIMELINES_STATE.lock().unwrap().timelines.remove(zttid) { - None => false, - Some(tli) => tli.deactivate_for_delete()?, - }; + let timeline = TIMELINES_STATE.lock().unwrap().timelines.remove(zttid); + let mut was_active = false; + if let Some(tli) = timeline { + was_active = tli.deactivate_for_delete().await?; + } GlobalTimelines::delete_force_internal(conf, zttid, was_active) } /// Deactivates and deletes all timelines for the tenant, see `delete()`. /// Returns map of all timelines which the tenant had, `true` if a timeline was active. /// There may be a race if new timelines are created simultaneously. - pub fn delete_force_all_for_tenant( + pub async fn delete_force_all_for_tenant( conf: &SafeKeeperConf, tenant_id: &ZTenantId, ) -> Result> { @@ -691,14 +833,15 @@ impl GlobalTimelines { let timelines = &mut TIMELINES_STATE.lock().unwrap().timelines; for (&zttid, tli) in timelines.iter() { if zttid.tenant_id == *tenant_id { - to_delete.insert(zttid, tli.deactivate_for_delete()?); + to_delete.insert(zttid, tli.clone()); } } // TODO: test that the correct subset of timelines is removed. It's complicated because they are implicitly created currently. timelines.retain(|zttid, _| !to_delete.contains_key(zttid)); } let mut deleted = HashMap::new(); - for (zttid, was_active) in to_delete { + for (zttid, timeline) in to_delete { + let was_active = timeline.deactivate_for_delete().await?; deleted.insert( zttid, GlobalTimelines::delete_force_internal(conf, &zttid, was_active)?, diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs new file mode 100644 index 0000000000..ef8ebe14e1 --- /dev/null +++ b/safekeeper/src/wal_backup.rs @@ -0,0 +1,418 @@ +use anyhow::{Context, Result}; +use tokio::task::JoinHandle; + +use std::cmp::min; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use postgres_ffi::xlog_utils::{XLogFileName, XLogSegNo, XLogSegNoOffsetToRecPtr, PG_TLI}; +use remote_storage::{GenericRemoteStorage, RemoteStorage}; +use tokio::fs::File; +use tokio::runtime::Builder; + +use tokio::select; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::sync::watch; +use tokio::time::sleep; +use tracing::*; + +use utils::{lsn::Lsn, zid::ZTenantTimelineId}; + +use crate::broker::{Election, ElectionLeader}; +use crate::timeline::{GlobalTimelines, Timeline}; +use crate::{broker, SafeKeeperConf}; + +use once_cell::sync::OnceCell; + +const BACKUP_ELECTION_NAME: &str = "WAL_BACKUP"; + +const BROKER_CONNECTION_RETRY_DELAY_MS: u64 = 1000; + +const UPLOAD_FAILURE_RETRY_MIN_MS: u64 = 10; +const UPLOAD_FAILURE_RETRY_MAX_MS: u64 = 5000; + +pub fn wal_backup_launcher_thread_main( + conf: SafeKeeperConf, + wal_backup_launcher_rx: Receiver, +) { + let rt = Builder::new_multi_thread() + .worker_threads(conf.backup_runtime_threads) + .enable_all() + .build() + .expect("failed to create wal backup runtime"); + + rt.block_on(async { + wal_backup_launcher_main_loop(conf, wal_backup_launcher_rx).await; + }); +} + +/// Check whether wal backup is required for timeline and mark that launcher is +/// aware of current status (if timeline exists). +fn is_wal_backup_required(zttid: ZTenantTimelineId) -> bool { + if let Some(tli) = GlobalTimelines::get_loaded(zttid) { + tli.wal_backup_attend() + } else { + false + } +} + +struct WalBackupTaskHandle { + shutdown_tx: Sender<()>, + handle: JoinHandle<()>, +} + +/// Sits on wal_backup_launcher_rx and starts/stops per timeline wal backup +/// tasks. Having this in separate task simplifies locking, allows to reap +/// panics and separate elections from offloading itself. +async fn wal_backup_launcher_main_loop( + conf: SafeKeeperConf, + mut wal_backup_launcher_rx: Receiver, +) { + info!( + "wal backup launcher started, remote config {:?}", + conf.remote_storage + ); + + let conf_ = conf.clone(); + REMOTE_STORAGE.get_or_init(|| { + conf_.remote_storage.as_ref().map(|c| { + GenericRemoteStorage::new(conf_.workdir, c).expect("failed to create remote storage") + }) + }); + + let mut tasks: HashMap = HashMap::new(); + + loop { + // channel is never expected to get closed + let zttid = wal_backup_launcher_rx.recv().await.unwrap(); + let is_wal_backup_required = is_wal_backup_required(zttid); + if conf.remote_storage.is_none() || !conf.wal_backup_enabled { + continue; /* just drain the channel and do nothing */ + } + // do we need to do anything at all? + if is_wal_backup_required != tasks.contains_key(&zttid) { + if is_wal_backup_required { + // need to start the task + info!("starting wal backup task for {}", zttid); + + // TODO: decide who should offload in launcher itself by simply checking current state + let election_name = broker::get_campaign_name( + BACKUP_ELECTION_NAME.to_string(), + conf.broker_etcd_prefix.clone(), + &zttid, + ); + let my_candidate_name = broker::get_candiate_name(conf.my_id); + let election = broker::Election::new( + election_name, + my_candidate_name, + conf.broker_endpoints.clone(), + ); + + let (shutdown_tx, shutdown_rx) = mpsc::channel(1); + let timeline_dir = conf.timeline_dir(&zttid); + + let handle = tokio::spawn( + backup_task_main(zttid, timeline_dir, shutdown_rx, election) + .instrument(info_span!("WAL backup", zttid = %zttid)), + ); + + tasks.insert( + zttid, + WalBackupTaskHandle { + shutdown_tx, + handle, + }, + ); + } else { + // need to stop the task + info!("stopping wal backup task for {}", zttid); + + let wb_handle = tasks.remove(&zttid).unwrap(); + // Tell the task to shutdown. Error means task exited earlier, that's ok. + let _ = wb_handle.shutdown_tx.send(()).await; + // Await the task itself. TODO: restart panicked tasks earlier. + // Hm, why I can't await on reference to handle? + if let Err(e) = wb_handle.handle.await { + warn!("WAL backup task for {} panicked: {}", zttid, e); + } + } + } + } +} + +struct WalBackupTask { + timeline: Arc, + timeline_dir: PathBuf, + wal_seg_size: usize, + commit_lsn_watch_rx: watch::Receiver, + leader: Option, + election: Election, +} + +/// Offload single timeline. +async fn backup_task_main( + zttid: ZTenantTimelineId, + timeline_dir: PathBuf, + mut shutdown_rx: Receiver<()>, + election: Election, +) { + info!("started"); + let timeline: Arc = if let Some(tli) = GlobalTimelines::get_loaded(zttid) { + tli + } else { + /* Timeline could get deleted while task was starting, just exit then. */ + info!("no timeline, exiting"); + return; + }; + + let mut wb = WalBackupTask { + wal_seg_size: timeline.get_wal_seg_size(), + commit_lsn_watch_rx: timeline.get_commit_lsn_watch_rx(), + timeline, + timeline_dir, + leader: None, + election, + }; + + // task is spinned up only when wal_seg_size already initialized + assert!(wb.wal_seg_size > 0); + + let mut canceled = false; + select! { + _ = wb.run() => {} + _ = shutdown_rx.recv() => { + canceled = true; + } + } + if let Some(l) = wb.leader { + l.give_up().await; + } + info!("task {}", if canceled { "canceled" } else { "terminated" }); +} + +impl WalBackupTask { + async fn run(&mut self) { + let mut backup_lsn = Lsn(0); + + // election loop + loop { + let mut retry_attempt = 0u32; + + if let Some(l) = self.leader.take() { + l.give_up().await; + } + + match broker::get_leader(&self.election).await { + Ok(l) => { + self.leader = Some(l); + } + Err(e) => { + error!("error during leader election {:?}", e); + sleep(Duration::from_millis(BROKER_CONNECTION_RETRY_DELAY_MS)).await; + continue; + } + } + + // offload loop + loop { + if retry_attempt == 0 { + // wait for new WAL to arrive + if let Err(e) = self.commit_lsn_watch_rx.changed().await { + // should never happen, as we hold Arc to timeline. + error!("commit_lsn watch shut down: {:?}", e); + return; + } + } else { + // or just sleep if we errored previously + let mut retry_delay = UPLOAD_FAILURE_RETRY_MAX_MS; + if let Some(backoff_delay) = + UPLOAD_FAILURE_RETRY_MIN_MS.checked_shl(retry_attempt) + { + retry_delay = min(retry_delay, backoff_delay); + } + sleep(Duration::from_millis(retry_delay)).await; + } + + let commit_lsn = *self.commit_lsn_watch_rx.borrow(); + assert!( + commit_lsn >= backup_lsn, + "backup lsn should never pass commit lsn" + ); + + if backup_lsn.segment_number(self.wal_seg_size) + == commit_lsn.segment_number(self.wal_seg_size) + { + continue; /* nothing to do, common case as we wake up on every commit_lsn bump */ + } + // Perhaps peers advanced the position, check shmem value. + backup_lsn = self.timeline.get_wal_backup_lsn(); + if backup_lsn.segment_number(self.wal_seg_size) + == commit_lsn.segment_number(self.wal_seg_size) + { + continue; + } + + if let Some(l) = self.leader.as_mut() { + // Optimization idea for later: + // Avoid checking election leader every time by returning current lease grant expiration time + // Re-check leadership only after expiration time, + // such approach woud reduce overhead on write-intensive workloads + + match l + .check_am_i( + self.election.election_name.clone(), + self.election.candidate_name.clone(), + ) + .await + { + Ok(leader) => { + if !leader { + info!("leader has changed"); + break; + } + } + Err(e) => { + warn!("error validating leader, {:?}", e); + break; + } + } + } + + match backup_lsn_range( + backup_lsn, + commit_lsn, + self.wal_seg_size, + &self.timeline_dir, + ) + .await + { + Ok(backup_lsn_result) => { + backup_lsn = backup_lsn_result; + self.timeline.set_wal_backup_lsn(backup_lsn_result); + retry_attempt = 0; + } + Err(e) => { + error!( + "failed while offloading range {}-{}: {:?}", + backup_lsn, commit_lsn, e + ); + + retry_attempt = min(retry_attempt + 1, u32::MAX); + } + } + } + } + } +} + +pub async fn backup_lsn_range( + start_lsn: Lsn, + end_lsn: Lsn, + wal_seg_size: usize, + timeline_dir: &Path, +) -> Result { + let mut res = start_lsn; + let segments = get_segments(start_lsn, end_lsn, wal_seg_size); + for s in &segments { + backup_single_segment(s, timeline_dir) + .await + .with_context(|| format!("offloading segno {}", s.seg_no))?; + + res = s.end_lsn; + } + info!( + "offloaded segnos {:?} up to {}, previous backup_lsn {}", + segments.iter().map(|&s| s.seg_no).collect::>(), + end_lsn, + start_lsn, + ); + Ok(res) +} + +async fn backup_single_segment(seg: &Segment, timeline_dir: &Path) -> Result<()> { + let segment_file_name = seg.file_path(timeline_dir)?; + + backup_object(&segment_file_name, seg.size()).await?; + debug!("Backup of {} done", segment_file_name.display()); + + Ok(()) +} + +#[derive(Debug, Copy, Clone)] +pub struct Segment { + seg_no: XLogSegNo, + start_lsn: Lsn, + end_lsn: Lsn, +} + +impl Segment { + pub fn new(seg_no: u64, start_lsn: Lsn, end_lsn: Lsn) -> Self { + Self { + seg_no, + start_lsn, + end_lsn, + } + } + + pub fn object_name(self) -> String { + XLogFileName(PG_TLI, self.seg_no, self.size()) + } + + pub fn file_path(self, timeline_dir: &Path) -> Result { + Ok(timeline_dir.join(self.object_name())) + } + + pub fn size(self) -> usize { + (u64::from(self.end_lsn) - u64::from(self.start_lsn)) as usize + } +} + +fn get_segments(start: Lsn, end: Lsn, seg_size: usize) -> Vec { + let first_seg = start.segment_number(seg_size); + let last_seg = end.segment_number(seg_size); + + let res: Vec = (first_seg..last_seg) + .map(|s| { + let start_lsn = XLogSegNoOffsetToRecPtr(s, 0, seg_size); + let end_lsn = XLogSegNoOffsetToRecPtr(s + 1, 0, seg_size); + Segment::new(s, Lsn::from(start_lsn), Lsn::from(end_lsn)) + }) + .collect(); + res +} + +static REMOTE_STORAGE: OnceCell> = OnceCell::new(); + +async fn backup_object(source_file: &Path, size: usize) -> Result<()> { + let storage = REMOTE_STORAGE.get().expect("failed to get remote storage"); + + let file = File::open(&source_file).await?; + + // Storage is initialized by launcher at ths point. + match storage.as_ref().unwrap() { + GenericRemoteStorage::Local(local_storage) => { + let destination = local_storage.remote_object_id(source_file)?; + + debug!( + "local upload about to start from {} to {}", + source_file.display(), + destination.display() + ); + local_storage.upload(file, size, &destination, None).await + } + GenericRemoteStorage::S3(s3_storage) => { + let s3key = s3_storage.remote_object_id(source_file)?; + + debug!( + "S3 upload about to start from {} to {:?}", + source_file.display(), + s3key + ); + s3_storage.upload(file, size, &s3key, None).await + } + }?; + + Ok(()) +} diff --git a/test_runner/batch_others/test_wal_acceptor.py b/test_runner/batch_others/test_wal_acceptor.py index e1b7bd91ee..fc192c28e8 100644 --- a/test_runner/batch_others/test_wal_acceptor.py +++ b/test_runner/batch_others/test_wal_acceptor.py @@ -12,7 +12,7 @@ from contextlib import closing from dataclasses import dataclass, field from multiprocessing import Process, Value from pathlib import Path -from fixtures.zenith_fixtures import PgBin, Etcd, Postgres, Safekeeper, ZenithEnv, ZenithEnvBuilder, PortDistributor, SafekeeperPort, zenith_binpath, PgProtocol +from fixtures.zenith_fixtures import PgBin, Etcd, Postgres, RemoteStorageUsers, Safekeeper, ZenithEnv, ZenithEnvBuilder, PortDistributor, SafekeeperPort, zenith_binpath, PgProtocol from fixtures.utils import get_dir_size, lsn_to_hex, mkdir_if_needed, lsn_from_hex from fixtures.log_helper import log from typing import List, Optional, Any @@ -401,7 +401,7 @@ def test_wal_removal(zenith_env_builder: ZenithEnvBuilder): http_cli = env.safekeepers[0].http_client() # Pretend WAL is offloaded to s3. - http_cli.record_safekeeper_info(tenant_id, timeline_id, {'s3_wal_lsn': 'FFFFFFFF/FEFFFFFF'}) + http_cli.record_safekeeper_info(tenant_id, timeline_id, {'backup_lsn': 'FFFFFFFF/FEFFFFFF'}) # wait till first segment is removed on all safekeepers started_at = time.time() @@ -414,6 +414,56 @@ def test_wal_removal(zenith_env_builder: ZenithEnvBuilder): time.sleep(0.5) +@pytest.mark.parametrize('storage_type', ['mock_s3', 'local_fs']) +def test_wal_backup(zenith_env_builder: ZenithEnvBuilder, storage_type: str): + zenith_env_builder.num_safekeepers = 3 + if storage_type == 'local_fs': + zenith_env_builder.enable_local_fs_remote_storage() + elif storage_type == 'mock_s3': + zenith_env_builder.enable_s3_mock_remote_storage('test_safekeepers_wal_backup') + else: + raise RuntimeError(f'Unknown storage type: {storage_type}') + zenith_env_builder.remote_storage_users = RemoteStorageUsers.SAFEKEEPER + + env = zenith_env_builder.init_start() + + env.zenith_cli.create_branch('test_safekeepers_wal_backup') + pg = env.postgres.create_start('test_safekeepers_wal_backup') + + # learn zenith timeline from compute + tenant_id = pg.safe_psql("show zenith.zenith_tenant")[0][0] + timeline_id = pg.safe_psql("show zenith.zenith_timeline")[0][0] + + pg_conn = pg.connect() + cur = pg_conn.cursor() + cur.execute('create table t(key int, value text)') + + # Shut down subsequently each of safekeepers and fill a segment while sk is + # down; ensure segment gets offloaded by others. + offloaded_seg_end = ['0/2000000', '0/3000000', '0/4000000'] + for victim, seg_end in zip(env.safekeepers, offloaded_seg_end): + victim.stop() + # roughly fills one segment + cur.execute("insert into t select generate_series(1,250000), 'payload'") + live_sk = [sk for sk in env.safekeepers if sk != victim][0] + http_cli = live_sk.http_client() + + started_at = time.time() + while True: + tli_status = http_cli.timeline_status(tenant_id, timeline_id) + log.info(f"live sk status is {tli_status}") + + if lsn_from_hex(tli_status.backup_lsn) >= lsn_from_hex(seg_end): + break + elapsed = time.time() - started_at + if elapsed > 20: + raise RuntimeError( + f"timed out waiting {elapsed:.0f}s segment ending at {seg_end} get offloaded") + time.sleep(0.5) + + victim.start() + + class ProposerPostgres(PgProtocol): """Object for running postgres without ZenithEnv""" def __init__(self, diff --git a/test_runner/fixtures/zenith_fixtures.py b/test_runner/fixtures/zenith_fixtures.py index 7f5b2ad2aa..a2e8c82d30 100644 --- a/test_runner/fixtures/zenith_fixtures.py +++ b/test_runner/fixtures/zenith_fixtures.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import field +from enum import Flag, auto import textwrap from cached_property import cached_property import asyncpg @@ -421,10 +422,51 @@ class MockS3Server: def secret_key(self) -> str: return 'test' + def access_env_vars(self) -> Dict[Any, Any]: + return { + 'AWS_ACCESS_KEY_ID': self.access_key(), + 'AWS_SECRET_ACCESS_KEY': self.secret_key(), + } + def kill(self): self.subprocess.kill() +@dataclass +class LocalFsStorage: + local_path: Path + + +@dataclass +class S3Storage: + bucket_name: str + bucket_region: str + endpoint: Optional[str] + + +RemoteStorage = Union[LocalFsStorage, S3Storage] + + +# serialize as toml inline table +def remote_storage_to_toml_inline_table(remote_storage): + if isinstance(remote_storage, LocalFsStorage): + res = f"local_path='{remote_storage.local_path}'" + elif isinstance(remote_storage, S3Storage): + res = f"bucket_name='{remote_storage.bucket_name}', bucket_region='{remote_storage.bucket_region}'" + if remote_storage.endpoint is not None: + res += f", endpoint='{remote_storage.endpoint}'" + else: + raise Exception(f'Unknown storage configuration {remote_storage}') + else: + raise Exception("invalid remote storage type") + return f"{{{res}}}" + + +class RemoteStorageUsers(Flag): + PAGESERVER = auto() + SAFEKEEPER = auto() + + class ZenithEnvBuilder: """ Builder object to create a Zenith runtime environment @@ -440,6 +482,7 @@ class ZenithEnvBuilder: broker: Etcd, mock_s3_server: MockS3Server, remote_storage: Optional[RemoteStorage] = None, + remote_storage_users: RemoteStorageUsers = RemoteStorageUsers.PAGESERVER, pageserver_config_override: Optional[str] = None, num_safekeepers: int = 1, pageserver_auth_enabled: bool = False, @@ -449,6 +492,7 @@ class ZenithEnvBuilder: self.rust_log_override = rust_log_override self.port_distributor = port_distributor self.remote_storage = remote_storage + self.remote_storage_users = remote_storage_users self.broker = broker self.mock_s3_server = mock_s3_server self.pageserver_config_override = pageserver_config_override @@ -497,9 +541,9 @@ class ZenithEnvBuilder: aws_access_key_id=self.mock_s3_server.access_key(), aws_secret_access_key=self.mock_s3_server.secret_key(), ).create_bucket(Bucket=bucket_name) - self.remote_storage = S3Storage(bucket=bucket_name, + self.remote_storage = S3Storage(bucket_name=bucket_name, endpoint=mock_endpoint, - region=mock_region) + bucket_region=mock_region) def __enter__(self): return self @@ -557,6 +601,7 @@ class ZenithEnv: self.safekeepers: List[Safekeeper] = [] self.broker = config.broker self.remote_storage = config.remote_storage + self.remote_storage_users = config.remote_storage_users # generate initial tenant ID here instead of letting 'zenith init' generate it, # so that we don't need to dig it out of the config file afterwards. @@ -605,8 +650,12 @@ class ZenithEnv: id = {id} pg_port = {port.pg} http_port = {port.http} - sync = false # Disable fsyncs to make the tests go faster - """) + sync = false # Disable fsyncs to make the tests go faster""") + if bool(self.remote_storage_users + & RemoteStorageUsers.SAFEKEEPER) and self.remote_storage is not None: + toml += textwrap.dedent(f""" + remote_storage = "{remote_storage_to_toml_inline_table(self.remote_storage)}" + """) safekeeper = Safekeeper(env=self, id=id, port=port) self.safekeepers.append(safekeeper) @@ -638,7 +687,7 @@ def _shared_simple_env(request: Any, mock_s3_server: MockS3Server, default_broker: Etcd) -> Iterator[ZenithEnv]: """ - Internal fixture backing the `zenith_simple_env` fixture. If TEST_SHARED_FIXTURES + # Internal fixture backing the `zenith_simple_env` fixture. If TEST_SHARED_FIXTURES is set, this is shared by all tests using `zenith_simple_env`. """ @@ -822,20 +871,6 @@ class PageserverPort: http: int -@dataclass -class LocalFsStorage: - root: Path - - -@dataclass -class S3Storage: - bucket: str - region: str - endpoint: Optional[str] - - -RemoteStorage = Union[LocalFsStorage, S3Storage] - CREATE_TIMELINE_ID_EXTRACTOR = re.compile(r"^Created timeline '(?P[^']+)'", re.MULTILINE) CREATE_TIMELINE_ID_EXTRACTOR = re.compile(r"^Created timeline '(?P[^']+)'", @@ -998,6 +1033,7 @@ class ZenithCli: append_pageserver_param_overrides( params_to_update=cmd, remote_storage=self.env.remote_storage, + remote_storage_users=self.env.remote_storage_users, pageserver_config_override=self.env.pageserver.config_override) res = self.raw_cli(cmd) @@ -1022,14 +1058,10 @@ class ZenithCli: append_pageserver_param_overrides( params_to_update=start_args, remote_storage=self.env.remote_storage, + remote_storage_users=self.env.remote_storage_users, pageserver_config_override=self.env.pageserver.config_override) - s3_env_vars = None - if self.env.s3_mock_server: - s3_env_vars = { - 'AWS_ACCESS_KEY_ID': self.env.s3_mock_server.access_key(), - 'AWS_SECRET_ACCESS_KEY': self.env.s3_mock_server.secret_key(), - } + s3_env_vars = self.env.s3_mock_server.access_env_vars() if self.env.s3_mock_server else None return self.raw_cli(start_args, extra_env_vars=s3_env_vars) def pageserver_stop(self, immediate=False) -> 'subprocess.CompletedProcess[str]': @@ -1041,7 +1073,8 @@ class ZenithCli: return self.raw_cli(cmd) def safekeeper_start(self, id: int) -> 'subprocess.CompletedProcess[str]': - return self.raw_cli(['safekeeper', 'start', str(id)]) + s3_env_vars = self.env.s3_mock_server.access_env_vars() if self.env.s3_mock_server else None + return self.raw_cli(['safekeeper', 'start', str(id)], extra_env_vars=s3_env_vars) def safekeeper_stop(self, id: Optional[int] = None, @@ -1237,22 +1270,13 @@ class ZenithPageserver(PgProtocol): def append_pageserver_param_overrides( params_to_update: List[str], remote_storage: Optional[RemoteStorage], + remote_storage_users: RemoteStorageUsers, pageserver_config_override: Optional[str] = None, ): - if remote_storage is not None: - if isinstance(remote_storage, LocalFsStorage): - pageserver_storage_override = f"local_path='{remote_storage.root}'" - elif isinstance(remote_storage, S3Storage): - pageserver_storage_override = f"bucket_name='{remote_storage.bucket}',\ - bucket_region='{remote_storage.region}'" - - if remote_storage.endpoint is not None: - pageserver_storage_override += f",endpoint='{remote_storage.endpoint}'" - - else: - raise Exception(f'Unknown storage configuration {remote_storage}') + if bool(remote_storage_users & RemoteStorageUsers.PAGESERVER) and remote_storage is not None: + remote_storage_toml_table = remote_storage_to_toml_inline_table(remote_storage) params_to_update.append( - f'--pageserver-config-override=remote_storage={{{pageserver_storage_override}}}') + f'--pageserver-config-override=remote_storage={remote_storage_toml_table}') env_overrides = os.getenv('ZENITH_PAGESERVER_OVERRIDES') if env_overrides is not None: @@ -1786,8 +1810,9 @@ class Safekeeper: class SafekeeperTimelineStatus: acceptor_epoch: int flush_lsn: str - remote_consistent_lsn: str timeline_start_lsn: str + backup_lsn: str + remote_consistent_lsn: str @dataclass @@ -1812,8 +1837,9 @@ class SafekeeperHttpClient(requests.Session): resj = res.json() return SafekeeperTimelineStatus(acceptor_epoch=resj['acceptor_state']['epoch'], flush_lsn=resj['flush_lsn'], - remote_consistent_lsn=resj['remote_consistent_lsn'], - timeline_start_lsn=resj['timeline_start_lsn']) + timeline_start_lsn=resj['timeline_start_lsn'], + backup_lsn=resj['backup_lsn'], + remote_consistent_lsn=resj['remote_consistent_lsn']) def record_safekeeper_info(self, tenant_id: str, timeline_id: str, body): res = self.post( From 54b75248ff53cd3530916200d9156a491c16b8dd Mon Sep 17 00:00:00 2001 From: Arseny Sher Date: Fri, 27 May 2022 13:09:17 +0400 Subject: [PATCH 21/27] s3 WAL offloading staging review. - Uncomment accidently `self.keep_alive.abort()` commented line, due to this task never finished, which blocked launcher. - Mess up with initialization one more time, to fix offloader trying to back up segment 0. Now we initialize all required LSNs in handle_elected, where we learn start LSN for the first time. - Fix blind attempt to provide safekeeper service file with remote storage params. --- .circleci/ansible/systemd/safekeeper.service | 2 +- libs/utils/src/zid.rs | 2 +- safekeeper/src/broker.rs | 2 +- safekeeper/src/safekeeper.rs | 50 +++++++++----------- safekeeper/src/wal_backup.rs | 19 ++++---- 5 files changed, 35 insertions(+), 40 deletions(-) diff --git a/.circleci/ansible/systemd/safekeeper.service b/.circleci/ansible/systemd/safekeeper.service index a6b443c3e7..e4a395a60e 100644 --- a/.circleci/ansible/systemd/safekeeper.service +++ b/.circleci/ansible/systemd/safekeeper.service @@ -6,7 +6,7 @@ After=network.target auditd.service Type=simple User=safekeeper Environment=RUST_BACKTRACE=1 ZENITH_REPO_DIR=/storage/safekeeper/data LD_LIBRARY_PATH=/usr/local/lib -ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}.local:6500 --listen-http {{ inventory_hostname }}.local:7676 -p {{ first_pageserver }}:6400 -D /storage/safekeeper/data --broker-endpoints={{ etcd_endpoints }} --remote_storage='{bucket_name={{bucket_name}}, bucket_region={{bucket_region}}, prefix_in_bucket=wal}' +ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}.local:6500 --listen-http {{ inventory_hostname }}.local:7676 -p {{ first_pageserver }}:6400 -D /storage/safekeeper/data --broker-endpoints={{ etcd_endpoints }} --remote-storage='{bucket_name="{{bucket_name}}", bucket_region="{{bucket_region}}", prefix_in_bucket="wal"}' ExecReload=/bin/kill -HUP $MAINPID KillMode=mixed KillSignal=SIGINT diff --git a/libs/utils/src/zid.rs b/libs/utils/src/zid.rs index 02f781c49a..0ef174da4d 100644 --- a/libs/utils/src/zid.rs +++ b/libs/utils/src/zid.rs @@ -218,7 +218,7 @@ impl ZTenantTimelineId { impl fmt::Display for ZTenantTimelineId { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}-{}", self.tenant_id, self.timeline_id) + write!(f, "{}/{}", self.tenant_id, self.timeline_id) } } diff --git a/safekeeper/src/broker.rs b/safekeeper/src/broker.rs index 676719b60d..5bcb197205 100644 --- a/safekeeper/src/broker.rs +++ b/safekeeper/src/broker.rs @@ -83,7 +83,7 @@ impl ElectionLeader { } pub async fn give_up(self) { - // self.keep_alive.abort(); + self.keep_alive.abort(); // TODO: it'll be wise to resign here but it'll happen after lease expiration anyway // should we await for keep alive termination? let _ = self.keep_alive.await; diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index 9a07127771..0a7adb96b6 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -731,24 +731,36 @@ where { let mut state = self.state.clone(); - // Remeber point where WAL begins globally, if not yet. + // Here we learn initial LSN for the first time, set fields + // interested in that. + if state.timeline_start_lsn == Lsn(0) { + // Remember point where WAL begins globally. state.timeline_start_lsn = msg.timeline_start_lsn; info!( "setting timeline_start_lsn to {:?}", state.timeline_start_lsn ); - } - // Remember point where WAL begins locally, if not yet. (I doubt the - // second condition is ever possible) - if state.local_start_lsn == Lsn(0) || state.local_start_lsn >= msg.start_streaming_at { state.local_start_lsn = msg.start_streaming_at; info!("setting local_start_lsn to {:?}", state.local_start_lsn); } + // Initializing commit_lsn before acking first flushed record is + // important to let find_end_of_wal skip the whole in the beginning + // of the first segment. + // + // NB: on new clusters, this happens at the same time as + // timeline_start_lsn initialization, it is taken outside to provide + // upgrade. + self.global_commit_lsn = max(self.global_commit_lsn, state.timeline_start_lsn); + self.inmem.commit_lsn = max(self.inmem.commit_lsn, state.timeline_start_lsn); + self.metrics.commit_lsn.set(self.inmem.commit_lsn.0 as f64); + + // Initalizing backup_lsn is useful to avoid making backup think it should upload 0 segment. + self.inmem.backup_lsn = max(self.inmem.backup_lsn, state.timeline_start_lsn); state.acceptor_state.term_history = msg.term_history.clone(); - self.state.persist(&state)?; + self.persist_control_file(state)?; } info!("start receiving WAL since {:?}", msg.start_streaming_at); @@ -764,14 +776,6 @@ where self.inmem.commit_lsn = commit_lsn; self.metrics.commit_lsn.set(self.inmem.commit_lsn.0 as f64); - // We got our first commit_lsn, which means we should sync - // everything to disk, to initialize the state. - if self.state.commit_lsn == Lsn::INVALID && commit_lsn != Lsn::INVALID { - self.inmem.backup_lsn = self.inmem.commit_lsn; // initialize backup_lsn - self.wal_store.flush_wal()?; - self.persist_control_file()?; - } - // If new commit_lsn reached epoch switch, force sync of control // file: walproposer in sync mode is very interested when this // happens. Note: this is for sync-safekeepers mode only, as @@ -780,15 +784,14 @@ where // that we receive new epoch_start_lsn, and we still need to sync // control file in this case. if commit_lsn == self.epoch_start_lsn && self.state.commit_lsn != commit_lsn { - self.persist_control_file()?; + self.persist_control_file(self.state.clone())?; } Ok(()) } - /// Persist in-memory state to the disk. - fn persist_control_file(&mut self) -> Result<()> { - let mut state = self.state.clone(); + /// Persist in-memory state to the disk, taking other data from state. + fn persist_control_file(&mut self, mut state: SafeKeeperState) -> Result<()> { state.commit_lsn = self.inmem.commit_lsn; state.backup_lsn = self.inmem.backup_lsn; state.peer_horizon_lsn = self.inmem.peer_horizon_lsn; @@ -823,13 +826,6 @@ where // do the job if !msg.wal_data.is_empty() { self.wal_store.write_wal(msg.h.begin_lsn, &msg.wal_data)?; - - // If this was the first record we ever received, initialize - // commit_lsn to help find_end_of_wal skip the hole in the - // beginning. - if self.global_commit_lsn == Lsn(0) { - self.global_commit_lsn = msg.h.begin_lsn; - } } // flush wal to the disk, if required @@ -852,7 +848,7 @@ where if self.state.peer_horizon_lsn + (self.state.server.wal_seg_size as u64) < self.inmem.peer_horizon_lsn { - self.persist_control_file()?; + self.persist_control_file(self.state.clone())?; } trace!( @@ -920,7 +916,7 @@ where self.inmem.peer_horizon_lsn = new_peer_horizon_lsn; } if sync_control_file { - self.persist_control_file()?; + self.persist_control_file(self.state.clone())?; } Ok(()) } diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index ef8ebe14e1..83dc312d28 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -71,7 +71,7 @@ async fn wal_backup_launcher_main_loop( mut wal_backup_launcher_rx: Receiver, ) { info!( - "wal backup launcher started, remote config {:?}", + "WAL backup launcher: started, remote config {:?}", conf.remote_storage ); @@ -95,7 +95,7 @@ async fn wal_backup_launcher_main_loop( if is_wal_backup_required != tasks.contains_key(&zttid) { if is_wal_backup_required { // need to start the task - info!("starting wal backup task for {}", zttid); + info!("starting WAL backup task for {}", zttid); // TODO: decide who should offload in launcher itself by simply checking current state let election_name = broker::get_campaign_name( @@ -115,7 +115,7 @@ async fn wal_backup_launcher_main_loop( let handle = tokio::spawn( backup_task_main(zttid, timeline_dir, shutdown_rx, election) - .instrument(info_span!("WAL backup", zttid = %zttid)), + .instrument(info_span!("WAL backup task", zttid = %zttid)), ); tasks.insert( @@ -127,7 +127,7 @@ async fn wal_backup_launcher_main_loop( ); } else { // need to stop the task - info!("stopping wal backup task for {}", zttid); + info!("stopping WAL backup task for {}", zttid); let wb_handle = tasks.remove(&zttid).unwrap(); // Tell the task to shutdown. Error means task exited earlier, that's ok. @@ -236,20 +236,19 @@ impl WalBackupTask { } let commit_lsn = *self.commit_lsn_watch_rx.borrow(); - assert!( - commit_lsn >= backup_lsn, - "backup lsn should never pass commit lsn" - ); + // Note that backup_lsn can be higher than commit_lsn if we + // don't have much local WAL and others already uploaded + // segments we don't even have. if backup_lsn.segment_number(self.wal_seg_size) - == commit_lsn.segment_number(self.wal_seg_size) + >= commit_lsn.segment_number(self.wal_seg_size) { continue; /* nothing to do, common case as we wake up on every commit_lsn bump */ } // Perhaps peers advanced the position, check shmem value. backup_lsn = self.timeline.get_wal_backup_lsn(); if backup_lsn.segment_number(self.wal_seg_size) - == commit_lsn.segment_number(self.wal_seg_size) + >= commit_lsn.segment_number(self.wal_seg_size) { continue; } From 75f71a63801c687a8bebe6aea28d751da52ac677 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Fri, 27 May 2022 11:43:06 -0400 Subject: [PATCH 22/27] Handle broken timelines on startup (#1809) Resolve #1663. ## Changes - ignore a "broken" [1] timeline on page server startup - fix the race condition when creating multiple timelines in parallel for a tenant - added tests for the above changes [1]: a timeline is marked as "broken" if either - failed to load the timeline's metadata or - the timeline's disk consistent LSN is zero --- pageserver/src/layered_repository.rs | 2 +- pageserver/src/tenant_mgr.rs | 31 +++++++++++++++- pageserver/src/timelines.rs | 9 ++++- .../batch_others/test_broken_timeline.py | 37 ++++++++++++++++++- 4 files changed, 74 insertions(+), 5 deletions(-) diff --git a/pageserver/src/layered_repository.rs b/pageserver/src/layered_repository.rs index d10c795214..0d7c6f54c8 100644 --- a/pageserver/src/layered_repository.rs +++ b/pageserver/src/layered_repository.rs @@ -2518,7 +2518,7 @@ fn rename_to_backup(path: PathBuf) -> anyhow::Result<()> { bail!("couldn't find an unused backup number for {:?}", path) } -fn load_metadata( +pub fn load_metadata( conf: &'static PageServerConf, timeline_id: ZTimelineId, tenant_id: ZTenantId, diff --git a/pageserver/src/tenant_mgr.rs b/pageserver/src/tenant_mgr.rs index bba67394c3..cc35d79d16 100644 --- a/pageserver/src/tenant_mgr.rs +++ b/pageserver/src/tenant_mgr.rs @@ -2,7 +2,7 @@ //! page server. use crate::config::PageServerConf; -use crate::layered_repository::LayeredRepository; +use crate::layered_repository::{load_metadata, LayeredRepository}; use crate::pgdatadir_mapping::DatadirTimeline; use crate::repository::{Repository, TimelineSyncStatusUpdate}; use crate::storage_sync::index::RemoteIndex; @@ -22,6 +22,7 @@ use std::collections::HashMap; use std::fmt; use std::sync::Arc; use tracing::*; +use utils::lsn::Lsn; use utils::zid::{ZTenantId, ZTimelineId}; @@ -399,6 +400,26 @@ pub fn list_tenants() -> Vec { .collect() } +/// Check if a given timeline is "broken" \[1\]. +/// The function returns an error if the timeline is "broken". +/// +/// \[1\]: it's not clear now how should we classify a timeline as broken. +/// A timeline is categorized as broken when any of following conditions is true: +/// - failed to load the timeline's metadata +/// - the timeline's disk consistent LSN is zero +fn check_broken_timeline(repo: &LayeredRepository, timeline_id: ZTimelineId) -> anyhow::Result<()> { + let metadata = load_metadata(repo.conf, timeline_id, repo.tenant_id()) + .context("failed to load metadata")?; + + // A timeline with zero disk consistent LSN can happen when the page server + // failed to checkpoint the timeline import data when creating that timeline. + if metadata.disk_consistent_lsn() == Lsn::INVALID { + bail!("Timeline {timeline_id} has a zero disk consistent LSN."); + } + + Ok(()) +} + fn init_local_repository( conf: &'static PageServerConf, tenant_id: ZTenantId, @@ -414,7 +435,13 @@ fn init_local_repository( match init_status { LocalTimelineInitStatus::LocallyComplete => { debug!("timeline {timeline_id} for tenant {tenant_id} is locally complete, registering it in repository"); - status_updates.insert(timeline_id, TimelineSyncStatusUpdate::Downloaded); + if let Err(err) = check_broken_timeline(&repo, timeline_id) { + info!( + "Found a broken timeline {timeline_id} (err={err:?}), skip registering it in repository" + ); + } else { + status_updates.insert(timeline_id, TimelineSyncStatusUpdate::Downloaded); + } } LocalTimelineInitStatus::NeedsSync => { debug!( diff --git a/pageserver/src/timelines.rs b/pageserver/src/timelines.rs index 408eca6501..9ab063107c 100644 --- a/pageserver/src/timelines.rs +++ b/pageserver/src/timelines.rs @@ -285,7 +285,9 @@ fn bootstrap_timeline( ) -> Result<()> { let _enter = info_span!("bootstrapping", timeline = %tli, tenant = %tenantid).entered(); - let initdb_path = conf.tenant_path(&tenantid).join("tmp"); + let initdb_path = conf + .tenant_path(&tenantid) + .join(format!("tmp-timeline-{}", tli)); // Init temporarily repo to get bootstrap data run_initdb(conf, &initdb_path)?; @@ -300,6 +302,11 @@ fn bootstrap_timeline( let timeline = repo.create_empty_timeline(tli, lsn)?; let mut page_tline: DatadirTimeline = DatadirTimeline::new(timeline, u64::MAX); import_datadir::import_timeline_from_postgres_datadir(&pgdata_path, &mut page_tline, lsn)?; + + fail::fail_point!("before-checkpoint-new-timeline", |_| { + bail!("failpoint before-checkpoint-new-timeline"); + }); + page_tline.tline.checkpoint(CheckpointConfig::Forced)?; info!( diff --git a/test_runner/batch_others/test_broken_timeline.py b/test_runner/batch_others/test_broken_timeline.py index 17eadb33b4..f0aa44e0a4 100644 --- a/test_runner/batch_others/test_broken_timeline.py +++ b/test_runner/batch_others/test_broken_timeline.py @@ -1,6 +1,7 @@ import pytest +import concurrent.futures from contextlib import closing -from fixtures.zenith_fixtures import ZenithEnvBuilder +from fixtures.zenith_fixtures import ZenithEnvBuilder, ZenithEnv from fixtures.log_helper import log import os @@ -78,3 +79,37 @@ def test_broken_timeline(zenith_env_builder: ZenithEnvBuilder): with pytest.raises(Exception, match="Cannot load local timeline") as err: pg.start() log.info(f'compute startup failed as expected: {err}') + + +def test_create_multiple_timelines_parallel(zenith_simple_env: ZenithEnv): + env = zenith_simple_env + + tenant_id, _ = env.zenith_cli.create_tenant() + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [ + executor.submit(env.zenith_cli.create_timeline, + f"test-create-multiple-timelines-{i}", + tenant_id) for i in range(4) + ] + for future in futures: + future.result() + + +def test_fix_broken_timelines_on_startup(zenith_simple_env: ZenithEnv): + env = zenith_simple_env + + tenant_id, _ = env.zenith_cli.create_tenant() + + # Introduce failpoint when creating a new timeline + env.pageserver.safe_psql(f"failpoints before-checkpoint-new-timeline=return") + with pytest.raises(Exception, match="before-checkpoint-new-timeline"): + _ = env.zenith_cli.create_timeline("test_fix_broken_timelines", tenant_id) + + # Restart the page server + env.zenith_cli.pageserver_stop(immediate=True) + env.zenith_cli.pageserver_start() + + # Check that the "broken" timeline is not loaded + timelines = env.zenith_cli.list_timelines(tenant_id) + assert len(timelines) == 1 From cb8bf1beb606fa97eeee0f038d28af4c7327af34 Mon Sep 17 00:00:00 2001 From: Arseny Sher Date: Fri, 27 May 2022 14:10:10 +0400 Subject: [PATCH 23/27] Prevent commit_lsn <= flush_lsn violation after a42eba3cd7. Nothing complained about that yet, but we definitely don't hold at least one assert, so let's keep it this way until better version. --- safekeeper/src/safekeeper.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index 0a7adb96b6..c254f2c57c 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -576,13 +576,16 @@ where self.state .acceptor_state .term_history - .up_to(self.wal_store.flush_lsn()) + .up_to(self.flush_lsn()) } pub fn get_epoch(&self) -> Term { - self.state - .acceptor_state - .get_epoch(self.wal_store.flush_lsn()) + self.state.acceptor_state.get_epoch(self.flush_lsn()) + } + + /// wal_store wrapper avoiding commit_lsn <= flush_lsn violation when we don't have WAL yet. + fn flush_lsn(&self) -> Lsn { + max(self.wal_store.flush_lsn(), self.state.timeline_start_lsn) } /// Process message from proposer and possibly form reply. Concurrent @@ -671,7 +674,7 @@ where let mut resp = VoteResponse { term: self.state.acceptor_state.term, vote_given: false as u64, - flush_lsn: self.wal_store.flush_lsn(), + flush_lsn: self.flush_lsn(), truncate_lsn: self.state.peer_horizon_lsn, term_history: self.get_term_history(), timeline_start_lsn: self.state.timeline_start_lsn, @@ -703,7 +706,7 @@ where fn append_response(&self) -> AppendResponse { let ar = AppendResponse { term: self.state.acceptor_state.term, - flush_lsn: self.wal_store.flush_lsn(), + flush_lsn: self.flush_lsn(), commit_lsn: self.state.commit_lsn, // will be filled by the upper code to avoid bothering safekeeper hs_feedback: HotStandbyFeedback::empty(), @@ -770,7 +773,7 @@ where /// Advance commit_lsn taking into account what we have locally pub fn update_commit_lsn(&mut self) -> Result<()> { - let commit_lsn = min(self.global_commit_lsn, self.wal_store.flush_lsn()); + let commit_lsn = min(self.global_commit_lsn, self.flush_lsn()); assert!(commit_lsn >= self.inmem.commit_lsn); self.inmem.commit_lsn = commit_lsn; From 757746b5717eec6e0c338e41f19844ec077852e7 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Fri, 27 May 2022 13:33:53 -0400 Subject: [PATCH 24/27] Fix `test_pageserver_http_get_wal_receiver_success` flaky test. (#1786) Fixes #1768. ## Context Previously, to test `get_wal_receiver` API, we make run some DB transactions then call the API to check the latest message's LSN from the WAL receiver. However, this test won't work because it's not guaranteed that the WAL receiver will get the latest WAL from the postgres/safekeeper at the time of making the API call. This PR resolves the above issue by adding a "poll and wait" code that waits to retrieve the latest data from the WAL receiver. This PR also fixes a bug that tries to compare two hex LSNs, should convert to number before the comparison. See: https://github.com/neondatabase/neon/issues/1768#issuecomment-1133752122. --- .../batch_others/test_pageserver_api.py | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/test_runner/batch_others/test_pageserver_api.py b/test_runner/batch_others/test_pageserver_api.py index 7fe3b4dff5..2b0e5ae8bd 100644 --- a/test_runner/batch_others/test_pageserver_api.py +++ b/test_runner/batch_others/test_pageserver_api.py @@ -1,11 +1,14 @@ +from typing import Optional from uuid import uuid4, UUID import pytest +from fixtures.utils import lsn_from_hex from fixtures.zenith_fixtures import ( DEFAULT_BRANCH_NAME, ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient, ZenithPageserverApiException, + wait_until, ) @@ -73,18 +76,35 @@ def test_pageserver_http_get_wal_receiver_success(zenith_simple_env: ZenithEnv): tenant_id, timeline_id = env.zenith_cli.create_tenant() pg = env.postgres.create_start(DEFAULT_BRANCH_NAME, tenant_id=tenant_id) - res = client.wal_receiver_get(tenant_id, timeline_id) - assert list(res.keys()) == [ - "thread_id", - "wal_producer_connstr", - "last_received_msg_lsn", - "last_received_msg_ts", - ] + def expect_updated_msg_lsn(prev_msg_lsn: Optional[int]) -> int: + res = client.wal_receiver_get(tenant_id, timeline_id) - # make a DB modification then expect getting a new WAL receiver's data + # a successful `wal_receiver_get` response must contain the below fields + assert list(res.keys()) == [ + "thread_id", + "wal_producer_connstr", + "last_received_msg_lsn", + "last_received_msg_ts", + ] + + assert res["last_received_msg_lsn"] is not None, "the last received message's LSN is empty" + + last_msg_lsn = lsn_from_hex(res["last_received_msg_lsn"]) + assert prev_msg_lsn is None or prev_msg_lsn < last_msg_lsn, \ + f"the last received message's LSN {last_msg_lsn} hasn't been updated \ + compared to the previous message's LSN {prev_msg_lsn}" + + return last_msg_lsn + + # Wait to make sure that we get a latest WAL receiver data. + # We need to wait here because it's possible that we don't have access to + # the latest WAL during the time the `wal_receiver_get` API is called. + # See: https://github.com/neondatabase/neon/issues/1768. + lsn = wait_until(number_of_iterations=5, interval=1, func=lambda: expect_updated_msg_lsn(None)) + + # Make a DB modification then expect getting a new WAL receiver's data. pg.safe_psql("CREATE TABLE t(key int primary key, value text)") - res2 = client.wal_receiver_get(tenant_id, timeline_id) - assert res2["last_received_msg_lsn"] > res["last_received_msg_lsn"] + wait_until(number_of_iterations=5, interval=1, func=lambda: expect_updated_msg_lsn(lsn)) def test_pageserver_http_api_client(zenith_simple_env: ZenithEnv): From 5d813f97386b34f020c8051bf2c5a1b06dc4e408 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Wed, 18 May 2022 16:01:56 +0300 Subject: [PATCH 25/27] [proxy] Refactoring This patch attempts to fix some of the technical debt we had to introduce in previous patches. --- proxy/src/auth.rs | 67 ++--- proxy/src/auth/backend.rs | 109 ++++++++ proxy/src/auth/backend/console.rs | 225 ++++++++++++++++ .../backend}/legacy_console.rs | 32 ++- .../{auth_backend => auth/backend}/link.rs | 14 +- proxy/src/auth/backend/postgres.rs | 88 +++++++ proxy/src/auth/credentials.rs | 30 ++- proxy/src/auth/flow.rs | 6 +- proxy/src/auth_backend.rs | 31 --- proxy/src/auth_backend/console.rs | 243 ------------------ proxy/src/auth_backend/postgres.rs | 93 ------- proxy/src/compute.rs | 4 +- proxy/src/config.rs | 35 ++- proxy/src/main.rs | 2 +- proxy/src/mgmt.rs | 8 +- proxy/src/url.rs | 82 ++++++ 16 files changed, 599 insertions(+), 470 deletions(-) create mode 100644 proxy/src/auth/backend.rs create mode 100644 proxy/src/auth/backend/console.rs rename proxy/src/{auth_backend => auth/backend}/legacy_console.rs (90%) rename proxy/src/{auth_backend => auth/backend}/link.rs (75%) create mode 100644 proxy/src/auth/backend/postgres.rs delete mode 100644 proxy/src/auth_backend.rs delete mode 100644 proxy/src/auth_backend/console.rs delete mode 100644 proxy/src/auth_backend/postgres.rs create mode 100644 proxy/src/url.rs diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 2463f31645..082a7bcf20 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,56 +1,58 @@ -mod credentials; -mod flow; +//! Client authentication mechanisms. -use crate::auth_backend::{console, legacy_console, link, postgres}; -use crate::config::{AuthBackendType, ProxyConfig}; -use crate::error::UserFacingError; -use crate::stream::PqStream; -use crate::{auth_backend, compute, waiters}; -use console::ConsoleAuthError::SniMissing; +pub mod backend; +pub use backend::DatabaseInfo; + +mod credentials; +pub use credentials::ClientCredentials; + +mod flow; +pub use flow::*; + +use crate::{error::UserFacingError, waiters}; use std::io; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite}; -pub use credentials::ClientCredentials; -pub use flow::*; +/// Convenience wrapper for the authentication error. +pub type Result = std::result::Result; /// Common authentication error. #[derive(Debug, Error)] pub enum AuthErrorImpl { /// Authentication error reported by the console. #[error(transparent)] - Console(#[from] auth_backend::AuthError), + Console(#[from] backend::AuthError), #[error(transparent)] - GetAuthInfo(#[from] auth_backend::console::ConsoleAuthError), + GetAuthInfo(#[from] backend::console::ConsoleAuthError), #[error(transparent)] Sasl(#[from] crate::sasl::Error), - /// For passwords that couldn't be processed by [`parse_password`]. + /// For passwords that couldn't be processed by [`backend::legacy_console::parse_password`]. #[error("Malformed password message")] MalformedPassword, - /// Errors produced by [`PqStream`]. + /// Errors produced by [`crate::stream::PqStream`]. #[error(transparent)] Io(#[from] io::Error), } impl AuthErrorImpl { pub fn auth_failed(msg: impl Into) -> Self { - AuthErrorImpl::Console(auth_backend::AuthError::auth_failed(msg)) + Self::Console(backend::AuthError::auth_failed(msg)) } } impl From for AuthErrorImpl { fn from(e: waiters::RegisterError) -> Self { - AuthErrorImpl::Console(auth_backend::AuthError::from(e)) + Self::Console(backend::AuthError::from(e)) } } impl From for AuthErrorImpl { fn from(e: waiters::WaitError) -> Self { - AuthErrorImpl::Console(auth_backend::AuthError::from(e)) + Self::Console(backend::AuthError::from(e)) } } @@ -63,7 +65,7 @@ where AuthErrorImpl: From, { fn from(e: T) -> Self { - AuthError(Box::new(e.into())) + Self(Box::new(e.into())) } } @@ -72,34 +74,9 @@ impl UserFacingError for AuthError { use AuthErrorImpl::*; match self.0.as_ref() { Console(e) => e.to_string_client(), + GetAuthInfo(e) => e.to_string_client(), MalformedPassword => self.to_string(), - GetAuthInfo(e) if matches!(e, SniMissing) => e.to_string(), _ => "Internal error".to_string(), } } } - -async fn handle_user( - config: &ProxyConfig, - client: &mut PqStream, - creds: ClientCredentials, -) -> Result { - match config.auth_backend { - AuthBackendType::LegacyConsole => { - legacy_console::handle_user( - &config.auth_endpoint, - &config.auth_link_uri, - client, - &creds, - ) - .await - } - AuthBackendType::Console => { - console::handle_user(config.auth_endpoint.as_ref(), client, &creds).await - } - AuthBackendType::Postgres => { - postgres::handle_user(&config.auth_endpoint, client, &creds).await - } - AuthBackendType::Link => link::handle_user(config.auth_link_uri.as_ref(), client).await, - } -} diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs new file mode 100644 index 0000000000..1d41f7f932 --- /dev/null +++ b/proxy/src/auth/backend.rs @@ -0,0 +1,109 @@ +mod legacy_console; +mod link; +mod postgres; + +pub mod console; + +pub use legacy_console::{AuthError, AuthErrorImpl}; + +use super::ClientCredentials; +use crate::{ + compute, + config::{AuthBackendType, ProxyConfig}, + mgmt, + stream::PqStream, + waiters::{self, Waiter, Waiters}, +}; +use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncRead, AsyncWrite}; + +lazy_static! { + static ref CPLANE_WAITERS: Waiters = Default::default(); +} + +/// Give caller an opportunity to wait for the cloud's reply. +pub async fn with_waiter( + psql_session_id: impl Into, + action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R, +) -> Result +where + R: std::future::Future>, + E: From, +{ + let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; + action(waiter).await +} + +pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> { + CPLANE_WAITERS.notify(psql_session_id, msg) +} + +/// Compute node connection params provided by the cloud. +/// Note how it implements serde traits, since we receive it over the wire. +#[derive(Serialize, Deserialize, Default)] +pub struct DatabaseInfo { + pub host: String, + pub port: u16, + pub dbname: String, + pub user: String, + pub password: Option, +} + +// Manually implement debug to omit personal and sensitive info. +impl std::fmt::Debug for DatabaseInfo { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("DatabaseInfo") + .field("host", &self.host) + .field("port", &self.port) + .finish() + } +} + +impl From for tokio_postgres::Config { + fn from(db_info: DatabaseInfo) -> Self { + let mut config = tokio_postgres::Config::new(); + + config + .host(&db_info.host) + .port(db_info.port) + .dbname(&db_info.dbname) + .user(&db_info.user); + + if let Some(password) = db_info.password { + config.password(password); + } + + config + } +} + +pub(super) async fn handle_user( + config: &ProxyConfig, + client: &mut PqStream, + creds: ClientCredentials, +) -> super::Result { + use AuthBackendType::*; + match config.auth_backend { + LegacyConsole => { + legacy_console::handle_user( + &config.auth_endpoint, + &config.auth_link_uri, + client, + &creds, + ) + .await + } + Console => { + console::Api::new(&config.auth_endpoint, &creds)? + .handle_user(client) + .await + } + Postgres => { + postgres::Api::new(&config.auth_endpoint, &creds)? + .handle_user(client) + .await + } + Link => link::handle_user(&config.auth_link_uri, client).await, + } +} diff --git a/proxy/src/auth/backend/console.rs b/proxy/src/auth/backend/console.rs new file mode 100644 index 0000000000..252522affb --- /dev/null +++ b/proxy/src/auth/backend/console.rs @@ -0,0 +1,225 @@ +//! Cloud API V2. + +use crate::{ + auth::{self, AuthFlow, ClientCredentials, DatabaseInfo}, + compute, + error::UserFacingError, + scram, + stream::PqStream, + url::ApiUrl, +}; +use serde::{Deserialize, Serialize}; +use std::{future::Future, io}; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; +use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; + +pub type Result = std::result::Result; + +#[derive(Debug, Error)] +pub enum ConsoleAuthError { + #[error(transparent)] + BadProjectName(#[from] auth::credentials::ProjectNameError), + + // We shouldn't include the actual secret here. + #[error("Bad authentication secret")] + BadSecret, + + #[error("Console responded with a malformed compute address: '{0}'")] + BadComputeAddress(String), + + #[error("Console responded with a malformed JSON: '{0}'")] + BadResponse(#[from] serde_json::Error), + + /// HTTP status (other than 200) returned by the console. + #[error("Console responded with an HTTP status: {0}")] + HttpStatus(reqwest::StatusCode), + + #[error(transparent)] + Io(#[from] std::io::Error), +} + +impl UserFacingError for ConsoleAuthError { + fn to_string_client(&self) -> String { + use ConsoleAuthError::*; + match self { + BadProjectName(e) => e.to_string_client(), + _ => "Internal error".to_string(), + } + } +} + +// TODO: convert into an enum with "error" +#[derive(Serialize, Deserialize, Debug)] +struct GetRoleSecretResponse { + role_secret: String, +} + +// TODO: convert into an enum with "error" +#[derive(Serialize, Deserialize, Debug)] +struct GetWakeComputeResponse { + address: String, +} + +/// Auth secret which is managed by the cloud. +pub enum AuthInfo { + /// Md5 hash of user's password. + Md5([u8; 16]), + + /// [SCRAM](crate::scram) authentication info. + Scram(scram::ServerSecret), +} + +#[must_use] +pub(super) struct Api<'a> { + endpoint: &'a ApiUrl, + creds: &'a ClientCredentials, + /// Cache project name, since we'll need it several times. + project: &'a str, +} + +impl<'a> Api<'a> { + /// Construct an API object containing the auth parameters. + pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result { + Ok(Self { + endpoint, + creds, + project: creds.project_name()?, + }) + } + + /// Authenticate the existing user or throw an error. + pub(super) async fn handle_user( + self, + client: &mut PqStream, + ) -> auth::Result { + handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await + } + + async fn get_auth_info(&self) -> Result { + let mut url = self.endpoint.clone(); + url.path_segments_mut().push("proxy_get_role_secret"); + url.query_pairs_mut() + .append_pair("project", self.project) + .append_pair("role", &self.creds.user); + + // TODO: use a proper logger + println!("cplane request: {url}"); + + let resp = reqwest::get(url.into_inner()).await.map_err(io_error)?; + if !resp.status().is_success() { + return Err(ConsoleAuthError::HttpStatus(resp.status())); + } + + let response: GetRoleSecretResponse = + serde_json::from_str(&resp.text().await.map_err(io_error)?)?; + + scram::ServerSecret::parse(response.role_secret.as_str()) + .map(AuthInfo::Scram) + .ok_or(ConsoleAuthError::BadSecret) + } + + /// Wake up the compute node and return the corresponding connection info. + async fn wake_compute(&self) -> Result { + let mut url = self.endpoint.clone(); + url.path_segments_mut().push("proxy_wake_compute"); + url.query_pairs_mut().append_pair("project", self.project); + + // TODO: use a proper logger + println!("cplane request: {url}"); + + let resp = reqwest::get(url.into_inner()).await.map_err(io_error)?; + if !resp.status().is_success() { + return Err(ConsoleAuthError::HttpStatus(resp.status())); + } + + let response: GetWakeComputeResponse = + serde_json::from_str(&resp.text().await.map_err(io_error)?)?; + + let (host, port) = parse_host_port(&response.address) + .ok_or(ConsoleAuthError::BadComputeAddress(response.address))?; + + Ok(DatabaseInfo { + host, + port, + dbname: self.creds.dbname.to_owned(), + user: self.creds.user.to_owned(), + password: None, + }) + } +} + +/// Common logic for user handling in API V2. +/// We reuse this for a mock API implementation in [`super::postgres`]. +pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>( + client: &mut PqStream, + endpoint: &'a Endpoint, + get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo, + wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute, +) -> auth::Result +where + GetAuthInfo: Future>, + WakeCompute: Future>, +{ + let auth_info = get_auth_info(endpoint).await?; + + let flow = AuthFlow::new(client); + let scram_keys = match auth_info { + AuthInfo::Md5(_) => { + // TODO: decide if we should support MD5 in api v2 + return Err(auth::AuthErrorImpl::auth_failed("MD5 is not supported").into()); + } + AuthInfo::Scram(secret) => { + let scram = auth::Scram(&secret); + Some(compute::ScramKeys { + client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(), + server_key: secret.server_key.as_bytes(), + }) + } + }; + + client + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())?; + + Ok(compute::NodeInfo { + db_info: wake_compute(endpoint).await?, + scram_keys, + }) +} + +/// Upcast (almost) any error into an opaque [`io::Error`]. +pub(super) fn io_error(e: impl Into>) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) +} + +fn parse_host_port(input: &str) -> Option<(String, u16)> { + let (host, port) = input.split_once(':')?; + Some((host.to_owned(), port.parse().ok()?)) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_db_info() -> anyhow::Result<()> { + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "password": "password", + }))?; + + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + }))?; + + Ok(()) + } +} diff --git a/proxy/src/auth_backend/legacy_console.rs b/proxy/src/auth/backend/legacy_console.rs similarity index 90% rename from proxy/src/auth_backend/legacy_console.rs rename to proxy/src/auth/backend/legacy_console.rs index 29997d2389..467da63a98 100644 --- a/proxy/src/auth_backend/legacy_console.rs +++ b/proxy/src/auth/backend/legacy_console.rs @@ -1,20 +1,18 @@ //! Cloud API V1. -use super::console::DatabaseInfo; - -use crate::auth::ClientCredentials; -use crate::stream::PqStream; - -use crate::{compute, waiters}; +use super::DatabaseInfo; +use crate::{ + auth::{self, ClientCredentials}, + compute, + error::UserFacingError, + stream::PqStream, + waiters, +}; use serde::{Deserialize, Serialize}; - +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; -use thiserror::Error; - -use crate::error::UserFacingError; - #[derive(Debug, Error)] pub enum AuthErrorImpl { /// Authentication error reported by the console. @@ -45,7 +43,7 @@ pub struct AuthError(Box); impl AuthError { /// Smart constructor for authentication error reported by `mgmt`. pub fn auth_failed(msg: impl Into) -> Self { - AuthError(Box::new(AuthErrorImpl::AuthFailed(msg.into()))) + Self(Box::new(AuthErrorImpl::AuthFailed(msg.into()))) } } @@ -54,7 +52,7 @@ where AuthErrorImpl: From, { fn from(e: T) -> Self { - AuthError(Box::new(e.into())) + Self(Box::new(e.into())) } } @@ -120,7 +118,7 @@ async fn handle_existing_user( auth_endpoint: &reqwest::Url, client: &mut PqStream, creds: &ClientCredentials, -) -> Result { +) -> Result { let psql_session_id = super::link::new_psql_session_id(); let md5_salt = rand::random(); @@ -130,7 +128,7 @@ async fn handle_existing_user( // Read client's password hash let msg = client.read_password_message().await?; - let md5_response = parse_password(&msg).ok_or(crate::auth::AuthErrorImpl::MalformedPassword)?; + let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword)?; let db_info = authenticate_proxy_client( auth_endpoint, @@ -156,11 +154,11 @@ pub async fn handle_user( auth_link_uri: &reqwest::Url, client: &mut PqStream, creds: &ClientCredentials, -) -> Result { +) -> auth::Result { if creds.is_existing_user() { handle_existing_user(auth_endpoint, client, creds).await } else { - super::link::handle_user(auth_link_uri.as_ref(), client).await + super::link::handle_user(auth_link_uri, client).await } } diff --git a/proxy/src/auth_backend/link.rs b/proxy/src/auth/backend/link.rs similarity index 75% rename from proxy/src/auth_backend/link.rs rename to proxy/src/auth/backend/link.rs index 8e5fcb32a9..669c9e00e9 100644 --- a/proxy/src/auth_backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -1,4 +1,4 @@ -use crate::{compute, stream::PqStream}; +use crate::{auth, compute, stream::PqStream}; use tokio::io::{AsyncRead, AsyncWrite}; use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; @@ -19,13 +19,13 @@ pub fn new_psql_session_id() -> String { } pub async fn handle_user( - redirect_uri: &str, + redirect_uri: &reqwest::Url, client: &mut PqStream, -) -> Result { +) -> auth::Result { let psql_session_id = new_psql_session_id(); - let greeting = hello_message(redirect_uri, &psql_session_id); + let greeting = hello_message(redirect_uri.as_str(), &psql_session_id); - let db_info = crate::auth_backend::with_waiter(psql_session_id, |waiter| async { + let db_info = super::with_waiter(psql_session_id, |waiter| async { // Give user a URL to spawn a new database client .write_message_noflush(&Be::AuthenticationOk)? @@ -34,9 +34,7 @@ pub async fn handle_user( .await?; // Wait for web console response (see `mgmt`) - waiter - .await? - .map_err(crate::auth::AuthErrorImpl::auth_failed) + waiter.await?.map_err(auth::AuthErrorImpl::auth_failed) }) .await?; diff --git a/proxy/src/auth/backend/postgres.rs b/proxy/src/auth/backend/postgres.rs new file mode 100644 index 0000000000..721b9db095 --- /dev/null +++ b/proxy/src/auth/backend/postgres.rs @@ -0,0 +1,88 @@ +//! Local mock of Cloud API V2. + +use crate::{ + auth::{ + self, + backend::console::{self, io_error, AuthInfo, Result}, + ClientCredentials, DatabaseInfo, + }, + compute, scram, + stream::PqStream, + url::ApiUrl, +}; +use tokio::io::{AsyncRead, AsyncWrite}; + +#[must_use] +pub(super) struct Api<'a> { + endpoint: &'a ApiUrl, + creds: &'a ClientCredentials, +} + +impl<'a> Api<'a> { + /// Construct an API object containing the auth parameters. + pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result { + Ok(Self { endpoint, creds }) + } + + /// Authenticate the existing user or throw an error. + pub(super) async fn handle_user( + self, + client: &mut PqStream, + ) -> auth::Result { + // We reuse user handling logic from a production module. + console::handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await + } + + /// This implementation fetches the auth info from a local postgres instance. + async fn get_auth_info(&self) -> Result { + // Perhaps we could persist this connection, but then we'd have to + // write more code for reopening it if it got closed, which doesn't + // seem worth it. + let (client, connection) = + tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls) + .await + .map_err(io_error)?; + + tokio::spawn(connection); + let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; + let rows = client + .query(query, &[&self.creds.user]) + .await + .map_err(io_error)?; + + match &rows[..] { + // We can't get a secret if there's no such user. + [] => Err(io_error(format!("unknown user '{}'", self.creds.user)).into()), + + // We shouldn't get more than one row anyway. + [row, ..] => { + let entry = row.try_get(0).map_err(io_error)?; + scram::ServerSecret::parse(entry) + .map(AuthInfo::Scram) + .or_else(|| { + // It could be an md5 hash if it's not a SCRAM secret. + let text = entry.strip_prefix("md5")?; + Some(AuthInfo::Md5({ + let mut bytes = [0u8; 16]; + hex::decode_to_slice(text, &mut bytes).ok()?; + bytes + })) + }) + // Putting the secret into this message is a security hazard! + .ok_or(console::ConsoleAuthError::BadSecret) + } + } + } + + /// We don't need to wake anything locally, so we just return the connection info. + async fn wake_compute(&self) -> Result { + Ok(DatabaseInfo { + // TODO: handle that near CLI params parsing + host: self.endpoint.host_str().unwrap_or("localhost").to_owned(), + port: self.endpoint.port().unwrap_or(5432), + dbname: self.creds.dbname.to_owned(), + user: self.creds.user.to_owned(), + password: None, + }) + } +} diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 9d2272b5ad..467e7db282 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -1,6 +1,5 @@ //! User credentials used in authentication. -use super::AuthError; use crate::compute; use crate::config::ProxyConfig; use crate::error::UserFacingError; @@ -36,6 +35,27 @@ impl ClientCredentials { } } +#[derive(Debug, Error)] +pub enum ProjectNameError { + #[error("SNI is missing, please upgrade the postgres client library")] + Missing, + + #[error("SNI is malformed")] + Bad, +} + +impl UserFacingError for ProjectNameError {} + +impl ClientCredentials { + /// Determine project name from SNI. + pub fn project_name(&self) -> Result<&str, ProjectNameError> { + // Currently project name is passed as a top level domain + let sni = self.sni_data.as_ref().ok_or(ProjectNameError::Missing)?; + let (first, _) = sni.split_once('.').ok_or(ProjectNameError::Bad)?; + Ok(first) + } +} + impl TryFrom> for ClientCredentials { type Error = ClientCredsParseError; @@ -47,11 +67,11 @@ impl TryFrom> for ClientCredentials { }; let user = get_param("user")?; - let db = get_param("database")?; + let dbname = get_param("database")?; Ok(Self { user, - dbname: db, + dbname, sni_data: None, }) } @@ -63,8 +83,8 @@ impl ClientCredentials { self, config: &ProxyConfig, client: &mut PqStream, - ) -> Result { + ) -> super::Result { // This method is just a convenient facade for `handle_user` - super::handle_user(config, client, self).await + super::backend::handle_user(config, client, self).await } } diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 3eed0f0a23..7efff13bfc 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,6 +1,6 @@ //! Main authentication flow. -use super::{AuthError, AuthErrorImpl}; +use super::AuthErrorImpl; use crate::stream::PqStream; use crate::{sasl, scram}; use std::io; @@ -32,7 +32,7 @@ impl AuthMethod for Scram<'_> { pub struct AuthFlow<'a, Stream, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream, - /// State might contain ancillary data (see [`AuthFlow::begin`]). + /// State might contain ancillary data (see [`Self::begin`]). state: State, } @@ -60,7 +60,7 @@ impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { /// Stream wrapper for handling [SCRAM](crate::scram) auth. impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> Result { + pub async fn authenticate(self) -> super::Result { // Initial client message contains the chosen auth method's name. let msg = self.stream.read_password_message().await?; let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; diff --git a/proxy/src/auth_backend.rs b/proxy/src/auth_backend.rs deleted file mode 100644 index 54362bf719..0000000000 --- a/proxy/src/auth_backend.rs +++ /dev/null @@ -1,31 +0,0 @@ -pub mod console; -pub mod legacy_console; -pub mod link; -pub mod postgres; - -pub use legacy_console::{AuthError, AuthErrorImpl}; - -use crate::mgmt; -use crate::waiters::{self, Waiter, Waiters}; -use lazy_static::lazy_static; - -lazy_static! { - static ref CPLANE_WAITERS: Waiters = Default::default(); -} - -/// Give caller an opportunity to wait for the cloud's reply. -pub async fn with_waiter( - psql_session_id: impl Into, - action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R, -) -> Result -where - R: std::future::Future>, - E: From, -{ - let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; - action(waiter).await -} - -pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> { - CPLANE_WAITERS.notify(psql_session_id, msg) -} diff --git a/proxy/src/auth_backend/console.rs b/proxy/src/auth_backend/console.rs deleted file mode 100644 index 41a822701f..0000000000 --- a/proxy/src/auth_backend/console.rs +++ /dev/null @@ -1,243 +0,0 @@ -//! Declaration of Cloud API V2. - -use crate::{ - auth::{self, AuthFlow}, - compute, scram, -}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; - -use crate::auth::ClientCredentials; -use crate::stream::PqStream; - -use tokio::io::{AsyncRead, AsyncWrite}; -use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; - -#[derive(Debug, Error)] -pub enum ConsoleAuthError { - // We shouldn't include the actual secret here. - #[error("Bad authentication secret")] - BadSecret, - - #[error("Bad client credentials: {0:?}")] - BadCredentials(crate::auth::ClientCredentials), - - #[error("SNI info is missing, please upgrade the postgres client library")] - SniMissing, - - #[error("Unexpected SNI content")] - SniWrong, - - #[error(transparent)] - BadUrl(#[from] url::ParseError), - - #[error(transparent)] - Io(#[from] std::io::Error), - - /// HTTP status (other than 200) returned by the console. - #[error("Console responded with an HTTP status: {0}")] - HttpStatus(reqwest::StatusCode), - - #[error(transparent)] - Transport(#[from] reqwest::Error), - - #[error("Console responded with a malformed JSON: '{0}'")] - MalformedResponse(#[from] serde_json::Error), - - #[error("Console responded with a malformed compute address: '{0}'")] - MalformedComputeAddress(String), -} - -#[derive(Serialize, Deserialize, Debug)] -struct GetRoleSecretResponse { - role_secret: String, -} - -#[derive(Serialize, Deserialize, Debug)] -struct GetWakeComputeResponse { - address: String, -} - -/// Auth secret which is managed by the cloud. -pub enum AuthInfo { - /// Md5 hash of user's password. - Md5([u8; 16]), - /// [SCRAM](crate::scram) authentication info. - Scram(scram::ServerSecret), -} - -/// Compute node connection params provided by the cloud. -/// Note how it implements serde traits, since we receive it over the wire. -#[derive(Serialize, Deserialize, Default)] -pub struct DatabaseInfo { - pub host: String, - pub port: u16, - pub dbname: String, - pub user: String, - - /// [Cloud API V1](super::legacy) returns cleartext password, - /// but [Cloud API V2](super::api) implements [SCRAM](crate::scram) - /// authentication, so we can leverage this method and cope without password. - pub password: Option, -} - -// Manually implement debug to omit personal and sensitive info. -impl std::fmt::Debug for DatabaseInfo { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct("DatabaseInfo") - .field("host", &self.host) - .field("port", &self.port) - .finish() - } -} - -impl From for tokio_postgres::Config { - fn from(db_info: DatabaseInfo) -> Self { - let mut config = tokio_postgres::Config::new(); - - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); - - if let Some(password) = db_info.password { - config.password(password); - } - - config - } -} - -async fn get_auth_info( - auth_endpoint: &str, - user: &str, - cluster: &str, -) -> Result { - let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_get_role_secret"))?; - - url.query_pairs_mut() - .append_pair("project", cluster) - .append_pair("role", user); - - // TODO: use a proper logger - println!("cplane request: {}", url); - - let resp = reqwest::get(url).await?; - if !resp.status().is_success() { - return Err(ConsoleAuthError::HttpStatus(resp.status())); - } - - let response: GetRoleSecretResponse = serde_json::from_str(resp.text().await?.as_str())?; - - scram::ServerSecret::parse(response.role_secret.as_str()) - .map(AuthInfo::Scram) - .ok_or(ConsoleAuthError::BadSecret) -} - -/// Wake up the compute node and return the corresponding connection info. -async fn wake_compute( - auth_endpoint: &str, - cluster: &str, -) -> Result<(String, u16), ConsoleAuthError> { - let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_wake_compute"))?; - url.query_pairs_mut().append_pair("project", cluster); - - // TODO: use a proper logger - println!("cplane request: {}", url); - - let resp = reqwest::get(url).await?; - if !resp.status().is_success() { - return Err(ConsoleAuthError::HttpStatus(resp.status())); - } - - let response: GetWakeComputeResponse = serde_json::from_str(resp.text().await?.as_str())?; - let (host, port) = response - .address - .split_once(':') - .ok_or_else(|| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?; - let port: u16 = port - .parse() - .map_err(|_| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?; - - Ok((host.to_string(), port)) -} - -pub async fn handle_user( - auth_endpoint: &str, - client: &mut PqStream, - creds: &ClientCredentials, -) -> Result { - // Determine cluster name from SNI. - let cluster = creds - .sni_data - .as_ref() - .ok_or(ConsoleAuthError::SniMissing)? - .split_once('.') - .ok_or(ConsoleAuthError::SniWrong)? - .0; - - let user = creds.user.as_str(); - - // Step 1: get the auth secret - let auth_info = get_auth_info(auth_endpoint, user, cluster).await?; - - let flow = AuthFlow::new(client); - let scram_keys = match auth_info { - AuthInfo::Md5(_) => { - // TODO: decide if we should support MD5 in api v2 - return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into()); - } - AuthInfo::Scram(secret) => { - let scram = auth::Scram(&secret); - Some(compute::ScramKeys { - client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(), - server_key: secret.server_key.as_bytes(), - }) - } - }; - - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())?; - - // Step 2: wake compute - let (host, port) = wake_compute(auth_endpoint, cluster).await?; - - Ok(compute::NodeInfo { - db_info: DatabaseInfo { - host, - port, - dbname: creds.dbname.clone(), - user: creds.user.clone(), - password: None, - }, - scram_keys, - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn parse_db_info() -> anyhow::Result<()> { - let _: DatabaseInfo = serde_json::from_value(json!({ - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "john_doe", - "password": "password", - }))?; - - let _: DatabaseInfo = serde_json::from_value(json!({ - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "john_doe", - }))?; - - Ok(()) - } -} diff --git a/proxy/src/auth_backend/postgres.rs b/proxy/src/auth_backend/postgres.rs deleted file mode 100644 index 148c2a2518..0000000000 --- a/proxy/src/auth_backend/postgres.rs +++ /dev/null @@ -1,93 +0,0 @@ -//! Local mock of Cloud API V2. - -use super::console::{self, AuthInfo, DatabaseInfo}; -use crate::scram; -use crate::{auth::ClientCredentials, compute}; - -use crate::stream::PqStream; -use tokio::io::{AsyncRead, AsyncWrite}; -use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; - -async fn get_auth_info( - auth_endpoint: &str, - creds: &ClientCredentials, -) -> Result { - // We wrap `tokio_postgres::Error` because we don't want to infect the - // method's error type with a detail that's specific to debug mode only. - let io_error = |e| std::io::Error::new(std::io::ErrorKind::Other, e); - - // Perhaps we could persist this connection, but then we'd have to - // write more code for reopening it if it got closed, which doesn't - // seem worth it. - let (client, connection) = tokio_postgres::connect(auth_endpoint, tokio_postgres::NoTls) - .await - .map_err(io_error)?; - - tokio::spawn(connection); - let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; - let rows = client - .query(query, &[&creds.user]) - .await - .map_err(io_error)?; - - match &rows[..] { - // We can't get a secret if there's no such user. - [] => Err(console::ConsoleAuthError::BadCredentials(creds.to_owned())), - // We shouldn't get more than one row anyway. - [row, ..] => { - let entry = row.try_get(0).map_err(io_error)?; - scram::ServerSecret::parse(entry) - .map(AuthInfo::Scram) - .or_else(|| { - // It could be an md5 hash if it's not a SCRAM secret. - let text = entry.strip_prefix("md5")?; - Some(AuthInfo::Md5({ - let mut bytes = [0u8; 16]; - hex::decode_to_slice(text, &mut bytes).ok()?; - bytes - })) - }) - // Putting the secret into this message is a security hazard! - .ok_or(console::ConsoleAuthError::BadSecret) - } - } -} - -pub async fn handle_user( - auth_endpoint: &reqwest::Url, - client: &mut PqStream, - creds: &ClientCredentials, -) -> Result { - let auth_info = get_auth_info(auth_endpoint.as_ref(), creds).await?; - - let flow = crate::auth::AuthFlow::new(client); - let scram_keys = match auth_info { - AuthInfo::Md5(_) => { - // TODO: decide if we should support MD5 in api v2 - return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into()); - } - AuthInfo::Scram(secret) => { - let scram = crate::auth::Scram(&secret); - Some(compute::ScramKeys { - client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(), - server_key: secret.server_key.as_bytes(), - }) - } - }; - - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())?; - - Ok(compute::NodeInfo { - db_info: DatabaseInfo { - // TODO: handle that near CLI params parsing - host: auth_endpoint.host_str().unwrap_or("localhost").to_owned(), - port: auth_endpoint.port().unwrap_or(5432), - dbname: creds.dbname.to_owned(), - user: creds.user.to_owned(), - password: None, - }, - scram_keys, - }) -} diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index c3c5ba47fb..cccd6e60d4 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,4 +1,4 @@ -use crate::auth_backend::console::DatabaseInfo; +use crate::auth::DatabaseInfo; use crate::cancellation::CancelClosure; use crate::error::UserFacingError; use std::io; @@ -37,7 +37,7 @@ pub struct NodeInfo { impl NodeInfo { async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> { - let host_port = format!("{}:{}", self.db_info.host, self.db_info.port); + let host_port = (self.db_info.host.as_str(), self.db_info.port); let socket = TcpStream::connect(host_port).await?; let socket_addr = socket.peer_addr()?; socket2::SockRef::from(&socket).set_keepalive(true)?; diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 6f1b56bfe4..a5cd17eb55 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,39 +1,38 @@ -use anyhow::{ensure, Context}; +use crate::url::ApiUrl; +use anyhow::{bail, ensure, Context}; use std::{str::FromStr, sync::Arc}; -#[non_exhaustive] pub enum AuthBackendType { + /// Legacy Cloud API (V1). LegacyConsole, - Console, - Postgres, + /// Authentication via a web browser. Link, + /// Current Cloud API (V2). + Console, + /// Local mock of Cloud API (V2). + Postgres, } impl FromStr for AuthBackendType { type Err = anyhow::Error; fn from_str(s: &str) -> anyhow::Result { - println!("ClientAuthMethod::from_str: '{}'", s); use AuthBackendType::*; - match s { - "legacy" => Ok(LegacyConsole), - "console" => Ok(Console), - "postgres" => Ok(Postgres), - "link" => Ok(Link), - _ => Err(anyhow::anyhow!("Invlid option for auth method")), - } + Ok(match s { + "legacy" => LegacyConsole, + "console" => Console, + "postgres" => Postgres, + "link" => Link, + _ => bail!("Invalid option `{s}` for auth method"), + }) } } pub struct ProxyConfig { - /// TLS configuration for the proxy. pub tls_config: Option, - pub auth_backend: AuthBackendType, - - pub auth_endpoint: reqwest::Url, - - pub auth_link_uri: reqwest::Url, + pub auth_endpoint: ApiUrl, + pub auth_link_uri: ApiUrl, } pub type TlsConfig = Arc; diff --git a/proxy/src/main.rs b/proxy/src/main.rs index b457d46824..672f24b6fb 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -5,7 +5,6 @@ //! in somewhat transparent manner (again via communication with control plane API). mod auth; -mod auth_backend; mod cancellation; mod compute; mod config; @@ -17,6 +16,7 @@ mod proxy; mod sasl; mod scram; mod stream; +mod url; mod waiters; use anyhow::{bail, Context}; diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 93618fff68..8737d170b1 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -1,4 +1,4 @@ -use crate::auth_backend; +use crate::auth; use anyhow::Context; use serde::Deserialize; use std::{ @@ -77,12 +77,12 @@ struct PsqlSessionResponse { #[derive(Deserialize)] enum PsqlSessionResult { - Success(auth_backend::console::DatabaseInfo), + Success(auth::DatabaseInfo), Failure(String), } /// A message received by `mgmt` when a compute node is ready. -pub type ComputeReady = Result; +pub type ComputeReady = Result; impl PsqlSessionResult { fn into_compute_ready(self) -> ComputeReady { @@ -113,7 +113,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R let resp: PsqlSessionResponse = serde_json::from_str(query_string)?; - match auth_backend::notify(&resp.session_id, resp.result.into_compute_ready()) { + match auth::backend::notify(&resp.session_id, resp.result.into_compute_ready()) { Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? diff --git a/proxy/src/url.rs b/proxy/src/url.rs new file mode 100644 index 0000000000..76d6ad0e66 --- /dev/null +++ b/proxy/src/url.rs @@ -0,0 +1,82 @@ +use anyhow::bail; +use url::form_urlencoded::Serializer; + +/// A [url](url::Url) type with additional guarantees. +#[derive(Debug, Clone)] +pub struct ApiUrl(url::Url); + +impl ApiUrl { + /// Consume the wrapper and return inner [url](url::Url). + pub fn into_inner(self) -> url::Url { + self.0 + } + + /// See [`url::Url::query_pairs_mut`]. + pub fn query_pairs_mut(&mut self) -> Serializer<'_, url::UrlQuery<'_>> { + self.0.query_pairs_mut() + } + + /// See [`url::Url::path_segments_mut`]. + pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut { + // We've already verified that it works during construction. + self.0.path_segments_mut().expect("bad API url") + } +} + +/// This instance imposes additional requirements on the url. +impl std::str::FromStr for ApiUrl { + type Err = anyhow::Error; + + fn from_str(s: &str) -> anyhow::Result { + let mut url: url::Url = s.parse()?; + + // Make sure that we can build upon this URL. + if url.path_segments_mut().is_err() { + bail!("bad API url provided"); + } + + Ok(Self(url)) + } +} + +/// This instance is safe because it doesn't allow us to modify the object. +impl std::ops::Deref for ApiUrl { + type Target = url::Url; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::fmt::Display for ApiUrl { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bad_url() { + let url = "test:foobar"; + url.parse::().expect("unexpected parsing failure"); + let _ = url.parse::().expect_err("should not parse"); + } + + #[test] + fn good_url() { + let url = "test://foobar"; + let mut a = url.parse::().expect("unexpected parsing failure"); + let mut b = url.parse::().expect("unexpected parsing failure"); + + a.path_segments_mut().unwrap().push("method"); + a.query_pairs_mut().append_pair("key", "value"); + + b.path_segments_mut().push("method"); + b.query_pairs_mut().append_pair("key", "value"); + + assert_eq!(a, b.into_inner()); + } +} From b3ec6e0661e1f08beb1cd08b265cc64af0cd4035 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Thu, 26 May 2022 20:39:33 +0300 Subject: [PATCH 26/27] [proxy] Propagate SASL/SCRAM auth errors to the user This will replace the vague (and incorrect) "Internal error" with a nice and helpful authentication error, e.g. "password doesn't match". --- proxy/src/auth.rs | 1 + proxy/src/config.rs | 1 + proxy/src/main.rs | 1 + proxy/src/sasl.rs | 15 +++++++++++++++ proxy/src/scram/exchange.rs | 6 ++++-- 5 files changed, 22 insertions(+), 2 deletions(-) diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 082a7bcf20..9bddd58fce 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -75,6 +75,7 @@ impl UserFacingError for AuthError { match self.0.as_ref() { Console(e) => e.to_string_client(), GetAuthInfo(e) => e.to_string_client(), + Sasl(e) => e.to_string_client(), MalformedPassword => self.to_string(), _ => "Internal error".to_string(), } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index a5cd17eb55..4def11aefc 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -2,6 +2,7 @@ use crate::url::ApiUrl; use anyhow::{bail, ensure, Context}; use std::{str::FromStr, sync::Arc}; +#[derive(Debug)] pub enum AuthBackendType { /// Legacy Cloud API (V1). LegacyConsole, diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 672f24b6fb..b68b2440dd 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -126,6 +126,7 @@ async fn main() -> anyhow::Result<()> { })); println!("Version: {GIT_VERSION}"); + println!("Authentication backend: {:?}", config.auth_backend); // Check that we can bind to address before further initialization println!("Starting http on {}", http_address); diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index cd9032bfb9..689fca6049 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -10,6 +10,7 @@ mod channel_binding; mod messages; mod stream; +use crate::error::UserFacingError; use std::io; use thiserror::Error; @@ -36,6 +37,20 @@ pub enum Error { Io(#[from] io::Error), } +impl UserFacingError for Error { + fn to_string_client(&self) -> String { + use Error::*; + match self { + // This constructor contains the reason why auth has failed. + AuthenticationFailed(s) => s.to_string(), + // TODO: add support for channel binding + ChannelBindingFailed(_) => "channel binding is not supported yet".to_string(), + ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"), + _ => "authentication protocol violation".to_string(), + } + } +} + /// A convenient result type for SASL exchange. pub type Result = std::result::Result; diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index cad77e15f5..fca5585b25 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -106,7 +106,9 @@ impl sasl::Mechanism for Exchange<'_> { } if client_final_message.nonce != server_first_message.nonce() { - return Err(SaslError::AuthenticationFailed("bad nonce")); + return Err(SaslError::AuthenticationFailed( + "combined nonce doesn't match", + )); } let signature_builder = SignatureBuilder { @@ -120,7 +122,7 @@ impl sasl::Mechanism for Exchange<'_> { .derive_client_key(&client_final_message.proof); if client_key.sha256() != self.secret.stored_key { - return Err(SaslError::AuthenticationFailed("keys don't match")); + return Err(SaslError::AuthenticationFailed("password doesn't match")); } let msg = client_final_message From c5f3c9bbc7e5debc5fe55dcbd781981882a12433 Mon Sep 17 00:00:00 2001 From: Kliment Serafimov Date: Thu, 2 Jun 2022 00:04:26 +0200 Subject: [PATCH 27/27] Merged changes. --- libs/utils/src/pq_proto.rs | 10 +- proxy/src/auth/credentials.rs | 34 +++- proxy/src/auth_backend/console.rs | 251 ++++++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 9 deletions(-) create mode 100644 proxy/src/auth_backend/console.rs diff --git a/libs/utils/src/pq_proto.rs b/libs/utils/src/pq_proto.rs index ce86cf8c91..0ad8adb3a6 100644 --- a/libs/utils/src/pq_proto.rs +++ b/libs/utils/src/pq_proto.rs @@ -269,11 +269,15 @@ impl FeStartupPacket { .next() .context("expected even number of params in StartupMessage")?; if name == "options" { - // deprecated way of passing params as cmd line args - for cmdopt in value.split(' ') { - let nameval: Vec<&str> = cmdopt.split('=').collect(); + //parsing options arguments "..&options=:,.." + //extended example and set of options: + //https://github.com/neondatabase/neon/blob/main/docs/rfcs/016-connection-routing.md#connection-url + for cmdopt in value.split(',') { + let nameval: Vec<&str> = cmdopt.split(':').collect(); if nameval.len() == 2 { params.insert(nameval[0].to_string(), nameval[1].to_string()); + } else { + //todo: inform user / throw error message if options format is wrong. } } } else { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 467e7db282..2636c2237d 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -26,6 +26,10 @@ pub struct ClientCredentials { // New console API requires SNI info to determine the cluster name. // Other Auth backends don't need it. pub sni_data: Option, + + // cluster_option is passed as argument from options from url. + // To be used to determine cluster name in case sni_data is missing. + pub project_option: Option, } impl ClientCredentials { @@ -37,10 +41,10 @@ impl ClientCredentials { #[derive(Debug, Error)] pub enum ProjectNameError { - #[error("SNI is missing, please upgrade the postgres client library")] + #[error("SNI info is missing. EITHER please upgrade the postgres client library OR pass the project name as a parameter: '..&options=project:..'.")] Missing, - #[error("SNI is malformed")] + #[error("SNI is malformed.")] Bad, } @@ -49,10 +53,22 @@ impl UserFacingError for ProjectNameError {} impl ClientCredentials { /// Determine project name from SNI. pub fn project_name(&self) -> Result<&str, ProjectNameError> { - // Currently project name is passed as a top level domain - let sni = self.sni_data.as_ref().ok_or(ProjectNameError::Missing)?; - let (first, _) = sni.split_once('.').ok_or(ProjectNameError::Bad)?; - Ok(first) + let ret = match &self.sni_data { + //if sni_data exists, use it to determine project name + Some(sni_data) => { + sni_data + .split_once('.') + .ok_or(ProjectNameError::Bad)? + .0 + } + //otherwise use project_option if it was manually set thought ..&options=project: parameter + None => self + .project_option + .as_ref() + .ok_or(ProjectNameError::Missing)? + .as_str(), + }; + Ok(ret) } } @@ -68,11 +84,17 @@ impl TryFrom> for ClientCredentials { let user = get_param("user")?; let dbname = get_param("database")?; + let project = get_param("project"); + let project_option = match project { + Ok(project) => Some(project), + Err(_) => None, + }; Ok(Self { user, dbname, sni_data: None, + project_option, }) } } diff --git a/proxy/src/auth_backend/console.rs b/proxy/src/auth_backend/console.rs new file mode 100644 index 0000000000..551bcb7bb6 --- /dev/null +++ b/proxy/src/auth_backend/console.rs @@ -0,0 +1,251 @@ +//! Declaration of Cloud API V2. + +use crate::{ + auth::{self, AuthFlow}, + compute, scram, +}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::auth::ClientCredentials; +use crate::stream::PqStream; + +use tokio::io::{AsyncRead, AsyncWrite}; +use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; + +#[derive(Debug, Error)] +pub enum ConsoleAuthError { + // We shouldn't include the actual secret here. + #[error("Bad authentication secret")] + BadSecret, + + #[error("Bad client credentials: {0:?}")] + BadCredentials(crate::auth::ClientCredentials), + + #[error("SNI info is missing. EITHER please upgrade the postgres client library OR pass ..&options=cluster:.. parameter")] + SniMissingAndProjectNameMissing, + + #[error("Unexpected SNI content")] + SniWrong, + + #[error(transparent)] + BadUrl(#[from] url::ParseError), + + #[error(transparent)] + Io(#[from] std::io::Error), + + /// HTTP status (other than 200) returned by the console. + #[error("Console responded with an HTTP status: {0}")] + HttpStatus(reqwest::StatusCode), + + #[error(transparent)] + Transport(#[from] reqwest::Error), + + #[error("Console responded with a malformed JSON: '{0}'")] + MalformedResponse(#[from] serde_json::Error), + + #[error("Console responded with a malformed compute address: '{0}'")] + MalformedComputeAddress(String), +} + +#[derive(Serialize, Deserialize, Debug)] +struct GetRoleSecretResponse { + role_secret: String, +} + +#[derive(Serialize, Deserialize, Debug)] +struct GetWakeComputeResponse { + address: String, +} + +/// Auth secret which is managed by the cloud. +pub enum AuthInfo { + /// Md5 hash of user's password. + Md5([u8; 16]), + /// [SCRAM](crate::scram) authentication info. + Scram(scram::ServerSecret), +} + +/// Compute node connection params provided by the cloud. +/// Note how it implements serde traits, since we receive it over the wire. +#[derive(Serialize, Deserialize, Default)] +pub struct DatabaseInfo { + pub host: String, + pub port: u16, + pub dbname: String, + pub user: String, + + /// [Cloud API V1](super::legacy) returns cleartext password, + /// but [Cloud API V2](super::api) implements [SCRAM](crate::scram) + /// authentication, so we can leverage this method and cope without password. + pub password: Option, +} + +// Manually implement debug to omit personal and sensitive info. +impl std::fmt::Debug for DatabaseInfo { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("DatabaseInfo") + .field("host", &self.host) + .field("port", &self.port) + .finish() + } +} + +impl From for tokio_postgres::Config { + fn from(db_info: DatabaseInfo) -> Self { + let mut config = tokio_postgres::Config::new(); + + config + .host(&db_info.host) + .port(db_info.port) + .dbname(&db_info.dbname) + .user(&db_info.user); + + if let Some(password) = db_info.password { + config.password(password); + } + + config + } +} + +async fn get_auth_info( + auth_endpoint: &str, + user: &str, + cluster: &str, +) -> Result { + let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_get_role_secret"))?; + + url.query_pairs_mut() + .append_pair("project", cluster) + .append_pair("role", user); + + // TODO: use a proper logger + println!("cplane request: {}", url); + + let resp = reqwest::get(url).await?; + if !resp.status().is_success() { + return Err(ConsoleAuthError::HttpStatus(resp.status())); + } + + let response: GetRoleSecretResponse = serde_json::from_str(resp.text().await?.as_str())?; + + scram::ServerSecret::parse(response.role_secret.as_str()) + .map(AuthInfo::Scram) + .ok_or(ConsoleAuthError::BadSecret) +} + +/// Wake up the compute node and return the corresponding connection info. +async fn wake_compute( + auth_endpoint: &str, + cluster: &str, +) -> Result<(String, u16), ConsoleAuthError> { + let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_wake_compute"))?; + url.query_pairs_mut().append_pair("project", cluster); + + // TODO: use a proper logger + println!("cplane request: {}", url); + + let resp = reqwest::get(url).await?; + if !resp.status().is_success() { + return Err(ConsoleAuthError::HttpStatus(resp.status())); + } + + let response: GetWakeComputeResponse = serde_json::from_str(resp.text().await?.as_str())?; + let (host, port) = response + .address + .split_once(':') + .ok_or_else(|| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?; + let port: u16 = port + .parse() + .map_err(|_| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?; + + Ok((host.to_string(), port)) +} + +pub async fn handle_user( + auth_endpoint: &str, + client: &mut PqStream, + creds: &ClientCredentials, +) -> Result { + // Determine cluster name from SNI (creds.sni_data) or from creds.cluster_option. + let cluster = match &creds.sni_data { + //if sni_data exists, use it + Some(sni_data) => { + sni_data + .split_once('.') + .ok_or(ConsoleAuthError::SniWrong)? + .0 + } + //otherwise use cluster_option if it was manually set thought ..&options=cluster: parameter + None => creds + .cluster_option + .as_ref() + .ok_or(ConsoleAuthError::SniMissingAndProjectNameMissing)? + .as_str(), + }; + + let user = creds.user.as_str(); + + // Step 1: get the auth secret + let auth_info = get_auth_info(auth_endpoint, user, cluster).await?; + + let flow = AuthFlow::new(client); + let scram_keys = match auth_info { + AuthInfo::Md5(_) => { + // TODO: decide if we should support MD5 in api v2 + return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into()); + } + AuthInfo::Scram(secret) => { + let scram = auth::Scram(&secret); + Some(compute::ScramKeys { + client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(), + server_key: secret.server_key.as_bytes(), + }) + } + }; + + client + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())?; + + // Step 2: wake compute + let (host, port) = wake_compute(auth_endpoint, cluster).await?; + + Ok(compute::NodeInfo { + db_info: DatabaseInfo { + host, + port, + dbname: creds.dbname.clone(), + user: creds.user.clone(), + password: None, + }, + scram_keys, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_db_info() -> anyhow::Result<()> { + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "password": "password", + }))?; + + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + }))?; + + Ok(()) + } +}