From 4a2a55d9b22e23b571f46cffccf56bc37aa74c1a Mon Sep 17 00:00:00 2001 From: Bojan Serafimov Date: Tue, 5 Jul 2022 15:07:39 -0400 Subject: [PATCH] Add standalone script --- scripts/add_missing_rels.py | 438 ++++++++++++++++++++++++++++++++++++ 1 file changed, 438 insertions(+) create mode 100644 scripts/add_missing_rels.py diff --git a/scripts/add_missing_rels.py b/scripts/add_missing_rels.py new file mode 100644 index 0000000000..b286d8cb38 --- /dev/null +++ b/scripts/add_missing_rels.py @@ -0,0 +1,438 @@ +import os +import shutil +from pathlib import Path +import tempfile +from contextlib import closing +import psycopg2 +import subprocess +import argparse + +### utils copied from test fixtures +from typing import Any, List +from psycopg2.extensions import connection as PgConnection +import asyncpg +from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union, Tuple + +Env = Dict[str, str] + +_global_counter = 0 + + +def global_counter() -> int: + """ A really dumb global counter. + + This is useful for giving output files a unique number, so if we run the + same command multiple times we can keep their output separate. + """ + global _global_counter + _global_counter += 1 + return _global_counter + + +def subprocess_capture(capture_dir: str, cmd: List[str], **kwargs: Any) -> str: + """ Run a process and capture its output + + Output will go to files named "cmd_NNN.stdout" and "cmd_NNN.stderr" + where "cmd" is the name of the program and NNN is an incrementing + counter. + + If those files already exist, we will overwrite them. + Returns basepath for files with captured output. + """ + assert type(cmd) is list + base = os.path.basename(cmd[0]) + '_{}'.format(global_counter()) + basepath = os.path.join(capture_dir, base) + stdout_filename = basepath + '.stdout' + stderr_filename = basepath + '.stderr' + + with open(stdout_filename, 'w') as stdout_f: + with open(stderr_filename, 'w') as stderr_f: + print('(capturing output to "{}.stdout")'.format(base)) + subprocess.run(cmd, **kwargs, stdout=stdout_f, stderr=stderr_f) + + return basepath + + +class PgBin: + """ A helper class for executing postgres binaries """ + def __init__(self, log_dir: Path, pg_distrib_dir): + self.log_dir = log_dir + self.pg_bin_path = os.path.join(str(pg_distrib_dir), 'bin') + self.env = os.environ.copy() + self.env['LD_LIBRARY_PATH'] = os.path.join(str(pg_distrib_dir), 'lib') + + def _fixpath(self, command: List[str]): + if '/' not in command[0]: + command[0] = os.path.join(self.pg_bin_path, command[0]) + + def _build_env(self, env_add: Optional[Env]) -> Env: + if env_add is None: + return self.env + env = self.env.copy() + env.update(env_add) + return env + + def run(self, command: List[str], env: Optional[Env] = None, cwd: Optional[str] = None): + """ + Run one of the postgres binaries. + + The command should be in list form, e.g. ['pgbench', '-p', '55432'] + + All the necessary environment variables will be set. + + If the first argument (the command name) doesn't include a path (no '/' + characters present), then it will be edited to include the correct path. + + If you want stdout/stderr captured to files, use `run_capture` instead. + """ + + self._fixpath(command) + print('Running command "{}"'.format(' '.join(command))) + env = self._build_env(env) + subprocess.run(command, env=env, cwd=cwd, check=True) + + def run_capture(self, + command: List[str], + env: Optional[Env] = None, + cwd: Optional[str] = None, + **kwargs: Any) -> str: + """ + Run one of the postgres binaries, with stderr and stdout redirected to a file. + + This is just like `run`, but for chatty programs. Returns basepath for files + with captured output. + """ + + self._fixpath(command) + print('Running command "{}"'.format(' '.join(command))) + env = self._build_env(env) + return subprocess_capture(str(self.log_dir), + command, + env=env, + cwd=cwd, + check=True, + **kwargs) + +class PgProtocol: + """ Reusable connection logic """ + def __init__(self, **kwargs): + self.default_options = kwargs + + def connstr(self, **kwargs) -> str: + """ + Build a libpq connection string for the Postgres instance. + """ + return str(make_dsn(**self.conn_options(**kwargs))) + + def conn_options(self, **kwargs): + conn_options = self.default_options.copy() + if 'dsn' in kwargs: + conn_options.update(parse_dsn(kwargs['dsn'])) + conn_options.update(kwargs) + + # Individual statement timeout in seconds. 2 minutes should be + # enough for our tests, but if you need a longer, you can + # change it by calling "SET statement_timeout" after + # connecting. + if 'options' in conn_options: + conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options'] + else: + conn_options['options'] = "-cstatement_timeout=120s" + return conn_options + + # autocommit=True here by default because that's what we need most of the time + def connect(self, autocommit=True, **kwargs) -> PgConnection: + """ + Connect to the node. + Returns psycopg2's connection object. + This method passes all extra params to connstr. + """ + conn = psycopg2.connect(**self.conn_options(**kwargs)) + + # WARNING: this setting affects *all* tests! + conn.autocommit = autocommit + return conn + + async def connect_async(self, **kwargs) -> asyncpg.Connection: + """ + Connect to the node from async python. + Returns asyncpg's connection object. + """ + + # asyncpg takes slightly different options than psycopg2. Try + # to convert the defaults from the psycopg2 format. + + # The psycopg2 option 'dbname' is called 'database' is asyncpg + conn_options = self.conn_options(**kwargs) + if 'dbname' in conn_options: + conn_options['database'] = conn_options.pop('dbname') + + # Convert options='-c=' to server_settings + if 'options' in conn_options: + options = conn_options.pop('options') + for match in re.finditer('-c(\w*)=(\w*)', options): + key = match.group(1) + val = match.group(2) + if 'server_options' in conn_options: + conn_options['server_settings'].update({key: val}) + else: + conn_options['server_settings'] = {key: val} + return await asyncpg.connect(**conn_options) + + def safe_psql(self, query: str, **kwargs: Any) -> List[Tuple[Any, ...]]: + """ + Execute query against the node and return all rows. + This method passes all extra params to connstr. + """ + return self.safe_psql_many([query], **kwargs)[0] + + def safe_psql_many(self, queries: List[str], **kwargs: Any) -> List[List[Tuple[Any, ...]]]: + """ + Execute queries against the node and return all rows. + This method passes all extra params to connstr. + """ + result: List[List[Any]] = [] + with closing(self.connect(**kwargs)) as conn: + with conn.cursor() as cur: + for query in queries: + print(f"Executing query: {query}") + cur.execute(query) + + if cur.description is None: + result.append([]) # query didn't return data + else: + result.append(cast(List[Any], cur.fetchall())) + return result + + +class VanillaPostgres(PgProtocol): + def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init=True): + super().__init__(host='localhost', port=port, dbname='postgres') + self.pgdatadir = pgdatadir + self.pg_bin = pg_bin + self.running = False + if init: + self.pg_bin.run_capture(['initdb', '-D', str(pgdatadir)]) + self.configure([f"port = {port}\n"]) + + def configure(self, options: List[str]): + """Append lines into postgresql.conf file.""" + assert not self.running + with open(os.path.join(self.pgdatadir, 'postgresql.conf'), 'a') as conf_file: + conf_file.write("\n".join(options)) + + def start(self, log_path: Optional[str] = None): + assert not self.running + self.running = True + + if log_path is None: + log_path = os.path.join(self.pgdatadir, "pg.log") + + self.pg_bin.run_capture( + ['pg_ctl', '-w', '-D', str(self.pgdatadir), '-l', log_path, 'start']) + + def stop(self): + assert self.running + self.running = False + self.pg_bin.run_capture(['pg_ctl', '-w', '-D', str(self.pgdatadir), 'stop']) + + def get_subdir_size(self, subdir) -> int: + """Return size of pgdatadir subdirectory in bytes.""" + return get_dir_size(os.path.join(self.pgdatadir, subdir)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + if self.running: + self.stop() + + +### actual code + + +def get_rel_paths(log_dir, pg_bin, base_tar): + """Yeild list of relation paths""" + with tempfile.TemporaryDirectory() as restored_dir: + # Unpack the base tar + subprocess_capture(log_dir, ["tar", "-xf", base_tar, "-C", restored_dir]) + + port = "55439" # Probably free + with VanillaPostgres(restored_dir, pg_bin, port, init=False) as vanilla_pg: + vanilla_pg.configure([f"port={port}"]) + vanilla_pg.start() + + # Create database based on template0 because we can't connect to template0 + query = "create database template0copy template template0" + vanilla_pg.safe_psql(query, user="cloud_admin") + vanilla_pg.safe_psql("CHECKPOINT", user="cloud_admin") + + # Get all databases + query = "select oid, datname from pg_database" + oid_dbname_pairs = vanilla_pg.safe_psql(query, user="cloud_admin") + template0_oid = [ + oid + for (oid, database) in oid_dbname_pairs + if database == "template0" + ][0] + + # Get rel paths for each database + for oid, database in oid_dbname_pairs: + if database == "template0": + # We can't connect to template0 + continue + + query = "select relname, pg_relation_filepath(oid) from pg_class" + result = vanilla_pg.safe_psql(query, user="cloud_admin", dbname=database) + for relname, filepath in result: + if filepath is not None: + + if database == "template0copy": + # Add all template0copy paths to template0 + prefix = f"base/{oid}/" + if filepath.startswith(prefix): + suffix = filepath[len(prefix):] + yield f"base/{template0_oid}/{suffix}" + elif filepath.startswith("global"): + print(f"skipping {database} global file {filepath}") + else: + raise AssertionError + else: + yield filepath + + +def pack_base(log_dir, restored_dir, output_tar): + tmp_tar_name = "tmp.tar" + tmp_tar_path = os.path.join(restored_dir, tmp_tar_name) + cmd = ["tar", "-cf", tmp_tar_name] + os.listdir(restored_dir) + subprocess_capture(log_dir, cmd, cwd=restored_dir) + shutil.move(tmp_tar_path, output_tar) + + +def get_files_in_tar(log_dir, tar): + with tempfile.TemporaryDirectory() as restored_dir: + # Unpack the base tar + subprocess_capture(log_dir, ["tar", "-xf", tar, "-C", restored_dir]) + + # Find empty files + empty_files = [] + for root, dirs, files in os.walk(restored_dir): + for name in files: + file_path = os.path.join(root, name) + yield file_path[len(restored_dir) + 1:] + + +def corrupt(log_dir, base_tar, output_tar): + """Remove all empty files and repackage. Return paths of files removed.""" + with tempfile.TemporaryDirectory() as restored_dir: + # Unpack the base tar + subprocess_capture(log_dir, ["tar", "-xf", base_tar, "-C", restored_dir]) + + # Find empty files + empty_files = [] + for root, dirs, files in os.walk(restored_dir): + for name in files: + file_path = os.path.join(root, name) + file_size = os.path.getsize(file_path) + if file_size == 0: + empty_files.append(file_path) + + # Delete empty files (just to see if they get recreated) + for empty_file in empty_files: + os.remove(empty_file) + + # Repackage + pack_base(log_dir, restored_dir, output_tar) + + # Return relative paths + return { + empty_file[len(restored_dir) + 1:] + for empty_file in empty_files + } + + +def touch_missing_rels(log_dir, corrupt_tar, output_tar, paths): + with tempfile.TemporaryDirectory() as restored_dir: + # Unpack the base tar + subprocess_capture(log_dir, ["tar", "-xf", corrupt_tar, "-C", restored_dir]) + + # Touch files that don't exist + for path in paths: + absolute_path = os.path.join(restored_dir, path) + exists = os.path.exists(absolute_path) + if not exists: + print("File {absolute_path} didn't exist. Creating..") + Path(absolute_path).touch() + + # Repackage + pack_base(log_dir, restored_dir, output_tar) + + +# TODO this test is not currently called. It needs any ordinary base.tar path as input +def test_add_missing_rels(base_tar): + output_tar = base_tar + ".fixed" + + # Create new base tar with missing empty files + corrupt_tar = os.path.join(test_output_dir, "psql_2-corrupted.stdout") + deleted_files = corrupt(test_output_dir, base_tar, corrupt_tar) + assert len(set(get_files_in_tar(test_output_dir, base_tar)) - + set(get_files_in_tar(test_output_dir, corrupt_tar))) > 0 + + # Reconstruct paths from the corrupted tar, assert it covers everything important + reconstructed_paths = set(get_rel_paths(test_output_dir, pg_bin, corrupt_tar)) + paths_missed = deleted_files - reconstructed_paths + assert paths_missed.issubset({ + "postgresql.auto.conf", + "pg_ident.conf", + }) + + # Recreate the correct tar by touching files, compare with original tar + touch_missing_rels(test_output_dir, corrupt_tar, output_tar, reconstructed_paths) + paths_missed = (set(get_files_in_tar(test_output_dir, base_tar)) - + set(get_files_in_tar(test_output_dir, output_tar))) + assert paths_missed.issubset({ + "postgresql.auto.conf", + "pg_ident.conf", + }) + + +# Example command: +# poetry run python scripts/add_missing_rels.py \ +# --base-tar /home/bojan/src/neondatabase/neon/test_output/test_import_from_pageserver/psql_2.stdout \ +# --output-tar output-base.tar \ +# --log-dir /home/bojan/tmp +# --pg-distrib-dir /home/bojan/src/neondatabase/neon/tmp_install/ +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--base-tar', + dest='base_tar', + required=True, + help='base.tar file to add missing rels to (file will not be modified)', + ) + parser.add_argument( + '--output-tar', + dest='output_tar', + required=True, + help='path and name for the output base.tar file', + ) + parser.add_argument( + '--log-dir', + dest='log_dir', + required=True, + help='directory to save log files in', + ) + parser.add_argument( + '--pg-distrib-dir', + dest='pg_distrib_dir', + required=True, + help='directory where postgres is installed', + ) + args = parser.parse_args() + base_tar = args.base_tar + output_tar = args.output_tar + log_dir = args.log_dir + pg_bin = PgBin(log_dir, args.pg_distrib_dir) + + reconstructed_paths = set(get_rel_paths(log_dir, pg_bin, base_tar)) + touch_missing_rels(log_dir, base_tar, output_tar, reconstructed_paths)