From d11c9f9fcb950ac263b1fa16651976ca11e96edf Mon Sep 17 00:00:00 2001 From: Anastasia Lubennikova Date: Wed, 15 Jun 2022 18:16:04 +0300 Subject: [PATCH 01/11] Use random ports for the proxy and local pg in tests Fixes #1931 Author: Dmitry Ivanov --- test_runner/fixtures/neon_fixtures.py | 25 ++++++++++++++++--------- test_runner/fixtures/utils.py | 2 +- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 4c0715bac3..167c3ff60a 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -29,7 +29,7 @@ from dataclasses import dataclass # Type-related stuff from psycopg2.extensions import connection as PgConnection from psycopg2.extensions import make_dsn, parse_dsn -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union, Tuple from typing_extensions import Literal import requests @@ -1379,6 +1379,7 @@ class VanillaPostgres(PgProtocol): self.pg_bin = pg_bin self.running = False self.pg_bin.run_capture(['initdb', '-D', pgdatadir]) + self.configure([f"port = {port}\n"]) def configure(self, options: List[str]): """Append lines into postgresql.conf file.""" @@ -1413,10 +1414,12 @@ class VanillaPostgres(PgProtocol): @pytest.fixture(scope='function') -def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]: +def vanilla_pg(test_output_dir: str, + port_distributor: PortDistributor) -> Iterator[VanillaPostgres]: pgdatadir = os.path.join(test_output_dir, "pgdata-vanilla") pg_bin = PgBin(test_output_dir) - with VanillaPostgres(pgdatadir, pg_bin, 5432) as vanilla_pg: + port = port_distributor.get_port() + with VanillaPostgres(pgdatadir, pg_bin, port) as vanilla_pg: yield vanilla_pg @@ -1462,7 +1465,7 @@ def remote_pg(test_output_dir: str) -> Iterator[RemotePostgres]: class NeonProxy(PgProtocol): - def __init__(self, port: int): + def __init__(self, port: int, pg_port: int): super().__init__(host="127.0.0.1", user="proxy_user", password="pytest2", @@ -1471,9 +1474,10 @@ class NeonProxy(PgProtocol): self.http_port = 7001 self.host = "127.0.0.1" self.port = port + self.pg_port = pg_port self._popen: Optional[subprocess.Popen[bytes]] = None - def start_static(self, addr="127.0.0.1:5432") -> None: + def start(self) -> None: assert self._popen is None # Start proxy @@ -1482,7 +1486,8 @@ class NeonProxy(PgProtocol): args.extend(["--http", f"{self.host}:{self.http_port}"]) args.extend(["--proxy", f"{self.host}:{self.port}"]) args.extend(["--auth-backend", "postgres"]) - args.extend(["--auth-endpoint", "postgres://proxy_auth:pytest1@localhost:5432/postgres"]) + args.extend( + ["--auth-endpoint", f"postgres://proxy_auth:pytest1@localhost:{self.pg_port}/postgres"]) self._popen = subprocess.Popen(args) self._wait_until_ready() @@ -1501,14 +1506,16 @@ class NeonProxy(PgProtocol): @pytest.fixture(scope='function') -def static_proxy(vanilla_pg) -> Iterator[NeonProxy]: +def static_proxy(vanilla_pg, port_distributor) -> Iterator[NeonProxy]: """Neon proxy that routes directly to vanilla postgres.""" vanilla_pg.start() vanilla_pg.safe_psql("create user proxy_auth with password 'pytest1' superuser") vanilla_pg.safe_psql("create user proxy_user with password 'pytest2'") - with NeonProxy(4432) as proxy: - proxy.start_static() + port = port_distributor.get_port() + pg_port = vanilla_pg.default_options['port'] + with NeonProxy(port, pg_port) as proxy: + proxy.start() yield proxy diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index ba9bc6e113..bfa57373b3 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -3,7 +3,7 @@ import shutil import subprocess from pathlib import Path -from typing import Any, List, Optional +from typing import Any, List from fixtures.log_helper import log From 36ee182d260dc01fd592e19a9928f10c3957cd05 Mon Sep 17 00:00:00 2001 From: Anastasia Lubennikova Date: Thu, 16 Jun 2022 14:07:11 +0300 Subject: [PATCH 02/11] Implement page servise 'fullbackup' endpoint (#1923) * Implement page servise 'fullbackup' endpoint that works like basebackup, but also sends relational files * Add test_runner/batch_others/test_fullbackup.py Co-authored-by: bojanserafimov --- pageserver/src/basebackup.rs | 80 ++++++++++++++++----- pageserver/src/page_service.rs | 31 +++++++- pageserver/src/reltag.rs | 26 ++++++- test_runner/batch_others/test_fullbackup.py | 73 +++++++++++++++++++ test_runner/fixtures/neon_fixtures.py | 5 +- 5 files changed, 193 insertions(+), 22 deletions(-) create mode 100644 test_runner/batch_others/test_fullbackup.py diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index 46d824b2e2..44a6442522 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -13,6 +13,7 @@ use anyhow::{anyhow, bail, ensure, Context, Result}; use bytes::{BufMut, BytesMut}; use fail::fail_point; +use itertools::Itertools; use std::fmt::Write as FmtWrite; use std::io; use std::io::Write; @@ -21,7 +22,7 @@ use std::time::SystemTime; use tar::{Builder, EntryType, Header}; use tracing::*; -use crate::reltag::SlruKind; +use crate::reltag::{RelTag, SlruKind}; use crate::repository::Timeline; use crate::DatadirTimelineImpl; use postgres_ffi::xlog_utils::*; @@ -39,11 +40,12 @@ where timeline: &'a Arc, pub lsn: Lsn, prev_record_lsn: Lsn, - + full_backup: bool, finished: bool, } -// Create basebackup with non-rel data in it. Omit relational data. +// Create basebackup with non-rel data in it. +// Only include relational data if 'full_backup' is true. // // Currently we use empty lsn in two cases: // * During the basebackup right after timeline creation @@ -58,6 +60,7 @@ where write: W, timeline: &'a Arc, req_lsn: Option, + full_backup: bool, ) -> Result> { // Compute postgres doesn't have any previous WAL files, but the first // record that it's going to write needs to include the LSN of the @@ -94,8 +97,8 @@ where }; info!( - "taking basebackup lsn={}, prev_lsn={}", - backup_lsn, backup_prev + "taking basebackup lsn={}, prev_lsn={} (full_backup={})", + backup_lsn, backup_prev, full_backup ); Ok(Basebackup { @@ -103,6 +106,7 @@ where timeline, lsn: backup_lsn, prev_record_lsn: backup_prev, + full_backup, finished: false, }) } @@ -140,6 +144,13 @@ where // Create tablespace directories for ((spcnode, dbnode), has_relmap_file) in self.timeline.list_dbdirs(self.lsn)? { self.add_dbdir(spcnode, dbnode, has_relmap_file)?; + + // Gather and send relational files in each database if full backup is requested. + if self.full_backup { + for rel in self.timeline.list_rels(spcnode, dbnode, self.lsn)? { + self.add_rel(rel)?; + } + } } for xid in self.timeline.list_twophase_files(self.lsn)? { self.add_twophase_file(xid)?; @@ -157,6 +168,38 @@ where Ok(()) } + fn add_rel(&mut self, tag: RelTag) -> anyhow::Result<()> { + let nblocks = self.timeline.get_rel_size(tag, self.lsn)?; + + // Function that adds relation segment data to archive + let mut add_file = |segment_index, data: &Vec| -> anyhow::Result<()> { + let file_name = tag.to_segfile_name(segment_index as u32); + let header = new_tar_header(&file_name, data.len() as u64)?; + self.ar.append(&header, data.as_slice())?; + Ok(()) + }; + + // If the relation is empty, create an empty file + if nblocks == 0 { + add_file(0, &vec![])?; + return Ok(()); + } + + // Add a file for each chunk of blocks (aka segment) + let chunks = (0..nblocks).chunks(pg_constants::RELSEG_SIZE as usize); + for (seg, blocks) in chunks.into_iter().enumerate() { + let mut segment_data: Vec = vec![]; + for blknum in blocks { + let img = self.timeline.get_rel_page_at_lsn(tag, blknum, self.lsn)?; + segment_data.extend_from_slice(&img[..]); + } + + add_file(seg, &segment_data)?; + } + + Ok(()) + } + // // Generate SLRU segment files from repository. // @@ -312,21 +355,24 @@ where pg_control.checkPointCopy = checkpoint; pg_control.state = pg_constants::DB_SHUTDOWNED; - // add zenith.signal file - let mut zenith_signal = String::new(); - if self.prev_record_lsn == Lsn(0) { - if self.lsn == self.timeline.tline.get_ancestor_lsn() { - write!(zenith_signal, "PREV LSN: none")?; + // Postgres doesn't recognize the zenith.signal file and doesn't need it. + if !self.full_backup { + // add zenith.signal file + let mut zenith_signal = String::new(); + if self.prev_record_lsn == Lsn(0) { + if self.lsn == self.timeline.tline.get_ancestor_lsn() { + write!(zenith_signal, "PREV LSN: none")?; + } else { + write!(zenith_signal, "PREV LSN: invalid")?; + } } else { - write!(zenith_signal, "PREV LSN: invalid")?; + write!(zenith_signal, "PREV LSN: {}", self.prev_record_lsn)?; } - } else { - write!(zenith_signal, "PREV LSN: {}", self.prev_record_lsn)?; + self.ar.append( + &new_tar_header("zenith.signal", zenith_signal.len() as u64)?, + zenith_signal.as_bytes(), + )?; } - self.ar.append( - &new_tar_header("zenith.signal", zenith_signal.len() as u64)?, - zenith_signal.as_bytes(), - )?; //send pg_control let pg_control_bytes = pg_control.encode(); diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 30f0d241d6..406228f034 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -596,6 +596,7 @@ impl PageServerHandler { timelineid: ZTimelineId, lsn: Option, tenantid: ZTenantId, + full_backup: bool, ) -> anyhow::Result<()> { let span = info_span!("basebackup", timeline = %timelineid, tenant = %tenantid, lsn = field::Empty); let _enter = span.enter(); @@ -618,7 +619,7 @@ impl PageServerHandler { { let mut writer = CopyDataSink { pgb }; - let basebackup = basebackup::Basebackup::new(&mut writer, &timeline, lsn)?; + let basebackup = basebackup::Basebackup::new(&mut writer, &timeline, lsn, full_backup)?; span.record("lsn", &basebackup.lsn.to_string().as_str()); basebackup.send_tarball()?; } @@ -721,7 +722,33 @@ impl postgres_backend::Handler for PageServerHandler { }; // Check that the timeline exists - self.handle_basebackup_request(pgb, timelineid, lsn, tenantid)?; + self.handle_basebackup_request(pgb, timelineid, lsn, tenantid, false)?; + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } + // same as basebackup, but result includes relational data as well + else if query_string.starts_with("fullbackup ") { + let (_, params_raw) = query_string.split_at("fullbackup ".len()); + let params = params_raw.split_whitespace().collect::>(); + + ensure!( + params.len() == 3, + "invalid param number for fullbackup command" + ); + + let tenantid = ZTenantId::from_str(params[0])?; + let timelineid = ZTimelineId::from_str(params[1])?; + + self.check_permission(Some(tenantid))?; + + // Lsn is required for fullbackup, because otherwise we would not know + // at which lsn to upload this backup. + // + // The caller is responsible for providing a valid lsn + // and using it in the subsequent import. + let lsn = Some(Lsn::from_str(params[2])?); + + // Check that the timeline exists + self.handle_basebackup_request(pgb, timelineid, lsn, tenantid, true)?; pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; } else if query_string.to_ascii_lowercase().starts_with("set ") { // important because psycopg2 executes "SET datestyle TO 'ISO'" diff --git a/pageserver/src/reltag.rs b/pageserver/src/reltag.rs index 18e26cc37a..fadd41f547 100644 --- a/pageserver/src/reltag.rs +++ b/pageserver/src/reltag.rs @@ -3,7 +3,7 @@ use std::cmp::Ordering; use std::fmt; use postgres_ffi::relfile_utils::forknumber_to_name; -use postgres_ffi::Oid; +use postgres_ffi::{pg_constants, Oid}; /// /// Relation data file segment id throughout the Postgres cluster. @@ -75,6 +75,30 @@ impl fmt::Display for RelTag { } } +impl RelTag { + pub fn to_segfile_name(&self, segno: u32) -> String { + let mut name = if self.spcnode == pg_constants::GLOBALTABLESPACE_OID { + "global/".to_string() + } else { + format!("base/{}/", self.dbnode) + }; + + name += &self.relnode.to_string(); + + if let Some(fork_name) = forknumber_to_name(self.forknum) { + name += "_"; + name += fork_name; + } + + if segno != 0 { + name += "."; + name += &segno.to_string(); + } + + name + } +} + /// /// Non-relation transaction status files (clog (a.k.a. pg_xact) and /// pg_multixact) in Postgres are handled by SLRU (Simple LRU) buffer, diff --git a/test_runner/batch_others/test_fullbackup.py b/test_runner/batch_others/test_fullbackup.py new file mode 100644 index 0000000000..e5d705beab --- /dev/null +++ b/test_runner/batch_others/test_fullbackup.py @@ -0,0 +1,73 @@ +import subprocess +from contextlib import closing + +import psycopg2.extras +import pytest +from fixtures.log_helper import log +from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, PortDistributor, VanillaPostgres +from fixtures.neon_fixtures import pg_distrib_dir +import os +from fixtures.utils import mkdir_if_needed, subprocess_capture +import shutil +import getpass +import pwd + +num_rows = 1000 + + +# Ensure that regular postgres can start from fullbackup +def test_fullbackup(neon_env_builder: NeonEnvBuilder, + pg_bin: PgBin, + port_distributor: PortDistributor): + + neon_env_builder.num_safekeepers = 1 + env = neon_env_builder.init_start() + + env.neon_cli.create_branch('test_fullbackup') + pgmain = env.postgres.create_start('test_fullbackup') + log.info("postgres is running on 'test_fullbackup' branch") + + timeline = pgmain.safe_psql("SHOW neon.timeline_id")[0][0] + + with closing(pgmain.connect()) as conn: + with conn.cursor() as cur: + # data loading may take a while, so increase statement timeout + cur.execute("SET statement_timeout='300s'") + cur.execute(f'''CREATE TABLE tbl AS SELECT 'long string to consume some space' || g + from generate_series(1,{num_rows}) g''') + cur.execute("CHECKPOINT") + + cur.execute('SELECT pg_current_wal_insert_lsn()') + lsn = cur.fetchone()[0] + log.info(f"start_backup_lsn = {lsn}") + + # Set LD_LIBRARY_PATH in the env properly, otherwise we may use the wrong libpq. + # PgBin sets it automatically, but here we need to pipe psql output to the tar command. + psql_env = {'LD_LIBRARY_PATH': os.path.join(str(pg_distrib_dir), 'lib')} + + # Get and unpack fullbackup from pageserver + restored_dir_path = os.path.join(env.repo_dir, "restored_datadir") + os.mkdir(restored_dir_path, 0o750) + query = f"fullbackup {env.initial_tenant.hex} {timeline} {lsn}" + cmd = ["psql", "--no-psqlrc", env.pageserver.connstr(), "-c", query] + result_basepath = pg_bin.run_capture(cmd, env=psql_env) + tar_output_file = result_basepath + ".stdout" + subprocess_capture(str(env.repo_dir), ["tar", "-xf", tar_output_file, "-C", restored_dir_path]) + + # HACK + # fullbackup returns neon specific pg_control and first WAL segment + # use resetwal to overwrite it + pg_resetwal_path = os.path.join(pg_bin.pg_bin_path, 'pg_resetwal') + cmd = [pg_resetwal_path, "-D", restored_dir_path] + pg_bin.run_capture(cmd, env=psql_env) + + # Restore from the backup and find the data we inserted + port = port_distributor.get_port() + with VanillaPostgres(restored_dir_path, pg_bin, port, init=False) as vanilla_pg: + # TODO make port an optional argument + vanilla_pg.configure([ + f"port={port}", + ]) + vanilla_pg.start() + num_rows_found = vanilla_pg.safe_psql('select count(*) from tbl;', user="cloud_admin")[0][0] + assert num_rows == num_rows_found diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 167c3ff60a..fcefaad8fa 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -1373,12 +1373,13 @@ def pg_bin(test_output_dir: str) -> PgBin: class VanillaPostgres(PgProtocol): - def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int): + def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int, init=True): super().__init__(host='localhost', port=port, dbname='postgres') self.pgdatadir = pgdatadir self.pg_bin = pg_bin self.running = False - self.pg_bin.run_capture(['initdb', '-D', pgdatadir]) + if init: + self.pg_bin.run_capture(['initdb', '-D', pgdatadir]) self.configure([f"port = {port}\n"]) def configure(self, options: List[str]): From 699f46cd84c23bdfbd382679e087a5b55da87eb6 Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Fri, 17 Jun 2022 15:33:39 +0300 Subject: [PATCH 03/11] Download WAL from S3 if it's not available in safekeeper dir (#1932) `send_wal.rs` and `WalReader` are now async. `test_s3_wal_replay` checks that WAL can be replayed after offloaded. --- safekeeper/src/json_ctrl.rs | 2 +- safekeeper/src/send_wal.rs | 225 +++++++++++------- safekeeper/src/timeline.rs | 52 +--- safekeeper/src/wal_backup.rs | 51 +++- safekeeper/src/wal_storage.rs | 125 +++++++--- test_runner/batch_others/test_wal_acceptor.py | 98 +++++++- 6 files changed, 379 insertions(+), 174 deletions(-) diff --git a/safekeeper/src/json_ctrl.rs b/safekeeper/src/json_ctrl.rs index 43514997d4..97fb3654d2 100644 --- a/safekeeper/src/json_ctrl.rs +++ b/safekeeper/src/json_ctrl.rs @@ -124,7 +124,7 @@ fn send_proposer_elected(spg: &mut SafekeeperPostgresHandler, term: Term, lsn: L term, start_streaming_at: lsn, term_history: history, - timeline_start_lsn: Lsn(0), + timeline_start_lsn: lsn, }); spg.timeline.get().process_msg(&proposer_elected_request)?; diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 11e5b963c9..a6b9de2050 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -13,9 +13,11 @@ use serde::{Deserialize, Serialize}; use std::cmp::min; use std::net::Shutdown; use std::sync::Arc; -use std::thread::sleep; use std::time::Duration; use std::{str, thread}; + +use tokio::sync::watch::Receiver; +use tokio::time::timeout; use tracing::*; use utils::{ bin_ser::BeSer, @@ -191,100 +193,143 @@ impl ReplicationConn { } })?; - let mut wal_seg_size: usize; - loop { - wal_seg_size = spg.timeline.get().get_state().1.server.wal_seg_size as usize; - if wal_seg_size == 0 { - error!("Cannot start replication before connecting to wal_proposer"); - sleep(Duration::from_secs(1)); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + runtime.block_on(async move { + let (_, persisted_state) = spg.timeline.get().get_state(); + if persisted_state.server.wal_seg_size == 0 + || persisted_state.timeline_start_lsn == Lsn(0) + { + bail!("Cannot start replication before connecting to walproposer"); + } + + let wal_end = spg.timeline.get().get_end_of_wal(); + // Walproposer gets special handling: safekeeper must give proposer all + // local WAL till the end, whether committed or not (walproposer will + // hang otherwise). That's because walproposer runs the consensus and + // synchronizes safekeepers on the most advanced one. + // + // There is a small risk of this WAL getting concurrently garbaged if + // another compute rises which collects majority and starts fixing log + // on this safekeeper itself. That's ok as (old) proposer will never be + // able to commit such WAL. + let stop_pos: Option = if spg.appname == Some("wal_proposer_recovery".to_string()) + { + Some(wal_end) } else { + None + }; + + info!("Start replication from {:?} till {:?}", start_pos, stop_pos); + + // switch to copy + pgb.write_message(&BeMessage::CopyBothResponse)?; + + let mut end_pos = Lsn(0); + + let mut wal_reader = WalReader::new( + spg.conf.timeline_dir(&spg.timeline.get().zttid), + &persisted_state, + start_pos, + spg.conf.wal_backup_enabled, + )?; + + // buffer for wal sending, limited by MAX_SEND_SIZE + let mut send_buf = vec![0u8; MAX_SEND_SIZE]; + + // watcher for commit_lsn updates + let mut commit_lsn_watch_rx = spg.timeline.get().get_commit_lsn_watch_rx(); + + loop { + if let Some(stop_pos) = stop_pos { + if start_pos >= stop_pos { + break; /* recovery finished */ + } + end_pos = stop_pos; + } else { + /* Wait until we have some data to stream */ + let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?; + + if let Some(lsn) = lsn { + end_pos = lsn; + } else { + // TODO: also check once in a while whether we are walsender + // to right pageserver. + if spg.timeline.get().stop_walsender(replica_id)? { + // Shut down, timeline is suspended. + // TODO create proper error type for this + bail!("end streaming to {:?}", spg.appname); + } + + // timeout expired: request pageserver status + pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive { + sent_ptr: end_pos.0, + timestamp: get_current_timestamp(), + request_reply: true, + })) + .context("Failed to send KeepAlive message")?; + continue; + } + } + + let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize; + let send_size = min(send_size, send_buf.len()); + + let send_buf = &mut send_buf[..send_size]; + + // read wal into buffer + let send_size = wal_reader.read(send_buf).await?; + let send_buf = &send_buf[..send_size]; + + // Write some data to the network socket. + pgb.write_message(&BeMessage::XLogData(XLogDataBody { + wal_start: start_pos.0, + wal_end: end_pos.0, + timestamp: get_current_timestamp(), + data: send_buf, + })) + .context("Failed to send XLogData")?; + + start_pos += send_size as u64; + trace!("sent WAL up to {}", start_pos); + } + + Ok(()) + }) + } +} + +const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); + +// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn. +async fn wait_for_lsn(rx: &mut Receiver, lsn: Lsn) -> Result> { + let commit_lsn: Lsn = *rx.borrow(); + if commit_lsn > lsn { + return Ok(Some(commit_lsn)); + } + + let res = timeout(POLL_STATE_TIMEOUT, async move { + let mut commit_lsn; + loop { + rx.changed().await?; + commit_lsn = *rx.borrow(); + if commit_lsn > lsn { break; } } - let wal_end = spg.timeline.get().get_end_of_wal(); - // Walproposer gets special handling: safekeeper must give proposer all - // local WAL till the end, whether committed or not (walproposer will - // hang otherwise). That's because walproposer runs the consensus and - // synchronizes safekeepers on the most advanced one. - // - // There is a small risk of this WAL getting concurrently garbaged if - // another compute rises which collects majority and starts fixing log - // on this safekeeper itself. That's ok as (old) proposer will never be - // able to commit such WAL. - let stop_pos: Option = if spg.appname == Some("wal_proposer_recovery".to_string()) { - Some(wal_end) - } else { - None - }; - info!("Start replication from {:?} till {:?}", start_pos, stop_pos); - // switch to copy - pgb.write_message(&BeMessage::CopyBothResponse)?; + Ok(commit_lsn) + }) + .await; - let mut end_pos = Lsn(0); - - let mut wal_reader = WalReader::new( - spg.conf.timeline_dir(&spg.timeline.get().zttid), - wal_seg_size, - start_pos, - ); - - // buffer for wal sending, limited by MAX_SEND_SIZE - let mut send_buf = vec![0u8; MAX_SEND_SIZE]; - - loop { - if let Some(stop_pos) = stop_pos { - if start_pos >= stop_pos { - break; /* recovery finished */ - } - end_pos = stop_pos; - } else { - /* Wait until we have some data to stream */ - let lsn = spg.timeline.get().wait_for_lsn(start_pos); - - if let Some(lsn) = lsn { - end_pos = lsn; - } else { - // TODO: also check once in a while whether we are walsender - // to right pageserver. - if spg.timeline.get().stop_walsender(replica_id)? { - // Shut down, timeline is suspended. - // TODO create proper error type for this - bail!("end streaming to {:?}", spg.appname); - } - - // timeout expired: request pageserver status - pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive { - sent_ptr: end_pos.0, - timestamp: get_current_timestamp(), - request_reply: true, - })) - .context("Failed to send KeepAlive message")?; - continue; - } - } - - let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize; - let send_size = min(send_size, send_buf.len()); - - let send_buf = &mut send_buf[..send_size]; - - // read wal into buffer - let send_size = wal_reader.read(send_buf)?; - let send_buf = &send_buf[..send_size]; - - // Write some data to the network socket. - pgb.write_message(&BeMessage::XLogData(XLogDataBody { - wal_start: start_pos.0, - wal_end: end_pos.0, - timestamp: get_current_timestamp(), - data: send_buf, - })) - .context("Failed to send XLogData")?; - - start_pos += send_size as u64; - trace!("sent WAL up to {}", start_pos); - } - Ok(()) + match res { + // success + Ok(Ok(commit_lsn)) => Ok(Some(commit_lsn)), + // error inside closure + Ok(Err(err)) => Err(err), + // timeout + Err(_) => Ok(None), } } diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index 39f2593dbc..2e415a53d0 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -14,8 +14,8 @@ use std::cmp::{max, min}; use std::collections::HashMap; use std::fs::{self}; -use std::sync::{Arc, Condvar, Mutex, MutexGuard}; -use std::time::Duration; +use std::sync::{Arc, Mutex, MutexGuard}; + use tokio::sync::mpsc::Sender; use tracing::*; @@ -37,8 +37,6 @@ use crate::wal_storage; use crate::wal_storage::Storage as wal_storage_iface; use crate::SafeKeeperConf; -const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); - /// Replica status update + hot standby feedback #[derive(Debug, Clone, Copy)] pub struct ReplicaState { @@ -77,9 +75,6 @@ impl ReplicaState { struct SharedState { /// Safekeeper object sk: SafeKeeper, - /// For receiving-sending wal cooperation - /// quorum commit LSN we've notified walsenders about - notified_commit_lsn: Lsn, /// State of replicas replicas: Vec>, /// True when WAL backup launcher oversees the timeline, making sure WAL is @@ -112,7 +107,6 @@ impl SharedState { let sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store, conf.my_id)?; Ok(Self { - notified_commit_lsn: Lsn(0), sk, replicas: Vec::new(), wal_backup_active: false, @@ -131,7 +125,6 @@ impl SharedState { info!("timeline {} restored", zttid.timeline_id); Ok(Self { - notified_commit_lsn: Lsn(0), sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store, conf.my_id)?, replicas: Vec::new(), wal_backup_active: false, @@ -271,8 +264,6 @@ pub struct Timeline { /// For breeding receivers. commit_lsn_watch_rx: watch::Receiver, mutex: Mutex, - /// conditional variable used to notify wal senders - cond: Condvar, } impl Timeline { @@ -289,7 +280,6 @@ impl Timeline { commit_lsn_watch_tx, commit_lsn_watch_rx, mutex: Mutex::new(shared_state), - cond: Condvar::new(), } } @@ -333,7 +323,7 @@ impl Timeline { let mut shared_state = self.mutex.lock().unwrap(); if shared_state.num_computes == 0 { let replica_state = shared_state.replicas[replica_id].unwrap(); - let stop = shared_state.notified_commit_lsn == Lsn(0) || // no data at all yet + let stop = shared_state.sk.inmem.commit_lsn == Lsn(0) || // no data at all yet (replica_state.remote_consistent_lsn != Lsn::MAX && // Lsn::MAX means that we don't know the latest LSN yet. replica_state.remote_consistent_lsn >= shared_state.sk.inmem.commit_lsn); if stop { @@ -405,39 +395,6 @@ impl Timeline { }) } - /// Timed wait for an LSN to be committed. - /// - /// Returns the last committed LSN, which will be at least - /// as high as the LSN waited for, or None if timeout expired. - /// - pub fn wait_for_lsn(&self, lsn: Lsn) -> Option { - let mut shared_state = self.mutex.lock().unwrap(); - loop { - let commit_lsn = shared_state.notified_commit_lsn; - // This must be `>`, not `>=`. - if commit_lsn > lsn { - return Some(commit_lsn); - } - let result = self - .cond - .wait_timeout(shared_state, POLL_STATE_TIMEOUT) - .unwrap(); - if result.1.timed_out() { - return None; - } - shared_state = result.0 - } - } - - // Notify caught-up WAL senders about new WAL data received - // TODO: replace-unify it with commit_lsn_watch. - fn notify_wal_senders(&self, shared_state: &mut MutexGuard) { - if shared_state.notified_commit_lsn < shared_state.sk.inmem.commit_lsn { - shared_state.notified_commit_lsn = shared_state.sk.inmem.commit_lsn; - self.cond.notify_all(); - } - } - pub fn get_commit_lsn_watch_rx(&self) -> watch::Receiver { self.commit_lsn_watch_rx.clone() } @@ -462,8 +419,6 @@ impl Timeline { } } - // Ping wal sender that new data might be available. - self.notify_wal_senders(&mut shared_state); commit_lsn = shared_state.sk.inmem.commit_lsn; } self.commit_lsn_watch_tx.send(commit_lsn)?; @@ -524,7 +479,6 @@ impl Timeline { return Ok(()); } shared_state.sk.record_safekeeper_info(sk_info)?; - self.notify_wal_senders(&mut shared_state); is_wal_backup_action_pending = shared_state.update_status(self.zttid); commit_lsn = shared_state.sk.inmem.commit_lsn; } diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index 1d7c8de3b8..8fada70e8b 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -2,6 +2,7 @@ use anyhow::{Context, Result}; use etcd_broker::subscription_key::{ NodeKind, OperationKind, SkOperationKind, SubscriptionKey, SubscriptionKind, }; +use tokio::io::AsyncRead; use tokio::task::JoinHandle; use std::cmp::min; @@ -10,7 +11,9 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; -use postgres_ffi::xlog_utils::{XLogFileName, XLogSegNo, XLogSegNoOffsetToRecPtr, PG_TLI}; +use postgres_ffi::xlog_utils::{ + XLogFileName, XLogSegNo, XLogSegNoOffsetToRecPtr, MAX_SEND_SIZE, PG_TLI, +}; use remote_storage::{GenericRemoteStorage, RemoteStorage}; use tokio::fs::File; use tokio::runtime::Builder; @@ -445,3 +448,49 @@ async fn backup_object(source_file: &Path, size: usize) -> Result<()> { Ok(()) } + +pub async fn read_object( + file_path: PathBuf, + offset: u64, +) -> (impl AsyncRead, JoinHandle>) { + let storage = REMOTE_STORAGE.get().expect("failed to get remote storage"); + + let (mut pipe_writer, pipe_reader) = tokio::io::duplex(MAX_SEND_SIZE); + + let copy_result = tokio::spawn(async move { + let res = match storage.as_ref().unwrap() { + GenericRemoteStorage::Local(local_storage) => { + let source = local_storage.remote_object_id(&file_path)?; + + info!( + "local download about to start from {} at offset {}", + source.display(), + offset + ); + local_storage + .download_byte_range(&source, offset, None, &mut pipe_writer) + .await + } + GenericRemoteStorage::S3(s3_storage) => { + let s3key = s3_storage.remote_object_id(&file_path)?; + + info!( + "S3 download about to start from {:?} at offset {}", + s3key, offset + ); + s3_storage + .download_byte_range(&s3key, offset, None, &mut pipe_writer) + .await + } + }; + + if let Err(e) = res { + error!("failed to download WAL segment from remote storage: {}", e); + Err(e) + } else { + Ok(()) + } + }); + + (pipe_reader, copy_result) +} diff --git a/safekeeper/src/wal_storage.rs b/safekeeper/src/wal_storage.rs index e3f1ce7333..5cfc96c84b 100644 --- a/safekeeper/src/wal_storage.rs +++ b/safekeeper/src/wal_storage.rs @@ -8,7 +8,9 @@ //! Note that last file has `.partial` suffix, that's different from postgres. use anyhow::{anyhow, bail, Context, Result}; -use std::io::{Read, Seek, SeekFrom}; +use std::io::{self, Seek, SeekFrom}; +use std::pin::Pin; +use tokio::io::AsyncRead; use lazy_static::lazy_static; use postgres_ffi::xlog_utils::{ @@ -26,6 +28,7 @@ use utils::{lsn::Lsn, zid::ZTenantTimelineId}; use crate::safekeeper::SafeKeeperState; +use crate::wal_backup::read_object; use crate::SafeKeeperConf; use postgres_ffi::xlog_utils::{XLogFileName, XLOG_BLCKSZ}; @@ -33,6 +36,8 @@ use postgres_ffi::waldecoder::WalStreamDecoder; use metrics::{register_histogram_vec, Histogram, HistogramVec, DISK_WRITE_SECONDS_BUCKETS}; +use tokio::io::{AsyncReadExt, AsyncSeekExt}; + lazy_static! { // The prometheus crate does not support u64 yet, i64 only (see `IntGauge`). // i64 is faster than f64, so update to u64 when available. @@ -504,69 +509,125 @@ pub struct WalReader { timeline_dir: PathBuf, wal_seg_size: usize, pos: Lsn, - file: Option, + wal_segment: Option>>, + + enable_remote_read: bool, + // S3 will be used to read WAL if LSN is not available locally + local_start_lsn: Lsn, } impl WalReader { - pub fn new(timeline_dir: PathBuf, wal_seg_size: usize, pos: Lsn) -> Self { - Self { - timeline_dir, - wal_seg_size, - pos, - file: None, + pub fn new( + timeline_dir: PathBuf, + state: &SafeKeeperState, + start_pos: Lsn, + enable_remote_read: bool, + ) -> Result { + if start_pos < state.timeline_start_lsn { + bail!( + "Requested streaming from {}, which is before the start of the timeline {}", + start_pos, + state.timeline_start_lsn + ); } + + if state.server.wal_seg_size == 0 + || state.timeline_start_lsn == Lsn(0) + || state.local_start_lsn == Lsn(0) + { + bail!("state uninitialized, no data to read"); + } + + Ok(Self { + timeline_dir, + wal_seg_size: state.server.wal_seg_size as usize, + pos: start_pos, + wal_segment: None, + enable_remote_read, + local_start_lsn: state.local_start_lsn, + }) } - pub fn read(&mut self, buf: &mut [u8]) -> Result { - // Take the `File` from `wal_file`, or open a new file. - let mut file = match self.file.take() { - Some(file) => file, - None => { - // Open a new file. - let segno = self.pos.segment_number(self.wal_seg_size); - let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size); - let wal_file_path = self.timeline_dir.join(wal_file_name); - Self::open_wal_file(&wal_file_path)? - } + pub async fn read(&mut self, buf: &mut [u8]) -> Result { + let mut wal_segment = match self.wal_segment.take() { + Some(reader) => reader, + None => self.open_segment().await?, }; - let xlogoff = self.pos.segment_offset(self.wal_seg_size) as usize; - // How much to read and send in message? We cannot cross the WAL file // boundary, and we don't want send more than provided buffer. + let xlogoff = self.pos.segment_offset(self.wal_seg_size) as usize; let send_size = min(buf.len(), self.wal_seg_size - xlogoff); // Read some data from the file. let buf = &mut buf[0..send_size]; - file.seek(SeekFrom::Start(xlogoff as u64)) - .and_then(|_| file.read_exact(buf)) - .context("Failed to read data from WAL file")?; - + let send_size = wal_segment.read_exact(buf).await?; self.pos += send_size as u64; - // Decide whether to reuse this file. If we don't set wal_file here - // a new file will be opened next time. + // Decide whether to reuse this file. If we don't set wal_segment here + // a new reader will be opened next time. if self.pos.segment_offset(self.wal_seg_size) != 0 { - self.file = Some(file); + self.wal_segment = Some(wal_segment); } Ok(send_size) } + /// Open WAL segment at the current position of the reader. + async fn open_segment(&self) -> Result>> { + let xlogoff = self.pos.segment_offset(self.wal_seg_size) as usize; + let segno = self.pos.segment_number(self.wal_seg_size); + let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size); + let wal_file_path = self.timeline_dir.join(wal_file_name); + + // Try to open local file, if we may have WAL locally + if self.pos >= self.local_start_lsn { + let res = Self::open_wal_file(&wal_file_path).await; + match res { + Ok(mut file) => { + file.seek(SeekFrom::Start(xlogoff as u64)).await?; + return Ok(Box::pin(file)); + } + Err(e) => { + let is_not_found = e.chain().any(|e| { + if let Some(e) = e.downcast_ref::() { + e.kind() == io::ErrorKind::NotFound + } else { + false + } + }); + if !is_not_found { + return Err(e); + } + // NotFound is expected, fall through to remote read + } + }; + } + + // Try to open remote file, if remote reads are enabled + if self.enable_remote_read { + let (reader, _) = read_object(wal_file_path, xlogoff as u64).await; + return Ok(Box::pin(reader)); + } + + bail!("WAL segment is not found") + } + /// Helper function for opening a wal file. - fn open_wal_file(wal_file_path: &Path) -> Result { + async fn open_wal_file(wal_file_path: &Path) -> Result { // First try to open the .partial file. let mut partial_path = wal_file_path.to_owned(); partial_path.set_extension("partial"); - if let Ok(opened_file) = File::open(&partial_path) { + if let Ok(opened_file) = tokio::fs::File::open(&partial_path).await { return Ok(opened_file); } // If that failed, try it without the .partial extension. - File::open(&wal_file_path) + tokio::fs::File::open(&wal_file_path) + .await .with_context(|| format!("Failed to open WAL file {:?}", wal_file_path)) .map_err(|e| { - error!("{}", e); + warn!("{}", e); e }) } diff --git a/test_runner/batch_others/test_wal_acceptor.py b/test_runner/batch_others/test_wal_acceptor.py index e4970272d4..05827baf86 100644 --- a/test_runner/batch_others/test_wal_acceptor.py +++ b/test_runner/batch_others/test_wal_acceptor.py @@ -2,6 +2,7 @@ import pytest import random import time import os +import shutil import signal import subprocess import sys @@ -353,7 +354,7 @@ def test_broker(neon_env_builder: NeonEnvBuilder): @pytest.mark.parametrize('auth_enabled', [False, True]) def test_wal_removal(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): neon_env_builder.num_safekeepers = 2 - # to advance remote_consistent_llsn + # to advance remote_consistent_lsn neon_env_builder.enable_local_fs_remote_storage() neon_env_builder.auth_enabled = auth_enabled env = neon_env_builder.init_start() @@ -437,6 +438,26 @@ def wait_segment_offload(tenant_id, timeline_id, live_sk, seg_end): time.sleep(0.5) +def wait_wal_trim(tenant_id, timeline_id, sk, target_size): + started_at = time.time() + http_cli = sk.http_client() + while True: + tli_status = http_cli.timeline_status(tenant_id, timeline_id) + sk_wal_size = get_dir_size(os.path.join(sk.data_dir(), tenant_id, + timeline_id)) / 1024 / 1024 + log.info(f"Safekeeper id={sk.id} wal_size={sk_wal_size:.2f}MB status={tli_status}") + + if sk_wal_size <= target_size: + break + + elapsed = time.time() - started_at + if elapsed > 20: + raise RuntimeError( + f"timed out waiting {elapsed:.0f}s for sk_id={sk.id} to trim WAL to {target_size:.2f}MB, current size is {sk_wal_size:.2f}MB" + ) + time.sleep(0.5) + + @pytest.mark.parametrize('storage_type', ['mock_s3', 'local_fs']) def test_wal_backup(neon_env_builder: NeonEnvBuilder, storage_type: str): neon_env_builder.num_safekeepers = 3 @@ -485,6 +506,81 @@ def test_wal_backup(neon_env_builder: NeonEnvBuilder, storage_type: str): wait_segment_offload(tenant_id, timeline_id, env.safekeepers[1], '0/5000000') +@pytest.mark.parametrize('storage_type', ['mock_s3', 'local_fs']) +def test_s3_wal_replay(neon_env_builder: NeonEnvBuilder, storage_type: str): + neon_env_builder.num_safekeepers = 3 + if storage_type == 'local_fs': + neon_env_builder.enable_local_fs_remote_storage() + elif storage_type == 'mock_s3': + neon_env_builder.enable_s3_mock_remote_storage('test_s3_wal_replay') + else: + raise RuntimeError(f'Unknown storage type: {storage_type}') + neon_env_builder.remote_storage_users = RemoteStorageUsers.SAFEKEEPER + + env = neon_env_builder.init_start() + env.neon_cli.create_branch('test_s3_wal_replay') + + env.pageserver.stop() + pageserver_tenants_dir = os.path.join(env.repo_dir, 'tenants') + pageserver_fresh_copy = os.path.join(env.repo_dir, 'tenants_fresh') + log.info(f"Creating a copy of pageserver in a fresh state at {pageserver_fresh_copy}") + shutil.copytree(pageserver_tenants_dir, pageserver_fresh_copy) + env.pageserver.start() + + pg = env.postgres.create_start('test_s3_wal_replay') + + # learn neon timeline from compute + tenant_id = pg.safe_psql("show neon.tenant_id")[0][0] + timeline_id = pg.safe_psql("show neon.timeline_id")[0][0] + + expected_sum = 0 + + with closing(pg.connect()) as conn: + with conn.cursor() as cur: + cur.execute("create table t(key int, value text)") + cur.execute("insert into t values (1, 'payload')") + expected_sum += 1 + + offloaded_seg_end = ['0/3000000'] + for seg_end in offloaded_seg_end: + # roughly fills two segments + cur.execute("insert into t select generate_series(1,500000), 'payload'") + expected_sum += 500000 * 500001 // 2 + + cur.execute("select sum(key) from t") + assert cur.fetchone()[0] == expected_sum + + for sk in env.safekeepers: + wait_segment_offload(tenant_id, timeline_id, sk, seg_end) + + # advance remote_consistent_lsn to trigger WAL trimming + # this LSN should be less than commit_lsn, so timeline will be active=true in safekeepers, to push etcd updates + env.safekeepers[0].http_client().record_safekeeper_info( + tenant_id, timeline_id, {'remote_consistent_lsn': offloaded_seg_end[-1]}) + + for sk in env.safekeepers: + # require WAL to be trimmed, so no more than one segment is left on disk + wait_wal_trim(tenant_id, timeline_id, sk, 16 * 1.5) + + # replace pageserver with a fresh copy + pg.stop_and_destroy() + env.pageserver.stop() + + log.info(f'Removing current pageserver state at {pageserver_tenants_dir}') + shutil.rmtree(pageserver_tenants_dir) + log.info(f'Copying fresh pageserver state from {pageserver_fresh_copy}') + shutil.move(pageserver_fresh_copy, pageserver_tenants_dir) + + # start everything, verify data + env.pageserver.start() + pg.create_start('test_s3_wal_replay') + + with closing(pg.connect()) as conn: + with conn.cursor() as cur: + cur.execute("select sum(key) from t") + assert cur.fetchone()[0] == expected_sum + + class ProposerPostgres(PgProtocol): """Object for running postgres without NeonEnv""" def __init__(self, From f862373ac0da301b906f6bbed9eea1c9f47bd0e4 Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Fri, 17 Jun 2022 20:43:54 +0300 Subject: [PATCH 04/11] Fix WAL timeout in test_s3_wal_replay (#1953) --- test_runner/batch_others/test_wal_acceptor.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/test_runner/batch_others/test_wal_acceptor.py b/test_runner/batch_others/test_wal_acceptor.py index 05827baf86..2b93dd160a 100644 --- a/test_runner/batch_others/test_wal_acceptor.py +++ b/test_runner/batch_others/test_wal_acceptor.py @@ -562,6 +562,16 @@ def test_s3_wal_replay(neon_env_builder: NeonEnvBuilder, storage_type: str): # require WAL to be trimmed, so no more than one segment is left on disk wait_wal_trim(tenant_id, timeline_id, sk, 16 * 1.5) + cur.execute('SELECT pg_current_wal_flush_lsn()') + last_lsn = cur.fetchone()[0] + + pageserver_lsn = env.pageserver.http_client().timeline_detail( + uuid.UUID(tenant_id), uuid.UUID((timeline_id)))["local"]["last_record_lsn"] + lag = lsn_from_hex(last_lsn) - lsn_from_hex(pageserver_lsn) + log.info( + f'Pageserver last_record_lsn={pageserver_lsn}; flush_lsn={last_lsn}; lag before replay is {lag / 1024}kb' + ) + # replace pageserver with a fresh copy pg.stop_and_destroy() env.pageserver.stop() @@ -571,8 +581,33 @@ def test_s3_wal_replay(neon_env_builder: NeonEnvBuilder, storage_type: str): log.info(f'Copying fresh pageserver state from {pageserver_fresh_copy}') shutil.move(pageserver_fresh_copy, pageserver_tenants_dir) - # start everything, verify data + # start pageserver and wait for replay env.pageserver.start() + wait_lsn_timeout = 60 * 3 + started_at = time.time() + last_debug_print = 0.0 + + while True: + elapsed = time.time() - started_at + if elapsed > wait_lsn_timeout: + raise RuntimeError(f'Timed out waiting for WAL redo') + + pageserver_lsn = env.pageserver.http_client().timeline_detail( + uuid.UUID(tenant_id), uuid.UUID((timeline_id)))["local"]["last_record_lsn"] + lag = lsn_from_hex(last_lsn) - lsn_from_hex(pageserver_lsn) + + if time.time() > last_debug_print + 10 or lag <= 0: + last_debug_print = time.time() + log.info(f'Pageserver last_record_lsn={pageserver_lsn}; lag is {lag / 1024}kb') + + if lag <= 0: + break + + time.sleep(1) + + log.info(f'WAL redo took {elapsed} s') + + # verify data pg.create_start('test_s3_wal_replay') with closing(pg.connect()) as conn: From 83c7e6ce527f26129d0e49e6d11593e109b06bea Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Mon, 20 Jun 2022 15:28:43 +0300 Subject: [PATCH 05/11] Bump vendor/postgres. This brings in the change to not use a shared memory in the WAL redo process, to avoid running out of sysv shmem segments in the page server. Also, removal of callmemaybe bits. --- vendor/postgres | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/postgres b/vendor/postgres index 50b6edfbe0..7faa67c3ca 160000 --- a/vendor/postgres +++ b/vendor/postgres @@ -1 +1 @@ -Subproject commit 50b6edfbe0c3b171bd6d407652e1e31a4c97aa8b +Subproject commit 7faa67c3ca53fcce51ae8fedf6b1af3b8cefd3e2 From ec0064c4425b606389417bd7c64cf407b5556a1a Mon Sep 17 00:00:00 2001 From: "Joshua D. Drake" Date: Mon, 20 Jun 2022 07:05:10 -0700 Subject: [PATCH 06/11] Small README.md changes (#1957) * Update make instructions for release and debug build. Update PostgreSQL glossary to proper version (14) * Continued cleanup of build instructions including removal of redundancies --- README.md | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index de9070ac0f..f63c21459e 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ Pageserver consists of: ## Running local installation -#### building on Linux +#### Installing dependencies on Linux 1. Install build dependencies and other useful packages * On Ubuntu or Debian this set of packages should be sufficient to build the code: @@ -49,14 +49,7 @@ dnf install flex bison readline-devel zlib-devel openssl-devel \ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh ``` -3. Build neon and patched postgres -```sh -git clone --recursive https://github.com/neondatabase/neon.git -cd neon -make -j`nproc` -``` - -#### building on OSX (12.3.1) +#### Installing dependencies on OSX (12.3.1) 1. Install XCode and dependencies ``` xcode-select --install @@ -76,10 +69,19 @@ brew install libpq brew link --force libpq ``` -4. Build neon and patched postgres -```sh +#### Building on Linux and OSX + +1. Build neon and patched postgres +``` +# Note: The path to the neon sources can not contain a space. + git clone --recursive https://github.com/neondatabase/neon.git cd neon + +# The preferred and default is to make a debug build. This will create a +# demonstrably slower build than a release build. If you want to use a release +# build, utilize "`BUILD_TYPE=release make -j`nproc``" + make -j`nproc` ``` @@ -209,7 +211,7 @@ Same applies to certain spelling: i.e. we use MB to denote 1024 * 1024 bytes, wh To get more familiar with this aspect, refer to: - [Neon glossary](/docs/glossary.md) -- [PostgreSQL glossary](https://www.postgresql.org/docs/13/glossary.html) +- [PostgreSQL glossary](https://www.postgresql.org/docs/14/glossary.html) - Other PostgreSQL documentation and sources (Neon fork sources can be found [here](https://github.com/neondatabase/postgres)) ## Join the development From 37465dafe3c34b586a88ba9ea40ca3de98994780 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 20 Jun 2022 11:40:55 -0400 Subject: [PATCH 07/11] Add wal backpressure tests (#1919) Resolves #1889. This PR adds new tests to measure the WAL backpressure's performance under different workloads. ## Changes - add new performance tests in `test_wal_backpressure.py` - allow safekeeper's fsync to be configurable when running tests --- test_runner/fixtures/neon_fixtures.py | 5 +- .../performance/test_wal_backpressure.py | 264 ++++++++++++++++++ 2 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 test_runner/performance/test_wal_backpressure.py diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index fcefaad8fa..51afd3a03d 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -500,6 +500,8 @@ class NeonEnvBuilder: num_safekeepers: int = 1, # Use non-standard SK ids to check for various parsing bugs safekeepers_id_start: int = 0, + # fsync is disabled by default to make the tests go faster + safekeepers_enable_fsync: bool = False, auth_enabled: bool = False, rust_log_override: Optional[str] = None, default_branch_name=DEFAULT_BRANCH_NAME): @@ -513,6 +515,7 @@ class NeonEnvBuilder: self.pageserver_config_override = pageserver_config_override self.num_safekeepers = num_safekeepers self.safekeepers_id_start = safekeepers_id_start + self.safekeepers_enable_fsync = safekeepers_enable_fsync self.auth_enabled = auth_enabled self.default_branch_name = default_branch_name self.env: Optional[NeonEnv] = None @@ -666,7 +669,7 @@ class NeonEnv: id = {id} pg_port = {port.pg} http_port = {port.http} - sync = false # Disable fsyncs to make the tests go faster""") + sync = {'true' if config.safekeepers_enable_fsync else 'false'}""") if config.auth_enabled: toml += textwrap.dedent(f""" auth_enabled = true diff --git a/test_runner/performance/test_wal_backpressure.py b/test_runner/performance/test_wal_backpressure.py new file mode 100644 index 0000000000..873d1132a7 --- /dev/null +++ b/test_runner/performance/test_wal_backpressure.py @@ -0,0 +1,264 @@ +import statistics +import threading +import time +import timeit +from typing import Callable + +import pytest +from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker +from fixtures.compare_fixtures import NeonCompare, PgCompare, VanillaCompare +from fixtures.log_helper import log +from fixtures.neon_fixtures import DEFAULT_BRANCH_NAME, NeonEnvBuilder, PgBin +from fixtures.utils import lsn_from_hex + +from performance.test_perf_pgbench import (get_durations_matrix, get_scales_matrix) + + +@pytest.fixture(params=["vanilla", "neon_off", "neon_on"]) +# This fixture constructs multiple `PgCompare` interfaces using a builder pattern. +# The builder parameters are encoded in the fixture's param. +# For example, to build a `NeonCompare` interface, the corresponding fixture's param should have +# a format of `neon_{safekeepers_enable_fsync}`. +# Note that, here "_" is used to separate builder parameters. +def pg_compare(request) -> PgCompare: + x = request.param.split("_") + + if x[0] == "vanilla": + # `VanillaCompare` interface + fixture = request.getfixturevalue("vanilla_compare") + assert isinstance(fixture, VanillaCompare) + + return fixture + else: + assert len(x) == 2, f"request param ({request.param}) should have a format of \ + `neon_{{safekeepers_enable_fsync}}`" + + # `NeonCompare` interface + neon_env_builder = request.getfixturevalue("neon_env_builder") + assert isinstance(neon_env_builder, NeonEnvBuilder) + + zenbenchmark = request.getfixturevalue("zenbenchmark") + assert isinstance(zenbenchmark, NeonBenchmarker) + + pg_bin = request.getfixturevalue("pg_bin") + assert isinstance(pg_bin, PgBin) + + neon_env_builder.safekeepers_enable_fsync = x[1] == "on" + + env = neon_env_builder.init_start() + env.neon_cli.create_branch("empty", ancestor_branch_name=DEFAULT_BRANCH_NAME) + + branch_name = request.node.name + return NeonCompare(zenbenchmark, env, pg_bin, branch_name) + + +def start_heavy_write_workload(env: PgCompare, n_tables: int, scale: int, num_iters: int): + """Start an intensive write workload across multiple tables. + + ## Single table workload: + At each step, insert new `new_rows_each_update` rows. + The variable `new_rows_each_update` is equal to `scale * 100_000`. + The number of steps is determined by `num_iters` variable.""" + new_rows_each_update = scale * 100_000 + + def start_single_table_workload(table_id: int): + for _ in range(num_iters): + with env.pg.connect().cursor() as cur: + cur.execute( + f"INSERT INTO t{table_id} SELECT FROM generate_series(1,{new_rows_each_update})" + ) + + with env.record_duration("run_duration"): + threads = [ + threading.Thread(target=start_single_table_workload, args=(i, )) + for i in range(n_tables) + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + +@pytest.mark.parametrize("n_tables", [5]) +@pytest.mark.parametrize("scale", get_scales_matrix(5)) +@pytest.mark.parametrize("num_iters", [10]) +def test_heavy_write_workload(pg_compare: PgCompare, n_tables: int, scale: int, num_iters: int): + env = pg_compare + + # Initializes test tables + with env.pg.connect().cursor() as cur: + for i in range(n_tables): + cur.execute( + f"CREATE TABLE t{i}(key serial primary key, t text default 'foooooooooooooooooooooooooooooooooooooooooooooooooooo')" + ) + cur.execute(f"INSERT INTO t{i} (key) VALUES (0)") + + workload_thread = threading.Thread(target=start_heavy_write_workload, + args=(env, n_tables, scale, num_iters)) + workload_thread.start() + + record_thread = threading.Thread(target=record_lsn_write_lag, + args=(env, lambda: workload_thread.is_alive())) + record_thread.start() + + record_read_latency(env, lambda: workload_thread.is_alive(), "SELECT * from t0 where key = 0") + workload_thread.join() + record_thread.join() + + +def start_pgbench_simple_update_workload(env: PgCompare, duration: int): + with env.record_duration("run_duration"): + env.pg_bin.run_capture([ + 'pgbench', + '-j10', + '-c10', + '-N', + f'-T{duration}', + '-Mprepared', + env.pg.connstr(options="-csynchronous_commit=off") + ]) + env.flush() + + +@pytest.mark.parametrize("scale", get_scales_matrix(100)) +@pytest.mark.parametrize("duration", get_durations_matrix()) +def test_pgbench_simple_update_workload(pg_compare: PgCompare, scale: int, duration: int): + env = pg_compare + + # initialize pgbench tables + env.pg_bin.run_capture(['pgbench', f'-s{scale}', '-i', env.pg.connstr()]) + env.flush() + + workload_thread = threading.Thread(target=start_pgbench_simple_update_workload, + args=(env, duration)) + workload_thread.start() + + record_thread = threading.Thread(target=record_lsn_write_lag, + args=(env, lambda: workload_thread.is_alive())) + record_thread.start() + + record_read_latency(env, + lambda: workload_thread.is_alive(), + "SELECT * from pgbench_accounts where aid = 1") + workload_thread.join() + record_thread.join() + + +def start_pgbench_intensive_initialization(env: PgCompare, scale: int): + with env.record_duration("run_duration"): + # Needs to increase the statement timeout (default: 120s) because the + # initialization step can be slow with a large scale. + env.pg_bin.run_capture([ + 'pgbench', + f'-s{scale}', + '-i', + '-Idtg', + env.pg.connstr(options='-cstatement_timeout=300s') + ]) + + +@pytest.mark.parametrize("scale", get_scales_matrix(1000)) +def test_pgbench_intensive_init_workload(pg_compare: PgCompare, scale: int): + env = pg_compare + with env.pg.connect().cursor() as cur: + cur.execute("CREATE TABLE foo as select generate_series(1,100000)") + + workload_thread = threading.Thread(target=start_pgbench_intensive_initialization, + args=(env, scale)) + workload_thread.start() + + record_thread = threading.Thread(target=record_lsn_write_lag, + args=(env, lambda: workload_thread.is_alive())) + record_thread.start() + + record_read_latency(env, lambda: workload_thread.is_alive(), "SELECT count(*) from foo") + workload_thread.join() + record_thread.join() + + +def record_lsn_write_lag(env: PgCompare, run_cond: Callable[[], bool], pool_interval: float = 1.0): + if not isinstance(env, NeonCompare): + return + + lsn_write_lags = [] + last_received_lsn = 0 + last_pg_flush_lsn = 0 + + with env.pg.connect().cursor() as cur: + cur.execute("CREATE EXTENSION neon") + + while run_cond(): + cur.execute(''' + select pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn), + pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn)), + pg_current_wal_flush_lsn(), + received_lsn + from backpressure_lsns(); + ''') + + res = cur.fetchone() + lsn_write_lags.append(res[0]) + + curr_received_lsn = lsn_from_hex(res[3]) + lsn_process_speed = (curr_received_lsn - last_received_lsn) / (1024**2) + last_received_lsn = curr_received_lsn + + curr_pg_flush_lsn = lsn_from_hex(res[2]) + lsn_produce_speed = (curr_pg_flush_lsn - last_pg_flush_lsn) / (1024**2) + last_pg_flush_lsn = curr_pg_flush_lsn + + log.info( + f"received_lsn_lag={res[1]}, pg_flush_lsn={res[2]}, received_lsn={res[3]}, lsn_process_speed={lsn_process_speed:.2f}MB/s, lsn_produce_speed={lsn_produce_speed:.2f}MB/s" + ) + + time.sleep(pool_interval) + + env.zenbenchmark.record("lsn_write_lag_max", + float(max(lsn_write_lags) / (1024**2)), + "MB", + MetricReport.LOWER_IS_BETTER) + env.zenbenchmark.record("lsn_write_lag_avg", + float(statistics.mean(lsn_write_lags) / (1024**2)), + "MB", + MetricReport.LOWER_IS_BETTER) + env.zenbenchmark.record("lsn_write_lag_stdev", + float(statistics.stdev(lsn_write_lags) / (1024**2)), + "MB", + MetricReport.LOWER_IS_BETTER) + + +def record_read_latency(env: PgCompare, + run_cond: Callable[[], bool], + read_query: str, + read_interval: float = 1.0): + read_latencies = [] + + with env.pg.connect().cursor() as cur: + while run_cond(): + try: + t1 = timeit.default_timer() + cur.execute(read_query) + t2 = timeit.default_timer() + + log.info( + f"Executed read query {read_query}, got {cur.fetchall()}, read time {t2-t1:.2f}s" + ) + read_latencies.append(t2 - t1) + except Exception as err: + log.error(f"Got error when executing the read query: {err}") + + time.sleep(read_interval) + + env.zenbenchmark.record("read_latency_max", + max(read_latencies), + 's', + MetricReport.LOWER_IS_BETTER) + env.zenbenchmark.record("read_latency_avg", + statistics.mean(read_latencies), + 's', + MetricReport.LOWER_IS_BETTER) + env.zenbenchmark.record("read_latency_stdev", + statistics.stdev(read_latencies), + 's', + MetricReport.LOWER_IS_BETTER) From 6c4d6a218386b7e63890cbe06ad7fabeaff3f801 Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Tue, 21 Jun 2022 02:02:24 +0300 Subject: [PATCH 08/11] Remove timeline_start_lsn check temporary. (#1964) --- safekeeper/src/send_wal.rs | 5 ++--- safekeeper/src/wal_storage.rs | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index a6b9de2050..7439d6a8f6 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -199,9 +199,8 @@ impl ReplicationConn { runtime.block_on(async move { let (_, persisted_state) = spg.timeline.get().get_state(); - if persisted_state.server.wal_seg_size == 0 - || persisted_state.timeline_start_lsn == Lsn(0) - { + // add persisted_state.timeline_start_lsn == Lsn(0) check + if persisted_state.server.wal_seg_size == 0 { bail!("Cannot start replication before connecting to walproposer"); } diff --git a/safekeeper/src/wal_storage.rs b/safekeeper/src/wal_storage.rs index 5cfc96c84b..5cb7a8c758 100644 --- a/safekeeper/src/wal_storage.rs +++ b/safekeeper/src/wal_storage.rs @@ -531,10 +531,8 @@ impl WalReader { ); } - if state.server.wal_seg_size == 0 - || state.timeline_start_lsn == Lsn(0) - || state.local_start_lsn == Lsn(0) - { + // TODO: add state.timeline_start_lsn == Lsn(0) check + if state.server.wal_seg_size == 0 || state.local_start_lsn == Lsn(0) { bail!("state uninitialized, no data to read"); } From 1ca28e6f3cc87840a28afc2a3cd4bcfb064de1c3 Mon Sep 17 00:00:00 2001 From: bojanserafimov Date: Tue, 21 Jun 2022 11:04:10 -0400 Subject: [PATCH 09/11] Import basebackup into pageserver (#1925) Allow importing basebackup taken from vanilla postgres or another pageserver via psql copy in protocol. --- Cargo.lock | 1 + control_plane/src/storage.rs | 53 ++- neon_local/src/main.rs | 53 ++- pageserver/Cargo.toml | 1 + pageserver/src/basebackup.rs | 29 +- pageserver/src/import_datadir.rs | 493 +++++++++++++++--------- pageserver/src/layered_repository.rs | 16 +- pageserver/src/page_service.rs | 231 ++++++++++- pageserver/src/pgdatadir_mapping.rs | 10 +- test_runner/batch_others/test_import.py | 193 ++++++++++ test_runner/fixtures/neon_fixtures.py | 4 +- 11 files changed, 875 insertions(+), 209 deletions(-) create mode 100644 test_runner/batch_others/test_import.py diff --git a/Cargo.lock b/Cargo.lock index c615766eb8..dca525941d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1842,6 +1842,7 @@ dependencies = [ "tracing", "url", "utils", + "walkdir", "workspace_hack", ] diff --git a/control_plane/src/storage.rs b/control_plane/src/storage.rs index a8f21406fb..f1eaa99904 100644 --- a/control_plane/src/storage.rs +++ b/control_plane/src/storage.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; -use std::io::Write; +use std::fs::File; +use std::io::{BufReader, Write}; use std::net::TcpStream; use std::num::NonZeroU64; use std::path::PathBuf; @@ -527,4 +528,54 @@ impl PageServerNode { Ok(timeline_info_response) } + + /// Import a basebackup prepared using either: + /// a) `pg_basebackup -F tar`, or + /// b) The `fullbackup` pageserver endpoint + /// + /// # Arguments + /// * `tenant_id` - tenant to import into. Created if not exists + /// * `timeline_id` - id to assign to imported timeline + /// * `base` - (start lsn of basebackup, path to `base.tar` file) + /// * `pg_wal` - if there's any wal to import: (end lsn, path to `pg_wal.tar`) + pub fn timeline_import( + &self, + tenant_id: ZTenantId, + timeline_id: ZTimelineId, + base: (Lsn, PathBuf), + pg_wal: Option<(Lsn, PathBuf)>, + ) -> anyhow::Result<()> { + let mut client = self.pg_connection_config.connect(NoTls).unwrap(); + + // Init base reader + let (start_lsn, base_tarfile_path) = base; + let base_tarfile = File::open(base_tarfile_path)?; + let mut base_reader = BufReader::new(base_tarfile); + + // Init wal reader if necessary + let (end_lsn, wal_reader) = if let Some((end_lsn, wal_tarfile_path)) = pg_wal { + let wal_tarfile = File::open(wal_tarfile_path)?; + let wal_reader = BufReader::new(wal_tarfile); + (end_lsn, Some(wal_reader)) + } else { + (start_lsn, None) + }; + + // Import base + let import_cmd = + format!("import basebackup {tenant_id} {timeline_id} {start_lsn} {end_lsn}"); + let mut writer = client.copy_in(&import_cmd)?; + io::copy(&mut base_reader, &mut writer)?; + writer.finish()?; + + // Import wal if necessary + if let Some(mut wal_reader) = wal_reader { + let import_cmd = format!("import wal {tenant_id} {timeline_id} {start_lsn} {end_lsn}"); + let mut writer = client.copy_in(&import_cmd)?; + io::copy(&mut wal_reader, &mut writer)?; + writer.finish()?; + } + + Ok(()) + } } diff --git a/neon_local/src/main.rs b/neon_local/src/main.rs index 8d39fe5d0d..35e2d9c9e2 100644 --- a/neon_local/src/main.rs +++ b/neon_local/src/main.rs @@ -14,7 +14,7 @@ use safekeeper::defaults::{ DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT, }; use std::collections::{BTreeSet, HashMap}; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::process::exit; use std::str::FromStr; use utils::{ @@ -159,6 +159,20 @@ fn main() -> Result<()> { .about("Create a new blank timeline") .arg(tenant_id_arg.clone()) .arg(branch_name_arg.clone())) + .subcommand(App::new("import") + .about("Import timeline from basebackup directory") + .arg(tenant_id_arg.clone()) + .arg(timeline_id_arg.clone()) + .arg(Arg::new("node-name").long("node-name").takes_value(true) + .help("Name to assign to the imported timeline")) + .arg(Arg::new("base-tarfile").long("base-tarfile").takes_value(true) + .help("Basebackup tarfile to import")) + .arg(Arg::new("base-lsn").long("base-lsn").takes_value(true) + .help("Lsn the basebackup starts at")) + .arg(Arg::new("wal-tarfile").long("wal-tarfile").takes_value(true) + .help("Wal to add after base")) + .arg(Arg::new("end-lsn").long("end-lsn").takes_value(true) + .help("Lsn the basebackup ends at"))) ).subcommand( App::new("tenant") .setting(AppSettings::ArgRequiredElseHelp) @@ -613,6 +627,43 @@ fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::LocalEnv) - timeline.timeline_id, last_record_lsn, tenant_id, ); } + Some(("import", import_match)) => { + let tenant_id = get_tenant_id(import_match, env)?; + let timeline_id = parse_timeline_id(import_match)?.expect("No timeline id provided"); + let name = import_match + .value_of("node-name") + .ok_or_else(|| anyhow!("No node name provided"))?; + + // Parse base inputs + let base_tarfile = import_match + .value_of("base-tarfile") + .map(|s| PathBuf::from_str(s).unwrap()) + .ok_or_else(|| anyhow!("No base-tarfile provided"))?; + let base_lsn = Lsn::from_str( + import_match + .value_of("base-lsn") + .ok_or_else(|| anyhow!("No base-lsn provided"))?, + )?; + let base = (base_lsn, base_tarfile); + + // Parse pg_wal inputs + let wal_tarfile = import_match + .value_of("wal-tarfile") + .map(|s| PathBuf::from_str(s).unwrap()); + let end_lsn = import_match + .value_of("end-lsn") + .map(|s| Lsn::from_str(s).unwrap()); + // TODO validate both or none are provided + let pg_wal = end_lsn.zip(wal_tarfile); + + let mut cplane = ComputeControlPlane::load(env.clone())?; + println!("Importing timeline into pageserver ..."); + pageserver.timeline_import(tenant_id, timeline_id, base, pg_wal)?; + println!("Creating node for imported timeline ..."); + env.register_branch_mapping(name.to_string(), tenant_id, timeline_id)?; + cplane.new_node(tenant_id, name, timeline_id, None, None)?; + println!("Done"); + } Some(("branch", branch_match)) => { let tenant_id = get_tenant_id(branch_match, env)?; let new_branch_name = branch_match diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 298addb838..b7d97a67c0 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -61,6 +61,7 @@ utils = { path = "../libs/utils" } remote_storage = { path = "../libs/remote_storage" } workspace_hack = { version = "0.1", path = "../workspace_hack" } close_fds = "0.3.2" +walkdir = "2.3.2" [dev-dependencies] hex-literal = "0.3" diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index 44a6442522..ed300b3360 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -112,6 +112,8 @@ where } pub fn send_tarball(mut self) -> anyhow::Result<()> { + // TODO include checksum + // Create pgdata subdirs structure for dir in pg_constants::PGDATA_SUBDIRS.iter() { let header = new_tar_header_dir(*dir)?; @@ -355,24 +357,21 @@ where pg_control.checkPointCopy = checkpoint; pg_control.state = pg_constants::DB_SHUTDOWNED; - // Postgres doesn't recognize the zenith.signal file and doesn't need it. - if !self.full_backup { - // add zenith.signal file - let mut zenith_signal = String::new(); - if self.prev_record_lsn == Lsn(0) { - if self.lsn == self.timeline.tline.get_ancestor_lsn() { - write!(zenith_signal, "PREV LSN: none")?; - } else { - write!(zenith_signal, "PREV LSN: invalid")?; - } + // add zenith.signal file + let mut zenith_signal = String::new(); + if self.prev_record_lsn == Lsn(0) { + if self.lsn == self.timeline.tline.get_ancestor_lsn() { + write!(zenith_signal, "PREV LSN: none")?; } else { - write!(zenith_signal, "PREV LSN: {}", self.prev_record_lsn)?; + write!(zenith_signal, "PREV LSN: invalid")?; } - self.ar.append( - &new_tar_header("zenith.signal", zenith_signal.len() as u64)?, - zenith_signal.as_bytes(), - )?; + } else { + write!(zenith_signal, "PREV LSN: {}", self.prev_record_lsn)?; } + self.ar.append( + &new_tar_header("zenith.signal", zenith_signal.len() as u64)?, + zenith_signal.as_bytes(), + )?; //send pg_control let pg_control_bytes = pg_control.encode(); diff --git a/pageserver/src/import_datadir.rs b/pageserver/src/import_datadir.rs index 703ee8f1b1..3ede949885 100644 --- a/pageserver/src/import_datadir.rs +++ b/pageserver/src/import_datadir.rs @@ -2,7 +2,6 @@ //! Import data and WAL from a PostgreSQL data directory and WAL segments into //! a zenith Timeline. //! -use std::fs; use std::fs::File; use std::io::{Read, Seek, SeekFrom}; use std::path::{Path, PathBuf}; @@ -10,16 +9,18 @@ use std::path::{Path, PathBuf}; use anyhow::{bail, ensure, Context, Result}; use bytes::Bytes; use tracing::*; +use walkdir::WalkDir; use crate::pgdatadir_mapping::*; use crate::reltag::{RelTag, SlruKind}; use crate::repository::Repository; +use crate::repository::Timeline; use crate::walingest::WalIngest; use postgres_ffi::relfile_utils::*; use postgres_ffi::waldecoder::*; use postgres_ffi::xlog_utils::*; +use postgres_ffi::Oid; use postgres_ffi::{pg_constants, ControlFileData, DBState_DB_SHUTDOWNED}; -use postgres_ffi::{Oid, TransactionId}; use utils::lsn::Lsn; /// @@ -35,100 +36,29 @@ pub fn import_timeline_from_postgres_datadir( ) -> Result<()> { let mut pg_control: Option = None; + // TODO this shoud be start_lsn, which is not necessarily equal to end_lsn (aka lsn) + // Then fishing out pg_control would be unnecessary let mut modification = tline.begin_modification(lsn); modification.init_empty()?; - // Scan 'global' - let mut relfiles: Vec = Vec::new(); - for direntry in fs::read_dir(path.join("global"))? { - let direntry = direntry?; - match direntry.file_name().to_str() { - None => continue, + // Import all but pg_wal + let all_but_wal = WalkDir::new(path) + .into_iter() + .filter_entry(|entry| !entry.path().ends_with("pg_wal")); + for entry in all_but_wal { + let entry = entry?; + let metadata = entry.metadata().expect("error getting dir entry metadata"); + if metadata.is_file() { + let absolute_path = entry.path(); + let relative_path = absolute_path.strip_prefix(path)?; - Some("pg_control") => { - pg_control = Some(import_control_file(&mut modification, &direntry.path())?); - } - Some("pg_filenode.map") => { - import_relmap_file( - &mut modification, - pg_constants::GLOBALTABLESPACE_OID, - 0, - &direntry.path(), - )?; - } - - // Load any relation files into the page server (but only after the other files) - _ => relfiles.push(direntry.path()), - } - } - for relfile in relfiles { - import_relfile( - &mut modification, - &relfile, - pg_constants::GLOBALTABLESPACE_OID, - 0, - )?; - } - - // Scan 'base'. It contains database dirs, the database OID is the filename. - // E.g. 'base/12345', where 12345 is the database OID. - for direntry in fs::read_dir(path.join("base"))? { - let direntry = direntry?; - - //skip all temporary files - if direntry.file_name().to_string_lossy() == "pgsql_tmp" { - continue; - } - - let dboid = direntry.file_name().to_string_lossy().parse::()?; - - let mut relfiles: Vec = Vec::new(); - for direntry in fs::read_dir(direntry.path())? { - let direntry = direntry?; - match direntry.file_name().to_str() { - None => continue, - - Some("PG_VERSION") => { - //modification.put_dbdir_creation(pg_constants::DEFAULTTABLESPACE_OID, dboid)?; - } - Some("pg_filenode.map") => import_relmap_file( - &mut modification, - pg_constants::DEFAULTTABLESPACE_OID, - dboid, - &direntry.path(), - )?, - - // Load any relation files into the page server - _ => relfiles.push(direntry.path()), + let file = File::open(absolute_path)?; + let len = metadata.len() as usize; + if let Some(control_file) = import_file(&mut modification, relative_path, file, len)? { + pg_control = Some(control_file); } } - for relfile in relfiles { - import_relfile( - &mut modification, - &relfile, - pg_constants::DEFAULTTABLESPACE_OID, - dboid, - )?; - } } - for entry in fs::read_dir(path.join("pg_xact"))? { - let entry = entry?; - import_slru_file(&mut modification, SlruKind::Clog, &entry.path())?; - } - for entry in fs::read_dir(path.join("pg_multixact").join("members"))? { - let entry = entry?; - import_slru_file(&mut modification, SlruKind::MultiXactMembers, &entry.path())?; - } - for entry in fs::read_dir(path.join("pg_multixact").join("offsets"))? { - let entry = entry?; - import_slru_file(&mut modification, SlruKind::MultiXactOffsets, &entry.path())?; - } - for entry in fs::read_dir(path.join("pg_twophase"))? { - let entry = entry?; - let xid = u32::from_str_radix(&entry.path().to_string_lossy(), 16)?; - import_twophase_file(&mut modification, xid, &entry.path())?; - } - // TODO: Scan pg_tblspc // We're done importing all the data files. modification.commit()?; @@ -158,31 +88,30 @@ pub fn import_timeline_from_postgres_datadir( } // subroutine of import_timeline_from_postgres_datadir(), to load one relation file. -fn import_relfile( +fn import_rel( modification: &mut DatadirModification, path: &Path, spcoid: Oid, dboid: Oid, + mut reader: Reader, + len: usize, ) -> anyhow::Result<()> { // Does it look like a relation file? trace!("importing rel file {}", path.display()); - let (relnode, forknum, segno) = parse_relfilename(&path.file_name().unwrap().to_string_lossy()) - .map_err(|e| { - warn!("unrecognized file in postgres datadir: {:?} ({})", path, e); - e - })?; + let filename = &path + .file_name() + .expect("missing rel filename") + .to_string_lossy(); + let (relnode, forknum, segno) = parse_relfilename(filename).map_err(|e| { + warn!("unrecognized file in postgres datadir: {:?} ({})", path, e); + e + })?; - let mut file = File::open(path)?; let mut buf: [u8; 8192] = [0u8; 8192]; - let len = file.metadata().unwrap().len(); - ensure!(len % pg_constants::BLCKSZ as u64 == 0); - let nblocks = len / pg_constants::BLCKSZ as u64; - - if segno != 0 { - todo!(); - } + ensure!(len % pg_constants::BLCKSZ as usize == 0); + let nblocks = len / pg_constants::BLCKSZ as usize; let rel = RelTag { spcnode: spcoid, @@ -190,11 +119,22 @@ fn import_relfile( relnode, forknum, }; - modification.put_rel_creation(rel, nblocks as u32)?; let mut blknum: u32 = segno * (1024 * 1024 * 1024 / pg_constants::BLCKSZ as u32); + + // Call put_rel_creation for every segment of the relation, + // because there is no guarantee about the order in which we are processing segments. + // ignore "relation already exists" error + if let Err(e) = modification.put_rel_creation(rel, nblocks as u32) { + if e.to_string().contains("already exists") { + debug!("relation {} already exists. we must be extending it", rel); + } else { + return Err(e); + } + } + loop { - let r = file.read_exact(&mut buf); + let r = reader.read_exact(&mut buf); match r { Ok(_) => { modification.put_rel_page_image(rel, blknum, Bytes::copy_from_slice(&buf))?; @@ -204,7 +144,9 @@ fn import_relfile( Err(err) => match err.kind() { std::io::ErrorKind::UnexpectedEof => { // reached EOF. That's expected. - ensure!(blknum == nblocks as u32, "unexpected EOF"); + let relative_blknum = + blknum - segno * (1024 * 1024 * 1024 / pg_constants::BLCKSZ as u32); + ensure!(relative_blknum == nblocks as u32, "unexpected EOF"); break; } _ => { @@ -215,96 +157,43 @@ fn import_relfile( blknum += 1; } + // Update relation size + // + // If we process rel segments out of order, + // put_rel_extend will skip the update. + modification.put_rel_extend(rel, blknum)?; + Ok(()) } -/// Import a relmapper (pg_filenode.map) file into the repository -fn import_relmap_file( - modification: &mut DatadirModification, - spcnode: Oid, - dbnode: Oid, - path: &Path, -) -> Result<()> { - let mut file = File::open(path)?; - let mut buffer = Vec::new(); - // read the whole file - file.read_to_end(&mut buffer)?; - - trace!("importing relmap file {}", path.display()); - - modification.put_relmap_file(spcnode, dbnode, Bytes::copy_from_slice(&buffer[..]))?; - Ok(()) -} - -/// Import a twophase state file (pg_twophase/) into the repository -fn import_twophase_file( - modification: &mut DatadirModification, - xid: TransactionId, - path: &Path, -) -> Result<()> { - let mut file = File::open(path)?; - let mut buffer = Vec::new(); - // read the whole file - file.read_to_end(&mut buffer)?; - - trace!("importing non-rel file {}", path.display()); - - modification.put_twophase_file(xid, Bytes::copy_from_slice(&buffer[..]))?; - Ok(()) -} - -/// -/// Import pg_control file into the repository. -/// -/// The control file is imported as is, but we also extract the checkpoint record -/// from it and store it separated. -fn import_control_file( - modification: &mut DatadirModification, - path: &Path, -) -> Result { - let mut file = File::open(path)?; - let mut buffer = Vec::new(); - // read the whole file - file.read_to_end(&mut buffer)?; - - trace!("importing control file {}", path.display()); - - // Import it as ControlFile - modification.put_control_file(Bytes::copy_from_slice(&buffer[..]))?; - - // Extract the checkpoint record and import it separately. - let pg_control = ControlFileData::decode(&buffer)?; - let checkpoint_bytes = pg_control.checkPointCopy.encode()?; - modification.put_checkpoint(checkpoint_bytes)?; - - Ok(pg_control) -} - -/// /// Import an SLRU segment file /// -fn import_slru_file( +fn import_slru( modification: &mut DatadirModification, slru: SlruKind, path: &Path, + mut reader: Reader, + len: usize, ) -> Result<()> { trace!("importing slru file {}", path.display()); - let mut file = File::open(path)?; let mut buf: [u8; 8192] = [0u8; 8192]; - let segno = u32::from_str_radix(&path.file_name().unwrap().to_string_lossy(), 16)?; + let filename = &path + .file_name() + .expect("missing slru filename") + .to_string_lossy(); + let segno = u32::from_str_radix(filename, 16)?; - let len = file.metadata().unwrap().len(); - ensure!(len % pg_constants::BLCKSZ as u64 == 0); // we assume SLRU block size is the same as BLCKSZ - let nblocks = len / pg_constants::BLCKSZ as u64; + ensure!(len % pg_constants::BLCKSZ as usize == 0); // we assume SLRU block size is the same as BLCKSZ + let nblocks = len / pg_constants::BLCKSZ as usize; - ensure!(nblocks <= pg_constants::SLRU_PAGES_PER_SEGMENT as u64); + ensure!(nblocks <= pg_constants::SLRU_PAGES_PER_SEGMENT as usize); modification.put_slru_segment_creation(slru, segno, nblocks as u32)?; let mut rpageno = 0; loop { - let r = file.read_exact(&mut buf); + let r = reader.read_exact(&mut buf); match r { Ok(_) => { modification.put_slru_page_image( @@ -396,10 +285,258 @@ fn import_wal( } if last_lsn != startpoint { - debug!("reached end of WAL at {}", last_lsn); + info!("reached end of WAL at {}", last_lsn); } else { info!("no WAL to import at {}", last_lsn); } Ok(()) } + +pub fn import_basebackup_from_tar( + tline: &mut DatadirTimeline, + reader: Reader, + base_lsn: Lsn, +) -> Result<()> { + info!("importing base at {}", base_lsn); + let mut modification = tline.begin_modification(base_lsn); + modification.init_empty()?; + + let mut pg_control: Option = None; + + // Import base + for base_tar_entry in tar::Archive::new(reader).entries()? { + let entry = base_tar_entry?; + let header = entry.header(); + let len = header.entry_size()? as usize; + let file_path = header.path()?.into_owned(); + + match header.entry_type() { + tar::EntryType::Regular => { + if let Some(res) = import_file(&mut modification, file_path.as_ref(), entry, len)? { + // We found the pg_control file. + pg_control = Some(res); + } + } + tar::EntryType::Directory => { + debug!("directory {:?}", file_path); + } + _ => { + panic!("tar::EntryType::?? {}", file_path.display()); + } + } + } + + // sanity check: ensure that pg_control is loaded + let _pg_control = pg_control.context("pg_control file not found")?; + + modification.commit()?; + Ok(()) +} + +pub fn import_wal_from_tar( + tline: &mut DatadirTimeline, + reader: Reader, + start_lsn: Lsn, + end_lsn: Lsn, +) -> Result<()> { + // Set up walingest mutable state + let mut waldecoder = WalStreamDecoder::new(start_lsn); + let mut segno = start_lsn.segment_number(pg_constants::WAL_SEGMENT_SIZE); + let mut offset = start_lsn.segment_offset(pg_constants::WAL_SEGMENT_SIZE); + let mut last_lsn = start_lsn; + let mut walingest = WalIngest::new(tline, start_lsn)?; + + // Ingest wal until end_lsn + info!("importing wal until {}", end_lsn); + let mut pg_wal_tar = tar::Archive::new(reader); + let mut pg_wal_entries_iter = pg_wal_tar.entries()?; + while last_lsn <= end_lsn { + let bytes = { + let entry = pg_wal_entries_iter.next().expect("expected more wal")?; + let header = entry.header(); + let file_path = header.path()?.into_owned(); + + match header.entry_type() { + tar::EntryType::Regular => { + // FIXME: assume postgresql tli 1 for now + let expected_filename = XLogFileName(1, segno, pg_constants::WAL_SEGMENT_SIZE); + let file_name = file_path + .file_name() + .expect("missing wal filename") + .to_string_lossy(); + ensure!(expected_filename == file_name); + + debug!("processing wal file {:?}", file_path); + read_all_bytes(entry)? + } + tar::EntryType::Directory => { + debug!("directory {:?}", file_path); + continue; + } + _ => { + panic!("tar::EntryType::?? {}", file_path.display()); + } + } + }; + + waldecoder.feed_bytes(&bytes[offset..]); + + while last_lsn <= end_lsn { + if let Some((lsn, recdata)) = waldecoder.poll_decode()? { + walingest.ingest_record(tline, recdata, lsn)?; + last_lsn = lsn; + + debug!("imported record at {} (end {})", lsn, end_lsn); + } + } + + debug!("imported records up to {}", last_lsn); + segno += 1; + offset = 0; + } + + if last_lsn != start_lsn { + info!("reached end of WAL at {}", last_lsn); + } else { + info!("there was no WAL to import at {}", last_lsn); + } + + // Log any extra unused files + for e in &mut pg_wal_entries_iter { + let entry = e?; + let header = entry.header(); + let file_path = header.path()?.into_owned(); + info!("skipping {:?}", file_path); + } + + Ok(()) +} + +pub fn import_file( + modification: &mut DatadirModification, + file_path: &Path, + reader: Reader, + len: usize, +) -> Result> { + debug!("looking at {:?}", file_path); + + if file_path.starts_with("global") { + let spcnode = pg_constants::GLOBALTABLESPACE_OID; + let dbnode = 0; + + match file_path + .file_name() + .expect("missing filename") + .to_string_lossy() + .as_ref() + { + "pg_control" => { + let bytes = read_all_bytes(reader)?; + + // Extract the checkpoint record and import it separately. + let pg_control = ControlFileData::decode(&bytes[..])?; + let checkpoint_bytes = pg_control.checkPointCopy.encode()?; + modification.put_checkpoint(checkpoint_bytes)?; + debug!("imported control file"); + + // Import it as ControlFile + modification.put_control_file(bytes)?; + return Ok(Some(pg_control)); + } + "pg_filenode.map" => { + let bytes = read_all_bytes(reader)?; + modification.put_relmap_file(spcnode, dbnode, bytes)?; + debug!("imported relmap file") + } + "PG_VERSION" => { + debug!("ignored"); + } + _ => { + import_rel(modification, file_path, spcnode, dbnode, reader, len)?; + debug!("imported rel creation"); + } + } + } else if file_path.starts_with("base") { + let spcnode = pg_constants::DEFAULTTABLESPACE_OID; + let dbnode: u32 = file_path + .iter() + .nth(1) + .expect("invalid file path, expected dbnode") + .to_string_lossy() + .parse()?; + + match file_path + .file_name() + .expect("missing base filename") + .to_string_lossy() + .as_ref() + { + "pg_filenode.map" => { + let bytes = read_all_bytes(reader)?; + modification.put_relmap_file(spcnode, dbnode, bytes)?; + debug!("imported relmap file") + } + "PG_VERSION" => { + debug!("ignored"); + } + _ => { + import_rel(modification, file_path, spcnode, dbnode, reader, len)?; + debug!("imported rel creation"); + } + } + } else if file_path.starts_with("pg_xact") { + let slru = SlruKind::Clog; + + import_slru(modification, slru, file_path, reader, len)?; + debug!("imported clog slru"); + } else if file_path.starts_with("pg_multixact/offsets") { + let slru = SlruKind::MultiXactOffsets; + + import_slru(modification, slru, file_path, reader, len)?; + debug!("imported multixact offsets slru"); + } else if file_path.starts_with("pg_multixact/members") { + let slru = SlruKind::MultiXactMembers; + + import_slru(modification, slru, file_path, reader, len)?; + debug!("imported multixact members slru"); + } else if file_path.starts_with("pg_twophase") { + let file_name = &file_path + .file_name() + .expect("missing twophase filename") + .to_string_lossy(); + let xid = u32::from_str_radix(file_name, 16)?; + + let bytes = read_all_bytes(reader)?; + modification.put_twophase_file(xid, Bytes::copy_from_slice(&bytes[..]))?; + debug!("imported twophase file"); + } else if file_path.starts_with("pg_wal") { + debug!("found wal file in base section. ignore it"); + } else if file_path.starts_with("zenith.signal") { + // Parse zenith signal file to set correct previous LSN + let bytes = read_all_bytes(reader)?; + // zenith.signal format is "PREV LSN: prev_lsn" + let zenith_signal = std::str::from_utf8(&bytes)?; + let zenith_signal = zenith_signal.split(':').collect::>(); + let prev_lsn = zenith_signal[1].trim().parse::()?; + + let writer = modification.tline.tline.writer(); + writer.finish_write(prev_lsn); + + debug!("imported zenith signal {}", prev_lsn); + } else if file_path.starts_with("pg_tblspc") { + // TODO Backups exported from neon won't have pg_tblspc, but we will need + // this to import arbitrary postgres databases. + bail!("Importing pg_tblspc is not implemented"); + } else { + debug!("ignored"); + } + + Ok(None) +} + +fn read_all_bytes(mut reader: Reader) -> Result { + let mut buf: Vec = vec![]; + reader.read_to_end(&mut buf)?; + Ok(Bytes::copy_from_slice(&buf[..])) +} diff --git a/pageserver/src/layered_repository.rs b/pageserver/src/layered_repository.rs index 5c5b03268a..fdd03ecf8b 100644 --- a/pageserver/src/layered_repository.rs +++ b/pageserver/src/layered_repository.rs @@ -243,15 +243,15 @@ impl Repository for LayeredRepository { ); timeline.layers.write().unwrap().next_open_layer_at = Some(initdb_lsn); + // Insert if not exists let timeline = Arc::new(timeline); - let r = timelines.insert( - timelineid, - LayeredTimelineEntry::Loaded(Arc::clone(&timeline)), - ); - ensure!( - r.is_none(), - "assertion failure, inserted duplicate timeline" - ); + match timelines.entry(timelineid) { + Entry::Occupied(_) => bail!("Timeline already exists"), + Entry::Vacant(vacant) => { + vacant.insert(LayeredTimelineEntry::Loaded(Arc::clone(&timeline))) + } + }; + Ok(timeline) } diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 406228f034..079f477f75 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -13,7 +13,7 @@ use anyhow::{bail, ensure, Context, Result}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use lazy_static::lazy_static; use regex::Regex; -use std::io; +use std::io::{self, Read}; use std::net::TcpListener; use std::str; use std::str::FromStr; @@ -29,6 +29,8 @@ use utils::{ use crate::basebackup; use crate::config::{PageServerConf, ProfilingConfig}; +use crate::import_datadir::{import_basebackup_from_tar, import_wal_from_tar}; +use crate::layered_repository::LayeredRepository; use crate::pgdatadir_mapping::{DatadirTimeline, LsnForTimestamp}; use crate::profiling::profpoint_start; use crate::reltag::RelTag; @@ -200,6 +202,96 @@ impl PagestreamBeMessage { } } +/// Implements Read for the server side of CopyIn +struct CopyInReader<'a> { + pgb: &'a mut PostgresBackend, + + /// Overflow buffer for bytes sent in CopyData messages + /// that the reader (caller of read) hasn't asked for yet. + /// TODO use BytesMut? + buf: Vec, + + /// Bytes before `buf_begin` are considered as dropped. + /// This allows us to implement O(1) pop_front on Vec. + /// The Vec won't grow large because we only add to it + /// when it's empty. + buf_begin: usize, +} + +impl<'a> CopyInReader<'a> { + // NOTE: pgb should be in copy in state already + fn new(pgb: &'a mut PostgresBackend) -> Self { + Self { + pgb, + buf: Vec::<_>::new(), + buf_begin: 0, + } + } +} + +impl<'a> Drop for CopyInReader<'a> { + fn drop(&mut self) { + // Finalize copy protocol so that self.pgb can be reused + // TODO instead, maybe take ownership of pgb and give it back at the end + let mut buf: Vec = vec![]; + let _ = self.read_to_end(&mut buf); + } +} + +impl<'a> Read for CopyInReader<'a> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + while !thread_mgr::is_shutdown_requested() { + // Return from buffer if nonempty + if self.buf_begin < self.buf.len() { + let bytes_to_read = std::cmp::min(buf.len(), self.buf.len() - self.buf_begin); + buf[..bytes_to_read].copy_from_slice(&self.buf[self.buf_begin..][..bytes_to_read]); + self.buf_begin += bytes_to_read; + return Ok(bytes_to_read); + } + + // Delete garbage + self.buf.clear(); + self.buf_begin = 0; + + // Wait for client to send CopyData bytes + match self.pgb.read_message() { + Ok(Some(message)) => { + let copy_data_bytes = match message { + FeMessage::CopyData(bytes) => bytes, + FeMessage::CopyDone => return Ok(0), + FeMessage::Sync => continue, + m => { + let msg = format!("unexpected message {:?}", m); + self.pgb.write_message(&BeMessage::ErrorResponse(&msg))?; + return Err(io::Error::new(io::ErrorKind::Other, msg)); + } + }; + + // Return as much as we can, saving the rest in self.buf + let mut reader = copy_data_bytes.reader(); + let bytes_read = reader.read(buf)?; + reader.read_to_end(&mut self.buf)?; + return Ok(bytes_read); + } + Ok(None) => { + let msg = "client closed connection"; + self.pgb.write_message(&BeMessage::ErrorResponse(msg))?; + return Err(io::Error::new(io::ErrorKind::Other, msg)); + } + Err(e) => { + if !is_socket_read_timed_out(&e) { + return Err(io::Error::new(io::ErrorKind::Other, e)); + } + } + } + } + + // Shutting down + let msg = "Importer thread was shut down"; + Err(io::Error::new(io::ErrorKind::Other, msg)) + } +} + /////////////////////////////////////////////////////////////////////////////// /// @@ -447,6 +539,98 @@ impl PageServerHandler { Ok(()) } + fn handle_import_basebackup( + &self, + pgb: &mut PostgresBackend, + tenant_id: ZTenantId, + timeline_id: ZTimelineId, + base_lsn: Lsn, + _end_lsn: Lsn, + ) -> anyhow::Result<()> { + thread_mgr::associate_with(Some(tenant_id), Some(timeline_id)); + let _enter = + info_span!("import basebackup", timeline = %timeline_id, tenant = %tenant_id).entered(); + + // Create empty timeline + info!("creating new timeline"); + let repo = tenant_mgr::get_repository_for_tenant(tenant_id)?; + let timeline = repo.create_empty_timeline(timeline_id, Lsn(0))?; + let repartition_distance = repo.get_checkpoint_distance(); + let mut datadir_timeline = + DatadirTimeline::::new(timeline, repartition_distance); + + // TODO mark timeline as not ready until it reaches end_lsn. + // We might have some wal to import as well, and we should prevent compute + // from connecting before that and writing conflicting wal. + // + // This is not relevant for pageserver->pageserver migrations, since there's + // no wal to import. But should be fixed if we want to import from postgres. + + // TODO leave clean state on error. For now you can use detach to clean + // up broken state from a failed import. + + // Import basebackup provided via CopyData + info!("importing basebackup"); + pgb.write_message(&BeMessage::CopyInResponse)?; + let reader = CopyInReader::new(pgb); + import_basebackup_from_tar(&mut datadir_timeline, reader, base_lsn)?; + + // TODO check checksum + // Meanwhile you can verify client-side by taking fullbackup + // and checking that it matches in size with what was imported. + // It wouldn't work if base came from vanilla postgres though, + // since we discard some log files. + + // Flush data to disk, then upload to s3 + info!("flushing layers"); + datadir_timeline.tline.checkpoint(CheckpointConfig::Flush)?; + + info!("done"); + Ok(()) + } + + fn handle_import_wal( + &self, + pgb: &mut PostgresBackend, + tenant_id: ZTenantId, + timeline_id: ZTimelineId, + start_lsn: Lsn, + end_lsn: Lsn, + ) -> anyhow::Result<()> { + thread_mgr::associate_with(Some(tenant_id), Some(timeline_id)); + let _enter = + info_span!("import wal", timeline = %timeline_id, tenant = %tenant_id).entered(); + + let repo = tenant_mgr::get_repository_for_tenant(tenant_id)?; + let timeline = repo.get_timeline_load(timeline_id)?; + ensure!(timeline.get_last_record_lsn() == start_lsn); + + let repartition_distance = repo.get_checkpoint_distance(); + let mut datadir_timeline = + DatadirTimeline::::new(timeline, repartition_distance); + + // TODO leave clean state on error. For now you can use detach to clean + // up broken state from a failed import. + + // Import wal provided via CopyData + info!("importing wal"); + pgb.write_message(&BeMessage::CopyInResponse)?; + let reader = CopyInReader::new(pgb); + import_wal_from_tar(&mut datadir_timeline, reader, start_lsn, end_lsn)?; + + // TODO Does it make sense to overshoot? + ensure!(datadir_timeline.tline.get_last_record_lsn() >= end_lsn); + + // Flush data to disk, then upload to s3. No need for a forced checkpoint. + // We only want to persist the data, and it doesn't matter if it's in the + // shape of deltas or images. + info!("flushing layers"); + datadir_timeline.tline.checkpoint(CheckpointConfig::Flush)?; + + info!("done"); + Ok(()) + } + /// Helper function to handle the LSN from client request. /// /// Each GetPage (and Exists and Nblocks) request includes information about @@ -750,6 +934,51 @@ impl postgres_backend::Handler for PageServerHandler { // Check that the timeline exists self.handle_basebackup_request(pgb, timelineid, lsn, tenantid, true)?; pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.starts_with("import basebackup ") { + // Import the `base` section (everything but the wal) of a basebackup. + // Assumes the tenant already exists on this pageserver. + // + // Files are scheduled to be persisted to remote storage, and the + // caller should poll the http api to check when that is done. + // + // Example import command: + // 1. Get start/end LSN from backup_manifest file + // 2. Run: + // cat my_backup/base.tar | psql -h $PAGESERVER \ + // -c "import basebackup $TENANT $TIMELINE $START_LSN $END_LSN" + let (_, params_raw) = query_string.split_at("import basebackup ".len()); + let params = params_raw.split_whitespace().collect::>(); + ensure!(params.len() == 4); + let tenant = ZTenantId::from_str(params[0])?; + let timeline = ZTimelineId::from_str(params[1])?; + let base_lsn = Lsn::from_str(params[2])?; + let end_lsn = Lsn::from_str(params[3])?; + + self.check_permission(Some(tenant))?; + + match self.handle_import_basebackup(pgb, tenant, timeline, base_lsn, end_lsn) { + Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?, + Err(e) => pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string()))?, + }; + } else if query_string.starts_with("import wal ") { + // Import the `pg_wal` section of a basebackup. + // + // Files are scheduled to be persisted to remote storage, and the + // caller should poll the http api to check when that is done. + let (_, params_raw) = query_string.split_at("import wal ".len()); + let params = params_raw.split_whitespace().collect::>(); + ensure!(params.len() == 4); + let tenant = ZTenantId::from_str(params[0])?; + let timeline = ZTimelineId::from_str(params[1])?; + let start_lsn = Lsn::from_str(params[2])?; + let end_lsn = Lsn::from_str(params[3])?; + + self.check_permission(Some(tenant))?; + + match self.handle_import_wal(pgb, tenant, timeline, start_lsn, end_lsn) { + Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?, + Err(e) => pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string()))?, + }; } else if query_string.to_ascii_lowercase().starts_with("set ") { // important because psycopg2 executes "SET datestyle TO 'ISO'" // on connect diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index 626ed1b0f1..59a53d68a1 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -749,6 +749,7 @@ impl<'a, R: Repository> DatadirModification<'a, R> { } /// Extend relation + /// If new size is smaller, do nothing. pub fn put_rel_extend(&mut self, rel: RelTag, nblocks: BlockNumber) -> Result<()> { ensure!(rel.relnode != 0, "invalid relnode"); @@ -756,10 +757,13 @@ impl<'a, R: Repository> DatadirModification<'a, R> { let size_key = rel_size_to_key(rel); let old_size = self.get(size_key)?.get_u32_le(); - let buf = nblocks.to_le_bytes(); - self.put(size_key, Value::Image(Bytes::from(buf.to_vec()))); + // only extend relation here. never decrease the size + if nblocks > old_size { + let buf = nblocks.to_le_bytes(); + self.put(size_key, Value::Image(Bytes::from(buf.to_vec()))); - self.pending_nblocks += nblocks as isize - old_size as isize; + self.pending_nblocks += nblocks as isize - old_size as isize; + } Ok(()) } diff --git a/test_runner/batch_others/test_import.py b/test_runner/batch_others/test_import.py new file mode 100644 index 0000000000..e478103313 --- /dev/null +++ b/test_runner/batch_others/test_import.py @@ -0,0 +1,193 @@ +import pytest +from fixtures.neon_fixtures import NeonEnvBuilder, wait_for_upload, wait_for_last_record_lsn +from fixtures.utils import lsn_from_hex, lsn_to_hex +from uuid import UUID, uuid4 +import tarfile +import os +import shutil +from pathlib import Path +import json +from fixtures.utils import subprocess_capture +from fixtures.log_helper import log +from contextlib import closing +from fixtures.neon_fixtures import pg_distrib_dir + + +@pytest.mark.timeout(600) +def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_builder): + # Put data in vanilla pg + vanilla_pg.start() + vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser") + vanilla_pg.safe_psql('''create table t as select 'long string to consume some space' || g + from generate_series(1,300000) g''') + assert vanilla_pg.safe_psql('select count(*) from t') == [(300000, )] + + # Take basebackup + basebackup_dir = os.path.join(test_output_dir, "basebackup") + base_tar = os.path.join(basebackup_dir, "base.tar") + wal_tar = os.path.join(basebackup_dir, "pg_wal.tar") + os.mkdir(basebackup_dir) + vanilla_pg.safe_psql("CHECKPOINT") + pg_bin.run([ + "pg_basebackup", + "-F", + "tar", + "-d", + vanilla_pg.connstr(), + "-D", + basebackup_dir, + ]) + + # Make corrupt base tar with missing pg_control + unpacked_base = os.path.join(basebackup_dir, "unpacked-base") + corrupt_base_tar = os.path.join(unpacked_base, "corrupt-base.tar") + os.mkdir(unpacked_base, 0o750) + subprocess_capture(str(test_output_dir), ["tar", "-xf", base_tar, "-C", unpacked_base]) + os.remove(os.path.join(unpacked_base, "global/pg_control")) + subprocess_capture(str(test_output_dir), + ["tar", "-cf", "corrupt-base.tar"] + os.listdir(unpacked_base), + cwd=unpacked_base) + + # Get start_lsn and end_lsn + with open(os.path.join(basebackup_dir, "backup_manifest")) as f: + manifest = json.load(f) + start_lsn = manifest["WAL-Ranges"][0]["Start-LSN"] + end_lsn = manifest["WAL-Ranges"][0]["End-LSN"] + + node_name = "import_from_vanilla" + tenant = uuid4() + timeline = uuid4() + + # Set up pageserver for import + neon_env_builder.enable_local_fs_remote_storage() + env = neon_env_builder.init_start() + env.pageserver.http_client().tenant_create(tenant) + + def import_tar(base, wal): + env.neon_cli.raw_cli([ + "timeline", + "import", + "--tenant-id", + tenant.hex, + "--timeline-id", + timeline.hex, + "--node-name", + node_name, + "--base-lsn", + start_lsn, + "--base-tarfile", + base, + "--end-lsn", + end_lsn, + "--wal-tarfile", + wal, + ]) + + # Importing corrupt backup fails + with pytest.raises(Exception): + import_tar(corrupt_base_tar, wal_tar) + + # Clean up + # TODO it should clean itself + client = env.pageserver.http_client() + client.timeline_detach(tenant, timeline) + + # Importing correct backup works + import_tar(base_tar, wal_tar) + + # Wait for data to land in s3 + wait_for_last_record_lsn(client, tenant, timeline, lsn_from_hex(end_lsn)) + wait_for_upload(client, tenant, timeline, lsn_from_hex(end_lsn)) + + # Check it worked + pg = env.postgres.create_start(node_name, tenant_id=tenant) + assert pg.safe_psql('select count(*) from t') == [(300000, )] + + +@pytest.mark.timeout(600) +def test_import_from_pageserver(test_output_dir, pg_bin, vanilla_pg, neon_env_builder): + + num_rows = 3000 + neon_env_builder.num_safekeepers = 1 + neon_env_builder.enable_local_fs_remote_storage() + env = neon_env_builder.init_start() + + env.neon_cli.create_branch('test_import_from_pageserver') + pgmain = env.postgres.create_start('test_import_from_pageserver') + log.info("postgres is running on 'test_import_from_pageserver' branch") + + timeline = pgmain.safe_psql("SHOW neon.timeline_id")[0][0] + + with closing(pgmain.connect()) as conn: + with conn.cursor() as cur: + # data loading may take a while, so increase statement timeout + cur.execute("SET statement_timeout='300s'") + cur.execute(f'''CREATE TABLE tbl AS SELECT 'long string to consume some space' || g + from generate_series(1,{num_rows}) g''') + cur.execute("CHECKPOINT") + + cur.execute('SELECT pg_current_wal_insert_lsn()') + lsn = cur.fetchone()[0] + log.info(f"start_backup_lsn = {lsn}") + + # Set LD_LIBRARY_PATH in the env properly, otherwise we may use the wrong libpq. + # PgBin sets it automatically, but here we need to pipe psql output to the tar command. + psql_env = {'LD_LIBRARY_PATH': os.path.join(str(pg_distrib_dir), 'lib')} + + # Get a fullbackup from pageserver + query = f"fullbackup { env.initial_tenant.hex} {timeline} {lsn}" + cmd = ["psql", "--no-psqlrc", env.pageserver.connstr(), "-c", query] + result_basepath = pg_bin.run_capture(cmd, env=psql_env) + tar_output_file = result_basepath + ".stdout" + + # Stop the first pageserver instance, erase all its data + env.postgres.stop_all() + env.pageserver.stop() + + dir_to_clear = Path(env.repo_dir) / 'tenants' + shutil.rmtree(dir_to_clear) + os.mkdir(dir_to_clear) + + #start the pageserver again + env.pageserver.start() + + # Import using another tenantid, because we use the same pageserver. + # TODO Create another pageserver to maeke test more realistic. + tenant = uuid4() + + # Import to pageserver + node_name = "import_from_pageserver" + client = env.pageserver.http_client() + client.tenant_create(tenant) + env.neon_cli.raw_cli([ + "timeline", + "import", + "--tenant-id", + tenant.hex, + "--timeline-id", + timeline, + "--node-name", + node_name, + "--base-lsn", + lsn, + "--base-tarfile", + os.path.join(tar_output_file), + ]) + + # Wait for data to land in s3 + wait_for_last_record_lsn(client, tenant, UUID(timeline), lsn_from_hex(lsn)) + wait_for_upload(client, tenant, UUID(timeline), lsn_from_hex(lsn)) + + # Check it worked + pg = env.postgres.create_start(node_name, tenant_id=tenant) + assert pg.safe_psql('select count(*) from tbl') == [(num_rows, )] + + # Take another fullbackup + query = f"fullbackup { tenant.hex} {timeline} {lsn}" + cmd = ["psql", "--no-psqlrc", env.pageserver.connstr(), "-c", query] + result_basepath = pg_bin.run_capture(cmd, env=psql_env) + new_tar_output_file = result_basepath + ".stdout" + + # Check it's the same as the first fullbackup + # TODO pageserver should be checking checksum + assert os.path.getsize(tar_output_file) == os.path.getsize(new_tar_output_file) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 51afd3a03d..12edcb8792 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -1398,12 +1398,12 @@ class VanillaPostgres(PgProtocol): if log_path is None: log_path = os.path.join(self.pgdatadir, "pg.log") - self.pg_bin.run_capture(['pg_ctl', '-D', self.pgdatadir, '-l', log_path, 'start']) + self.pg_bin.run_capture(['pg_ctl', '-w', '-D', self.pgdatadir, '-l', log_path, 'start']) def stop(self): assert self.running self.running = False - self.pg_bin.run_capture(['pg_ctl', '-D', self.pgdatadir, 'stop']) + self.pg_bin.run_capture(['pg_ctl', '-w', '-D', self.pgdatadir, 'stop']) def get_subdir_size(self, subdir) -> int: """Return size of pgdatadir subdirectory in bytes.""" From 6222a0012bf1a856149af618b67e7e362528bf36 Mon Sep 17 00:00:00 2001 From: Sergey Melnikov Date: Wed, 22 Jun 2022 11:40:59 +0300 Subject: [PATCH 10/11] Migrate from CircleCI to Github Actions: python codestyle, build and tests (#1647) Duplicate postgres and neon build and test jobs from CircleCI to Github actions. --- .../actions/run-python-test-set/action.yml | 119 ++++++++ .github/workflows/build_and_test.yml | 276 ++++++++++++++++++ 2 files changed, 395 insertions(+) create mode 100644 .github/actions/run-python-test-set/action.yml create mode 100644 .github/workflows/build_and_test.yml diff --git a/.github/actions/run-python-test-set/action.yml b/.github/actions/run-python-test-set/action.yml new file mode 100644 index 0000000000..94fac2ee99 --- /dev/null +++ b/.github/actions/run-python-test-set/action.yml @@ -0,0 +1,119 @@ +name: 'Run python test' +description: 'Runs a Neon python test set, performing all the required preparations before' + +inputs: + # Select the type of Rust build. Must be "release" or "debug". + build_type: + required: true + rust_toolchain: + required: true + # This parameter is required, to prevent the mistake of running all tests in one job. + test_selection: + required: true + # Arbitrary parameters to pytest. For example "-s" to prevent capturing stdout/stderr + extra_params: + required: false + default: '' + needs_postgres_source: + required: false + default: 'false' + run_in_parallel: + required: false + default: 'true' + save_perf_report: + required: false + default: 'false' + +runs: + using: "composite" + steps: + - name: Get Neon artifact for restoration + uses: actions/download-artifact@v3 + with: + name: neon-${{ runner.os }}-${{ inputs.build_type }}-${{ inputs.rust_toolchain }}-artifact + path: ./neon-artifact/ + + - name: Extract Neon artifact + shell: bash -ex {0} + run: | + mkdir -p /tmp/neon/ + tar -xf ./neon-artifact/neon.tgz -C /tmp/neon/ + rm -rf ./neon-artifact/ + + - name: Checkout + if: inputs.needs_postgres_source == 'true' + uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 1 + + - name: Cache poetry deps + id: cache_poetry + uses: actions/cache@v3 + with: + path: ~/.cache/pypoetry/virtualenvs + key: v1-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }} + + - name: Install Python deps + shell: bash -ex {0} + run: ./scripts/pysync + + - name: Run pytest + env: + ZENITH_BIN: /tmp/neon/bin + POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install + TEST_OUTPUT: /tmp/test_output + # this variable will be embedded in perf test report + # and is needed to distinguish different environments + PLATFORM: github-actions-selfhosted + shell: bash -ex {0} + run: | + PERF_REPORT_DIR="$(realpath test_runner/perf-report-local)" + rm -rf $PERF_REPORT_DIR + + TEST_SELECTION="test_runner/${{ inputs.test_selection }}" + EXTRA_PARAMS="${{ inputs.extra_params }}" + if [ -z "$TEST_SELECTION" ]; then + echo "test_selection must be set" + exit 1 + fi + if [[ "${{ inputs.run_in_parallel }}" == "true" ]]; then + EXTRA_PARAMS="-n4 $EXTRA_PARAMS" + fi + if [[ "${{ inputs.save_perf_report }}" == "true" ]]; then + if [[ "$GITHUB_REF" == "main" ]]; then + mkdir -p "$PERF_REPORT_DIR" + EXTRA_PARAMS="--out-dir $PERF_REPORT_DIR $EXTRA_PARAMS" + fi + fi + + if [[ "${{ inputs.build_type }}" == "debug" ]]; then + cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage run) + elif [[ "${{ inputs.build_type }}" == "release" ]]; then + cov_prefix=() + fi + + # Run the tests. + # + # The junit.xml file allows CircleCI to display more fine-grained test information + # in its "Tests" tab in the results page. + # --verbose prints name of each test (helpful when there are + # multiple tests in one file) + # -rA prints summary in the end + # -n4 uses four processes to run tests via pytest-xdist + # -s is not used to prevent pytest from capturing output, because tests are running + # in parallel and logs are mixed between different tests + "${cov_prefix[@]}" ./scripts/pytest \ + --junitxml=$TEST_OUTPUT/junit.xml \ + --tb=short \ + --verbose \ + -m "not remote_cluster" \ + -rA $TEST_SELECTION $EXTRA_PARAMS + + if [[ "${{ inputs.save_perf_report }}" == "true" ]]; then + if [[ "$GITHUB_REF" == "main" ]]; then + export REPORT_FROM="$PERF_REPORT_DIR" + export REPORT_TO=local + scripts/generate_and_push_perf_report.sh + fi + fi diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml new file mode 100644 index 0000000000..5f4dd754d2 --- /dev/null +++ b/.github/workflows/build_and_test.yml @@ -0,0 +1,276 @@ +name: build_and_test +on: [ push ] +defaults: + run: + shell: bash -ex {0} + +jobs: + build-postgres: + runs-on: [ self-hosted, Linux, k8s-runner ] + strategy: + matrix: + build_type: [ debug, release ] + rust_toolchain: [ 1.58 ] + + env: + BUILD_TYPE: ${{ matrix.build_type }} + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 1 + + - name: Set pg revision for caching + id: pg_ver + run: echo ::set-output name=pg_rev::$(git rev-parse HEAD:vendor/postgres) + + - name: Cache postgres build + id: cache_pg + uses: actions/cache@v3 + with: + path: tmp_install/ + key: v1-${{ runner.os }}-${{ matrix.build_type }}-pg-${{ steps.pg_ver.outputs.pg_rev }}-${{ hashFiles('Makefile') }} + + - name: Build postgres + if: steps.cache_pg.outputs.cache-hit != 'true' + run: COPT='-Werror' mold -run make postgres -j$(nproc) + + # actions/cache@v3 does not allow concurrently using the same cache across job steps, so use a separate cache + - name: Prepare postgres artifact + run: tar -C tmp_install/ -czf ./pg.tgz . + - name: Upload postgres artifact + uses: actions/upload-artifact@v3 + with: + retention-days: 7 + if-no-files-found: error + name: postgres-${{ runner.os }}-${{ matrix.build_type }}-artifact + path: ./pg.tgz + + + build-neon: + runs-on: [ self-hosted, Linux, k8s-runner ] + needs: [ build-postgres ] + strategy: + matrix: + build_type: [ debug, release ] + rust_toolchain: [ 1.58 ] + + env: + BUILD_TYPE: ${{ matrix.build_type }} + + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 1 + + - name: Get postgres artifact for restoration + uses: actions/download-artifact@v3 + with: + name: postgres-${{ runner.os }}-${{ matrix.build_type }}-artifact + path: ./postgres-artifact/ + - name: Extract postgres artifact + run: | + mkdir ./tmp_install/ + tar -xf ./postgres-artifact/pg.tgz -C ./tmp_install/ + rm -rf ./postgres-artifact/ + + - name: Cache cargo deps + id: cache_cargo + uses: actions/cache@v3 + with: + path: | + ~/.cargo/registry/ + ~/.cargo/git/ + target/ + key: v2-${{ runner.os }}-${{ matrix.build_type }}-cargo-${{ matrix.rust_toolchain }}-${{ hashFiles('Cargo.lock') }} + + - name: Run cargo build + run: | + if [[ $BUILD_TYPE == "debug" ]]; then + cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage run) + CARGO_FLAGS= + elif [[ $BUILD_TYPE == "release" ]]; then + cov_prefix=() + CARGO_FLAGS="--release --features profiling" + fi + + export CACHEPOT_BUCKET=zenith-rust-cachepot + export RUSTC_WRAPPER=cachepot + export AWS_ACCESS_KEY_ID="${{ secrets.AWS_ACCESS_KEY_ID }}" + export AWS_SECRET_ACCESS_KEY="${{ secrets.AWS_SECRET_ACCESS_KEY }}" + export HOME=/home/runner + "${cov_prefix[@]}" mold -run cargo build $CARGO_FLAGS --features failpoints --bins --tests + cachepot -s + + - name: Run cargo test + run: | + export HOME=/home/runner + if [[ $BUILD_TYPE == "debug" ]]; then + cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage run) + CARGO_FLAGS= + elif [[ $BUILD_TYPE == "release" ]]; then + cov_prefix=() + CARGO_FLAGS=--release + fi + + "${cov_prefix[@]}" cargo test $CARGO_FLAGS + + - name: Install rust binaries + run: | + export HOME=/home/runner + if [[ $BUILD_TYPE == "debug" ]]; then + cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage run) + elif [[ $BUILD_TYPE == "release" ]]; then + cov_prefix=() + fi + + binaries=$( + "${cov_prefix[@]}" cargo metadata --format-version=1 --no-deps | + jq -r '.packages[].targets[] | select(.kind | index("bin")) | .name' + ) + + test_exe_paths=$( + "${cov_prefix[@]}" cargo test --message-format=json --no-run | + jq -r '.executable | select(. != null)' + ) + + mkdir -p /tmp/neon/bin + mkdir -p /tmp/neon/test_bin + mkdir -p /tmp/neon/etc + + # Install target binaries + for bin in $binaries; do + SRC=target/$BUILD_TYPE/$bin + DST=/tmp/neon/bin/$bin + cp $SRC $DST + echo $DST >> /tmp/neon/etc/binaries.list + done + + # Install test executables (for code coverage) + if [[ $BUILD_TYPE == "debug" ]]; then + for bin in $test_exe_paths; do + SRC=$bin + DST=/tmp/neon/test_bin/$(basename $bin) + cp $SRC $DST + echo $DST >> /tmp/neon/etc/binaries.list + done + fi + + - name: Install postgres binaries + run: cp -a tmp_install /tmp/neon/pg_install + + - name: Merge coverage data + run: | + export HOME=/home/runner + # This will speed up workspace uploads + if [[ $BUILD_TYPE == "debug" ]]; then + scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage merge + fi + + - name: Prepare neon artifact + run: tar -C /tmp/neon/ -czf ./neon.tgz . + + - name: Upload neon binaries + uses: actions/upload-artifact@v3 + with: + retention-days: 7 + if-no-files-found: error + name: neon-${{ runner.os }}-${{ matrix.build_type }}-${{ matrix.rust_toolchain }}-artifact + path: ./neon.tgz + + check-codestyle-python: + runs-on: [ self-hosted, Linux, k8s-runner ] + strategy: + matrix: + rust_toolchain: [ 1.58 ] + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 1 + + - name: Cache poetry deps + id: cache_poetry + uses: actions/cache@v3 + with: + path: ~/.cache/pypoetry/virtualenvs + key: v1-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }} + + - name: Install Python deps + run: ./scripts/pysync + + - name: Run yapf to ensure code format + run: poetry run yapf --recursive --diff . + + - name: Run mypy to check types + run: poetry run mypy . + + pg_regress-tests: + runs-on: [ self-hosted, Linux, k8s-runner ] + needs: [ build-neon ] + strategy: + matrix: + build_type: [ debug, release ] + rust_toolchain: [ 1.58 ] + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 2 + + - name: Pytest regress tests + uses: ./.github/actions/run-python-test-set + with: + build_type: ${{ matrix.build_type }} + rust_toolchain: ${{ matrix.rust_toolchain }} + test_selection: batch_pg_regress + needs_postgres_source: true + + other-tests: + runs-on: [ self-hosted, Linux, k8s-runner ] + needs: [ build-neon ] + strategy: + matrix: + build_type: [ debug, release ] + rust_toolchain: [ 1.58 ] + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 2 + + - name: Pytest other tests + uses: ./.github/actions/run-python-test-set + with: + build_type: ${{ matrix.build_type }} + rust_toolchain: ${{ matrix.rust_toolchain }} + test_selection: batch_others + + benchmarks: + runs-on: [ self-hosted, Linux, k8s-runner ] + needs: [ build-neon ] + strategy: + matrix: + build_type: [ release ] + rust_toolchain: [ 1.58 ] + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 2 + + - name: Pytest benchmarks + uses: ./.github/actions/run-python-test-set + with: + build_type: ${{ matrix.build_type }} + rust_toolchain: ${{ matrix.rust_toolchain }} + test_selection: performance + run_in_parallel: false + # save_perf_report: true From d059e588a663d0b26992761a11b93acb11ff499d Mon Sep 17 00:00:00 2001 From: KlimentSerafimov Date: Wed, 22 Jun 2022 15:34:24 +0200 Subject: [PATCH 11/11] Added invariant check for project name. (#1921) Summary: Added invariant checking for project name. Refactored ClientCredentials and TlsConfig. * Added formatting invariant check for project name: **\forall c \in project_name . c \in [alnum] U {'-'}. ** sni_data == . * Added exhaustive tests for get_project_name. * Refactored TlsConfig to contain common_name : Option. * Refactored ClientCredentials construction to construct project_name directly. * Merged ProjectNameError into ClientCredsParseError. * Tweaked proxy tests to accommodate refactored ClientCredentials construction semantics. * [Pytests] Added project option argument to test_proxy_select_1. * Removed project param from Api since now it's contained in creds. * Refactored &Option -> Option<&str>. Co-authored-by: Dmitrii Ivanov . --- Cargo.lock | 119 +++++++++ proxy/Cargo.toml | 2 + proxy/src/auth/backend/console.rs | 15 +- proxy/src/auth/credentials.rs | 339 ++++++++++++++++++++----- proxy/src/config.rs | 42 ++- proxy/src/proxy.rs | 58 +++-- test_runner/batch_others/test_proxy.py | 2 +- 7 files changed, 473 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dca525941d..f4d3743676 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -64,6 +64,45 @@ dependencies = [ "nodrop", ] +[[package]] +name = "asn1-rs" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ff05a702273012438132f449575dbc804e27b2f3cbe3069aa237d26c98fa33" +dependencies = [ + "asn1-rs-derive", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror", + "time 0.3.9", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db8b7511298d5b7784b40b092d9e9dcd3a627a5707e4b5e507931ab0d44eeebf" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2777730b2039ac0f95f093556e61b6d26cebed5393ca6f152717777cec3a42ed" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-stream" version = "0.3.3" @@ -712,6 +751,12 @@ dependencies = [ "syn", ] +[[package]] +name = "data-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57" + [[package]] name = "debugid" version = "0.7.3" @@ -721,6 +766,20 @@ dependencies = [ "uuid", ] +[[package]] +name = "der-parser" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe398ac75057914d7d07307bf67dc7f3f574a26783b4fc7805a20ffa9f506e82" +dependencies = [ + "asn1-rs", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "digest" version = "0.9.0" @@ -762,6 +821,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "displaydoc" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bf95dc3f046b9da4f2d51833c0d3547d8564ef6910f5c1ed130306a75b92886" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "either" version = "1.6.1" @@ -1731,6 +1801,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "oid-registry" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38e20717fa0541f39bd146692035c37bedfa532b3e5071b35761082407546b2a" +dependencies = [ + "asn1-rs", +] + [[package]] name = "once_cell" version = "1.10.0" @@ -2250,6 +2329,7 @@ dependencies = [ "url", "utils", "workspace_hack", + "x509-parser", ] [[package]] @@ -2621,6 +2701,15 @@ dependencies = [ "semver", ] +[[package]] +name = "rusticata-macros" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" +dependencies = [ + "nom", +] + [[package]] name = "rustls" version = "0.20.4" @@ -3060,6 +3149,18 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" +[[package]] +name = "synstructure" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "unicode-xid", +] + [[package]] name = "tar" version = "0.4.38" @@ -3922,6 +4023,24 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "x509-parser" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb9bace5b5589ffead1afb76e43e34cff39cd0f3ce7e170ae0c29e53b88eb1c" +dependencies = [ + "asn1-rs", + "base64", + "data-encoding", + "der-parser", + "lazy_static", + "nom", + "oid-registry", + "rusticata-macros", + "thiserror", + "time 0.3.9", +] + [[package]] name = "xattr" version = "0.2.2" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 4e45698e3e..8c6036f87d 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -39,6 +39,8 @@ utils = { path = "../libs/utils" } metrics = { path = "../libs/metrics" } workspace_hack = { version = "0.1", path = "../workspace_hack" } +x509-parser = "0.13.2" + [dev-dependencies] rcgen = "0.8.14" rstest = "0.12" diff --git a/proxy/src/auth/backend/console.rs b/proxy/src/auth/backend/console.rs index 252522affb..93462086ea 100644 --- a/proxy/src/auth/backend/console.rs +++ b/proxy/src/auth/backend/console.rs @@ -19,7 +19,7 @@ pub type Result = std::result::Result; #[derive(Debug, Error)] pub enum ConsoleAuthError { #[error(transparent)] - BadProjectName(#[from] auth::credentials::ProjectNameError), + BadProjectName(#[from] auth::credentials::ClientCredsParseError), // We shouldn't include the actual secret here. #[error("Bad authentication secret")] @@ -74,18 +74,12 @@ pub enum AuthInfo { pub(super) struct Api<'a> { endpoint: &'a ApiUrl, creds: &'a ClientCredentials, - /// Cache project name, since we'll need it several times. - project: &'a str, } impl<'a> Api<'a> { /// Construct an API object containing the auth parameters. pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result { - Ok(Self { - endpoint, - creds, - project: creds.project_name()?, - }) + Ok(Self { endpoint, creds }) } /// Authenticate the existing user or throw an error. @@ -100,7 +94,7 @@ impl<'a> Api<'a> { let mut url = self.endpoint.clone(); url.path_segments_mut().push("proxy_get_role_secret"); url.query_pairs_mut() - .append_pair("project", self.project) + .append_pair("project", &self.creds.project_name) .append_pair("role", &self.creds.user); // TODO: use a proper logger @@ -123,7 +117,8 @@ impl<'a> Api<'a> { async fn wake_compute(&self) -> Result { let mut url = self.endpoint.clone(); url.path_segments_mut().push("proxy_wake_compute"); - url.query_pairs_mut().append_pair("project", self.project); + url.query_pairs_mut() + .append_pair("project", &self.creds.project_name); // TODO: use a proper logger println!("cplane request: {url}"); diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 6521162b50..48dc8542ec 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -8,10 +8,32 @@ use std::collections::HashMap; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -#[derive(Debug, Error)] +#[derive(Debug, Error, PartialEq)] pub enum ClientCredsParseError { - #[error("Parameter `{0}` is missing in startup packet")] + #[error("Parameter `{0}` is missing in startup packet.")] MissingKey(&'static str), + + #[error( + "Project name is not specified. \ + EITHER please upgrade the postgres client library (libpq) for SNI support \ + OR pass the project name as a parameter: '&options=project%3D'." + )] + MissingSNIAndProjectName, + + #[error("Inconsistent project name inferred from SNI ('{0}') and project option ('{1}').")] + InconsistentProjectNameAndSNI(String, String), + + #[error("Common name is not set.")] + CommonNameNotSet, + + #[error( + "SNI ('{1}') inconsistently formatted with respect to common name ('{0}'). \ + SNI should be formatted as '.'." + )] + InconsistentCommonNameAndSNI(String, String), + + #[error("Project name ('{0}') must contain only alphanumeric characters and hyphens ('-').")] + ProjectNameContainsIllegalChars(String), } impl UserFacingError for ClientCredsParseError {} @@ -22,15 +44,7 @@ impl UserFacingError for ClientCredsParseError {} pub struct ClientCredentials { pub user: String, pub dbname: String, - - // New console API requires SNI info to determine the cluster name. - // Other Auth backends don't need it. - pub sni_data: Option, - - // project_name is passed as argument from options from url. - // In case sni_data is missing: project_name is used to determine cluster name. - // In case sni_data is available: project_name and sni_data should match (otherwise throws an error). - pub project_name: Option, + pub project_name: String, } impl ClientCredentials { @@ -38,60 +52,14 @@ impl ClientCredentials { // This logic will likely change in the future. self.user.ends_with("@zenith") } -} -#[derive(Debug, Error)] -pub enum ProjectNameError { - #[error("SNI is missing. EITHER please upgrade the postgres client library OR pass the project name as a parameter: '...&options=project%3D...'.")] - Missing, - - #[error("SNI is malformed.")] - Bad, - - #[error("Inconsistent project name inferred from SNI and project option. String from SNI: '{0}', String from project option: '{1}'")] - Inconsistent(String, String), -} - -impl UserFacingError for ProjectNameError {} - -impl ClientCredentials { - /// Determine project name from SNI or from project_name parameter from options argument. - pub fn project_name(&self) -> Result<&str, ProjectNameError> { - // Checking that if both sni_data and project_name are set, then they should match - // otherwise, throws a ProjectNameError::Inconsistent error. - if let Some(sni_data) = &self.sni_data { - let project_name_from_sni_data = - sni_data.split_once('.').ok_or(ProjectNameError::Bad)?.0; - if let Some(project_name_from_options) = &self.project_name { - if !project_name_from_options.eq(project_name_from_sni_data) { - return Err(ProjectNameError::Inconsistent( - project_name_from_sni_data.to_string(), - project_name_from_options.to_string(), - )); - } - } - } - // determine the project name from self.sni_data if it exists, otherwise from self.project_name. - let ret = match &self.sni_data { - // if sni_data exists, use it to determine project name - Some(sni_data) => sni_data.split_once('.').ok_or(ProjectNameError::Bad)?.0, - // otherwise use project_option if it was manually set thought options parameter. - None => self - .project_name - .as_ref() - .ok_or(ProjectNameError::Missing)? - .as_str(), - }; - Ok(ret) - } -} - -impl TryFrom> for ClientCredentials { - type Error = ClientCredsParseError; - - fn try_from(mut value: HashMap) -> Result { + pub fn parse( + mut options: HashMap, + sni_data: Option<&str>, + common_name: Option<&str>, + ) -> Result { let mut get_param = |key| { - value + options .remove(key) .ok_or(ClientCredsParseError::MissingKey(key)) }; @@ -99,17 +67,15 @@ impl TryFrom> for ClientCredentials { let user = get_param("user")?; let dbname = get_param("database")?; let project_name = get_param("project").ok(); + let project_name = get_project_name(sni_data, common_name, project_name.as_deref())?; Ok(Self { user, dbname, - sni_data: None, project_name, }) } -} -impl ClientCredentials { /// Use credentials to authenticate the user. pub async fn authenticate( self, @@ -120,3 +86,244 @@ impl ClientCredentials { super::backend::handle_user(config, client, self).await } } + +/// Inferring project name from sni_data. +fn project_name_from_sni_data( + sni_data: &str, + common_name: &str, +) -> Result { + let common_name_with_dot = format!(".{common_name}"); + // check that ".{common_name_with_dot}" is the actual suffix in sni_data + if !sni_data.ends_with(&common_name_with_dot) { + return Err(ClientCredsParseError::InconsistentCommonNameAndSNI( + common_name.to_string(), + sni_data.to_string(), + )); + } + // return sni_data without the common name suffix. + Ok(sni_data + .strip_suffix(&common_name_with_dot) + .unwrap() + .to_string()) +} + +#[cfg(test)] +mod tests_for_project_name_from_sni_data { + use super::*; + + #[test] + fn passing() { + let target_project_name = "my-project-123"; + let common_name = "localtest.me"; + let sni_data = format!("{target_project_name}.{common_name}"); + assert_eq!( + project_name_from_sni_data(&sni_data, common_name), + Ok(target_project_name.to_string()) + ); + } + + #[test] + fn throws_inconsistent_common_name_and_sni_data() { + let target_project_name = "my-project-123"; + let common_name = "localtest.me"; + let wrong_suffix = "wrongtest.me"; + assert_eq!(common_name.len(), wrong_suffix.len()); + let wrong_common_name = format!("wrong{wrong_suffix}"); + let sni_data = format!("{target_project_name}.{wrong_common_name}"); + assert_eq!( + project_name_from_sni_data(&sni_data, common_name), + Err(ClientCredsParseError::InconsistentCommonNameAndSNI( + common_name.to_string(), + sni_data + )) + ); + } +} + +/// Determine project name from SNI or from project_name parameter from options argument. +fn get_project_name( + sni_data: Option<&str>, + common_name: Option<&str>, + project_name: Option<&str>, +) -> Result { + // determine the project name from sni_data if it exists, otherwise from project_name. + let ret = match sni_data { + Some(sni_data) => { + let common_name = common_name.ok_or(ClientCredsParseError::CommonNameNotSet)?; + let project_name_from_sni = project_name_from_sni_data(sni_data, common_name)?; + // check invariant: project name from options and from sni should match + if let Some(project_name) = &project_name { + if !project_name_from_sni.eq(project_name) { + return Err(ClientCredsParseError::InconsistentProjectNameAndSNI( + project_name_from_sni, + project_name.to_string(), + )); + } + } + project_name_from_sni + } + None => project_name + .ok_or(ClientCredsParseError::MissingSNIAndProjectName)? + .to_string(), + }; + + // check formatting invariant: project name must contain only alphanumeric characters and hyphens. + if !ret.chars().all(|x: char| x.is_alphanumeric() || x == '-') { + return Err(ClientCredsParseError::ProjectNameContainsIllegalChars(ret)); + } + + Ok(ret) +} + +#[cfg(test)] +mod tests_for_project_name_only { + use super::*; + + #[test] + fn passing_from_sni_data_only() { + let target_project_name = "my-project-123"; + let common_name = "localtest.me"; + let sni_data = format!("{target_project_name}.{common_name}"); + assert_eq!( + get_project_name(Some(&sni_data), Some(common_name), None), + Ok(target_project_name.to_string()) + ); + } + + #[test] + fn throws_project_name_contains_illegal_chars_from_sni_data_only() { + let project_name_prefix = "my-project"; + let project_name_suffix = "123"; + let common_name = "localtest.me"; + + for illegal_char_id in 0..256 { + let illegal_char = char::from_u32(illegal_char_id).unwrap(); + if !(illegal_char.is_alphanumeric() || illegal_char == '-') + && illegal_char.to_string().len() == 1 + { + let target_project_name = + format!("{project_name_prefix}{illegal_char}{project_name_suffix}"); + let sni_data = format!("{target_project_name}.{common_name}"); + assert_eq!( + get_project_name(Some(&sni_data), Some(common_name), None), + Err(ClientCredsParseError::ProjectNameContainsIllegalChars( + target_project_name + )) + ); + } + } + } + + #[test] + fn passing_from_project_name_only() { + let target_project_name = "my-project-123"; + let common_names = [Some("localtest.me"), None]; + for common_name in common_names { + assert_eq!( + get_project_name(None, common_name, Some(target_project_name)), + Ok(target_project_name.to_string()) + ); + } + } + + #[test] + fn throws_project_name_contains_illegal_chars_from_project_name_only() { + let project_name_prefix = "my-project"; + let project_name_suffix = "123"; + let common_names = [Some("localtest.me"), None]; + + for common_name in common_names { + for illegal_char_id in 0..256 { + let illegal_char: char = char::from_u32(illegal_char_id).unwrap(); + if !(illegal_char.is_alphanumeric() || illegal_char == '-') + && illegal_char.to_string().len() == 1 + { + let target_project_name = + format!("{project_name_prefix}{illegal_char}{project_name_suffix}"); + assert_eq!( + get_project_name(None, common_name, Some(&target_project_name)), + Err(ClientCredsParseError::ProjectNameContainsIllegalChars( + target_project_name + )) + ); + } + } + } + } + + #[test] + fn passing_from_sni_data_and_project_name() { + let target_project_name = "my-project-123"; + let common_name = "localtest.me"; + let sni_data = format!("{target_project_name}.{common_name}"); + assert_eq!( + get_project_name( + Some(&sni_data), + Some(common_name), + Some(target_project_name) + ), + Ok(target_project_name.to_string()) + ); + } + + #[test] + fn throws_inconsistent_project_name_and_sni() { + let project_name_param = "my-project-123"; + let wrong_project_name = "not-my-project-123"; + let common_name = "localtest.me"; + let sni_data = format!("{wrong_project_name}.{common_name}"); + assert_eq!( + get_project_name(Some(&sni_data), Some(common_name), Some(project_name_param)), + Err(ClientCredsParseError::InconsistentProjectNameAndSNI( + wrong_project_name.to_string(), + project_name_param.to_string() + )) + ); + } + + #[test] + fn throws_common_name_not_set() { + let target_project_name = "my-project-123"; + let wrong_project_name = "not-my-project-123"; + let common_name = "localtest.me"; + let sni_datas = [ + Some(format!("{wrong_project_name}.{common_name}")), + Some(format!("{target_project_name}.{common_name}")), + ]; + let project_names = [None, Some(target_project_name)]; + for sni_data in sni_datas { + for project_name_param in project_names { + assert_eq!( + get_project_name(sni_data.as_deref(), None, project_name_param), + Err(ClientCredsParseError::CommonNameNotSet) + ); + } + } + } + + #[test] + fn throws_inconsistent_common_name_and_sni_data() { + let target_project_name = "my-project-123"; + let wrong_project_name = "not-my-project-123"; + let common_name = "localtest.me"; + let wrong_suffix = "wrongtest.me"; + assert_eq!(common_name.len(), wrong_suffix.len()); + let wrong_common_name = format!("wrong{wrong_suffix}"); + let sni_datas = [ + Some(format!("{wrong_project_name}.{wrong_common_name}")), + Some(format!("{target_project_name}.{wrong_common_name}")), + ]; + let project_names = [None, Some(target_project_name)]; + for project_name_param in project_names { + for sni_data in &sni_datas { + assert_eq!( + get_project_name(sni_data.as_deref(), Some(common_name), project_name_param), + Err(ClientCredsParseError::InconsistentCommonNameAndSNI( + common_name.to_string(), + sni_data.clone().unwrap().to_string() + )) + ); + } + } + } +} diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 4def11aefc..df3923de1a 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -36,23 +36,35 @@ pub struct ProxyConfig { pub auth_link_uri: ApiUrl, } -pub type TlsConfig = Arc; +pub struct TlsConfig { + pub config: Arc, + pub common_name: Option, +} + +impl TlsConfig { + pub fn to_server_config(&self) -> Arc { + self.config.clone() + } +} /// Configure TLS for the main endpoint. pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result { let key = { let key_bytes = std::fs::read(key_path).context("TLS key file")?; let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) - .context("couldn't read TLS keys")?; + .context(format!("Failed to read TLS keys at '{key_path}'"))?; ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len()); keys.pop().map(rustls::PrivateKey).unwrap() }; + let cert_chain_bytes = std::fs::read(cert_path) + .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?; let cert_chain = { - let cert_chain_bytes = std::fs::read(cert_path).context("TLS cert file")?; rustls_pemfile::certs(&mut &cert_chain_bytes[..]) - .context("couldn't read TLS certificate chain")? + .context(format!( + "Failed to read TLS certificate chain from bytes from file at '{cert_path}'." + ))? .into_iter() .map(rustls::Certificate) .collect() @@ -64,7 +76,25 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result x, None => return Ok(()), // it's a cancellation request @@ -99,12 +99,14 @@ async fn handle_client( /// we also take an extra care of propagating only the select handshake errors to client. async fn handshake( stream: S, - mut tls: Option, + mut tls: Option<&TlsConfig>, cancel_map: &CancelMap, ) -> anyhow::Result>, auth::ClientCredentials)>> { // Client may try upgrading to each protocol only once let (mut tried_ssl, mut tried_gss) = (false, false); + let common_name = tls.and_then(|cfg| cfg.common_name.as_deref()); + let mut stream = PqStream::new(Stream::from_raw(stream)); loop { let msg = stream.read_startup_packet().await?; @@ -122,7 +124,9 @@ async fn handshake( if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. - stream = PqStream::new(stream.into_inner().upgrade(tls).await?); + stream = PqStream::new( + stream.into_inner().upgrade(tls.to_server_config()).await?, + ); } } _ => bail!(ERR_PROTO_VIOLATION), @@ -143,15 +147,16 @@ async fn handshake( stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; } - // Here and forth: `or_else` demands that we use a future here - let mut creds: auth::ClientCredentials = async { params.try_into() } - .or_else(|e| stream.throw_error(e)) - .await?; + // Get SNI info when available + let sni_data = match stream.get_ref() { + Stream::Tls { tls } => tls.get_ref().1.sni_hostname().map(|s| s.to_owned()), + _ => None, + }; - // Set SNI info when available - if let Stream::Tls { tls } = stream.get_ref() { - creds.sni_data = tls.get_ref().1.sni_hostname().map(|s| s.to_owned()); - } + // Construct credentials + let creds = + auth::ClientCredentials::parse(params, sni_data.as_deref(), common_name); + let creds = async { creds }.or_else(|e| stream.throw_error(e)).await?; break Ok(Some((stream, creds))); } @@ -264,12 +269,13 @@ mod tests { } /// Generate TLS certificates and build rustls configs for client and server. - fn generate_tls_config( - hostname: &str, - ) -> anyhow::Result<(ClientConfig<'_>, Arc)> { + fn generate_tls_config<'a>( + hostname: &'a str, + common_name: &'a str, + ) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> { let (ca, cert, key) = generate_certs(hostname)?; - let server_config = { + let tls_config = { let config = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() @@ -291,7 +297,12 @@ mod tests { ClientConfig { config, hostname } }; - Ok((client_config, server_config)) + let tls_config = TlsConfig { + config: tls_config, + common_name: Some(common_name.to_string()), + }; + + Ok((client_config, tls_config)) } #[async_trait] @@ -346,7 +357,7 @@ mod tests { auth: impl TestAuth + Send, ) -> anyhow::Result<()> { let cancel_map = CancelMap::default(); - let (mut stream, _creds) = handshake(client, tls, &cancel_map) + let (mut stream, _creds) = handshake(client, tls.as_ref(), &cancel_map) .await? .context("handshake failed")?; @@ -365,7 +376,8 @@ mod tests { async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (_, server_config) = generate_tls_config("localhost")?; + let (_, server_config) = + generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let client_err = tokio_postgres::Config::new() @@ -393,7 +405,8 @@ mod tests { async fn handshake_tls() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (client_config, server_config) = generate_tls_config("localhost")?; + let (client_config, server_config) = + generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); let (_client, _conn) = tokio_postgres::Config::new() @@ -415,6 +428,7 @@ mod tests { let (_client, _conn) = tokio_postgres::Config::new() .user("john_doe") .dbname("earth") + .options("project=generic-project-name") .ssl_mode(SslMode::Prefer) .connect_raw(server, NoTls) .await?; @@ -476,7 +490,8 @@ mod tests { async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (client_config, server_config) = generate_tls_config("localhost")?; + let (client_config, server_config) = + generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy( client, Some(server_config), @@ -498,7 +513,8 @@ mod tests { async fn scram_auth_mock() -> anyhow::Result<()> { let (client, server) = tokio::io::duplex(1024); - let (client_config, server_config) = generate_tls_config("localhost")?; + let (client_config, server_config) = + generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy( client, Some(server_config), diff --git a/test_runner/batch_others/test_proxy.py b/test_runner/batch_others/test_proxy.py index a6f828f829..ebeede8df7 100644 --- a/test_runner/batch_others/test_proxy.py +++ b/test_runner/batch_others/test_proxy.py @@ -2,7 +2,7 @@ import pytest def test_proxy_select_1(static_proxy): - static_proxy.safe_psql("select 1;") + static_proxy.safe_psql("select 1;", options="project=generic-project-name") # Pass extra options to the server.