From 6f7aeaa1c563906fe14a4f62870d9fcd62436e02 Mon Sep 17 00:00:00 2001 From: Alexander Bayandin Date: Mon, 25 Nov 2024 09:01:05 +0000 Subject: [PATCH 01/34] test_runner: use LFC by default (#8613) ## Problem LFC is not enabled by default in tests, but it is enabled in production. This increases the risk of errors in the production environment, which were not found during the routine workflow. However, enabling LFC for all the tests may overload the disk on our servers and increase the number of failures. So, we try enabling LFC in one case to evaluate the possible risk. ## Summary of changes A new environment variable, USE_LFC is introduced. If it is set to true, LFC is enabled by default in all the tests. In our workflow, we enable LFC for PG17, release, x86-64, and disabled for all other combinations. --------- Co-authored-by: Alexey Masterov Co-authored-by: a-masterov <72613290+a-masterov@users.noreply.github.com> --- .github/workflows/_build-and-test-locally.yml | 9 +- .github/workflows/build_and_test.yml | 9 +- .../ingest_regress_test_result-new-format.py | 4 + test_runner/fixtures/neon_fixtures.py | 82 +++++++++++++++++-- test_runner/fixtures/parametrize.py | 1 + test_runner/fixtures/utils.py | 21 +++++ test_runner/fixtures/workload.py | 2 +- test_runner/regress/test_combocid.py | 18 +--- .../regress/test_explain_with_lfc_stats.py | 5 +- test_runner/regress/test_hot_standby.py | 2 +- test_runner/regress/test_lfc_resize.py | 15 ++-- .../test_lfc_working_set_approximation.py | 9 +- test_runner/regress/test_local_file_cache.py | 6 +- .../regress/test_logical_replication.py | 12 ++- test_runner/regress/test_oid_overflow.py | 2 +- test_runner/regress/test_read_validation.py | 2 +- test_runner/regress/test_readonly_node.py | 2 +- .../regress/test_timeline_detach_ancestor.py | 2 +- test_runner/regress/test_vm_bits.py | 2 +- test_runner/regress/test_wal_acceptor.py | 2 +- 20 files changed, 158 insertions(+), 49 deletions(-) diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index 8e28049888..bdf7c07c6a 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -19,8 +19,8 @@ on: description: 'debug or release' required: true type: string - pg-versions: - description: 'a json array of postgres versions to run regression tests on' + test-cfg: + description: 'a json object of postgres versions and lfc states to run regression tests on' required: true type: string @@ -276,14 +276,14 @@ jobs: options: --init --shm-size=512mb --ulimit memlock=67108864:67108864 strategy: fail-fast: false - matrix: - pg_version: ${{ fromJson(inputs.pg-versions) }} + matrix: ${{ fromJSON(format('{{"include":{0}}}', inputs.test-cfg)) }} steps: - uses: actions/checkout@v4 with: submodules: true - name: Pytest regression tests + continue-on-error: ${{ matrix.lfc_state == 'with-lfc' }} uses: ./.github/actions/run-python-test-set timeout-minutes: 60 with: @@ -300,6 +300,7 @@ jobs: CHECK_ONDISK_DATA_COMPATIBILITY: nonempty BUILD_TAG: ${{ inputs.build-tag }} PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring + USE_LFC: ${{ matrix.lfc_state == 'with-lfc' && 'true' || 'false' }} # Temporary disable this step until we figure out why it's so flaky # Ref https://github.com/neondatabase/neon/issues/4540 diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 89fd2d0d17..9830c2a0c9 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -253,7 +253,14 @@ jobs: build-tag: ${{ needs.tag.outputs.build-tag }} build-type: ${{ matrix.build-type }} # Run tests on all Postgres versions in release builds and only on the latest version in debug builds - pg-versions: ${{ matrix.build-type == 'release' && '["v14", "v15", "v16", "v17"]' || '["v17"]' }} + # run without LFC on v17 release only + test-cfg: | + ${{ matrix.build-type == 'release' && '[{"pg_version":"v14", "lfc_state": "without-lfc"}, + {"pg_version":"v15", "lfc_state": "without-lfc"}, + {"pg_version":"v16", "lfc_state": "without-lfc"}, + {"pg_version":"v17", "lfc_state": "without-lfc"}, + {"pg_version":"v17", "lfc_state": "with-lfc"}]' + || '[{"pg_version":"v17", "lfc_state": "without-lfc"}]' }} secrets: inherit # Keep `benchmarks` job outside of `build-and-test-locally` workflow to make job failures non-blocking diff --git a/scripts/ingest_regress_test_result-new-format.py b/scripts/ingest_regress_test_result-new-format.py index c99cfa2b01..064c516718 100644 --- a/scripts/ingest_regress_test_result-new-format.py +++ b/scripts/ingest_regress_test_result-new-format.py @@ -31,6 +31,7 @@ CREATE TABLE IF NOT EXISTS results ( duration INT NOT NULL, flaky BOOLEAN NOT NULL, arch arch DEFAULT 'X64', + lfc BOOLEAN DEFAULT false NOT NULL, build_type TEXT NOT NULL, pg_version INT NOT NULL, run_id BIGINT NOT NULL, @@ -54,6 +55,7 @@ class Row: duration: int flaky: bool arch: str + lfc: bool build_type: str pg_version: int run_id: int @@ -132,6 +134,7 @@ def ingest_test_result( if p["name"].startswith("__") } arch = parameters.get("arch", "UNKNOWN").strip("'") + lfc = parameters.get("lfc", "False") == "True" build_type, pg_version, unparametrized_name = parse_test_name(test["name"]) labels = {label["name"]: label["value"] for label in test["labels"]} @@ -145,6 +148,7 @@ def ingest_test_result( duration=test["time"]["duration"], flaky=test["flaky"] or test["retriesStatusChange"], arch=arch, + lfc=lfc, build_type=build_type, pg_version=pg_version, run_id=run_id, diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 78e2422171..07d442b4a6 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -90,10 +90,12 @@ from fixtures.safekeeper.utils import wait_walreceivers_absent from fixtures.utils import ( ATTACHMENT_NAME_REGEX, COMPONENT_BINARIES, + USE_LFC, allure_add_grafana_links, assert_no_errors, get_dir_size, print_gc_result, + size_to_bytes, subprocess_capture, wait_until, ) @@ -3742,12 +3744,45 @@ class Endpoint(PgProtocol, LogUtils): self.pgdata_dir = self.env.repo_dir / path self.logfile = self.endpoint_path() / "compute.log" - config_lines = config_lines or [] - # set small 'max_replication_write_lag' to enable backpressure # and make tests more stable. config_lines = ["max_replication_write_lag=15MB"] + config_lines + # Delete file cache if it exists (and we're recreating the endpoint) + if USE_LFC: + if (lfc_path := Path(self.lfc_path())).exists(): + lfc_path.unlink() + else: + lfc_path.parent.mkdir(parents=True, exist_ok=True) + for line in config_lines: + if ( + line.find("neon.max_file_cache_size") > -1 + or line.find("neon.file_cache_size_limit") > -1 + ): + m = re.search(r"=\s*(\S+)", line) + assert m is not None, f"malformed config line {line}" + size = m.group(1) + assert size_to_bytes(size) >= size_to_bytes( + "1MB" + ), "LFC size cannot be set less than 1MB" + # shared_buffers = 512kB to make postgres use LFC intensively + # neon.max_file_cache_size and neon.file_cache size limit are + # set to 1MB because small LFC is better for testing (helps to find more problems) + config_lines = [ + "shared_buffers = 512kB", + f"neon.file_cache_path = '{self.lfc_path()}'", + "neon.max_file_cache_size = 1MB", + "neon.file_cache_size_limit = 1MB", + ] + config_lines + else: + for line in config_lines: + assert ( + line.find("neon.max_file_cache_size") == -1 + ), "Setting LFC parameters is not allowed when LFC is disabled" + assert ( + line.find("neon.file_cache_size_limit") == -1 + ), "Setting LFC parameters is not allowed when LFC is disabled" + self.config(config_lines) return self @@ -3781,6 +3816,9 @@ class Endpoint(PgProtocol, LogUtils): basebackup_request_tries=basebackup_request_tries, ) self._running.release(1) + self.log_config_value("shared_buffers") + self.log_config_value("neon.max_file_cache_size") + self.log_config_value("neon.file_cache_size_limit") return self @@ -3806,6 +3844,10 @@ class Endpoint(PgProtocol, LogUtils): """Path to the postgresql.conf in the endpoint directory (not the one in pgdata)""" return self.endpoint_path() / "postgresql.conf" + def lfc_path(self) -> Path: + """Path to the lfc file""" + return self.endpoint_path() / "file_cache" / "file.cache" + def config(self, lines: list[str]) -> Self: """ Add lines to postgresql.conf. @@ -3984,16 +4026,46 @@ class Endpoint(PgProtocol, LogUtils): assert self.pgdata_dir is not None # please mypy return get_dir_size(self.pgdata_dir / "pg_wal") / 1024 / 1024 - def clear_shared_buffers(self, cursor: Any | None = None): + def clear_buffers(self, cursor: Any | None = None): """ Best-effort way to clear postgres buffers. Pinned buffers will not be 'cleared.' - - Might also clear LFC. + It clears LFC as well by setting neon.file_cache_size_limit to 0 and then returning it to the previous value, + if LFC is enabled """ if cursor is not None: cursor.execute("select clear_buffer_cache()") + if not USE_LFC: + return + cursor.execute("SHOW neon.file_cache_size_limit") + res = cursor.fetchone() + assert res, "Cannot get neon.file_cache_size_limit" + file_cache_size_limit = res[0] + if file_cache_size_limit == 0: + return + cursor.execute("ALTER SYSTEM SET neon.file_cache_size_limit=0") + cursor.execute("SELECT pg_reload_conf()") + cursor.execute(f"ALTER SYSTEM SET neon.file_cache_size_limit='{file_cache_size_limit}'") + cursor.execute("SELECT pg_reload_conf()") else: self.safe_psql("select clear_buffer_cache()") + if not USE_LFC: + return + file_cache_size_limit = self.safe_psql_scalar( + "SHOW neon.file_cache_size_limit", log_query=False + ) + if file_cache_size_limit == 0: + return + self.safe_psql("ALTER SYSTEM SET neon.file_cache_size_limit=0") + self.safe_psql("SELECT pg_reload_conf()") + self.safe_psql(f"ALTER SYSTEM SET neon.file_cache_size_limit='{file_cache_size_limit}'") + self.safe_psql("SELECT pg_reload_conf()") + + def log_config_value(self, param): + """ + Writes the config value param to log + """ + res = self.safe_psql_scalar(f"SHOW {param}", log_query=False) + log.info("%s = %s", param, res) class EndpointFactory: diff --git a/test_runner/fixtures/parametrize.py b/test_runner/fixtures/parametrize.py index 2c6adb8a33..f57c0f801f 100644 --- a/test_runner/fixtures/parametrize.py +++ b/test_runner/fixtures/parametrize.py @@ -116,5 +116,6 @@ def pytest_runtest_makereport(*args, **kwargs): }.get(os.uname().machine, "UNKNOWN") arch = os.getenv("RUNNER_ARCH", uname_m) allure.dynamic.parameter("__arch", arch) + allure.dynamic.parameter("__lfc", os.getenv("USE_LFC") != "false") yield diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index 30720e648d..04e98fe494 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -57,6 +57,10 @@ VERSIONS_COMBINATIONS = ( ) # fmt: on +# If the environment variable USE_LFC is set and its value is "false", then LFC is disabled for tests. +# If it is not set or set to a value not equal to "false", LFC is enabled by default. +USE_LFC = os.environ.get("USE_LFC") != "false" + def subprocess_capture( capture_dir: Path, @@ -653,6 +657,23 @@ def allpairs_versions(): return {"argnames": "combination", "argvalues": tuple(argvalues), "ids": ids} +def size_to_bytes(hr_size: str) -> int: + """ + Gets human-readable size from postgresql.conf (e.g. 512kB, 10MB) + returns size in bytes + """ + units = {"B": 1, "kB": 1024, "MB": 1024**2, "GB": 1024**3, "TB": 1024**4, "PB": 1024**5} + match = re.search(r"^\'?(\d+)\s*([kMGTP]?B)?\'?$", hr_size) + assert match is not None, f'"{hr_size}" is not a well-formatted human-readable size' + number, unit = match.groups() + + if unit: + amp = units[unit] + else: + amp = 8192 + return int(number) * amp + + def skip_on_postgres(version: PgVersion, reason: str): return pytest.mark.skipif( PgVersion(os.getenv("DEFAULT_PG_VERSION", PgVersion.DEFAULT)) is version, diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index 4c6b2b6b3e..1b8c9fef44 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -193,7 +193,7 @@ class Workload: def validate(self, pageserver_id: int | None = None): endpoint = self.endpoint(pageserver_id) - endpoint.clear_shared_buffers() + endpoint.clear_buffers() result = endpoint.safe_psql(f"SELECT COUNT(*) FROM {self.table}") log.info(f"validate({self.expect_rows}): {result}") diff --git a/test_runner/regress/test_combocid.py b/test_runner/regress/test_combocid.py index 57d5b2d8b3..2db16d9f64 100644 --- a/test_runner/regress/test_combocid.py +++ b/test_runner/regress/test_combocid.py @@ -5,12 +5,7 @@ from fixtures.neon_fixtures import NeonEnvBuilder, flush_ep_to_pageserver def do_combocid_op(neon_env_builder: NeonEnvBuilder, op): env = neon_env_builder.init_start() - endpoint = env.endpoints.create_start( - "main", - config_lines=[ - "shared_buffers='1MB'", - ], - ) + endpoint = env.endpoints.create_start("main") conn = endpoint.connect() cur = conn.cursor() @@ -36,7 +31,7 @@ def do_combocid_op(neon_env_builder: NeonEnvBuilder, op): # Clear the cache, so that we exercise reconstructing the pages # from WAL - endpoint.clear_shared_buffers() + endpoint.clear_buffers() # Check that the cursor opened earlier still works. If the # combocids are not restored correctly, it won't. @@ -65,12 +60,7 @@ def test_combocid_lock(neon_env_builder: NeonEnvBuilder): def test_combocid_multi_insert(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() - endpoint = env.endpoints.create_start( - "main", - config_lines=[ - "shared_buffers='1MB'", - ], - ) + endpoint = env.endpoints.create_start("main") conn = endpoint.connect() cur = conn.cursor() @@ -98,7 +88,7 @@ def test_combocid_multi_insert(neon_env_builder: NeonEnvBuilder): cur.execute("delete from t") # Clear the cache, so that we exercise reconstructing the pages # from WAL - endpoint.clear_shared_buffers() + endpoint.clear_buffers() # Check that the cursor opened earlier still works. If the # combocids are not restored correctly, it won't. diff --git a/test_runner/regress/test_explain_with_lfc_stats.py b/test_runner/regress/test_explain_with_lfc_stats.py index 2128bd93dd..382556fd7e 100644 --- a/test_runner/regress/test_explain_with_lfc_stats.py +++ b/test_runner/regress/test_explain_with_lfc_stats.py @@ -2,10 +2,13 @@ from __future__ import annotations from pathlib import Path +import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv +from fixtures.utils import USE_LFC +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_explain_with_lfc_stats(neon_simple_env: NeonEnv): env = neon_simple_env @@ -16,8 +19,6 @@ def test_explain_with_lfc_stats(neon_simple_env: NeonEnv): endpoint = env.endpoints.create_start( "main", config_lines=[ - "shared_buffers='1MB'", - f"neon.file_cache_path='{cache_dir}/file.cache'", "neon.max_file_cache_size='128MB'", "neon.file_cache_size_limit='64MB'", ], diff --git a/test_runner/regress/test_hot_standby.py b/test_runner/regress/test_hot_standby.py index a906e7a243..0b1ac11c16 100644 --- a/test_runner/regress/test_hot_standby.py +++ b/test_runner/regress/test_hot_standby.py @@ -170,7 +170,7 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): # re-execute the query, it will make GetPage # requests. This does not clear the last-written LSN cache # so we still remember the LSNs of the pages. - secondary.clear_shared_buffers(cursor=s_cur) + secondary.clear_buffers(cursor=s_cur) if pause_apply: s_cur.execute("SELECT pg_wal_replay_pause()") diff --git a/test_runner/regress/test_lfc_resize.py b/test_runner/regress/test_lfc_resize.py index 3083128d87..377b0fb4d4 100644 --- a/test_runner/regress/test_lfc_resize.py +++ b/test_runner/regress/test_lfc_resize.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import random import re import subprocess @@ -10,20 +9,24 @@ import time import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv, PgBin +from fixtures.utils import USE_LFC @pytest.mark.timeout(600) +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_lfc_resize(neon_simple_env: NeonEnv, pg_bin: PgBin): """ Test resizing the Local File Cache """ env = neon_simple_env + cache_dir = env.repo_dir / "file_cache" + cache_dir.mkdir(exist_ok=True) + env.create_branch("test_lfc_resize") endpoint = env.endpoints.create_start( "main", config_lines=[ - "neon.file_cache_path='file.cache'", - "neon.max_file_cache_size=512MB", - "neon.file_cache_size_limit=512MB", + "neon.max_file_cache_size=1GB", + "neon.file_cache_size_limit=1GB", ], ) n_resize = 10 @@ -63,8 +66,8 @@ def test_lfc_resize(neon_simple_env: NeonEnv, pg_bin: PgBin): cur.execute("select pg_reload_conf()") nretries = 10 while True: - lfc_file_path = f"{endpoint.pg_data_dir_path()}/file.cache" - lfc_file_size = os.path.getsize(lfc_file_path) + lfc_file_path = endpoint.lfc_path() + lfc_file_size = lfc_file_path.stat().st_size res = subprocess.run( ["ls", "-sk", lfc_file_path], check=True, text=True, capture_output=True ) diff --git a/test_runner/regress/test_lfc_working_set_approximation.py b/test_runner/regress/test_lfc_working_set_approximation.py index 36dfec969f..17068849d4 100644 --- a/test_runner/regress/test_lfc_working_set_approximation.py +++ b/test_runner/regress/test_lfc_working_set_approximation.py @@ -3,11 +3,13 @@ from __future__ import annotations import time from pathlib import Path +import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv -from fixtures.utils import query_scalar +from fixtures.utils import USE_LFC, query_scalar +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_lfc_working_set_approximation(neon_simple_env: NeonEnv): env = neon_simple_env @@ -18,8 +20,6 @@ def test_lfc_working_set_approximation(neon_simple_env: NeonEnv): endpoint = env.endpoints.create_start( "main", config_lines=[ - "shared_buffers='1MB'", - f"neon.file_cache_path='{cache_dir}/file.cache'", "neon.max_file_cache_size='128MB'", "neon.file_cache_size_limit='64MB'", ], @@ -72,9 +72,10 @@ WITH (fillfactor='100'); # verify working set size after some index access of a few select pages only blocks = query_scalar(cur, "select approximate_working_set_size(true)") log.info(f"working set size after some index access of a few select pages only {blocks}") - assert blocks < 10 + assert blocks < 12 +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_sliding_working_set_approximation(neon_simple_env: NeonEnv): env = neon_simple_env diff --git a/test_runner/regress/test_local_file_cache.py b/test_runner/regress/test_local_file_cache.py index fbf018a167..94c630ffcf 100644 --- a/test_runner/regress/test_local_file_cache.py +++ b/test_runner/regress/test_local_file_cache.py @@ -6,10 +6,12 @@ import random import threading import time +import pytest from fixtures.neon_fixtures import NeonEnvBuilder -from fixtures.utils import query_scalar +from fixtures.utils import USE_LFC, query_scalar +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_local_file_cache_unlink(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() @@ -19,8 +21,6 @@ def test_local_file_cache_unlink(neon_env_builder: NeonEnvBuilder): endpoint = env.endpoints.create_start( "main", config_lines=[ - "shared_buffers='1MB'", - f"neon.file_cache_path='{cache_dir}/file.cache'", "neon.max_file_cache_size='64MB'", "neon.file_cache_size_limit='10MB'", ], diff --git a/test_runner/regress/test_logical_replication.py b/test_runner/regress/test_logical_replication.py index df83ca1c44..ba471b7147 100644 --- a/test_runner/regress/test_logical_replication.py +++ b/test_runner/regress/test_logical_replication.py @@ -12,7 +12,7 @@ from fixtures.neon_fixtures import ( logical_replication_sync, wait_for_last_flush_lsn, ) -from fixtures.utils import wait_until +from fixtures.utils import USE_LFC, wait_until if TYPE_CHECKING: from fixtures.neon_fixtures import ( @@ -576,7 +576,15 @@ def test_subscriber_synchronous_commit(neon_simple_env: NeonEnv, vanilla_pg: Van # We want all data to fit into shared_buffers because later we stop # safekeeper and insert more; this shouldn't cause page requests as they # will be stuck. - sub = env.endpoints.create("subscriber", config_lines=["shared_buffers=128MB"]) + sub = env.endpoints.create( + "subscriber", + config_lines=[ + "neon.max_file_cache_size = 32MB", + "neon.file_cache_size_limit = 32MB", + ] + if USE_LFC + else [], + ) sub.start() with vanilla_pg.cursor() as pcur: diff --git a/test_runner/regress/test_oid_overflow.py b/test_runner/regress/test_oid_overflow.py index f69c1112c7..e2bde8be6f 100644 --- a/test_runner/regress/test_oid_overflow.py +++ b/test_runner/regress/test_oid_overflow.py @@ -39,7 +39,7 @@ def test_oid_overflow(neon_env_builder: NeonEnvBuilder): oid = cur.fetchall()[0][0] log.info(f"t2.relfilenode={oid}") - endpoint.clear_shared_buffers(cursor=cur) + endpoint.clear_buffers(cursor=cur) cur.execute("SELECT x from t1") assert cur.fetchone() == (1,) diff --git a/test_runner/regress/test_read_validation.py b/test_runner/regress/test_read_validation.py index 471a3b406a..70a7a675df 100644 --- a/test_runner/regress/test_read_validation.py +++ b/test_runner/regress/test_read_validation.py @@ -54,7 +54,7 @@ def test_read_validation(neon_simple_env: NeonEnv): log.info("Clear buffer cache to ensure no stale pages are brought into the cache") - endpoint.clear_shared_buffers(cursor=c) + endpoint.clear_buffers(cursor=c) cache_entries = query_scalar( c, f"select count(*) from pg_buffercache where relfilenode = {relfilenode}" diff --git a/test_runner/regress/test_readonly_node.py b/test_runner/regress/test_readonly_node.py index fcebf8d23a..70d558ac5a 100644 --- a/test_runner/regress/test_readonly_node.py +++ b/test_runner/regress/test_readonly_node.py @@ -230,7 +230,7 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): return offset # Insert some records on main branch - with env.endpoints.create_start("main") as ep_main: + with env.endpoints.create_start("main", config_lines=["shared_buffers=1MB"]) as ep_main: with ep_main.cursor() as cur: cur.execute("CREATE TABLE t0(v0 int primary key, v1 text)") lsn = Lsn(0) diff --git a/test_runner/regress/test_timeline_detach_ancestor.py b/test_runner/regress/test_timeline_detach_ancestor.py index cd4e0a5f3b..9c7e851ba8 100644 --- a/test_runner/regress/test_timeline_detach_ancestor.py +++ b/test_runner/regress/test_timeline_detach_ancestor.py @@ -416,7 +416,7 @@ def test_detached_receives_flushes_while_being_detached(neon_env_builder: NeonEn assert client.timeline_detail(env.initial_tenant, timeline_id)["ancestor_timeline_id"] is None - ep.clear_shared_buffers() + ep.clear_buffers() assert ep.safe_psql("SELECT count(*) FROM foo;")[0][0] == rows assert ep.safe_psql("SELECT SUM(LENGTH(aux)) FROM foo")[0][0] != 0 ep.stop() diff --git a/test_runner/regress/test_vm_bits.py b/test_runner/regress/test_vm_bits.py index d4c2ca7e07..f93fc6bd8b 100644 --- a/test_runner/regress/test_vm_bits.py +++ b/test_runner/regress/test_vm_bits.py @@ -63,7 +63,7 @@ def test_vm_bit_clear(neon_simple_env: NeonEnv): # Clear the buffer cache, to force the VM page to be re-fetched from # the page server - endpoint.clear_shared_buffers(cursor=cur) + endpoint.clear_buffers(cursor=cur) # Check that an index-only scan doesn't see the deleted row. If the # clearing of the VM bit was not replayed correctly, this would incorrectly diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 405f15e488..8fa33b81a9 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -2446,7 +2446,7 @@ def test_broker_discovery(neon_env_builder: NeonEnvBuilder): # generate some data to commit WAL on safekeepers endpoint.safe_psql("insert into t select generate_series(1,100), 'action'") # clear the buffers - endpoint.clear_shared_buffers() + endpoint.clear_buffers() # read data to fetch pages from pageserver endpoint.safe_psql("select sum(i) from t") From 0d1e82f0a7112478d3681bbba1035cbbe2f4b408 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Mon, 25 Nov 2024 11:59:49 +0100 Subject: [PATCH 02/34] Bump futures-* crates, drop unused license, hide duplicate crate warnings (#9858) * The futures-util crate we use was yanked. Bump it and its siblings to new patch release. https://github.com/rust-lang/futures-rs/releases/tag/0.3.31 * cargo-deny: Drop an unused license. * cargo-deny: Don't warn about duplicate crate. Duplicate crates are unavoidable and the noise just hides real warnings. --- Cargo.lock | 28 ++++++++++++++-------------- deny.toml | 3 +-- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 665aa4aecc..5be8b97815 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2174,9 +2174,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -2184,9 +2184,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" @@ -2201,9 +2201,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" @@ -2222,9 +2222,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -2233,15 +2233,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -2251,9 +2251,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", diff --git a/deny.toml b/deny.toml index 8bf643f4ba..7a1eecac99 100644 --- a/deny.toml +++ b/deny.toml @@ -33,7 +33,6 @@ reason = "the marvin attack only affects private key decryption, not public key [licenses] allow = [ "Apache-2.0", - "Artistic-2.0", "BSD-2-Clause", "BSD-3-Clause", "CC0-1.0", @@ -67,7 +66,7 @@ registries = [] # More documentation about the 'bans' section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html [bans] -multiple-versions = "warn" +multiple-versions = "allow" wildcards = "allow" highlight = "all" workspace-default-features = "allow" From 6f6749c4a96673430fa6ecce26ca756ead7e8a46 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 25 Nov 2024 12:01:30 +0000 Subject: [PATCH 03/34] chore: update rustls (#9871) --- Cargo.lock | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5be8b97815..98d2e0864a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4182,7 +4182,7 @@ dependencies = [ "bytes", "once_cell", "pq_proto", - "rustls 0.23.16", + "rustls 0.23.18", "rustls-pemfile 2.1.1", "serde", "thiserror", @@ -4518,7 +4518,7 @@ dependencies = [ "rsa", "rstest", "rustc-hash", - "rustls 0.23.16", + "rustls 0.23.18", "rustls-native-certs 0.8.0", "rustls-pemfile 2.1.1", "scopeguard", @@ -5231,9 +5231,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" dependencies = [ "log", "once_cell", @@ -5948,7 +5948,7 @@ dependencies = [ "once_cell", "parking_lot 0.12.1", "prost", - "rustls 0.23.16", + "rustls 0.23.18", "tokio", "tonic", "tonic-build", @@ -6031,7 +6031,7 @@ dependencies = [ "postgres_ffi", "remote_storage", "reqwest 0.12.4", - "rustls 0.23.16", + "rustls 0.23.18", "rustls-native-certs 0.8.0", "serde", "serde_json", @@ -6493,7 +6493,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab" dependencies = [ "ring", - "rustls 0.23.16", + "rustls 0.23.18", "tokio", "tokio-postgres", "tokio-rustls 0.26.0", @@ -6527,7 +6527,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.16", + "rustls 0.23.18", "rustls-pki-types", "tokio", ] @@ -6936,7 +6936,7 @@ dependencies = [ "base64 0.22.1", "log", "once_cell", - "rustls 0.23.16", + "rustls 0.23.18", "rustls-pki-types", "url", "webpki-roots 0.26.1", @@ -7598,7 +7598,7 @@ dependencies = [ "regex-automata 0.4.3", "regex-syntax 0.8.2", "reqwest 0.12.4", - "rustls 0.23.16", + "rustls 0.23.18", "scopeguard", "serde", "serde_json", From 4630b709627ffea01dd6863f0db84ec4c693bcd6 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Mon, 25 Nov 2024 09:25:18 -0500 Subject: [PATCH 04/34] fix(pageserver): ensure all layers are flushed before measuring RSS (#9861) ## Problem close https://github.com/neondatabase/neon/issues/9761 The test assumed that no new L0 layers are flushed throughout the process, which is not true. ## Summary of changes Fix the test case `test_compaction_l0_memory` by flushing in-memory layers before compaction. Signed-off-by: Alex Chi Z --- test_runner/performance/test_compaction.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test_runner/performance/test_compaction.py b/test_runner/performance/test_compaction.py index 8868dddf39..0cd1080fa7 100644 --- a/test_runner/performance/test_compaction.py +++ b/test_runner/performance/test_compaction.py @@ -103,6 +103,9 @@ def test_compaction_l0_memory(neon_compare: NeonCompare): cur.execute(f"update tbl{i} set j = {j};") wait_for_last_flush_lsn(env, endpoint, tenant_id, timeline_id) + pageserver_http.timeline_checkpoint( + tenant_id, timeline_id, compact=False + ) # ^1: flush all in-memory layers endpoint.stop() # Check we have generated the L0 stack we expected @@ -118,7 +121,9 @@ def test_compaction_l0_memory(neon_compare: NeonCompare): return v * 1024 before = rss_hwm() - pageserver_http.timeline_compact(tenant_id, timeline_id) + pageserver_http.timeline_compact( + tenant_id, timeline_id + ) # ^1: we must ensure during this process no new L0 layers are flushed after = rss_hwm() log.info(f"RSS across compaction: {before} -> {after} (grew {after - before})") @@ -137,7 +142,7 @@ def test_compaction_l0_memory(neon_compare: NeonCompare): # To be fixed in https://github.com/neondatabase/neon/issues/8184, after which # this memory estimate can be revised far downwards to something that doesn't scale # linearly with the layer sizes. - MEMORY_ESTIMATE = (initial_l0s_size - final_l0s_size) * 1.5 + MEMORY_ESTIMATE = (initial_l0s_size - final_l0s_size) * 1.25 # If we find that compaction is using more memory, this may indicate a regression assert compaction_mapped_rss < MEMORY_ESTIMATE From 3d380acbd1dd878b908df860a541ce77bd4506f3 Mon Sep 17 00:00:00 2001 From: Alexander Bayandin Date: Mon, 25 Nov 2024 14:43:32 +0000 Subject: [PATCH 05/34] Bump default Debian version to Bookworm everywhere (#9863) ## Problem We have a couple of CI workflows that still run on Debian Bullseye, and the default Debian version in images is Bullseye as well (we explicitly set building on Bookworm) ## Summary of changes - Run `pgbench-pgvector` on Bookworm (fix a couple of packages) - Run `trigger_bench_on_ec2_machine_in_eu_central_1` on Bookworm - Change default `DEBIAN_VERSION` in Dockerfiles to Bookworm - Make `pinned` docker tag an alias to `pinned-bookworm` --- .github/workflows/benchmarking.yml | 14 +++++++------- .github/workflows/build-build-tools-image.yml | 2 +- .github/workflows/periodic_pagebench.yml | 2 +- .github/workflows/pin-build-tools-image.yml | 2 +- Dockerfile | 2 +- build-tools.Dockerfile | 2 +- compute/compute-node.Dockerfile | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/benchmarking.yml b/.github/workflows/benchmarking.yml index 2ad1ee0a42..ea8fee80c2 100644 --- a/.github/workflows/benchmarking.yml +++ b/.github/workflows/benchmarking.yml @@ -541,7 +541,7 @@ jobs: runs-on: ${{ matrix.RUNNER }} container: - image: neondatabase/build-tools:pinned + image: neondatabase/build-tools:pinned-bookworm credentials: username: ${{ secrets.NEON_DOCKERHUB_USERNAME }} password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }} @@ -558,12 +558,12 @@ jobs: arch=$(uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g') cd /home/nonroot - wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.2-1.pgdg110+1_${arch}.deb" - wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.6-1.pgdg110+1_${arch}.deb" - wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.6-1.pgdg110+1_${arch}.deb" - dpkg -x libpq5_17.2-1.pgdg110+1_${arch}.deb pg - dpkg -x postgresql-16_16.6-1.pgdg110+1_${arch}.deb pg - dpkg -x postgresql-client-16_16.6-1.pgdg110+1_${arch}.deb pg + wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.2-1.pgdg120+1_${arch}.deb" + wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb" + wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.6-1.pgdg120+1_${arch}.deb" + dpkg -x libpq5_17.2-1.pgdg120+1_${arch}.deb pg + dpkg -x postgresql-16_16.6-1.pgdg120+1_${arch}.deb pg + dpkg -x postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb pg mkdir -p /tmp/neon/pg_install/v16/bin ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench diff --git a/.github/workflows/build-build-tools-image.yml b/.github/workflows/build-build-tools-image.yml index 9e7be76901..93da86a353 100644 --- a/.github/workflows/build-build-tools-image.yml +++ b/.github/workflows/build-build-tools-image.yml @@ -117,7 +117,7 @@ jobs: - name: Create multi-arch image env: - DEFAULT_DEBIAN_VERSION: bullseye + DEFAULT_DEBIAN_VERSION: bookworm IMAGE_TAG: ${{ needs.check-image.outputs.tag }} run: | for debian_version in bullseye bookworm; do diff --git a/.github/workflows/periodic_pagebench.yml b/.github/workflows/periodic_pagebench.yml index 1cce348ae2..6b98bc873f 100644 --- a/.github/workflows/periodic_pagebench.yml +++ b/.github/workflows/periodic_pagebench.yml @@ -29,7 +29,7 @@ jobs: trigger_bench_on_ec2_machine_in_eu_central_1: runs-on: [ self-hosted, small ] container: - image: neondatabase/build-tools:pinned + image: neondatabase/build-tools:pinned-bookworm credentials: username: ${{ secrets.NEON_DOCKERHUB_USERNAME }} password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }} diff --git a/.github/workflows/pin-build-tools-image.yml b/.github/workflows/pin-build-tools-image.yml index c196d07d3e..5b43d97de6 100644 --- a/.github/workflows/pin-build-tools-image.yml +++ b/.github/workflows/pin-build-tools-image.yml @@ -94,7 +94,7 @@ jobs: - name: Tag build-tools with `${{ env.TO_TAG }}` in Docker Hub, ECR, and ACR env: - DEFAULT_DEBIAN_VERSION: bullseye + DEFAULT_DEBIAN_VERSION: bookworm run: | for debian_version in bullseye bookworm; do tags=() diff --git a/Dockerfile b/Dockerfile index 785dd4598e..e888efbae2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ ARG IMAGE=build-tools ARG TAG=pinned ARG DEFAULT_PG_VERSION=17 ARG STABLE_PG_VERSION=16 -ARG DEBIAN_VERSION=bullseye +ARG DEBIAN_VERSION=bookworm ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim # Build Postgres diff --git a/build-tools.Dockerfile b/build-tools.Dockerfile index 24e5bbf46f..4f491afec5 100644 --- a/build-tools.Dockerfile +++ b/build-tools.Dockerfile @@ -1,4 +1,4 @@ -ARG DEBIAN_VERSION=bullseye +ARG DEBIAN_VERSION=bookworm FROM debian:bookworm-slim AS pgcopydb_builder ARG DEBIAN_VERSION diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 7c21c67a0a..2fcd9985bc 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -3,7 +3,7 @@ ARG REPOSITORY=neondatabase ARG IMAGE=build-tools ARG TAG=pinned ARG BUILD_TAG -ARG DEBIAN_VERSION=bullseye +ARG DEBIAN_VERSION=bookworm ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim ######################################################################################### From 77630e5408c0771b0e1db020f8d81a7be9728391 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Mon, 25 Nov 2024 15:59:12 +0100 Subject: [PATCH 06/34] Address beta clippy lint needless_lifetimes (#9877) The 1.82.0 version of Rust will be stable soon, let's get the clippy lint fixes in before the compiler version upgrade. --- libs/vm_monitor/src/cgroup.rs | 4 ++-- .../virtual_file/owned_buffers_io/aligned_buffer/slice.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libs/vm_monitor/src/cgroup.rs b/libs/vm_monitor/src/cgroup.rs index 3223765016..1d70cedcf9 100644 --- a/libs/vm_monitor/src/cgroup.rs +++ b/libs/vm_monitor/src/cgroup.rs @@ -218,7 +218,7 @@ impl MemoryStatus { fn debug_slice(slice: &[Self]) -> impl '_ + Debug { struct DS<'a>(&'a [MemoryStatus]); - impl<'a> Debug for DS<'a> { + impl Debug for DS<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("[MemoryStatus]") .field( @@ -233,7 +233,7 @@ impl MemoryStatus { struct Fields<'a, F>(&'a [MemoryStatus], F); - impl<'a, F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'a, F> { + impl T, T: Debug> Debug for Fields<'_, F> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_list().entries(self.0.iter().map(&self.1)).finish() } diff --git a/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/slice.rs b/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/slice.rs index 6cecf34c1c..1952b82578 100644 --- a/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/slice.rs +++ b/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/slice.rs @@ -19,7 +19,7 @@ impl<'a, const N: usize, const A: usize> AlignedSlice<'a, N, ConstAlign> { } } -impl<'a, const N: usize, A: Alignment> Deref for AlignedSlice<'a, N, A> { +impl Deref for AlignedSlice<'_, N, A> { type Target = [u8; N]; fn deref(&self) -> &Self::Target { @@ -27,13 +27,13 @@ impl<'a, const N: usize, A: Alignment> Deref for AlignedSlice<'a, N, A> { } } -impl<'a, const N: usize, A: Alignment> DerefMut for AlignedSlice<'a, N, A> { +impl DerefMut for AlignedSlice<'_, N, A> { fn deref_mut(&mut self) -> &mut Self::Target { self.buf } } -impl<'a, const N: usize, A: Alignment> AsRef<[u8; N]> for AlignedSlice<'a, N, A> { +impl AsRef<[u8; N]> for AlignedSlice<'_, N, A> { fn as_ref(&self) -> &[u8; N] { self.buf } From 441612c1ce11f92d0fad1226be0e42f5500c2c46 Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Mon, 25 Nov 2024 17:21:52 +0200 Subject: [PATCH 07/34] Prefetch on macos (#9875) ## Problem Prefetch is disabled at MacODS because `posix_fadvise` is not available. But Neon prefetch is not using this function and for testing at MacOS is it very convenient that prefetch is available. ## Summary of changes Define `USE_PREFETCH` in Makefile. --------- Co-authored-by: Konstantin Knizhnik --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 8e3b755112..dc67b87239 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,7 @@ ifeq ($(UNAME_S),Linux) # Seccomp BPF is only available for Linux PG_CONFIGURE_OPTS += --with-libseccomp else ifeq ($(UNAME_S),Darwin) + PG_CFLAGS += -DUSE_PREFETCH ifndef DISABLE_HOMEBREW # macOS with brew-installed openssl requires explicit paths # It can be configured with OPENSSL_PREFIX variable From 5c2356988e8beea0418e1292da2adf4b988340ad Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Mon, 25 Nov 2024 16:52:39 +0100 Subject: [PATCH 08/34] page_service: add benchmark for batching (#9820) This PR adds two benchmark to demonstrate the effect of server-side getpage request batching added in https://github.com/neondatabase/neon/pull/9321. For the CPU usage, I found the the `prometheus` crate's built-in CPU usage accounts the seconds at integer granularity. That's not enough you reduce the target benchmark runtime for local iteration. So, add a new `libmetrics` metric and report that. The benchmarks are disabled because [on our benchmark nodes, timer resolution isn't high enough](https://neondb.slack.com/archives/C059ZC138NR/p1732264223207449). They work (no statement about quality) on my bare-metal devbox. They will be refined and enabled once we find a fix. Candidates at time of writing are: - https://github.com/neondatabase/neon/pull/9822 - https://github.com/neondatabase/neon/pull/9851 Refs: - Epic: https://github.com/neondatabase/neon/issues/9376 - Extracted from https://github.com/neondatabase/neon/pull/9792 --- libs/metrics/src/more_process_metrics.rs | 40 ++- test_runner/performance/README.md | 3 +- .../test_pageserver_getpage_merge.py | 307 ++++++++++++++++++ 3 files changed, 347 insertions(+), 3 deletions(-) create mode 100644 test_runner/performance/pageserver/test_pageserver_getpage_merge.py diff --git a/libs/metrics/src/more_process_metrics.rs b/libs/metrics/src/more_process_metrics.rs index 920724fdec..13a745e031 100644 --- a/libs/metrics/src/more_process_metrics.rs +++ b/libs/metrics/src/more_process_metrics.rs @@ -2,14 +2,28 @@ // This module has heavy inspiration from the prometheus crate's `process_collector.rs`. +use once_cell::sync::Lazy; +use prometheus::Gauge; + use crate::UIntGauge; pub struct Collector { descs: Vec, vmlck: crate::UIntGauge, + cpu_seconds_highres: Gauge, } -const NMETRICS: usize = 1; +const NMETRICS: usize = 2; + +static CLK_TCK_F64: Lazy = Lazy::new(|| { + let long = unsafe { libc::sysconf(libc::_SC_CLK_TCK) }; + if long == -1 { + panic!("sysconf(_SC_CLK_TCK) failed"); + } + let convertible_to_f64: i32 = + i32::try_from(long).expect("sysconf(_SC_CLK_TCK) is larger than i32"); + convertible_to_f64 as f64 +}); impl prometheus::core::Collector for Collector { fn desc(&self) -> Vec<&prometheus::core::Desc> { @@ -27,6 +41,12 @@ impl prometheus::core::Collector for Collector { mfs.extend(self.vmlck.collect()) } } + if let Ok(stat) = myself.stat() { + let cpu_seconds = stat.utime + stat.stime; + self.cpu_seconds_highres + .set(cpu_seconds as f64 / *CLK_TCK_F64); + mfs.extend(self.cpu_seconds_highres.collect()); + } mfs } } @@ -43,7 +63,23 @@ impl Collector { .cloned(), ); - Self { descs, vmlck } + let cpu_seconds_highres = Gauge::new( + "libmetrics_process_cpu_seconds_highres", + "Total user and system CPU time spent in seconds.\ + Sub-second resolution, hence better than `process_cpu_seconds_total`.", + ) + .unwrap(); + descs.extend( + prometheus::core::Collector::desc(&cpu_seconds_highres) + .into_iter() + .cloned(), + ); + + Self { + descs, + vmlck, + cpu_seconds_highres, + } } } diff --git a/test_runner/performance/README.md b/test_runner/performance/README.md index 70d75a6dcf..85096d3770 100644 --- a/test_runner/performance/README.md +++ b/test_runner/performance/README.md @@ -15,6 +15,7 @@ Some handy pytest flags for local development: - `-k` selects a test to run - `--timeout=0` disables our default timeout of 300s (see `setup.cfg`) - `--preserve-database-files` to skip cleanup +- `--out-dir` to produce a JSON with the recorded test metrics # What performance tests do we have and how we run them @@ -36,6 +37,6 @@ All tests run only once. Usually to obtain more consistent performance numbers, ## Results collection -Local test results for main branch, and results of daily performance tests, are stored in a neon project deployed in production environment. There is a Grafana dashboard that visualizes the results. Here is the [dashboard](https://observer.zenith.tech/d/DGKBm9Jnz/perf-test-results?orgId=1). The main problem with it is the unavailability to point at particular commit, though the data for that is available in the database. Needs some tweaking from someone who knows Grafana tricks. +Local test results for main branch, and results of daily performance tests, are stored in a [neon project](https://console.neon.tech/app/projects/withered-sky-69117821) deployed in production environment. There is a Grafana dashboard that visualizes the results. Here is the [dashboard](https://observer.zenith.tech/d/DGKBm9Jnz/perf-test-results?orgId=1). The main problem with it is the unavailability to point at particular commit, though the data for that is available in the database. Needs some tweaking from someone who knows Grafana tricks. There is also an inconsistency in test naming. Test name should be the same across platforms, and results can be differentiated by the platform field. But currently, platform is sometimes included in test name because of the way how parametrization works in pytest. I.e. there is a platform switch in the dashboard with neon-local-ci and neon-staging variants. I.e. some tests under neon-local-ci value for a platform switch are displayed as `Test test_runner/performance/test_bulk_insert.py::test_bulk_insert[vanilla]` and `Test test_runner/performance/test_bulk_insert.py::test_bulk_insert[neon]` which is highly confusing. diff --git a/test_runner/performance/pageserver/test_pageserver_getpage_merge.py b/test_runner/performance/pageserver/test_pageserver_getpage_merge.py new file mode 100644 index 0000000000..34cce9900b --- /dev/null +++ b/test_runner/performance/pageserver/test_pageserver_getpage_merge.py @@ -0,0 +1,307 @@ +import dataclasses +import json +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pytest +from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker +from fixtures.log_helper import log +from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, wait_for_last_flush_lsn +from fixtures.utils import humantime_to_ms + +TARGET_RUNTIME = 60 + + +@pytest.mark.skip("See https://github.com/neondatabase/neon/pull/9820#issue-2675856095") +@pytest.mark.parametrize( + "tablesize_mib, batch_timeout, target_runtime, effective_io_concurrency, readhead_buffer_size, name", + [ + # the next 4 cases demonstrate how not-batchable workloads suffer from batching timeout + (50, None, TARGET_RUNTIME, 1, 128, "not batchable no batching"), + (50, "10us", TARGET_RUNTIME, 1, 128, "not batchable 10us timeout"), + (50, "1ms", TARGET_RUNTIME, 1, 128, "not batchable 1ms timeout"), + # the next 4 cases demonstrate how batchable workloads benefit from batching + (50, None, TARGET_RUNTIME, 100, 128, "batchable no batching"), + (50, "10us", TARGET_RUNTIME, 100, 128, "batchable 10us timeout"), + (50, "100us", TARGET_RUNTIME, 100, 128, "batchable 100us timeout"), + (50, "1ms", TARGET_RUNTIME, 100, 128, "batchable 1ms timeout"), + ], +) +def test_getpage_merge_smoke( + neon_env_builder: NeonEnvBuilder, + zenbenchmark: NeonBenchmarker, + tablesize_mib: int, + batch_timeout: str | None, + target_runtime: int, + effective_io_concurrency: int, + readhead_buffer_size: int, + name: str, +): + """ + Do a bunch of sequential scans and ensure that the pageserver does some merging. + """ + + # + # record perf-related parameters as metrics to simplify processing of results + # + params: dict[str, tuple[float | int, dict[str, Any]]] = {} + + params.update( + { + "tablesize_mib": (tablesize_mib, {"unit": "MiB"}), + "batch_timeout": ( + -1 if batch_timeout is None else 1e3 * humantime_to_ms(batch_timeout), + {"unit": "us"}, + ), + # target_runtime is just a polite ask to the workload to run for this long + "effective_io_concurrency": (effective_io_concurrency, {}), + "readhead_buffer_size": (readhead_buffer_size, {}), + # name is not a metric + } + ) + + log.info("params: %s", params) + + for param, (value, kwargs) in params.items(): + zenbenchmark.record( + param, + metric_value=value, + unit=kwargs.pop("unit", ""), + report=MetricReport.TEST_PARAM, + **kwargs, + ) + + # + # Setup + # + + env = neon_env_builder.init_start() + ps_http = env.pageserver.http_client() + endpoint = env.endpoints.create_start("main") + conn = endpoint.connect() + cur = conn.cursor() + + cur.execute("SET max_parallel_workers_per_gather=0") # disable parallel backends + cur.execute(f"SET effective_io_concurrency={effective_io_concurrency}") + cur.execute( + f"SET neon.readahead_buffer_size={readhead_buffer_size}" + ) # this is the current default value, but let's hard-code that + + cur.execute("CREATE EXTENSION IF NOT EXISTS neon;") + cur.execute("CREATE EXTENSION IF NOT EXISTS neon_test_utils;") + + log.info("Filling the table") + cur.execute("CREATE TABLE t (data char(1000)) with (fillfactor=10)") + tablesize = tablesize_mib * 1024 * 1024 + npages = tablesize // (8 * 1024) + cur.execute("INSERT INTO t SELECT generate_series(1, %s)", (npages,)) + # TODO: can we force postgres to do sequential scans? + + # + # Run the workload, collect `Metrics` before and after, calculate difference, normalize. + # + + @dataclass + class Metrics: + time: float + pageserver_getpage_count: float + pageserver_vectored_get_count: float + compute_getpage_count: float + pageserver_cpu_seconds_total: float + + def __sub__(self, other: "Metrics") -> "Metrics": + return Metrics( + time=self.time - other.time, + pageserver_getpage_count=self.pageserver_getpage_count + - other.pageserver_getpage_count, + pageserver_vectored_get_count=self.pageserver_vectored_get_count + - other.pageserver_vectored_get_count, + compute_getpage_count=self.compute_getpage_count - other.compute_getpage_count, + pageserver_cpu_seconds_total=self.pageserver_cpu_seconds_total + - other.pageserver_cpu_seconds_total, + ) + + def normalize(self, by) -> "Metrics": + return Metrics( + time=self.time / by, + pageserver_getpage_count=self.pageserver_getpage_count / by, + pageserver_vectored_get_count=self.pageserver_vectored_get_count / by, + compute_getpage_count=self.compute_getpage_count / by, + pageserver_cpu_seconds_total=self.pageserver_cpu_seconds_total / by, + ) + + def get_metrics() -> Metrics: + with conn.cursor() as cur: + cur.execute( + "select value from neon_perf_counters where metric='getpage_wait_seconds_count';" + ) + compute_getpage_count = cur.fetchall()[0][0] + pageserver_metrics = ps_http.get_metrics() + return Metrics( + time=time.time(), + pageserver_getpage_count=pageserver_metrics.query_one( + "pageserver_smgr_query_seconds_count", {"smgr_query_type": "get_page_at_lsn"} + ).value, + pageserver_vectored_get_count=pageserver_metrics.query_one( + "pageserver_get_vectored_seconds_count", {"task_kind": "PageRequestHandler"} + ).value, + compute_getpage_count=compute_getpage_count, + pageserver_cpu_seconds_total=pageserver_metrics.query_one( + "libmetrics_process_cpu_seconds_highres" + ).value, + ) + + def workload() -> Metrics: + start = time.time() + iters = 0 + while time.time() - start < target_runtime or iters < 2: + log.info("Seqscan %d", iters) + if iters == 1: + # round zero for warming up + before = get_metrics() + cur.execute( + "select clear_buffer_cache()" + ) # TODO: what about LFC? doesn't matter right now because LFC isn't enabled by default in tests + cur.execute("select sum(data::bigint) from t") + assert cur.fetchall()[0][0] == npages * (npages + 1) // 2 + iters += 1 + after = get_metrics() + return (after - before).normalize(iters - 1) + + env.pageserver.patch_config_toml_nonrecursive({"server_side_batch_timeout": batch_timeout}) + env.pageserver.restart() + metrics = workload() + + log.info("Results: %s", metrics) + + # + # Sanity-checks on the collected data + # + # assert that getpage counts roughly match between compute and ps + assert metrics.pageserver_getpage_count == pytest.approx( + metrics.compute_getpage_count, rel=0.01 + ) + + # + # Record the results + # + + for metric, value in dataclasses.asdict(metrics).items(): + zenbenchmark.record(f"counters.{metric}", value, unit="", report=MetricReport.TEST_PARAM) + + zenbenchmark.record( + "perfmetric.batching_factor", + metrics.pageserver_getpage_count / metrics.pageserver_vectored_get_count, + unit="", + report=MetricReport.HIGHER_IS_BETTER, + ) + + +@pytest.mark.skip("See https://github.com/neondatabase/neon/pull/9820#issue-2675856095") +@pytest.mark.parametrize( + "batch_timeout", [None, "10us", "20us", "50us", "100us", "200us", "500us", "1ms"] +) +def test_timer_precision( + neon_env_builder: NeonEnvBuilder, + zenbenchmark: NeonBenchmarker, + pg_bin: PgBin, + batch_timeout: str | None, +): + """ + Determine the batching timeout precision (mean latency) and tail latency impact. + + The baseline is `None`; an ideal batching timeout implementation would increase + the mean latency by exactly `batch_timeout`. + + That is not the case with the current implementation, will be addressed in future changes. + """ + + # + # Setup + # + + def patch_ps_config(ps_config): + ps_config["server_side_batch_timeout"] = batch_timeout + + neon_env_builder.pageserver_config_override = patch_ps_config + + env = neon_env_builder.init_start() + endpoint = env.endpoints.create_start("main") + conn = endpoint.connect() + cur = conn.cursor() + + cur.execute("SET max_parallel_workers_per_gather=0") # disable parallel backends + cur.execute("SET effective_io_concurrency=1") + + cur.execute("CREATE EXTENSION IF NOT EXISTS neon;") + cur.execute("CREATE EXTENSION IF NOT EXISTS neon_test_utils;") + + log.info("Filling the table") + cur.execute("CREATE TABLE t (data char(1000)) with (fillfactor=10)") + tablesize = 50 * 1024 * 1024 + npages = tablesize // (8 * 1024) + cur.execute("INSERT INTO t SELECT generate_series(1, %s)", (npages,)) + # TODO: can we force postgres to do sequential scans? + + cur.close() + conn.close() + + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) + + endpoint.stop() + + for sk in env.safekeepers: + sk.stop() + + # + # Run single-threaded pagebench (TODO: dedup with other benchmark code) + # + + env.pageserver.allowed_errors.append( + # https://github.com/neondatabase/neon/issues/6925 + r".*query handler for.*pagestream.*failed: unexpected message: CopyFail during COPY.*" + ) + + ps_http = env.pageserver.http_client() + + cmd = [ + str(env.neon_binpath / "pagebench"), + "get-page-latest-lsn", + "--mgmt-api-endpoint", + ps_http.base_url, + "--page-service-connstring", + env.pageserver.connstr(password=None), + "--num-clients", + "1", + "--runtime", + "10s", + ] + log.info(f"command: {' '.join(cmd)}") + basepath = pg_bin.run_capture(cmd, with_command_header=False) + results_path = Path(basepath + ".stdout") + log.info(f"Benchmark results at: {results_path}") + + with open(results_path) as f: + results = json.load(f) + log.info(f"Results:\n{json.dumps(results, sort_keys=True, indent=2)}") + + total = results["total"] + + metric = "latency_mean" + zenbenchmark.record( + metric, + metric_value=humantime_to_ms(total[metric]), + unit="ms", + report=MetricReport.LOWER_IS_BETTER, + ) + + metric = "latency_percentiles" + for k, v in total[metric].items(): + zenbenchmark.record( + f"{metric}.{k}", + metric_value=humantime_to_ms(v), + unit="ms", + report=MetricReport.LOWER_IS_BETTER, + ) From 7a2f0ed8d452cd05bb1b9a85f1cda6e00ead1f85 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Mon, 25 Nov 2024 17:29:28 +0000 Subject: [PATCH 09/34] safekeeper: lift decoding and interpretation of WAL to the safekeeper (#9746) ## Problem For any given tenant shard, pageservers receive all of the tenant's WAL from the safekeeper. This soft-blocks us from using larger shard counts due to bandwidth concerns and CPU overhead of filtering out the records. ## Summary of changes This PR lifts the decoding and interpretation of WAL from the pageserver into the safekeeper. A customised PG replication protocol is used where instead of sending raw WAL, the safekeeper sends filtered, interpreted records. The receiver drives the protocol selection, so, on the pageserver side, usage of the new protocol is gated by a new pageserver config: `wal_receiver_protocol`. More granularly the changes are: 1. Optionally inject the protocol and shard identity into the arguments used for starting replication 2. On the safekeeper side, implement a new wal sending primitive which decodes and interprets records before sending them over 3. On the pageserver side, implement the ingestion of this new replication message type. It's very similar to what we already have for raw wal (minus decoding and interpreting). ## Notes * This PR currently uses my [branch of rust-postgres](https://github.com/neondatabase/rust-postgres/tree/vlad/interpreted-wal-record-replication-support) which includes the deserialization logic for the new replication message type. PR for that is open [here](https://github.com/neondatabase/rust-postgres/pull/32). * This PR contains changes for both pageservers and safekeepers. It's safe to merge because the new protocol is disabled by default on the pageserver side. We can gradually start enabling it in subsequent releases. * CI tests are running on https://github.com/neondatabase/neon/pull/9747 ## Links Related: https://github.com/neondatabase/neon/issues/9336 Epic: https://github.com/neondatabase/neon/issues/9329 --- Cargo.lock | 11 +- libs/pageserver_api/src/config.rs | 7 +- libs/pq_proto/src/lib.rs | 36 +++++ libs/utils/Cargo.toml | 1 + libs/utils/src/postgres_client.rs | 95 +++++++++-- libs/wal_decoder/src/models.rs | 12 ++ libs/wal_decoder/src/serialized_batch.rs | 15 +- pageserver/src/bin/pageserver.rs | 1 + pageserver/src/config.rs | 5 + pageserver/src/pgdatadir_mapping.rs | 7 +- pageserver/src/tenant/timeline.rs | 3 +- pageserver/src/tenant/timeline/walreceiver.rs | 2 + .../walreceiver/connection_manager.rs | 40 +++-- .../walreceiver/walreceiver_connection.rs | 149 ++++++++++++++++-- safekeeper/Cargo.toml | 2 + safekeeper/src/handler.rs | 79 ++++++++++ safekeeper/src/lib.rs | 2 + safekeeper/src/recovery.rs | 13 +- safekeeper/src/send_interpreted_wal.rs | 121 ++++++++++++++ safekeeper/src/send_wal.rs | 122 +++++++++++--- safekeeper/src/wal_reader_stream.rs | 149 ++++++++++++++++++ .../performance/test_sharded_ingest.py | 50 +++++- test_runner/regress/test_compaction.py | 7 +- test_runner/regress/test_crafted_wal_end.py | 9 +- test_runner/regress/test_subxacts.py | 12 +- .../regress/test_wal_acceptor_async.py | 6 +- 26 files changed, 870 insertions(+), 86 deletions(-) create mode 100644 safekeeper/src/send_interpreted_wal.rs create mode 100644 safekeeper/src/wal_reader_stream.rs diff --git a/Cargo.lock b/Cargo.lock index 98d2e0864a..c1a14210de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4133,7 +4133,7 @@ dependencies = [ [[package]] name = "postgres" version = "0.19.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5" dependencies = [ "bytes", "fallible-iterator", @@ -4146,7 +4146,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5" dependencies = [ "base64 0.20.0", "byteorder", @@ -4165,7 +4165,7 @@ dependencies = [ [[package]] name = "postgres-types" version = "0.2.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5" dependencies = [ "bytes", "fallible-iterator", @@ -5364,6 +5364,7 @@ dependencies = [ "itertools 0.10.5", "metrics", "once_cell", + "pageserver_api", "parking_lot 0.12.1", "postgres", "postgres-protocol", @@ -5395,6 +5396,7 @@ dependencies = [ "tracing-subscriber", "url", "utils", + "wal_decoder", "walproposer", "workspace_hack", ] @@ -6466,7 +6468,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.7" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5" dependencies = [ "async-trait", "byteorder", @@ -7021,6 +7023,7 @@ dependencies = [ "serde_assert", "serde_json", "serde_path_to_error", + "serde_with", "signal-hook", "strum", "strum_macros", diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 7666728427..0abca5cdc2 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -18,7 +18,7 @@ use std::{ str::FromStr, time::Duration, }; -use utils::logging::LogFormat; +use utils::{logging::LogFormat, postgres_client::PostgresClientProtocol}; use crate::models::ImageCompressionAlgorithm; use crate::models::LsnLease; @@ -120,6 +120,7 @@ pub struct ConfigToml { pub no_sync: Option, #[serde(with = "humantime_serde")] pub server_side_batch_timeout: Option, + pub wal_receiver_protocol: PostgresClientProtocol, } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -330,6 +331,9 @@ pub mod defaults { pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512; pub const DEFAULT_SERVER_SIDE_BATCH_TIMEOUT: Option<&str> = None; + + pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol = + utils::postgres_client::PostgresClientProtocol::Vanilla; } impl Default for ConfigToml { @@ -418,6 +422,7 @@ impl Default for ConfigToml { .map(|duration| humantime::parse_duration(duration).unwrap()), tenant_config: TenantConfigToml::default(), no_sync: None, + wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL, } } } diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 6c40968496..b7871ab01f 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -562,6 +562,9 @@ pub enum BeMessage<'a> { options: &'a [&'a str], }, KeepAlive(WalSndKeepAlive), + /// Batch of interpreted, shard filtered WAL records, + /// ready for the pageserver to ingest + InterpretedWalRecords(InterpretedWalRecordsBody<'a>), } /// Common shorthands. @@ -672,6 +675,25 @@ pub struct WalSndKeepAlive { pub request_reply: bool, } +/// Batch of interpreted WAL records used in the interpreted +/// safekeeper to pageserver protocol. +/// +/// Note that the pageserver uses the RawInterpretedWalRecordsBody +/// counterpart of this from the neondatabase/rust-postgres repo. +/// If you're changing this struct, you likely need to change its +/// twin as well. +#[derive(Debug)] +pub struct InterpretedWalRecordsBody<'a> { + /// End of raw WAL in [`Self::data`] + pub streaming_lsn: u64, + /// Current end of WAL on the server + pub commit_lsn: u64, + /// Start LSN of the next record in PG WAL. + /// Is 0 if the portion of PG WAL did not contain any records. + pub next_record_lsn: u64, + pub data: &'a [u8], +} + pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]); // single text column @@ -996,6 +1018,20 @@ impl BeMessage<'_> { Ok(()) })? } + + BeMessage::InterpretedWalRecords(rec) => { + // We use the COPY_DATA_TAG for our custom message + // since this tag is interpreted as raw bytes. + buf.put_u8(b'd'); + write_body(buf, |buf| { + buf.put_u8(b'0'); // matches INTERPRETED_WAL_RECORD_TAG in postgres-protocol + // dependency + buf.put_u64(rec.streaming_lsn); + buf.put_u64(rec.commit_lsn); + buf.put_u64(rec.next_record_lsn); + buf.put_slice(rec.data); + }); + } } Ok(()) } diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index 4aad0aee2c..f440b81d8f 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -33,6 +33,7 @@ pprof.workspace = true regex.workspace = true routerify.workspace = true serde.workspace = true +serde_with.workspace = true serde_json.workspace = true signal-hook.workspace = true thiserror.workspace = true diff --git a/libs/utils/src/postgres_client.rs b/libs/utils/src/postgres_client.rs index dba74f5b0b..3073bbde4c 100644 --- a/libs/utils/src/postgres_client.rs +++ b/libs/utils/src/postgres_client.rs @@ -7,29 +7,94 @@ use postgres_connection::{parse_host_port, PgConnectionConfig}; use crate::id::TenantTimelineId; +/// Postgres client protocol types +#[derive( + Copy, + Clone, + PartialEq, + Eq, + strum_macros::EnumString, + strum_macros::Display, + serde_with::DeserializeFromStr, + serde_with::SerializeDisplay, + Debug, +)] +#[strum(serialize_all = "kebab-case")] +#[repr(u8)] +pub enum PostgresClientProtocol { + /// Usual Postgres replication protocol + Vanilla, + /// Custom shard-aware protocol that replicates interpreted records. + /// Used to send wal from safekeeper to pageserver. + Interpreted, +} + +impl TryFrom for PostgresClientProtocol { + type Error = u8; + + fn try_from(value: u8) -> Result { + Ok(match value { + v if v == (PostgresClientProtocol::Vanilla as u8) => PostgresClientProtocol::Vanilla, + v if v == (PostgresClientProtocol::Interpreted as u8) => { + PostgresClientProtocol::Interpreted + } + x => return Err(x), + }) + } +} + +pub struct ConnectionConfigArgs<'a> { + pub protocol: PostgresClientProtocol, + + pub ttid: TenantTimelineId, + pub shard_number: Option, + pub shard_count: Option, + pub shard_stripe_size: Option, + + pub listen_pg_addr_str: &'a str, + + pub auth_token: Option<&'a str>, + pub availability_zone: Option<&'a str>, +} + +impl<'a> ConnectionConfigArgs<'a> { + fn options(&'a self) -> Vec { + let mut options = vec![ + "-c".to_owned(), + format!("timeline_id={}", self.ttid.timeline_id), + format!("tenant_id={}", self.ttid.tenant_id), + format!("protocol={}", self.protocol as u8), + ]; + + if self.shard_number.is_some() { + assert!(self.shard_count.is_some()); + assert!(self.shard_stripe_size.is_some()); + + options.push(format!("shard_count={}", self.shard_count.unwrap())); + options.push(format!("shard_number={}", self.shard_number.unwrap())); + options.push(format!( + "shard_stripe_size={}", + self.shard_stripe_size.unwrap() + )); + } + + options + } +} + /// Create client config for fetching WAL from safekeeper on particular timeline. /// listen_pg_addr_str is in form host:\[port\]. pub fn wal_stream_connection_config( - TenantTimelineId { - tenant_id, - timeline_id, - }: TenantTimelineId, - listen_pg_addr_str: &str, - auth_token: Option<&str>, - availability_zone: Option<&str>, + args: ConnectionConfigArgs, ) -> anyhow::Result { let (host, port) = - parse_host_port(listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?; + parse_host_port(args.listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?; let port = port.unwrap_or(5432); let mut connstr = PgConnectionConfig::new_host_port(host, port) - .extend_options([ - "-c".to_owned(), - format!("timeline_id={}", timeline_id), - format!("tenant_id={}", tenant_id), - ]) - .set_password(auth_token.map(|s| s.to_owned())); + .extend_options(args.options()) + .set_password(args.auth_token.map(|s| s.to_owned())); - if let Some(availability_zone) = availability_zone { + if let Some(availability_zone) = args.availability_zone { connstr = connstr.extend_options([format!("availability_zone={}", availability_zone)]); } diff --git a/libs/wal_decoder/src/models.rs b/libs/wal_decoder/src/models.rs index c69f8c869a..7ac425cb5f 100644 --- a/libs/wal_decoder/src/models.rs +++ b/libs/wal_decoder/src/models.rs @@ -65,6 +65,18 @@ pub struct InterpretedWalRecord { pub xid: TransactionId, } +impl InterpretedWalRecord { + /// Checks if the WAL record is empty + /// + /// An empty interpreted WAL record has no data or metadata and does not have to be sent to the + /// pageserver. + pub fn is_empty(&self) -> bool { + self.batch.is_empty() + && self.metadata_record.is_none() + && matches!(self.flush_uncommitted, FlushUncommittedRecords::No) + } +} + /// The interpreted part of the Postgres WAL record which requires metadata /// writes to the underlying storage engine. #[derive(Serialize, Deserialize)] diff --git a/libs/wal_decoder/src/serialized_batch.rs b/libs/wal_decoder/src/serialized_batch.rs index 9c0708ebbe..41294da7a0 100644 --- a/libs/wal_decoder/src/serialized_batch.rs +++ b/libs/wal_decoder/src/serialized_batch.rs @@ -496,11 +496,16 @@ impl SerializedValueBatch { } } - /// Checks if the batch is empty - /// - /// A batch is empty when it contains no serialized values. - /// Note that it may still contain observed values. + /// Checks if the batch contains any serialized or observed values pub fn is_empty(&self) -> bool { + !self.has_data() && self.metadata.is_empty() + } + + /// Checks if the batch contains data + /// + /// Note that if this returns false, it may still contain observed values or + /// a metadata record. + pub fn has_data(&self) -> bool { let empty = self.raw.is_empty(); if cfg!(debug_assertions) && empty { @@ -510,7 +515,7 @@ impl SerializedValueBatch { .all(|meta| matches!(meta, ValueMeta::Observed(_)))); } - empty + !empty } /// Returns the number of values serialized in the batch diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 033a9a4619..a8c2c2e992 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -126,6 +126,7 @@ fn main() -> anyhow::Result<()> { // after setting up logging, log the effective IO engine choice and read path implementations info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine"); info!(?conf.virtual_file_io_mode, "starting with virtual_file IO mode"); + info!(?conf.wal_receiver_protocol, "starting with WAL receiver protocol"); // The tenants directory contains all the pageserver local disk state. // Create if not exists and make sure all the contents are durable before proceeding. diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 59ea6fb941..2cf237e72b 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -14,6 +14,7 @@ use remote_storage::{RemotePath, RemoteStorageConfig}; use std::env; use storage_broker::Uri; use utils::logging::SecretString; +use utils::postgres_client::PostgresClientProtocol; use once_cell::sync::OnceCell; use reqwest::Url; @@ -190,6 +191,8 @@ pub struct PageServerConf { /// Maximum amount of time for which a get page request request /// might be held up for request merging. pub server_side_batch_timeout: Option, + + pub wal_receiver_protocol: PostgresClientProtocol, } /// Token for authentication to safekeepers @@ -350,6 +353,7 @@ impl PageServerConf { server_side_batch_timeout, tenant_config, no_sync, + wal_receiver_protocol, } = config_toml; let mut conf = PageServerConf { @@ -393,6 +397,7 @@ impl PageServerConf { import_pgdata_upcall_api, import_pgdata_upcall_api_token: import_pgdata_upcall_api_token.map(SecretString::from), import_pgdata_aws_endpoint_url, + wal_receiver_protocol, // ------------------------------------------------------------ // fields that require additional validation or custom handling diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index f4f184be5a..c491bfe650 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -1229,10 +1229,9 @@ impl<'a> DatadirModification<'a> { } pub(crate) fn has_dirty_data(&self) -> bool { - !self - .pending_data_batch + self.pending_data_batch .as_ref() - .map_or(true, |b| b.is_empty()) + .map_or(false, |b| b.has_data()) } /// Set the current lsn @@ -1408,7 +1407,7 @@ impl<'a> DatadirModification<'a> { Some(pending_batch) => { pending_batch.extend(batch); } - None if !batch.is_empty() => { + None if batch.has_data() => { self.pending_data_batch = Some(batch); } None => { diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 4881be33a6..f6a06e73a7 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -2470,6 +2470,7 @@ impl Timeline { *guard = Some(WalReceiver::start( Arc::clone(self), WalReceiverConf { + protocol: self.conf.wal_receiver_protocol, wal_connect_timeout, lagging_wal_timeout, max_lsn_wal_lag, @@ -5896,7 +5897,7 @@ impl<'a> TimelineWriter<'a> { batch: SerializedValueBatch, ctx: &RequestContext, ) -> anyhow::Result<()> { - if batch.is_empty() { + if !batch.has_data() { return Ok(()); } diff --git a/pageserver/src/tenant/timeline/walreceiver.rs b/pageserver/src/tenant/timeline/walreceiver.rs index 4a3a5c621b..f831f5e48a 100644 --- a/pageserver/src/tenant/timeline/walreceiver.rs +++ b/pageserver/src/tenant/timeline/walreceiver.rs @@ -38,6 +38,7 @@ use storage_broker::BrokerClientChannel; use tokio::sync::watch; use tokio_util::sync::CancellationToken; use tracing::*; +use utils::postgres_client::PostgresClientProtocol; use self::connection_manager::ConnectionManagerStatus; @@ -45,6 +46,7 @@ use super::Timeline; #[derive(Clone)] pub struct WalReceiverConf { + pub protocol: PostgresClientProtocol, /// The timeout on the connection to safekeeper for WAL streaming. pub wal_connect_timeout: Duration, /// The timeout to use to determine when the current connection is "stale" and reconnect to the other one. diff --git a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs index de50f217d8..7a64703a30 100644 --- a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs +++ b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs @@ -36,7 +36,9 @@ use postgres_connection::PgConnectionConfig; use utils::backoff::{ exponential_backoff, DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS, }; -use utils::postgres_client::wal_stream_connection_config; +use utils::postgres_client::{ + wal_stream_connection_config, ConnectionConfigArgs, PostgresClientProtocol, +}; use utils::{ id::{NodeId, TenantTimelineId}, lsn::Lsn, @@ -984,15 +986,33 @@ impl ConnectionManagerState { if info.safekeeper_connstr.is_empty() { return None; // no connection string, ignore sk } - match wal_stream_connection_config( - self.id, - info.safekeeper_connstr.as_ref(), - match &self.conf.auth_token { - None => None, - Some(x) => Some(x), + + let (shard_number, shard_count, shard_stripe_size) = match self.conf.protocol { + PostgresClientProtocol::Vanilla => { + (None, None, None) }, - self.conf.availability_zone.as_deref(), - ) { + PostgresClientProtocol::Interpreted => { + let shard_identity = self.timeline.get_shard_identity(); + ( + Some(shard_identity.number.0), + Some(shard_identity.count.0), + Some(shard_identity.stripe_size.0), + ) + } + }; + + let connection_conf_args = ConnectionConfigArgs { + protocol: self.conf.protocol, + ttid: self.id, + shard_number, + shard_count, + shard_stripe_size, + listen_pg_addr_str: info.safekeeper_connstr.as_ref(), + auth_token: self.conf.auth_token.as_ref().map(|t| t.as_str()), + availability_zone: self.conf.availability_zone.as_deref() + }; + + match wal_stream_connection_config(connection_conf_args) { Ok(connstr) => Some((*sk_id, info, connstr)), Err(e) => { error!("Failed to create wal receiver connection string from broker data of safekeeper node {}: {e:#}", sk_id); @@ -1096,6 +1116,7 @@ impl ReconnectReason { mod tests { use super::*; use crate::tenant::harness::{TenantHarness, TIMELINE_ID}; + use pageserver_api::config::defaults::DEFAULT_WAL_RECEIVER_PROTOCOL; use url::Host; fn dummy_broker_sk_timeline( @@ -1532,6 +1553,7 @@ mod tests { timeline, cancel: CancellationToken::new(), conf: WalReceiverConf { + protocol: DEFAULT_WAL_RECEIVER_PROTOCOL, wal_connect_timeout: Duration::from_secs(1), lagging_wal_timeout: Duration::from_secs(1), max_lsn_wal_lag: NonZeroU64::new(1024 * 1024).unwrap(), diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 6ac6920d47..1a0e66ceb3 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -36,7 +36,7 @@ use crate::{ use postgres_backend::is_expected_io_error; use postgres_connection::PgConnectionConfig; use postgres_ffi::waldecoder::WalStreamDecoder; -use utils::{id::NodeId, lsn::Lsn}; +use utils::{bin_ser::BeSer, id::NodeId, lsn::Lsn}; use utils::{pageserver_feedback::PageserverFeedback, sync::gate::GateError}; /// Status of the connection. @@ -291,6 +291,15 @@ pub(super) async fn handle_walreceiver_connection( connection_status.latest_connection_update = now; connection_status.commit_lsn = Some(Lsn::from(keepalive.wal_end())); } + ReplicationMessage::RawInterpretedWalRecords(raw) => { + connection_status.latest_connection_update = now; + if !raw.data().is_empty() { + connection_status.latest_wal_update = now; + } + + connection_status.commit_lsn = Some(Lsn::from(raw.commit_lsn())); + connection_status.streaming_lsn = Some(Lsn::from(raw.streaming_lsn())); + } &_ => {} }; if let Err(e) = events_sender.send(TaskStateUpdate::Progress(connection_status)) { @@ -298,7 +307,130 @@ pub(super) async fn handle_walreceiver_connection( return Ok(()); } + async fn commit( + modification: &mut DatadirModification<'_>, + uncommitted: &mut u64, + filtered: &mut u64, + ctx: &RequestContext, + ) -> anyhow::Result<()> { + WAL_INGEST + .records_committed + .inc_by(*uncommitted - *filtered); + modification.commit(ctx).await?; + *uncommitted = 0; + *filtered = 0; + Ok(()) + } + let status_update = match replication_message { + ReplicationMessage::RawInterpretedWalRecords(raw) => { + WAL_INGEST.bytes_received.inc_by(raw.data().len() as u64); + + let mut uncommitted_records = 0; + let mut filtered_records = 0; + + // This is the end LSN of the raw WAL from which the records + // were interpreted. + let streaming_lsn = Lsn::from(raw.streaming_lsn()); + tracing::debug!( + "Received WAL up to {streaming_lsn} with next_record_lsn={}", + Lsn(raw.next_record_lsn().unwrap_or(0)) + ); + + let records = Vec::::des(raw.data()).with_context(|| { + anyhow::anyhow!( + "Failed to deserialize interpreted records ending at LSN {streaming_lsn}" + ) + })?; + + // We start the modification at 0 because each interpreted record + // advances it to its end LSN. 0 is just an initialization placeholder. + let mut modification = timeline.begin_modification(Lsn(0)); + + for interpreted in records { + if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes) + && uncommitted_records > 0 + { + commit( + &mut modification, + &mut uncommitted_records, + &mut filtered_records, + &ctx, + ) + .await?; + } + + let next_record_lsn = interpreted.next_record_lsn; + let ingested = walingest + .ingest_record(interpreted, &mut modification, &ctx) + .await + .with_context(|| format!("could not ingest record at {next_record_lsn}"))?; + + if !ingested { + tracing::debug!("ingest: filtered out record @ LSN {next_record_lsn}"); + WAL_INGEST.records_filtered.inc(); + filtered_records += 1; + } + + uncommitted_records += 1; + + // FIXME: this cannot be made pausable_failpoint without fixing the + // failpoint library; in tests, the added amount of debugging will cause us + // to timeout the tests. + fail_point!("walreceiver-after-ingest"); + + // Commit every ingest_batch_size records. Even if we filtered out + // all records, we still need to call commit to advance the LSN. + if uncommitted_records >= ingest_batch_size + || modification.approx_pending_bytes() + > DatadirModification::MAX_PENDING_BYTES + { + commit( + &mut modification, + &mut uncommitted_records, + &mut filtered_records, + &ctx, + ) + .await?; + } + } + + // Records might have been filtered out on the safekeeper side, but we still + // need to advance last record LSN on all shards. If we've not ingested the latest + // record, then set the LSN of the modification past it. This way all shards + // advance their last record LSN at the same time. + let needs_last_record_lsn_advance = match raw.next_record_lsn().map(Lsn::from) { + Some(lsn) if lsn > modification.get_lsn() => { + modification.set_lsn(lsn).unwrap(); + true + } + _ => false, + }; + + if uncommitted_records > 0 || needs_last_record_lsn_advance { + // Commit any uncommitted records + commit( + &mut modification, + &mut uncommitted_records, + &mut filtered_records, + &ctx, + ) + .await?; + } + + if !caught_up && streaming_lsn >= end_of_wal { + info!("caught up at LSN {streaming_lsn}"); + caught_up = true; + } + + tracing::debug!( + "Ingested WAL up to {streaming_lsn}. Last record LSN is {}", + timeline.get_last_record_lsn() + ); + + Some(streaming_lsn) + } + ReplicationMessage::XLogData(xlog_data) => { // Pass the WAL data to the decoder, and see if we can decode // more records as a result. @@ -316,21 +448,6 @@ pub(super) async fn handle_walreceiver_connection( let mut uncommitted_records = 0; let mut filtered_records = 0; - async fn commit( - modification: &mut DatadirModification<'_>, - uncommitted: &mut u64, - filtered: &mut u64, - ctx: &RequestContext, - ) -> anyhow::Result<()> { - WAL_INGEST - .records_committed - .inc_by(*uncommitted - *filtered); - modification.commit(ctx).await?; - *uncommitted = 0; - *filtered = 0; - Ok(()) - } - while let Some((next_record_lsn, recdata)) = waldecoder.poll_decode()? { // It is important to deal with the aligned records as lsn in getPage@LSN is // aligned and can be several bytes bigger. Without this alignment we are diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index ab77b63d54..635a9222e1 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -28,6 +28,7 @@ hyper0.workspace = true futures.workspace = true once_cell.workspace = true parking_lot.workspace = true +pageserver_api.workspace = true postgres.workspace = true postgres-protocol.workspace = true pprof.workspace = true @@ -58,6 +59,7 @@ sd-notify.workspace = true storage_broker.workspace = true tokio-stream.workspace = true utils.workspace = true +wal_decoder.workspace = true workspace_hack.workspace = true diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index 3f00b69cde..cec7c3c7ee 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -2,11 +2,15 @@ //! protocol commands. use anyhow::Context; +use pageserver_api::models::ShardParameters; +use pageserver_api::shard::{ShardIdentity, ShardStripeSize}; use std::future::Future; use std::str::{self, FromStr}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, info_span, Instrument}; +use utils::postgres_client::PostgresClientProtocol; +use utils::shard::{ShardCount, ShardNumber}; use crate::auth::check_permission; use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage}; @@ -35,6 +39,8 @@ pub struct SafekeeperPostgresHandler { pub tenant_id: Option, pub timeline_id: Option, pub ttid: TenantTimelineId, + pub shard: Option, + pub protocol: Option, /// Unique connection id is logged in spans for observability. pub conn_id: ConnectionId, /// Auth scope allowed on the connections and public key used to check auth tokens. None if auth is not configured. @@ -107,11 +113,28 @@ impl postgres_backend::Handler ) -> Result<(), QueryError> { if let FeStartupPacket::StartupMessage { params, .. } = sm { if let Some(options) = params.options_raw() { + let mut shard_count: Option = None; + let mut shard_number: Option = None; + let mut shard_stripe_size: Option = None; + for opt in options { // FIXME `ztenantid` and `ztimelineid` left for compatibility during deploy, // remove these after the PR gets deployed: // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064 match opt.split_once('=') { + Some(("protocol", value)) => { + let raw_value = value + .parse::() + .with_context(|| format!("Failed to parse {value} as protocol"))?; + + self.protocol = Some( + PostgresClientProtocol::try_from(raw_value).map_err(|_| { + QueryError::Other(anyhow::anyhow!( + "Unexpected client protocol type: {raw_value}" + )) + })?, + ); + } Some(("ztenantid", value)) | Some(("tenant_id", value)) => { self.tenant_id = Some(value.parse().with_context(|| { format!("Failed to parse {value} as tenant id") @@ -127,9 +150,54 @@ impl postgres_backend::Handler metrics.set_client_az(client_az) } } + Some(("shard_count", value)) => { + shard_count = Some(value.parse::().with_context(|| { + format!("Failed to parse {value} as shard count") + })?); + } + Some(("shard_number", value)) => { + shard_number = Some(value.parse::().with_context(|| { + format!("Failed to parse {value} as shard number") + })?); + } + Some(("shard_stripe_size", value)) => { + shard_stripe_size = Some(value.parse::().with_context(|| { + format!("Failed to parse {value} as shard stripe size") + })?); + } _ => continue, } } + + match self.protocol() { + PostgresClientProtocol::Vanilla => { + if shard_count.is_some() + || shard_number.is_some() + || shard_stripe_size.is_some() + { + return Err(QueryError::Other(anyhow::anyhow!( + "Shard params specified for vanilla protocol" + ))); + } + } + PostgresClientProtocol::Interpreted => { + match (shard_count, shard_number, shard_stripe_size) { + (Some(count), Some(number), Some(stripe_size)) => { + let params = ShardParameters { + count: ShardCount(count), + stripe_size: ShardStripeSize(stripe_size), + }; + self.shard = + Some(ShardIdentity::from_params(ShardNumber(number), ¶ms)); + } + _ => { + return Err(QueryError::Other(anyhow::anyhow!( + "Shard params were not specified" + ))); + } + } + } + } } if let Some(app_name) = params.get("application_name") { @@ -150,6 +218,11 @@ impl postgres_backend::Handler tracing::field::debug(self.appname.clone()), ); + if let Some(shard) = self.shard.as_ref() { + tracing::Span::current() + .record("shard", tracing::field::display(shard.shard_slug())); + } + Ok(()) } else { Err(QueryError::Other(anyhow::anyhow!( @@ -258,6 +331,8 @@ impl SafekeeperPostgresHandler { tenant_id: None, timeline_id: None, ttid: TenantTimelineId::empty(), + shard: None, + protocol: None, conn_id, claims: None, auth, @@ -265,6 +340,10 @@ impl SafekeeperPostgresHandler { } } + pub fn protocol(&self) -> PostgresClientProtocol { + self.protocol.unwrap_or(PostgresClientProtocol::Vanilla) + } + // when accessing management api supply None as an argument // when using to authorize tenant pass corresponding tenant id fn check_permission(&self, tenant_id: Option) -> Result<(), QueryError> { diff --git a/safekeeper/src/lib.rs b/safekeeper/src/lib.rs index 6d68b6b59b..abe6e00a66 100644 --- a/safekeeper/src/lib.rs +++ b/safekeeper/src/lib.rs @@ -29,6 +29,7 @@ pub mod receive_wal; pub mod recovery; pub mod remove_wal; pub mod safekeeper; +pub mod send_interpreted_wal; pub mod send_wal; pub mod state; pub mod timeline; @@ -38,6 +39,7 @@ pub mod timeline_manager; pub mod timelines_set; pub mod wal_backup; pub mod wal_backup_partial; +pub mod wal_reader_stream; pub mod wal_service; pub mod wal_storage; diff --git a/safekeeper/src/recovery.rs b/safekeeper/src/recovery.rs index 9c4149d8f1..7b87166aa0 100644 --- a/safekeeper/src/recovery.rs +++ b/safekeeper/src/recovery.rs @@ -17,6 +17,7 @@ use tokio::{ use tokio_postgres::replication::ReplicationStream; use tokio_postgres::types::PgLsn; use tracing::*; +use utils::postgres_client::{ConnectionConfigArgs, PostgresClientProtocol}; use utils::{id::NodeId, lsn::Lsn, postgres_client::wal_stream_connection_config}; use crate::receive_wal::{WalAcceptor, REPLY_QUEUE_SIZE}; @@ -325,7 +326,17 @@ async fn recovery_stream( conf: &SafeKeeperConf, ) -> anyhow::Result { // TODO: pass auth token - let cfg = wal_stream_connection_config(tli.ttid, &donor.pg_connstr, None, None)?; + let connection_conf_args = ConnectionConfigArgs { + protocol: PostgresClientProtocol::Vanilla, + ttid: tli.ttid, + shard_number: None, + shard_count: None, + shard_stripe_size: None, + listen_pg_addr_str: &donor.pg_connstr, + auth_token: None, + availability_zone: None, + }; + let cfg = wal_stream_connection_config(connection_conf_args)?; let mut cfg = cfg.to_tokio_postgres_config(); // It will make safekeeper give out not committed WAL (up to flush_lsn). cfg.application_name(&format!("safekeeper_{}", conf.my_id)); diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs new file mode 100644 index 0000000000..cf0ee276e9 --- /dev/null +++ b/safekeeper/src/send_interpreted_wal.rs @@ -0,0 +1,121 @@ +use std::time::Duration; + +use anyhow::Context; +use futures::StreamExt; +use pageserver_api::shard::ShardIdentity; +use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend}; +use postgres_ffi::MAX_SEND_SIZE; +use postgres_ffi::{get_current_timestamp, waldecoder::WalStreamDecoder}; +use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::MissedTickBehavior; +use utils::bin_ser::BeSer; +use utils::lsn::Lsn; +use wal_decoder::models::InterpretedWalRecord; + +use crate::send_wal::EndWatchView; +use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder}; + +/// Shard-aware interpreted record sender. +/// This is used for sending WAL to the pageserver. Said WAL +/// is pre-interpreted and filtered for the shard. +pub(crate) struct InterpretedWalSender<'a, IO> { + pub(crate) pgb: &'a mut PostgresBackend, + pub(crate) wal_stream_builder: WalReaderStreamBuilder, + pub(crate) end_watch_view: EndWatchView, + pub(crate) shard: ShardIdentity, + pub(crate) pg_version: u32, + pub(crate) appname: Option, +} + +impl InterpretedWalSender<'_, IO> { + /// Send interpreted WAL to a receiver. + /// Stops when an error occurs or the receiver is caught up and there's no active compute. + /// + /// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ? + /// convenience. + pub(crate) async fn run(self) -> Result<(), CopyStreamHandlerEnd> { + let mut wal_position = self.wal_stream_builder.start_pos(); + let mut wal_decoder = + WalStreamDecoder::new(self.wal_stream_builder.start_pos(), self.pg_version); + + let stream = self.wal_stream_builder.build(MAX_SEND_SIZE).await?; + let mut stream = std::pin::pin!(stream); + + let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1)); + keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + keepalive_ticker.reset(); + + loop { + tokio::select! { + // Get some WAL from the stream and then: decode, interpret and send it + wal = stream.next() => { + let WalBytes { wal, wal_start_lsn: _, wal_end_lsn, available_wal_end_lsn } = match wal { + Some(some) => some?, + None => { break; } + }; + + wal_position = wal_end_lsn; + wal_decoder.feed_bytes(&wal); + + let mut records = Vec::new(); + let mut max_next_record_lsn = None; + while let Some((next_record_lsn, recdata)) = wal_decoder + .poll_decode() + .with_context(|| "Failed to decode WAL")? + { + assert!(next_record_lsn.is_aligned()); + max_next_record_lsn = Some(next_record_lsn); + + // Deserialize and interpret WAL record + let interpreted = InterpretedWalRecord::from_bytes_filtered( + recdata, + &self.shard, + next_record_lsn, + self.pg_version, + ) + .with_context(|| "Failed to interpret WAL")?; + + if !interpreted.is_empty() { + records.push(interpreted); + } + } + + let mut buf = Vec::new(); + records + .ser_into(&mut buf) + .with_context(|| "Failed to serialize interpreted WAL")?; + + // Reset the keep alive ticker since we are sending something + // over the wire now. + keepalive_ticker.reset(); + + self.pgb + .write_message(&BeMessage::InterpretedWalRecords(InterpretedWalRecordsBody { + streaming_lsn: wal_end_lsn.0, + commit_lsn: available_wal_end_lsn.0, + next_record_lsn: max_next_record_lsn.unwrap_or(Lsn::INVALID).0, + data: buf.as_slice(), + })).await?; + } + + // Send a periodic keep alive when the connection has been idle for a while. + _ = keepalive_ticker.tick() => { + self.pgb + .write_message(&BeMessage::KeepAlive(WalSndKeepAlive { + wal_end: self.end_watch_view.get().0, + timestamp: get_current_timestamp(), + request_reply: true, + })) + .await?; + } + } + } + + // The loop above ends when the receiver is caught up and there's no more WAL to send. + Err(CopyStreamHandlerEnd::ServerInitiated(format!( + "ending streaming to {:?} at {}, receiver is caughtup and there is no computes", + self.appname, wal_position, + ))) + } +} diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index aa65ec851b..1acfcad418 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -5,12 +5,15 @@ use crate::handler::SafekeeperPostgresHandler; use crate::metrics::RECEIVED_PS_FEEDBACKS; use crate::receive_wal::WalReceivers; use crate::safekeeper::{Term, TermLsn}; +use crate::send_interpreted_wal::InterpretedWalSender; use crate::timeline::WalResidentTimeline; +use crate::wal_reader_stream::WalReaderStreamBuilder; use crate::wal_service::ConnectionId; use crate::wal_storage::WalReader; use crate::GlobalTimelines; use anyhow::{bail, Context as AnyhowContext}; use bytes::Bytes; +use futures::future::Either; use parking_lot::Mutex; use postgres_backend::PostgresBackend; use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError}; @@ -22,6 +25,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use utils::failpoint_support; use utils::id::TenantTimelineId; use utils::pageserver_feedback::PageserverFeedback; +use utils::postgres_client::PostgresClientProtocol; use std::cmp::{max, min}; use std::net::SocketAddr; @@ -226,7 +230,7 @@ impl WalSenders { /// Get remote_consistent_lsn reported by the pageserver. Returns None if /// client is not pageserver. - fn get_ws_remote_consistent_lsn(self: &Arc, id: WalSenderId) -> Option { + pub fn get_ws_remote_consistent_lsn(self: &Arc, id: WalSenderId) -> Option { let shared = self.mutex.lock(); let slot = shared.get_slot(id); match slot.feedback { @@ -370,6 +374,16 @@ pub struct WalSenderGuard { walsenders: Arc, } +impl WalSenderGuard { + pub fn id(&self) -> WalSenderId { + self.id + } + + pub fn walsenders(&self) -> &Arc { + &self.walsenders + } +} + impl Drop for WalSenderGuard { fn drop(&mut self) { self.walsenders.unregister(self.id); @@ -440,11 +454,12 @@ impl SafekeeperPostgresHandler { } info!( - "starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}", + "starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}, protocol={}", start_pos, end_pos, matches!(end_watch, EndWatch::Flush(_)), - appname + appname, + self.protocol(), ); // switch to copy @@ -456,21 +471,51 @@ impl SafekeeperPostgresHandler { // not synchronized with sends, so this avoids deadlocks. let reader = pgb.split().context("START_REPLICATION split")?; + let send_fut = match self.protocol() { + PostgresClientProtocol::Vanilla => { + let sender = WalSender { + pgb, + // should succeed since we're already holding another guard + tli: tli.wal_residence_guard().await?, + appname, + start_pos, + end_pos, + term, + end_watch, + ws_guard: ws_guard.clone(), + wal_reader, + send_buf: vec![0u8; MAX_SEND_SIZE], + }; + + Either::Left(sender.run()) + } + PostgresClientProtocol::Interpreted => { + let pg_version = tli.tli.get_state().await.1.server.pg_version / 10000; + let end_watch_view = end_watch.view(); + let wal_stream_builder = WalReaderStreamBuilder { + tli: tli.wal_residence_guard().await?, + start_pos, + end_pos, + term, + end_watch, + wal_sender_guard: ws_guard.clone(), + }; + + let sender = InterpretedWalSender { + pgb, + wal_stream_builder, + end_watch_view, + shard: self.shard.unwrap(), + pg_version, + appname, + }; + + Either::Right(sender.run()) + } + }; + let tli_cancel = tli.cancel.clone(); - let mut sender = WalSender { - pgb, - // should succeed since we're already holding another guard - tli: tli.wal_residence_guard().await?, - appname, - start_pos, - end_pos, - term, - end_watch, - ws_guard: ws_guard.clone(), - wal_reader, - send_buf: vec![0u8; MAX_SEND_SIZE], - }; let mut reply_reader = ReplyReader { reader, ws_guard: ws_guard.clone(), @@ -479,7 +524,7 @@ impl SafekeeperPostgresHandler { let res = tokio::select! { // todo: add read|write .context to these errors - r = sender.run() => r, + r = send_fut => r, r = reply_reader.run() => r, _ = tli_cancel.cancelled() => { return Err(CopyStreamHandlerEnd::Cancelled); @@ -504,16 +549,22 @@ impl SafekeeperPostgresHandler { } } +/// TODO(vlad): maybe lift this instead /// Walsender streams either up to commit_lsn (normally) or flush_lsn in the /// given term (recovery by walproposer or peer safekeeper). -enum EndWatch { +#[derive(Clone)] +pub(crate) enum EndWatch { Commit(Receiver), Flush(Receiver), } impl EndWatch { + pub(crate) fn view(&self) -> EndWatchView { + EndWatchView(self.clone()) + } + /// Get current end of WAL. - fn get(&self) -> Lsn { + pub(crate) fn get(&self) -> Lsn { match self { EndWatch::Commit(r) => *r.borrow(), EndWatch::Flush(r) => r.borrow().lsn, @@ -521,15 +572,44 @@ impl EndWatch { } /// Wait for the update. - async fn changed(&mut self) -> anyhow::Result<()> { + pub(crate) async fn changed(&mut self) -> anyhow::Result<()> { match self { EndWatch::Commit(r) => r.changed().await?, EndWatch::Flush(r) => r.changed().await?, } Ok(()) } + + pub(crate) async fn wait_for_lsn( + &mut self, + lsn: Lsn, + client_term: Option, + ) -> anyhow::Result { + loop { + let end_pos = self.get(); + if end_pos > lsn { + return Ok(end_pos); + } + if let EndWatch::Flush(rx) = &self { + let curr_term = rx.borrow().term; + if let Some(client_term) = client_term { + if curr_term != client_term { + bail!("term changed: requested {}, now {}", client_term, curr_term); + } + } + } + self.changed().await?; + } + } } +pub(crate) struct EndWatchView(EndWatch); + +impl EndWatchView { + pub(crate) fn get(&self) -> Lsn { + self.0.get() + } +} /// A half driving sending WAL. struct WalSender<'a, IO> { pgb: &'a mut PostgresBackend, @@ -566,7 +646,7 @@ impl WalSender<'_, IO> { /// /// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ? /// convenience. - async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> { + async fn run(mut self) -> Result<(), CopyStreamHandlerEnd> { loop { // Wait for the next portion if it is not there yet, or just // update our end of WAL available for sending value, we diff --git a/safekeeper/src/wal_reader_stream.rs b/safekeeper/src/wal_reader_stream.rs new file mode 100644 index 0000000000..f8c0c502cd --- /dev/null +++ b/safekeeper/src/wal_reader_stream.rs @@ -0,0 +1,149 @@ +use std::sync::Arc; + +use async_stream::try_stream; +use bytes::Bytes; +use futures::Stream; +use postgres_backend::CopyStreamHandlerEnd; +use std::time::Duration; +use tokio::time::timeout; +use utils::lsn::Lsn; + +use crate::{ + safekeeper::Term, + send_wal::{EndWatch, WalSenderGuard}, + timeline::WalResidentTimeline, +}; + +pub(crate) struct WalReaderStreamBuilder { + pub(crate) tli: WalResidentTimeline, + pub(crate) start_pos: Lsn, + pub(crate) end_pos: Lsn, + pub(crate) term: Option, + pub(crate) end_watch: EndWatch, + pub(crate) wal_sender_guard: Arc, +} + +impl WalReaderStreamBuilder { + pub(crate) fn start_pos(&self) -> Lsn { + self.start_pos + } +} + +pub(crate) struct WalBytes { + /// Raw PG WAL + pub(crate) wal: Bytes, + /// Start LSN of [`Self::wal`] + #[allow(dead_code)] + pub(crate) wal_start_lsn: Lsn, + /// End LSN of [`Self::wal`] + pub(crate) wal_end_lsn: Lsn, + /// End LSN of WAL available on the safekeeper. + /// + /// For pagservers this will be commit LSN, + /// while for the compute it will be the flush LSN. + pub(crate) available_wal_end_lsn: Lsn, +} + +impl WalReaderStreamBuilder { + /// Builds a stream of Postgres WAL starting from [`Self::start_pos`]. + /// The stream terminates when the receiver (pageserver) is fully caught up + /// and there's no active computes. + pub(crate) async fn build( + self, + buffer_size: usize, + ) -> anyhow::Result>> { + // TODO(vlad): The code below duplicates functionality from [`crate::send_wal`]. + // We can make the raw WAL sender use this stream too and remove the duplication. + let Self { + tli, + mut start_pos, + mut end_pos, + term, + mut end_watch, + wal_sender_guard, + } = self; + let mut wal_reader = tli.get_walreader(start_pos).await?; + let mut buffer = vec![0; buffer_size]; + + const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); + + Ok(try_stream! { + loop { + let have_something_to_send = end_pos > start_pos; + + if !have_something_to_send { + // wait for lsn + let res = timeout(POLL_STATE_TIMEOUT, end_watch.wait_for_lsn(start_pos, term)).await; + match res { + Ok(ok) => { + end_pos = ok?; + }, + Err(_) => { + if let EndWatch::Commit(_) = end_watch { + if let Some(remote_consistent_lsn) = wal_sender_guard + .walsenders() + .get_ws_remote_consistent_lsn(wal_sender_guard.id()) + { + if tli.should_walsender_stop(remote_consistent_lsn).await { + // Stop streaming if the receivers are caught up and + // there's no active compute. This causes the loop in + // [`crate::send_interpreted_wal::InterpretedWalSender::run`] + // to exit and terminate the WAL stream. + return; + } + } + } + + continue; + } + } + } + + + assert!( + end_pos > start_pos, + "nothing to send after waiting for WAL" + ); + + // try to send as much as available, capped by the buffer size + let mut chunk_end_pos = start_pos + buffer_size as u64; + // if we went behind available WAL, back off + if chunk_end_pos >= end_pos { + chunk_end_pos = end_pos; + } else { + // If sending not up to end pos, round down to page boundary to + // avoid breaking WAL record not at page boundary, as protocol + // demands. See walsender.c (XLogSendPhysical). + chunk_end_pos = chunk_end_pos + .checked_sub(chunk_end_pos.block_offset()) + .unwrap(); + } + let send_size = (chunk_end_pos.0 - start_pos.0) as usize; + let buffer = &mut buffer[..send_size]; + let send_size: usize; + { + // If uncommitted part is being pulled, check that the term is + // still the expected one. + let _term_guard = if let Some(t) = term { + Some(tli.acquire_term(t).await?) + } else { + None + }; + // Read WAL into buffer. send_size can be additionally capped to + // segment boundary here. + send_size = wal_reader.read(buffer).await? + }; + let wal = Bytes::copy_from_slice(&buffer[..send_size]); + + yield WalBytes { + wal, + wal_start_lsn: start_pos, + wal_end_lsn: start_pos + send_size as u64, + available_wal_end_lsn: end_pos + }; + + start_pos += send_size as u64; + } + }) + } +} diff --git a/test_runner/performance/test_sharded_ingest.py b/test_runner/performance/test_sharded_ingest.py index 77e8f2cf17..e965aae5a0 100644 --- a/test_runner/performance/test_sharded_ingest.py +++ b/test_runner/performance/test_sharded_ingest.py @@ -15,16 +15,21 @@ from fixtures.neon_fixtures import ( @pytest.mark.timeout(600) @pytest.mark.parametrize("shard_count", [1, 8, 32]) +@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) def test_sharded_ingest( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, shard_count: int, + wal_receiver_protocol: str, ): """ Benchmarks sharded ingestion throughput, by ingesting a large amount of WAL into a Safekeeper and fanning out to a large number of shards on dedicated Pageservers. Comparing the base case (shard_count=1) to the sharded case indicates the overhead of sharding. """ + neon_env_builder.pageserver_config_override = ( + f"wal_receiver_protocol = '{wal_receiver_protocol}'" + ) ROW_COUNT = 100_000_000 # about 7 GB of WAL @@ -50,7 +55,6 @@ def test_sharded_ingest( # Start the endpoint. endpoint = env.endpoints.create_start("main", tenant_id=tenant_id) start_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0]) - # Ingest data and measure WAL volume and duration. with closing(endpoint.connect()) as conn: with conn.cursor() as cur: @@ -68,4 +72,48 @@ def test_sharded_ingest( wal_written_mb = round((end_lsn - start_lsn) / (1024 * 1024)) zenbenchmark.record("wal_written", wal_written_mb, "MB", MetricReport.TEST_PARAM) + total_ingested = 0 + total_records_received = 0 + ingested_by_ps = [] + for pageserver in env.pageservers: + ingested = pageserver.http_client().get_metric_value( + "pageserver_wal_ingest_bytes_received_total" + ) + records_received = pageserver.http_client().get_metric_value( + "pageserver_wal_ingest_records_received_total" + ) + + if ingested is None: + ingested = 0 + + if records_received is None: + records_received = 0 + + ingested_by_ps.append( + ( + pageserver.id, + { + "ingested": ingested, + "records_received": records_received, + }, + ) + ) + + total_ingested += int(ingested) + total_records_received += int(records_received) + + total_ingested_mb = total_ingested / (1024 * 1024) + zenbenchmark.record("wal_ingested", total_ingested_mb, "MB", MetricReport.LOWER_IS_BETTER) + zenbenchmark.record( + "records_received", total_records_received, "records", MetricReport.LOWER_IS_BETTER + ) + + ingested_by_ps.sort(key=lambda x: x[0]) + for _, stats in ingested_by_ps: + for k in stats: + if k != "records_received": + stats[k] /= 1024**2 + + log.info(f"WAL ingested by each pageserver {ingested_by_ps}") + assert tenant_get_shards(env, tenant_id) == shards, "shards moved" diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index f71e05924a..79fd256304 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -27,7 +27,8 @@ AGGRESIVE_COMPACTION_TENANT_CONF = { @skip_in_debug_build("only run with release build") -def test_pageserver_compaction_smoke(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) +def test_pageserver_compaction_smoke(neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: str): """ This is a smoke test that compaction kicks in. The workload repeatedly churns a small number of rows and manually instructs the pageserver to run compaction @@ -38,8 +39,8 @@ def test_pageserver_compaction_smoke(neon_env_builder: NeonEnvBuilder): # Effectively disable the page cache to rely only on image layers # to shorten reads. - neon_env_builder.pageserver_config_override = """ -page_cache_size=10 + neon_env_builder.pageserver_config_override = f""" +page_cache_size=10; wal_receiver_protocol='{wal_receiver_protocol}' """ env = neon_env_builder.init_start(initial_tenant_conf=AGGRESIVE_COMPACTION_TENANT_CONF) diff --git a/test_runner/regress/test_crafted_wal_end.py b/test_runner/regress/test_crafted_wal_end.py index 23c6fa3a5a..70e71d99cd 100644 --- a/test_runner/regress/test_crafted_wal_end.py +++ b/test_runner/regress/test_crafted_wal_end.py @@ -19,7 +19,14 @@ from fixtures.neon_fixtures import NeonEnvBuilder "wal_record_crossing_segment_followed_by_small_one", ], ) -def test_crafted_wal_end(neon_env_builder: NeonEnvBuilder, wal_type: str): +@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) +def test_crafted_wal_end( + neon_env_builder: NeonEnvBuilder, wal_type: str, wal_receiver_protocol: str +): + neon_env_builder.pageserver_config_override = ( + f"wal_receiver_protocol = '{wal_receiver_protocol}'" + ) + env = neon_env_builder.init_start() env.create_branch("test_crafted_wal_end") env.pageserver.allowed_errors.extend( diff --git a/test_runner/regress/test_subxacts.py b/test_runner/regress/test_subxacts.py index 7a46f0140c..1d86c353be 100644 --- a/test_runner/regress/test_subxacts.py +++ b/test_runner/regress/test_subxacts.py @@ -1,6 +1,7 @@ from __future__ import annotations -from fixtures.neon_fixtures import NeonEnv, check_restored_datadir_content +import pytest +from fixtures.neon_fixtures import NeonEnvBuilder, check_restored_datadir_content # Test subtransactions @@ -9,8 +10,13 @@ from fixtures.neon_fixtures import NeonEnv, check_restored_datadir_content # maintained in the pageserver, so subtransactions are not very exciting for # Neon. They are included in the commit record though and updated in the # CLOG. -def test_subxacts(neon_simple_env: NeonEnv, test_output_dir): - env = neon_simple_env +@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) +def test_subxacts(neon_env_builder: NeonEnvBuilder, test_output_dir, wal_receiver_protocol): + neon_env_builder.pageserver_config_override = ( + f"wal_receiver_protocol = '{wal_receiver_protocol}'" + ) + + env = neon_env_builder.init_start() endpoint = env.endpoints.create_start("main") pg_conn = endpoint.connect() diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index 18408b0619..094b10b576 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -622,8 +622,12 @@ async def run_segment_init_failure(env: NeonEnv): # Test (injected) failure during WAL segment init. # https://github.com/neondatabase/neon/issues/6401 # https://github.com/neondatabase/neon/issues/6402 -def test_segment_init_failure(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) +def test_segment_init_failure(neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: str): neon_env_builder.num_safekeepers = 1 + neon_env_builder.pageserver_config_override = ( + f"wal_receiver_protocol = '{wal_receiver_protocol}'" + ) env = neon_env_builder.init_start() asyncio.run(run_segment_init_failure(env)) From 87e4dd23a1d92ad99665b109ba58b7050a934b4d Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Mon, 25 Nov 2024 18:53:26 +0100 Subject: [PATCH 10/34] proxy: Demote all cplane error replies to info log level (#9880) ## Problem The vast majority of the error/warn logs from cplane are about time or data transfer quotas exceeded or endpoint-not-found errors and not operational errors in proxy or cplane. ## Summary of changes * Demote cplane error replies to info level. * Raise other errors from warn back to error. --- proxy/src/proxy/wake_compute.rs | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 8a672d48dc..4e9206feff 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,9 +1,9 @@ -use tracing::{error, info, warn}; +use tracing::{error, info}; use super::connect_compute::ComputeConnectBackend; use crate::config::RetryConfig; use crate::context::RequestContext; -use crate::control_plane::errors::WakeComputeError; +use crate::control_plane::errors::{ControlPlaneError, WakeComputeError}; use crate::control_plane::CachedNodeInfo; use crate::error::ReportableError; use crate::metrics::{ @@ -11,6 +11,18 @@ use crate::metrics::{ }; use crate::proxy::retry::{retry_after, should_retry}; +// Use macro to retain original callsite. +macro_rules! log_wake_compute_error { + (error = ?$error:expr, $num_retries:expr, retriable = $retriable:literal) => { + match $error { + WakeComputeError::ControlPlane(ControlPlaneError::Message(_)) => { + info!(error = ?$error, num_retries = $num_retries, retriable = $retriable, "couldn't wake compute node") + } + _ => error!(error = ?$error, num_retries = $num_retries, retriable = $retriable, "couldn't wake compute node"), + } + }; +} + pub(crate) async fn wake_compute( num_retries: &mut u32, ctx: &RequestContext, @@ -20,7 +32,7 @@ pub(crate) async fn wake_compute( loop { match api.wake_compute(ctx).await { Err(e) if !should_retry(&e, *num_retries, config) => { - error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); + log_wake_compute_error!(error = ?e, num_retries, retriable = false); report_error(&e, false); Metrics::get().proxy.retries_metric.observe( RetriesMetricGroup { @@ -32,7 +44,7 @@ pub(crate) async fn wake_compute( return Err(e); } Err(e) => { - warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); + log_wake_compute_error!(error = ?e, num_retries, retriable = true); report_error(&e, true); } Ok(n) => { From 7404887b810403768a950006335877aa6a966c8d Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Mon, 25 Nov 2024 20:35:32 +0100 Subject: [PATCH 11/34] proxy: Demote errors from cplane request routines to debug (#9886) ## Problem Any errors from these async blocks are unconditionally logged at error level even though we already handle such errors based on context. ## Summary of changes * Log raw errors from creating and executing cplane requests at debug level. * Inline macro calls to retain the correct callsite. --- proxy/src/control_plane/client/mock.rs | 2 +- proxy/src/control_plane/client/neon.rs | 13 ++++++------- proxy/src/error.rs | 6 ------ 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 500acad50f..9537d717a1 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -114,7 +114,7 @@ impl MockControlPlane { Ok((secret, allowed_ips)) } - .map_err(crate::error::log_error::) + .inspect_err(|e: &GetAuthInfoError| tracing::error!("{e}")) .instrument(info_span!("postgres", url = self.endpoint.as_str())) .await?; Ok(AuthInfo { diff --git a/proxy/src/control_plane/client/neon.rs b/proxy/src/control_plane/client/neon.rs index 757ea6720a..2cad981d01 100644 --- a/proxy/src/control_plane/client/neon.rs +++ b/proxy/src/control_plane/client/neon.rs @@ -134,8 +134,8 @@ impl NeonControlPlaneClient { project_id: body.project_id, }) } - .map_err(crate::error::log_error) - .instrument(info_span!("http", id = request_id)) + .inspect_err(|e| tracing::debug!(error = ?e)) + .instrument(info_span!("do_get_auth_info")) .await } @@ -193,8 +193,8 @@ impl NeonControlPlaneClient { Ok(rules) } - .map_err(crate::error::log_error) - .instrument(info_span!("http", id = request_id)) + .inspect_err(|e| tracing::debug!(error = ?e)) + .instrument(info_span!("do_get_endpoint_jwks")) .await } @@ -252,9 +252,8 @@ impl NeonControlPlaneClient { Ok(node) } - .map_err(crate::error::log_error) - // TODO: redo this span stuff - .instrument(info_span!("http", id = request_id)) + .inspect_err(|e| tracing::debug!(error = ?e)) + .instrument(info_span!("do_wake_compute")) .await } } diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 7b693a7418..2221aac407 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -10,12 +10,6 @@ pub(crate) fn io_error(e: impl Into>) -> io::Err io::Error::new(io::ErrorKind::Other, e) } -/// A small combinator for pluggable error logging. -pub(crate) fn log_error(e: E) -> E { - tracing::error!("{e}"); - e -} - /// Marks errors that may be safely shown to a client. /// This trait can be seen as a specialized version of [`ToString`]. /// From a74ab9338d85ae0fcac1fa965c1119a4d74c98df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Mon, 25 Nov 2024 21:23:42 +0100 Subject: [PATCH 12/34] fast_import: remove hardcoding of pg_version (#9878) Before, we hardcoded the pg_version to 140000, while the code expected version numbers like 14. Now we use an enum, and code from `extension_server.rs` to auto-detect the correct version. The enum helps when we add support for a version: enums ensure that compilation fails if one forgets to put the version to one of the `match` locations. cc https://github.com/neondatabase/neon/pull/9218 --- compute_tools/src/bin/compute_ctl.rs | 4 +- compute_tools/src/bin/fast_import.rs | 9 ++++- compute_tools/src/extension_server.rs | 54 ++++++++++++++++++--------- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 4689cc2b83..6b670de2ea 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -58,7 +58,7 @@ use compute_tools::compute::{ forward_termination_signal, ComputeNode, ComputeState, ParsedSpec, PG_PID, }; use compute_tools::configurator::launch_configurator; -use compute_tools::extension_server::get_pg_version; +use compute_tools::extension_server::get_pg_version_string; use compute_tools::http::api::launch_http_server; use compute_tools::logger::*; use compute_tools::monitor::launch_monitor; @@ -326,7 +326,7 @@ fn wait_spec( connstr: Url::parse(connstr).context("cannot parse connstr as a URL")?, pgdata: pgdata.to_string(), pgbin: pgbin.to_string(), - pgversion: get_pg_version(pgbin), + pgversion: get_pg_version_string(pgbin), live_config_allowed, state: Mutex::new(new_state), state_changed: Condvar::new(), diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 3b0b990df2..6716cc6234 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -29,6 +29,7 @@ use anyhow::Context; use aws_config::BehaviorVersion; use camino::{Utf8Path, Utf8PathBuf}; use clap::Parser; +use compute_tools::extension_server::{get_pg_version, PostgresMajorVersion}; use nix::unistd::Pid; use tracing::{info, info_span, warn, Instrument}; use utils::fs_ext::is_directory_empty; @@ -131,11 +132,17 @@ pub(crate) async fn main() -> anyhow::Result<()> { // // Initialize pgdata // + let pg_version = match get_pg_version(pg_bin_dir.as_str()) { + PostgresMajorVersion::V14 => 14, + PostgresMajorVersion::V15 => 15, + PostgresMajorVersion::V16 => 16, + PostgresMajorVersion::V17 => 17, + }; let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs { superuser, locale: "en_US.UTF-8", // XXX: this shouldn't be hard-coded, - pg_version: 140000, // XXX: this shouldn't be hard-coded but derived from which compute image we're running in + pg_version, initdb_bin: pg_bin_dir.join("initdb").as_ref(), library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local. pgdata: &pgdata_dir, diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index da2d107b54..f13b2308e7 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -103,14 +103,33 @@ fn get_pg_config(argument: &str, pgbin: &str) -> String { .to_string() } -pub fn get_pg_version(pgbin: &str) -> String { +pub fn get_pg_version(pgbin: &str) -> PostgresMajorVersion { // pg_config --version returns a (platform specific) human readable string // such as "PostgreSQL 15.4". We parse this to v14/v15/v16 etc. let human_version = get_pg_config("--version", pgbin); - parse_pg_version(&human_version).to_string() + parse_pg_version(&human_version) } -fn parse_pg_version(human_version: &str) -> &str { +pub fn get_pg_version_string(pgbin: &str) -> String { + match get_pg_version(pgbin) { + PostgresMajorVersion::V14 => "v14", + PostgresMajorVersion::V15 => "v15", + PostgresMajorVersion::V16 => "v16", + PostgresMajorVersion::V17 => "v17", + } + .to_owned() +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum PostgresMajorVersion { + V14, + V15, + V16, + V17, +} + +fn parse_pg_version(human_version: &str) -> PostgresMajorVersion { + use PostgresMajorVersion::*; // Normal releases have version strings like "PostgreSQL 15.4". But there // are also pre-release versions like "PostgreSQL 17devel" or "PostgreSQL // 16beta2" or "PostgreSQL 17rc1". And with the --with-extra-version @@ -121,10 +140,10 @@ fn parse_pg_version(human_version: &str) -> &str { .captures(human_version) { Some(captures) if captures.len() == 2 => match &captures["major"] { - "14" => return "v14", - "15" => return "v15", - "16" => return "v16", - "17" => return "v17", + "14" => return V14, + "15" => return V15, + "16" => return V16, + "17" => return V17, _ => {} }, _ => {} @@ -263,24 +282,25 @@ mod tests { #[test] fn test_parse_pg_version() { - assert_eq!(parse_pg_version("PostgreSQL 15.4"), "v15"); - assert_eq!(parse_pg_version("PostgreSQL 15.14"), "v15"); + use super::PostgresMajorVersion::*; + assert_eq!(parse_pg_version("PostgreSQL 15.4"), V15); + assert_eq!(parse_pg_version("PostgreSQL 15.14"), V15); assert_eq!( parse_pg_version("PostgreSQL 15.4 (Ubuntu 15.4-0ubuntu0.23.04.1)"), - "v15" + V15 ); - assert_eq!(parse_pg_version("PostgreSQL 14.15"), "v14"); - assert_eq!(parse_pg_version("PostgreSQL 14.0"), "v14"); + assert_eq!(parse_pg_version("PostgreSQL 14.15"), V14); + assert_eq!(parse_pg_version("PostgreSQL 14.0"), V14); assert_eq!( parse_pg_version("PostgreSQL 14.9 (Debian 14.9-1.pgdg120+1"), - "v14" + V14 ); - assert_eq!(parse_pg_version("PostgreSQL 16devel"), "v16"); - assert_eq!(parse_pg_version("PostgreSQL 16beta1"), "v16"); - assert_eq!(parse_pg_version("PostgreSQL 16rc2"), "v16"); - assert_eq!(parse_pg_version("PostgreSQL 16extra"), "v16"); + assert_eq!(parse_pg_version("PostgreSQL 16devel"), V16); + assert_eq!(parse_pg_version("PostgreSQL 16beta1"), V16); + assert_eq!(parse_pg_version("PostgreSQL 16rc2"), V16); + assert_eq!(parse_pg_version("PostgreSQL 16extra"), V16); } #[test] From 96a1b71c84965782bb10c9fb591ff3fe43b1f8c5 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 25 Nov 2024 21:32:53 +0000 Subject: [PATCH 13/34] chore(proxy): discard request context span during passthrough (#9882) ## Problem The RequestContext::span shouldn't live for the entire postgres connection, only the handshake. ## Summary of changes * Slight refactor to the RequestContext to discard the span upon handshake completion. * Make sure the temporary future for the handshake is dropped (not bound to a variable) * Runs our nightly fmt script --- proxy/src/cancellation.rs | 6 ++-- proxy/src/console_redirect_proxy.rs | 38 ++++++++++---------- proxy/src/context/mod.rs | 23 +++++++++---- proxy/src/proxy/mod.rs | 42 +++++++++++------------ proxy/src/proxy/passthrough.rs | 3 +- proxy/src/redis/cancellation_publisher.rs | 2 +- 6 files changed, 59 insertions(+), 55 deletions(-) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 4b72a66e63..74415f1ffe 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,7 +1,8 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use dashmap::DashMap; +use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use pq_proto::CancelKeyData; use thiserror::Error; use tokio::net::TcpStream; @@ -17,9 +18,6 @@ use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::cancellation_publisher::{ CancellationPublisher, CancellationPublisherMut, RedisPublisherClient, }; -use std::net::IpAddr; - -use ipnet::{IpNet, Ipv4Net, Ipv6Net}; pub type CancelMap = Arc>>; pub type CancellationHandlerMain = CancellationHandler>>>; diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index fbd0c8e5c5..b910b524b1 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use futures::TryFutureExt; +use futures::{FutureExt, TryFutureExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, Instrument}; @@ -88,40 +88,37 @@ pub async fn task_main( crate::metrics::Protocol::Tcp, &config.region, ); - let span = ctx.span(); - let startup = Box::pin( - handle_client( - config, - backend, - &ctx, - cancellation_handler, - socket, - conn_gauge, - ) - .instrument(span.clone()), - ); - let res = startup.await; + let res = handle_client( + config, + backend, + &ctx, + cancellation_handler, + socket, + conn_gauge, + ) + .instrument(ctx.span()) + .boxed() + .await; match res { Err(e) => { - // todo: log and push to ctx the error kind ctx.set_error_kind(e.get_error_kind()); - error!(parent: &span, "per-client task finished with an error: {e:#}"); + error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); } Ok(None) => { ctx.set_success(); } Ok(Some(p)) => { ctx.set_success(); - ctx.log_connect(); - match p.proxy_pass().instrument(span.clone()).await { + let _disconnect = ctx.log_connect(); + match p.proxy_pass().await { Ok(()) => {} Err(ErrorSource::Client(e)) => { - error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); + error!(?session_id, "per-client task finished with an IO error from the client: {e:#}"); } Err(ErrorSource::Compute(e)) => { - error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}"); + error!(?session_id, "per-client task finished with an IO error from the compute: {e:#}"); } } } @@ -219,6 +216,7 @@ pub(crate) async fn handle_client( client: stream, aux: node.aux.clone(), compute: node, + session_id: ctx.session_id(), _req: request_gauge, _conn: conn_gauge, _cancel: session, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 6d2d2d51ce..4ec04deb25 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -272,11 +272,14 @@ impl RequestContext { this.success = true; } - pub fn log_connect(&self) { - self.0 - .try_lock() - .expect("should not deadlock") - .log_connect(); + pub fn log_connect(self) -> DisconnectLogger { + let mut this = self.0.into_inner(); + this.log_connect(); + + // close current span. + this.span = Span::none(); + + DisconnectLogger(this) } pub(crate) fn protocol(&self) -> Protocol { @@ -434,8 +437,14 @@ impl Drop for RequestContextInner { fn drop(&mut self) { if self.sender.is_some() { self.log_connect(); - } else { - self.log_disconnect(); } } } + +pub struct DisconnectLogger(RequestContextInner); + +impl Drop for DisconnectLogger { + fn drop(&mut self) { + self.0.log_disconnect(); + } +} diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 5d9468d89a..7fe67e43de 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -10,7 +10,7 @@ pub(crate) mod wake_compute; use std::sync::Arc; pub use copy_bidirectional::{copy_bidirectional_client_compute, ErrorSource}; -use futures::TryFutureExt; +use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, StartupMessageParams}; @@ -123,42 +123,39 @@ pub async fn task_main( crate::metrics::Protocol::Tcp, &config.region, ); - let span = ctx.span(); - let startup = Box::pin( - handle_client( - config, - auth_backend, - &ctx, - cancellation_handler, - socket, - ClientMode::Tcp, - endpoint_rate_limiter2, - conn_gauge, - ) - .instrument(span.clone()), - ); - let res = startup.await; + let res = handle_client( + config, + auth_backend, + &ctx, + cancellation_handler, + socket, + ClientMode::Tcp, + endpoint_rate_limiter2, + conn_gauge, + ) + .instrument(ctx.span()) + .boxed() + .await; match res { Err(e) => { - // todo: log and push to ctx the error kind ctx.set_error_kind(e.get_error_kind()); - warn!(parent: &span, "per-client task finished with an error: {e:#}"); + warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); } Ok(None) => { ctx.set_success(); } Ok(Some(p)) => { ctx.set_success(); - ctx.log_connect(); - match p.proxy_pass().instrument(span.clone()).await { + let _disconnect = ctx.log_connect(); + match p.proxy_pass().await { Ok(()) => {} Err(ErrorSource::Client(e)) => { - warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); + warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}"); } Err(ErrorSource::Compute(e)) => { - error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}"); + error!(?session_id, "per-client task finished with an IO error from the compute: {e:#}"); } } } @@ -352,6 +349,7 @@ pub(crate) async fn handle_client( client: stream, aux: node.aux.clone(), compute: node, + session_id: ctx.session_id(), _req: request_gauge, _conn: conn_gauge, _cancel: session, diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 5e07c8eeae..dcaa81e5cd 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -59,6 +59,7 @@ pub(crate) struct ProxyPassthrough { pub(crate) client: Stream, pub(crate) compute: PostgresConnection, pub(crate) aux: MetricsAuxInfo, + pub(crate) session_id: uuid::Uuid, pub(crate) _req: NumConnectionRequestsGuard<'static>, pub(crate) _conn: NumClientConnectionsGuard<'static>, @@ -69,7 +70,7 @@ impl ProxyPassthrough { pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { - tracing::warn!(?err, "could not cancel the query in the database"); + tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); } res } diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 633a2f1b81..228dbb7f64 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -1,6 +1,6 @@ +use core::net::IpAddr; use std::sync::Arc; -use core::net::IpAddr; use pq_proto::CancelKeyData; use redis::AsyncCommands; use tokio::sync::Mutex; From 13feda0669d65cbac4b2103952caba1a9db1342e Mon Sep 17 00:00:00 2001 From: Peter Bendel Date: Tue, 26 Nov 2024 12:46:58 +0100 Subject: [PATCH 14/34] track how much time the flush loop is stalled waiting for uploads (#9885) ## Problem We don't know how much time PS is losing during ingest when waiting for remote storage uploads in the flush frozen layer loop. Also we don't know how many remote storage requests get an permit without waiting (not throttled by remote_storage concurrency_limit). ## Summary of changes - Add a metric that accumulates the time waited per shard/PS - in [remote storage semaphore wait seconds](https://neonprod.grafana.net/d/febd9732-9bcf-4992-a821-49b1f6b02724/remote-storage?orgId=1&var-datasource=HUNg6jvVk&var-instance=pageserver-26.us-east-2.aws.neon.build&var-instance=pageserver-27.us-east-2.aws.neon.build&var-instance=pageserver-28.us-east-2.aws.neon.build&var-instance=pageserver-29.us-east-2.aws.neon.build&var-instance=pageserver-30.us-east-2.aws.neon.build&var-instance=pageserver-31.us-east-2.aws.neon.build&var-instance=pageserver-36.us-east-2.aws.neon.build&var-instance=pageserver-37.us-east-2.aws.neon.build&var-instance=pageserver-38.us-east-2.aws.neon.build&var-instance=pageserver-39.us-east-2.aws.neon.build&var-instance=pageserver-40.us-east-2.aws.neon.build&var-instance=pageserver-41.us-east-2.aws.neon.build&var-request_type=put_object&from=1731961336340&to=1731964762933&viewPanel=3) add a first bucket with 100 microseconds to count requests that do not need to wait on semaphore Update: created a new version that uses a Gauge (one increasing value per PS/shard) instead of histogram as suggested by review --- libs/remote_storage/src/metrics.rs | 4 +++- pageserver/src/metrics.rs | 25 ++++++++++++++++++++++++- pageserver/src/tenant/timeline.rs | 5 ++++- test_runner/fixtures/metrics.py | 1 + 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/libs/remote_storage/src/metrics.rs b/libs/remote_storage/src/metrics.rs index f1aa4c433b..48c121fbc8 100644 --- a/libs/remote_storage/src/metrics.rs +++ b/libs/remote_storage/src/metrics.rs @@ -176,7 +176,9 @@ pub(crate) struct BucketMetrics { impl Default for BucketMetrics { fn default() -> Self { - let buckets = [0.01, 0.10, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]; + // first bucket 100 microseconds to count requests that do not need to wait at all + // and get a permit immediately + let buckets = [0.0001, 0.01, 0.10, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]; let req_seconds = register_histogram_vec!( "remote_storage_s3_request_seconds", diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 3cdc2a761e..5ce3ae6cf7 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -3,7 +3,7 @@ use metrics::{ register_counter_vec, register_gauge_vec, register_histogram, register_histogram_vec, register_int_counter, register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec, - Counter, CounterVec, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair, + Counter, CounterVec, Gauge, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair, IntCounterPairVec, IntCounterVec, IntGauge, IntGaugeVec, UIntGauge, UIntGaugeVec, }; use once_cell::sync::Lazy; @@ -457,6 +457,15 @@ pub(crate) static WAIT_LSN_TIME: Lazy = Lazy::new(|| { .expect("failed to define a metric") }); +static FLUSH_WAIT_UPLOAD_TIME: Lazy = Lazy::new(|| { + register_gauge_vec!( + "pageserver_flush_wait_upload_seconds", + "Time spent waiting for preceding uploads during layer flush", + &["tenant_id", "shard_id", "timeline_id"] + ) + .expect("failed to define a metric") +}); + static LAST_RECORD_LSN: Lazy = Lazy::new(|| { register_int_gauge_vec!( "pageserver_last_record_lsn", @@ -2336,6 +2345,7 @@ pub(crate) struct TimelineMetrics { shard_id: String, timeline_id: String, pub flush_time_histo: StorageTimeMetrics, + pub flush_wait_upload_time_gauge: Gauge, pub compact_time_histo: StorageTimeMetrics, pub create_images_time_histo: StorageTimeMetrics, pub logical_size_histo: StorageTimeMetrics, @@ -2379,6 +2389,9 @@ impl TimelineMetrics { &shard_id, &timeline_id, ); + let flush_wait_upload_time_gauge = FLUSH_WAIT_UPLOAD_TIME + .get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id]) + .unwrap(); let compact_time_histo = StorageTimeMetrics::new( StorageTimeOperation::Compact, &tenant_id, @@ -2516,6 +2529,7 @@ impl TimelineMetrics { shard_id, timeline_id, flush_time_histo, + flush_wait_upload_time_gauge, compact_time_histo, create_images_time_histo, logical_size_histo, @@ -2563,6 +2577,14 @@ impl TimelineMetrics { self.resident_physical_size_gauge.get() } + pub(crate) fn flush_wait_upload_time_gauge_add(&self, duration: f64) { + self.flush_wait_upload_time_gauge.add(duration); + crate::metrics::FLUSH_WAIT_UPLOAD_TIME + .get_metric_with_label_values(&[&self.tenant_id, &self.shard_id, &self.timeline_id]) + .unwrap() + .add(duration); + } + pub(crate) fn shutdown(&self) { let was_shutdown = self .shutdown @@ -2579,6 +2601,7 @@ impl TimelineMetrics { let timeline_id = &self.timeline_id; let shard_id = &self.shard_id; let _ = LAST_RECORD_LSN.remove_label_values(&[tenant_id, shard_id, timeline_id]); + let _ = FLUSH_WAIT_UPLOAD_TIME.remove_label_values(&[tenant_id, shard_id, timeline_id]); let _ = STANDBY_HORIZON.remove_label_values(&[tenant_id, shard_id, timeline_id]); { RESIDENT_PHYSICAL_SIZE_GLOBAL.sub(self.resident_physical_size_get()); diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index f6a06e73a7..c1ff0f426d 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -3830,7 +3830,8 @@ impl Timeline { }; // Backpressure mechanism: wait with continuation of the flush loop until we have uploaded all layer files. - // This makes us refuse ingest until the new layers have been persisted to the remote. + // This makes us refuse ingest until the new layers have been persisted to the remote + let start = Instant::now(); self.remote_client .wait_completion() .await @@ -3843,6 +3844,8 @@ impl Timeline { FlushLayerError::Other(anyhow!(e).into()) } })?; + let duration = start.elapsed().as_secs_f64(); + self.metrics.flush_wait_upload_time_gauge_add(duration); // FIXME: between create_delta_layer and the scheduling of the upload in `update_metadata_file`, // a compaction can delete the file and then it won't be available for uploads any more. diff --git a/test_runner/fixtures/metrics.py b/test_runner/fixtures/metrics.py index 330f007a77..3f90c233a6 100644 --- a/test_runner/fixtures/metrics.py +++ b/test_runner/fixtures/metrics.py @@ -168,6 +168,7 @@ PAGESERVER_PER_TENANT_METRICS: tuple[str, ...] = ( "pageserver_evictions_with_low_residence_duration_total", "pageserver_aux_file_estimated_size", "pageserver_valid_lsn_lease_count", + "pageserver_flush_wait_upload_seconds", counter("pageserver_tenant_throttling_count_accounted_start"), counter("pageserver_tenant_throttling_count_accounted_finish"), counter("pageserver_tenant_throttling_wait_usecs_sum"), From 2b788cb53f53606b0e56df540b762e853f7bc41b Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Tue, 26 Nov 2024 11:49:37 -0600 Subject: [PATCH 15/34] Bump neon.logical_replication_max_snap_files default to 10000 (#9896) This bump comes from a recommendation from Chi. Signed-off-by: Tristan Partin --- pgxn/neon/logical_replication_monitor.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgxn/neon/logical_replication_monitor.c b/pgxn/neon/logical_replication_monitor.c index 1badbbed21..5eee5a1679 100644 --- a/pgxn/neon/logical_replication_monitor.c +++ b/pgxn/neon/logical_replication_monitor.c @@ -20,7 +20,7 @@ #define LS_MONITOR_CHECK_INTERVAL 10000 /* ms */ -static int logical_replication_max_snap_files = 300; +static int logical_replication_max_snap_files = 10000; /* * According to Chi (shyzh), the pageserver _should_ be good with 10 MB worth of @@ -184,7 +184,7 @@ InitLogicalReplicationMonitor(void) "Maximum allowed logical replication .snap files. When exceeded, slots are dropped until the limit is met. -1 disables the limit.", NULL, &logical_replication_max_snap_files, - 300, -1, INT_MAX, + 10000, -1, INT_MAX, PGC_SIGHUP, 0, NULL, NULL, NULL); From 277c33ba3f47e88f7c032ce90d87b09df0c0e92c Mon Sep 17 00:00:00 2001 From: Peter Bendel Date: Wed, 27 Nov 2024 11:09:01 +0100 Subject: [PATCH 16/34] ingest benchmark: after effective_io_concurrency = 100 we can increase compute side parallelism (#9904) ## Problem ingest benchmark tests project migration to Neon involving steps - COPY relation data - create indexes - create constraints Previously we used only 4 copy jobs, 4 create index jobs and 7 maintenance workers. After increasing effective_io_concurrency on compute we see that we can sustain more parallelism in the ingest bench ## Summary of changes Increase copy jobs to 8, create index jobs to 8 and maintenance workers to 16 --- .../performance/test_perf_ingest_using_pgcopydb.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test_runner/performance/test_perf_ingest_using_pgcopydb.py b/test_runner/performance/test_perf_ingest_using_pgcopydb.py index 2f4574ba88..37f2e9db50 100644 --- a/test_runner/performance/test_perf_ingest_using_pgcopydb.py +++ b/test_runner/performance/test_perf_ingest_using_pgcopydb.py @@ -60,13 +60,13 @@ def build_pgcopydb_command(pgcopydb_filter_file: Path, test_output_dir: Path): "--no-acl", "--skip-db-properties", "--table-jobs", - "4", + "8", "--index-jobs", - "4", + "8", "--restore-jobs", - "4", + "8", "--split-tables-larger-than", - "10GB", + "5GB", "--skip-extensions", "--use-copy-binary", "--filters", @@ -136,7 +136,7 @@ def run_command_and_log_output(command, log_file_path: Path): "LD_LIBRARY_PATH": f"{os.getenv('PGCOPYDB_LIB_PATH')}:{os.getenv('PG_16_LIB_PATH')}", "PGCOPYDB_SOURCE_PGURI": cast(str, os.getenv("BENCHMARK_INGEST_SOURCE_CONNSTR")), "PGCOPYDB_TARGET_PGURI": cast(str, os.getenv("BENCHMARK_INGEST_TARGET_CONNSTR")), - "PGOPTIONS": "-c maintenance_work_mem=8388608 -c max_parallel_maintenance_workers=7", + "PGOPTIONS": "-c maintenance_work_mem=8388608 -c max_parallel_maintenance_workers=16", } # Combine the current environment with custom variables env = os.environ.copy() From 7b41ee872eff41f6a0d427e86f6cd3e9563c6fee Mon Sep 17 00:00:00 2001 From: Alexander Bayandin Date: Wed, 27 Nov 2024 10:42:26 +0000 Subject: [PATCH 17/34] CI(pre-merge-checks): build only one build-tools-image (#9718) ## Problem The `pre-merge-checks` workflow relies on the build-tools image. If changes to the `build-tools` image have been merged into the main branch since the last CI run for a PR (with other changes to the `build-tools`), the image will be rebuilt during the merge queue run. Otherwise, cached images are used. Rebuilding the image adds approximately 10 minutes on x86-64 and 20 minutes on arm64 to the process. ## Summary of changes - parametrise `build-build-tools-image` job with arch and Debian version - Run `pre-merge-checks` only on Debian 12 x86-64 image --- .github/workflows/build-build-tools-image.yml | 73 +++++++++++++------ .github/workflows/pre-merge-checks.yml | 9 ++- 2 files changed, 59 insertions(+), 23 deletions(-) diff --git a/.github/workflows/build-build-tools-image.yml b/.github/workflows/build-build-tools-image.yml index 93da86a353..0a7f0cd7a0 100644 --- a/.github/workflows/build-build-tools-image.yml +++ b/.github/workflows/build-build-tools-image.yml @@ -2,6 +2,17 @@ name: Build build-tools image on: workflow_call: + inputs: + archs: + description: "Json array of architectures to build" + # Default values are set in `check-image` job, `set-variables` step + type: string + required: false + debians: + description: "Json array of Debian versions to build" + # Default values are set in `check-image` job, `set-variables` step + type: string + required: false outputs: image-tag: description: "build-tools tag" @@ -32,25 +43,37 @@ jobs: check-image: runs-on: ubuntu-22.04 outputs: - tag: ${{ steps.get-build-tools-tag.outputs.image-tag }} - found: ${{ steps.check-image.outputs.found }} + archs: ${{ steps.set-variables.outputs.archs }} + debians: ${{ steps.set-variables.outputs.debians }} + tag: ${{ steps.set-variables.outputs.image-tag }} + everything: ${{ steps.set-more-variables.outputs.everything }} + found: ${{ steps.set-more-variables.outputs.found }} steps: - uses: actions/checkout@v4 - - name: Get build-tools image tag for the current commit - id: get-build-tools-tag + - name: Set variables + id: set-variables env: + ARCHS: ${{ inputs.archs || '["x64","arm64"]' }} + DEBIANS: ${{ inputs.debians || '["bullseye","bookworm"]' }} IMAGE_TAG: | ${{ hashFiles('build-tools.Dockerfile', '.github/workflows/build-build-tools-image.yml') }} run: | - echo "image-tag=${IMAGE_TAG}" | tee -a $GITHUB_OUTPUT + echo "archs=${ARCHS}" | tee -a ${GITHUB_OUTPUT} + echo "debians=${DEBIANS}" | tee -a ${GITHUB_OUTPUT} + echo "image-tag=${IMAGE_TAG}" | tee -a ${GITHUB_OUTPUT} - - name: Check if such tag found in the registry - id: check-image + - name: Set more variables + id: set-more-variables env: - IMAGE_TAG: ${{ steps.get-build-tools-tag.outputs.image-tag }} + IMAGE_TAG: ${{ steps.set-variables.outputs.image-tag }} + EVERYTHING: | + ${{ contains(fromJson(steps.set-variables.outputs.archs), 'x64') && + contains(fromJson(steps.set-variables.outputs.archs), 'arm64') && + contains(fromJson(steps.set-variables.outputs.debians), 'bullseye') && + contains(fromJson(steps.set-variables.outputs.debians), 'bookworm') }} run: | if docker manifest inspect neondatabase/build-tools:${IMAGE_TAG}; then found=true @@ -58,8 +81,8 @@ jobs: found=false fi - echo "found=${found}" | tee -a $GITHUB_OUTPUT - + echo "everything=${EVERYTHING}" | tee -a ${GITHUB_OUTPUT} + echo "found=${found}" | tee -a ${GITHUB_OUTPUT} build-image: needs: [ check-image ] @@ -67,8 +90,8 @@ jobs: strategy: matrix: - debian-version: [ bullseye, bookworm ] - arch: [ x64, arm64 ] + arch: ${{ fromJson(needs.check-image.outputs.archs) }} + debian: ${{ fromJson(needs.check-image.outputs.debians) }} runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', matrix.arch == 'arm64' && 'large-arm64' || 'large')) }} @@ -99,11 +122,11 @@ jobs: push: true pull: true build-args: | - DEBIAN_VERSION=${{ matrix.debian-version }} - cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian-version }}-${{ matrix.arch }} - cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian-version, matrix.arch) || '' }} + DEBIAN_VERSION=${{ matrix.debian }} + cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian }}-${{ matrix.arch }} + cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian, matrix.arch) || '' }} tags: | - neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian-version }}-${{ matrix.arch }} + neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian }}-${{ matrix.arch }} merge-images: needs: [ check-image, build-image ] @@ -118,15 +141,21 @@ jobs: - name: Create multi-arch image env: DEFAULT_DEBIAN_VERSION: bookworm + ARCHS: ${{ join(fromJson(needs.check-image.outputs.archs), ' ') }} + DEBIANS: ${{ join(fromJson(needs.check-image.outputs.debians), ' ') }} + EVERYTHING: ${{ needs.check-image.outputs.everything }} IMAGE_TAG: ${{ needs.check-image.outputs.tag }} run: | - for debian_version in bullseye bookworm; do - tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian_version}") - if [ "${debian_version}" == "${DEFAULT_DEBIAN_VERSION}" ]; then + for debian in ${DEBIANS}; do + tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian}") + + if [ "${EVERYTHING}" == "true" ] && [ "${debian}" == "${DEFAULT_DEBIAN_VERSION}" ]; then tags+=("-t" "neondatabase/build-tools:${IMAGE_TAG}") fi - docker buildx imagetools create "${tags[@]}" \ - neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-x64 \ - neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-arm64 + for arch in ${ARCHS}; do + tags+=("neondatabase/build-tools:${IMAGE_TAG}-${debian}-${arch}") + done + + docker buildx imagetools create "${tags[@]}" done diff --git a/.github/workflows/pre-merge-checks.yml b/.github/workflows/pre-merge-checks.yml index e1cec6d33d..d2f9d8a666 100644 --- a/.github/workflows/pre-merge-checks.yml +++ b/.github/workflows/pre-merge-checks.yml @@ -23,6 +23,8 @@ jobs: id: python-src with: files: | + .github/workflows/_check-codestyle-python.yml + .github/workflows/build-build-tools-image.yml .github/workflows/pre-merge-checks.yml **/**.py poetry.lock @@ -38,6 +40,10 @@ jobs: if: needs.get-changed-files.outputs.python-changed == 'true' needs: [ get-changed-files ] uses: ./.github/workflows/build-build-tools-image.yml + with: + # Build only one combination to save time + archs: '["x64"]' + debians: '["bookworm"]' secrets: inherit check-codestyle-python: @@ -45,7 +51,8 @@ jobs: needs: [ get-changed-files, build-build-tools-image ] uses: ./.github/workflows/_check-codestyle-python.yml with: - build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm + # `-bookworm-x64` suffix should match the combination in `build-build-tools-image` + build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm-x64 secrets: inherit # To get items from the merge queue merged into main we need to satisfy "Status checks that are required". From 9e0148de11feefae7402bdc655ff6bf4ace8bc1f Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Wed, 27 Nov 2024 12:12:21 +0000 Subject: [PATCH 18/34] safekeeper: use protobuf for sending compressed records to pageserver (#9821) ## Problem https://github.com/neondatabase/neon/pull/9746 lifted decoding and interpretation of WAL to the safekeeper. This reduced the ingested amount on the pageservers by around 10x for a tenant with 8 shards, but doubled the ingested amount for single sharded tenants. Also, https://github.com/neondatabase/neon/pull/9746 uses bincode which doesn't support schema evolution. Technically the schema can be evolved, but it's very cumbersome. ## Summary of changes This patch set addresses both problems by adding protobuf support for the interpreted wal records and adding compression support. Compressed protobuf reduced the ingested amount by 100x on the 32 shards `test_sharded_ingest` case (compared to non-interpreted proto). For the 1 shard case the reduction is 5x. Sister change to `rust-postgres` is [here](https://github.com/neondatabase/rust-postgres/pull/33). ## Links Related: https://github.com/neondatabase/neon/issues/9336 Epic: https://github.com/neondatabase/neon/issues/9329 --- Cargo.lock | 14 +- libs/pageserver_api/src/key.rs | 12 + libs/pq_proto/src/lib.rs | 4 - libs/utils/src/postgres_client.rs | 54 ++- libs/wal_decoder/Cargo.toml | 8 + libs/wal_decoder/build.rs | 11 + libs/wal_decoder/proto/interpreted_wal.proto | 43 +++ libs/wal_decoder/src/lib.rs | 1 + libs/wal_decoder/src/models.rs | 20 + libs/wal_decoder/src/wire_format.rs | 356 ++++++++++++++++++ .../walreceiver/connection_manager.rs | 4 +- .../walreceiver/walreceiver_connection.rs | 52 ++- safekeeper/src/handler.rs | 17 +- safekeeper/src/send_interpreted_wal.rs | 53 ++- safekeeper/src/send_wal.rs | 9 +- test_runner/fixtures/neon_fixtures.py | 36 ++ .../performance/test_sharded_ingest.py | 47 ++- test_runner/regress/test_compaction.py | 16 +- test_runner/regress/test_crafted_wal_end.py | 15 +- test_runner/regress/test_subxacts.py | 15 +- .../regress/test_wal_acceptor_async.py | 21 +- 21 files changed, 702 insertions(+), 106 deletions(-) create mode 100644 libs/wal_decoder/build.rs create mode 100644 libs/wal_decoder/proto/interpreted_wal.proto create mode 100644 libs/wal_decoder/src/wire_format.rs diff --git a/Cargo.lock b/Cargo.lock index c1a14210de..43a46fb1eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4133,7 +4133,7 @@ dependencies = [ [[package]] name = "postgres" version = "0.19.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796" dependencies = [ "bytes", "fallible-iterator", @@ -4146,7 +4146,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796" dependencies = [ "base64 0.20.0", "byteorder", @@ -4165,7 +4165,7 @@ dependencies = [ [[package]] name = "postgres-types" version = "0.2.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796" dependencies = [ "bytes", "fallible-iterator", @@ -6468,7 +6468,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.7" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796" dependencies = [ "async-trait", "byteorder", @@ -7120,10 +7120,16 @@ name = "wal_decoder" version = "0.1.0" dependencies = [ "anyhow", + "async-compression", "bytes", "pageserver_api", "postgres_ffi", + "prost", "serde", + "thiserror", + "tokio", + "tonic", + "tonic-build", "tracing", "utils", "workspace_hack", diff --git a/libs/pageserver_api/src/key.rs b/libs/pageserver_api/src/key.rs index 4505101ea6..523d143381 100644 --- a/libs/pageserver_api/src/key.rs +++ b/libs/pageserver_api/src/key.rs @@ -229,6 +229,18 @@ impl Key { } } +impl CompactKey { + pub fn raw(&self) -> i128 { + self.0 + } +} + +impl From for CompactKey { + fn from(value: i128) -> Self { + Self(value) + } +} + impl fmt::Display for Key { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index b7871ab01f..4b0331999d 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -688,9 +688,6 @@ pub struct InterpretedWalRecordsBody<'a> { pub streaming_lsn: u64, /// Current end of WAL on the server pub commit_lsn: u64, - /// Start LSN of the next record in PG WAL. - /// Is 0 if the portion of PG WAL did not contain any records. - pub next_record_lsn: u64, pub data: &'a [u8], } @@ -1028,7 +1025,6 @@ impl BeMessage<'_> { // dependency buf.put_u64(rec.streaming_lsn); buf.put_u64(rec.commit_lsn); - buf.put_u64(rec.next_record_lsn); buf.put_slice(rec.data); }); } diff --git a/libs/utils/src/postgres_client.rs b/libs/utils/src/postgres_client.rs index 3073bbde4c..a62568202b 100644 --- a/libs/utils/src/postgres_client.rs +++ b/libs/utils/src/postgres_client.rs @@ -7,40 +7,31 @@ use postgres_connection::{parse_host_port, PgConnectionConfig}; use crate::id::TenantTimelineId; -/// Postgres client protocol types -#[derive( - Copy, - Clone, - PartialEq, - Eq, - strum_macros::EnumString, - strum_macros::Display, - serde_with::DeserializeFromStr, - serde_with::SerializeDisplay, - Debug, -)] -#[strum(serialize_all = "kebab-case")] -#[repr(u8)] +#[derive(Copy, Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum InterpretedFormat { + Bincode, + Protobuf, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum Compression { + Zstd { level: i8 }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(tag = "type", content = "args")] +#[serde(rename_all = "kebab-case")] pub enum PostgresClientProtocol { /// Usual Postgres replication protocol Vanilla, /// Custom shard-aware protocol that replicates interpreted records. /// Used to send wal from safekeeper to pageserver. - Interpreted, -} - -impl TryFrom for PostgresClientProtocol { - type Error = u8; - - fn try_from(value: u8) -> Result { - Ok(match value { - v if v == (PostgresClientProtocol::Vanilla as u8) => PostgresClientProtocol::Vanilla, - v if v == (PostgresClientProtocol::Interpreted as u8) => { - PostgresClientProtocol::Interpreted - } - x => return Err(x), - }) - } + Interpreted { + format: InterpretedFormat, + compression: Option, + }, } pub struct ConnectionConfigArgs<'a> { @@ -63,7 +54,10 @@ impl<'a> ConnectionConfigArgs<'a> { "-c".to_owned(), format!("timeline_id={}", self.ttid.timeline_id), format!("tenant_id={}", self.ttid.tenant_id), - format!("protocol={}", self.protocol as u8), + format!( + "protocol={}", + serde_json::to_string(&self.protocol).unwrap() + ), ]; if self.shard_number.is_some() { diff --git a/libs/wal_decoder/Cargo.toml b/libs/wal_decoder/Cargo.toml index c8c0f4c990..8fac4e38ca 100644 --- a/libs/wal_decoder/Cargo.toml +++ b/libs/wal_decoder/Cargo.toml @@ -8,11 +8,19 @@ license.workspace = true testing = ["pageserver_api/testing"] [dependencies] +async-compression.workspace = true anyhow.workspace = true bytes.workspace = true pageserver_api.workspace = true +prost.workspace = true postgres_ffi.workspace = true serde.workspace = true +thiserror.workspace = true +tokio = { workspace = true, features = ["io-util"] } +tonic.workspace = true tracing.workspace = true utils.workspace = true workspace_hack = { version = "0.1", path = "../../workspace_hack" } + +[build-dependencies] +tonic-build.workspace = true diff --git a/libs/wal_decoder/build.rs b/libs/wal_decoder/build.rs new file mode 100644 index 0000000000..d5b7ad02ad --- /dev/null +++ b/libs/wal_decoder/build.rs @@ -0,0 +1,11 @@ +fn main() -> Result<(), Box> { + // Generate rust code from .proto protobuf. + // + // Note: we previously tried to use deterministic location at proto/ for + // easy location, but apparently interference with cachepot sometimes fails + // the build then. Anyway, per cargo docs build script shouldn't output to + // anywhere but $OUT_DIR. + tonic_build::compile_protos("proto/interpreted_wal.proto") + .unwrap_or_else(|e| panic!("failed to compile protos {:?}", e)); + Ok(()) +} diff --git a/libs/wal_decoder/proto/interpreted_wal.proto b/libs/wal_decoder/proto/interpreted_wal.proto new file mode 100644 index 0000000000..0393392c1a --- /dev/null +++ b/libs/wal_decoder/proto/interpreted_wal.proto @@ -0,0 +1,43 @@ +syntax = "proto3"; + +package interpreted_wal; + +message InterpretedWalRecords { + repeated InterpretedWalRecord records = 1; + optional uint64 next_record_lsn = 2; +} + +message InterpretedWalRecord { + optional bytes metadata_record = 1; + SerializedValueBatch batch = 2; + uint64 next_record_lsn = 3; + bool flush_uncommitted = 4; + uint32 xid = 5; +} + +message SerializedValueBatch { + bytes raw = 1; + repeated ValueMeta metadata = 2; + uint64 max_lsn = 3; + uint64 len = 4; +} + +enum ValueMetaType { + Serialized = 0; + Observed = 1; +} + +message ValueMeta { + ValueMetaType type = 1; + CompactKey key = 2; + uint64 lsn = 3; + optional uint64 batch_offset = 4; + optional uint64 len = 5; + optional bool will_init = 6; +} + +message CompactKey { + int64 high = 1; + int64 low = 2; +} + diff --git a/libs/wal_decoder/src/lib.rs b/libs/wal_decoder/src/lib.rs index a8a26956e6..96b717021f 100644 --- a/libs/wal_decoder/src/lib.rs +++ b/libs/wal_decoder/src/lib.rs @@ -1,3 +1,4 @@ pub mod decoder; pub mod models; pub mod serialized_batch; +pub mod wire_format; diff --git a/libs/wal_decoder/src/models.rs b/libs/wal_decoder/src/models.rs index 7ac425cb5f..af22de5d95 100644 --- a/libs/wal_decoder/src/models.rs +++ b/libs/wal_decoder/src/models.rs @@ -37,12 +37,32 @@ use utils::lsn::Lsn; use crate::serialized_batch::SerializedValueBatch; +// Code generated by protobuf. +pub mod proto { + // Tonic does derives as `#[derive(Clone, PartialEq, ::prost::Message)]` + // we don't use these types for anything but broker data transmission, + // so it's ok to ignore this one. + #![allow(clippy::derive_partial_eq_without_eq)] + // The generated ValueMeta has a `len` method generate for its `len` field. + #![allow(clippy::len_without_is_empty)] + tonic::include_proto!("interpreted_wal"); +} + #[derive(Serialize, Deserialize)] pub enum FlushUncommittedRecords { Yes, No, } +/// A batch of interpreted WAL records +#[derive(Serialize, Deserialize)] +pub struct InterpretedWalRecords { + pub records: Vec, + // Start LSN of the next record after the batch. + // Note that said record may not belong to the current shard. + pub next_record_lsn: Option, +} + /// An interpreted Postgres WAL record, ready to be handled by the pageserver #[derive(Serialize, Deserialize)] pub struct InterpretedWalRecord { diff --git a/libs/wal_decoder/src/wire_format.rs b/libs/wal_decoder/src/wire_format.rs new file mode 100644 index 0000000000..5a343054c3 --- /dev/null +++ b/libs/wal_decoder/src/wire_format.rs @@ -0,0 +1,356 @@ +use bytes::{BufMut, Bytes, BytesMut}; +use pageserver_api::key::CompactKey; +use prost::{DecodeError, EncodeError, Message}; +use tokio::io::AsyncWriteExt; +use utils::bin_ser::{BeSer, DeserializeError, SerializeError}; +use utils::lsn::Lsn; +use utils::postgres_client::{Compression, InterpretedFormat}; + +use crate::models::{ + FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords, MetadataRecord, +}; + +use crate::serialized_batch::{ + ObservedValueMeta, SerializedValueBatch, SerializedValueMeta, ValueMeta, +}; + +use crate::models::proto; + +#[derive(Debug, thiserror::Error)] +pub enum ToWireFormatError { + #[error("{0}")] + Bincode(#[from] SerializeError), + #[error("{0}")] + Protobuf(#[from] ProtobufSerializeError), + #[error("{0}")] + Compression(#[from] std::io::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum ProtobufSerializeError { + #[error("{0}")] + MetadataRecord(#[from] SerializeError), + #[error("{0}")] + Encode(#[from] EncodeError), +} + +#[derive(Debug, thiserror::Error)] +pub enum FromWireFormatError { + #[error("{0}")] + Bincode(#[from] DeserializeError), + #[error("{0}")] + Protobuf(#[from] ProtobufDeserializeError), + #[error("{0}")] + Decompress(#[from] std::io::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum ProtobufDeserializeError { + #[error("{0}")] + Transcode(#[from] TranscodeError), + #[error("{0}")] + Decode(#[from] DecodeError), +} + +#[derive(Debug, thiserror::Error)] +pub enum TranscodeError { + #[error("{0}")] + BadInput(String), + #[error("{0}")] + MetadataRecord(#[from] DeserializeError), +} + +pub trait ToWireFormat { + fn to_wire( + self, + format: InterpretedFormat, + compression: Option, + ) -> impl std::future::Future> + Send; +} + +pub trait FromWireFormat { + type T; + fn from_wire( + buf: &Bytes, + format: InterpretedFormat, + compression: Option, + ) -> impl std::future::Future> + Send; +} + +impl ToWireFormat for InterpretedWalRecords { + async fn to_wire( + self, + format: InterpretedFormat, + compression: Option, + ) -> Result { + use async_compression::tokio::write::ZstdEncoder; + use async_compression::Level; + + let encode_res: Result = match format { + InterpretedFormat::Bincode => { + let buf = BytesMut::new(); + let mut buf = buf.writer(); + self.ser_into(&mut buf)?; + Ok(buf.into_inner().freeze()) + } + InterpretedFormat::Protobuf => { + let proto: proto::InterpretedWalRecords = self.try_into()?; + let mut buf = BytesMut::new(); + proto + .encode(&mut buf) + .map_err(|e| ToWireFormatError::Protobuf(e.into()))?; + + Ok(buf.freeze()) + } + }; + + let buf = encode_res?; + let compressed_buf = match compression { + Some(Compression::Zstd { level }) => { + let mut encoder = ZstdEncoder::with_quality( + Vec::with_capacity(buf.len() / 4), + Level::Precise(level as i32), + ); + encoder.write_all(&buf).await?; + encoder.shutdown().await?; + Bytes::from(encoder.into_inner()) + } + None => buf, + }; + + Ok(compressed_buf) + } +} + +impl FromWireFormat for InterpretedWalRecords { + type T = Self; + + async fn from_wire( + buf: &Bytes, + format: InterpretedFormat, + compression: Option, + ) -> Result { + let decompressed_buf = match compression { + Some(Compression::Zstd { .. }) => { + use async_compression::tokio::write::ZstdDecoder; + let mut decoded_buf = Vec::with_capacity(buf.len()); + let mut decoder = ZstdDecoder::new(&mut decoded_buf); + decoder.write_all(buf).await?; + decoder.flush().await?; + Bytes::from(decoded_buf) + } + None => buf.clone(), + }; + + match format { + InterpretedFormat::Bincode => { + InterpretedWalRecords::des(&decompressed_buf).map_err(FromWireFormatError::Bincode) + } + InterpretedFormat::Protobuf => { + let proto = proto::InterpretedWalRecords::decode(decompressed_buf) + .map_err(|e| FromWireFormatError::Protobuf(e.into()))?; + InterpretedWalRecords::try_from(proto) + .map_err(|e| FromWireFormatError::Protobuf(e.into())) + } + } + } +} + +impl TryFrom for proto::InterpretedWalRecords { + type Error = SerializeError; + + fn try_from(value: InterpretedWalRecords) -> Result { + let records = value + .records + .into_iter() + .map(proto::InterpretedWalRecord::try_from) + .collect::, _>>()?; + Ok(proto::InterpretedWalRecords { + records, + next_record_lsn: value.next_record_lsn.map(|l| l.0), + }) + } +} + +impl TryFrom for proto::InterpretedWalRecord { + type Error = SerializeError; + + fn try_from(value: InterpretedWalRecord) -> Result { + let metadata_record = value + .metadata_record + .map(|meta_rec| -> Result, Self::Error> { + let mut buf = Vec::new(); + meta_rec.ser_into(&mut buf)?; + Ok(buf) + }) + .transpose()?; + + Ok(proto::InterpretedWalRecord { + metadata_record, + batch: Some(proto::SerializedValueBatch::from(value.batch)), + next_record_lsn: value.next_record_lsn.0, + flush_uncommitted: matches!(value.flush_uncommitted, FlushUncommittedRecords::Yes), + xid: value.xid, + }) + } +} + +impl From for proto::SerializedValueBatch { + fn from(value: SerializedValueBatch) -> Self { + proto::SerializedValueBatch { + raw: value.raw, + metadata: value + .metadata + .into_iter() + .map(proto::ValueMeta::from) + .collect(), + max_lsn: value.max_lsn.0, + len: value.len as u64, + } + } +} + +impl From for proto::ValueMeta { + fn from(value: ValueMeta) -> Self { + match value { + ValueMeta::Observed(obs) => proto::ValueMeta { + r#type: proto::ValueMetaType::Observed.into(), + key: Some(proto::CompactKey::from(obs.key)), + lsn: obs.lsn.0, + batch_offset: None, + len: None, + will_init: None, + }, + ValueMeta::Serialized(ser) => proto::ValueMeta { + r#type: proto::ValueMetaType::Serialized.into(), + key: Some(proto::CompactKey::from(ser.key)), + lsn: ser.lsn.0, + batch_offset: Some(ser.batch_offset), + len: Some(ser.len as u64), + will_init: Some(ser.will_init), + }, + } + } +} + +impl From for proto::CompactKey { + fn from(value: CompactKey) -> Self { + proto::CompactKey { + high: (value.raw() >> 64) as i64, + low: value.raw() as i64, + } + } +} + +impl TryFrom for InterpretedWalRecords { + type Error = TranscodeError; + + fn try_from(value: proto::InterpretedWalRecords) -> Result { + let records = value + .records + .into_iter() + .map(InterpretedWalRecord::try_from) + .collect::>()?; + + Ok(InterpretedWalRecords { + records, + next_record_lsn: value.next_record_lsn.map(Lsn::from), + }) + } +} + +impl TryFrom for InterpretedWalRecord { + type Error = TranscodeError; + + fn try_from(value: proto::InterpretedWalRecord) -> Result { + let metadata_record = value + .metadata_record + .map(|mrec| -> Result<_, DeserializeError> { MetadataRecord::des(&mrec) }) + .transpose()?; + + let batch = { + let batch = value.batch.ok_or_else(|| { + TranscodeError::BadInput("InterpretedWalRecord::batch missing".to_string()) + })?; + + SerializedValueBatch::try_from(batch)? + }; + + Ok(InterpretedWalRecord { + metadata_record, + batch, + next_record_lsn: Lsn(value.next_record_lsn), + flush_uncommitted: if value.flush_uncommitted { + FlushUncommittedRecords::Yes + } else { + FlushUncommittedRecords::No + }, + xid: value.xid, + }) + } +} + +impl TryFrom for SerializedValueBatch { + type Error = TranscodeError; + + fn try_from(value: proto::SerializedValueBatch) -> Result { + let metadata = value + .metadata + .into_iter() + .map(ValueMeta::try_from) + .collect::, _>>()?; + + Ok(SerializedValueBatch { + raw: value.raw, + metadata, + max_lsn: Lsn(value.max_lsn), + len: value.len as usize, + }) + } +} + +impl TryFrom for ValueMeta { + type Error = TranscodeError; + + fn try_from(value: proto::ValueMeta) -> Result { + match proto::ValueMetaType::try_from(value.r#type) { + Ok(proto::ValueMetaType::Serialized) => { + Ok(ValueMeta::Serialized(SerializedValueMeta { + key: value + .key + .ok_or_else(|| { + TranscodeError::BadInput("ValueMeta::key missing".to_string()) + })? + .into(), + lsn: Lsn(value.lsn), + batch_offset: value.batch_offset.ok_or_else(|| { + TranscodeError::BadInput("ValueMeta::batch_offset missing".to_string()) + })?, + len: value.len.ok_or_else(|| { + TranscodeError::BadInput("ValueMeta::len missing".to_string()) + })? as usize, + will_init: value.will_init.ok_or_else(|| { + TranscodeError::BadInput("ValueMeta::will_init missing".to_string()) + })?, + })) + } + Ok(proto::ValueMetaType::Observed) => Ok(ValueMeta::Observed(ObservedValueMeta { + key: value + .key + .ok_or_else(|| TranscodeError::BadInput("ValueMeta::key missing".to_string()))? + .into(), + lsn: Lsn(value.lsn), + })), + Err(_) => Err(TranscodeError::BadInput(format!( + "Unexpected ValueMeta::type {}", + value.r#type + ))), + } + } +} + +impl From for CompactKey { + fn from(value: proto::CompactKey) -> Self { + (((value.high as i128) << 64) | (value.low as i128)).into() + } +} diff --git a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs index 7a64703a30..583d6309ab 100644 --- a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs +++ b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs @@ -535,6 +535,7 @@ impl ConnectionManagerState { let node_id = new_sk.safekeeper_id; let connect_timeout = self.conf.wal_connect_timeout; let ingest_batch_size = self.conf.ingest_batch_size; + let protocol = self.conf.protocol; let timeline = Arc::clone(&self.timeline); let ctx = ctx.detached_child( TaskKind::WalReceiverConnectionHandler, @@ -548,6 +549,7 @@ impl ConnectionManagerState { let res = super::walreceiver_connection::handle_walreceiver_connection( timeline, + protocol, new_sk.wal_source_connconf, events_sender, cancellation.clone(), @@ -991,7 +993,7 @@ impl ConnectionManagerState { PostgresClientProtocol::Vanilla => { (None, None, None) }, - PostgresClientProtocol::Interpreted => { + PostgresClientProtocol::Interpreted { .. } => { let shard_identity = self.timeline.get_shard_identity(); ( Some(shard_identity.number.0), diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 1a0e66ceb3..31cf1b6307 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -22,7 +22,10 @@ use tokio::{select, sync::watch, time}; use tokio_postgres::{replication::ReplicationStream, Client}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, trace, warn, Instrument}; -use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord}; +use wal_decoder::{ + models::{FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords}, + wire_format::FromWireFormat, +}; use super::TaskStateUpdate; use crate::{ @@ -36,7 +39,7 @@ use crate::{ use postgres_backend::is_expected_io_error; use postgres_connection::PgConnectionConfig; use postgres_ffi::waldecoder::WalStreamDecoder; -use utils::{bin_ser::BeSer, id::NodeId, lsn::Lsn}; +use utils::{id::NodeId, lsn::Lsn, postgres_client::PostgresClientProtocol}; use utils::{pageserver_feedback::PageserverFeedback, sync::gate::GateError}; /// Status of the connection. @@ -109,6 +112,7 @@ impl From for WalReceiverError { #[allow(clippy::too_many_arguments)] pub(super) async fn handle_walreceiver_connection( timeline: Arc, + protocol: PostgresClientProtocol, wal_source_connconf: PgConnectionConfig, events_sender: watch::Sender>, cancellation: CancellationToken, @@ -260,6 +264,14 @@ pub(super) async fn handle_walreceiver_connection( let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx).await?; + let interpreted_proto_config = match protocol { + PostgresClientProtocol::Vanilla => None, + PostgresClientProtocol::Interpreted { + format, + compression, + } => Some((format, compression)), + }; + while let Some(replication_message) = { select! { _ = cancellation.cancelled() => { @@ -332,16 +344,26 @@ pub(super) async fn handle_walreceiver_connection( // This is the end LSN of the raw WAL from which the records // were interpreted. let streaming_lsn = Lsn::from(raw.streaming_lsn()); - tracing::debug!( - "Received WAL up to {streaming_lsn} with next_record_lsn={}", - Lsn(raw.next_record_lsn().unwrap_or(0)) - ); - let records = Vec::::des(raw.data()).with_context(|| { - anyhow::anyhow!( + let (format, compression) = interpreted_proto_config.unwrap(); + let batch = InterpretedWalRecords::from_wire(raw.data(), format, compression) + .await + .with_context(|| { + anyhow::anyhow!( "Failed to deserialize interpreted records ending at LSN {streaming_lsn}" ) - })?; + })?; + + let InterpretedWalRecords { + records, + next_record_lsn, + } = batch; + + tracing::debug!( + "Received WAL up to {} with next_record_lsn={:?}", + streaming_lsn, + next_record_lsn + ); // We start the modification at 0 because each interpreted record // advances it to its end LSN. 0 is just an initialization placeholder. @@ -360,14 +382,18 @@ pub(super) async fn handle_walreceiver_connection( .await?; } - let next_record_lsn = interpreted.next_record_lsn; + let local_next_record_lsn = interpreted.next_record_lsn; let ingested = walingest .ingest_record(interpreted, &mut modification, &ctx) .await - .with_context(|| format!("could not ingest record at {next_record_lsn}"))?; + .with_context(|| { + format!("could not ingest record at {local_next_record_lsn}") + })?; if !ingested { - tracing::debug!("ingest: filtered out record @ LSN {next_record_lsn}"); + tracing::debug!( + "ingest: filtered out record @ LSN {local_next_record_lsn}" + ); WAL_INGEST.records_filtered.inc(); filtered_records += 1; } @@ -399,7 +425,7 @@ pub(super) async fn handle_walreceiver_connection( // need to advance last record LSN on all shards. If we've not ingested the latest // record, then set the LSN of the modification past it. This way all shards // advance their last record LSN at the same time. - let needs_last_record_lsn_advance = match raw.next_record_lsn().map(Lsn::from) { + let needs_last_record_lsn_advance = match next_record_lsn.map(Lsn::from) { Some(lsn) if lsn > modification.get_lsn() => { modification.set_lsn(lsn).unwrap(); true diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index cec7c3c7ee..22f33b17e0 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -123,17 +123,10 @@ impl postgres_backend::Handler // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064 match opt.split_once('=') { Some(("protocol", value)) => { - let raw_value = value - .parse::() - .with_context(|| format!("Failed to parse {value} as protocol"))?; - - self.protocol = Some( - PostgresClientProtocol::try_from(raw_value).map_err(|_| { - QueryError::Other(anyhow::anyhow!( - "Unexpected client protocol type: {raw_value}" - )) - })?, - ); + self.protocol = + Some(serde_json::from_str(value).with_context(|| { + format!("Failed to parse {value} as protocol") + })?); } Some(("ztenantid", value)) | Some(("tenant_id", value)) => { self.tenant_id = Some(value.parse().with_context(|| { @@ -180,7 +173,7 @@ impl postgres_backend::Handler ))); } } - PostgresClientProtocol::Interpreted => { + PostgresClientProtocol::Interpreted { .. } => { match (shard_count, shard_number, shard_stripe_size) { (Some(count), Some(number), Some(stripe_size)) => { let params = ShardParameters { diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs index cf0ee276e9..2589030422 100644 --- a/safekeeper/src/send_interpreted_wal.rs +++ b/safekeeper/src/send_interpreted_wal.rs @@ -9,9 +9,11 @@ use postgres_ffi::{get_current_timestamp, waldecoder::WalStreamDecoder}; use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::MissedTickBehavior; -use utils::bin_ser::BeSer; use utils::lsn::Lsn; -use wal_decoder::models::InterpretedWalRecord; +use utils::postgres_client::Compression; +use utils::postgres_client::InterpretedFormat; +use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords}; +use wal_decoder::wire_format::ToWireFormat; use crate::send_wal::EndWatchView; use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder}; @@ -20,6 +22,8 @@ use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder}; /// This is used for sending WAL to the pageserver. Said WAL /// is pre-interpreted and filtered for the shard. pub(crate) struct InterpretedWalSender<'a, IO> { + pub(crate) format: InterpretedFormat, + pub(crate) compression: Option, pub(crate) pgb: &'a mut PostgresBackend, pub(crate) wal_stream_builder: WalReaderStreamBuilder, pub(crate) end_watch_view: EndWatchView, @@ -28,6 +32,12 @@ pub(crate) struct InterpretedWalSender<'a, IO> { pub(crate) appname: Option, } +struct Batch { + wal_end_lsn: Lsn, + available_wal_end_lsn: Lsn, + records: InterpretedWalRecords, +} + impl InterpretedWalSender<'_, IO> { /// Send interpreted WAL to a receiver. /// Stops when an error occurs or the receiver is caught up and there's no active compute. @@ -46,10 +56,13 @@ impl InterpretedWalSender<'_, IO> { keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); keepalive_ticker.reset(); + let (tx, mut rx) = tokio::sync::mpsc::channel::(2); + loop { tokio::select! { - // Get some WAL from the stream and then: decode, interpret and send it - wal = stream.next() => { + // Get some WAL from the stream and then: decode, interpret and push it down the + // pipeline. + wal = stream.next(), if tx.capacity() > 0 => { let WalBytes { wal, wal_start_lsn: _, wal_end_lsn, available_wal_end_lsn } = match wal { Some(some) => some?, None => { break; } @@ -81,10 +94,26 @@ impl InterpretedWalSender<'_, IO> { } } - let mut buf = Vec::new(); - records - .ser_into(&mut buf) - .with_context(|| "Failed to serialize interpreted WAL")?; + let batch = InterpretedWalRecords { + records, + next_record_lsn: max_next_record_lsn + }; + + tx.send(Batch {wal_end_lsn, available_wal_end_lsn, records: batch}).await.unwrap(); + }, + // For a previously interpreted batch, serialize it and push it down the wire. + batch = rx.recv() => { + let batch = match batch { + Some(b) => b, + None => { break; } + }; + + let buf = batch + .records + .to_wire(self.format, self.compression) + .await + .with_context(|| "Failed to serialize interpreted WAL") + .map_err(CopyStreamHandlerEnd::from)?; // Reset the keep alive ticker since we are sending something // over the wire now. @@ -92,13 +121,11 @@ impl InterpretedWalSender<'_, IO> { self.pgb .write_message(&BeMessage::InterpretedWalRecords(InterpretedWalRecordsBody { - streaming_lsn: wal_end_lsn.0, - commit_lsn: available_wal_end_lsn.0, - next_record_lsn: max_next_record_lsn.unwrap_or(Lsn::INVALID).0, - data: buf.as_slice(), + streaming_lsn: batch.wal_end_lsn.0, + commit_lsn: batch.available_wal_end_lsn.0, + data: &buf, })).await?; } - // Send a periodic keep alive when the connection has been idle for a while. _ = keepalive_ticker.tick() => { self.pgb diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 1acfcad418..225b7f4c05 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -454,7 +454,7 @@ impl SafekeeperPostgresHandler { } info!( - "starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}, protocol={}", + "starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}, protocol={:?}", start_pos, end_pos, matches!(end_watch, EndWatch::Flush(_)), @@ -489,7 +489,10 @@ impl SafekeeperPostgresHandler { Either::Left(sender.run()) } - PostgresClientProtocol::Interpreted => { + PostgresClientProtocol::Interpreted { + format, + compression, + } => { let pg_version = tli.tli.get_state().await.1.server.pg_version / 10000; let end_watch_view = end_watch.view(); let wal_stream_builder = WalReaderStreamBuilder { @@ -502,6 +505,8 @@ impl SafekeeperPostgresHandler { }; let sender = InterpretedWalSender { + format, + compression, pgb, wal_stream_builder, end_watch_view, diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 07d442b4a6..a45a311dc2 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -310,6 +310,31 @@ class PgProtocol: return self.safe_psql(query, log_query=log_query)[0][0] +class PageserverWalReceiverProtocol(StrEnum): + VANILLA = "vanilla" + INTERPRETED = "interpreted" + + @staticmethod + def to_config_key_value(proto) -> tuple[str, dict[str, Any]]: + if proto == PageserverWalReceiverProtocol.VANILLA: + return ( + "wal_receiver_protocol", + { + "type": "vanilla", + }, + ) + elif proto == PageserverWalReceiverProtocol.INTERPRETED: + return ( + "wal_receiver_protocol", + { + "type": "interpreted", + "args": {"format": "protobuf", "compression": {"zstd": {"level": 1}}}, + }, + ) + else: + raise ValueError(f"Unknown protocol type: {proto}") + + class NeonEnvBuilder: """ Builder object to create a Neon runtime environment @@ -356,6 +381,7 @@ class NeonEnvBuilder: safekeeper_extra_opts: list[str] | None = None, storage_controller_port_override: int | None = None, pageserver_virtual_file_io_mode: str | None = None, + pageserver_wal_receiver_protocol: PageserverWalReceiverProtocol | None = None, ): self.repo_dir = repo_dir self.rust_log_override = rust_log_override @@ -409,6 +435,8 @@ class NeonEnvBuilder: self.pageserver_virtual_file_io_mode = pageserver_virtual_file_io_mode + self.pageserver_wal_receiver_protocol = pageserver_wal_receiver_protocol + assert test_name.startswith( "test_" ), "Unexpectedly instantiated from outside a test function" @@ -1023,6 +1051,7 @@ class NeonEnv: self.pageserver_virtual_file_io_engine = config.pageserver_virtual_file_io_engine self.pageserver_virtual_file_io_mode = config.pageserver_virtual_file_io_mode + self.pageserver_wal_receiver_protocol = config.pageserver_wal_receiver_protocol # Create the neon_local's `NeonLocalInitConf` cfg: dict[str, Any] = { @@ -1092,6 +1121,13 @@ class NeonEnv: if self.pageserver_virtual_file_io_mode is not None: ps_cfg["virtual_file_io_mode"] = self.pageserver_virtual_file_io_mode + if self.pageserver_wal_receiver_protocol is not None: + key, value = PageserverWalReceiverProtocol.to_config_key_value( + self.pageserver_wal_receiver_protocol + ) + if key not in ps_cfg: + ps_cfg[key] = value + # Create a corresponding NeonPageserver object self.pageservers.append( NeonPageserver(self, ps_id, port=pageserver_port, az_id=ps_cfg["availability_zone"]) diff --git a/test_runner/performance/test_sharded_ingest.py b/test_runner/performance/test_sharded_ingest.py index e965aae5a0..4c21e799c8 100644 --- a/test_runner/performance/test_sharded_ingest.py +++ b/test_runner/performance/test_sharded_ingest.py @@ -15,7 +15,14 @@ from fixtures.neon_fixtures import ( @pytest.mark.timeout(600) @pytest.mark.parametrize("shard_count", [1, 8, 32]) -@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) +@pytest.mark.parametrize( + "wal_receiver_protocol", + [ + "vanilla", + "interpreted-bincode-compressed", + "interpreted-protobuf-compressed", + ], +) def test_sharded_ingest( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, @@ -27,14 +34,42 @@ def test_sharded_ingest( and fanning out to a large number of shards on dedicated Pageservers. Comparing the base case (shard_count=1) to the sharded case indicates the overhead of sharding. """ - neon_env_builder.pageserver_config_override = ( - f"wal_receiver_protocol = '{wal_receiver_protocol}'" - ) - ROW_COUNT = 100_000_000 # about 7 GB of WAL neon_env_builder.num_pageservers = shard_count - env = neon_env_builder.init_start() + env = neon_env_builder.init_configs() + + for ps in env.pageservers: + if wal_receiver_protocol == "vanilla": + ps.patch_config_toml_nonrecursive( + { + "wal_receiver_protocol": { + "type": "vanilla", + } + } + ) + elif wal_receiver_protocol == "interpreted-bincode-compressed": + ps.patch_config_toml_nonrecursive( + { + "wal_receiver_protocol": { + "type": "interpreted", + "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, + } + } + ) + elif wal_receiver_protocol == "interpreted-protobuf-compressed": + ps.patch_config_toml_nonrecursive( + { + "wal_receiver_protocol": { + "type": "interpreted", + "args": {"format": "protobuf", "compression": {"zstd": {"level": 1}}}, + } + } + ) + else: + raise AssertionError("Test must use explicit wal receiver protocol config") + + env.start() # Create a sharded tenant and timeline, and migrate it to the respective pageservers. Ensure # the storage controller doesn't mess with shard placements. diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 79fd256304..302a8fd0d1 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -8,6 +8,7 @@ import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnvBuilder, + PageserverWalReceiverProtocol, generate_uploads_and_deletions, ) from fixtures.pageserver.http import PageserverApiException @@ -27,8 +28,13 @@ AGGRESIVE_COMPACTION_TENANT_CONF = { @skip_in_debug_build("only run with release build") -@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) -def test_pageserver_compaction_smoke(neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: str): +@pytest.mark.parametrize( + "wal_receiver_protocol", + [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], +) +def test_pageserver_compaction_smoke( + neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: PageserverWalReceiverProtocol +): """ This is a smoke test that compaction kicks in. The workload repeatedly churns a small number of rows and manually instructs the pageserver to run compaction @@ -37,10 +43,12 @@ def test_pageserver_compaction_smoke(neon_env_builder: NeonEnvBuilder, wal_recei observed bounds. """ + neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol + # Effectively disable the page cache to rely only on image layers # to shorten reads. - neon_env_builder.pageserver_config_override = f""" -page_cache_size=10; wal_receiver_protocol='{wal_receiver_protocol}' + neon_env_builder.pageserver_config_override = """ +page_cache_size=10 """ env = neon_env_builder.init_start(initial_tenant_conf=AGGRESIVE_COMPACTION_TENANT_CONF) diff --git a/test_runner/regress/test_crafted_wal_end.py b/test_runner/regress/test_crafted_wal_end.py index 70e71d99cd..6b9dcbba07 100644 --- a/test_runner/regress/test_crafted_wal_end.py +++ b/test_runner/regress/test_crafted_wal_end.py @@ -3,7 +3,7 @@ from __future__ import annotations import pytest from fixtures.log_helper import log from fixtures.neon_cli import WalCraft -from fixtures.neon_fixtures import NeonEnvBuilder +from fixtures.neon_fixtures import NeonEnvBuilder, PageserverWalReceiverProtocol # Restart nodes with WAL end having specially crafted shape, like last record # crossing segment boundary, to test decoding issues. @@ -19,13 +19,16 @@ from fixtures.neon_fixtures import NeonEnvBuilder "wal_record_crossing_segment_followed_by_small_one", ], ) -@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) +@pytest.mark.parametrize( + "wal_receiver_protocol", + [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], +) def test_crafted_wal_end( - neon_env_builder: NeonEnvBuilder, wal_type: str, wal_receiver_protocol: str + neon_env_builder: NeonEnvBuilder, + wal_type: str, + wal_receiver_protocol: PageserverWalReceiverProtocol, ): - neon_env_builder.pageserver_config_override = ( - f"wal_receiver_protocol = '{wal_receiver_protocol}'" - ) + neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol env = neon_env_builder.init_start() env.create_branch("test_crafted_wal_end") diff --git a/test_runner/regress/test_subxacts.py b/test_runner/regress/test_subxacts.py index 1d86c353be..b235da0bc7 100644 --- a/test_runner/regress/test_subxacts.py +++ b/test_runner/regress/test_subxacts.py @@ -1,7 +1,11 @@ from __future__ import annotations import pytest -from fixtures.neon_fixtures import NeonEnvBuilder, check_restored_datadir_content +from fixtures.neon_fixtures import ( + NeonEnvBuilder, + PageserverWalReceiverProtocol, + check_restored_datadir_content, +) # Test subtransactions @@ -10,11 +14,12 @@ from fixtures.neon_fixtures import NeonEnvBuilder, check_restored_datadir_conten # maintained in the pageserver, so subtransactions are not very exciting for # Neon. They are included in the commit record though and updated in the # CLOG. -@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) +@pytest.mark.parametrize( + "wal_receiver_protocol", + [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], +) def test_subxacts(neon_env_builder: NeonEnvBuilder, test_output_dir, wal_receiver_protocol): - neon_env_builder.pageserver_config_override = ( - f"wal_receiver_protocol = '{wal_receiver_protocol}'" - ) + neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol env = neon_env_builder.init_start() endpoint = env.endpoints.create_start("main") diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index 094b10b576..b32b028fa1 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -11,7 +11,13 @@ import pytest import toml from fixtures.common_types import Lsn, TenantId, TimelineId from fixtures.log_helper import getLogger -from fixtures.neon_fixtures import Endpoint, NeonEnv, NeonEnvBuilder, Safekeeper +from fixtures.neon_fixtures import ( + Endpoint, + NeonEnv, + NeonEnvBuilder, + PageserverWalReceiverProtocol, + Safekeeper, +) from fixtures.remote_storage import RemoteStorageKind from fixtures.utils import skip_in_debug_build @@ -622,12 +628,15 @@ async def run_segment_init_failure(env: NeonEnv): # Test (injected) failure during WAL segment init. # https://github.com/neondatabase/neon/issues/6401 # https://github.com/neondatabase/neon/issues/6402 -@pytest.mark.parametrize("wal_receiver_protocol", ["vanilla", "interpreted"]) -def test_segment_init_failure(neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: str): +@pytest.mark.parametrize( + "wal_receiver_protocol", + [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], +) +def test_segment_init_failure( + neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: PageserverWalReceiverProtocol +): neon_env_builder.num_safekeepers = 1 - neon_env_builder.pageserver_config_override = ( - f"wal_receiver_protocol = '{wal_receiver_protocol}'" - ) + neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol env = neon_env_builder.init_start() asyncio.run(run_segment_init_failure(env)) From 8fdf786217170192d383211f6e3fe0283ce5036d Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Wed, 27 Nov 2024 13:46:23 +0000 Subject: [PATCH 19/34] pageserver: add tenant config override for wal receiver proto (#9888) ## Problem Can't change protocol at tenant granularity. ## Summary of changes Add tenant config level override for wal receiver protocol. ## Links Related: https://github.com/neondatabase/neon/issues/9336 Epic: https://github.com/neondatabase/neon/issues/9329 --- control_plane/src/pageserver.rs | 5 +++++ libs/pageserver_api/src/config.rs | 3 +++ libs/pageserver_api/src/models.rs | 2 ++ pageserver/src/tenant.rs | 1 + pageserver/src/tenant/config.rs | 8 ++++++++ pageserver/src/tenant/timeline.rs | 18 +++++++++++++++++- .../regress/test_attach_tenant_config.py | 4 ++++ 7 files changed, 40 insertions(+), 1 deletion(-) diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index ae5e22ddc6..1d1455b95b 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -415,6 +415,11 @@ impl PageServerNode { .map(|x| x.parse::()) .transpose() .context("Failed to parse 'timeline_offloading' as bool")?, + wal_receiver_protocol_override: settings + .remove("wal_receiver_protocol_override") + .map(serde_json::from_str) + .transpose() + .context("parse `wal_receiver_protocol_override` from json")?, }; if !settings.is_empty() { bail!("Unrecognized tenant settings: {settings:?}") diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 0abca5cdc2..721d97404b 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -278,6 +278,8 @@ pub struct TenantConfigToml { /// Enable auto-offloading of timelines. /// (either this flag or the pageserver-global one need to be set) pub timeline_offloading: bool, + + pub wal_receiver_protocol_override: Option, } pub mod defaults { @@ -510,6 +512,7 @@ impl Default for TenantConfigToml { lsn_lease_length: LsnLease::DEFAULT_LENGTH, lsn_lease_length_for_ts: LsnLease::DEFAULT_LENGTH_FOR_TS, timeline_offloading: false, + wal_receiver_protocol_override: None, } } } diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 1b86bfd91a..42c5d10c05 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -23,6 +23,7 @@ use utils::{ completion, id::{NodeId, TenantId, TimelineId}, lsn::Lsn, + postgres_client::PostgresClientProtocol, serde_system_time, }; @@ -352,6 +353,7 @@ pub struct TenantConfig { pub lsn_lease_length: Option, pub lsn_lease_length_for_ts: Option, pub timeline_offloading: Option, + pub wal_receiver_protocol_override: Option, } /// The policy for the aux file storage. diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 0214ee68fa..bddcb534a1 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -5344,6 +5344,7 @@ pub(crate) mod harness { lsn_lease_length: Some(tenant_conf.lsn_lease_length), lsn_lease_length_for_ts: Some(tenant_conf.lsn_lease_length_for_ts), timeline_offloading: Some(tenant_conf.timeline_offloading), + wal_receiver_protocol_override: tenant_conf.wal_receiver_protocol_override, } } } diff --git a/pageserver/src/tenant/config.rs b/pageserver/src/tenant/config.rs index 4d6176bfd9..5d3ac5a8e3 100644 --- a/pageserver/src/tenant/config.rs +++ b/pageserver/src/tenant/config.rs @@ -19,6 +19,7 @@ use serde_json::Value; use std::num::NonZeroU64; use std::time::Duration; use utils::generation::Generation; +use utils::postgres_client::PostgresClientProtocol; #[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) enum AttachmentMode { @@ -353,6 +354,9 @@ pub struct TenantConfOpt { #[serde(skip_serializing_if = "Option::is_none")] #[serde(default)] pub timeline_offloading: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub wal_receiver_protocol_override: Option, } impl TenantConfOpt { @@ -418,6 +422,9 @@ impl TenantConfOpt { timeline_offloading: self .lazy_slru_download .unwrap_or(global_conf.timeline_offloading), + wal_receiver_protocol_override: self + .wal_receiver_protocol_override + .or(global_conf.wal_receiver_protocol_override), } } } @@ -472,6 +479,7 @@ impl From for models::TenantConfig { lsn_lease_length: value.lsn_lease_length.map(humantime), lsn_lease_length_for_ts: value.lsn_lease_length_for_ts.map(humantime), timeline_offloading: value.timeline_offloading, + wal_receiver_protocol_override: value.wal_receiver_protocol_override, } } } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index c1ff0f426d..afd4664d01 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -50,6 +50,7 @@ use tokio_util::sync::CancellationToken; use tracing::*; use utils::{ fs_ext, pausable_failpoint, + postgres_client::PostgresClientProtocol, sync::gate::{Gate, GateGuard}, }; use wal_decoder::serialized_batch::SerializedValueBatch; @@ -2178,6 +2179,21 @@ impl Timeline { ) } + /// Resolve the effective WAL receiver protocol to use for this tenant. + /// + /// Priority order is: + /// 1. Tenant config override + /// 2. Default value for tenant config override + /// 3. Pageserver config override + /// 4. Pageserver config default + pub fn resolve_wal_receiver_protocol(&self) -> PostgresClientProtocol { + let tenant_conf = self.tenant_conf.load().tenant_conf.clone(); + tenant_conf + .wal_receiver_protocol_override + .or(self.conf.default_tenant_conf.wal_receiver_protocol_override) + .unwrap_or(self.conf.wal_receiver_protocol) + } + pub(super) fn tenant_conf_updated(&self, new_conf: &AttachedTenantConf) { // NB: Most tenant conf options are read by background loops, so, // changes will automatically be picked up. @@ -2470,7 +2486,7 @@ impl Timeline { *guard = Some(WalReceiver::start( Arc::clone(self), WalReceiverConf { - protocol: self.conf.wal_receiver_protocol, + protocol: self.resolve_wal_receiver_protocol(), wal_connect_timeout, lagging_wal_timeout, max_lsn_wal_lag, diff --git a/test_runner/regress/test_attach_tenant_config.py b/test_runner/regress/test_attach_tenant_config.py index 5744c445f6..670c2698f5 100644 --- a/test_runner/regress/test_attach_tenant_config.py +++ b/test_runner/regress/test_attach_tenant_config.py @@ -174,6 +174,10 @@ def test_fully_custom_config(positive_env: NeonEnv): "lsn_lease_length": "1m", "lsn_lease_length_for_ts": "5s", "timeline_offloading": True, + "wal_receiver_protocol_override": { + "type": "interpreted", + "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, + }, } vps_http = env.storage_controller.pageserver_api() From e4f437a354cc42bcbb081f72dffa8987932459f3 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Wed, 27 Nov 2024 14:54:14 +0100 Subject: [PATCH 20/34] pageserver: add relsize cache metrics (#9890) ## Problem We don't have any observability for the relation size cache. We have seen cache misses cause significant performance impact with high relation counts. Touches #9855. ## Summary of changes Adds the following metrics: * `pageserver_relsize_cache_entries` * `pageserver_relsize_cache_hits` * `pageserver_relsize_cache_misses` * `pageserver_relsize_cache_misses_old` --- pageserver/src/metrics.rs | 29 +++++++++++++++++++++++++++++ pageserver/src/pgdatadir_mapping.rs | 15 +++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 5ce3ae6cf7..78a157f51e 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -662,6 +662,35 @@ pub(crate) static COMPRESSION_IMAGE_OUTPUT_BYTES: Lazy = Lazy::new(| .expect("failed to define a metric") }); +pub(crate) static RELSIZE_CACHE_ENTRIES: Lazy = Lazy::new(|| { + register_uint_gauge!( + "pageserver_relsize_cache_entries", + "Number of entries in the relation size cache", + ) + .expect("failed to define a metric") +}); + +pub(crate) static RELSIZE_CACHE_HITS: Lazy = Lazy::new(|| { + register_int_counter!("pageserver_relsize_cache_hits", "Relation size cache hits",) + .expect("failed to define a metric") +}); + +pub(crate) static RELSIZE_CACHE_MISSES: Lazy = Lazy::new(|| { + register_int_counter!( + "pageserver_relsize_cache_misses", + "Relation size cache misses", + ) + .expect("failed to define a metric") +}); + +pub(crate) static RELSIZE_CACHE_MISSES_OLD: Lazy = Lazy::new(|| { + register_int_counter!( + "pageserver_relsize_cache_misses_old", + "Relation size cache misses where the lookup LSN is older than the last relation update" + ) + .expect("failed to define a metric") +}); + pub(crate) mod initial_logical_size { use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec}; use once_cell::sync::Lazy; diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index c491bfe650..4f42427276 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -10,6 +10,9 @@ use super::tenant::{PageReconstructError, Timeline}; use crate::aux_file; use crate::context::RequestContext; use crate::keyspace::{KeySpace, KeySpaceAccum}; +use crate::metrics::{ + RELSIZE_CACHE_ENTRIES, RELSIZE_CACHE_HITS, RELSIZE_CACHE_MISSES, RELSIZE_CACHE_MISSES_OLD, +}; use crate::span::{ debug_assert_current_span_has_tenant_and_timeline_id, debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id, @@ -1129,9 +1132,12 @@ impl Timeline { let rel_size_cache = self.rel_size_cache.read().unwrap(); if let Some((cached_lsn, nblocks)) = rel_size_cache.map.get(tag) { if lsn >= *cached_lsn { + RELSIZE_CACHE_HITS.inc(); return Some(*nblocks); } + RELSIZE_CACHE_MISSES_OLD.inc(); } + RELSIZE_CACHE_MISSES.inc(); None } @@ -1156,6 +1162,7 @@ impl Timeline { } hash_map::Entry::Vacant(entry) => { entry.insert((lsn, nblocks)); + RELSIZE_CACHE_ENTRIES.inc(); } } } @@ -1163,13 +1170,17 @@ impl Timeline { /// Store cached relation size pub fn set_cached_rel_size(&self, tag: RelTag, lsn: Lsn, nblocks: BlockNumber) { let mut rel_size_cache = self.rel_size_cache.write().unwrap(); - rel_size_cache.map.insert(tag, (lsn, nblocks)); + if rel_size_cache.map.insert(tag, (lsn, nblocks)).is_none() { + RELSIZE_CACHE_ENTRIES.inc(); + } } /// Remove cached relation size pub fn remove_cached_rel_size(&self, tag: &RelTag) { let mut rel_size_cache = self.rel_size_cache.write().unwrap(); - rel_size_cache.map.remove(tag); + if rel_size_cache.map.remove(tag).is_some() { + RELSIZE_CACHE_ENTRIES.dec(); + } } } From 23f5a2714631e6e211a5afac23c21076399e3c8a Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Wed, 27 Nov 2024 11:07:39 -0500 Subject: [PATCH 21/34] fix(storage-scrubber): valid layermap error degrades to warning (#9902) Valid layer assumption is a necessary condition for a layer map to be valid. It's a stronger check imposed by gc-compaction than the actual valid layermap definition. Actually, the system can work as long as there are no overlapping layer maps. Therefore, we degrade that into a warning. Signed-off-by: Alex Chi Z --- storage_scrubber/src/checks.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/storage_scrubber/src/checks.rs b/storage_scrubber/src/checks.rs index 525f412b56..8d855d263c 100644 --- a/storage_scrubber/src/checks.rs +++ b/storage_scrubber/src/checks.rs @@ -128,7 +128,7 @@ pub(crate) async fn branch_cleanup_and_check_errors( let layer_names = index_part.layer_metadata.keys().cloned().collect_vec(); if let Some(err) = check_valid_layermap(&layer_names) { - result.errors.push(format!( + result.warnings.push(format!( "index_part.json contains invalid layer map structure: {err}" )); } From cc37fa0f33df21cb0adfff922e199b2ef1d30207 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Wed, 27 Nov 2024 18:16:41 +0100 Subject: [PATCH 22/34] pageserver: add metrics for unknown `ClearVmBits` pages (#9911) ## Problem When ingesting implicit `ClearVmBits` operations, we silently drop the writes if the relation or page is unknown. There are implicit assumptions around VM pages wrt. explicit/implicit updates, sharding, and relation sizes, which can possibly drop writes incorrectly. Adding a few metrics will allow us to investigate further and tighten up the logic. Touches #9855. ## Summary of changes Add a `pageserver_wal_ingest_clear_vm_bits_unknown` metric to record dropped `ClearVmBits` writes. Also add comments clarifying the behavior of relation sizes on non-zero shards. --- pageserver/src/metrics.rs | 7 +++++ pageserver/src/pgdatadir_mapping.rs | 17 ++++++++-- pageserver/src/walingest.rs | 49 +++++++++++++++++++++-------- 3 files changed, 57 insertions(+), 16 deletions(-) diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 78a157f51e..86be97587f 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -2144,6 +2144,7 @@ pub(crate) struct WalIngestMetrics { pub(crate) records_committed: IntCounter, pub(crate) records_filtered: IntCounter, pub(crate) gap_blocks_zeroed_on_rel_extend: IntCounter, + pub(crate) clear_vm_bits_unknown: IntCounterVec, } pub(crate) static WAL_INGEST: Lazy = Lazy::new(|| WalIngestMetrics { @@ -2172,6 +2173,12 @@ pub(crate) static WAL_INGEST: Lazy = Lazy::new(|| WalIngestMet "Total number of zero gap blocks written on relation extends" ) .expect("failed to define a metric"), + clear_vm_bits_unknown: register_int_counter_vec!( + "pageserver_wal_ingest_clear_vm_bits_unknown", + "Number of ignored ClearVmBits operations due to unknown pages/relations", + &["entity"], + ) + .expect("failed to define a metric"), }); pub(crate) static WAL_REDO_TIME: Lazy = Lazy::new(|| { diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index 4f42427276..d48a1ba117 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -392,7 +392,9 @@ impl Timeline { result } - // Get size of a database in blocks + /// Get size of a database in blocks. This is only accurate on shard 0. It will undercount on + /// other shards, by only accounting for relations the shard has pages for, and only accounting + /// for pages up to the highest page number it has stored. pub(crate) async fn get_db_size( &self, spcnode: Oid, @@ -411,7 +413,10 @@ impl Timeline { Ok(total_blocks) } - /// Get size of a relation file + /// Get size of a relation file. The relation must exist, otherwise an error is returned. + /// + /// This is only accurate on shard 0. On other shards, it will return the size up to the highest + /// page number stored in the shard. pub(crate) async fn get_rel_size( &self, tag: RelTag, @@ -447,7 +452,10 @@ impl Timeline { Ok(nblocks) } - /// Does relation exist? + /// Does the relation exist? + /// + /// Only shard 0 has a full view of the relations. Other shards only know about relations that + /// the shard stores pages for. pub(crate) async fn get_rel_exists( &self, tag: RelTag, @@ -481,6 +489,9 @@ impl Timeline { /// Get a list of all existing relations in given tablespace and database. /// + /// Only shard 0 has a full view of the relations. Other shards only know about relations that + /// the shard stores pages for. + /// /// # Cancel-Safety /// /// This method is cancellation-safe. diff --git a/pageserver/src/walingest.rs b/pageserver/src/walingest.rs index ad6ccbc854..d568da596a 100644 --- a/pageserver/src/walingest.rs +++ b/pageserver/src/walingest.rs @@ -334,14 +334,32 @@ impl WalIngest { // replaying it would fail to find the previous image of the page, because // it doesn't exist. So check if the VM page(s) exist, and skip the WAL // record if it doesn't. - let vm_size = get_relsize(modification, vm_rel, ctx).await?; + // + // TODO: analyze the metrics and tighten this up accordingly. This logic + // implicitly assumes that VM pages see explicit WAL writes before + // implicit ClearVmBits, and will otherwise silently drop updates. + let Some(vm_size) = get_relsize(modification, vm_rel, ctx).await? else { + WAL_INGEST + .clear_vm_bits_unknown + .with_label_values(&["relation"]) + .inc(); + return Ok(()); + }; if let Some(blknum) = new_vm_blk { if blknum >= vm_size { + WAL_INGEST + .clear_vm_bits_unknown + .with_label_values(&["new_page"]) + .inc(); new_vm_blk = None; } } if let Some(blknum) = old_vm_blk { if blknum >= vm_size { + WAL_INGEST + .clear_vm_bits_unknown + .with_label_values(&["old_page"]) + .inc(); old_vm_blk = None; } } @@ -572,7 +590,8 @@ impl WalIngest { modification.put_rel_page_image_zero(rel, fsm_physical_page_no)?; fsm_physical_page_no += 1; } - let nblocks = get_relsize(modification, rel, ctx).await?; + // TODO: re-examine the None case here wrt. sharding; should we error? + let nblocks = get_relsize(modification, rel, ctx).await?.unwrap_or(0); if nblocks > fsm_physical_page_no { // check if something to do: FSM is larger than truncate position self.put_rel_truncation(modification, rel, fsm_physical_page_no, ctx) @@ -612,7 +631,8 @@ impl WalIngest { )?; vm_page_no += 1; } - let nblocks = get_relsize(modification, rel, ctx).await?; + // TODO: re-examine the None case here wrt. sharding; should we error? + let nblocks = get_relsize(modification, rel, ctx).await?.unwrap_or(0); if nblocks > vm_page_no { // check if something to do: VM is larger than truncate position self.put_rel_truncation(modification, rel, vm_page_no, ctx) @@ -1430,24 +1450,27 @@ impl WalIngest { } } +/// Returns the size of the relation as of this modification, or None if the relation doesn't exist. +/// +/// This is only accurate on shard 0. On other shards, it will return the size up to the highest +/// page number stored in the shard, or None if the shard does not have any pages for it. async fn get_relsize( modification: &DatadirModification<'_>, rel: RelTag, ctx: &RequestContext, -) -> Result { - let nblocks = if !modification +) -> Result, PageReconstructError> { + if !modification .tline .get_rel_exists(rel, Version::Modified(modification), ctx) .await? { - 0 - } else { - modification - .tline - .get_rel_size(rel, Version::Modified(modification), ctx) - .await? - }; - Ok(nblocks) + return Ok(None); + } + modification + .tline + .get_rel_size(rel, Version::Modified(modification), ctx) + .await + .map(Some) } #[allow(clippy::bool_assert_comparison)] From 5c41707beefa99b53548530698dc2970dd876024 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Wed, 27 Nov 2024 19:05:46 +0100 Subject: [PATCH 23/34] proxy: promote two logs to error, fix multiline log (#9913) * Promote two logs from mpsc send errors to error level. The channels are unbounded and there shouldn't be errors. * Fix one multiline log from anyhow::Error. Use Debug instead of Display. --- proxy/src/context/mod.rs | 18 +++++++++++------- proxy/src/context/parquet.rs | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 4ec04deb25..5c19a23e36 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -8,7 +8,7 @@ use pq_proto::StartupMessageParams; use smol_str::SmolStr; use tokio::sync::mpsc; use tracing::field::display; -use tracing::{debug, info_span, Span}; +use tracing::{debug, error, info_span, Span}; use try_lock::TryLock; use uuid::Uuid; @@ -415,9 +415,11 @@ impl RequestContextInner { }); } if let Some(tx) = self.sender.take() { - tx.send(RequestData::from(&*self)) - .inspect_err(|e| debug!("tx send failed: {e}")) - .ok(); + // If type changes, this error handling needs to be updated. + let tx: mpsc::UnboundedSender = tx; + if let Err(e) = tx.send(RequestData::from(&*self)) { + error!("log_connect channel send failed: {e}"); + } } } @@ -426,9 +428,11 @@ impl RequestContextInner { // Here we log the length of the session. self.disconnect_timestamp = Some(Utc::now()); if let Some(tx) = self.disconnect_sender.take() { - tx.send(RequestData::from(&*self)) - .inspect_err(|e| debug!("tx send failed: {e}")) - .ok(); + // If type changes, this error handling needs to be updated. + let tx: mpsc::UnboundedSender = tx; + if let Err(e) = tx.send(RequestData::from(&*self)) { + error!("log_disconnect channel send failed: {e}"); + } } } } diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 9bf3a275bb..e328c6de79 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -398,7 +398,7 @@ async fn upload_parquet( .err(); if let Some(err) = maybe_err { - tracing::warn!(%id, %err, "failed to upload request data"); + tracing::error!(%id, error = ?err, "failed to upload request data"); } Ok(buffer.writer()) From 9e3cb75bc785a87967f5a8c0f866f65808680b2e Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:30:54 -0500 Subject: [PATCH 24/34] fix(pageserver): flush deletion queue in `reload` shutdown mode (#9884) ## Problem close https://github.com/neondatabase/neon/issues/9859 ## Summary of changes Ensure that the deletion queue gets fully flushed (i.e., the deletion lists get applied) during a graceful shutdown. It is still possible that an incomplete shutdown would leave deletion list behind and cause race upon the next startup, but we assume this will unlikely happen, and even if it happened, the pageserver should already be at a tainted state and the tenant should be moved to a new tenant with a new generation number. --------- Signed-off-by: Alex Chi Z --- pageserver/src/deletion_queue.rs | 74 +++++++++++------------ pageserver/src/tenant.rs | 12 ++++ pageserver/src/tenant/mgr.rs | 2 +- pageserver/src/tenant/timeline.rs | 11 ++-- pageserver/src/tenant/timeline/offload.rs | 2 +- 5 files changed, 57 insertions(+), 44 deletions(-) diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index 37fa300467..e74c8ecf5a 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -1144,18 +1144,24 @@ pub(crate) mod mock { rx: tokio::sync::mpsc::UnboundedReceiver, executor_rx: tokio::sync::mpsc::Receiver, cancel: CancellationToken, + executed: Arc, } impl ConsumerState { - async fn consume(&mut self, remote_storage: &GenericRemoteStorage) -> usize { - let mut executed = 0; - + async fn consume(&mut self, remote_storage: &GenericRemoteStorage) { info!("Executing all pending deletions"); // Transform all executor messages to generic frontend messages - while let Ok(msg) = self.executor_rx.try_recv() { + loop { + use either::Either; + let msg = tokio::select! { + left = self.executor_rx.recv() => Either::Left(left), + right = self.rx.recv() => Either::Right(right), + }; match msg { - DeleterMessage::Delete(objects) => { + Either::Left(None) => break, + Either::Right(None) => break, + Either::Left(Some(DeleterMessage::Delete(objects))) => { for path in objects { match remote_storage.delete(&path, &self.cancel).await { Ok(_) => { @@ -1165,18 +1171,13 @@ pub(crate) mod mock { error!("Failed to delete {path}, leaking object! ({e})"); } } - executed += 1; + self.executed.fetch_add(1, Ordering::Relaxed); } } - DeleterMessage::Flush(flush_op) => { + Either::Left(Some(DeleterMessage::Flush(flush_op))) => { flush_op.notify(); } - } - } - - while let Ok(msg) = self.rx.try_recv() { - match msg { - ListWriterQueueMessage::Delete(op) => { + Either::Right(Some(ListWriterQueueMessage::Delete(op))) => { let mut objects = op.objects; for (layer, meta) in op.layers { objects.push(remote_layer_path( @@ -1198,33 +1199,27 @@ pub(crate) mod mock { error!("Failed to delete {path}, leaking object! ({e})"); } } - executed += 1; + self.executed.fetch_add(1, Ordering::Relaxed); } } - ListWriterQueueMessage::Flush(op) => { + Either::Right(Some(ListWriterQueueMessage::Flush(op))) => { op.notify(); } - ListWriterQueueMessage::FlushExecute(op) => { + Either::Right(Some(ListWriterQueueMessage::FlushExecute(op))) => { // We have already executed all prior deletions because mock does them inline op.notify(); } - ListWriterQueueMessage::Recover(_) => { + Either::Right(Some(ListWriterQueueMessage::Recover(_))) => { // no-op in mock } } - info!("All pending deletions have been executed"); } - - executed } } pub struct MockDeletionQueue { tx: tokio::sync::mpsc::UnboundedSender, executor_tx: tokio::sync::mpsc::Sender, - executed: Arc, - remote_storage: Option, - consumer: std::sync::Mutex, lsn_table: Arc>, } @@ -1235,29 +1230,34 @@ pub(crate) mod mock { let executed = Arc::new(AtomicUsize::new(0)); + let mut consumer = ConsumerState { + rx, + executor_rx, + cancel: CancellationToken::new(), + executed: executed.clone(), + }; + + tokio::spawn(async move { + if let Some(remote_storage) = &remote_storage { + consumer.consume(remote_storage).await; + } + }); + Self { tx, executor_tx, - executed, - remote_storage, - consumer: std::sync::Mutex::new(ConsumerState { - rx, - executor_rx, - cancel: CancellationToken::new(), - }), lsn_table: Arc::new(std::sync::RwLock::new(VisibleLsnUpdates::new())), } } #[allow(clippy::await_holding_lock)] pub async fn pump(&self) { - if let Some(remote_storage) = &self.remote_storage { - // Permit holding mutex across await, because this is only ever - // called once at a time in tests. - let mut locked = self.consumer.lock().unwrap(); - let count = locked.consume(remote_storage).await; - self.executed.fetch_add(count, Ordering::Relaxed); - } + let (tx, rx) = tokio::sync::oneshot::channel(); + self.executor_tx + .send(DeleterMessage::Flush(FlushOp { tx })) + .await + .expect("Failed to send flush message"); + rx.await.ok(); } pub(crate) fn new_client(&self) -> DeletionQueueClient { diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index bddcb534a1..339a3ca1bb 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -3215,6 +3215,18 @@ impl Tenant { } } + if let ShutdownMode::Reload = shutdown_mode { + tracing::info!("Flushing deletion queue"); + if let Err(e) = self.deletion_queue_client.flush().await { + match e { + DeletionQueueError::ShuttingDown => { + // This is the only error we expect for now. In the future, if more error + // variants are added, we should handle them here. + } + } + } + } + // We cancel the Tenant's cancellation token _after_ the timelines have all shut down. This permits // them to continue to do work during their shutdown methods, e.g. flushing data. tracing::debug!("Cancelling CancellationToken"); diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 92b2200542..eb8191e43e 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -1960,7 +1960,7 @@ impl TenantManager { attempt.before_reset_tenant(); let (_guard, progress) = utils::completion::channel(); - match tenant.shutdown(progress, ShutdownMode::Flush).await { + match tenant.shutdown(progress, ShutdownMode::Reload).await { Ok(()) => { slot_guard.drop_old_value().expect("it was just shutdown"); } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index afd4664d01..730477a7f4 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -894,10 +894,11 @@ pub(crate) enum ShutdownMode { /// While we are flushing, we continue to accept read I/O for LSNs ingested before /// the call to [`Timeline::shutdown`]. FreezeAndFlush, - /// Only flush the layers to the remote storage without freezing any open layers. This is the - /// mode used by ancestor detach and any other operations that reloads a tenant but not increasing - /// the generation number. - Flush, + /// Only flush the layers to the remote storage without freezing any open layers. Flush the deletion + /// queue. This is the mode used by ancestor detach and any other operations that reloads a tenant + /// but not increasing the generation number. Note that this mode cannot be used at tenant shutdown, + /// as flushing the deletion queue at that time will cause shutdown-in-progress errors. + Reload, /// Shut down immediately, without waiting for any open layers to flush. Hard, } @@ -1818,7 +1819,7 @@ impl Timeline { } } - if let ShutdownMode::Flush = mode { + if let ShutdownMode::Reload = mode { // drain the upload queue self.remote_client.shutdown().await; if !self.remote_client.no_pending_work() { diff --git a/pageserver/src/tenant/timeline/offload.rs b/pageserver/src/tenant/timeline/offload.rs index 3595d743bc..3bfbfb5061 100644 --- a/pageserver/src/tenant/timeline/offload.rs +++ b/pageserver/src/tenant/timeline/offload.rs @@ -58,7 +58,7 @@ pub(crate) async fn offload_timeline( } // Now that the Timeline is in Stopping state, request all the related tasks to shut down. - timeline.shutdown(super::ShutdownMode::Flush).await; + timeline.shutdown(super::ShutdownMode::Reload).await; // TODO extend guard mechanism above with method // to make deletions possible while offloading is in progress From da1daa2426aa592f5d57a13c4b09b5d21bcbeaf7 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Wed, 27 Nov 2024 20:44:24 +0100 Subject: [PATCH 25/34] pageserver: only apply `ClearVmBits` on relevant shards (#9895) # Problem VM (visibility map) pages are stored and managed as any regular relation page, in the VM fork of the main relation. They are also sharded like other pages. Regular WAL writes to the VM pages (typically performed by vacuum) are routed to the correct shard as usual. However, VM pages are also updated via `ClearVmBits` metadata records emitted when main relation pages are updated. These metadata records were sent to all shards, like other metadata records. This had the following effects: * On shards responsible for VM pages, the `ClearVmBits` applies as expected. * On shard 0, which knows about the VM relation and its size but doesn't necessarily have any VM pages, the `ClearVmBits` writes may have been applied without also having applied the explicit WAL writes to VM pages. * If VM pages are spread across multiple shards (unlikely with 256MB stripe size), all shards may have applied `ClearVmBits` if the pages fall within their local view of the relation size, even for pages they do not own. * On other shards, this caused a relation size cache miss and a DbDir and RelDir lookup before dropping the `ClearVmBits`. With many relations, this could cause significant CPU overhead. This is not believed to be a correctness problem, but this will be verified in #9914. Resolves #9855. # Changes Route `ClearVmBits` metadata records only to the shards responsible for the VM pages. Verification of the current VM handling and cleanup of incomplete VM pages on shard 0 (and potentially elsewhere) is left as follow-up work. --- libs/wal_decoder/src/decoder.rs | 71 +++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 17 deletions(-) diff --git a/libs/wal_decoder/src/decoder.rs b/libs/wal_decoder/src/decoder.rs index 1895f25bfc..36c4b19266 100644 --- a/libs/wal_decoder/src/decoder.rs +++ b/libs/wal_decoder/src/decoder.rs @@ -4,6 +4,7 @@ use crate::models::*; use crate::serialized_batch::SerializedValueBatch; use bytes::{Buf, Bytes}; +use pageserver_api::key::rel_block_to_key; use pageserver_api::reltag::{RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; use postgres_ffi::pg_constants; @@ -32,7 +33,8 @@ impl InterpretedWalRecord { FlushUncommittedRecords::No }; - let metadata_record = MetadataRecord::from_decoded(&decoded, next_record_lsn, pg_version)?; + let metadata_record = + MetadataRecord::from_decoded_filtered(&decoded, shard, next_record_lsn, pg_version)?; let batch = SerializedValueBatch::from_decoded_filtered( decoded, shard, @@ -51,8 +53,13 @@ impl InterpretedWalRecord { } impl MetadataRecord { - fn from_decoded( + /// Builds a metadata record for this WAL record, if any. + /// + /// Only metadata records relevant for the given shard are emitted. Currently, most metadata + /// records are broadcast to all shards for simplicity, but this should be improved. + fn from_decoded_filtered( decoded: &DecodedWALRecord, + shard: &ShardIdentity, next_record_lsn: Lsn, pg_version: u32, ) -> anyhow::Result> { @@ -61,26 +68,27 @@ impl MetadataRecord { let mut buf = decoded.record.clone(); buf.advance(decoded.main_data_offset); - match decoded.xl_rmid { + // First, generate metadata records from the decoded WAL record. + let mut metadata_record = match decoded.xl_rmid { pg_constants::RM_HEAP_ID | pg_constants::RM_HEAP2_ID => { - Self::decode_heapam_record(&mut buf, decoded, pg_version) + Self::decode_heapam_record(&mut buf, decoded, pg_version)? } - pg_constants::RM_NEON_ID => Self::decode_neonmgr_record(&mut buf, decoded, pg_version), + pg_constants::RM_NEON_ID => Self::decode_neonmgr_record(&mut buf, decoded, pg_version)?, // Handle other special record types - pg_constants::RM_SMGR_ID => Self::decode_smgr_record(&mut buf, decoded), - pg_constants::RM_DBASE_ID => Self::decode_dbase_record(&mut buf, decoded, pg_version), + pg_constants::RM_SMGR_ID => Self::decode_smgr_record(&mut buf, decoded)?, + pg_constants::RM_DBASE_ID => Self::decode_dbase_record(&mut buf, decoded, pg_version)?, pg_constants::RM_TBLSPC_ID => { tracing::trace!("XLOG_TBLSPC_CREATE/DROP is not handled yet"); - Ok(None) + None } - pg_constants::RM_CLOG_ID => Self::decode_clog_record(&mut buf, decoded, pg_version), + pg_constants::RM_CLOG_ID => Self::decode_clog_record(&mut buf, decoded, pg_version)?, pg_constants::RM_XACT_ID => { - Self::decode_xact_record(&mut buf, decoded, next_record_lsn) + Self::decode_xact_record(&mut buf, decoded, next_record_lsn)? } pg_constants::RM_MULTIXACT_ID => { - Self::decode_multixact_record(&mut buf, decoded, pg_version) + Self::decode_multixact_record(&mut buf, decoded, pg_version)? } - pg_constants::RM_RELMAP_ID => Self::decode_relmap_record(&mut buf, decoded), + pg_constants::RM_RELMAP_ID => Self::decode_relmap_record(&mut buf, decoded)?, // This is an odd duck. It needs to go to all shards. // Since it uses the checkpoint image (that's initialized from CHECKPOINT_KEY // in WalIngest::new), we have to send the whole DecodedWalRecord::record to @@ -89,19 +97,48 @@ impl MetadataRecord { // Alternatively, one can make the checkpoint part of the subscription protocol // to the pageserver. This should work fine, but can be done at a later point. pg_constants::RM_XLOG_ID => { - Self::decode_xlog_record(&mut buf, decoded, next_record_lsn) + Self::decode_xlog_record(&mut buf, decoded, next_record_lsn)? } pg_constants::RM_LOGICALMSG_ID => { - Self::decode_logical_message_record(&mut buf, decoded) + Self::decode_logical_message_record(&mut buf, decoded)? } - pg_constants::RM_STANDBY_ID => Self::decode_standby_record(&mut buf, decoded), - pg_constants::RM_REPLORIGIN_ID => Self::decode_replorigin_record(&mut buf, decoded), + pg_constants::RM_STANDBY_ID => Self::decode_standby_record(&mut buf, decoded)?, + pg_constants::RM_REPLORIGIN_ID => Self::decode_replorigin_record(&mut buf, decoded)?, _unexpected => { // TODO: consider failing here instead of blindly doing something without // understanding the protocol - Ok(None) + None + } + }; + + // Next, filter the metadata record by shard. + + // Route VM page updates to the shards that own them. VM pages are stored in the VM fork + // of the main relation. These are sharded and managed just like regular relation pages. + // See: https://github.com/neondatabase/neon/issues/9855 + if let Some( + MetadataRecord::Heapam(HeapamRecord::ClearVmBits(ref mut clear_vm_bits)) + | MetadataRecord::Neonrmgr(NeonrmgrRecord::ClearVmBits(ref mut clear_vm_bits)), + ) = metadata_record + { + let is_local_vm_page = |heap_blk| { + let vm_blk = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blk); + shard.is_key_local(&rel_block_to_key(clear_vm_bits.vm_rel, vm_blk)) + }; + // Send the old and new VM page updates to their respective shards. + clear_vm_bits.old_heap_blkno = clear_vm_bits + .old_heap_blkno + .filter(|&blkno| is_local_vm_page(blkno)); + clear_vm_bits.new_heap_blkno = clear_vm_bits + .new_heap_blkno + .filter(|&blkno| is_local_vm_page(blkno)); + // If neither VM page belongs to this shard, discard the record. + if clear_vm_bits.old_heap_blkno.is_none() && clear_vm_bits.new_heap_blkno.is_none() { + metadata_record = None } } + + Ok(metadata_record) } fn decode_heapam_record( From 8173dc600ad68872f4e488c753f59b8a1e2093aa Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Thu, 28 Nov 2024 08:32:22 +0200 Subject: [PATCH 26/34] proxy: spawn cancellation checks in the background (#9918) ## Problem For cancellation, a connection is open during all the cancel checks. ## Summary of changes Spawn cancellation checks in the background, and close connection immediately. Use task_tracker for cancellation checks. --- proxy/src/cancellation.rs | 15 ++++++++----- proxy/src/console_redirect_proxy.rs | 35 +++++++++++++++++++++-------- proxy/src/proxy/mod.rs | 35 +++++++++++++++++++++-------- proxy/src/redis/notifications.rs | 2 +- proxy/src/serverless/mod.rs | 9 ++++++++ proxy/src/serverless/websocket.rs | 3 +++ 6 files changed, 75 insertions(+), 24 deletions(-) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 74415f1ffe..91e198bf88 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -99,16 +99,17 @@ impl CancellationHandler

{ /// Try to cancel a running query for the corresponding connection. /// If the cancellation key is not found, it will be published to Redis. /// check_allowed - if true, check if the IP is allowed to cancel the query + /// return Result primarily for tests pub(crate) async fn cancel_session( &self, key: CancelKeyData, session_id: Uuid, - peer_addr: &IpAddr, + peer_addr: IpAddr, check_allowed: bool, ) -> Result<(), CancelError> { // TODO: check for unspecified address is only for backward compatibility, should be removed if !peer_addr.is_unspecified() { - let subnet_key = match *peer_addr { + let subnet_key = match peer_addr { IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), }; @@ -141,9 +142,11 @@ impl CancellationHandler

{ return Ok(()); } - match self.client.try_publish(key, session_id, *peer_addr).await { + match self.client.try_publish(key, session_id, peer_addr).await { Ok(()) => {} // do nothing Err(e) => { + // log it here since cancel_session could be spawned in a task + tracing::error!("failed to publish cancellation key: {key}, error: {e}"); return Err(CancelError::IO(std::io::Error::new( std::io::ErrorKind::Other, e.to_string(), @@ -154,8 +157,10 @@ impl CancellationHandler

{ }; if check_allowed - && !check_peer_addr_is_in_list(peer_addr, cancel_closure.ip_allowlist.as_slice()) + && !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice()) { + // log it here since cancel_session could be spawned in a task + tracing::warn!("IP is not allowed to cancel the query: {key}"); return Err(CancelError::IpNotAllowed); } @@ -306,7 +311,7 @@ mod tests { cancel_key: 0, }, Uuid::new_v4(), - &("127.0.0.1".parse().unwrap()), + "127.0.0.1".parse().unwrap(), true, ) .await diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index b910b524b1..8f78df1964 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -35,6 +35,7 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -48,6 +49,7 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); + let cancellations = cancellations.clone(); debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); @@ -96,6 +98,7 @@ pub async fn task_main( cancellation_handler, socket, conn_gauge, + cancellations, ) .instrument(ctx.span()) .boxed() @@ -127,10 +130,12 @@ pub async fn task_main( } connections.close(); + cancellations.close(); drop(listener); // Drain connections connections.wait().await; + cancellations.wait().await; Ok(()) } @@ -142,6 +147,7 @@ pub(crate) async fn handle_client( cancellation_handler: Arc, stream: S, conn_gauge: NumClientConnectionsGuard<'static>, + cancellations: tokio_util::task::task_tracker::TaskTracker, ) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), @@ -161,15 +167,26 @@ pub(crate) async fn handle_client( match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancellation_handler - .cancel_session( - cancel_key_data, - ctx.session_id(), - &ctx.peer_addr(), - config.authentication_config.ip_allowlist_check_enabled, - ) - .await - .map(|()| None)?) + // spawn a task to cancel the session, but don't wait for it + cancellations.spawn({ + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let session_id = ctx.session_id(); + let peer_ip = ctx.peer_addr(); + async move { + drop( + cancellation_handler_clone + .cancel_session( + cancel_key_data, + session_id, + peer_ip, + config.authentication_config.ip_allowlist_check_enabled, + ) + .await, + ); + } + }); + + return Ok(None); } }; drop(pause); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 7fe67e43de..956036d29d 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -69,6 +69,7 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -82,6 +83,7 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); + let cancellations = cancellations.clone(); debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); @@ -133,6 +135,7 @@ pub async fn task_main( ClientMode::Tcp, endpoint_rate_limiter2, conn_gauge, + cancellations, ) .instrument(ctx.span()) .boxed() @@ -164,10 +167,12 @@ pub async fn task_main( } connections.close(); + cancellations.close(); drop(listener); // Drain connections connections.wait().await; + cancellations.wait().await; Ok(()) } @@ -250,6 +255,7 @@ pub(crate) async fn handle_client( mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: NumClientConnectionsGuard<'static>, + cancellations: tokio_util::task::task_tracker::TaskTracker, ) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), @@ -270,15 +276,26 @@ pub(crate) async fn handle_client( match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancellation_handler - .cancel_session( - cancel_key_data, - ctx.session_id(), - &ctx.peer_addr(), - config.authentication_config.ip_allowlist_check_enabled, - ) - .await - .map(|()| None)?) + // spawn a task to cancel the session, but don't wait for it + cancellations.spawn({ + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let session_id = ctx.session_id(); + let peer_ip = ctx.peer_addr(); + async move { + drop( + cancellation_handler_clone + .cancel_session( + cancel_key_data, + session_id, + peer_ip, + config.authentication_config.ip_allowlist_check_enabled, + ) + .await, + ); + } + }); + + return Ok(None); } }; drop(pause); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 65008ae943..9ac07b7e90 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -149,7 +149,7 @@ impl MessageHandler { .cancel_session( cancel_session.cancel_key_data, uuid::Uuid::nil(), - &peer_addr, + peer_addr, cancel_session.peer_addr.is_some(), ) .await diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 77025f419d..80b42f9e55 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -132,6 +132,7 @@ pub async fn task_main( let connections = tokio_util::task::task_tracker::TaskTracker::new(); connections.close(); // allows `connections.wait to complete` + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await { let (conn, peer_addr) = res.context("could not accept TCP stream")?; if let Err(e) = conn.set_nodelay(true) { @@ -160,6 +161,7 @@ pub async fn task_main( let connections2 = connections.clone(); let cancellation_handler = cancellation_handler.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellations = cancellations.clone(); connections.spawn( async move { let conn_token2 = conn_token.clone(); @@ -188,6 +190,7 @@ pub async fn task_main( config, backend, connections2, + cancellations, cancellation_handler, endpoint_rate_limiter, conn_token, @@ -313,6 +316,7 @@ async fn connection_handler( config: &'static ProxyConfig, backend: Arc, connections: TaskTracker, + cancellations: TaskTracker, cancellation_handler: Arc, endpoint_rate_limiter: Arc, cancellation_token: CancellationToken, @@ -353,6 +357,7 @@ async fn connection_handler( // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. // By spawning the future, we ensure it never gets cancelled until it decides to. + let cancellations = cancellations.clone(); let handler = connections.spawn( request_handler( req, @@ -364,6 +369,7 @@ async fn connection_handler( conn_info2.clone(), http_request_token, endpoint_rate_limiter.clone(), + cancellations, ) .in_current_span() .map_ok_or_else(api_error_into_response, |r| r), @@ -411,6 +417,7 @@ async fn request_handler( // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellations: TaskTracker, ) -> Result>, ApiError> { let host = request .headers() @@ -436,6 +443,7 @@ async fn request_handler( let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) .map_err(|e| ApiError::BadRequest(e.into()))?; + let cancellations = cancellations.clone(); ws_connections.spawn( async move { if let Err(e) = websocket::serve_websocket( @@ -446,6 +454,7 @@ async fn request_handler( cancellation_handler, endpoint_rate_limiter, host, + cancellations, ) .await { diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 4088fea835..bdb83fe6be 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -123,6 +123,7 @@ impl AsyncBufRead for WebSocketRw { } } +#[allow(clippy::too_many_arguments)] pub(crate) async fn serve_websocket( config: &'static ProxyConfig, auth_backend: &'static crate::auth::Backend<'static, ()>, @@ -131,6 +132,7 @@ pub(crate) async fn serve_websocket( cancellation_handler: Arc, endpoint_rate_limiter: Arc, hostname: Option, + cancellations: tokio_util::task::task_tracker::TaskTracker, ) -> anyhow::Result<()> { let websocket = websocket.await?; let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket)); @@ -149,6 +151,7 @@ pub(crate) async fn serve_websocket( ClientMode::Websockets { hostname }, endpoint_rate_limiter, conn_gauge, + cancellations, )) .await; From e82f7f0dfc1571ddbbb4ff37c1c94579a7101834 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Thu, 28 Nov 2024 10:11:08 +0000 Subject: [PATCH 27/34] remote_storage/abs: count 404 and 304 for get as ok for metrics (#9912) ## Problem We currently see elevated levels of errors for GetBlob requests. This is because 404 and 304 are counted as errors for metric reporting. ## Summary of Changes Bring the implementation in line with the S3 client and treat 404 and 304 responses as ok for metric purposes. Related: https://github.com/neondatabase/cloud/issues/20666 --- libs/remote_storage/src/azure_blob.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/libs/remote_storage/src/azure_blob.rs b/libs/remote_storage/src/azure_blob.rs index ae0a94295c..840917ef68 100644 --- a/libs/remote_storage/src/azure_blob.rs +++ b/libs/remote_storage/src/azure_blob.rs @@ -220,6 +220,11 @@ impl AzureBlobStorage { let started_at = ScopeGuard::into_inner(started_at); let outcome = match &download { Ok(_) => AttemptOutcome::Ok, + // At this level in the stack 404 and 304 responses do not indicate an error. + // There's expected cases when a blob may not exist or hasn't been modified since + // the last get (e.g. probing for timeline indices and heatmap downloads). + // Callers should handle errors if they are unexpected. + Err(DownloadError::NotFound | DownloadError::Unmodified) => AttemptOutcome::Ok, Err(_) => AttemptOutcome::Err, }; crate::metrics::BUCKET_METRICS From 70780e310c9640650eeb8b5cb0838bebd1c6c0ff Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Thu, 28 Nov 2024 16:48:18 +0100 Subject: [PATCH 28/34] Makefile: build pg_visibility (#9922) Build the `pg_visibility` extension for use with `neon_local`. This is useful to inspect the visibility map for debugging. Touches #9914. --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index dc67b87239..9cffc74508 100644 --- a/Makefile +++ b/Makefile @@ -147,6 +147,8 @@ postgres-%: postgres-configure-% \ $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_prewarm install +@echo "Compiling pg_buffercache $*" $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_buffercache install + +@echo "Compiling pg_visibility $*" + $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_visibility install +@echo "Compiling pageinspect $*" $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pageinspect install +@echo "Compiling amcheck $*" From eb5d832e6fea0e1c3c14b9e6024fce916c3f1c32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Thu, 28 Nov 2024 16:49:30 +0100 Subject: [PATCH 29/34] Update rust to 1.83.0, also update cargo adjacent tools (#9926) We keep the practice of keeping the compiler up to date, pointing to the latest release. This is done by many other projects in the Rust ecosystem as well. [Release notes](https://releases.rs/docs/1.83.0/). Also update `cargo-hakari`, `cargo-deny`, `cargo-hack` and `cargo-nextest` to their latest versions. Prior update was in #9445. --- build-tools.Dockerfile | 18 +++++++++--------- rust-toolchain.toml | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/build-tools.Dockerfile b/build-tools.Dockerfile index 4f491afec5..2671702697 100644 --- a/build-tools.Dockerfile +++ b/build-tools.Dockerfile @@ -57,9 +57,9 @@ RUN mkdir -p /pgcopydb/bin && \ mkdir -p /pgcopydb/lib && \ chmod -R 755 /pgcopydb && \ chown -R nonroot:nonroot /pgcopydb - -COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb -COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5 + +COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb +COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5 # System deps # @@ -258,14 +258,14 @@ WORKDIR /home/nonroot # Rust # Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`) -ENV RUSTC_VERSION=1.82.0 +ENV RUSTC_VERSION=1.83.0 ENV RUSTUP_HOME="/home/nonroot/.rustup" ENV PATH="/home/nonroot/.cargo/bin:${PATH}" ARG RUSTFILT_VERSION=0.2.1 -ARG CARGO_HAKARI_VERSION=0.9.30 -ARG CARGO_DENY_VERSION=0.16.1 -ARG CARGO_HACK_VERSION=0.6.31 -ARG CARGO_NEXTEST_VERSION=0.9.72 +ARG CARGO_HAKARI_VERSION=0.9.33 +ARG CARGO_DENY_VERSION=0.16.2 +ARG CARGO_HACK_VERSION=0.6.33 +ARG CARGO_NEXTEST_VERSION=0.9.85 RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \ chmod +x rustup-init && \ ./rustup-init -y --default-toolchain ${RUSTC_VERSION} && \ @@ -289,7 +289,7 @@ RUN whoami \ && cargo --version --verbose \ && rustup --version --verbose \ && rustc --version --verbose \ - && clang --version + && clang --version RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \ LD_LIBRARY_PATH=/pgcopydb/lib /pgcopydb/bin/pgcopydb --version; \ diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 92b7929c7f..f0661a32e0 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.82.0" +channel = "1.83.0" profile = "default" # The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy. # https://rust-lang.github.io/rustup/concepts/profiles.html From eb520a14ce12dc16f33f39964632982f6c14b9f3 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Thu, 28 Nov 2024 17:38:47 +0000 Subject: [PATCH 30/34] pageserver: return correct LSN for interpreted proto keep alive responses (#9928) ## Problem For the interpreted proto the pageserver is not returning the correct LSN in replies to keep alive requests. This is because the interpreted protocol arm was not updating `last_rec_lsn`. ## Summary of changes * Return correct LSN in keep-alive responses * Fix shard field in wal sender traces --- .../tenant/timeline/walreceiver/walreceiver_connection.rs | 4 ++++ safekeeper/src/handler.rs | 5 +++-- safekeeper/src/wal_service.rs | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 31cf1b6307..d90ffbfa2c 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -454,6 +454,10 @@ pub(super) async fn handle_walreceiver_connection( timeline.get_last_record_lsn() ); + if let Some(lsn) = next_record_lsn { + last_rec_lsn = lsn; + } + Some(streaming_lsn) } diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index 22f33b17e0..8dd2929a03 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -212,8 +212,9 @@ impl postgres_backend::Handler ); if let Some(shard) = self.shard.as_ref() { - tracing::Span::current() - .record("shard", tracing::field::display(shard.shard_slug())); + if let Some(slug) = shard.shard_slug().strip_prefix("-") { + tracing::Span::current().record("shard", tracing::field::display(slug)); + } } Ok(()) diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index 1ab54d4cce..5248d545db 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -44,7 +44,7 @@ pub async fn task_main( error!("connection handler exited: {}", err); } } - .instrument(info_span!("", cid = %conn_id, ttid = field::Empty, application_name = field::Empty)), + .instrument(info_span!("", cid = %conn_id, ttid = field::Empty, application_name = field::Empty, shard = field::Empty)), ); } } From e04dd3be0b6bd6702fa6e3301c9b7202d72ccc1c Mon Sep 17 00:00:00 2001 From: Alexander Bayandin Date: Thu, 28 Nov 2024 19:02:57 +0000 Subject: [PATCH 31/34] test_runner: rerun all failed tests (#9917) ## Problem Currently, we rerun only known flaky tests. This approach was chosen to reduce the number of tests that go unnoticed (by forcing people to take a look at failed tests and rerun the job manually), but it has some drawbacks: - In PRs, people tend to push new changes without checking failed tests (that's ok) - In the main, tests are just restarted without checking (understandable) - Parametrised tests become flaky one by one, i.e. if `test[1]` is flaky `, test[2]` is not marked as flaky automatically (which may or may not be the case). I suggest rerunning all failed tests to increase the stability of GitHub jobs and using the Grafana Dashboard with flaky tests for deeper analysis. ## Summary of changes - Rerun all failed tests twice at max --- .../actions/run-python-test-set/action.yml | 17 +- .github/workflows/_build-and-test-locally.yml | 2 +- poetry.lock | 12 +- pyproject.toml | 2 +- scripts/flaky_tests.py | 147 ------------------ test_runner/conftest.py | 2 +- test_runner/fixtures/flaky.py | 78 ---------- test_runner/fixtures/paths.py | 2 +- test_runner/fixtures/reruns.py | 31 ++++ 9 files changed, 46 insertions(+), 247 deletions(-) delete mode 100755 scripts/flaky_tests.py delete mode 100644 test_runner/fixtures/flaky.py create mode 100644 test_runner/fixtures/reruns.py diff --git a/.github/actions/run-python-test-set/action.yml b/.github/actions/run-python-test-set/action.yml index 275f161019..1159627302 100644 --- a/.github/actions/run-python-test-set/action.yml +++ b/.github/actions/run-python-test-set/action.yml @@ -36,8 +36,8 @@ inputs: description: 'Region name for real s3 tests' required: false default: '' - rerun_flaky: - description: 'Whether to rerun flaky tests' + rerun_failed: + description: 'Whether to rerun failed tests' required: false default: 'false' pg_version: @@ -108,7 +108,7 @@ runs: COMPATIBILITY_SNAPSHOT_DIR: /tmp/compatibility_snapshot_pg${{ inputs.pg_version }} ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'backward compatibility breakage') ALLOW_FORWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'forward compatibility breakage') - RERUN_FLAKY: ${{ inputs.rerun_flaky }} + RERUN_FAILED: ${{ inputs.rerun_failed }} PG_VERSION: ${{ inputs.pg_version }} shell: bash -euxo pipefail {0} run: | @@ -154,15 +154,8 @@ runs: EXTRA_PARAMS="--out-dir $PERF_REPORT_DIR $EXTRA_PARAMS" fi - if [ "${RERUN_FLAKY}" == "true" ]; then - mkdir -p $TEST_OUTPUT - poetry run ./scripts/flaky_tests.py "${TEST_RESULT_CONNSTR}" \ - --days 7 \ - --output "$TEST_OUTPUT/flaky.json" \ - --pg-version "${DEFAULT_PG_VERSION}" \ - --build-type "${BUILD_TYPE}" - - EXTRA_PARAMS="--flaky-tests-json $TEST_OUTPUT/flaky.json $EXTRA_PARAMS" + if [ "${RERUN_FAILED}" == "true" ]; then + EXTRA_PARAMS="--reruns 2 $EXTRA_PARAMS" fi # We use pytest-split plugin to run benchmarks in parallel on different CI runners diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index bdf7c07c6a..42c32a23e3 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -293,7 +293,7 @@ jobs: run_with_real_s3: true real_s3_bucket: neon-github-ci-tests real_s3_region: eu-central-1 - rerun_flaky: true + rerun_failed: true pg_version: ${{ matrix.pg_version }} env: TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }} diff --git a/poetry.lock b/poetry.lock index e2fca7be47..59ae5cf1ca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2563,18 +2563,18 @@ pytest = "*" [[package]] name = "pytest-rerunfailures" -version = "13.0" +version = "15.0" description = "pytest plugin to re-run tests to eliminate flaky failures" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" files = [ - {file = "pytest-rerunfailures-13.0.tar.gz", hash = "sha256:e132dbe420bc476f544b96e7036edd0a69707574209b6677263c950d19b09199"}, - {file = "pytest_rerunfailures-13.0-py3-none-any.whl", hash = "sha256:34919cb3fcb1f8e5d4b940aa75ccdea9661bade925091873b7c6fa5548333069"}, + {file = "pytest-rerunfailures-15.0.tar.gz", hash = "sha256:2d9ac7baf59f4c13ac730b47f6fa80e755d1ba0581da45ce30b72fb3542b4474"}, + {file = "pytest_rerunfailures-15.0-py3-none-any.whl", hash = "sha256:dd150c4795c229ef44320adc9a0c0532c51b78bb7a6843a8c53556b9a611df1a"}, ] [package.dependencies] packaging = ">=17.1" -pytest = ">=7" +pytest = ">=7.4,<8.2.2 || >8.2.2" [[package]] name = "pytest-split" @@ -3524,4 +3524,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "21debe1116843e5d14bdf37d6e265c68c63a98a64ba04ec8b8a02af2e8d9f486" +content-hash = "426c385df93f578ba3537c40a269535e27fbcca1978b3cf266096ecbc298c6a9" diff --git a/pyproject.toml b/pyproject.toml index ccd3ab1864..01d15ee6bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ types-psutil = "^5.9.5.12" types-toml = "^0.10.8.6" pytest-httpserver = "^1.0.8" aiohttp = "3.10.11" -pytest-rerunfailures = "^13.0" +pytest-rerunfailures = "^15.0" types-pytest-lazy-fixture = "^0.6.3.3" pytest-split = "^0.8.1" zstandard = "^0.21.0" diff --git a/scripts/flaky_tests.py b/scripts/flaky_tests.py deleted file mode 100755 index 3fb668ed2d..0000000000 --- a/scripts/flaky_tests.py +++ /dev/null @@ -1,147 +0,0 @@ -#! /usr/bin/env python3 - -from __future__ import annotations - -import argparse -import json -import logging -import os -from collections import defaultdict -from typing import TYPE_CHECKING - -import psycopg2 -import psycopg2.extras -import toml - -if TYPE_CHECKING: - from typing import Any - -FLAKY_TESTS_QUERY = """ - SELECT - DISTINCT parent_suite, suite, name - FROM results - WHERE - started_at > CURRENT_DATE - INTERVAL '%s' day - AND ( - (status IN ('failed', 'broken') AND reference = 'refs/heads/main') - OR flaky - ) - ; -""" - - -def main(args: argparse.Namespace): - connstr = args.connstr - interval_days = args.days - output = args.output - - build_type = args.build_type - pg_version = args.pg_version - - res: defaultdict[str, defaultdict[str, dict[str, bool]]] - res = defaultdict(lambda: defaultdict(dict)) - - try: - logging.info("connecting to the database...") - with psycopg2.connect(connstr, connect_timeout=30) as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - logging.info("fetching flaky tests...") - cur.execute(FLAKY_TESTS_QUERY, (interval_days,)) - rows = cur.fetchall() - except psycopg2.OperationalError as exc: - logging.error("cannot fetch flaky tests from the DB due to an error", exc) - rows = [] - - # If a test run has non-default PAGESERVER_VIRTUAL_FILE_IO_ENGINE (i.e. not empty, not tokio-epoll-uring), - # use it to parametrize test name along with build_type and pg_version - # - # See test_runner/fixtures/parametrize.py for details - if (io_engine := os.getenv("PAGESERVER_VIRTUAL_FILE_IO_ENGINE", "")) not in ( - "", - "tokio-epoll-uring", - ): - pageserver_virtual_file_io_engine_parameter = f"-{io_engine}" - else: - pageserver_virtual_file_io_engine_parameter = "" - - # re-use existing records of flaky tests from before parametrization by compaction_algorithm - def get_pageserver_default_tenant_config_compaction_algorithm() -> dict[str, Any] | None: - """Duplicated from parametrize.py""" - toml_table = os.getenv("PAGESERVER_DEFAULT_TENANT_CONFIG_COMPACTION_ALGORITHM") - if toml_table is None: - return None - v = toml.loads(toml_table) - assert isinstance(v, dict) - return v - - pageserver_default_tenant_config_compaction_algorithm_parameter = "" - if ( - explicit_default := get_pageserver_default_tenant_config_compaction_algorithm() - ) is not None: - pageserver_default_tenant_config_compaction_algorithm_parameter = ( - f"-{explicit_default['kind']}" - ) - - for row in rows: - # We don't want to automatically rerun tests in a performance suite - if row["parent_suite"] != "test_runner.regress": - continue - - if row["name"].endswith("]"): - parametrized_test = row["name"].replace( - "[", - f"[{build_type}-pg{pg_version}{pageserver_virtual_file_io_engine_parameter}{pageserver_default_tenant_config_compaction_algorithm_parameter}-", - ) - else: - parametrized_test = f"{row['name']}[{build_type}-pg{pg_version}{pageserver_virtual_file_io_engine_parameter}{pageserver_default_tenant_config_compaction_algorithm_parameter}]" - - res[row["parent_suite"]][row["suite"]][parametrized_test] = True - - logging.info( - f"\t{row['parent_suite'].replace('.', '/')}/{row['suite']}.py::{parametrized_test}" - ) - - logging.info(f"saving results to {output.name}") - json.dump(res, output, indent=2) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Detect flaky tests in the last N days") - parser.add_argument( - "--output", - type=argparse.FileType("w"), - default="flaky.json", - help="path to output json file (default: flaky.json)", - ) - parser.add_argument( - "--days", - required=False, - default=10, - type=int, - help="how many days to look back for flaky tests (default: 10)", - ) - parser.add_argument( - "--build-type", - required=True, - type=str, - help="for which build type to create list of flaky tests (debug or release)", - ) - parser.add_argument( - "--pg-version", - required=True, - type=int, - help="for which Postgres version to create list of flaky tests (14, 15, etc.)", - ) - parser.add_argument( - "connstr", - help="connection string to the test results database", - ) - args = parser.parse_args() - - level = logging.INFO - logging.basicConfig( - format="%(message)s", - level=level, - ) - - main(args) diff --git a/test_runner/conftest.py b/test_runner/conftest.py index 84eda52d33..887bfef478 100644 --- a/test_runner/conftest.py +++ b/test_runner/conftest.py @@ -13,5 +13,5 @@ pytest_plugins = ( "fixtures.pg_stats", "fixtures.compare_fixtures", "fixtures.slow", - "fixtures.flaky", + "fixtures.reruns", ) diff --git a/test_runner/fixtures/flaky.py b/test_runner/fixtures/flaky.py deleted file mode 100644 index 01634a29c5..0000000000 --- a/test_runner/fixtures/flaky.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import MutableMapping -from pathlib import Path -from typing import TYPE_CHECKING, cast - -import pytest -from _pytest.config import Config -from _pytest.config.argparsing import Parser -from allure_commons.types import LabelType -from allure_pytest.utils import allure_name, allure_suite_labels - -from fixtures.log_helper import log - -if TYPE_CHECKING: - from collections.abc import MutableMapping - from typing import Any - - -""" -The plugin reruns flaky tests. -It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py` - -Note: the logic of getting flaky tests is extracted to a separate script to avoid running it for each of N xdist workers -""" - - -def pytest_addoption(parser: Parser): - parser.addoption( - "--flaky-tests-json", - action="store", - type=Path, - help="Path to json file with flaky tests generated by scripts/flaky_tests.py", - ) - - -def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]): - if not config.getoption("--flaky-tests-json"): - return - - # Any error with getting flaky tests aren't critical, so just do not rerun any tests - flaky_json = config.getoption("--flaky-tests-json") - if not flaky_json.exists(): - return - - content = flaky_json.read_text() - try: - flaky_tests = json.loads(content) - except ValueError: - log.error(f"Can't parse {content} as json") - return - - for item in items: - # Use the same logic for constructing test name as Allure does (we store allure-provided data in DB) - # Ref https://github.com/allure-framework/allure-python/blob/2.13.1/allure-pytest/src/listener.py#L98-L100 - allure_labels = dict(allure_suite_labels(item)) - parent_suite = str(allure_labels.get(LabelType.PARENT_SUITE)) - suite = str(allure_labels.get(LabelType.SUITE)) - params = item.callspec.params if hasattr(item, "callspec") else {} - name = allure_name(item, params) - - if flaky_tests.get(parent_suite, {}).get(suite, {}).get(name, False): - # Rerun 3 times = 1 original run + 2 reruns - log.info(f"Marking {item.nodeid} as flaky. It will be rerun up to 3 times") - item.add_marker(pytest.mark.flaky(reruns=2)) - - # pytest-rerunfailures is not compatible with pytest-timeout (timeout is not set for reruns), - # we can workaround it by setting `timeout_func_only` to True[1]. - # Unfortunately, setting `timeout_func_only = True` globally in pytest.ini is broken[2], - # but we still can do it using pytest marker. - # - # - [1] https://github.com/pytest-dev/pytest-rerunfailures/issues/99 - # - [2] https://github.com/pytest-dev/pytest-timeout/issues/142 - timeout_marker = item.get_closest_marker("timeout") - if timeout_marker is not None: - kwargs = cast("MutableMapping[str, Any]", timeout_marker.kwargs) - kwargs["func_only"] = True diff --git a/test_runner/fixtures/paths.py b/test_runner/fixtures/paths.py index 1c71abea19..80777d65e9 100644 --- a/test_runner/fixtures/paths.py +++ b/test_runner/fixtures/paths.py @@ -30,7 +30,7 @@ def get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str | No test_name = request.node.name test_dir = top_output_dir / f"{prefix or ''}{test_name.replace('/', '-')}" - # We rerun flaky tests multiple times, use a separate directory for each run. + # We rerun failed tests multiple times, use a separate directory for each run. if (suffix := getattr(request.node, "execution_count", None)) is not None: test_dir = test_dir.parent / f"{test_dir.name}-{suffix}" diff --git a/test_runner/fixtures/reruns.py b/test_runner/fixtures/reruns.py new file mode 100644 index 0000000000..f2a25ae8f6 --- /dev/null +++ b/test_runner/fixtures/reruns.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from collections.abc import MutableMapping +from typing import TYPE_CHECKING, cast + +import pytest + +if TYPE_CHECKING: + from collections.abc import MutableMapping + from typing import Any + + from _pytest.config import Config + + +def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]): + # pytest-rerunfailures is not compatible with pytest-timeout (timeout is not set for reruns), + # we can workaround it by setting `timeout_func_only` to True[1]. + # Unfortunately, setting `timeout_func_only = True` globally in pytest.ini is broken[2], + # but we still can do it using pytest marker. + # + # - [1] https://github.com/pytest-dev/pytest-rerunfailures/issues/99 + # - [2] https://github.com/pytest-dev/pytest-timeout/issues/142 + + if not config.getoption("--reruns"): + return + + for item in items: + timeout_marker = item.get_closest_marker("timeout") + if timeout_marker is not None: + kwargs = cast("MutableMapping[str, Any]", timeout_marker.kwargs) + kwargs["func_only"] = True From 42fb3c4d30bf93ad0ad85bbd636a4262d205f673 Mon Sep 17 00:00:00 2001 From: Alexey Kondratov Date: Thu, 28 Nov 2024 22:38:30 +0100 Subject: [PATCH 32/34] fix(compute_ctl): Allow usage of DB names with whitespaces (#9919) ## Problem We used `set_path()` to replace the database name in the connection string. It automatically does url-safe encoding if the path is not already encoded, but it does it as per the URL standard, which assumes that tabs can be safely removed from the path without changing the meaning of the URL. See, e.g., https://url.spec.whatwg.org/#concept-basic-url-parser. It also breaks for DBs with properly %-encoded names, like with `%20`, as they are kept intact, but actually should be escaped. Yet, this is not true for Postgres, where it's completely valid to have trailing tabs in the database name. I think this is the PR that caused this regression https://github.com/neondatabase/neon/pull/9717, as it switched from `postgres::config::Config` back to `set_path()`. This was fixed a while ago already [1], btw, I just haven't added a test to catch this regression back then :( ## Summary of changes This commit changes the code back to use `postgres/tokio_postgres::Config` everywhere. While on it, also do some changes around, as I had to touch this code: 1. Bump some logging from `debug` to `info` in the spec apply path. We do not use `debug` in prod, and it was tricky to understand what was going on with this bug in prod. 2. Refactor configuration concurrency calculation code so it was reusable. Yet, still keep `1` in the case of reconfiguration. The database can be actively used at this moment, so we cannot guarantee that there will be enough spare connection slots, and the underlying code won't handle connection errors properly. 3. Simplify the installed extensions code. It was spawning a blocking task inside async function, which doesn't make much sense. Instead, just have a main sync function and call it with `spawn_blocking` in the API code -- the only place we need it to be async. 4. Add regression python test to cover this and related problems in the future. Also, add more extensive testing of schema dump and DBs and roles listing API. [1]: https://github.com/neondatabase/neon/commit/4d1e48f3b9a4b7064787513fd2c455f0001f6e18 [2]: https://www.postgresql.org/message-id/flat/20151023003445.931.91267%40wrigleys.postgresql.org Resolves neondatabase/cloud#20869 --- compute_tools/src/catalog.rs | 39 ++++- compute_tools/src/compute.rs | 153 +++++++++++--------- compute_tools/src/http/api.rs | 7 +- compute_tools/src/installed_extensions.rs | 105 +++++--------- compute_tools/src/pg_helpers.rs | 11 ++ test_runner/fixtures/endpoint/http.py | 6 +- test_runner/fixtures/neon_fixtures.py | 29 ++++ test_runner/regress/test_compute_catalog.py | 111 +++++++++++++- 8 files changed, 318 insertions(+), 143 deletions(-) diff --git a/compute_tools/src/catalog.rs b/compute_tools/src/catalog.rs index 2f6f82dd39..08ae8bf44d 100644 --- a/compute_tools/src/catalog.rs +++ b/compute_tools/src/catalog.rs @@ -1,4 +1,3 @@ -use compute_api::responses::CatalogObjects; use futures::Stream; use postgres::NoTls; use std::{path::Path, process::Stdio, result::Result, sync::Arc}; @@ -13,7 +12,8 @@ use tokio_util::codec::{BytesCodec, FramedRead}; use tracing::warn; use crate::compute::ComputeNode; -use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async}; +use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async, postgres_conf_for_db}; +use compute_api::responses::CatalogObjects; pub async fn get_dbs_and_roles(compute: &Arc) -> anyhow::Result { let connstr = compute.connstr.clone(); @@ -43,6 +43,8 @@ pub enum SchemaDumpError { DatabaseDoesNotExist, #[error("Failed to execute pg_dump.")] IO(#[from] std::io::Error), + #[error("Unexpected error.")] + Unexpected, } // It uses the pg_dump utility to dump the schema of the specified database. @@ -60,11 +62,38 @@ pub async fn get_database_schema( let pgbin = &compute.pgbin; let basepath = Path::new(pgbin).parent().unwrap(); let pgdump = basepath.join("pg_dump"); - let mut connstr = compute.connstr.clone(); - connstr.set_path(dbname); + + // Replace the DB in the connection string and disable it to parts. + // This is the only option to handle DBs with special characters. + let conf = + postgres_conf_for_db(&compute.connstr, dbname).map_err(|_| SchemaDumpError::Unexpected)?; + let host = conf + .get_hosts() + .first() + .ok_or(SchemaDumpError::Unexpected)?; + let host = match host { + tokio_postgres::config::Host::Tcp(ip) => ip.to_string(), + #[cfg(unix)] + tokio_postgres::config::Host::Unix(path) => path.to_string_lossy().to_string(), + }; + let port = conf + .get_ports() + .first() + .ok_or(SchemaDumpError::Unexpected)?; + let user = conf.get_user().ok_or(SchemaDumpError::Unexpected)?; + let dbname = conf.get_dbname().ok_or(SchemaDumpError::Unexpected)?; + let mut cmd = Command::new(pgdump) + // XXX: this seems to be the only option to deal with DBs with `=` in the name + // See + .env("PGDATABASE", dbname) + .arg("--host") + .arg(host) + .arg("--port") + .arg(port.to_string()) + .arg("--username") + .arg(user) .arg("--schema-only") - .arg(connstr.as_str()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .kill_on_drop(true) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 4f67425ba8..1a026a4014 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -34,9 +34,8 @@ use utils::measured_stream::MeasuredReader; use nix::sys::signal::{kill, Signal}; use remote_storage::{DownloadError, RemotePath}; use tokio::spawn; -use url::Url; -use crate::installed_extensions::get_installed_extensions_sync; +use crate::installed_extensions::get_installed_extensions; use crate::local_proxy; use crate::pg_helpers::*; use crate::spec::*; @@ -816,30 +815,32 @@ impl ComputeNode { Ok(()) } - async fn get_maintenance_client(url: &Url) -> Result { - let mut connstr = url.clone(); + async fn get_maintenance_client( + conf: &tokio_postgres::Config, + ) -> Result { + let mut conf = conf.clone(); - connstr - .query_pairs_mut() - .append_pair("application_name", "apply_config"); + conf.application_name("apply_config"); - let (client, conn) = match tokio_postgres::connect(connstr.as_str(), NoTls).await { + let (client, conn) = match conf.connect(NoTls).await { + // If connection fails, it may be the old node with `zenith_admin` superuser. + // + // In this case we need to connect with old `zenith_admin` name + // and create new user. We cannot simply rename connected user, + // but we can create a new one and grant it all privileges. Err(e) => match e.code() { Some(&SqlState::INVALID_PASSWORD) | Some(&SqlState::INVALID_AUTHORIZATION_SPECIFICATION) => { - // connect with zenith_admin if cloud_admin could not authenticate + // Connect with zenith_admin if cloud_admin could not authenticate info!( "cannot connect to postgres: {}, retrying with `zenith_admin` username", e ); - let mut zenith_admin_connstr = connstr.clone(); - - zenith_admin_connstr - .set_username("zenith_admin") - .map_err(|_| anyhow::anyhow!("invalid connstr"))?; + let mut zenith_admin_conf = postgres::config::Config::from(conf.clone()); + zenith_admin_conf.user("zenith_admin"); let mut client = - Client::connect(zenith_admin_connstr.as_str(), NoTls) + zenith_admin_conf.connect(NoTls) .context("broken cloud_admin credential: tried connecting with cloud_admin but could not authenticate, and zenith_admin does not work either")?; // Disable forwarding so that users don't get a cloud_admin role @@ -853,8 +854,8 @@ impl ComputeNode { drop(client); - // reconnect with connstring with expected name - tokio_postgres::connect(connstr.as_str(), NoTls).await? + // Reconnect with connstring with expected name + conf.connect(NoTls).await? } _ => return Err(e.into()), }, @@ -885,7 +886,7 @@ impl ComputeNode { pub fn apply_spec_sql( &self, spec: Arc, - url: Arc, + conf: Arc, concurrency: usize, ) -> Result<()> { let rt = tokio::runtime::Builder::new_multi_thread() @@ -897,7 +898,7 @@ impl ComputeNode { rt.block_on(async { // Proceed with post-startup configuration. Note, that order of operations is important. - let client = Self::get_maintenance_client(&url).await?; + let client = Self::get_maintenance_client(&conf).await?; let spec = spec.clone(); let databases = get_existing_dbs_async(&client).await?; @@ -931,7 +932,7 @@ impl ComputeNode { RenameAndDeleteDatabases, CreateAndAlterDatabases, ] { - debug!("Applying phase {:?}", &phase); + info!("Applying phase {:?}", &phase); apply_operations( spec.clone(), ctx.clone(), @@ -942,6 +943,7 @@ impl ComputeNode { .await?; } + info!("Applying RunInEachDatabase phase"); let concurrency_token = Arc::new(tokio::sync::Semaphore::new(concurrency)); let db_processes = spec @@ -955,7 +957,7 @@ impl ComputeNode { let spec = spec.clone(); let ctx = ctx.clone(); let jwks_roles = jwks_roles.clone(); - let mut url = url.as_ref().clone(); + let mut conf = conf.as_ref().clone(); let concurrency_token = concurrency_token.clone(); let db = db.clone(); @@ -964,14 +966,14 @@ impl ComputeNode { match &db { DB::SystemDB => {} DB::UserDB(db) => { - url.set_path(db.name.as_str()); + conf.dbname(db.name.as_str()); } } - let url = Arc::new(url); + let conf = Arc::new(conf); let fut = Self::apply_spec_sql_db( spec.clone(), - url, + conf, ctx.clone(), jwks_roles.clone(), concurrency_token.clone(), @@ -1017,7 +1019,7 @@ impl ComputeNode { /// semaphore. The caller has to make sure the semaphore isn't exhausted. async fn apply_spec_sql_db( spec: Arc, - url: Arc, + conf: Arc, ctx: Arc>, jwks_roles: Arc>, concurrency_token: Arc, @@ -1046,7 +1048,7 @@ impl ComputeNode { // that database. || async { if client_conn.is_none() { - let db_client = Self::get_maintenance_client(&url).await?; + let db_client = Self::get_maintenance_client(&conf).await?; client_conn.replace(db_client); } let client = client_conn.as_ref().unwrap(); @@ -1061,34 +1063,16 @@ impl ComputeNode { Ok::<(), anyhow::Error>(()) } - /// Do initial configuration of the already started Postgres. - #[instrument(skip_all)] - pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> { - // If connection fails, - // it may be the old node with `zenith_admin` superuser. - // - // In this case we need to connect with old `zenith_admin` name - // and create new user. We cannot simply rename connected user, - // but we can create a new one and grant it all privileges. - let mut url = self.connstr.clone(); - url.query_pairs_mut() - .append_pair("application_name", "apply_config"); - - let url = Arc::new(url); - let spec = Arc::new( - compute_state - .pspec - .as_ref() - .expect("spec must be set") - .spec - .clone(), - ); - - // Choose how many concurrent connections to use for applying the spec changes. - // If the cluster is not currently Running we don't have to deal with user connections, + /// Choose how many concurrent connections to use for applying the spec changes. + pub fn max_service_connections( + &self, + compute_state: &ComputeState, + spec: &ComputeSpec, + ) -> usize { + // If the cluster is in Init state we don't have to deal with user connections, // and can thus use all `max_connections` connection slots. However, that's generally not // very efficient, so we generally still limit it to a smaller number. - let max_concurrent_connections = if compute_state.status != ComputeStatus::Running { + if compute_state.status == ComputeStatus::Init { // If the settings contain 'max_connections', use that as template if let Some(config) = spec.cluster.settings.find("max_connections") { config.parse::().ok() @@ -1144,10 +1128,29 @@ impl ComputeNode { .map(|val| if val > 1 { val - 1 } else { 1 }) .last() .unwrap_or(3) - }; + } + } + + /// Do initial configuration of the already started Postgres. + #[instrument(skip_all)] + pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> { + let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap(); + conf.application_name("apply_config"); + + let conf = Arc::new(conf); + let spec = Arc::new( + compute_state + .pspec + .as_ref() + .expect("spec must be set") + .spec + .clone(), + ); + + let max_concurrent_connections = self.max_service_connections(compute_state, &spec); // Merge-apply spec & changes to PostgreSQL state. - self.apply_spec_sql(spec.clone(), url.clone(), max_concurrent_connections)?; + self.apply_spec_sql(spec.clone(), conf.clone(), max_concurrent_connections)?; if let Some(ref local_proxy) = &spec.clone().local_proxy_config { info!("configuring local_proxy"); @@ -1156,12 +1159,11 @@ impl ComputeNode { // Run migrations separately to not hold up cold starts thread::spawn(move || { - let mut connstr = url.as_ref().clone(); - connstr - .query_pairs_mut() - .append_pair("application_name", "migrations"); + let conf = conf.as_ref().clone(); + let mut conf = postgres::config::Config::from(conf); + conf.application_name("migrations"); - let mut client = Client::connect(connstr.as_str(), NoTls)?; + let mut client = conf.connect(NoTls)?; handle_migrations(&mut client).context("apply_config handle_migrations") }); @@ -1222,21 +1224,28 @@ impl ComputeNode { let pgdata_path = Path::new(&self.pgdata); let postgresql_conf_path = pgdata_path.join("postgresql.conf"); config::write_postgres_conf(&postgresql_conf_path, &spec, None)?; - // temporarily reset max_cluster_size in config + + // TODO(ololobus): We need a concurrency during reconfiguration as well, + // but DB is already running and used by user. We can easily get out of + // `max_connections` limit, and the current code won't handle that. + // let compute_state = self.state.lock().unwrap().clone(); + // let max_concurrent_connections = self.max_service_connections(&compute_state, &spec); + let max_concurrent_connections = 1; + + // Temporarily reset max_cluster_size in config // to avoid the possibility of hitting the limit, while we are reconfiguring: - // creating new extensions, roles, etc... + // creating new extensions, roles, etc. config::with_compute_ctl_tmp_override(pgdata_path, "neon.max_cluster_size=-1", || { self.pg_reload_conf()?; if spec.mode == ComputeMode::Primary { - let mut url = self.connstr.clone(); - url.query_pairs_mut() - .append_pair("application_name", "apply_config"); - let url = Arc::new(url); + let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap(); + conf.application_name("apply_config"); + let conf = Arc::new(conf); let spec = Arc::new(spec.clone()); - self.apply_spec_sql(spec, url, 1)?; + self.apply_spec_sql(spec, conf, max_concurrent_connections)?; } Ok(()) @@ -1362,7 +1371,17 @@ impl ComputeNode { let connstr = self.connstr.clone(); thread::spawn(move || { - get_installed_extensions_sync(connstr).context("get_installed_extensions") + let res = get_installed_extensions(&connstr); + match res { + Ok(extensions) => { + info!( + "[NEON_EXT_STAT] {}", + serde_json::to_string(&extensions) + .expect("failed to serialize extensions list") + ); + } + Err(err) => error!("could not get installed extensions: {err:?}"), + } }); } diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index 8a047634df..a6c6cff20a 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -296,7 +296,12 @@ async fn routes(req: Request, compute: &Arc) -> Response render_json(Body::from(serde_json::to_string(&res).unwrap())), Err(e) => render_json_error( diff --git a/compute_tools/src/installed_extensions.rs b/compute_tools/src/installed_extensions.rs index 79d8b2ca04..f473c29a55 100644 --- a/compute_tools/src/installed_extensions.rs +++ b/compute_tools/src/installed_extensions.rs @@ -2,17 +2,16 @@ use compute_api::responses::{InstalledExtension, InstalledExtensions}; use metrics::proto::MetricFamily; use std::collections::HashMap; use std::collections::HashSet; -use tracing::info; -use url::Url; use anyhow::Result; use postgres::{Client, NoTls}; -use tokio::task; use metrics::core::Collector; use metrics::{register_uint_gauge_vec, UIntGaugeVec}; use once_cell::sync::Lazy; +use crate::pg_helpers::postgres_conf_for_db; + /// We don't reuse get_existing_dbs() just for code clarity /// and to make database listing query here more explicit. /// @@ -42,75 +41,51 @@ fn list_dbs(client: &mut Client) -> Result> { /// /// Same extension can be installed in multiple databases with different versions, /// we only keep the highest and lowest version across all databases. -pub async fn get_installed_extensions(connstr: Url) -> Result { - let mut connstr = connstr.clone(); +pub fn get_installed_extensions(connstr: &url::Url) -> Result { + let mut client = Client::connect(connstr.as_str(), NoTls)?; + let databases: Vec = list_dbs(&mut client)?; - task::spawn_blocking(move || { - let mut client = Client::connect(connstr.as_str(), NoTls)?; - let databases: Vec = list_dbs(&mut client)?; + let mut extensions_map: HashMap = HashMap::new(); + for db in databases.iter() { + let config = postgres_conf_for_db(connstr, db)?; + let mut db_client = config.connect(NoTls)?; + let extensions: Vec<(String, String)> = db_client + .query( + "SELECT extname, extversion FROM pg_catalog.pg_extension;", + &[], + )? + .iter() + .map(|row| (row.get("extname"), row.get("extversion"))) + .collect(); - let mut extensions_map: HashMap = HashMap::new(); - for db in databases.iter() { - connstr.set_path(db); - let mut db_client = Client::connect(connstr.as_str(), NoTls)?; - let extensions: Vec<(String, String)> = db_client - .query( - "SELECT extname, extversion FROM pg_catalog.pg_extension;", - &[], - )? - .iter() - .map(|row| (row.get("extname"), row.get("extversion"))) - .collect(); + for (extname, v) in extensions.iter() { + let version = v.to_string(); - for (extname, v) in extensions.iter() { - let version = v.to_string(); + // increment the number of databases where the version of extension is installed + INSTALLED_EXTENSIONS + .with_label_values(&[extname, &version]) + .inc(); - // increment the number of databases where the version of extension is installed - INSTALLED_EXTENSIONS - .with_label_values(&[extname, &version]) - .inc(); - - extensions_map - .entry(extname.to_string()) - .and_modify(|e| { - e.versions.insert(version.clone()); - // count the number of databases where the extension is installed - e.n_databases += 1; - }) - .or_insert(InstalledExtension { - extname: extname.to_string(), - versions: HashSet::from([version.clone()]), - n_databases: 1, - }); - } + extensions_map + .entry(extname.to_string()) + .and_modify(|e| { + e.versions.insert(version.clone()); + // count the number of databases where the extension is installed + e.n_databases += 1; + }) + .or_insert(InstalledExtension { + extname: extname.to_string(), + versions: HashSet::from([version.clone()]), + n_databases: 1, + }); } + } - let res = InstalledExtensions { - extensions: extensions_map.values().cloned().collect(), - }; + let res = InstalledExtensions { + extensions: extensions_map.values().cloned().collect(), + }; - Ok(res) - }) - .await? -} - -// Gather info about installed extensions -pub fn get_installed_extensions_sync(connstr: Url) -> Result<()> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to create runtime"); - let result = rt - .block_on(crate::installed_extensions::get_installed_extensions( - connstr, - )) - .expect("failed to get installed extensions"); - - info!( - "[NEON_EXT_STAT] {}", - serde_json::to_string(&result).expect("failed to serialize extensions list") - ); - Ok(()) + Ok(res) } static INSTALLED_EXTENSIONS: Lazy = Lazy::new(|| { diff --git a/compute_tools/src/pg_helpers.rs b/compute_tools/src/pg_helpers.rs index 4a1e5ee0e8..e03b410699 100644 --- a/compute_tools/src/pg_helpers.rs +++ b/compute_tools/src/pg_helpers.rs @@ -6,6 +6,7 @@ use std::io::{BufRead, BufReader}; use std::os::unix::fs::PermissionsExt; use std::path::Path; use std::process::Child; +use std::str::FromStr; use std::thread::JoinHandle; use std::time::{Duration, Instant}; @@ -13,8 +14,10 @@ use anyhow::{bail, Result}; use futures::StreamExt; use ini::Ini; use notify::{RecursiveMode, Watcher}; +use postgres::config::Config; use tokio::io::AsyncBufReadExt; use tokio::time::timeout; +use tokio_postgres; use tokio_postgres::NoTls; use tracing::{debug, error, info, instrument}; @@ -542,3 +545,11 @@ async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Resu Ok(()) } + +/// `Postgres::config::Config` handles database names with whitespaces +/// and special characters properly. +pub fn postgres_conf_for_db(connstr: &url::Url, dbname: &str) -> Result { + let mut conf = Config::from_str(connstr.as_str())?; + conf.dbname(dbname); + Ok(conf) +} diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index db3723b7cc..1cd9158c68 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -1,5 +1,7 @@ from __future__ import annotations +import urllib.parse + import requests from requests.adapters import HTTPAdapter @@ -20,7 +22,9 @@ class EndpointHttpClient(requests.Session): return res.json() def database_schema(self, database: str): - res = self.get(f"http://localhost:{self.port}/database_schema?database={database}") + res = self.get( + f"http://localhost:{self.port}/database_schema?database={urllib.parse.quote(database, safe='')}" + ) res.raise_for_status() return res.text diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index a45a311dc2..1f4d2aa5ec 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3934,6 +3934,35 @@ class Endpoint(PgProtocol, LogUtils): log.info(json.dumps(dict(data_dict, **kwargs))) json.dump(dict(data_dict, **kwargs), file, indent=4) + def respec_deep(self, **kwargs: Any) -> None: + """ + Update the endpoint.json file taking into account nested keys. + It does one level deep update. Should enough for most cases. + Distinct method from respec() to do not break existing functionality. + NOTE: This method also updates the spec.json file, not endpoint.json. + We need it because neon_local also writes to spec.json, so intended + use-case is i) start endpoint with some config, ii) respec_deep(), + iii) call reconfigure() to apply the changes. + """ + config_path = os.path.join(self.endpoint_path(), "spec.json") + with open(config_path) as f: + data_dict: dict[str, Any] = json.load(f) + + log.info("Current compute spec: %s", json.dumps(data_dict, indent=4)) + + for key, value in kwargs.items(): + if isinstance(value, dict): + if key not in data_dict: + data_dict[key] = value + else: + data_dict[key] = {**data_dict[key], **value} + else: + data_dict[key] = value + + with open(config_path, "w") as file: + log.info("Updating compute spec to: %s", json.dumps(data_dict, indent=4)) + json.dump(data_dict, file, indent=4) + # Please note: Migrations only run if pg_skip_catalog_updates is false def wait_for_migrations(self, num_migrations: int = 11): with self.cursor() as cur: diff --git a/test_runner/regress/test_compute_catalog.py b/test_runner/regress/test_compute_catalog.py index d43c71ceac..b3719a45ed 100644 --- a/test_runner/regress/test_compute_catalog.py +++ b/test_runner/regress/test_compute_catalog.py @@ -3,13 +3,60 @@ from __future__ import annotations import requests from fixtures.neon_fixtures import NeonEnv +TEST_DB_NAMES = [ + { + "name": "neondb", + "owner": "cloud_admin", + }, + { + "name": "db with spaces", + "owner": "cloud_admin", + }, + { + "name": "db with%20spaces ", + "owner": "cloud_admin", + }, + { + "name": "db with whitespaces ", + "owner": "cloud_admin", + }, + { + "name": "injective db with spaces'; SELECT pg_sleep(10);", + "owner": "cloud_admin", + }, + { + "name": "db with #pound-sign and &ersands=true", + "owner": "cloud_admin", + }, + { + "name": "db with emoji 🌍", + "owner": "cloud_admin", + }, +] + def test_compute_catalog(neon_simple_env: NeonEnv): + """ + Create a bunch of databases with tricky names and test that we can list them + and dump via API. + """ env = neon_simple_env - endpoint = env.endpoints.create_start("main", config_lines=["log_min_messages=debug1"]) - client = endpoint.http_client() + endpoint = env.endpoints.create_start("main") + # Update the spec.json file to include new databases + # and reconfigure the endpoint to create some test databases. + endpoint.respec_deep( + **{ + "skip_pg_catalog_updates": False, + "cluster": { + "databases": TEST_DB_NAMES, + }, + } + ) + endpoint.reconfigure() + + client = endpoint.http_client() objects = client.dbs_and_roles() # Assert that 'cloud_admin' role exists in the 'roles' list @@ -22,9 +69,24 @@ def test_compute_catalog(neon_simple_env: NeonEnv): db["name"] == "postgres" for db in objects["databases"] ), "The 'postgres' database is missing" - ddl = client.database_schema(database="postgres") + # Check other databases + for test_db in TEST_DB_NAMES: + db = next((db for db in objects["databases"] if db["name"] == test_db["name"]), None) + assert db is not None, f"The '{test_db['name']}' database is missing" + assert ( + db["owner"] == test_db["owner"] + ), f"The '{test_db['name']}' database has incorrect owner" - assert "-- PostgreSQL database dump" in ddl + ddl = client.database_schema(database=test_db["name"]) + + # Check that it looks like a valid PostgreSQL dump + assert "-- PostgreSQL database dump" in ddl + + # Check that it doesn't contain health_check and migration traces. + # They are only created in system `postgres` database, so by checking + # that we ensure that we dump right databases. + assert "health_check" not in ddl, f"The '{test_db['name']}' database contains health_check" + assert "migration" not in ddl, f"The '{test_db['name']}' database contains migrations data" try: client.database_schema(database="nonexistentdb") @@ -33,3 +95,44 @@ def test_compute_catalog(neon_simple_env: NeonEnv): assert ( e.response.status_code == 404 ), f"Expected 404 status code, but got {e.response.status_code}" + + +def test_compute_create_databases(neon_simple_env: NeonEnv): + """ + Test that compute_ctl can create and work with databases with special + characters (whitespaces, %, tabs, etc.) in the name. + """ + env = neon_simple_env + + # Create and start endpoint so that neon_local put all the generated + # stuff into the spec.json file. + endpoint = env.endpoints.create_start("main") + + # Update the spec.json file to include new databases + # and reconfigure the endpoint to apply the changes. + endpoint.respec_deep( + **{ + "skip_pg_catalog_updates": False, + "cluster": { + "databases": TEST_DB_NAMES, + }, + } + ) + endpoint.reconfigure() + + for db in TEST_DB_NAMES: + # Check that database has a correct name in the system catalog + with endpoint.cursor() as cursor: + cursor.execute("SELECT datname FROM pg_database WHERE datname = %s", (db["name"],)) + catalog_db = cursor.fetchone() + assert catalog_db is not None + assert len(catalog_db) == 1 + assert catalog_db[0] == db["name"] + + # Check that we can connect to this database without any issues + with endpoint.cursor(dbname=db["name"]) as cursor: + cursor.execute("SELECT * FROM current_database()") + curr_db = cursor.fetchone() + assert curr_db is not None + assert len(curr_db) == 1 + assert curr_db[0] == db["name"] From 3ffe6de0b9a4f49cf18f6a2ebf0fc2c6274dfccd Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Fri, 29 Nov 2024 10:40:08 +0100 Subject: [PATCH 33/34] test_runner/performance: add logical message ingest benchmark (#9749) Adds a benchmark for logical message WAL ingestion throughput end-to-end. Logical messages are essentially noops, and thus ignored by the Pageserver. Example results from my MacBook, with fsync enabled: ``` postgres_ingest: 14.445 s safekeeper_ingest: 29.948 s pageserver_ingest: 30.013 s pageserver_recover_ingest: 8.633 s wal_written: 10,340 MB message_count: 1310720 messages postgres_throughput: 715 MB/s safekeeper_throughput: 345 MB/s pageserver_throughput: 344 MB/s pageserver_recover_throughput: 1197 MB/s ``` See https://github.com/neondatabase/neon/issues/9642#issuecomment-2475995205 for running analysis. Touches #9642. --- test_runner/fixtures/neon_fixtures.py | 31 ++++++ .../test_ingest_logical_message.py | 101 ++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 test_runner/performance/test_ingest_logical_message.py diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 1f4d2aa5ec..e3c88e9965 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -4404,6 +4404,10 @@ class Safekeeper(LogUtils): log.info(f"sk {self.id} flush LSN: {flush_lsn}") return flush_lsn + def get_commit_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn: + timeline_status = self.http_client().timeline_status(tenant_id, timeline_id) + return timeline_status.commit_lsn + def pull_timeline( self, srcs: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId ) -> dict[str, Any]: @@ -4949,6 +4953,33 @@ def wait_for_last_flush_lsn( return min(results) +def wait_for_commit_lsn( + env: NeonEnv, + tenant: TenantId, + timeline: TimelineId, + lsn: Lsn, +) -> Lsn: + # TODO: it would be better to poll this in the compute, but there's no API for it. See: + # https://github.com/neondatabase/neon/issues/9758 + "Wait for the given LSN to be committed on any Safekeeper" + + max_commit_lsn = Lsn(0) + for i in range(1000): + for sk in env.safekeepers: + commit_lsn = sk.get_commit_lsn(tenant, timeline) + if commit_lsn >= lsn: + log.info(f"{tenant}/{timeline} at commit_lsn {commit_lsn}") + return commit_lsn + max_commit_lsn = max(max_commit_lsn, commit_lsn) + + if i % 10 == 0: + log.info( + f"{tenant}/{timeline} waiting for commit_lsn to reach {lsn}, now {max_commit_lsn}" + ) + time.sleep(0.1) + raise Exception(f"timed out while waiting for commit_lsn to reach {lsn}, was {max_commit_lsn}") + + def flush_ep_to_pageserver( env: NeonEnv, ep: Endpoint, diff --git a/test_runner/performance/test_ingest_logical_message.py b/test_runner/performance/test_ingest_logical_message.py new file mode 100644 index 0000000000..d3118eb15a --- /dev/null +++ b/test_runner/performance/test_ingest_logical_message.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import pytest +from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker +from fixtures.common_types import Lsn +from fixtures.log_helper import log +from fixtures.neon_fixtures import ( + NeonEnvBuilder, + wait_for_commit_lsn, + wait_for_last_flush_lsn, +) +from fixtures.pageserver.utils import wait_for_last_record_lsn + + +@pytest.mark.timeout(600) +@pytest.mark.parametrize("size", [1024, 8192, 131072]) +@pytest.mark.parametrize("fsync", [True, False], ids=["fsync", "nofsync"]) +def test_ingest_logical_message( + request: pytest.FixtureRequest, + neon_env_builder: NeonEnvBuilder, + zenbenchmark: NeonBenchmarker, + fsync: bool, + size: int, +): + """ + Benchmarks ingestion of 10 GB of logical message WAL. These are essentially noops, and don't + incur any pageserver writes. + """ + + VOLUME = 10 * 1024**3 + count = VOLUME // size + + neon_env_builder.safekeepers_enable_fsync = fsync + + env = neon_env_builder.init_start() + endpoint = env.endpoints.create_start( + "main", + config_lines=[ + f"fsync = {fsync}", + # Disable backpressure. We don't want to block on pageserver. + "max_replication_apply_lag = 0", + "max_replication_flush_lag = 0", + "max_replication_write_lag = 0", + ], + ) + client = env.pageserver.http_client() + + # Wait for the timeline to be propagated to the pageserver. + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) + + # Ingest data and measure durations. + start_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0]) + + with endpoint.cursor() as cur: + cur.execute("set statement_timeout = 0") + + # Postgres will return once the logical messages have been written to its local WAL, without + # waiting for Safekeeper commit. We measure ingestion time both for Postgres, Safekeeper, + # and Pageserver to detect bottlenecks. + log.info("Ingesting data") + with zenbenchmark.record_duration("pageserver_ingest"): + with zenbenchmark.record_duration("safekeeper_ingest"): + with zenbenchmark.record_duration("postgres_ingest"): + cur.execute(f""" + select pg_logical_emit_message(false, '', repeat('x', {size})) + from generate_series(1, {count}) + """) + + end_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0]) + + # Wait for Safekeeper. + log.info("Waiting for Safekeeper to catch up") + wait_for_commit_lsn(env, env.initial_tenant, env.initial_timeline, end_lsn) + + # Wait for Pageserver. + log.info("Waiting for Pageserver to catch up") + wait_for_last_record_lsn(client, env.initial_tenant, env.initial_timeline, end_lsn) + + # Now that all data is ingested, delete and recreate the tenant in the pageserver. This will + # reingest all the WAL from the safekeeper without any other constraints. This gives us a + # baseline of how fast the pageserver can ingest this WAL in isolation. + status = env.storage_controller.inspect(tenant_shard_id=env.initial_tenant) + assert status is not None + + client.tenant_delete(env.initial_tenant) + env.pageserver.tenant_create(tenant_id=env.initial_tenant, generation=status[0]) + + with zenbenchmark.record_duration("pageserver_recover_ingest"): + log.info("Recovering WAL into pageserver") + client.timeline_create(env.pg_version, env.initial_tenant, env.initial_timeline) + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) + + # Emit metrics. + wal_written_mb = round((end_lsn - start_lsn) / (1024 * 1024)) + zenbenchmark.record("wal_written", wal_written_mb, "MB", MetricReport.TEST_PARAM) + zenbenchmark.record("message_count", count, "messages", MetricReport.TEST_PARAM) + + props = {p["name"]: p["value"] for _, p in request.node.user_properties} + for name in ("postgres", "safekeeper", "pageserver", "pageserver_recover"): + throughput = int(wal_written_mb / props[f"{name}_ingest"]) + zenbenchmark.record(f"{name}_throughput", throughput, "MB/s", MetricReport.HIGHER_IS_BETTER) From 1d642d6a57dd1cd1645a34aba5a2dd6e06a6c651 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 29 Nov 2024 11:08:01 +0000 Subject: [PATCH 34/34] chore(proxy): vendor a subset of rust-postgres (#9930) Our rust-postgres fork is getting messy. Mostly because proxy wants more control over the raw protocol than tokio-postgres provides. As such, it's diverging more and more. Storage and compute also make use of rust-postgres, but in more normal usage, thus they don't need our crazy changes. Idea: * proxy maintains their subset * other teams use a minimal patch set against upstream rust-postgres Reviewing this code will be difficult. To implement it, I 1. Copied tokio-postgres, postgres-protocol and postgres-types from https://github.com/neondatabase/rust-postgres/tree/00940fcdb57a8e99e805297b75839e7c4c7b1796 2. Updated their package names with the `2` suffix to make them compile in the workspace. 3. Updated proxy to use those packages 4. Copied in the code from tokio-postgres-rustls 0.13 (with some patches applied https://github.com/jbg/tokio-postgres-rustls/pull/32 https://github.com/jbg/tokio-postgres-rustls/pull/33) 5. Removed as much dead code as I could find in the vendored libraries 6. Updated the tokio-postgres-rustls code to use our existing channel binding implementation --- .config/hakari.toml | 3 + Cargo.lock | 56 +- Cargo.toml | 3 + libs/proxy/README.md | 6 + libs/proxy/postgres-protocol2/Cargo.toml | 21 + .../src/authentication/mod.rs | 37 + .../src/authentication/sasl.rs | 516 +++++ .../postgres-protocol2/src/escape/mod.rs | 93 + .../postgres-protocol2/src/escape/test.rs | 17 + libs/proxy/postgres-protocol2/src/lib.rs | 78 + .../postgres-protocol2/src/message/backend.rs | 766 ++++++++ .../src/message/frontend.rs | 297 +++ .../postgres-protocol2/src/message/mod.rs | 8 + .../postgres-protocol2/src/password/mod.rs | 107 ++ .../postgres-protocol2/src/password/test.rs | 19 + .../proxy/postgres-protocol2/src/types/mod.rs | 294 +++ .../postgres-protocol2/src/types/test.rs | 87 + libs/proxy/postgres-types2/Cargo.toml | 10 + libs/proxy/postgres-types2/src/lib.rs | 477 +++++ libs/proxy/postgres-types2/src/private.rs | 34 + libs/proxy/postgres-types2/src/type_gen.rs | 1524 +++++++++++++++ libs/proxy/tokio-postgres2/Cargo.toml | 21 + .../proxy/tokio-postgres2/src/cancel_query.rs | 40 + .../tokio-postgres2/src/cancel_query_raw.rs | 29 + .../proxy/tokio-postgres2/src/cancel_token.rs | 62 + libs/proxy/tokio-postgres2/src/client.rs | 439 +++++ libs/proxy/tokio-postgres2/src/codec.rs | 109 ++ libs/proxy/tokio-postgres2/src/config.rs | 897 +++++++++ libs/proxy/tokio-postgres2/src/connect.rs | 112 ++ libs/proxy/tokio-postgres2/src/connect_raw.rs | 359 ++++ .../tokio-postgres2/src/connect_socket.rs | 65 + libs/proxy/tokio-postgres2/src/connect_tls.rs | 48 + libs/proxy/tokio-postgres2/src/connection.rs | 323 ++++ libs/proxy/tokio-postgres2/src/error/mod.rs | 501 +++++ .../tokio-postgres2/src/error/sqlstate.rs | 1670 +++++++++++++++++ .../tokio-postgres2/src/generic_client.rs | 64 + libs/proxy/tokio-postgres2/src/lib.rs | 148 ++ .../tokio-postgres2/src/maybe_tls_stream.rs | 77 + libs/proxy/tokio-postgres2/src/prepare.rs | 262 +++ libs/proxy/tokio-postgres2/src/query.rs | 340 ++++ libs/proxy/tokio-postgres2/src/row.rs | 300 +++ .../proxy/tokio-postgres2/src/simple_query.rs | 142 ++ libs/proxy/tokio-postgres2/src/statement.rs | 157 ++ libs/proxy/tokio-postgres2/src/tls.rs | 162 ++ .../proxy/tokio-postgres2/src/to_statement.rs | 57 + libs/proxy/tokio-postgres2/src/transaction.rs | 74 + .../src/transaction_builder.rs | 113 ++ libs/proxy/tokio-postgres2/src/types.rs | 6 + proxy/Cargo.toml | 6 +- proxy/src/compute.rs | 5 +- proxy/src/context/mod.rs | 1 + proxy/src/lib.rs | 1 + proxy/src/postgres_rustls/mod.rs | 158 ++ proxy/src/proxy/tests/mod.rs | 2 +- proxy/src/serverless/backend.rs | 2 +- proxy/src/serverless/conn_pool.rs | 5 +- proxy/src/serverless/local_conn_pool.rs | 11 +- workspace_hack/Cargo.toml | 4 +- 58 files changed, 11199 insertions(+), 26 deletions(-) create mode 100644 libs/proxy/README.md create mode 100644 libs/proxy/postgres-protocol2/Cargo.toml create mode 100644 libs/proxy/postgres-protocol2/src/authentication/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/authentication/sasl.rs create mode 100644 libs/proxy/postgres-protocol2/src/escape/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/escape/test.rs create mode 100644 libs/proxy/postgres-protocol2/src/lib.rs create mode 100644 libs/proxy/postgres-protocol2/src/message/backend.rs create mode 100644 libs/proxy/postgres-protocol2/src/message/frontend.rs create mode 100644 libs/proxy/postgres-protocol2/src/message/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/password/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/password/test.rs create mode 100644 libs/proxy/postgres-protocol2/src/types/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/types/test.rs create mode 100644 libs/proxy/postgres-types2/Cargo.toml create mode 100644 libs/proxy/postgres-types2/src/lib.rs create mode 100644 libs/proxy/postgres-types2/src/private.rs create mode 100644 libs/proxy/postgres-types2/src/type_gen.rs create mode 100644 libs/proxy/tokio-postgres2/Cargo.toml create mode 100644 libs/proxy/tokio-postgres2/src/cancel_query.rs create mode 100644 libs/proxy/tokio-postgres2/src/cancel_query_raw.rs create mode 100644 libs/proxy/tokio-postgres2/src/cancel_token.rs create mode 100644 libs/proxy/tokio-postgres2/src/client.rs create mode 100644 libs/proxy/tokio-postgres2/src/codec.rs create mode 100644 libs/proxy/tokio-postgres2/src/config.rs create mode 100644 libs/proxy/tokio-postgres2/src/connect.rs create mode 100644 libs/proxy/tokio-postgres2/src/connect_raw.rs create mode 100644 libs/proxy/tokio-postgres2/src/connect_socket.rs create mode 100644 libs/proxy/tokio-postgres2/src/connect_tls.rs create mode 100644 libs/proxy/tokio-postgres2/src/connection.rs create mode 100644 libs/proxy/tokio-postgres2/src/error/mod.rs create mode 100644 libs/proxy/tokio-postgres2/src/error/sqlstate.rs create mode 100644 libs/proxy/tokio-postgres2/src/generic_client.rs create mode 100644 libs/proxy/tokio-postgres2/src/lib.rs create mode 100644 libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs create mode 100644 libs/proxy/tokio-postgres2/src/prepare.rs create mode 100644 libs/proxy/tokio-postgres2/src/query.rs create mode 100644 libs/proxy/tokio-postgres2/src/row.rs create mode 100644 libs/proxy/tokio-postgres2/src/simple_query.rs create mode 100644 libs/proxy/tokio-postgres2/src/statement.rs create mode 100644 libs/proxy/tokio-postgres2/src/tls.rs create mode 100644 libs/proxy/tokio-postgres2/src/to_statement.rs create mode 100644 libs/proxy/tokio-postgres2/src/transaction.rs create mode 100644 libs/proxy/tokio-postgres2/src/transaction_builder.rs create mode 100644 libs/proxy/tokio-postgres2/src/types.rs create mode 100644 proxy/src/postgres_rustls/mod.rs diff --git a/.config/hakari.toml b/.config/hakari.toml index b5990d090e..3b6d9d8822 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -46,6 +46,9 @@ workspace-members = [ "utils", "wal_craft", "walproposer", + "postgres-protocol2", + "postgres-types2", + "tokio-postgres2", ] # Write out exact versions rather than a semver range. (Defaults to false.) diff --git a/Cargo.lock b/Cargo.lock index 43a46fb1eb..f05c6311dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4162,6 +4162,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "postgres-protocol2" +version = "0.1.0" +dependencies = [ + "base64 0.20.0", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand 0.8.5", + "sha2", + "stringprep", + "tokio", +] + [[package]] name = "postgres-types" version = "0.2.4" @@ -4170,8 +4187,15 @@ dependencies = [ "bytes", "fallible-iterator", "postgres-protocol", - "serde", - "serde_json", +] + +[[package]] +name = "postgres-types2" +version = "0.1.0" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol2", ] [[package]] @@ -4501,7 +4525,7 @@ dependencies = [ "parquet_derive", "pbkdf2", "pin-project-lite", - "postgres-protocol", + "postgres-protocol2", "postgres_backend", "pq_proto", "prometheus", @@ -4536,8 +4560,7 @@ dependencies = [ "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", - "tokio-postgres", - "tokio-postgres-rustls", + "tokio-postgres2", "tokio-rustls 0.26.0", "tokio-tungstenite", "tokio-util", @@ -6421,6 +6444,7 @@ dependencies = [ "libc", "mio", "num_cpus", + "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2", @@ -6502,6 +6526,26 @@ dependencies = [ "x509-certificate", ] +[[package]] +name = "tokio-postgres2" +version = "0.1.0" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-util", + "log", + "parking_lot 0.12.1", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol2", + "postgres-types2", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-rustls" version = "0.24.0" @@ -7597,7 +7641,6 @@ dependencies = [ "num-traits", "once_cell", "parquet", - "postgres-types", "prettyplease", "proc-macro2", "prost", @@ -7622,7 +7665,6 @@ dependencies = [ "time", "time-macros", "tokio", - "tokio-postgres", "tokio-rustls 0.26.0", "tokio-stream", "tokio-util", diff --git a/Cargo.toml b/Cargo.toml index e3dc5b97f8..742201d0f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,9 @@ members = [ "libs/walproposer", "libs/wal_decoder", "libs/postgres_initdb", + "libs/proxy/postgres-protocol2", + "libs/proxy/postgres-types2", + "libs/proxy/tokio-postgres2", ] [workspace.package] diff --git a/libs/proxy/README.md b/libs/proxy/README.md new file mode 100644 index 0000000000..2ae6210e46 --- /dev/null +++ b/libs/proxy/README.md @@ -0,0 +1,6 @@ +This directory contains libraries that are specific for proxy. + +Currently, it contains a signficant fork/refactoring of rust-postgres that no longer reflects the API +of the original library. Since it was so significant, it made sense to upgrade it to it's own set of libraries. + +Proxy needs unique access to the protocol, which explains why such heavy modifications were necessary. diff --git a/libs/proxy/postgres-protocol2/Cargo.toml b/libs/proxy/postgres-protocol2/Cargo.toml new file mode 100644 index 0000000000..284a632954 --- /dev/null +++ b/libs/proxy/postgres-protocol2/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "postgres-protocol2" +version = "0.1.0" +edition = "2018" +license = "MIT/Apache-2.0" + +[dependencies] +base64 = "0.20" +byteorder.workspace = true +bytes.workspace = true +fallible-iterator.workspace = true +hmac.workspace = true +md-5 = "0.10" +memchr = "2.0" +rand.workspace = true +sha2.workspace = true +stringprep = "0.1" +tokio = { workspace = true, features = ["rt"] } + +[dev-dependencies] +tokio = { workspace = true, features = ["full"] } diff --git a/libs/proxy/postgres-protocol2/src/authentication/mod.rs b/libs/proxy/postgres-protocol2/src/authentication/mod.rs new file mode 100644 index 0000000000..71afa4b9b6 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/authentication/mod.rs @@ -0,0 +1,37 @@ +//! Authentication protocol support. +use md5::{Digest, Md5}; + +pub mod sasl; + +/// Hashes authentication information in a way suitable for use in response +/// to an `AuthenticationMd5Password` message. +/// +/// The resulting string should be sent back to the database in a +/// `PasswordMessage` message. +#[inline] +pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String { + let mut md5 = Md5::new(); + md5.update(password); + md5.update(username); + let output = md5.finalize_reset(); + md5.update(format!("{:x}", output)); + md5.update(salt); + format!("md5{:x}", md5.finalize()) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn md5() { + let username = b"md5_user"; + let password = b"password"; + let salt = [0x2a, 0x3d, 0x8f, 0xe0]; + + assert_eq!( + md5_hash(username, password, salt), + "md562af4dd09bbb41884907a838a3233294" + ); + } +} diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs new file mode 100644 index 0000000000..19aa3c1e9a --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -0,0 +1,516 @@ +//! SASL-based authentication support. + +use hmac::{Hmac, Mac}; +use rand::{self, Rng}; +use sha2::digest::FixedOutput; +use sha2::{Digest, Sha256}; +use std::fmt::Write; +use std::io; +use std::iter; +use std::mem; +use std::str; +use tokio::task::yield_now; + +const NONCE_LENGTH: usize = 24; + +/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism. +pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; +/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism. +pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; + +// since postgres passwords are not required to exclude saslprep-prohibited +// characters or even be valid UTF8, we run saslprep if possible and otherwise +// return the raw password. +fn normalize(pass: &[u8]) -> Vec { + let pass = match str::from_utf8(pass) { + Ok(pass) => pass, + Err(_) => return pass.to_vec(), + }; + + match stringprep::saslprep(pass) { + Ok(pass) => pass.into_owned().into_bytes(), + Err(_) => pass.as_bytes().to_vec(), + } +} + +pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] { + let mut hmac = + Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + hmac.update(salt); + hmac.update(&[0, 0, 0, 1]); + let mut prev = hmac.finalize().into_bytes(); + + let mut hi = prev; + + for i in 1..iterations { + let mut hmac = Hmac::::new_from_slice(str).expect("already checked above"); + hmac.update(&prev); + prev = hmac.finalize().into_bytes(); + + for (hi, prev) in hi.iter_mut().zip(prev) { + *hi ^= prev; + } + // yield every ~250us + // hopefully reduces tail latencies + if i % 1024 == 0 { + yield_now().await + } + } + + hi.into() +} + +enum ChannelBindingInner { + Unrequested, + Unsupported, + TlsServerEndPoint(Vec), +} + +/// The channel binding configuration for a SCRAM authentication exchange. +pub struct ChannelBinding(ChannelBindingInner); + +impl ChannelBinding { + /// The server did not request channel binding. + pub fn unrequested() -> ChannelBinding { + ChannelBinding(ChannelBindingInner::Unrequested) + } + + /// The server requested channel binding but the client is unable to provide it. + pub fn unsupported() -> ChannelBinding { + ChannelBinding(ChannelBindingInner::Unsupported) + } + + /// The server requested channel binding and the client will use the `tls-server-end-point` + /// method. + pub fn tls_server_end_point(signature: Vec) -> ChannelBinding { + ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature)) + } + + fn gs2_header(&self) -> &'static str { + match self.0 { + ChannelBindingInner::Unrequested => "y,,", + ChannelBindingInner::Unsupported => "n,,", + ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,", + } + } + + fn cbind_data(&self) -> &[u8] { + match self.0 { + ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[], + ChannelBindingInner::TlsServerEndPoint(ref buf) => buf, + } + } +} + +/// A pair of keys for the SCRAM-SHA-256 mechanism. +/// See for details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ScramKeys { + /// Used by server to authenticate client. + pub client_key: [u8; N], + /// Used by client to verify server's signature. + pub server_key: [u8; N], +} + +/// Password or keys which were derived from it. +enum Credentials { + /// A regular password as a vector of bytes. + Password(Vec), + /// A precomputed pair of keys. + Keys(Box>), +} + +enum State { + Update { + nonce: String, + password: Credentials<32>, + channel_binding: ChannelBinding, + }, + Finish { + server_key: [u8; 32], + auth_message: String, + }, + Done, +} + +/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication +/// process. +/// +/// During the authentication process, if the backend sends an `AuthenticationSASL` message which +/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used. +/// +/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be +/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name. +/// +/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be +/// passed to the `update()` method, after which the buffer returned by the `message()` method +/// should be sent to the backend in a `SASLResponse` message. +/// +/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed +/// to the `finish()` method, after which the authentication process is complete. +pub struct ScramSha256 { + message: String, + state: State, +} + +fn nonce() -> String { + // rand 0.5's ThreadRng is cryptographically secure + let mut rng = rand::thread_rng(); + (0..NONCE_LENGTH) + .map(|_| { + let mut v = rng.gen_range(0x21u8..0x7e); + if v == 0x2c { + v = 0x7e + } + v as char + }) + .collect() +} + +impl ScramSha256 { + /// Constructs a new instance which will use the provided password for authentication. + pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Password(normalize(password)); + ScramSha256::new_inner(password, channel_binding, nonce()) + } + + /// Constructs a new instance which will use the provided key pair for authentication. + pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Keys(keys.into()); + ScramSha256::new_inner(password, channel_binding, nonce()) + } + + fn new_inner( + password: Credentials<32>, + channel_binding: ChannelBinding, + nonce: String, + ) -> ScramSha256 { + ScramSha256 { + message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce), + state: State::Update { + nonce, + password, + channel_binding, + }, + } + } + + /// Returns the message which should be sent to the backend in an `SASLResponse` message. + pub fn message(&self) -> &[u8] { + if let State::Done = self.state { + panic!("invalid SCRAM state"); + } + self.message.as_bytes() + } + + /// Updates the state machine with the response from the backend. + /// + /// This should be called when an `AuthenticationSASLContinue` message is received. + pub async fn update(&mut self, message: &[u8]) -> io::Result<()> { + let (client_nonce, password, channel_binding) = + match mem::replace(&mut self.state, State::Done) { + State::Update { + nonce, + password, + channel_binding, + } => (nonce, password, channel_binding), + _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), + }; + + let message = + str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + + let parsed = Parser::new(message).server_first_message()?; + + if !parsed.nonce.starts_with(&client_nonce) { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce")); + } + + let (client_key, server_key) = match password { + Credentials::Password(password) => { + let salt = match base64::decode(parsed.salt) { + Ok(salt) => salt, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; + + let salted_password = hi(&password, &salt, parsed.iteration_count).await; + + let make_key = |name| { + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(name); + + let mut key = [0u8; 32]; + key.copy_from_slice(hmac.finalize().into_bytes().as_slice()); + key + }; + + (make_key(b"Client Key"), make_key(b"Server Key")) + } + Credentials::Keys(keys) => (keys.client_key, keys.server_key), + }; + + let mut hash = Sha256::default(); + hash.update(client_key); + let stored_key = hash.finalize_fixed(); + + let mut cbind_input = vec![]; + cbind_input.extend(channel_binding.gs2_header().as_bytes()); + cbind_input.extend(channel_binding.cbind_data()); + let cbind_input = base64::encode(&cbind_input); + + self.message.clear(); + write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap(); + + let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message); + + let mut hmac = Hmac::::new_from_slice(&stored_key) + .expect("HMAC is able to accept all key sizes"); + hmac.update(auth_message.as_bytes()); + let client_signature = hmac.finalize().into_bytes(); + + let mut client_proof = client_key; + for (proof, signature) in client_proof.iter_mut().zip(client_signature) { + *proof ^= signature; + } + + write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap(); + + self.state = State::Finish { + server_key, + auth_message, + }; + Ok(()) + } + + /// Finalizes the authentication process. + /// + /// This should be called when the backend sends an `AuthenticationSASLFinal` message. + /// Authentication has only succeeded if this method returns `Ok(())`. + pub fn finish(&mut self, message: &[u8]) -> io::Result<()> { + let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) { + State::Finish { + server_key, + auth_message, + } => (server_key, auth_message), + _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), + }; + + let message = + str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + + let parsed = Parser::new(message).server_final_message()?; + + let verifier = match parsed { + ServerFinalMessage::Error(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("SCRAM error: {}", e), + )); + } + ServerFinalMessage::Verifier(verifier) => verifier, + }; + + let verifier = match base64::decode(verifier) { + Ok(verifier) => verifier, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; + + let mut hmac = Hmac::::new_from_slice(&server_key) + .expect("HMAC is able to accept all key sizes"); + hmac.update(auth_message.as_bytes()); + hmac.verify_slice(&verifier) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error")) + } +} + +struct Parser<'a> { + s: &'a str, + it: iter::Peekable>, +} + +impl<'a> Parser<'a> { + fn new(s: &'a str) -> Parser<'a> { + Parser { + s, + it: s.char_indices().peekable(), + } + } + + fn eat(&mut self, target: char) -> io::Result<()> { + match self.it.next() { + Some((_, c)) if c == target => Ok(()), + Some((i, c)) => { + let m = format!( + "unexpected character at byte {}: expected `{}` but got `{}", + i, target, c + ); + Err(io::Error::new(io::ErrorKind::InvalidInput, m)) + } + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } + } + + fn take_while(&mut self, f: F) -> io::Result<&'a str> + where + F: Fn(char) -> bool, + { + let start = match self.it.peek() { + Some(&(i, _)) => i, + None => return Ok(""), + }; + + loop { + match self.it.peek() { + Some(&(_, c)) if f(c) => { + self.it.next(); + } + Some(&(i, _)) => return Ok(&self.s[start..i]), + None => return Ok(&self.s[start..]), + } + } + } + + fn printable(&mut self) -> io::Result<&'a str> { + self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e')) + } + + fn nonce(&mut self) -> io::Result<&'a str> { + self.eat('r')?; + self.eat('=')?; + self.printable() + } + + fn base64(&mut self) -> io::Result<&'a str> { + self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '=')) + } + + fn salt(&mut self) -> io::Result<&'a str> { + self.eat('s')?; + self.eat('=')?; + self.base64() + } + + fn posit_number(&mut self) -> io::Result { + let n = self.take_while(|c| c.is_ascii_digit())?; + n.parse() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) + } + + fn iteration_count(&mut self) -> io::Result { + self.eat('i')?; + self.eat('=')?; + self.posit_number() + } + + fn eof(&mut self) -> io::Result<()> { + match self.it.peek() { + Some(&(i, _)) => Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected trailing data at byte {}", i), + )), + None => Ok(()), + } + } + + fn server_first_message(&mut self) -> io::Result> { + let nonce = self.nonce()?; + self.eat(',')?; + let salt = self.salt()?; + self.eat(',')?; + let iteration_count = self.iteration_count()?; + self.eof()?; + + Ok(ServerFirstMessage { + nonce, + salt, + iteration_count, + }) + } + + fn value(&mut self) -> io::Result<&'a str> { + self.take_while(|c| matches!(c, '\0' | '=' | ',')) + } + + fn server_error(&mut self) -> io::Result> { + match self.it.peek() { + Some(&(_, 'e')) => {} + _ => return Ok(None), + } + + self.eat('e')?; + self.eat('=')?; + self.value().map(Some) + } + + fn verifier(&mut self) -> io::Result<&'a str> { + self.eat('v')?; + self.eat('=')?; + self.base64() + } + + fn server_final_message(&mut self) -> io::Result> { + let message = match self.server_error()? { + Some(error) => ServerFinalMessage::Error(error), + None => ServerFinalMessage::Verifier(self.verifier()?), + }; + self.eof()?; + Ok(message) + } +} + +struct ServerFirstMessage<'a> { + nonce: &'a str, + salt: &'a str, + iteration_count: u32, +} + +enum ServerFinalMessage<'a> { + Error(&'a str), + Verifier(&'a str), +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn parse_server_first_message() { + let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096"; + let message = Parser::new(message).server_first_message().unwrap(); + assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j"); + assert_eq!(message.salt, "QSXCR+Q6sek8bf92"); + assert_eq!(message.iteration_count, 4096); + } + + // recorded auth exchange from psql + #[tokio::test] + async fn exchange() { + let password = "foobar"; + let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB"; + + let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB"; + let server_first = + "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\ + =4096"; + let client_final = + "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\ + 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8="; + let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw="; + + let mut scram = ScramSha256::new_inner( + Credentials::Password(normalize(password.as_bytes())), + ChannelBinding::unsupported(), + nonce.to_string(), + ); + assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first); + + scram.update(server_first.as_bytes()).await.unwrap(); + assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final); + + scram.finish(server_final.as_bytes()).unwrap(); + } +} diff --git a/libs/proxy/postgres-protocol2/src/escape/mod.rs b/libs/proxy/postgres-protocol2/src/escape/mod.rs new file mode 100644 index 0000000000..0ba7efdcac --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/escape/mod.rs @@ -0,0 +1,93 @@ +//! Provides functions for escaping literals and identifiers for use +//! in SQL queries. +//! +//! Prefer parameterized queries where possible. Do not escape +//! parameters in a parameterized query. + +#[cfg(test)] +mod test; + +/// Escape a literal and surround result with single quotes. Not +/// recommended in most cases. +/// +/// If input contains backslashes, result will be of the form ` +/// E'...'` so it is safe to use regardless of the setting of +/// standard_conforming_strings. +pub fn escape_literal(input: &str) -> String { + escape_internal(input, false) +} + +/// Escape an identifier and surround result with double quotes. +pub fn escape_identifier(input: &str) -> String { + escape_internal(input, true) +} + +// Translation of PostgreSQL libpq's PQescapeInternal(). Does not +// require a connection because input string is known to be valid +// UTF-8. +// +// Escape arbitrary strings. If as_ident is true, we escape the +// result as an identifier; if false, as a literal. The result is +// returned in a newly allocated buffer. If we fail due to an +// encoding violation or out of memory condition, we return NULL, +// storing an error message into conn. +fn escape_internal(input: &str, as_ident: bool) -> String { + let mut num_backslashes = 0; + let mut num_quotes = 0; + let quote_char = if as_ident { '"' } else { '\'' }; + + // Scan the string for characters that must be escaped. + for ch in input.chars() { + if ch == quote_char { + num_quotes += 1; + } else if ch == '\\' { + num_backslashes += 1; + } + } + + // Allocate output String. + let mut result_size = input.len() + num_quotes + 3; // two quotes, plus a NUL + if !as_ident && num_backslashes > 0 { + result_size += num_backslashes + 2; + } + + let mut output = String::with_capacity(result_size); + + // If we are escaping a literal that contains backslashes, we use + // the escape string syntax so that the result is correct under + // either value of standard_conforming_strings. We also emit a + // leading space in this case, to guard against the possibility + // that the result might be interpolated immediately following an + // identifier. + if !as_ident && num_backslashes > 0 { + output.push(' '); + output.push('E'); + } + + // Opening quote. + output.push(quote_char); + + // Use fast path if possible. + // + // We've already verified that the input string is well-formed in + // the current encoding. If it contains no quotes and, in the + // case of literal-escaping, no backslashes, then we can just copy + // it directly to the output buffer, adding the necessary quotes. + // + // If not, we must rescan the input and process each character + // individually. + if num_quotes == 0 && (num_backslashes == 0 || as_ident) { + output.push_str(input); + } else { + for ch in input.chars() { + if ch == quote_char || (!as_ident && ch == '\\') { + output.push(ch); + } + output.push(ch); + } + } + + output.push(quote_char); + + output +} diff --git a/libs/proxy/postgres-protocol2/src/escape/test.rs b/libs/proxy/postgres-protocol2/src/escape/test.rs new file mode 100644 index 0000000000..4816a103b7 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/escape/test.rs @@ -0,0 +1,17 @@ +use crate::escape::{escape_identifier, escape_literal}; + +#[test] +fn test_escape_idenifier() { + assert_eq!(escape_identifier("foo"), String::from("\"foo\"")); + assert_eq!(escape_identifier("f\\oo"), String::from("\"f\\oo\"")); + assert_eq!(escape_identifier("f'oo"), String::from("\"f'oo\"")); + assert_eq!(escape_identifier("f\"oo"), String::from("\"f\"\"oo\"")); +} + +#[test] +fn test_escape_literal() { + assert_eq!(escape_literal("foo"), String::from("'foo'")); + assert_eq!(escape_literal("f\\oo"), String::from(" E'f\\\\oo'")); + assert_eq!(escape_literal("f'oo"), String::from("'f''oo'")); + assert_eq!(escape_literal("f\"oo"), String::from("'f\"oo'")); +} diff --git a/libs/proxy/postgres-protocol2/src/lib.rs b/libs/proxy/postgres-protocol2/src/lib.rs new file mode 100644 index 0000000000..947f2f835d --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/lib.rs @@ -0,0 +1,78 @@ +//! Low level Postgres protocol APIs. +//! +//! This crate implements the low level components of Postgres's communication +//! protocol, including message and value serialization and deserialization. +//! It is designed to be used as a building block by higher level APIs such as +//! `rust-postgres`, and should not typically be used directly. +//! +//! # Note +//! +//! This library assumes that the `client_encoding` backend parameter has been +//! set to `UTF8`. It will most likely not behave properly if that is not the case. +#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")] +#![warn(missing_docs, rust_2018_idioms, clippy::all)] + +use byteorder::{BigEndian, ByteOrder}; +use bytes::{BufMut, BytesMut}; +use std::io; + +pub mod authentication; +pub mod escape; +pub mod message; +pub mod password; +pub mod types; + +/// A Postgres OID. +pub type Oid = u32; + +/// A Postgres Log Sequence Number (LSN). +pub type Lsn = u64; + +/// An enum indicating if a value is `NULL` or not. +pub enum IsNull { + /// The value is `NULL`. + Yes, + /// The value is not `NULL`. + No, +} + +fn write_nullable(serializer: F, buf: &mut BytesMut) -> Result<(), E> +where + F: FnOnce(&mut BytesMut) -> Result, + E: From, +{ + let base = buf.len(); + buf.put_i32(0); + let size = match serializer(buf)? { + IsNull::No => i32::from_usize(buf.len() - base - 4)?, + IsNull::Yes => -1, + }; + BigEndian::write_i32(&mut buf[base..], size); + + Ok(()) +} + +trait FromUsize: Sized { + fn from_usize(x: usize) -> Result; +} + +macro_rules! from_usize { + ($t:ty) => { + impl FromUsize for $t { + #[inline] + fn from_usize(x: usize) -> io::Result<$t> { + if x > <$t>::MAX as usize { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "value too large to transmit", + )) + } else { + Ok(x as $t) + } + } + } + }; +} + +from_usize!(i16); +from_usize!(i32); diff --git a/libs/proxy/postgres-protocol2/src/message/backend.rs b/libs/proxy/postgres-protocol2/src/message/backend.rs new file mode 100644 index 0000000000..356d142f3f --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/message/backend.rs @@ -0,0 +1,766 @@ +#![allow(missing_docs)] + +use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; +use bytes::{Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use memchr::memchr; +use std::cmp; +use std::io::{self, Read}; +use std::ops::Range; +use std::str; + +use crate::Oid; + +// top-level message tags +const PARSE_COMPLETE_TAG: u8 = b'1'; +const BIND_COMPLETE_TAG: u8 = b'2'; +const CLOSE_COMPLETE_TAG: u8 = b'3'; +pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A'; +const COPY_DONE_TAG: u8 = b'c'; +const COMMAND_COMPLETE_TAG: u8 = b'C'; +const COPY_DATA_TAG: u8 = b'd'; +const DATA_ROW_TAG: u8 = b'D'; +const ERROR_RESPONSE_TAG: u8 = b'E'; +const COPY_IN_RESPONSE_TAG: u8 = b'G'; +const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; +const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; +const BACKEND_KEY_DATA_TAG: u8 = b'K'; +pub const NO_DATA_TAG: u8 = b'n'; +pub const NOTICE_RESPONSE_TAG: u8 = b'N'; +const AUTHENTICATION_TAG: u8 = b'R'; +const PORTAL_SUSPENDED_TAG: u8 = b's'; +pub const PARAMETER_STATUS_TAG: u8 = b'S'; +const PARAMETER_DESCRIPTION_TAG: u8 = b't'; +const ROW_DESCRIPTION_TAG: u8 = b'T'; +pub const READY_FOR_QUERY_TAG: u8 = b'Z'; + +#[derive(Debug, Copy, Clone)] +pub struct Header { + tag: u8, + len: i32, +} + +#[allow(clippy::len_without_is_empty)] +impl Header { + #[inline] + pub fn parse(buf: &[u8]) -> io::Result> { + if buf.len() < 5 { + return Ok(None); + } + + let tag = buf[0]; + let len = BigEndian::read_i32(&buf[1..]); + + if len < 4 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length: header length < 4", + )); + } + + Ok(Some(Header { tag, len })) + } + + #[inline] + pub fn tag(self) -> u8 { + self.tag + } + + #[inline] + pub fn len(self) -> i32 { + self.len + } +} + +/// An enum representing Postgres backend messages. +#[non_exhaustive] +pub enum Message { + AuthenticationCleartextPassword, + AuthenticationGss, + AuthenticationKerberosV5, + AuthenticationMd5Password(AuthenticationMd5PasswordBody), + AuthenticationOk, + AuthenticationScmCredential, + AuthenticationSspi, + AuthenticationGssContinue, + AuthenticationSasl(AuthenticationSaslBody), + AuthenticationSaslContinue(AuthenticationSaslContinueBody), + AuthenticationSaslFinal(AuthenticationSaslFinalBody), + BackendKeyData(BackendKeyDataBody), + BindComplete, + CloseComplete, + CommandComplete(CommandCompleteBody), + CopyData, + CopyDone, + CopyInResponse, + CopyOutResponse, + CopyBothResponse, + DataRow(DataRowBody), + EmptyQueryResponse, + ErrorResponse(ErrorResponseBody), + NoData, + NoticeResponse(NoticeResponseBody), + NotificationResponse(NotificationResponseBody), + ParameterDescription(ParameterDescriptionBody), + ParameterStatus(ParameterStatusBody), + ParseComplete, + PortalSuspended, + ReadyForQuery(ReadyForQueryBody), + RowDescription(RowDescriptionBody), +} + +impl Message { + #[inline] + pub fn parse(buf: &mut BytesMut) -> io::Result> { + if buf.len() < 5 { + let to_read = 5 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + let tag = buf[0]; + let len = (&buf[1..5]).read_u32::().unwrap(); + + if len < 4 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: parsing u32", + )); + } + + let total_len = len as usize + 1; + if buf.len() < total_len { + let to_read = total_len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + let mut buf = Buffer { + bytes: buf.split_to(total_len).freeze(), + idx: 5, + }; + + let message = match tag { + PARSE_COMPLETE_TAG => Message::ParseComplete, + BIND_COMPLETE_TAG => Message::BindComplete, + CLOSE_COMPLETE_TAG => Message::CloseComplete, + NOTIFICATION_RESPONSE_TAG => { + let process_id = buf.read_i32::()?; + let channel = buf.read_cstr()?; + let message = buf.read_cstr()?; + Message::NotificationResponse(NotificationResponseBody { + process_id, + channel, + message, + }) + } + COPY_DONE_TAG => Message::CopyDone, + COMMAND_COMPLETE_TAG => { + let tag = buf.read_cstr()?; + Message::CommandComplete(CommandCompleteBody { tag }) + } + COPY_DATA_TAG => Message::CopyData, + DATA_ROW_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::DataRow(DataRowBody { storage, len }) + } + ERROR_RESPONSE_TAG => { + let storage = buf.read_all(); + Message::ErrorResponse(ErrorResponseBody { storage }) + } + COPY_IN_RESPONSE_TAG => Message::CopyInResponse, + COPY_OUT_RESPONSE_TAG => Message::CopyOutResponse, + COPY_BOTH_RESPONSE_TAG => Message::CopyBothResponse, + EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, + BACKEND_KEY_DATA_TAG => { + let process_id = buf.read_i32::()?; + let secret_key = buf.read_i32::()?; + Message::BackendKeyData(BackendKeyDataBody { + process_id, + secret_key, + }) + } + NO_DATA_TAG => Message::NoData, + NOTICE_RESPONSE_TAG => { + let storage = buf.read_all(); + Message::NoticeResponse(NoticeResponseBody { storage }) + } + AUTHENTICATION_TAG => match buf.read_i32::()? { + 0 => Message::AuthenticationOk, + 2 => Message::AuthenticationKerberosV5, + 3 => Message::AuthenticationCleartextPassword, + 5 => { + let mut salt = [0; 4]; + buf.read_exact(&mut salt)?; + Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt }) + } + 6 => Message::AuthenticationScmCredential, + 7 => Message::AuthenticationGss, + 8 => Message::AuthenticationGssContinue, + 9 => Message::AuthenticationSspi, + 10 => { + let storage = buf.read_all(); + Message::AuthenticationSasl(AuthenticationSaslBody(storage)) + } + 11 => { + let storage = buf.read_all(); + Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage)) + } + 12 => { + let storage = buf.read_all(); + Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage)) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown authentication tag `{}`", tag), + )); + } + }, + PORTAL_SUSPENDED_TAG => Message::PortalSuspended, + PARAMETER_STATUS_TAG => { + let name = buf.read_cstr()?; + let value = buf.read_cstr()?; + Message::ParameterStatus(ParameterStatusBody { name, value }) + } + PARAMETER_DESCRIPTION_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::ParameterDescription(ParameterDescriptionBody { storage, len }) + } + ROW_DESCRIPTION_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::RowDescription(RowDescriptionBody { storage, len }) + } + READY_FOR_QUERY_TAG => { + let status = buf.read_u8()?; + Message::ReadyForQuery(ReadyForQueryBody { status }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown message tag `{}`", tag), + )); + } + }; + + if !buf.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: expected buffer to be empty", + )); + } + + Ok(Some(message)) + } +} + +struct Buffer { + bytes: Bytes, + idx: usize, +} + +impl Buffer { + #[inline] + fn slice(&self) -> &[u8] { + &self.bytes[self.idx..] + } + + #[inline] + fn is_empty(&self) -> bool { + self.slice().is_empty() + } + + #[inline] + fn read_cstr(&mut self) -> io::Result { + match memchr(0, self.slice()) { + Some(pos) => { + let start = self.idx; + let end = start + pos; + let cstr = self.bytes.slice(start..end); + self.idx = end + 1; + Ok(cstr) + } + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } + } + + #[inline] + fn read_all(&mut self) -> Bytes { + let buf = self.bytes.slice(self.idx..); + self.idx = self.bytes.len(); + buf + } +} + +impl Read for Buffer { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let len = { + let slice = self.slice(); + let len = cmp::min(slice.len(), buf.len()); + buf[..len].copy_from_slice(&slice[..len]); + len + }; + self.idx += len; + Ok(len) + } +} + +pub struct AuthenticationMd5PasswordBody { + salt: [u8; 4], +} + +impl AuthenticationMd5PasswordBody { + #[inline] + pub fn salt(&self) -> [u8; 4] { + self.salt + } +} + +pub struct AuthenticationSaslBody(Bytes); + +impl AuthenticationSaslBody { + #[inline] + pub fn mechanisms(&self) -> SaslMechanisms<'_> { + SaslMechanisms(&self.0) + } +} + +pub struct SaslMechanisms<'a>(&'a [u8]); + +impl<'a> FallibleIterator for SaslMechanisms<'a> { + type Item = &'a str; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + let value_end = find_null(self.0, 0)?; + if value_end == 0 { + if self.0.len() != 1 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length: expected to be at end of iterator for sasl", + )); + } + Ok(None) + } else { + let value = get_str(&self.0[..value_end])?; + self.0 = &self.0[value_end + 1..]; + Ok(Some(value)) + } + } +} + +pub struct AuthenticationSaslContinueBody(Bytes); + +impl AuthenticationSaslContinueBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.0 + } +} + +pub struct AuthenticationSaslFinalBody(Bytes); + +impl AuthenticationSaslFinalBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.0 + } +} + +pub struct BackendKeyDataBody { + process_id: i32, + secret_key: i32, +} + +impl BackendKeyDataBody { + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn secret_key(&self) -> i32 { + self.secret_key + } +} + +pub struct CommandCompleteBody { + tag: Bytes, +} + +impl CommandCompleteBody { + #[inline] + pub fn tag(&self) -> io::Result<&str> { + get_str(&self.tag) + } +} + +#[derive(Debug)] +pub struct DataRowBody { + storage: Bytes, + len: u16, +} + +impl DataRowBody { + #[inline] + pub fn ranges(&self) -> DataRowRanges<'_> { + DataRowRanges { + buf: &self.storage, + len: self.storage.len(), + remaining: self.len, + } + } + + #[inline] + pub fn buffer(&self) -> &[u8] { + &self.storage + } +} + +pub struct DataRowRanges<'a> { + buf: &'a [u8], + len: usize, + remaining: u16, +} + +impl FallibleIterator for DataRowRanges<'_> { + type Item = Option>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: datarowrange is not empty", + )); + } + } + + self.remaining -= 1; + let len = self.buf.read_i32::()?; + if len < 0 { + Ok(Some(None)) + } else { + let len = len as usize; + if self.buf.len() < len { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )); + } + let base = self.len - self.buf.len(); + self.buf = &self.buf[len..]; + Ok(Some(Some(base..base + len))) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ErrorResponseBody { + storage: Bytes, +} + +impl ErrorResponseBody { + #[inline] + pub fn fields(&self) -> ErrorFields<'_> { + ErrorFields { buf: &self.storage } + } +} + +pub struct ErrorFields<'a> { + buf: &'a [u8], +} + +impl<'a> FallibleIterator for ErrorFields<'a> { + type Item = ErrorField<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + let type_ = self.buf.read_u8()?; + if type_ == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: error fields is not drained", + )); + } + } + + let value_end = find_null(self.buf, 0)?; + let value = get_str(&self.buf[..value_end])?; + self.buf = &self.buf[value_end + 1..]; + + Ok(Some(ErrorField { type_, value })) + } +} + +pub struct ErrorField<'a> { + type_: u8, + value: &'a str, +} + +impl ErrorField<'_> { + #[inline] + pub fn type_(&self) -> u8 { + self.type_ + } + + #[inline] + pub fn value(&self) -> &str { + self.value + } +} + +pub struct NoticeResponseBody { + storage: Bytes, +} + +impl NoticeResponseBody { + #[inline] + pub fn fields(&self) -> ErrorFields<'_> { + ErrorFields { buf: &self.storage } + } +} + +pub struct NotificationResponseBody { + process_id: i32, + channel: Bytes, + message: Bytes, +} + +impl NotificationResponseBody { + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn channel(&self) -> io::Result<&str> { + get_str(&self.channel) + } + + #[inline] + pub fn message(&self) -> io::Result<&str> { + get_str(&self.message) + } +} + +pub struct ParameterDescriptionBody { + storage: Bytes, + len: u16, +} + +impl ParameterDescriptionBody { + #[inline] + pub fn parameters(&self) -> Parameters<'_> { + Parameters { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Parameters<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl FallibleIterator for Parameters<'_> { + type Item = Oid; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: parameters is not drained", + )); + } + } + + self.remaining -= 1; + self.buf.read_u32::().map(Some) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ParameterStatusBody { + name: Bytes, + value: Bytes, +} + +impl ParameterStatusBody { + #[inline] + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + pub fn value(&self) -> io::Result<&str> { + get_str(&self.value) + } +} + +pub struct ReadyForQueryBody { + status: u8, +} + +impl ReadyForQueryBody { + #[inline] + pub fn status(&self) -> u8 { + self.status + } +} + +pub struct RowDescriptionBody { + storage: Bytes, + len: u16, +} + +impl RowDescriptionBody { + #[inline] + pub fn fields(&self) -> Fields<'_> { + Fields { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Fields<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl<'a> FallibleIterator for Fields<'a> { + type Item = Field<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: field is not drained", + )); + } + } + + self.remaining -= 1; + let name_end = find_null(self.buf, 0)?; + let name = get_str(&self.buf[..name_end])?; + self.buf = &self.buf[name_end + 1..]; + let table_oid = self.buf.read_u32::()?; + let column_id = self.buf.read_i16::()?; + let type_oid = self.buf.read_u32::()?; + let type_size = self.buf.read_i16::()?; + let type_modifier = self.buf.read_i32::()?; + let format = self.buf.read_i16::()?; + + Ok(Some(Field { + name, + table_oid, + column_id, + type_oid, + type_size, + type_modifier, + format, + })) + } +} + +pub struct Field<'a> { + name: &'a str, + table_oid: Oid, + column_id: i16, + type_oid: Oid, + type_size: i16, + type_modifier: i32, + format: i16, +} + +impl<'a> Field<'a> { + #[inline] + pub fn name(&self) -> &'a str { + self.name + } + + #[inline] + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + #[inline] + pub fn column_id(&self) -> i16 { + self.column_id + } + + #[inline] + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + #[inline] + pub fn type_size(&self) -> i16 { + self.type_size + } + + #[inline] + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } + + #[inline] + pub fn format(&self) -> i16 { + self.format + } +} + +#[inline] +fn find_null(buf: &[u8], start: usize) -> io::Result { + match memchr(0, &buf[start..]) { + Some(pos) => Ok(pos + start), + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } +} + +#[inline] +fn get_str(buf: &[u8]) -> io::Result<&str> { + str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) +} diff --git a/libs/proxy/postgres-protocol2/src/message/frontend.rs b/libs/proxy/postgres-protocol2/src/message/frontend.rs new file mode 100644 index 0000000000..5d0a8ff8c8 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/message/frontend.rs @@ -0,0 +1,297 @@ +//! Frontend message serialization. +#![allow(missing_docs)] + +use byteorder::{BigEndian, ByteOrder}; +use bytes::{Buf, BufMut, BytesMut}; +use std::convert::TryFrom; +use std::error::Error; +use std::io; +use std::marker; + +use crate::{write_nullable, FromUsize, IsNull, Oid}; + +#[inline] +fn write_body(buf: &mut BytesMut, f: F) -> Result<(), E> +where + F: FnOnce(&mut BytesMut) -> Result<(), E>, + E: From, +{ + let base = buf.len(); + buf.extend_from_slice(&[0; 4]); + + f(buf)?; + + let size = i32::from_usize(buf.len() - base)?; + BigEndian::write_i32(&mut buf[base..], size); + Ok(()) +} + +pub enum BindError { + Conversion(Box), + Serialization(io::Error), +} + +impl From> for BindError { + #[inline] + fn from(e: Box) -> BindError { + BindError::Conversion(e) + } +} + +impl From for BindError { + #[inline] + fn from(e: io::Error) -> BindError { + BindError::Serialization(e) + } +} + +#[inline] +pub fn bind( + portal: &str, + statement: &str, + formats: I, + values: J, + mut serializer: F, + result_formats: K, + buf: &mut BytesMut, +) -> Result<(), BindError> +where + I: IntoIterator, + J: IntoIterator, + F: FnMut(T, &mut BytesMut) -> Result>, + K: IntoIterator, +{ + buf.put_u8(b'B'); + + write_body(buf, |buf| { + write_cstr(portal.as_bytes(), buf)?; + write_cstr(statement.as_bytes(), buf)?; + write_counted( + formats, + |f, buf| { + buf.put_i16(f); + Ok::<_, io::Error>(()) + }, + buf, + )?; + write_counted( + values, + |v, buf| write_nullable(|buf| serializer(v, buf), buf), + buf, + )?; + write_counted( + result_formats, + |f, buf| { + buf.put_i16(f); + Ok::<_, io::Error>(()) + }, + buf, + )?; + + Ok(()) + }) +} + +#[inline] +fn write_counted(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E> +where + I: IntoIterator, + F: FnMut(T, &mut BytesMut) -> Result<(), E>, + E: From, +{ + let base = buf.len(); + buf.extend_from_slice(&[0; 2]); + let mut count = 0; + for item in items { + serializer(item, buf)?; + count += 1; + } + let count = i16::from_usize(count)?; + BigEndian::write_i16(&mut buf[base..], count); + + Ok(()) +} + +#[inline] +pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) { + write_body(buf, |buf| { + buf.put_i32(80_877_102); + buf.put_i32(process_id); + buf.put_i32(secret_key); + Ok::<_, io::Error>(()) + }) + .unwrap(); +} + +#[inline] +pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'C'); + write_body(buf, |buf| { + buf.put_u8(variant); + write_cstr(name.as_bytes(), buf) + }) +} + +pub struct CopyData { + buf: T, + len: i32, +} + +impl CopyData +where + T: Buf, +{ + pub fn new(buf: T) -> io::Result> { + let len = buf + .remaining() + .checked_add(4) + .and_then(|l| i32::try_from(l).ok()) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "message length overflow") + })?; + + Ok(CopyData { buf, len }) + } + + pub fn write(self, out: &mut BytesMut) { + out.put_u8(b'd'); + out.put_i32(self.len); + out.put(self.buf); + } +} + +#[inline] +pub fn copy_done(buf: &mut BytesMut) { + buf.put_u8(b'c'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + +#[inline] +pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'f'); + write_body(buf, |buf| write_cstr(message.as_bytes(), buf)) +} + +#[inline] +pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'D'); + write_body(buf, |buf| { + buf.put_u8(variant); + write_cstr(name.as_bytes(), buf) + }) +} + +#[inline] +pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'E'); + write_body(buf, |buf| { + write_cstr(portal.as_bytes(), buf)?; + buf.put_i32(max_rows); + Ok(()) + }) +} + +#[inline] +pub fn parse(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()> +where + I: IntoIterator, +{ + buf.put_u8(b'P'); + write_body(buf, |buf| { + write_cstr(name.as_bytes(), buf)?; + write_cstr(query.as_bytes(), buf)?; + write_counted( + param_types, + |t, buf| { + buf.put_u32(t); + Ok::<_, io::Error>(()) + }, + buf, + )?; + Ok(()) + }) +} + +#[inline] +pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| write_cstr(password, buf)) +} + +#[inline] +pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'Q'); + write_body(buf, |buf| write_cstr(query.as_bytes(), buf)) +} + +#[inline] +pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| { + write_cstr(mechanism.as_bytes(), buf)?; + let len = i32::from_usize(data.len())?; + buf.put_i32(len); + buf.put_slice(data); + Ok(()) + }) +} + +#[inline] +pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| { + buf.put_slice(data); + Ok(()) + }) +} + +#[inline] +pub fn ssl_request(buf: &mut BytesMut) { + write_body(buf, |buf| { + buf.put_i32(80_877_103); + Ok::<_, io::Error>(()) + }) + .unwrap(); +} + +#[inline] +pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()> +where + I: IntoIterator, +{ + write_body(buf, |buf| { + // postgres protocol version 3.0(196608) in bigger-endian + buf.put_i32(0x00_03_00_00); + for (key, value) in parameters { + write_cstr(key.as_bytes(), buf)?; + write_cstr(value.as_bytes(), buf)?; + } + buf.put_u8(0); + Ok(()) + }) +} + +#[inline] +pub fn sync(buf: &mut BytesMut) { + buf.put_u8(b'S'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + +#[inline] +pub fn terminate(buf: &mut BytesMut) { + buf.put_u8(b'X'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + +#[inline] +fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { + if s.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "string contains embedded null", + )); + } + buf.put_slice(s); + buf.put_u8(0); + Ok(()) +} diff --git a/libs/proxy/postgres-protocol2/src/message/mod.rs b/libs/proxy/postgres-protocol2/src/message/mod.rs new file mode 100644 index 0000000000..9e5d997548 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/message/mod.rs @@ -0,0 +1,8 @@ +//! Postgres message protocol support. +//! +//! See [Postgres's documentation][docs] for more information on message flow. +//! +//! [docs]: https://www.postgresql.org/docs/9.5/static/protocol-flow.html + +pub mod backend; +pub mod frontend; diff --git a/libs/proxy/postgres-protocol2/src/password/mod.rs b/libs/proxy/postgres-protocol2/src/password/mod.rs new file mode 100644 index 0000000000..e669e80f3f --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/password/mod.rs @@ -0,0 +1,107 @@ +//! Functions to encrypt a password in the client. +//! +//! This is intended to be used by client applications that wish to +//! send commands like `ALTER USER joe PASSWORD 'pwd'`. The password +//! need not be sent in cleartext if it is encrypted on the client +//! side. This is good because it ensures the cleartext password won't +//! end up in logs pg_stat displays, etc. + +use crate::authentication::sasl; +use hmac::{Hmac, Mac}; +use md5::Md5; +use rand::RngCore; +use sha2::digest::FixedOutput; +use sha2::{Digest, Sha256}; + +#[cfg(test)] +mod test; + +const SCRAM_DEFAULT_ITERATIONS: u32 = 4096; +const SCRAM_DEFAULT_SALT_LEN: usize = 16; + +/// Hash password using SCRAM-SHA-256 with a randomly-generated +/// salt. +/// +/// The client may assume the returned string doesn't contain any +/// special characters that would require escaping in an SQL command. +pub async fn scram_sha_256(password: &[u8]) -> String { + let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN]; + let mut rng = rand::thread_rng(); + rng.fill_bytes(&mut salt); + scram_sha_256_salt(password, salt).await +} + +// Internal implementation of scram_sha_256 with a caller-provided +// salt. This is useful for testing. +pub(crate) async fn scram_sha_256_salt( + password: &[u8], + salt: [u8; SCRAM_DEFAULT_SALT_LEN], +) -> String { + // Prepare the password, per [RFC + // 4013](https://tools.ietf.org/html/rfc4013), if possible. + // + // Postgres treats passwords as byte strings (without embedded NUL + // bytes), but SASL expects passwords to be valid UTF-8. + // + // Follow the behavior of libpq's PQencryptPasswordConn(), and + // also the backend. If the password is not valid UTF-8, or if it + // contains prohibited characters (such as non-ASCII whitespace), + // just skip the SASLprep step and use the original byte + // sequence. + let prepared: Vec = match std::str::from_utf8(password) { + Ok(password_str) => { + match stringprep::saslprep(password_str) { + Ok(p) => p.into_owned().into_bytes(), + // contains invalid characters; skip saslprep + Err(_) => Vec::from(password), + } + } + // not valid UTF-8; skip saslprep + Err(_) => Vec::from(password), + }; + + // salt password + let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS).await; + + // client key + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Client Key"); + let client_key = hmac.finalize().into_bytes(); + + // stored key + let mut hash = Sha256::default(); + hash.update(client_key.as_slice()); + let stored_key = hash.finalize_fixed(); + + // server key + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Server Key"); + let server_key = hmac.finalize().into_bytes(); + + format!( + "SCRAM-SHA-256${}:{}${}:{}", + SCRAM_DEFAULT_ITERATIONS, + base64::encode(salt), + base64::encode(stored_key), + base64::encode(server_key) + ) +} + +/// **Not recommended, as MD5 is not considered to be secure.** +/// +/// Hash password using MD5 with the username as the salt. +/// +/// The client may assume the returned string doesn't contain any +/// special characters that would require escaping. +pub fn md5(password: &[u8], username: &str) -> String { + // salt password with username + let mut salted_password = Vec::from(password); + salted_password.extend_from_slice(username.as_bytes()); + + let mut hash = Md5::new(); + hash.update(&salted_password); + let digest = hash.finalize(); + format!("md5{:x}", digest) +} diff --git a/libs/proxy/postgres-protocol2/src/password/test.rs b/libs/proxy/postgres-protocol2/src/password/test.rs new file mode 100644 index 0000000000..c9d340f09d --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/password/test.rs @@ -0,0 +1,19 @@ +use crate::password; + +#[tokio::test] +async fn test_encrypt_scram_sha_256() { + // Specify the salt to make the test deterministic. Any bytes will do. + let salt: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + assert_eq!( + password::scram_sha_256_salt(b"secret", salt).await, + "SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA=" + ); +} + +#[test] +fn test_encrypt_md5() { + assert_eq!( + password::md5(b"secret", "foo"), + "md54ab2c5d00339c4b2a4e921d2dc4edec7" + ); +} diff --git a/libs/proxy/postgres-protocol2/src/types/mod.rs b/libs/proxy/postgres-protocol2/src/types/mod.rs new file mode 100644 index 0000000000..78131c05bf --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/types/mod.rs @@ -0,0 +1,294 @@ +//! Conversions to and from Postgres's binary format for various types. +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, BytesMut}; +use fallible_iterator::FallibleIterator; +use std::boxed::Box as StdBox; +use std::error::Error; +use std::str; + +use crate::Oid; + +#[cfg(test)] +mod test; + +/// Serializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value. +#[inline] +pub fn text_to_sql(v: &str, buf: &mut BytesMut) { + buf.put_slice(v.as_bytes()); +} + +/// Deserializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value. +#[inline] +pub fn text_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + Ok(str::from_utf8(buf)?) +} + +/// Deserializes a `"char"` value. +#[inline] +pub fn char_from_sql(mut buf: &[u8]) -> Result> { + let v = buf.read_i8()?; + if !buf.is_empty() { + return Err("invalid buffer size".into()); + } + Ok(v) +} + +/// Serializes an `OID` value. +#[inline] +pub fn oid_to_sql(v: Oid, buf: &mut BytesMut) { + buf.put_u32(v); +} + +/// Deserializes an `OID` value. +#[inline] +pub fn oid_from_sql(mut buf: &[u8]) -> Result> { + let v = buf.read_u32::()?; + if !buf.is_empty() { + return Err("invalid buffer size".into()); + } + Ok(v) +} + +/// A fallible iterator over `HSTORE` entries. +pub struct HstoreEntries<'a> { + remaining: i32, + buf: &'a [u8], +} + +impl<'a> FallibleIterator for HstoreEntries<'a> { + type Item = (&'a str, Option<&'a str>); + type Error = StdBox; + + #[inline] + #[allow(clippy::type_complexity)] + fn next( + &mut self, + ) -> Result)>, StdBox> { + if self.remaining == 0 { + if !self.buf.is_empty() { + return Err("invalid buffer size".into()); + } + return Ok(None); + } + + self.remaining -= 1; + + let key_len = self.buf.read_i32::()?; + if key_len < 0 { + return Err("invalid key length".into()); + } + let (key, buf) = self.buf.split_at(key_len as usize); + let key = str::from_utf8(key)?; + self.buf = buf; + + let value_len = self.buf.read_i32::()?; + let value = if value_len < 0 { + None + } else { + let (value, buf) = self.buf.split_at(value_len as usize); + let value = str::from_utf8(value)?; + self.buf = buf; + Some(value) + }; + + Ok(Some((key, value))) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +/// Deserializes an array value. +#[inline] +pub fn array_from_sql(mut buf: &[u8]) -> Result, StdBox> { + let dimensions = buf.read_i32::()?; + if dimensions < 0 { + return Err("invalid dimension count".into()); + } + + let mut r = buf; + let mut elements = 1i32; + for _ in 0..dimensions { + let len = r.read_i32::()?; + if len < 0 { + return Err("invalid dimension size".into()); + } + let _lower_bound = r.read_i32::()?; + elements = match elements.checked_mul(len) { + Some(elements) => elements, + None => return Err("too many array elements".into()), + }; + } + + if dimensions == 0 { + elements = 0; + } + + Ok(Array { + dimensions, + elements, + buf, + }) +} + +/// A Postgres array. +pub struct Array<'a> { + dimensions: i32, + elements: i32, + buf: &'a [u8], +} + +impl<'a> Array<'a> { + /// Returns an iterator over the dimensions of the array. + #[inline] + pub fn dimensions(&self) -> ArrayDimensions<'a> { + ArrayDimensions(&self.buf[..self.dimensions as usize * 8]) + } + + /// Returns an iterator over the values of the array. + #[inline] + pub fn values(&self) -> ArrayValues<'a> { + ArrayValues { + remaining: self.elements, + buf: &self.buf[self.dimensions as usize * 8..], + } + } +} + +/// An iterator over the dimensions of an array. +pub struct ArrayDimensions<'a>(&'a [u8]); + +impl FallibleIterator for ArrayDimensions<'_> { + type Item = ArrayDimension; + type Error = StdBox; + + #[inline] + fn next(&mut self) -> Result, StdBox> { + if self.0.is_empty() { + return Ok(None); + } + + let len = self.0.read_i32::()?; + let lower_bound = self.0.read_i32::()?; + + Ok(Some(ArrayDimension { len, lower_bound })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.0.len() / 8; + (len, Some(len)) + } +} + +/// Information about a dimension of an array. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct ArrayDimension { + /// The length of this dimension. + pub len: i32, + + /// The base value used to index into this dimension. + pub lower_bound: i32, +} + +/// An iterator over the values of an array, in row-major order. +pub struct ArrayValues<'a> { + remaining: i32, + buf: &'a [u8], +} + +impl<'a> FallibleIterator for ArrayValues<'a> { + type Item = Option<&'a [u8]>; + type Error = StdBox; + + #[inline] + fn next(&mut self) -> Result>, StdBox> { + if self.remaining == 0 { + if !self.buf.is_empty() { + return Err("invalid message length: arrayvalue not drained".into()); + } + return Ok(None); + } + self.remaining -= 1; + + let len = self.buf.read_i32::()?; + let val = if len < 0 { + None + } else { + if self.buf.len() < len as usize { + return Err("invalid value length".into()); + } + + let (val, buf) = self.buf.split_at(len as usize); + self.buf = buf; + Some(val) + }; + + Ok(Some(val)) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +/// Serializes a Postgres ltree string +#[inline] +pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltree string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltree string +#[inline] +pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltree per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltree version 1 only supported".into()), + } +} + +/// Serializes a Postgres lquery string +#[inline] +pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an lquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres lquery string +#[inline] +pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the lquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("lquery version 1 only supported".into()), + } +} + +/// Serializes a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltxtquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltxtquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltxtquery version 1 only supported".into()), + } +} diff --git a/libs/proxy/postgres-protocol2/src/types/test.rs b/libs/proxy/postgres-protocol2/src/types/test.rs new file mode 100644 index 0000000000..96cc055bc3 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/types/test.rs @@ -0,0 +1,87 @@ +use bytes::{Buf, BytesMut}; + +use super::*; + +#[test] +fn ltree_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltree_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn ltree_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_err()) +} + +#[test] +fn lquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + lquery_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn lquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(lquery_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn lquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(lquery_from_sql(query.as_slice()).is_err()) +} + +#[test] +fn ltxtquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("a & b*", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltxtquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn ltxtquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_err()) +} diff --git a/libs/proxy/postgres-types2/Cargo.toml b/libs/proxy/postgres-types2/Cargo.toml new file mode 100644 index 0000000000..58cfb5571f --- /dev/null +++ b/libs/proxy/postgres-types2/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "postgres-types2" +version = "0.1.0" +edition = "2018" +license = "MIT/Apache-2.0" + +[dependencies] +bytes.workspace = true +fallible-iterator.workspace = true +postgres-protocol2 = { path = "../postgres-protocol2" } diff --git a/libs/proxy/postgres-types2/src/lib.rs b/libs/proxy/postgres-types2/src/lib.rs new file mode 100644 index 0000000000..18ba032151 --- /dev/null +++ b/libs/proxy/postgres-types2/src/lib.rs @@ -0,0 +1,477 @@ +//! Conversions to and from Postgres types. +//! +//! This crate is used by the `tokio-postgres` and `postgres` crates. You normally don't need to depend directly on it +//! unless you want to define your own `ToSql` or `FromSql` definitions. +#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] +#![warn(clippy::all, rust_2018_idioms, missing_docs)] + +use fallible_iterator::FallibleIterator; +use postgres_protocol2::types; +use std::any::type_name; +use std::error::Error; +use std::fmt; +use std::sync::Arc; + +use crate::type_gen::{Inner, Other}; + +#[doc(inline)] +pub use postgres_protocol2::Oid; + +use bytes::BytesMut; + +/// Generates a simple implementation of `ToSql::accepts` which accepts the +/// types passed to it. +macro_rules! accepts { + ($($expected:ident),+) => ( + fn accepts(ty: &$crate::Type) -> bool { + matches!(*ty, $($crate::Type::$expected)|+) + } + ) +} + +/// Generates an implementation of `ToSql::to_sql_checked`. +/// +/// All `ToSql` implementations should use this macro. +macro_rules! to_sql_checked { + () => { + fn to_sql_checked( + &self, + ty: &$crate::Type, + out: &mut $crate::private::BytesMut, + ) -> ::std::result::Result< + $crate::IsNull, + Box, + > { + $crate::__to_sql_checked(self, ty, out) + } + }; +} + +// WARNING: this function is not considered part of this crate's public API. +// It is subject to change at any time. +#[doc(hidden)] +pub fn __to_sql_checked( + v: &T, + ty: &Type, + out: &mut BytesMut, +) -> Result> +where + T: ToSql, +{ + if !T::accepts(ty) { + return Err(Box::new(WrongType::new::(ty.clone()))); + } + v.to_sql(ty, out) +} + +// mod pg_lsn; +#[doc(hidden)] +pub mod private; +// mod special; +mod type_gen; + +/// A Postgres type. +#[derive(PartialEq, Eq, Clone, Hash)] +pub struct Type(Inner); + +impl fmt::Debug for Type { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.0, fmt) + } +} + +impl fmt::Display for Type { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.schema() { + "public" | "pg_catalog" => {} + schema => write!(fmt, "{}.", schema)?, + } + fmt.write_str(self.name()) + } +} + +impl Type { + /// Creates a new `Type`. + pub fn new(name: String, oid: Oid, kind: Kind, schema: String) -> Type { + Type(Inner::Other(Arc::new(Other { + name, + oid, + kind, + schema, + }))) + } + + /// Returns the `Type` corresponding to the provided `Oid` if it + /// corresponds to a built-in type. + pub fn from_oid(oid: Oid) -> Option { + Inner::from_oid(oid).map(Type) + } + + /// Returns the OID of the `Type`. + pub fn oid(&self) -> Oid { + self.0.oid() + } + + /// Returns the kind of this type. + pub fn kind(&self) -> &Kind { + self.0.kind() + } + + /// Returns the schema of this type. + pub fn schema(&self) -> &str { + match self.0 { + Inner::Other(ref u) => &u.schema, + _ => "pg_catalog", + } + } + + /// Returns the name of this type. + pub fn name(&self) -> &str { + self.0.name() + } +} + +/// Represents the kind of a Postgres type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum Kind { + /// A simple type like `VARCHAR` or `INTEGER`. + Simple, + /// An enumerated type along with its variants. + Enum(Vec), + /// A pseudo-type. + Pseudo, + /// An array type along with the type of its elements. + Array(Type), + /// A range type along with the type of its elements. + Range(Type), + /// A multirange type along with the type of its elements. + Multirange(Type), + /// A domain type along with its underlying type. + Domain(Type), + /// A composite type along with information about its fields. + Composite(Vec), +} + +/// Information about a field of a composite type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Field { + name: String, + type_: Type, +} + +impl Field { + /// Creates a new `Field`. + pub fn new(name: String, type_: Type) -> Field { + Field { name, type_ } + } + + /// Returns the name of the field. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the type of the field. + pub fn type_(&self) -> &Type { + &self.type_ + } +} + +/// An error indicating that a `NULL` Postgres value was passed to a `FromSql` +/// implementation that does not support `NULL` values. +#[derive(Debug, Clone, Copy)] +pub struct WasNull; + +impl fmt::Display for WasNull { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("a Postgres value was `NULL`") + } +} + +impl Error for WasNull {} + +/// An error indicating that a conversion was attempted between incompatible +/// Rust and Postgres types. +#[derive(Debug)] +pub struct WrongType { + postgres: Type, + rust: &'static str, +} + +impl fmt::Display for WrongType { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot convert between the Rust type `{}` and the Postgres type `{}`", + self.rust, self.postgres, + ) + } +} + +impl Error for WrongType {} + +impl WrongType { + /// Creates a new `WrongType` error. + pub fn new(ty: Type) -> WrongType { + WrongType { + postgres: ty, + rust: type_name::(), + } + } +} + +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + +/// A trait for types that can be created from a Postgres value. +pub trait FromSql<'a>: Sized { + /// Creates a new value of this type from a buffer of data of the specified + /// Postgres `Type` in its binary format. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result>; + + /// Creates a new value of this type from a `NULL` SQL value. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + /// + /// The default implementation returns `Err(Box::new(WasNull))`. + #[allow(unused_variables)] + fn from_sql_null(ty: &Type) -> Result> { + Err(Box::new(WasNull)) + } + + /// A convenience function that delegates to `from_sql` and `from_sql_null` depending on the + /// value of `raw`. + fn from_sql_nullable( + ty: &Type, + raw: Option<&'a [u8]>, + ) -> Result> { + match raw { + Some(raw) => Self::from_sql(ty, raw), + None => Self::from_sql_null(ty), + } + } + + /// Determines if a value of this type can be created from the specified + /// Postgres `Type`. + fn accepts(ty: &Type) -> bool; +} + +/// A trait for types which can be created from a Postgres value without borrowing any data. +/// +/// This is primarily useful for trait bounds on functions. +pub trait FromSqlOwned: for<'a> FromSql<'a> {} + +impl FromSqlOwned for T where T: for<'a> FromSql<'a> {} + +impl<'a, T: FromSql<'a>> FromSql<'a> for Option { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + ::from_sql(ty, raw).map(Some) + } + + fn from_sql_null(_: &Type) -> Result, Box> { + Ok(None) + } + + fn accepts(ty: &Type) -> bool { + ::accepts(ty) + } +} + +impl<'a, T: FromSql<'a>> FromSql<'a> for Vec { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + let member_type = match *ty.kind() { + Kind::Array(ref member) => member, + _ => panic!("expected array type"), + }; + + let array = types::array_from_sql(raw)?; + if array.dimensions().count()? > 1 { + return Err("array contains too many dimensions".into()); + } + + array + .values() + .map(|v| T::from_sql_nullable(member_type, v)) + .collect() + } + + fn accepts(ty: &Type) -> bool { + match *ty.kind() { + Kind::Array(ref inner) => T::accepts(inner), + _ => false, + } + } +} + +impl<'a> FromSql<'a> for String { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + <&str as FromSql>::from_sql(ty, raw).map(ToString::to_string) + } + + fn accepts(ty: &Type) -> bool { + <&str as FromSql>::accepts(ty) + } +} + +impl<'a> FromSql<'a> for &'a str { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw), + ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw), + _ => types::text_from_sql(raw), + } + } + + fn accepts(ty: &Type) -> bool { + match *ty { + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } + _ => false, + } + } +} + +macro_rules! simple_from { + ($t:ty, $f:ident, $($expected:ident),+) => { + impl<'a> FromSql<'a> for $t { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result<$t, Box> { + types::$f(raw) + } + + accepts!($($expected),+); + } + } +} + +simple_from!(i8, char_from_sql, CHAR); +simple_from!(u32, oid_from_sql, OID); + +/// An enum representing the nullability of a Postgres value. +pub enum IsNull { + /// The value is NULL. + Yes, + /// The value is not NULL. + No, +} + +/// A trait for types that can be converted into Postgres values. +pub trait ToSql: fmt::Debug { + /// Converts the value of `self` into the binary format of the specified + /// Postgres `Type`, appending it to `out`. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + /// + /// The return value indicates if this value should be represented as + /// `NULL`. If this is the case, implementations **must not** write + /// anything to `out`. + fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result> + where + Self: Sized; + + /// Determines if a value of this type can be converted to the specified + /// Postgres `Type`. + fn accepts(ty: &Type) -> bool + where + Self: Sized; + + /// An adaptor method used internally by Rust-Postgres. + /// + /// *All* implementations of this method should be generated by the + /// `to_sql_checked!()` macro. + fn to_sql_checked( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result>; + + /// Specify the encode format + fn encode_format(&self, _ty: &Type) -> Format { + Format::Binary + } +} + +/// Supported Postgres message format types +/// +/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Format { + /// Text format (UTF-8) + Text, + /// Compact, typed binary format + Binary, +} + +impl ToSql for &str { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), + ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), + _ => types::text_to_sql(self, w), + } + Ok(IsNull::No) + } + + fn accepts(ty: &Type) -> bool { + match *ty { + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } + _ => false, + } + } + + to_sql_checked!(); +} + +macro_rules! simple_to { + ($t:ty, $f:ident, $($expected:ident),+) => { + impl ToSql for $t { + fn to_sql(&self, + _: &Type, + w: &mut BytesMut) + -> Result> { + types::$f(*self, w); + Ok(IsNull::No) + } + + accepts!($($expected),+); + + to_sql_checked!(); + } + } +} + +simple_to!(u32, oid_to_sql, OID); diff --git a/libs/proxy/postgres-types2/src/private.rs b/libs/proxy/postgres-types2/src/private.rs new file mode 100644 index 0000000000..774f9a301c --- /dev/null +++ b/libs/proxy/postgres-types2/src/private.rs @@ -0,0 +1,34 @@ +use crate::{FromSql, Type}; +pub use bytes::BytesMut; +use std::error::Error; + +pub fn read_be_i32(buf: &mut &[u8]) -> Result> { + if buf.len() < 4 { + return Err("invalid buffer size".into()); + } + let mut bytes = [0; 4]; + bytes.copy_from_slice(&buf[..4]); + *buf = &buf[4..]; + Ok(i32::from_be_bytes(bytes)) +} + +pub fn read_value<'a, T>( + type_: &Type, + buf: &mut &'a [u8], +) -> Result> +where + T: FromSql<'a>, +{ + let len = read_be_i32(buf)?; + let value = if len < 0 { + None + } else { + if len as usize > buf.len() { + return Err("invalid buffer size".into()); + } + let (head, tail) = buf.split_at(len as usize); + *buf = tail; + Some(head) + }; + T::from_sql_nullable(type_, value) +} diff --git a/libs/proxy/postgres-types2/src/type_gen.rs b/libs/proxy/postgres-types2/src/type_gen.rs new file mode 100644 index 0000000000..a1bc3f85c0 --- /dev/null +++ b/libs/proxy/postgres-types2/src/type_gen.rs @@ -0,0 +1,1524 @@ +// Autogenerated file - DO NOT EDIT +use std::sync::Arc; + +use crate::{Kind, Oid, Type}; + +#[derive(PartialEq, Eq, Debug, Hash)] +pub struct Other { + pub name: String, + pub oid: Oid, + pub kind: Kind, + pub schema: String, +} + +#[derive(PartialEq, Eq, Clone, Debug, Hash)] +pub enum Inner { + Bool, + Bytea, + Char, + Name, + Int8, + Int2, + Int2Vector, + Int4, + Regproc, + Text, + Oid, + Tid, + Xid, + Cid, + OidVector, + PgDdlCommand, + Json, + Xml, + XmlArray, + PgNodeTree, + JsonArray, + TableAmHandler, + Xid8Array, + IndexAmHandler, + Point, + Lseg, + Path, + Box, + Polygon, + Line, + LineArray, + Cidr, + CidrArray, + Float4, + Float8, + Unknown, + Circle, + CircleArray, + Macaddr8, + Macaddr8Array, + Money, + MoneyArray, + Macaddr, + Inet, + BoolArray, + ByteaArray, + CharArray, + NameArray, + Int2Array, + Int2VectorArray, + Int4Array, + RegprocArray, + TextArray, + TidArray, + XidArray, + CidArray, + OidVectorArray, + BpcharArray, + VarcharArray, + Int8Array, + PointArray, + LsegArray, + PathArray, + BoxArray, + Float4Array, + Float8Array, + PolygonArray, + OidArray, + Aclitem, + AclitemArray, + MacaddrArray, + InetArray, + Bpchar, + Varchar, + Date, + Time, + Timestamp, + TimestampArray, + DateArray, + TimeArray, + Timestamptz, + TimestamptzArray, + Interval, + IntervalArray, + NumericArray, + CstringArray, + Timetz, + TimetzArray, + Bit, + BitArray, + Varbit, + VarbitArray, + Numeric, + Refcursor, + RefcursorArray, + Regprocedure, + Regoper, + Regoperator, + Regclass, + Regtype, + RegprocedureArray, + RegoperArray, + RegoperatorArray, + RegclassArray, + RegtypeArray, + Record, + Cstring, + Any, + Anyarray, + Void, + Trigger, + LanguageHandler, + Internal, + Anyelement, + RecordArray, + Anynonarray, + TxidSnapshotArray, + Uuid, + UuidArray, + TxidSnapshot, + FdwHandler, + PgLsn, + PgLsnArray, + TsmHandler, + PgNdistinct, + PgDependencies, + Anyenum, + TsVector, + Tsquery, + GtsVector, + TsVectorArray, + GtsVectorArray, + TsqueryArray, + Regconfig, + RegconfigArray, + Regdictionary, + RegdictionaryArray, + Jsonb, + JsonbArray, + AnyRange, + EventTrigger, + Int4Range, + Int4RangeArray, + NumRange, + NumRangeArray, + TsRange, + TsRangeArray, + TstzRange, + TstzRangeArray, + DateRange, + DateRangeArray, + Int8Range, + Int8RangeArray, + Jsonpath, + JsonpathArray, + Regnamespace, + RegnamespaceArray, + Regrole, + RegroleArray, + Regcollation, + RegcollationArray, + Int4multiRange, + NummultiRange, + TsmultiRange, + TstzmultiRange, + DatemultiRange, + Int8multiRange, + AnymultiRange, + AnycompatiblemultiRange, + PgBrinBloomSummary, + PgBrinMinmaxMultiSummary, + PgMcvList, + PgSnapshot, + PgSnapshotArray, + Xid8, + Anycompatible, + Anycompatiblearray, + Anycompatiblenonarray, + AnycompatibleRange, + Int4multiRangeArray, + NummultiRangeArray, + TsmultiRangeArray, + TstzmultiRangeArray, + DatemultiRangeArray, + Int8multiRangeArray, + Other(Arc), +} + +impl Inner { + pub fn from_oid(oid: Oid) -> Option { + match oid { + 16 => Some(Inner::Bool), + 17 => Some(Inner::Bytea), + 18 => Some(Inner::Char), + 19 => Some(Inner::Name), + 20 => Some(Inner::Int8), + 21 => Some(Inner::Int2), + 22 => Some(Inner::Int2Vector), + 23 => Some(Inner::Int4), + 24 => Some(Inner::Regproc), + 25 => Some(Inner::Text), + 26 => Some(Inner::Oid), + 27 => Some(Inner::Tid), + 28 => Some(Inner::Xid), + 29 => Some(Inner::Cid), + 30 => Some(Inner::OidVector), + 32 => Some(Inner::PgDdlCommand), + 114 => Some(Inner::Json), + 142 => Some(Inner::Xml), + 143 => Some(Inner::XmlArray), + 194 => Some(Inner::PgNodeTree), + 199 => Some(Inner::JsonArray), + 269 => Some(Inner::TableAmHandler), + 271 => Some(Inner::Xid8Array), + 325 => Some(Inner::IndexAmHandler), + 600 => Some(Inner::Point), + 601 => Some(Inner::Lseg), + 602 => Some(Inner::Path), + 603 => Some(Inner::Box), + 604 => Some(Inner::Polygon), + 628 => Some(Inner::Line), + 629 => Some(Inner::LineArray), + 650 => Some(Inner::Cidr), + 651 => Some(Inner::CidrArray), + 700 => Some(Inner::Float4), + 701 => Some(Inner::Float8), + 705 => Some(Inner::Unknown), + 718 => Some(Inner::Circle), + 719 => Some(Inner::CircleArray), + 774 => Some(Inner::Macaddr8), + 775 => Some(Inner::Macaddr8Array), + 790 => Some(Inner::Money), + 791 => Some(Inner::MoneyArray), + 829 => Some(Inner::Macaddr), + 869 => Some(Inner::Inet), + 1000 => Some(Inner::BoolArray), + 1001 => Some(Inner::ByteaArray), + 1002 => Some(Inner::CharArray), + 1003 => Some(Inner::NameArray), + 1005 => Some(Inner::Int2Array), + 1006 => Some(Inner::Int2VectorArray), + 1007 => Some(Inner::Int4Array), + 1008 => Some(Inner::RegprocArray), + 1009 => Some(Inner::TextArray), + 1010 => Some(Inner::TidArray), + 1011 => Some(Inner::XidArray), + 1012 => Some(Inner::CidArray), + 1013 => Some(Inner::OidVectorArray), + 1014 => Some(Inner::BpcharArray), + 1015 => Some(Inner::VarcharArray), + 1016 => Some(Inner::Int8Array), + 1017 => Some(Inner::PointArray), + 1018 => Some(Inner::LsegArray), + 1019 => Some(Inner::PathArray), + 1020 => Some(Inner::BoxArray), + 1021 => Some(Inner::Float4Array), + 1022 => Some(Inner::Float8Array), + 1027 => Some(Inner::PolygonArray), + 1028 => Some(Inner::OidArray), + 1033 => Some(Inner::Aclitem), + 1034 => Some(Inner::AclitemArray), + 1040 => Some(Inner::MacaddrArray), + 1041 => Some(Inner::InetArray), + 1042 => Some(Inner::Bpchar), + 1043 => Some(Inner::Varchar), + 1082 => Some(Inner::Date), + 1083 => Some(Inner::Time), + 1114 => Some(Inner::Timestamp), + 1115 => Some(Inner::TimestampArray), + 1182 => Some(Inner::DateArray), + 1183 => Some(Inner::TimeArray), + 1184 => Some(Inner::Timestamptz), + 1185 => Some(Inner::TimestamptzArray), + 1186 => Some(Inner::Interval), + 1187 => Some(Inner::IntervalArray), + 1231 => Some(Inner::NumericArray), + 1263 => Some(Inner::CstringArray), + 1266 => Some(Inner::Timetz), + 1270 => Some(Inner::TimetzArray), + 1560 => Some(Inner::Bit), + 1561 => Some(Inner::BitArray), + 1562 => Some(Inner::Varbit), + 1563 => Some(Inner::VarbitArray), + 1700 => Some(Inner::Numeric), + 1790 => Some(Inner::Refcursor), + 2201 => Some(Inner::RefcursorArray), + 2202 => Some(Inner::Regprocedure), + 2203 => Some(Inner::Regoper), + 2204 => Some(Inner::Regoperator), + 2205 => Some(Inner::Regclass), + 2206 => Some(Inner::Regtype), + 2207 => Some(Inner::RegprocedureArray), + 2208 => Some(Inner::RegoperArray), + 2209 => Some(Inner::RegoperatorArray), + 2210 => Some(Inner::RegclassArray), + 2211 => Some(Inner::RegtypeArray), + 2249 => Some(Inner::Record), + 2275 => Some(Inner::Cstring), + 2276 => Some(Inner::Any), + 2277 => Some(Inner::Anyarray), + 2278 => Some(Inner::Void), + 2279 => Some(Inner::Trigger), + 2280 => Some(Inner::LanguageHandler), + 2281 => Some(Inner::Internal), + 2283 => Some(Inner::Anyelement), + 2287 => Some(Inner::RecordArray), + 2776 => Some(Inner::Anynonarray), + 2949 => Some(Inner::TxidSnapshotArray), + 2950 => Some(Inner::Uuid), + 2951 => Some(Inner::UuidArray), + 2970 => Some(Inner::TxidSnapshot), + 3115 => Some(Inner::FdwHandler), + 3220 => Some(Inner::PgLsn), + 3221 => Some(Inner::PgLsnArray), + 3310 => Some(Inner::TsmHandler), + 3361 => Some(Inner::PgNdistinct), + 3402 => Some(Inner::PgDependencies), + 3500 => Some(Inner::Anyenum), + 3614 => Some(Inner::TsVector), + 3615 => Some(Inner::Tsquery), + 3642 => Some(Inner::GtsVector), + 3643 => Some(Inner::TsVectorArray), + 3644 => Some(Inner::GtsVectorArray), + 3645 => Some(Inner::TsqueryArray), + 3734 => Some(Inner::Regconfig), + 3735 => Some(Inner::RegconfigArray), + 3769 => Some(Inner::Regdictionary), + 3770 => Some(Inner::RegdictionaryArray), + 3802 => Some(Inner::Jsonb), + 3807 => Some(Inner::JsonbArray), + 3831 => Some(Inner::AnyRange), + 3838 => Some(Inner::EventTrigger), + 3904 => Some(Inner::Int4Range), + 3905 => Some(Inner::Int4RangeArray), + 3906 => Some(Inner::NumRange), + 3907 => Some(Inner::NumRangeArray), + 3908 => Some(Inner::TsRange), + 3909 => Some(Inner::TsRangeArray), + 3910 => Some(Inner::TstzRange), + 3911 => Some(Inner::TstzRangeArray), + 3912 => Some(Inner::DateRange), + 3913 => Some(Inner::DateRangeArray), + 3926 => Some(Inner::Int8Range), + 3927 => Some(Inner::Int8RangeArray), + 4072 => Some(Inner::Jsonpath), + 4073 => Some(Inner::JsonpathArray), + 4089 => Some(Inner::Regnamespace), + 4090 => Some(Inner::RegnamespaceArray), + 4096 => Some(Inner::Regrole), + 4097 => Some(Inner::RegroleArray), + 4191 => Some(Inner::Regcollation), + 4192 => Some(Inner::RegcollationArray), + 4451 => Some(Inner::Int4multiRange), + 4532 => Some(Inner::NummultiRange), + 4533 => Some(Inner::TsmultiRange), + 4534 => Some(Inner::TstzmultiRange), + 4535 => Some(Inner::DatemultiRange), + 4536 => Some(Inner::Int8multiRange), + 4537 => Some(Inner::AnymultiRange), + 4538 => Some(Inner::AnycompatiblemultiRange), + 4600 => Some(Inner::PgBrinBloomSummary), + 4601 => Some(Inner::PgBrinMinmaxMultiSummary), + 5017 => Some(Inner::PgMcvList), + 5038 => Some(Inner::PgSnapshot), + 5039 => Some(Inner::PgSnapshotArray), + 5069 => Some(Inner::Xid8), + 5077 => Some(Inner::Anycompatible), + 5078 => Some(Inner::Anycompatiblearray), + 5079 => Some(Inner::Anycompatiblenonarray), + 5080 => Some(Inner::AnycompatibleRange), + 6150 => Some(Inner::Int4multiRangeArray), + 6151 => Some(Inner::NummultiRangeArray), + 6152 => Some(Inner::TsmultiRangeArray), + 6153 => Some(Inner::TstzmultiRangeArray), + 6155 => Some(Inner::DatemultiRangeArray), + 6157 => Some(Inner::Int8multiRangeArray), + _ => None, + } + } + + pub fn oid(&self) -> Oid { + match *self { + Inner::Bool => 16, + Inner::Bytea => 17, + Inner::Char => 18, + Inner::Name => 19, + Inner::Int8 => 20, + Inner::Int2 => 21, + Inner::Int2Vector => 22, + Inner::Int4 => 23, + Inner::Regproc => 24, + Inner::Text => 25, + Inner::Oid => 26, + Inner::Tid => 27, + Inner::Xid => 28, + Inner::Cid => 29, + Inner::OidVector => 30, + Inner::PgDdlCommand => 32, + Inner::Json => 114, + Inner::Xml => 142, + Inner::XmlArray => 143, + Inner::PgNodeTree => 194, + Inner::JsonArray => 199, + Inner::TableAmHandler => 269, + Inner::Xid8Array => 271, + Inner::IndexAmHandler => 325, + Inner::Point => 600, + Inner::Lseg => 601, + Inner::Path => 602, + Inner::Box => 603, + Inner::Polygon => 604, + Inner::Line => 628, + Inner::LineArray => 629, + Inner::Cidr => 650, + Inner::CidrArray => 651, + Inner::Float4 => 700, + Inner::Float8 => 701, + Inner::Unknown => 705, + Inner::Circle => 718, + Inner::CircleArray => 719, + Inner::Macaddr8 => 774, + Inner::Macaddr8Array => 775, + Inner::Money => 790, + Inner::MoneyArray => 791, + Inner::Macaddr => 829, + Inner::Inet => 869, + Inner::BoolArray => 1000, + Inner::ByteaArray => 1001, + Inner::CharArray => 1002, + Inner::NameArray => 1003, + Inner::Int2Array => 1005, + Inner::Int2VectorArray => 1006, + Inner::Int4Array => 1007, + Inner::RegprocArray => 1008, + Inner::TextArray => 1009, + Inner::TidArray => 1010, + Inner::XidArray => 1011, + Inner::CidArray => 1012, + Inner::OidVectorArray => 1013, + Inner::BpcharArray => 1014, + Inner::VarcharArray => 1015, + Inner::Int8Array => 1016, + Inner::PointArray => 1017, + Inner::LsegArray => 1018, + Inner::PathArray => 1019, + Inner::BoxArray => 1020, + Inner::Float4Array => 1021, + Inner::Float8Array => 1022, + Inner::PolygonArray => 1027, + Inner::OidArray => 1028, + Inner::Aclitem => 1033, + Inner::AclitemArray => 1034, + Inner::MacaddrArray => 1040, + Inner::InetArray => 1041, + Inner::Bpchar => 1042, + Inner::Varchar => 1043, + Inner::Date => 1082, + Inner::Time => 1083, + Inner::Timestamp => 1114, + Inner::TimestampArray => 1115, + Inner::DateArray => 1182, + Inner::TimeArray => 1183, + Inner::Timestamptz => 1184, + Inner::TimestamptzArray => 1185, + Inner::Interval => 1186, + Inner::IntervalArray => 1187, + Inner::NumericArray => 1231, + Inner::CstringArray => 1263, + Inner::Timetz => 1266, + Inner::TimetzArray => 1270, + Inner::Bit => 1560, + Inner::BitArray => 1561, + Inner::Varbit => 1562, + Inner::VarbitArray => 1563, + Inner::Numeric => 1700, + Inner::Refcursor => 1790, + Inner::RefcursorArray => 2201, + Inner::Regprocedure => 2202, + Inner::Regoper => 2203, + Inner::Regoperator => 2204, + Inner::Regclass => 2205, + Inner::Regtype => 2206, + Inner::RegprocedureArray => 2207, + Inner::RegoperArray => 2208, + Inner::RegoperatorArray => 2209, + Inner::RegclassArray => 2210, + Inner::RegtypeArray => 2211, + Inner::Record => 2249, + Inner::Cstring => 2275, + Inner::Any => 2276, + Inner::Anyarray => 2277, + Inner::Void => 2278, + Inner::Trigger => 2279, + Inner::LanguageHandler => 2280, + Inner::Internal => 2281, + Inner::Anyelement => 2283, + Inner::RecordArray => 2287, + Inner::Anynonarray => 2776, + Inner::TxidSnapshotArray => 2949, + Inner::Uuid => 2950, + Inner::UuidArray => 2951, + Inner::TxidSnapshot => 2970, + Inner::FdwHandler => 3115, + Inner::PgLsn => 3220, + Inner::PgLsnArray => 3221, + Inner::TsmHandler => 3310, + Inner::PgNdistinct => 3361, + Inner::PgDependencies => 3402, + Inner::Anyenum => 3500, + Inner::TsVector => 3614, + Inner::Tsquery => 3615, + Inner::GtsVector => 3642, + Inner::TsVectorArray => 3643, + Inner::GtsVectorArray => 3644, + Inner::TsqueryArray => 3645, + Inner::Regconfig => 3734, + Inner::RegconfigArray => 3735, + Inner::Regdictionary => 3769, + Inner::RegdictionaryArray => 3770, + Inner::Jsonb => 3802, + Inner::JsonbArray => 3807, + Inner::AnyRange => 3831, + Inner::EventTrigger => 3838, + Inner::Int4Range => 3904, + Inner::Int4RangeArray => 3905, + Inner::NumRange => 3906, + Inner::NumRangeArray => 3907, + Inner::TsRange => 3908, + Inner::TsRangeArray => 3909, + Inner::TstzRange => 3910, + Inner::TstzRangeArray => 3911, + Inner::DateRange => 3912, + Inner::DateRangeArray => 3913, + Inner::Int8Range => 3926, + Inner::Int8RangeArray => 3927, + Inner::Jsonpath => 4072, + Inner::JsonpathArray => 4073, + Inner::Regnamespace => 4089, + Inner::RegnamespaceArray => 4090, + Inner::Regrole => 4096, + Inner::RegroleArray => 4097, + Inner::Regcollation => 4191, + Inner::RegcollationArray => 4192, + Inner::Int4multiRange => 4451, + Inner::NummultiRange => 4532, + Inner::TsmultiRange => 4533, + Inner::TstzmultiRange => 4534, + Inner::DatemultiRange => 4535, + Inner::Int8multiRange => 4536, + Inner::AnymultiRange => 4537, + Inner::AnycompatiblemultiRange => 4538, + Inner::PgBrinBloomSummary => 4600, + Inner::PgBrinMinmaxMultiSummary => 4601, + Inner::PgMcvList => 5017, + Inner::PgSnapshot => 5038, + Inner::PgSnapshotArray => 5039, + Inner::Xid8 => 5069, + Inner::Anycompatible => 5077, + Inner::Anycompatiblearray => 5078, + Inner::Anycompatiblenonarray => 5079, + Inner::AnycompatibleRange => 5080, + Inner::Int4multiRangeArray => 6150, + Inner::NummultiRangeArray => 6151, + Inner::TsmultiRangeArray => 6152, + Inner::TstzmultiRangeArray => 6153, + Inner::DatemultiRangeArray => 6155, + Inner::Int8multiRangeArray => 6157, + Inner::Other(ref u) => u.oid, + } + } + + pub fn kind(&self) -> &Kind { + match *self { + Inner::Bool => &Kind::Simple, + Inner::Bytea => &Kind::Simple, + Inner::Char => &Kind::Simple, + Inner::Name => &Kind::Simple, + Inner::Int8 => &Kind::Simple, + Inner::Int2 => &Kind::Simple, + Inner::Int2Vector => &Kind::Array(Type(Inner::Int2)), + Inner::Int4 => &Kind::Simple, + Inner::Regproc => &Kind::Simple, + Inner::Text => &Kind::Simple, + Inner::Oid => &Kind::Simple, + Inner::Tid => &Kind::Simple, + Inner::Xid => &Kind::Simple, + Inner::Cid => &Kind::Simple, + Inner::OidVector => &Kind::Array(Type(Inner::Oid)), + Inner::PgDdlCommand => &Kind::Pseudo, + Inner::Json => &Kind::Simple, + Inner::Xml => &Kind::Simple, + Inner::XmlArray => &Kind::Array(Type(Inner::Xml)), + Inner::PgNodeTree => &Kind::Simple, + Inner::JsonArray => &Kind::Array(Type(Inner::Json)), + Inner::TableAmHandler => &Kind::Pseudo, + Inner::Xid8Array => &Kind::Array(Type(Inner::Xid8)), + Inner::IndexAmHandler => &Kind::Pseudo, + Inner::Point => &Kind::Simple, + Inner::Lseg => &Kind::Simple, + Inner::Path => &Kind::Simple, + Inner::Box => &Kind::Simple, + Inner::Polygon => &Kind::Simple, + Inner::Line => &Kind::Simple, + Inner::LineArray => &Kind::Array(Type(Inner::Line)), + Inner::Cidr => &Kind::Simple, + Inner::CidrArray => &Kind::Array(Type(Inner::Cidr)), + Inner::Float4 => &Kind::Simple, + Inner::Float8 => &Kind::Simple, + Inner::Unknown => &Kind::Simple, + Inner::Circle => &Kind::Simple, + Inner::CircleArray => &Kind::Array(Type(Inner::Circle)), + Inner::Macaddr8 => &Kind::Simple, + Inner::Macaddr8Array => &Kind::Array(Type(Inner::Macaddr8)), + Inner::Money => &Kind::Simple, + Inner::MoneyArray => &Kind::Array(Type(Inner::Money)), + Inner::Macaddr => &Kind::Simple, + Inner::Inet => &Kind::Simple, + Inner::BoolArray => &Kind::Array(Type(Inner::Bool)), + Inner::ByteaArray => &Kind::Array(Type(Inner::Bytea)), + Inner::CharArray => &Kind::Array(Type(Inner::Char)), + Inner::NameArray => &Kind::Array(Type(Inner::Name)), + Inner::Int2Array => &Kind::Array(Type(Inner::Int2)), + Inner::Int2VectorArray => &Kind::Array(Type(Inner::Int2Vector)), + Inner::Int4Array => &Kind::Array(Type(Inner::Int4)), + Inner::RegprocArray => &Kind::Array(Type(Inner::Regproc)), + Inner::TextArray => &Kind::Array(Type(Inner::Text)), + Inner::TidArray => &Kind::Array(Type(Inner::Tid)), + Inner::XidArray => &Kind::Array(Type(Inner::Xid)), + Inner::CidArray => &Kind::Array(Type(Inner::Cid)), + Inner::OidVectorArray => &Kind::Array(Type(Inner::OidVector)), + Inner::BpcharArray => &Kind::Array(Type(Inner::Bpchar)), + Inner::VarcharArray => &Kind::Array(Type(Inner::Varchar)), + Inner::Int8Array => &Kind::Array(Type(Inner::Int8)), + Inner::PointArray => &Kind::Array(Type(Inner::Point)), + Inner::LsegArray => &Kind::Array(Type(Inner::Lseg)), + Inner::PathArray => &Kind::Array(Type(Inner::Path)), + Inner::BoxArray => &Kind::Array(Type(Inner::Box)), + Inner::Float4Array => &Kind::Array(Type(Inner::Float4)), + Inner::Float8Array => &Kind::Array(Type(Inner::Float8)), + Inner::PolygonArray => &Kind::Array(Type(Inner::Polygon)), + Inner::OidArray => &Kind::Array(Type(Inner::Oid)), + Inner::Aclitem => &Kind::Simple, + Inner::AclitemArray => &Kind::Array(Type(Inner::Aclitem)), + Inner::MacaddrArray => &Kind::Array(Type(Inner::Macaddr)), + Inner::InetArray => &Kind::Array(Type(Inner::Inet)), + Inner::Bpchar => &Kind::Simple, + Inner::Varchar => &Kind::Simple, + Inner::Date => &Kind::Simple, + Inner::Time => &Kind::Simple, + Inner::Timestamp => &Kind::Simple, + Inner::TimestampArray => &Kind::Array(Type(Inner::Timestamp)), + Inner::DateArray => &Kind::Array(Type(Inner::Date)), + Inner::TimeArray => &Kind::Array(Type(Inner::Time)), + Inner::Timestamptz => &Kind::Simple, + Inner::TimestamptzArray => &Kind::Array(Type(Inner::Timestamptz)), + Inner::Interval => &Kind::Simple, + Inner::IntervalArray => &Kind::Array(Type(Inner::Interval)), + Inner::NumericArray => &Kind::Array(Type(Inner::Numeric)), + Inner::CstringArray => &Kind::Array(Type(Inner::Cstring)), + Inner::Timetz => &Kind::Simple, + Inner::TimetzArray => &Kind::Array(Type(Inner::Timetz)), + Inner::Bit => &Kind::Simple, + Inner::BitArray => &Kind::Array(Type(Inner::Bit)), + Inner::Varbit => &Kind::Simple, + Inner::VarbitArray => &Kind::Array(Type(Inner::Varbit)), + Inner::Numeric => &Kind::Simple, + Inner::Refcursor => &Kind::Simple, + Inner::RefcursorArray => &Kind::Array(Type(Inner::Refcursor)), + Inner::Regprocedure => &Kind::Simple, + Inner::Regoper => &Kind::Simple, + Inner::Regoperator => &Kind::Simple, + Inner::Regclass => &Kind::Simple, + Inner::Regtype => &Kind::Simple, + Inner::RegprocedureArray => &Kind::Array(Type(Inner::Regprocedure)), + Inner::RegoperArray => &Kind::Array(Type(Inner::Regoper)), + Inner::RegoperatorArray => &Kind::Array(Type(Inner::Regoperator)), + Inner::RegclassArray => &Kind::Array(Type(Inner::Regclass)), + Inner::RegtypeArray => &Kind::Array(Type(Inner::Regtype)), + Inner::Record => &Kind::Pseudo, + Inner::Cstring => &Kind::Pseudo, + Inner::Any => &Kind::Pseudo, + Inner::Anyarray => &Kind::Pseudo, + Inner::Void => &Kind::Pseudo, + Inner::Trigger => &Kind::Pseudo, + Inner::LanguageHandler => &Kind::Pseudo, + Inner::Internal => &Kind::Pseudo, + Inner::Anyelement => &Kind::Pseudo, + Inner::RecordArray => &Kind::Pseudo, + Inner::Anynonarray => &Kind::Pseudo, + Inner::TxidSnapshotArray => &Kind::Array(Type(Inner::TxidSnapshot)), + Inner::Uuid => &Kind::Simple, + Inner::UuidArray => &Kind::Array(Type(Inner::Uuid)), + Inner::TxidSnapshot => &Kind::Simple, + Inner::FdwHandler => &Kind::Pseudo, + Inner::PgLsn => &Kind::Simple, + Inner::PgLsnArray => &Kind::Array(Type(Inner::PgLsn)), + Inner::TsmHandler => &Kind::Pseudo, + Inner::PgNdistinct => &Kind::Simple, + Inner::PgDependencies => &Kind::Simple, + Inner::Anyenum => &Kind::Pseudo, + Inner::TsVector => &Kind::Simple, + Inner::Tsquery => &Kind::Simple, + Inner::GtsVector => &Kind::Simple, + Inner::TsVectorArray => &Kind::Array(Type(Inner::TsVector)), + Inner::GtsVectorArray => &Kind::Array(Type(Inner::GtsVector)), + Inner::TsqueryArray => &Kind::Array(Type(Inner::Tsquery)), + Inner::Regconfig => &Kind::Simple, + Inner::RegconfigArray => &Kind::Array(Type(Inner::Regconfig)), + Inner::Regdictionary => &Kind::Simple, + Inner::RegdictionaryArray => &Kind::Array(Type(Inner::Regdictionary)), + Inner::Jsonb => &Kind::Simple, + Inner::JsonbArray => &Kind::Array(Type(Inner::Jsonb)), + Inner::AnyRange => &Kind::Pseudo, + Inner::EventTrigger => &Kind::Pseudo, + Inner::Int4Range => &Kind::Range(Type(Inner::Int4)), + Inner::Int4RangeArray => &Kind::Array(Type(Inner::Int4Range)), + Inner::NumRange => &Kind::Range(Type(Inner::Numeric)), + Inner::NumRangeArray => &Kind::Array(Type(Inner::NumRange)), + Inner::TsRange => &Kind::Range(Type(Inner::Timestamp)), + Inner::TsRangeArray => &Kind::Array(Type(Inner::TsRange)), + Inner::TstzRange => &Kind::Range(Type(Inner::Timestamptz)), + Inner::TstzRangeArray => &Kind::Array(Type(Inner::TstzRange)), + Inner::DateRange => &Kind::Range(Type(Inner::Date)), + Inner::DateRangeArray => &Kind::Array(Type(Inner::DateRange)), + Inner::Int8Range => &Kind::Range(Type(Inner::Int8)), + Inner::Int8RangeArray => &Kind::Array(Type(Inner::Int8Range)), + Inner::Jsonpath => &Kind::Simple, + Inner::JsonpathArray => &Kind::Array(Type(Inner::Jsonpath)), + Inner::Regnamespace => &Kind::Simple, + Inner::RegnamespaceArray => &Kind::Array(Type(Inner::Regnamespace)), + Inner::Regrole => &Kind::Simple, + Inner::RegroleArray => &Kind::Array(Type(Inner::Regrole)), + Inner::Regcollation => &Kind::Simple, + Inner::RegcollationArray => &Kind::Array(Type(Inner::Regcollation)), + Inner::Int4multiRange => &Kind::Multirange(Type(Inner::Int4)), + Inner::NummultiRange => &Kind::Multirange(Type(Inner::Numeric)), + Inner::TsmultiRange => &Kind::Multirange(Type(Inner::Timestamp)), + Inner::TstzmultiRange => &Kind::Multirange(Type(Inner::Timestamptz)), + Inner::DatemultiRange => &Kind::Multirange(Type(Inner::Date)), + Inner::Int8multiRange => &Kind::Multirange(Type(Inner::Int8)), + Inner::AnymultiRange => &Kind::Pseudo, + Inner::AnycompatiblemultiRange => &Kind::Pseudo, + Inner::PgBrinBloomSummary => &Kind::Simple, + Inner::PgBrinMinmaxMultiSummary => &Kind::Simple, + Inner::PgMcvList => &Kind::Simple, + Inner::PgSnapshot => &Kind::Simple, + Inner::PgSnapshotArray => &Kind::Array(Type(Inner::PgSnapshot)), + Inner::Xid8 => &Kind::Simple, + Inner::Anycompatible => &Kind::Pseudo, + Inner::Anycompatiblearray => &Kind::Pseudo, + Inner::Anycompatiblenonarray => &Kind::Pseudo, + Inner::AnycompatibleRange => &Kind::Pseudo, + Inner::Int4multiRangeArray => &Kind::Array(Type(Inner::Int4multiRange)), + Inner::NummultiRangeArray => &Kind::Array(Type(Inner::NummultiRange)), + Inner::TsmultiRangeArray => &Kind::Array(Type(Inner::TsmultiRange)), + Inner::TstzmultiRangeArray => &Kind::Array(Type(Inner::TstzmultiRange)), + Inner::DatemultiRangeArray => &Kind::Array(Type(Inner::DatemultiRange)), + Inner::Int8multiRangeArray => &Kind::Array(Type(Inner::Int8multiRange)), + Inner::Other(ref u) => &u.kind, + } + } + + pub fn name(&self) -> &str { + match *self { + Inner::Bool => "bool", + Inner::Bytea => "bytea", + Inner::Char => "char", + Inner::Name => "name", + Inner::Int8 => "int8", + Inner::Int2 => "int2", + Inner::Int2Vector => "int2vector", + Inner::Int4 => "int4", + Inner::Regproc => "regproc", + Inner::Text => "text", + Inner::Oid => "oid", + Inner::Tid => "tid", + Inner::Xid => "xid", + Inner::Cid => "cid", + Inner::OidVector => "oidvector", + Inner::PgDdlCommand => "pg_ddl_command", + Inner::Json => "json", + Inner::Xml => "xml", + Inner::XmlArray => "_xml", + Inner::PgNodeTree => "pg_node_tree", + Inner::JsonArray => "_json", + Inner::TableAmHandler => "table_am_handler", + Inner::Xid8Array => "_xid8", + Inner::IndexAmHandler => "index_am_handler", + Inner::Point => "point", + Inner::Lseg => "lseg", + Inner::Path => "path", + Inner::Box => "box", + Inner::Polygon => "polygon", + Inner::Line => "line", + Inner::LineArray => "_line", + Inner::Cidr => "cidr", + Inner::CidrArray => "_cidr", + Inner::Float4 => "float4", + Inner::Float8 => "float8", + Inner::Unknown => "unknown", + Inner::Circle => "circle", + Inner::CircleArray => "_circle", + Inner::Macaddr8 => "macaddr8", + Inner::Macaddr8Array => "_macaddr8", + Inner::Money => "money", + Inner::MoneyArray => "_money", + Inner::Macaddr => "macaddr", + Inner::Inet => "inet", + Inner::BoolArray => "_bool", + Inner::ByteaArray => "_bytea", + Inner::CharArray => "_char", + Inner::NameArray => "_name", + Inner::Int2Array => "_int2", + Inner::Int2VectorArray => "_int2vector", + Inner::Int4Array => "_int4", + Inner::RegprocArray => "_regproc", + Inner::TextArray => "_text", + Inner::TidArray => "_tid", + Inner::XidArray => "_xid", + Inner::CidArray => "_cid", + Inner::OidVectorArray => "_oidvector", + Inner::BpcharArray => "_bpchar", + Inner::VarcharArray => "_varchar", + Inner::Int8Array => "_int8", + Inner::PointArray => "_point", + Inner::LsegArray => "_lseg", + Inner::PathArray => "_path", + Inner::BoxArray => "_box", + Inner::Float4Array => "_float4", + Inner::Float8Array => "_float8", + Inner::PolygonArray => "_polygon", + Inner::OidArray => "_oid", + Inner::Aclitem => "aclitem", + Inner::AclitemArray => "_aclitem", + Inner::MacaddrArray => "_macaddr", + Inner::InetArray => "_inet", + Inner::Bpchar => "bpchar", + Inner::Varchar => "varchar", + Inner::Date => "date", + Inner::Time => "time", + Inner::Timestamp => "timestamp", + Inner::TimestampArray => "_timestamp", + Inner::DateArray => "_date", + Inner::TimeArray => "_time", + Inner::Timestamptz => "timestamptz", + Inner::TimestamptzArray => "_timestamptz", + Inner::Interval => "interval", + Inner::IntervalArray => "_interval", + Inner::NumericArray => "_numeric", + Inner::CstringArray => "_cstring", + Inner::Timetz => "timetz", + Inner::TimetzArray => "_timetz", + Inner::Bit => "bit", + Inner::BitArray => "_bit", + Inner::Varbit => "varbit", + Inner::VarbitArray => "_varbit", + Inner::Numeric => "numeric", + Inner::Refcursor => "refcursor", + Inner::RefcursorArray => "_refcursor", + Inner::Regprocedure => "regprocedure", + Inner::Regoper => "regoper", + Inner::Regoperator => "regoperator", + Inner::Regclass => "regclass", + Inner::Regtype => "regtype", + Inner::RegprocedureArray => "_regprocedure", + Inner::RegoperArray => "_regoper", + Inner::RegoperatorArray => "_regoperator", + Inner::RegclassArray => "_regclass", + Inner::RegtypeArray => "_regtype", + Inner::Record => "record", + Inner::Cstring => "cstring", + Inner::Any => "any", + Inner::Anyarray => "anyarray", + Inner::Void => "void", + Inner::Trigger => "trigger", + Inner::LanguageHandler => "language_handler", + Inner::Internal => "internal", + Inner::Anyelement => "anyelement", + Inner::RecordArray => "_record", + Inner::Anynonarray => "anynonarray", + Inner::TxidSnapshotArray => "_txid_snapshot", + Inner::Uuid => "uuid", + Inner::UuidArray => "_uuid", + Inner::TxidSnapshot => "txid_snapshot", + Inner::FdwHandler => "fdw_handler", + Inner::PgLsn => "pg_lsn", + Inner::PgLsnArray => "_pg_lsn", + Inner::TsmHandler => "tsm_handler", + Inner::PgNdistinct => "pg_ndistinct", + Inner::PgDependencies => "pg_dependencies", + Inner::Anyenum => "anyenum", + Inner::TsVector => "tsvector", + Inner::Tsquery => "tsquery", + Inner::GtsVector => "gtsvector", + Inner::TsVectorArray => "_tsvector", + Inner::GtsVectorArray => "_gtsvector", + Inner::TsqueryArray => "_tsquery", + Inner::Regconfig => "regconfig", + Inner::RegconfigArray => "_regconfig", + Inner::Regdictionary => "regdictionary", + Inner::RegdictionaryArray => "_regdictionary", + Inner::Jsonb => "jsonb", + Inner::JsonbArray => "_jsonb", + Inner::AnyRange => "anyrange", + Inner::EventTrigger => "event_trigger", + Inner::Int4Range => "int4range", + Inner::Int4RangeArray => "_int4range", + Inner::NumRange => "numrange", + Inner::NumRangeArray => "_numrange", + Inner::TsRange => "tsrange", + Inner::TsRangeArray => "_tsrange", + Inner::TstzRange => "tstzrange", + Inner::TstzRangeArray => "_tstzrange", + Inner::DateRange => "daterange", + Inner::DateRangeArray => "_daterange", + Inner::Int8Range => "int8range", + Inner::Int8RangeArray => "_int8range", + Inner::Jsonpath => "jsonpath", + Inner::JsonpathArray => "_jsonpath", + Inner::Regnamespace => "regnamespace", + Inner::RegnamespaceArray => "_regnamespace", + Inner::Regrole => "regrole", + Inner::RegroleArray => "_regrole", + Inner::Regcollation => "regcollation", + Inner::RegcollationArray => "_regcollation", + Inner::Int4multiRange => "int4multirange", + Inner::NummultiRange => "nummultirange", + Inner::TsmultiRange => "tsmultirange", + Inner::TstzmultiRange => "tstzmultirange", + Inner::DatemultiRange => "datemultirange", + Inner::Int8multiRange => "int8multirange", + Inner::AnymultiRange => "anymultirange", + Inner::AnycompatiblemultiRange => "anycompatiblemultirange", + Inner::PgBrinBloomSummary => "pg_brin_bloom_summary", + Inner::PgBrinMinmaxMultiSummary => "pg_brin_minmax_multi_summary", + Inner::PgMcvList => "pg_mcv_list", + Inner::PgSnapshot => "pg_snapshot", + Inner::PgSnapshotArray => "_pg_snapshot", + Inner::Xid8 => "xid8", + Inner::Anycompatible => "anycompatible", + Inner::Anycompatiblearray => "anycompatiblearray", + Inner::Anycompatiblenonarray => "anycompatiblenonarray", + Inner::AnycompatibleRange => "anycompatiblerange", + Inner::Int4multiRangeArray => "_int4multirange", + Inner::NummultiRangeArray => "_nummultirange", + Inner::TsmultiRangeArray => "_tsmultirange", + Inner::TstzmultiRangeArray => "_tstzmultirange", + Inner::DatemultiRangeArray => "_datemultirange", + Inner::Int8multiRangeArray => "_int8multirange", + Inner::Other(ref u) => &u.name, + } + } +} +impl Type { + /// BOOL - boolean, 'true'/'false' + pub const BOOL: Type = Type(Inner::Bool); + + /// BYTEA - variable-length string, binary values escaped + pub const BYTEA: Type = Type(Inner::Bytea); + + /// CHAR - single character + pub const CHAR: Type = Type(Inner::Char); + + /// NAME - 63-byte type for storing system identifiers + pub const NAME: Type = Type(Inner::Name); + + /// INT8 - ~18 digit integer, 8-byte storage + pub const INT8: Type = Type(Inner::Int8); + + /// INT2 - -32 thousand to 32 thousand, 2-byte storage + pub const INT2: Type = Type(Inner::Int2); + + /// INT2VECTOR - array of int2, used in system tables + pub const INT2_VECTOR: Type = Type(Inner::Int2Vector); + + /// INT4 - -2 billion to 2 billion integer, 4-byte storage + pub const INT4: Type = Type(Inner::Int4); + + /// REGPROC - registered procedure + pub const REGPROC: Type = Type(Inner::Regproc); + + /// TEXT - variable-length string, no limit specified + pub const TEXT: Type = Type(Inner::Text); + + /// OID - object identifier(oid), maximum 4 billion + pub const OID: Type = Type(Inner::Oid); + + /// TID - (block, offset), physical location of tuple + pub const TID: Type = Type(Inner::Tid); + + /// XID - transaction id + pub const XID: Type = Type(Inner::Xid); + + /// CID - command identifier type, sequence in transaction id + pub const CID: Type = Type(Inner::Cid); + + /// OIDVECTOR - array of oids, used in system tables + pub const OID_VECTOR: Type = Type(Inner::OidVector); + + /// PG_DDL_COMMAND - internal type for passing CollectedCommand + pub const PG_DDL_COMMAND: Type = Type(Inner::PgDdlCommand); + + /// JSON - JSON stored as text + pub const JSON: Type = Type(Inner::Json); + + /// XML - XML content + pub const XML: Type = Type(Inner::Xml); + + /// XML[] + pub const XML_ARRAY: Type = Type(Inner::XmlArray); + + /// PG_NODE_TREE - string representing an internal node tree + pub const PG_NODE_TREE: Type = Type(Inner::PgNodeTree); + + /// JSON[] + pub const JSON_ARRAY: Type = Type(Inner::JsonArray); + + /// TABLE_AM_HANDLER + pub const TABLE_AM_HANDLER: Type = Type(Inner::TableAmHandler); + + /// XID8[] + pub const XID8_ARRAY: Type = Type(Inner::Xid8Array); + + /// INDEX_AM_HANDLER - pseudo-type for the result of an index AM handler function + pub const INDEX_AM_HANDLER: Type = Type(Inner::IndexAmHandler); + + /// POINT - geometric point '(x, y)' + pub const POINT: Type = Type(Inner::Point); + + /// LSEG - geometric line segment '(pt1,pt2)' + pub const LSEG: Type = Type(Inner::Lseg); + + /// PATH - geometric path '(pt1,...)' + pub const PATH: Type = Type(Inner::Path); + + /// BOX - geometric box '(lower left,upper right)' + pub const BOX: Type = Type(Inner::Box); + + /// POLYGON - geometric polygon '(pt1,...)' + pub const POLYGON: Type = Type(Inner::Polygon); + + /// LINE - geometric line + pub const LINE: Type = Type(Inner::Line); + + /// LINE[] + pub const LINE_ARRAY: Type = Type(Inner::LineArray); + + /// CIDR - network IP address/netmask, network address + pub const CIDR: Type = Type(Inner::Cidr); + + /// CIDR[] + pub const CIDR_ARRAY: Type = Type(Inner::CidrArray); + + /// FLOAT4 - single-precision floating point number, 4-byte storage + pub const FLOAT4: Type = Type(Inner::Float4); + + /// FLOAT8 - double-precision floating point number, 8-byte storage + pub const FLOAT8: Type = Type(Inner::Float8); + + /// UNKNOWN - pseudo-type representing an undetermined type + pub const UNKNOWN: Type = Type(Inner::Unknown); + + /// CIRCLE - geometric circle '(center,radius)' + pub const CIRCLE: Type = Type(Inner::Circle); + + /// CIRCLE[] + pub const CIRCLE_ARRAY: Type = Type(Inner::CircleArray); + + /// MACADDR8 - XX:XX:XX:XX:XX:XX:XX:XX, MAC address + pub const MACADDR8: Type = Type(Inner::Macaddr8); + + /// MACADDR8[] + pub const MACADDR8_ARRAY: Type = Type(Inner::Macaddr8Array); + + /// MONEY - monetary amounts, $d,ddd.cc + pub const MONEY: Type = Type(Inner::Money); + + /// MONEY[] + pub const MONEY_ARRAY: Type = Type(Inner::MoneyArray); + + /// MACADDR - XX:XX:XX:XX:XX:XX, MAC address + pub const MACADDR: Type = Type(Inner::Macaddr); + + /// INET - IP address/netmask, host address, netmask optional + pub const INET: Type = Type(Inner::Inet); + + /// BOOL[] + pub const BOOL_ARRAY: Type = Type(Inner::BoolArray); + + /// BYTEA[] + pub const BYTEA_ARRAY: Type = Type(Inner::ByteaArray); + + /// CHAR[] + pub const CHAR_ARRAY: Type = Type(Inner::CharArray); + + /// NAME[] + pub const NAME_ARRAY: Type = Type(Inner::NameArray); + + /// INT2[] + pub const INT2_ARRAY: Type = Type(Inner::Int2Array); + + /// INT2VECTOR[] + pub const INT2_VECTOR_ARRAY: Type = Type(Inner::Int2VectorArray); + + /// INT4[] + pub const INT4_ARRAY: Type = Type(Inner::Int4Array); + + /// REGPROC[] + pub const REGPROC_ARRAY: Type = Type(Inner::RegprocArray); + + /// TEXT[] + pub const TEXT_ARRAY: Type = Type(Inner::TextArray); + + /// TID[] + pub const TID_ARRAY: Type = Type(Inner::TidArray); + + /// XID[] + pub const XID_ARRAY: Type = Type(Inner::XidArray); + + /// CID[] + pub const CID_ARRAY: Type = Type(Inner::CidArray); + + /// OIDVECTOR[] + pub const OID_VECTOR_ARRAY: Type = Type(Inner::OidVectorArray); + + /// BPCHAR[] + pub const BPCHAR_ARRAY: Type = Type(Inner::BpcharArray); + + /// VARCHAR[] + pub const VARCHAR_ARRAY: Type = Type(Inner::VarcharArray); + + /// INT8[] + pub const INT8_ARRAY: Type = Type(Inner::Int8Array); + + /// POINT[] + pub const POINT_ARRAY: Type = Type(Inner::PointArray); + + /// LSEG[] + pub const LSEG_ARRAY: Type = Type(Inner::LsegArray); + + /// PATH[] + pub const PATH_ARRAY: Type = Type(Inner::PathArray); + + /// BOX[] + pub const BOX_ARRAY: Type = Type(Inner::BoxArray); + + /// FLOAT4[] + pub const FLOAT4_ARRAY: Type = Type(Inner::Float4Array); + + /// FLOAT8[] + pub const FLOAT8_ARRAY: Type = Type(Inner::Float8Array); + + /// POLYGON[] + pub const POLYGON_ARRAY: Type = Type(Inner::PolygonArray); + + /// OID[] + pub const OID_ARRAY: Type = Type(Inner::OidArray); + + /// ACLITEM - access control list + pub const ACLITEM: Type = Type(Inner::Aclitem); + + /// ACLITEM[] + pub const ACLITEM_ARRAY: Type = Type(Inner::AclitemArray); + + /// MACADDR[] + pub const MACADDR_ARRAY: Type = Type(Inner::MacaddrArray); + + /// INET[] + pub const INET_ARRAY: Type = Type(Inner::InetArray); + + /// BPCHAR - char(length), blank-padded string, fixed storage length + pub const BPCHAR: Type = Type(Inner::Bpchar); + + /// VARCHAR - varchar(length), non-blank-padded string, variable storage length + pub const VARCHAR: Type = Type(Inner::Varchar); + + /// DATE - date + pub const DATE: Type = Type(Inner::Date); + + /// TIME - time of day + pub const TIME: Type = Type(Inner::Time); + + /// TIMESTAMP - date and time + pub const TIMESTAMP: Type = Type(Inner::Timestamp); + + /// TIMESTAMP[] + pub const TIMESTAMP_ARRAY: Type = Type(Inner::TimestampArray); + + /// DATE[] + pub const DATE_ARRAY: Type = Type(Inner::DateArray); + + /// TIME[] + pub const TIME_ARRAY: Type = Type(Inner::TimeArray); + + /// TIMESTAMPTZ - date and time with time zone + pub const TIMESTAMPTZ: Type = Type(Inner::Timestamptz); + + /// TIMESTAMPTZ[] + pub const TIMESTAMPTZ_ARRAY: Type = Type(Inner::TimestamptzArray); + + /// INTERVAL - @ <number> <units>, time interval + pub const INTERVAL: Type = Type(Inner::Interval); + + /// INTERVAL[] + pub const INTERVAL_ARRAY: Type = Type(Inner::IntervalArray); + + /// NUMERIC[] + pub const NUMERIC_ARRAY: Type = Type(Inner::NumericArray); + + /// CSTRING[] + pub const CSTRING_ARRAY: Type = Type(Inner::CstringArray); + + /// TIMETZ - time of day with time zone + pub const TIMETZ: Type = Type(Inner::Timetz); + + /// TIMETZ[] + pub const TIMETZ_ARRAY: Type = Type(Inner::TimetzArray); + + /// BIT - fixed-length bit string + pub const BIT: Type = Type(Inner::Bit); + + /// BIT[] + pub const BIT_ARRAY: Type = Type(Inner::BitArray); + + /// VARBIT - variable-length bit string + pub const VARBIT: Type = Type(Inner::Varbit); + + /// VARBIT[] + pub const VARBIT_ARRAY: Type = Type(Inner::VarbitArray); + + /// NUMERIC - numeric(precision, decimal), arbitrary precision number + pub const NUMERIC: Type = Type(Inner::Numeric); + + /// REFCURSOR - reference to cursor (portal name) + pub const REFCURSOR: Type = Type(Inner::Refcursor); + + /// REFCURSOR[] + pub const REFCURSOR_ARRAY: Type = Type(Inner::RefcursorArray); + + /// REGPROCEDURE - registered procedure (with args) + pub const REGPROCEDURE: Type = Type(Inner::Regprocedure); + + /// REGOPER - registered operator + pub const REGOPER: Type = Type(Inner::Regoper); + + /// REGOPERATOR - registered operator (with args) + pub const REGOPERATOR: Type = Type(Inner::Regoperator); + + /// REGCLASS - registered class + pub const REGCLASS: Type = Type(Inner::Regclass); + + /// REGTYPE - registered type + pub const REGTYPE: Type = Type(Inner::Regtype); + + /// REGPROCEDURE[] + pub const REGPROCEDURE_ARRAY: Type = Type(Inner::RegprocedureArray); + + /// REGOPER[] + pub const REGOPER_ARRAY: Type = Type(Inner::RegoperArray); + + /// REGOPERATOR[] + pub const REGOPERATOR_ARRAY: Type = Type(Inner::RegoperatorArray); + + /// REGCLASS[] + pub const REGCLASS_ARRAY: Type = Type(Inner::RegclassArray); + + /// REGTYPE[] + pub const REGTYPE_ARRAY: Type = Type(Inner::RegtypeArray); + + /// RECORD - pseudo-type representing any composite type + pub const RECORD: Type = Type(Inner::Record); + + /// CSTRING - C-style string + pub const CSTRING: Type = Type(Inner::Cstring); + + /// ANY - pseudo-type representing any type + pub const ANY: Type = Type(Inner::Any); + + /// ANYARRAY - pseudo-type representing a polymorphic array type + pub const ANYARRAY: Type = Type(Inner::Anyarray); + + /// VOID - pseudo-type for the result of a function with no real result + pub const VOID: Type = Type(Inner::Void); + + /// TRIGGER - pseudo-type for the result of a trigger function + pub const TRIGGER: Type = Type(Inner::Trigger); + + /// LANGUAGE_HANDLER - pseudo-type for the result of a language handler function + pub const LANGUAGE_HANDLER: Type = Type(Inner::LanguageHandler); + + /// INTERNAL - pseudo-type representing an internal data structure + pub const INTERNAL: Type = Type(Inner::Internal); + + /// ANYELEMENT - pseudo-type representing a polymorphic base type + pub const ANYELEMENT: Type = Type(Inner::Anyelement); + + /// RECORD[] + pub const RECORD_ARRAY: Type = Type(Inner::RecordArray); + + /// ANYNONARRAY - pseudo-type representing a polymorphic base type that is not an array + pub const ANYNONARRAY: Type = Type(Inner::Anynonarray); + + /// TXID_SNAPSHOT[] + pub const TXID_SNAPSHOT_ARRAY: Type = Type(Inner::TxidSnapshotArray); + + /// UUID - UUID datatype + pub const UUID: Type = Type(Inner::Uuid); + + /// UUID[] + pub const UUID_ARRAY: Type = Type(Inner::UuidArray); + + /// TXID_SNAPSHOT - txid snapshot + pub const TXID_SNAPSHOT: Type = Type(Inner::TxidSnapshot); + + /// FDW_HANDLER - pseudo-type for the result of an FDW handler function + pub const FDW_HANDLER: Type = Type(Inner::FdwHandler); + + /// PG_LSN - PostgreSQL LSN datatype + pub const PG_LSN: Type = Type(Inner::PgLsn); + + /// PG_LSN[] + pub const PG_LSN_ARRAY: Type = Type(Inner::PgLsnArray); + + /// TSM_HANDLER - pseudo-type for the result of a tablesample method function + pub const TSM_HANDLER: Type = Type(Inner::TsmHandler); + + /// PG_NDISTINCT - multivariate ndistinct coefficients + pub const PG_NDISTINCT: Type = Type(Inner::PgNdistinct); + + /// PG_DEPENDENCIES - multivariate dependencies + pub const PG_DEPENDENCIES: Type = Type(Inner::PgDependencies); + + /// ANYENUM - pseudo-type representing a polymorphic base type that is an enum + pub const ANYENUM: Type = Type(Inner::Anyenum); + + /// TSVECTOR - text representation for text search + pub const TS_VECTOR: Type = Type(Inner::TsVector); + + /// TSQUERY - query representation for text search + pub const TSQUERY: Type = Type(Inner::Tsquery); + + /// GTSVECTOR - GiST index internal text representation for text search + pub const GTS_VECTOR: Type = Type(Inner::GtsVector); + + /// TSVECTOR[] + pub const TS_VECTOR_ARRAY: Type = Type(Inner::TsVectorArray); + + /// GTSVECTOR[] + pub const GTS_VECTOR_ARRAY: Type = Type(Inner::GtsVectorArray); + + /// TSQUERY[] + pub const TSQUERY_ARRAY: Type = Type(Inner::TsqueryArray); + + /// REGCONFIG - registered text search configuration + pub const REGCONFIG: Type = Type(Inner::Regconfig); + + /// REGCONFIG[] + pub const REGCONFIG_ARRAY: Type = Type(Inner::RegconfigArray); + + /// REGDICTIONARY - registered text search dictionary + pub const REGDICTIONARY: Type = Type(Inner::Regdictionary); + + /// REGDICTIONARY[] + pub const REGDICTIONARY_ARRAY: Type = Type(Inner::RegdictionaryArray); + + /// JSONB - Binary JSON + pub const JSONB: Type = Type(Inner::Jsonb); + + /// JSONB[] + pub const JSONB_ARRAY: Type = Type(Inner::JsonbArray); + + /// ANYRANGE - pseudo-type representing a range over a polymorphic base type + pub const ANY_RANGE: Type = Type(Inner::AnyRange); + + /// EVENT_TRIGGER - pseudo-type for the result of an event trigger function + pub const EVENT_TRIGGER: Type = Type(Inner::EventTrigger); + + /// INT4RANGE - range of integers + pub const INT4_RANGE: Type = Type(Inner::Int4Range); + + /// INT4RANGE[] + pub const INT4_RANGE_ARRAY: Type = Type(Inner::Int4RangeArray); + + /// NUMRANGE - range of numerics + pub const NUM_RANGE: Type = Type(Inner::NumRange); + + /// NUMRANGE[] + pub const NUM_RANGE_ARRAY: Type = Type(Inner::NumRangeArray); + + /// TSRANGE - range of timestamps without time zone + pub const TS_RANGE: Type = Type(Inner::TsRange); + + /// TSRANGE[] + pub const TS_RANGE_ARRAY: Type = Type(Inner::TsRangeArray); + + /// TSTZRANGE - range of timestamps with time zone + pub const TSTZ_RANGE: Type = Type(Inner::TstzRange); + + /// TSTZRANGE[] + pub const TSTZ_RANGE_ARRAY: Type = Type(Inner::TstzRangeArray); + + /// DATERANGE - range of dates + pub const DATE_RANGE: Type = Type(Inner::DateRange); + + /// DATERANGE[] + pub const DATE_RANGE_ARRAY: Type = Type(Inner::DateRangeArray); + + /// INT8RANGE - range of bigints + pub const INT8_RANGE: Type = Type(Inner::Int8Range); + + /// INT8RANGE[] + pub const INT8_RANGE_ARRAY: Type = Type(Inner::Int8RangeArray); + + /// JSONPATH - JSON path + pub const JSONPATH: Type = Type(Inner::Jsonpath); + + /// JSONPATH[] + pub const JSONPATH_ARRAY: Type = Type(Inner::JsonpathArray); + + /// REGNAMESPACE - registered namespace + pub const REGNAMESPACE: Type = Type(Inner::Regnamespace); + + /// REGNAMESPACE[] + pub const REGNAMESPACE_ARRAY: Type = Type(Inner::RegnamespaceArray); + + /// REGROLE - registered role + pub const REGROLE: Type = Type(Inner::Regrole); + + /// REGROLE[] + pub const REGROLE_ARRAY: Type = Type(Inner::RegroleArray); + + /// REGCOLLATION - registered collation + pub const REGCOLLATION: Type = Type(Inner::Regcollation); + + /// REGCOLLATION[] + pub const REGCOLLATION_ARRAY: Type = Type(Inner::RegcollationArray); + + /// INT4MULTIRANGE - multirange of integers + pub const INT4MULTI_RANGE: Type = Type(Inner::Int4multiRange); + + /// NUMMULTIRANGE - multirange of numerics + pub const NUMMULTI_RANGE: Type = Type(Inner::NummultiRange); + + /// TSMULTIRANGE - multirange of timestamps without time zone + pub const TSMULTI_RANGE: Type = Type(Inner::TsmultiRange); + + /// TSTZMULTIRANGE - multirange of timestamps with time zone + pub const TSTZMULTI_RANGE: Type = Type(Inner::TstzmultiRange); + + /// DATEMULTIRANGE - multirange of dates + pub const DATEMULTI_RANGE: Type = Type(Inner::DatemultiRange); + + /// INT8MULTIRANGE - multirange of bigints + pub const INT8MULTI_RANGE: Type = Type(Inner::Int8multiRange); + + /// ANYMULTIRANGE - pseudo-type representing a polymorphic base type that is a multirange + pub const ANYMULTI_RANGE: Type = Type(Inner::AnymultiRange); + + /// ANYCOMPATIBLEMULTIRANGE - pseudo-type representing a multirange over a polymorphic common type + pub const ANYCOMPATIBLEMULTI_RANGE: Type = Type(Inner::AnycompatiblemultiRange); + + /// PG_BRIN_BLOOM_SUMMARY - BRIN bloom summary + pub const PG_BRIN_BLOOM_SUMMARY: Type = Type(Inner::PgBrinBloomSummary); + + /// PG_BRIN_MINMAX_MULTI_SUMMARY - BRIN minmax-multi summary + pub const PG_BRIN_MINMAX_MULTI_SUMMARY: Type = Type(Inner::PgBrinMinmaxMultiSummary); + + /// PG_MCV_LIST - multivariate MCV list + pub const PG_MCV_LIST: Type = Type(Inner::PgMcvList); + + /// PG_SNAPSHOT - snapshot + pub const PG_SNAPSHOT: Type = Type(Inner::PgSnapshot); + + /// PG_SNAPSHOT[] + pub const PG_SNAPSHOT_ARRAY: Type = Type(Inner::PgSnapshotArray); + + /// XID8 - full transaction id + pub const XID8: Type = Type(Inner::Xid8); + + /// ANYCOMPATIBLE - pseudo-type representing a polymorphic common type + pub const ANYCOMPATIBLE: Type = Type(Inner::Anycompatible); + + /// ANYCOMPATIBLEARRAY - pseudo-type representing an array of polymorphic common type elements + pub const ANYCOMPATIBLEARRAY: Type = Type(Inner::Anycompatiblearray); + + /// ANYCOMPATIBLENONARRAY - pseudo-type representing a polymorphic common type that is not an array + pub const ANYCOMPATIBLENONARRAY: Type = Type(Inner::Anycompatiblenonarray); + + /// ANYCOMPATIBLERANGE - pseudo-type representing a range over a polymorphic common type + pub const ANYCOMPATIBLE_RANGE: Type = Type(Inner::AnycompatibleRange); + + /// INT4MULTIRANGE[] + pub const INT4MULTI_RANGE_ARRAY: Type = Type(Inner::Int4multiRangeArray); + + /// NUMMULTIRANGE[] + pub const NUMMULTI_RANGE_ARRAY: Type = Type(Inner::NummultiRangeArray); + + /// TSMULTIRANGE[] + pub const TSMULTI_RANGE_ARRAY: Type = Type(Inner::TsmultiRangeArray); + + /// TSTZMULTIRANGE[] + pub const TSTZMULTI_RANGE_ARRAY: Type = Type(Inner::TstzmultiRangeArray); + + /// DATEMULTIRANGE[] + pub const DATEMULTI_RANGE_ARRAY: Type = Type(Inner::DatemultiRangeArray); + + /// INT8MULTIRANGE[] + pub const INT8MULTI_RANGE_ARRAY: Type = Type(Inner::Int8multiRangeArray); +} diff --git a/libs/proxy/tokio-postgres2/Cargo.toml b/libs/proxy/tokio-postgres2/Cargo.toml new file mode 100644 index 0000000000..7130c1b726 --- /dev/null +++ b/libs/proxy/tokio-postgres2/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tokio-postgres2" +version = "0.1.0" +edition = "2018" +license = "MIT/Apache-2.0" + +[dependencies] +async-trait.workspace = true +bytes.workspace = true +byteorder.workspace = true +fallible-iterator.workspace = true +futures-util = { workspace = true, features = ["sink"] } +log = "0.4" +parking_lot.workspace = true +percent-encoding = "2.0" +pin-project-lite.workspace = true +phf = "0.11" +postgres-protocol2 = { path = "../postgres-protocol2" } +postgres-types2 = { path = "../postgres-types2" } +tokio = { workspace = true, features = ["io-util", "time", "net"] } +tokio-util = { workspace = true, features = ["codec"] } diff --git a/libs/proxy/tokio-postgres2/src/cancel_query.rs b/libs/proxy/tokio-postgres2/src/cancel_query.rs new file mode 100644 index 0000000000..cddbf16336 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/cancel_query.rs @@ -0,0 +1,40 @@ +use tokio::net::TcpStream; + +use crate::client::SocketConfig; +use crate::config::{Host, SslMode}; +use crate::tls::MakeTlsConnect; +use crate::{cancel_query_raw, connect_socket, Error}; +use std::io; + +pub(crate) async fn cancel_query( + config: Option, + ssl_mode: SslMode, + mut tls: T, + process_id: i32, + secret_key: i32, +) -> Result<(), Error> +where + T: MakeTlsConnect, +{ + let config = match config { + Some(config) => config, + None => { + return Err(Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "unknown host", + ))) + } + }; + + let hostname = match &config.host { + Host::Tcp(host) => &**host, + }; + let tls = tls + .make_tls_connect(hostname) + .map_err(|e| Error::tls(e.into()))?; + + let socket = + connect_socket::connect_socket(&config.host, config.port, config.connect_timeout).await?; + + cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await +} diff --git a/libs/proxy/tokio-postgres2/src/cancel_query_raw.rs b/libs/proxy/tokio-postgres2/src/cancel_query_raw.rs new file mode 100644 index 0000000000..8c08296435 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/cancel_query_raw.rs @@ -0,0 +1,29 @@ +use crate::config::SslMode; +use crate::tls::TlsConnect; +use crate::{connect_tls, Error}; +use bytes::BytesMut; +use postgres_protocol2::message::frontend; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; + +pub async fn cancel_query_raw( + stream: S, + mode: SslMode, + tls: T, + process_id: i32, + secret_key: i32, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + let mut stream = connect_tls::connect_tls(stream, mode, tls).await?; + + let mut buf = BytesMut::new(); + frontend::cancel_request(process_id, secret_key, &mut buf); + + stream.write_all(&buf).await.map_err(Error::io)?; + stream.flush().await.map_err(Error::io)?; + stream.shutdown().await.map_err(Error::io)?; + + Ok(()) +} diff --git a/libs/proxy/tokio-postgres2/src/cancel_token.rs b/libs/proxy/tokio-postgres2/src/cancel_token.rs new file mode 100644 index 0000000000..b949bf358f --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/cancel_token.rs @@ -0,0 +1,62 @@ +use crate::config::SslMode; +use crate::tls::TlsConnect; + +use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect}; +use crate::{cancel_query_raw, Error}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; + +/// The capability to request cancellation of in-progress queries on a +/// connection. +#[derive(Clone)] +pub struct CancelToken { + pub(crate) socket_config: Option, + pub(crate) ssl_mode: SslMode, + pub(crate) process_id: i32, + pub(crate) secret_key: i32, +} + +impl CancelToken { + /// Attempts to cancel the in-progress query on the connection associated + /// with this `CancelToken`. + /// + /// The server provides no information about whether a cancellation attempt was successful or not. An error will + /// only be returned if the client was unable to connect to the database. + /// + /// Cancellation is inherently racy. There is no guarantee that the + /// cancellation request will reach the server before the query terminates + /// normally, or that the connection associated with this token is still + /// active. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + pub async fn cancel_query(&self, tls: T) -> Result<(), Error> + where + T: MakeTlsConnect, + { + cancel_query::cancel_query( + self.socket_config.clone(), + self.ssl_mode, + tls, + self.process_id, + self.secret_key, + ) + .await + } + + /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new + /// connection itself. + pub async fn cancel_query_raw(&self, stream: S, tls: T) -> Result<(), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + cancel_query_raw::cancel_query_raw( + stream, + self.ssl_mode, + tls, + self.process_id, + self.secret_key, + ) + .await + } +} diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs new file mode 100644 index 0000000000..96200b71e7 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -0,0 +1,439 @@ +use crate::codec::{BackendMessages, FrontendMessage}; + +use crate::config::Host; +use crate::config::SslMode; +use crate::connection::{Request, RequestMessages}; + +use crate::query::RowStream; +use crate::simple_query::SimpleQueryStream; + +use crate::types::{Oid, ToSql, Type}; + +use crate::{ + prepare, query, simple_query, slice_iter, CancelToken, Error, ReadyForQueryStatus, Row, + SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, +}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_util::{future, ready, TryStreamExt}; +use parking_lot::Mutex; +use postgres_protocol2::message::{backend::Message, frontend}; +use std::collections::HashMap; +use std::fmt; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::mpsc; + +use std::time::Duration; + +pub struct Responses { + receiver: mpsc::Receiver, + cur: BackendMessages, +} + +impl Responses { + pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match self.cur.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))), + Some(message) => return Poll::Ready(Ok(message)), + None => {} + } + + match ready!(self.receiver.poll_recv(cx)) { + Some(messages) => self.cur = messages, + None => return Poll::Ready(Err(Error::closed())), + } + } + } + + pub async fn next(&mut self) -> Result { + future::poll_fn(|cx| self.poll_next(cx)).await + } +} + +/// A cache of type info and prepared statements for fetching type info +/// (corresponding to the queries in the [prepare] module). +#[derive(Default)] +struct CachedTypeInfo { + /// A statement for basic information for a type from its + /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its + /// fallback). + typeinfo: Option, + /// A statement for getting information for a composite type from its OID. + /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY). + typeinfo_composite: Option, + /// A statement for getting information for a composite type from its OID. + /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or + /// its fallback). + typeinfo_enum: Option, + + /// Cache of types already looked up. + types: HashMap, +} + +pub struct InnerClient { + sender: mpsc::UnboundedSender, + cached_typeinfo: Mutex, + + /// A buffer to use when writing out postgres commands. + buffer: Mutex, +} + +impl InnerClient { + pub fn send(&self, messages: RequestMessages) -> Result { + let (sender, receiver) = mpsc::channel(1); + let request = Request { messages, sender }; + self.sender.send(request).map_err(|_| Error::closed())?; + + Ok(Responses { + receiver, + cur: BackendMessages::empty(), + }) + } + + pub fn typeinfo(&self) -> Option { + self.cached_typeinfo.lock().typeinfo.clone() + } + + pub fn set_typeinfo(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo = Some(statement.clone()); + } + + pub fn typeinfo_composite(&self) -> Option { + self.cached_typeinfo.lock().typeinfo_composite.clone() + } + + pub fn set_typeinfo_composite(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone()); + } + + pub fn typeinfo_enum(&self) -> Option { + self.cached_typeinfo.lock().typeinfo_enum.clone() + } + + pub fn set_typeinfo_enum(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone()); + } + + pub fn type_(&self, oid: Oid) -> Option { + self.cached_typeinfo.lock().types.get(&oid).cloned() + } + + pub fn set_type(&self, oid: Oid, type_: &Type) { + self.cached_typeinfo.lock().types.insert(oid, type_.clone()); + } + + /// Call the given function with a buffer to be used when writing out + /// postgres commands. + pub fn with_buf(&self, f: F) -> R + where + F: FnOnce(&mut BytesMut) -> R, + { + let mut buffer = self.buffer.lock(); + let r = f(&mut buffer); + buffer.clear(); + r + } +} + +#[derive(Clone)] +pub(crate) struct SocketConfig { + pub host: Host, + pub port: u16, + pub connect_timeout: Option, + // pub keepalive: Option, +} + +/// An asynchronous PostgreSQL client. +/// +/// The client is one half of what is returned when a connection is established. Users interact with the database +/// through this client object. +pub struct Client { + inner: Arc, + + socket_config: Option, + ssl_mode: SslMode, + process_id: i32, + secret_key: i32, +} + +impl Client { + pub(crate) fn new( + sender: mpsc::UnboundedSender, + ssl_mode: SslMode, + process_id: i32, + secret_key: i32, + ) -> Client { + Client { + inner: Arc::new(InnerClient { + sender, + cached_typeinfo: Default::default(), + buffer: Default::default(), + }), + + socket_config: None, + ssl_mode, + process_id, + secret_key, + } + } + + /// Returns process_id. + pub fn get_process_id(&self) -> i32 { + self.process_id + } + + pub(crate) fn inner(&self) -> &Arc { + &self.inner + } + + pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) { + self.socket_config = Some(socket_config); + } + + /// Creates a new prepared statement. + /// + /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc), + /// which are set when executed. Prepared statements can only be used with the connection that created them. + pub async fn prepare(&self, query: &str) -> Result { + self.prepare_typed(query, &[]).await + } + + /// Like `prepare`, but allows the types of query parameters to be explicitly specified. + /// + /// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be + /// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`. + pub async fn prepare_typed( + &self, + query: &str, + parameter_types: &[Type], + ) -> Result { + prepare::prepare(&self.inner, query, parameter_types).await + } + + /// Executes a statement, returning a vector of the resulting rows. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.query_raw(statement, slice_iter(params)) + .await? + .try_collect() + .await + } + + /// The maximally flexible version of [`query`]. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + /// + /// [`query`]: #method.query + pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::query(&self.inner, statement, params).await + } + + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + query::query_txt(&self.inner, statement, params).await + } + + /// Executes a statement, returning the number of rows modified. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + pub async fn execute( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement, + { + self.execute_raw(statement, slice_iter(params)).await + } + + /// The maximally flexible version of [`execute`]. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + /// + /// [`execute`]: #method.execute + pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::execute(self.inner(), statement, params).await + } + + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings, + /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the + /// rows, this method returns a list of an enum which indicates either the completion of one of the commands, + /// or a row of data. This preserves the framing between the separate statements in the request. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query_raw(query).await?.try_collect().await + } + + pub(crate) async fn simple_query_raw(&self, query: &str) -> Result { + simple_query::simple_query(self.inner(), query).await + } + + /// Executes a sequence of SQL statements using the simple query protocol. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. This is intended for use when, for example, initializing a database schema. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub async fn batch_execute(&self, query: &str) -> Result { + simple_query::batch_execute(self.inner(), query).await + } + + /// Begins a new database transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + pub async fn transaction(&mut self) -> Result, Error> { + struct RollbackIfNotDone<'me> { + client: &'me Client, + done: bool, + } + + impl Drop for RollbackIfNotDone<'_> { + fn drop(&mut self) { + if self.done { + return; + } + + let buf = self.client.inner().with_buf(|buf| { + frontend::query("ROLLBACK", buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } + + // This is done, as `Future` created by this method can be dropped after + // `RequestMessages` is synchronously send to the `Connection` by + // `batch_execute()`, but before `Responses` is asynchronously polled to + // completion. In that case `Transaction` won't be created and thus + // won't be rolled back. + { + let mut cleaner = RollbackIfNotDone { + client: self, + done: false, + }; + self.batch_execute("BEGIN").await?; + cleaner.done = true; + } + + Ok(Transaction::new(self)) + } + + /// Returns a builder for a transaction with custom settings. + /// + /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other + /// attributes. + pub fn build_transaction(&mut self) -> TransactionBuilder<'_> { + TransactionBuilder::new(self) + } + + /// Constructs a cancellation token that can later be used to request cancellation of a query running on the + /// connection associated with this client. + pub fn cancel_token(&self) -> CancelToken { + CancelToken { + socket_config: self.socket_config.clone(), + ssl_mode: self.ssl_mode, + process_id: self.process_id, + secret_key: self.secret_key, + } + } + + /// Query for type information + pub async fn get_type(&self, oid: Oid) -> Result { + crate::prepare::get_type(&self.inner, oid).await + } + + /// Determines if the connection to the server has already closed. + /// + /// In that case, all future queries will fail. + pub fn is_closed(&self) -> bool { + self.inner.sender.is_closed() + } +} + +impl fmt::Debug for Client { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Client").finish() + } +} diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs new file mode 100644 index 0000000000..7412db785b --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -0,0 +1,109 @@ +use bytes::{Buf, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use postgres_protocol2::message::backend; +use postgres_protocol2::message::frontend::CopyData; +use std::io; +use tokio_util::codec::{Decoder, Encoder}; + +pub enum FrontendMessage { + Raw(Bytes), + CopyData(CopyData>), +} + +pub enum BackendMessage { + Normal { + messages: BackendMessages, + request_complete: bool, + }, + Async(backend::Message), +} + +pub struct BackendMessages(BytesMut); + +impl BackendMessages { + pub fn empty() -> BackendMessages { + BackendMessages(BytesMut::new()) + } +} + +impl FallibleIterator for BackendMessages { + type Item = backend::Message; + type Error = io::Error; + + fn next(&mut self) -> io::Result> { + backend::Message::parse(&mut self.0) + } +} + +pub struct PostgresCodec { + pub max_message_size: Option, +} + +impl Encoder for PostgresCodec { + type Error = io::Error; + + fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { + match item { + FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), + FrontendMessage::CopyData(data) => data.write(dst), + } + + Ok(()) + } +} + +impl Decoder for PostgresCodec { + type Item = BackendMessage; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { + let mut idx = 0; + let mut request_complete = false; + + while let Some(header) = backend::Header::parse(&src[idx..])? { + let len = header.len() as usize + 1; + if src[idx..].len() < len { + break; + } + + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + + match header.tag() { + backend::NOTICE_RESPONSE_TAG + | backend::NOTIFICATION_RESPONSE_TAG + | backend::PARAMETER_STATUS_TAG => { + if idx == 0 { + let message = backend::Message::parse(src)?.unwrap(); + return Ok(Some(BackendMessage::Async(message))); + } else { + break; + } + } + _ => {} + } + + idx += len; + + if header.tag() == backend::READY_FOR_QUERY_TAG { + request_complete = true; + break; + } + } + + if idx == 0 { + Ok(None) + } else { + Ok(Some(BackendMessage::Normal { + messages: BackendMessages(src.split_to(idx)), + request_complete, + })) + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs new file mode 100644 index 0000000000..969c20ba47 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -0,0 +1,897 @@ +//! Connection configuration. + +use crate::connect::connect; +use crate::connect_raw::connect_raw; +use crate::tls::MakeTlsConnect; +use crate::tls::TlsConnect; +use crate::{Client, Connection, Error}; +use std::borrow::Cow; +use std::str; +use std::str::FromStr; +use std::time::Duration; +use std::{error, fmt, iter, mem}; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub use postgres_protocol2::authentication::sasl::ScramKeys; +use tokio::net::TcpStream; + +/// Properties required of a session. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum TargetSessionAttrs { + /// No special properties are required. + Any, + /// The session must allow writes. + ReadWrite, +} + +/// TLS configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum SslMode { + /// Do not use TLS. + Disable, + /// Attempt to connect with TLS but allow sessions without. + Prefer, + /// Require the use of TLS. + Require, +} + +/// Channel binding configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ChannelBinding { + /// Do not use channel binding. + Disable, + /// Attempt to use channel binding but allow sessions without. + Prefer, + /// Require the use of channel binding. + Require, +} + +/// Replication mode configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + +/// A host specification. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Host { + /// A TCP hostname. + Tcp(String), +} + +/// Precomputed keys which may override password during auth. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthKeys { + /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`. + ScramSha256(ScramKeys<32>), +} + +/// Connection configuration. +/// +/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: +/// +/// # Key-Value +/// +/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain +/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped. +/// +/// ## Keys +/// +/// * `user` - The username to authenticate with. Required. +/// * `password` - The password to authenticate with. +/// * `dbname` - The name of the database to connect to. Defaults to the username. +/// * `options` - Command line options used to configure the server. +/// * `application_name` - Sets the `application_name` parameter on the server. +/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used +/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. +/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the +/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts +/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting +/// with the `connect` method. +/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be +/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if +/// omitted or the empty string. +/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames +/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout. +/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that +/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server +/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// +/// ## Examples +/// +/// ```not_rust +/// host=localhost user=postgres connect_timeout=10 keepalives=0 +/// ``` +/// +/// ```not_rust +/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// ``` +/// +/// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write +/// ``` +/// +/// # Url +/// +/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, +/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple +/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, +/// as the path component of the URL specifies the database name. +/// +/// ## Examples +/// +/// ```not_rust +/// postgresql://user@localhost +/// ``` +/// +/// ```not_rust +/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 +/// ``` +/// +/// ```not_rust +/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write +/// ``` +/// +/// ```not_rust +/// postgresql:///mydb?user=user&host=/var/lib/postgresql +/// ``` +#[derive(Clone, PartialEq, Eq)] +pub struct Config { + pub(crate) user: Option, + pub(crate) password: Option>, + pub(crate) auth_keys: Option>, + pub(crate) dbname: Option, + pub(crate) options: Option, + pub(crate) application_name: Option, + pub(crate) ssl_mode: SslMode, + pub(crate) host: Vec, + pub(crate) port: Vec, + pub(crate) connect_timeout: Option, + pub(crate) target_session_attrs: TargetSessionAttrs, + pub(crate) channel_binding: ChannelBinding, + pub(crate) replication_mode: Option, + pub(crate) max_backend_message_size: Option, +} + +impl Default for Config { + fn default() -> Config { + Config::new() + } +} + +impl Config { + /// Creates a new configuration. + pub fn new() -> Config { + Config { + user: None, + password: None, + auth_keys: None, + dbname: None, + options: None, + application_name: None, + ssl_mode: SslMode::Prefer, + host: vec![], + port: vec![], + connect_timeout: None, + target_session_attrs: TargetSessionAttrs::Any, + channel_binding: ChannelBinding::Prefer, + replication_mode: None, + max_backend_message_size: None, + } + } + + /// Sets the user to authenticate with. + /// + /// Required. + pub fn user(&mut self, user: &str) -> &mut Config { + self.user = Some(user.to_string()); + self + } + + /// Gets the user to authenticate with, if one has been configured with + /// the `user` method. + pub fn get_user(&self) -> Option<&str> { + self.user.as_deref() + } + + /// Sets the password to authenticate with. + pub fn password(&mut self, password: T) -> &mut Config + where + T: AsRef<[u8]>, + { + self.password = Some(password.as_ref().to_vec()); + self + } + + /// Gets the password to authenticate with, if one has been configured with + /// the `password` method. + pub fn get_password(&self) -> Option<&[u8]> { + self.password.as_deref() + } + + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.auth_keys = Some(Box::new(keys)); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.auth_keys.as_deref().copied() + } + + /// Sets the name of the database to connect to. + /// + /// Defaults to the user. + pub fn dbname(&mut self, dbname: &str) -> &mut Config { + self.dbname = Some(dbname.to_string()); + self + } + + /// Gets the name of the database to connect to, if one has been configured + /// with the `dbname` method. + pub fn get_dbname(&self) -> Option<&str> { + self.dbname.as_deref() + } + + /// Sets command line options used to configure the server. + pub fn options(&mut self, options: &str) -> &mut Config { + self.options = Some(options.to_string()); + self + } + + /// Gets the command line options used to configure the server, if the + /// options have been set with the `options` method. + pub fn get_options(&self) -> Option<&str> { + self.options.as_deref() + } + + /// Sets the value of the `application_name` runtime parameter. + pub fn application_name(&mut self, application_name: &str) -> &mut Config { + self.application_name = Some(application_name.to_string()); + self + } + + /// Gets the value of the `application_name` runtime parameter, if it has + /// been set with the `application_name` method. + pub fn get_application_name(&self) -> Option<&str> { + self.application_name.as_deref() + } + + /// Sets the SSL configuration. + /// + /// Defaults to `prefer`. + pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config { + self.ssl_mode = ssl_mode; + self + } + + /// Gets the SSL configuration. + pub fn get_ssl_mode(&self) -> SslMode { + self.ssl_mode + } + + /// Adds a host to the configuration. + /// + /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. + pub fn host(&mut self, host: &str) -> &mut Config { + self.host.push(Host::Tcp(host.to_string())); + self + } + + /// Gets the hosts that have been added to the configuration with `host`. + pub fn get_hosts(&self) -> &[Host] { + &self.host + } + + /// Adds a port to the configuration. + /// + /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which + /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports + /// as hosts. + pub fn port(&mut self, port: u16) -> &mut Config { + self.port.push(port); + self + } + + /// Gets the ports that have been added to the configuration with `port`. + pub fn get_ports(&self) -> &[u16] { + &self.port + } + + /// Sets the timeout applied to socket-level connection attempts. + /// + /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each + /// host separately. Defaults to no limit. + pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config { + self.connect_timeout = Some(connect_timeout); + self + } + + /// Gets the connection timeout, if one has been set with the + /// `connect_timeout` method. + pub fn get_connect_timeout(&self) -> Option<&Duration> { + self.connect_timeout.as_ref() + } + + /// Sets the requirements of the session. + /// + /// This can be used to connect to the primary server in a clustered database rather than one of the read-only + /// secondary servers. Defaults to `Any`. + pub fn target_session_attrs( + &mut self, + target_session_attrs: TargetSessionAttrs, + ) -> &mut Config { + self.target_session_attrs = target_session_attrs; + self + } + + /// Gets the requirements of the session. + pub fn get_target_session_attrs(&self) -> TargetSessionAttrs { + self.target_session_attrs + } + + /// Sets the channel binding behavior. + /// + /// Defaults to `prefer`. + pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config { + self.channel_binding = channel_binding; + self + } + + /// Gets the channel binding behavior. + pub fn get_channel_binding(&self) -> ChannelBinding { + self.channel_binding + } + + /// Set replication mode. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { + match key { + "user" => { + self.user(value); + } + "password" => { + self.password(value); + } + "dbname" => { + self.dbname(value); + } + "options" => { + self.options(value); + } + "application_name" => { + self.application_name(value); + } + "sslmode" => { + let mode = match value { + "disable" => SslMode::Disable, + "prefer" => SslMode::Prefer, + "require" => SslMode::Require, + _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))), + }; + self.ssl_mode(mode); + } + "host" => { + for host in value.split(',') { + self.host(host); + } + } + "port" => { + for port in value.split(',') { + let port = if port.is_empty() { + 5432 + } else { + port.parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))? + }; + self.port(port); + } + } + "connect_timeout" => { + let timeout = value + .parse::() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?; + if timeout > 0 { + self.connect_timeout(Duration::from_secs(timeout as u64)); + } + } + "target_session_attrs" => { + let target_session_attrs = match value { + "any" => TargetSessionAttrs::Any, + "read-write" => TargetSessionAttrs::ReadWrite, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "target_session_attrs", + )))); + } + }; + self.target_session_attrs(target_session_attrs); + } + "channel_binding" => { + let channel_binding = match value { + "disable" => ChannelBinding::Disable, + "prefer" => ChannelBinding::Prefer, + "require" => ChannelBinding::Require, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "channel_binding", + )))) + } + }; + self.channel_binding(channel_binding); + } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } + key => { + return Err(Error::config_parse(Box::new(UnknownOption( + key.to_string(), + )))); + } + } + + Ok(()) + } + + /// Opens a connection to a PostgreSQL database. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + pub async fn connect( + &self, + tls: T, + ) -> Result<(Client, Connection), Error> + where + T: MakeTlsConnect, + { + connect(tls, self).await + } + + /// Connects to a PostgreSQL database over an arbitrary stream. + /// + /// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored. + pub async fn connect_raw( + &self, + stream: S, + tls: T, + ) -> Result<(Client, Connection), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + connect_raw(stream, tls, self).await + } +} + +impl FromStr for Config { + type Err = Error; + + fn from_str(s: &str) -> Result { + match UrlParser::parse(s)? { + Some(config) => Ok(config), + None => Parser::parse(s), + } + } +} + +// Omit password from debug output +impl fmt::Debug for Config { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Redaction {} + impl fmt::Debug for Redaction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "_") + } + } + + f.debug_struct("Config") + .field("user", &self.user) + .field("password", &self.password.as_ref().map(|_| Redaction {})) + .field("dbname", &self.dbname) + .field("options", &self.options) + .field("application_name", &self.application_name) + .field("ssl_mode", &self.ssl_mode) + .field("host", &self.host) + .field("port", &self.port) + .field("connect_timeout", &self.connect_timeout) + .field("target_session_attrs", &self.target_session_attrs) + .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) + .finish() + } +} + +#[derive(Debug)] +struct UnknownOption(String); + +impl fmt::Display for UnknownOption { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "unknown option `{}`", self.0) + } +} + +impl error::Error for UnknownOption {} + +#[derive(Debug)] +struct InvalidValue(&'static str); + +impl fmt::Display for InvalidValue { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "invalid value for option `{}`", self.0) + } +} + +impl error::Error for InvalidValue {} + +struct Parser<'a> { + s: &'a str, + it: iter::Peekable>, +} + +impl<'a> Parser<'a> { + fn parse(s: &'a str) -> Result { + let mut parser = Parser { + s, + it: s.char_indices().peekable(), + }; + + let mut config = Config::new(); + + while let Some((key, value)) = parser.parameter()? { + config.param(key, &value)?; + } + + Ok(config) + } + + fn skip_ws(&mut self) { + self.take_while(char::is_whitespace); + } + + fn take_while(&mut self, f: F) -> &'a str + where + F: Fn(char) -> bool, + { + let start = match self.it.peek() { + Some(&(i, _)) => i, + None => return "", + }; + + loop { + match self.it.peek() { + Some(&(_, c)) if f(c) => { + self.it.next(); + } + Some(&(i, _)) => return &self.s[start..i], + None => return &self.s[start..], + } + } + } + + fn eat(&mut self, target: char) -> Result<(), Error> { + match self.it.next() { + Some((_, c)) if c == target => Ok(()), + Some((i, c)) => { + let m = format!( + "unexpected character at byte {}: expected `{}` but got `{}`", + i, target, c + ); + Err(Error::config_parse(m.into())) + } + None => Err(Error::config_parse("unexpected EOF".into())), + } + } + + fn eat_if(&mut self, target: char) -> bool { + match self.it.peek() { + Some(&(_, c)) if c == target => { + self.it.next(); + true + } + _ => false, + } + } + + fn keyword(&mut self) -> Option<&'a str> { + let s = self.take_while(|c| match c { + c if c.is_whitespace() => false, + '=' => false, + _ => true, + }); + + if s.is_empty() { + None + } else { + Some(s) + } + } + + fn value(&mut self) -> Result { + let value = if self.eat_if('\'') { + let value = self.quoted_value()?; + self.eat('\'')?; + value + } else { + self.simple_value()? + }; + + Ok(value) + } + + fn simple_value(&mut self) -> Result { + let mut value = String::new(); + + while let Some(&(_, c)) = self.it.peek() { + if c.is_whitespace() { + break; + } + + self.it.next(); + if c == '\\' { + if let Some((_, c2)) = self.it.next() { + value.push(c2); + } + } else { + value.push(c); + } + } + + if value.is_empty() { + return Err(Error::config_parse("unexpected EOF".into())); + } + + Ok(value) + } + + fn quoted_value(&mut self) -> Result { + let mut value = String::new(); + + while let Some(&(_, c)) = self.it.peek() { + if c == '\'' { + return Ok(value); + } + + self.it.next(); + if c == '\\' { + if let Some((_, c2)) = self.it.next() { + value.push(c2); + } + } else { + value.push(c); + } + } + + Err(Error::config_parse( + "unterminated quoted connection parameter value".into(), + )) + } + + fn parameter(&mut self) -> Result, Error> { + self.skip_ws(); + let keyword = match self.keyword() { + Some(keyword) => keyword, + None => return Ok(None), + }; + self.skip_ws(); + self.eat('=')?; + self.skip_ws(); + let value = self.value()?; + + Ok(Some((keyword, value))) + } +} + +// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict +struct UrlParser<'a> { + s: &'a str, + config: Config, +} + +impl<'a> UrlParser<'a> { + fn parse(s: &'a str) -> Result, Error> { + let s = match Self::remove_url_prefix(s) { + Some(s) => s, + None => return Ok(None), + }; + + let mut parser = UrlParser { + s, + config: Config::new(), + }; + + parser.parse_credentials()?; + parser.parse_host()?; + parser.parse_path()?; + parser.parse_params()?; + + Ok(Some(parser.config)) + } + + fn remove_url_prefix(s: &str) -> Option<&str> { + for prefix in &["postgres://", "postgresql://"] { + if let Some(stripped) = s.strip_prefix(prefix) { + return Some(stripped); + } + } + + None + } + + fn take_until(&mut self, end: &[char]) -> Option<&'a str> { + match self.s.find(end) { + Some(pos) => { + let (head, tail) = self.s.split_at(pos); + self.s = tail; + Some(head) + } + None => None, + } + } + + fn take_all(&mut self) -> &'a str { + mem::take(&mut self.s) + } + + fn eat_byte(&mut self) { + self.s = &self.s[1..]; + } + + fn parse_credentials(&mut self) -> Result<(), Error> { + let creds = match self.take_until(&['@']) { + Some(creds) => creds, + None => return Ok(()), + }; + self.eat_byte(); + + let mut it = creds.splitn(2, ':'); + let user = self.decode(it.next().unwrap())?; + self.config.user(&user); + + if let Some(password) = it.next() { + let password = Cow::from(percent_encoding::percent_decode(password.as_bytes())); + self.config.password(password); + } + + Ok(()) + } + + fn parse_host(&mut self) -> Result<(), Error> { + let host = match self.take_until(&['/', '?']) { + Some(host) => host, + None => self.take_all(), + }; + + if host.is_empty() { + return Ok(()); + } + + for chunk in host.split(',') { + let (host, port) = if chunk.starts_with('[') { + let idx = match chunk.find(']') { + Some(idx) => idx, + None => return Err(Error::config_parse(InvalidValue("host").into())), + }; + + let host = &chunk[1..idx]; + let remaining = &chunk[idx + 1..]; + let port = if let Some(port) = remaining.strip_prefix(':') { + Some(port) + } else if remaining.is_empty() { + None + } else { + return Err(Error::config_parse(InvalidValue("host").into())); + }; + + (host, port) + } else { + let mut it = chunk.splitn(2, ':'); + (it.next().unwrap(), it.next()) + }; + + self.host_param(host)?; + let port = self.decode(port.unwrap_or("5432"))?; + self.config.param("port", &port)?; + } + + Ok(()) + } + + fn parse_path(&mut self) -> Result<(), Error> { + if !self.s.starts_with('/') { + return Ok(()); + } + self.eat_byte(); + + let dbname = match self.take_until(&['?']) { + Some(dbname) => dbname, + None => self.take_all(), + }; + + if !dbname.is_empty() { + self.config.dbname(&self.decode(dbname)?); + } + + Ok(()) + } + + fn parse_params(&mut self) -> Result<(), Error> { + if !self.s.starts_with('?') { + return Ok(()); + } + self.eat_byte(); + + while !self.s.is_empty() { + let key = match self.take_until(&['=']) { + Some(key) => self.decode(key)?, + None => return Err(Error::config_parse("unterminated parameter".into())), + }; + self.eat_byte(); + + let value = match self.take_until(&['&']) { + Some(value) => { + self.eat_byte(); + value + } + None => self.take_all(), + }; + + if key == "host" { + self.host_param(value)?; + } else { + let value = self.decode(value)?; + self.config.param(&key, &value)?; + } + } + + Ok(()) + } + + fn host_param(&mut self, s: &str) -> Result<(), Error> { + let s = self.decode(s)?; + self.config.param("host", &s) + } + + fn decode(&self, s: &'a str) -> Result, Error> { + percent_encoding::percent_decode(s.as_bytes()) + .decode_utf8() + .map_err(|e| Error::config_parse(e.into())) + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs new file mode 100644 index 0000000000..7517fe0cde --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -0,0 +1,112 @@ +use crate::client::SocketConfig; +use crate::config::{Host, TargetSessionAttrs}; +use crate::connect_raw::connect_raw; +use crate::connect_socket::connect_socket; +use crate::tls::{MakeTlsConnect, TlsConnect}; +use crate::{Client, Config, Connection, Error, SimpleQueryMessage}; +use futures_util::{future, pin_mut, Future, FutureExt, Stream}; +use std::io; +use std::task::Poll; +use tokio::net::TcpStream; + +pub async fn connect( + mut tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + if config.host.is_empty() { + return Err(Error::config("host missing".into())); + } + + if config.port.len() > 1 && config.port.len() != config.host.len() { + return Err(Error::config("invalid number of ports".into())); + } + + let mut error = None; + for (i, host) in config.host.iter().enumerate() { + let port = config + .port + .get(i) + .or_else(|| config.port.first()) + .copied() + .unwrap_or(5432); + + let hostname = match host { + Host::Tcp(host) => host.as_str(), + }; + + let tls = tls + .make_tls_connect(hostname) + .map_err(|e| Error::tls(e.into()))?; + + match connect_once(host, port, tls, config).await { + Ok((client, connection)) => return Ok((client, connection)), + Err(e) => error = Some(e), + } + } + + Err(error.unwrap()) +} + +async fn connect_once( + host: &Host, + port: u16, + tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: TlsConnect, +{ + let socket = connect_socket(host, port, config.connect_timeout).await?; + let (mut client, mut connection) = connect_raw(socket, tls, config).await?; + + if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { + let rows = client.simple_query_raw("SHOW transaction_read_only"); + pin_mut!(rows); + + let rows = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Err(Error::closed())); + } + + rows.as_mut().poll(cx) + }) + .await?; + pin_mut!(rows); + + loop { + let next = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Some(Err(Error::closed()))); + } + + rows.as_mut().poll_next(cx) + }); + + match next.await.transpose()? { + Some(SimpleQueryMessage::Row(row)) => { + if row.try_get(0)? == Some("on") { + return Err(Error::connect(io::Error::new( + io::ErrorKind::PermissionDenied, + "database does not allow writes", + ))); + } else { + break; + } + } + Some(_) => {} + None => return Err(Error::unexpected_message()), + } + } + } + + client.set_socket_config(SocketConfig { + host: host.clone(), + port, + connect_timeout: config.connect_timeout, + }); + + Ok((client, connection)) +} diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs new file mode 100644 index 0000000000..80677af969 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -0,0 +1,359 @@ +use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::config::{self, AuthKeys, Config, ReplicationMode}; +use crate::connect_tls::connect_tls; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::tls::{TlsConnect, TlsStream}; +use crate::{Client, Connection, Error}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt}; +use postgres_protocol2::authentication; +use postgres_protocol2::authentication::sasl; +use postgres_protocol2::authentication::sasl::ScramSha256; +use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message}; +use postgres_protocol2::message::frontend; +use std::collections::{HashMap, VecDeque}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::mpsc; +use tokio_util::codec::Framed; + +pub struct StartupStream { + inner: Framed, PostgresCodec>, + buf: BackendMessages, + delayed: VecDeque, +} + +impl Sink for StartupStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> { + Pin::new(&mut self.inner).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx) + } +} + +impl Stream for StartupStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Item = io::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match self.buf.next() { + Ok(Some(message)) => return Poll::Ready(Some(Ok(message))), + Ok(None) => {} + Err(e) => return Poll::Ready(Some(Err(e))), + } + + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages, + Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => return Poll::Ready(None), + } + } + } +} + +pub async fn connect_raw( + stream: S, + tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + let stream = connect_tls(stream, config.ssl_mode, tls).await?; + + let mut stream = StartupStream { + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), + buf: BackendMessages::empty(), + delayed: VecDeque::new(), + }; + + startup(&mut stream, config).await?; + authenticate(&mut stream, config).await?; + let (process_id, secret_key, parameters) = read_info(&mut stream).await?; + + let (sender, receiver) = mpsc::unbounded_channel(); + let client = Client::new(sender, config.ssl_mode, process_id, secret_key); + let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver); + + Ok((client, connection)) +} + +async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut params = vec![("client_encoding", "UTF8")]; + if let Some(user) = &config.user { + params.push(("user", &**user)); + } + if let Some(dbname) = &config.dbname { + params.push(("database", &**dbname)); + } + if let Some(options) = &config.options { + params.push(("options", &**options)); + } + if let Some(application_name) = &config.application_name { + params.push(("application_name", &**application_name)); + } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } + + let mut buf = BytesMut::new(); + frontend::startup_message(params, &mut buf).map_err(Error::encode)?; + + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io) +} + +async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationOk) => { + can_skip_channel_binding(config)?; + return Ok(()); + } + Some(Message::AuthenticationCleartextPassword) => { + can_skip_channel_binding(config)?; + + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + authenticate_password(stream, pass).await?; + } + Some(Message::AuthenticationMd5Password(body)) => { + can_skip_channel_binding(config)?; + + let user = config + .user + .as_ref() + .ok_or_else(|| Error::config("user missing".into()))?; + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); + authenticate_password(stream, output.as_bytes()).await?; + } + Some(Message::AuthenticationSasl(body)) => { + authenticate_sasl(stream, body, config).await?; + } + Some(Message::AuthenticationKerberosV5) + | Some(Message::AuthenticationScmCredential) + | Some(Message::AuthenticationGss) + | Some(Message::AuthenticationSspi) => { + return Err(Error::authentication( + "unsupported authentication method".into(), + )) + } + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + + match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationOk) => Ok(()), + Some(Message::ErrorResponse(body)) => Err(Error::db(body)), + Some(_) => Err(Error::unexpected_message()), + None => Err(Error::closed()), + } +} + +fn can_skip_channel_binding(config: &Config) -> Result<(), Error> { + match config.channel_binding { + config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()), + config::ChannelBinding::Require => Err(Error::authentication( + "server did not use channel binding".into(), + )), + } +} + +async fn authenticate_password( + stream: &mut StartupStream, + password: &[u8], +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut buf = BytesMut::new(); + frontend::password_message(password, &mut buf).map_err(Error::encode)?; + + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io) +} + +async fn authenticate_sasl( + stream: &mut StartupStream, + body: AuthenticationSaslBody, + config: &Config, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + let mut has_scram = false; + let mut has_scram_plus = false; + let mut mechanisms = body.mechanisms(); + while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? { + match mechanism { + sasl::SCRAM_SHA_256 => has_scram = true, + sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true, + _ => {} + } + } + + let channel_binding = stream + .inner + .get_ref() + .channel_binding() + .tls_server_end_point + .filter(|_| config.channel_binding != config::ChannelBinding::Disable) + .map(sasl::ChannelBinding::tls_server_end_point); + + let (channel_binding, mechanism) = if has_scram_plus { + match channel_binding { + Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS), + None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), + } + } else if has_scram { + match channel_binding { + Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), + None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), + } + } else { + return Err(Error::authentication("unsupported SASL mechanism".into())); + }; + + if mechanism != sasl::SCRAM_SHA_256_PLUS { + can_skip_channel_binding(config)?; + } + + let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() { + ScramSha256::new_with_keys(keys, channel_binding) + } else if let Some(password) = config.get_password() { + ScramSha256::new(password, channel_binding) + } else { + return Err(Error::config("password or auth keys missing".into())); + }; + + let mut buf = BytesMut::new(); + frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io)?; + + let body = match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationSaslContinue(body)) => body, + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + }; + + scram + .update(body.data()) + .await + .map_err(|e| Error::authentication(e.into()))?; + + let mut buf = BytesMut::new(); + frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?; + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io)?; + + let body = match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationSaslFinal(body)) => body, + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + }; + + scram + .finish(body.data()) + .map_err(|e| Error::authentication(e.into()))?; + + Ok(()) +} + +async fn read_info( + stream: &mut StartupStream, +) -> Result<(i32, i32, HashMap), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut process_id = 0; + let mut secret_key = 0; + let mut parameters = HashMap::new(); + + loop { + match stream.try_next().await.map_err(Error::io)? { + Some(Message::BackendKeyData(body)) => { + process_id = body.process_id(); + secret_key = body.secret_key(); + } + Some(Message::ParameterStatus(body)) => { + parameters.insert( + body.name().map_err(Error::parse)?.to_string(), + body.value().map_err(Error::parse)?.to_string(), + ); + } + Some(msg @ Message::NoticeResponse(_)) => { + stream.delayed.push_back(BackendMessage::Async(msg)) + } + Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect_socket.rs b/libs/proxy/tokio-postgres2/src/connect_socket.rs new file mode 100644 index 0000000000..336a13317f --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect_socket.rs @@ -0,0 +1,65 @@ +use crate::config::Host; +use crate::Error; +use std::future::Future; +use std::io; +use std::time::Duration; +use tokio::net::{self, TcpStream}; +use tokio::time; + +pub(crate) async fn connect_socket( + host: &Host, + port: u16, + connect_timeout: Option, +) -> Result { + match host { + Host::Tcp(host) => { + let addrs = net::lookup_host((&**host, port)) + .await + .map_err(Error::connect)?; + + let mut last_err = None; + + for addr in addrs { + let stream = + match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await { + Ok(stream) => stream, + Err(e) => { + last_err = Some(e); + continue; + } + }; + + stream.set_nodelay(true).map_err(Error::connect)?; + + return Ok(stream); + } + + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) + } + } +} + +async fn connect_with_timeout(connect: F, timeout: Option) -> Result +where + F: Future>, +{ + match timeout { + Some(timeout) => match time::timeout(timeout, connect).await { + Ok(Ok(socket)) => Ok(socket), + Ok(Err(e)) => Err(Error::connect(e)), + Err(_) => Err(Error::connect(io::Error::new( + io::ErrorKind::TimedOut, + "connection timed out", + ))), + }, + None => match connect.await { + Ok(socket) => Ok(socket), + Err(e) => Err(Error::connect(e)), + }, + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect_tls.rs b/libs/proxy/tokio-postgres2/src/connect_tls.rs new file mode 100644 index 0000000000..64b0b68abc --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect_tls.rs @@ -0,0 +1,48 @@ +use crate::config::SslMode; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::tls::private::ForcePrivateApi; +use crate::tls::TlsConnect; +use crate::Error; +use bytes::BytesMut; +use postgres_protocol2::message::frontend; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +pub async fn connect_tls( + mut stream: S, + mode: SslMode, + tls: T, +) -> Result, Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + match mode { + SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), + SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { + return Ok(MaybeTlsStream::Raw(stream)) + } + SslMode::Prefer | SslMode::Require => {} + } + + let mut buf = BytesMut::new(); + frontend::ssl_request(&mut buf); + stream.write_all(&buf).await.map_err(Error::io)?; + + let mut buf = [0]; + stream.read_exact(&mut buf).await.map_err(Error::io)?; + + if buf[0] != b'S' { + if SslMode::Require == mode { + return Err(Error::tls("server does not support TLS".into())); + } else { + return Ok(MaybeTlsStream::Raw(stream)); + } + } + + let stream = tls + .connect(stream) + .await + .map_err(|e| Error::tls(e.into()))?; + + Ok(MaybeTlsStream::Tls(stream)) +} diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs new file mode 100644 index 0000000000..0aa5c77e22 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -0,0 +1,323 @@ +use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::error::DbError; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::{AsyncMessage, Error, Notification}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Sink, Stream}; +use log::{info, trace}; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use std::collections::{HashMap, VecDeque}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::mpsc; +use tokio_util::codec::Framed; +use tokio_util::sync::PollSender; + +pub enum RequestMessages { + Single(FrontendMessage), +} + +pub struct Request { + pub messages: RequestMessages, + pub sender: mpsc::Sender, +} + +pub struct Response { + sender: PollSender, +} + +#[derive(PartialEq, Debug)] +enum State { + Active, + Terminating, + Closing, +} + +/// A connection to a PostgreSQL database. +/// +/// This is one half of what is returned when a new connection is established. It performs the actual IO with the +/// server, and should generally be spawned off onto an executor to run in the background. +/// +/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has +/// occurred, or because its associated `Client` has dropped and all outstanding work has completed. +#[must_use = "futures do nothing unless polled"] +pub struct Connection { + /// HACK: we need this in the Neon Proxy. + pub stream: Framed, PostgresCodec>, + /// HACK: we need this in the Neon Proxy to forward params. + pub parameters: HashMap, + receiver: mpsc::UnboundedReceiver, + pending_request: Option, + pending_responses: VecDeque, + responses: VecDeque, + state: State, +} + +impl Connection +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) fn new( + stream: Framed, PostgresCodec>, + pending_responses: VecDeque, + parameters: HashMap, + receiver: mpsc::UnboundedReceiver, + ) -> Connection { + Connection { + stream, + parameters, + receiver, + pending_request: None, + pending_responses, + responses: VecDeque::new(), + state: State::Active, + } + } + + fn poll_response( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if let Some(message) = self.pending_responses.pop_front() { + trace!("retrying pending response"); + return Poll::Ready(Some(Ok(message))); + } + + Pin::new(&mut self.stream) + .poll_next(cx) + .map(|o| o.map(|r| r.map_err(Error::io))) + } + + fn poll_read(&mut self, cx: &mut Context<'_>) -> Result, Error> { + if self.state != State::Active { + trace!("poll_read: done"); + return Ok(None); + } + + loop { + let message = match self.poll_response(cx)? { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => return Err(Error::closed()), + Poll::Pending => { + trace!("poll_read: waiting on response"); + return Ok(None); + } + }; + + let (mut messages, request_complete) = match message { + BackendMessage::Async(Message::NoticeResponse(body)) => { + let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; + return Ok(Some(AsyncMessage::Notice(error))); + } + BackendMessage::Async(Message::NotificationResponse(body)) => { + let notification = Notification { + process_id: body.process_id(), + channel: body.channel().map_err(Error::parse)?.to_string(), + payload: body.message().map_err(Error::parse)?.to_string(), + }; + return Ok(Some(AsyncMessage::Notification(notification))); + } + BackendMessage::Async(Message::ParameterStatus(body)) => { + self.parameters.insert( + body.name().map_err(Error::parse)?.to_string(), + body.value().map_err(Error::parse)?.to_string(), + ); + continue; + } + BackendMessage::Async(_) => unreachable!(), + BackendMessage::Normal { + messages, + request_complete, + } => (messages, request_complete), + }; + + let mut response = match self.responses.pop_front() { + Some(response) => response, + None => match messages.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), + _ => return Err(Error::unexpected_message()), + }, + }; + + match response.sender.poll_reserve(cx) { + Poll::Ready(Ok(())) => { + let _ = response.sender.send_item(messages); + if !request_complete { + self.responses.push_front(response); + } + } + Poll::Ready(Err(_)) => { + // we need to keep paging through the rest of the messages even if the receiver's hung up + if !request_complete { + self.responses.push_front(response); + } + } + Poll::Pending => { + self.responses.push_front(response); + self.pending_responses.push_back(BackendMessage::Normal { + messages, + request_complete, + }); + trace!("poll_read: waiting on sender"); + return Ok(None); + } + } + } + } + + fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(messages) = self.pending_request.take() { + trace!("retrying pending request"); + return Poll::Ready(Some(messages)); + } + + if self.receiver.is_closed() { + return Poll::Ready(None); + } + + match self.receiver.poll_recv(cx) { + Poll::Ready(Some(request)) => { + trace!("polled new request"); + self.responses.push_back(Response { + sender: PollSender::new(request.sender), + }); + Poll::Ready(Some(request.messages)) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn poll_write(&mut self, cx: &mut Context<'_>) -> Result { + loop { + if self.state == State::Closing { + trace!("poll_write: done"); + return Ok(false); + } + + if Pin::new(&mut self.stream) + .poll_ready(cx) + .map_err(Error::io)? + .is_pending() + { + trace!("poll_write: waiting on socket"); + return Ok(false); + } + + let request = match self.poll_request(cx) { + Poll::Ready(Some(request)) => request, + Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => { + trace!("poll_write: at eof, terminating"); + self.state = State::Terminating; + let mut request = BytesMut::new(); + frontend::terminate(&mut request); + RequestMessages::Single(FrontendMessage::Raw(request.freeze())) + } + Poll::Ready(None) => { + trace!( + "poll_write: at eof, pending responses {}", + self.responses.len() + ); + return Ok(true); + } + Poll::Pending => { + trace!("poll_write: waiting on request"); + return Ok(true); + } + }; + + match request { + RequestMessages::Single(request) => { + Pin::new(&mut self.stream) + .start_send(request) + .map_err(Error::io)?; + if self.state == State::Terminating { + trace!("poll_write: sent eof, closing"); + self.state = State::Closing; + } + } + } + } + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> { + match Pin::new(&mut self.stream) + .poll_flush(cx) + .map_err(Error::io)? + { + Poll::Ready(()) => trace!("poll_flush: flushed"), + Poll::Pending => trace!("poll_flush: waiting on socket"), + } + Ok(()) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.state != State::Closing { + return Poll::Pending; + } + + match Pin::new(&mut self.stream) + .poll_close(cx) + .map_err(Error::io)? + { + Poll::Ready(()) => { + trace!("poll_shutdown: complete"); + Poll::Ready(Ok(())) + } + Poll::Pending => { + trace!("poll_shutdown: waiting on socket"); + Poll::Pending + } + } + } + + /// Returns the value of a runtime parameter for this connection. + pub fn parameter(&self, name: &str) -> Option<&str> { + self.parameters.get(name).map(|s| &**s) + } + + /// Polls for asynchronous messages from the server. + /// + /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to + /// examine those messages should use this method to drive the connection rather than its `Future` implementation. + pub fn poll_message( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + let message = self.poll_read(cx)?; + let want_flush = self.poll_write(cx)?; + if want_flush { + self.poll_flush(cx)?; + } + match message { + Some(message) => Poll::Ready(Some(Ok(message))), + None => match self.poll_shutdown(cx) { + Poll::Ready(Ok(())) => Poll::Ready(None), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + }, + } + } +} + +impl Future for Connection +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(), Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Some(message) = ready!(self.poll_message(cx)?) { + if let AsyncMessage::Notice(notice) = message { + info!("{}: {}", notice.severity(), notice.message()); + } + } + Poll::Ready(Ok(())) + } +} diff --git a/libs/proxy/tokio-postgres2/src/error/mod.rs b/libs/proxy/tokio-postgres2/src/error/mod.rs new file mode 100644 index 0000000000..6514322250 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/error/mod.rs @@ -0,0 +1,501 @@ +//! Errors. + +use fallible_iterator::FallibleIterator; +use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody}; +use std::error::{self, Error as _Error}; +use std::fmt; +use std::io; + +pub use self::sqlstate::*; + +#[allow(clippy::unreadable_literal)] +mod sqlstate; + +/// The severity of a Postgres error or notice. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Severity { + /// PANIC + Panic, + /// FATAL + Fatal, + /// ERROR + Error, + /// WARNING + Warning, + /// NOTICE + Notice, + /// DEBUG + Debug, + /// INFO + Info, + /// LOG + Log, +} + +impl fmt::Display for Severity { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Severity::Panic => "PANIC", + Severity::Fatal => "FATAL", + Severity::Error => "ERROR", + Severity::Warning => "WARNING", + Severity::Notice => "NOTICE", + Severity::Debug => "DEBUG", + Severity::Info => "INFO", + Severity::Log => "LOG", + }; + fmt.write_str(s) + } +} + +impl Severity { + fn from_str(s: &str) -> Option { + match s { + "PANIC" => Some(Severity::Panic), + "FATAL" => Some(Severity::Fatal), + "ERROR" => Some(Severity::Error), + "WARNING" => Some(Severity::Warning), + "NOTICE" => Some(Severity::Notice), + "DEBUG" => Some(Severity::Debug), + "INFO" => Some(Severity::Info), + "LOG" => Some(Severity::Log), + _ => None, + } + } +} + +/// A Postgres error or notice. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DbError { + severity: String, + parsed_severity: Option, + code: SqlState, + message: String, + detail: Option, + hint: Option, + position: Option, + where_: Option, + schema: Option, + table: Option, + column: Option, + datatype: Option, + constraint: Option, + file: Option, + line: Option, + routine: Option, +} + +impl DbError { + pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result { + let mut severity = None; + let mut parsed_severity = None; + let mut code = None; + let mut message = None; + let mut detail = None; + let mut hint = None; + let mut normal_position = None; + let mut internal_position = None; + let mut internal_query = None; + let mut where_ = None; + let mut schema = None; + let mut table = None; + let mut column = None; + let mut datatype = None; + let mut constraint = None; + let mut file = None; + let mut line = None; + let mut routine = None; + + while let Some(field) = fields.next()? { + match field.type_() { + b'S' => severity = Some(field.value().to_owned()), + b'C' => code = Some(SqlState::from_code(field.value())), + b'M' => message = Some(field.value().to_owned()), + b'D' => detail = Some(field.value().to_owned()), + b'H' => hint = Some(field.value().to_owned()), + b'P' => { + normal_position = Some(field.value().parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`P` field did not contain an integer", + ) + })?); + } + b'p' => { + internal_position = Some(field.value().parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`p` field did not contain an integer", + ) + })?); + } + b'q' => internal_query = Some(field.value().to_owned()), + b'W' => where_ = Some(field.value().to_owned()), + b's' => schema = Some(field.value().to_owned()), + b't' => table = Some(field.value().to_owned()), + b'c' => column = Some(field.value().to_owned()), + b'd' => datatype = Some(field.value().to_owned()), + b'n' => constraint = Some(field.value().to_owned()), + b'F' => file = Some(field.value().to_owned()), + b'L' => { + line = Some(field.value().parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`L` field did not contain an integer", + ) + })?); + } + b'R' => routine = Some(field.value().to_owned()), + b'V' => { + parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`V` field contained an invalid value", + ) + })?); + } + _ => {} + } + } + + Ok(DbError { + severity: severity + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?, + parsed_severity, + code: code + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?, + message: message + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?, + detail, + hint, + position: match normal_position { + Some(position) => Some(ErrorPosition::Original(position)), + None => match internal_position { + Some(position) => Some(ErrorPosition::Internal { + position, + query: internal_query.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`q` field missing but `p` field present", + ) + })?, + }), + None => None, + }, + }, + where_, + schema, + table, + column, + datatype, + constraint, + file, + line, + routine, + }) + } + + /// The field contents are ERROR, FATAL, or PANIC (in an error message), + /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a + /// localized translation of one of these. + pub fn severity(&self) -> &str { + &self.severity + } + + /// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+) + pub fn parsed_severity(&self) -> Option { + self.parsed_severity + } + + /// The SQLSTATE code for the error. + pub fn code(&self) -> &SqlState { + &self.code + } + + /// The primary human-readable error message. + /// + /// This should be accurate but terse (typically one line). + pub fn message(&self) -> &str { + &self.message + } + + /// An optional secondary error message carrying more detail about the + /// problem. + /// + /// Might run to multiple lines. + pub fn detail(&self) -> Option<&str> { + self.detail.as_deref() + } + + /// An optional suggestion what to do about the problem. + /// + /// This is intended to differ from `detail` in that it offers advice + /// (potentially inappropriate) rather than hard facts. Might run to + /// multiple lines. + pub fn hint(&self) -> Option<&str> { + self.hint.as_deref() + } + + /// An optional error cursor position into either the original query string + /// or an internally generated query. + pub fn position(&self) -> Option<&ErrorPosition> { + self.position.as_ref() + } + + /// An indication of the context in which the error occurred. + /// + /// Presently this includes a call stack traceback of active procedural + /// language functions and internally-generated queries. The trace is one + /// entry per line, most recent first. + pub fn where_(&self) -> Option<&str> { + self.where_.as_deref() + } + + /// If the error was associated with a specific database object, the name + /// of the schema containing that object, if any. (PostgreSQL 9.3+) + pub fn schema(&self) -> Option<&str> { + self.schema.as_deref() + } + + /// If the error was associated with a specific table, the name of the + /// table. (Refer to the schema name field for the name of the table's + /// schema.) (PostgreSQL 9.3+) + pub fn table(&self) -> Option<&str> { + self.table.as_deref() + } + + /// If the error was associated with a specific table column, the name of + /// the column. + /// + /// (Refer to the schema and table name fields to identify the table.) + /// (PostgreSQL 9.3+) + pub fn column(&self) -> Option<&str> { + self.column.as_deref() + } + + /// If the error was associated with a specific data type, the name of the + /// data type. (Refer to the schema name field for the name of the data + /// type's schema.) (PostgreSQL 9.3+) + pub fn datatype(&self) -> Option<&str> { + self.datatype.as_deref() + } + + /// If the error was associated with a specific constraint, the name of the + /// constraint. + /// + /// Refer to fields listed above for the associated table or domain. + /// (For this purpose, indexes are treated as constraints, even if they + /// weren't created with constraint syntax.) (PostgreSQL 9.3+) + pub fn constraint(&self) -> Option<&str> { + self.constraint.as_deref() + } + + /// The file name of the source-code location where the error was reported. + pub fn file(&self) -> Option<&str> { + self.file.as_deref() + } + + /// The line number of the source-code location where the error was + /// reported. + pub fn line(&self) -> Option { + self.line + } + + /// The name of the source-code routine reporting the error. + pub fn routine(&self) -> Option<&str> { + self.routine.as_deref() + } +} + +impl fmt::Display for DbError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{}: {}", self.severity, self.message)?; + if let Some(detail) = &self.detail { + write!(fmt, "\nDETAIL: {}", detail)?; + } + if let Some(hint) = &self.hint { + write!(fmt, "\nHINT: {}", hint)?; + } + Ok(()) + } +} + +impl error::Error for DbError {} + +/// Represents the position of an error in a query. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum ErrorPosition { + /// A position in the original query. + Original(u32), + /// A position in an internally generated query. + Internal { + /// The byte position. + position: u32, + /// A query generated by the Postgres server. + query: String, + }, +} + +#[derive(Debug, PartialEq)] +enum Kind { + Io, + UnexpectedMessage, + Tls, + ToSql(usize), + FromSql(usize), + Column(String), + Closed, + Db, + Parse, + Encode, + Authentication, + ConfigParse, + Config, + Connect, + Timeout, +} + +struct ErrorInner { + kind: Kind, + cause: Option>, +} + +/// An error communicating with the Postgres server. +pub struct Error(Box); + +impl fmt::Debug for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Error") + .field("kind", &self.0.kind) + .field("cause", &self.0.cause) + .finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0.kind { + Kind::Io => fmt.write_str("error communicating with the server")?, + Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, + Kind::Tls => fmt.write_str("error performing TLS handshake")?, + Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?, + Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, + Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?, + Kind::Closed => fmt.write_str("connection closed")?, + Kind::Db => fmt.write_str("db error")?, + Kind::Parse => fmt.write_str("error parsing response from server")?, + Kind::Encode => fmt.write_str("error encoding message to server")?, + Kind::Authentication => fmt.write_str("authentication error")?, + Kind::ConfigParse => fmt.write_str("invalid connection string")?, + Kind::Config => fmt.write_str("invalid configuration")?, + Kind::Connect => fmt.write_str("error connecting to server")?, + Kind::Timeout => fmt.write_str("timeout waiting for server")?, + }; + if let Some(ref cause) = self.0.cause { + write!(fmt, ": {}", cause)?; + } + Ok(()) + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + self.0.cause.as_ref().map(|e| &**e as _) + } +} + +impl Error { + /// Consumes the error, returning its cause. + pub fn into_source(self) -> Option> { + self.0.cause + } + + /// Returns the source of this error if it was a `DbError`. + /// + /// This is a simple convenience method. + pub fn as_db_error(&self) -> Option<&DbError> { + self.source().and_then(|e| e.downcast_ref::()) + } + + /// Determines if the error was associated with closed connection. + pub fn is_closed(&self) -> bool { + self.0.kind == Kind::Closed + } + + /// Returns the SQLSTATE error code associated with the error. + /// + /// This is a convenience method that downcasts the cause to a `DbError` and returns its code. + pub fn code(&self) -> Option<&SqlState> { + self.as_db_error().map(DbError::code) + } + + fn new(kind: Kind, cause: Option>) -> Error { + Error(Box::new(ErrorInner { kind, cause })) + } + + pub(crate) fn closed() -> Error { + Error::new(Kind::Closed, None) + } + + pub(crate) fn unexpected_message() -> Error { + Error::new(Kind::UnexpectedMessage, None) + } + + #[allow(clippy::needless_pass_by_value)] + pub(crate) fn db(error: ErrorResponseBody) -> Error { + match DbError::parse(&mut error.fields()) { + Ok(e) => Error::new(Kind::Db, Some(Box::new(e))), + Err(e) => Error::new(Kind::Parse, Some(Box::new(e))), + } + } + + pub(crate) fn parse(e: io::Error) -> Error { + Error::new(Kind::Parse, Some(Box::new(e))) + } + + pub(crate) fn encode(e: io::Error) -> Error { + Error::new(Kind::Encode, Some(Box::new(e))) + } + + #[allow(clippy::wrong_self_convention)] + pub(crate) fn to_sql(e: Box, idx: usize) -> Error { + Error::new(Kind::ToSql(idx), Some(e)) + } + + pub(crate) fn from_sql(e: Box, idx: usize) -> Error { + Error::new(Kind::FromSql(idx), Some(e)) + } + + pub(crate) fn column(column: String) -> Error { + Error::new(Kind::Column(column), None) + } + + pub(crate) fn tls(e: Box) -> Error { + Error::new(Kind::Tls, Some(e)) + } + + pub(crate) fn io(e: io::Error) -> Error { + Error::new(Kind::Io, Some(Box::new(e))) + } + + pub(crate) fn authentication(e: Box) -> Error { + Error::new(Kind::Authentication, Some(e)) + } + + pub(crate) fn config_parse(e: Box) -> Error { + Error::new(Kind::ConfigParse, Some(e)) + } + + pub(crate) fn config(e: Box) -> Error { + Error::new(Kind::Config, Some(e)) + } + + pub(crate) fn connect(e: io::Error) -> Error { + Error::new(Kind::Connect, Some(Box::new(e))) + } + + #[doc(hidden)] + pub fn __private_api_timeout() -> Error { + Error::new(Kind::Timeout, None) + } +} diff --git a/libs/proxy/tokio-postgres2/src/error/sqlstate.rs b/libs/proxy/tokio-postgres2/src/error/sqlstate.rs new file mode 100644 index 0000000000..13a1d75f95 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/error/sqlstate.rs @@ -0,0 +1,1670 @@ +// Autogenerated file - DO NOT EDIT + +/// A SQLSTATE error code +#[derive(PartialEq, Eq, Clone, Debug)] +pub struct SqlState(Inner); + +impl SqlState { + /// Creates a `SqlState` from its error code. + pub fn from_code(s: &str) -> SqlState { + match SQLSTATE_MAP.get(s) { + Some(state) => state.clone(), + None => SqlState(Inner::Other(s.into())), + } + } + + /// Returns the error code corresponding to the `SqlState`. + pub fn code(&self) -> &str { + match &self.0 { + Inner::E00000 => "00000", + Inner::E01000 => "01000", + Inner::E0100C => "0100C", + Inner::E01008 => "01008", + Inner::E01003 => "01003", + Inner::E01007 => "01007", + Inner::E01006 => "01006", + Inner::E01004 => "01004", + Inner::E01P01 => "01P01", + Inner::E02000 => "02000", + Inner::E02001 => "02001", + Inner::E03000 => "03000", + Inner::E08000 => "08000", + Inner::E08003 => "08003", + Inner::E08006 => "08006", + Inner::E08001 => "08001", + Inner::E08004 => "08004", + Inner::E08007 => "08007", + Inner::E08P01 => "08P01", + Inner::E09000 => "09000", + Inner::E0A000 => "0A000", + Inner::E0B000 => "0B000", + Inner::E0F000 => "0F000", + Inner::E0F001 => "0F001", + Inner::E0L000 => "0L000", + Inner::E0LP01 => "0LP01", + Inner::E0P000 => "0P000", + Inner::E0Z000 => "0Z000", + Inner::E0Z002 => "0Z002", + Inner::E20000 => "20000", + Inner::E21000 => "21000", + Inner::E22000 => "22000", + Inner::E2202E => "2202E", + Inner::E22021 => "22021", + Inner::E22008 => "22008", + Inner::E22012 => "22012", + Inner::E22005 => "22005", + Inner::E2200B => "2200B", + Inner::E22022 => "22022", + Inner::E22015 => "22015", + Inner::E2201E => "2201E", + Inner::E22014 => "22014", + Inner::E22016 => "22016", + Inner::E2201F => "2201F", + Inner::E2201G => "2201G", + Inner::E22018 => "22018", + Inner::E22007 => "22007", + Inner::E22019 => "22019", + Inner::E2200D => "2200D", + Inner::E22025 => "22025", + Inner::E22P06 => "22P06", + Inner::E22010 => "22010", + Inner::E22023 => "22023", + Inner::E22013 => "22013", + Inner::E2201B => "2201B", + Inner::E2201W => "2201W", + Inner::E2201X => "2201X", + Inner::E2202H => "2202H", + Inner::E2202G => "2202G", + Inner::E22009 => "22009", + Inner::E2200C => "2200C", + Inner::E2200G => "2200G", + Inner::E22004 => "22004", + Inner::E22002 => "22002", + Inner::E22003 => "22003", + Inner::E2200H => "2200H", + Inner::E22026 => "22026", + Inner::E22001 => "22001", + Inner::E22011 => "22011", + Inner::E22027 => "22027", + Inner::E22024 => "22024", + Inner::E2200F => "2200F", + Inner::E22P01 => "22P01", + Inner::E22P02 => "22P02", + Inner::E22P03 => "22P03", + Inner::E22P04 => "22P04", + Inner::E22P05 => "22P05", + Inner::E2200L => "2200L", + Inner::E2200M => "2200M", + Inner::E2200N => "2200N", + Inner::E2200S => "2200S", + Inner::E2200T => "2200T", + Inner::E22030 => "22030", + Inner::E22031 => "22031", + Inner::E22032 => "22032", + Inner::E22033 => "22033", + Inner::E22034 => "22034", + Inner::E22035 => "22035", + Inner::E22036 => "22036", + Inner::E22037 => "22037", + Inner::E22038 => "22038", + Inner::E22039 => "22039", + Inner::E2203A => "2203A", + Inner::E2203B => "2203B", + Inner::E2203C => "2203C", + Inner::E2203D => "2203D", + Inner::E2203E => "2203E", + Inner::E2203F => "2203F", + Inner::E2203G => "2203G", + Inner::E23000 => "23000", + Inner::E23001 => "23001", + Inner::E23502 => "23502", + Inner::E23503 => "23503", + Inner::E23505 => "23505", + Inner::E23514 => "23514", + Inner::E23P01 => "23P01", + Inner::E24000 => "24000", + Inner::E25000 => "25000", + Inner::E25001 => "25001", + Inner::E25002 => "25002", + Inner::E25008 => "25008", + Inner::E25003 => "25003", + Inner::E25004 => "25004", + Inner::E25005 => "25005", + Inner::E25006 => "25006", + Inner::E25007 => "25007", + Inner::E25P01 => "25P01", + Inner::E25P02 => "25P02", + Inner::E25P03 => "25P03", + Inner::E26000 => "26000", + Inner::E27000 => "27000", + Inner::E28000 => "28000", + Inner::E28P01 => "28P01", + Inner::E2B000 => "2B000", + Inner::E2BP01 => "2BP01", + Inner::E2D000 => "2D000", + Inner::E2F000 => "2F000", + Inner::E2F005 => "2F005", + Inner::E2F002 => "2F002", + Inner::E2F003 => "2F003", + Inner::E2F004 => "2F004", + Inner::E34000 => "34000", + Inner::E38000 => "38000", + Inner::E38001 => "38001", + Inner::E38002 => "38002", + Inner::E38003 => "38003", + Inner::E38004 => "38004", + Inner::E39000 => "39000", + Inner::E39001 => "39001", + Inner::E39004 => "39004", + Inner::E39P01 => "39P01", + Inner::E39P02 => "39P02", + Inner::E39P03 => "39P03", + Inner::E3B000 => "3B000", + Inner::E3B001 => "3B001", + Inner::E3D000 => "3D000", + Inner::E3F000 => "3F000", + Inner::E40000 => "40000", + Inner::E40002 => "40002", + Inner::E40001 => "40001", + Inner::E40003 => "40003", + Inner::E40P01 => "40P01", + Inner::E42000 => "42000", + Inner::E42601 => "42601", + Inner::E42501 => "42501", + Inner::E42846 => "42846", + Inner::E42803 => "42803", + Inner::E42P20 => "42P20", + Inner::E42P19 => "42P19", + Inner::E42830 => "42830", + Inner::E42602 => "42602", + Inner::E42622 => "42622", + Inner::E42939 => "42939", + Inner::E42804 => "42804", + Inner::E42P18 => "42P18", + Inner::E42P21 => "42P21", + Inner::E42P22 => "42P22", + Inner::E42809 => "42809", + Inner::E428C9 => "428C9", + Inner::E42703 => "42703", + Inner::E42883 => "42883", + Inner::E42P01 => "42P01", + Inner::E42P02 => "42P02", + Inner::E42704 => "42704", + Inner::E42701 => "42701", + Inner::E42P03 => "42P03", + Inner::E42P04 => "42P04", + Inner::E42723 => "42723", + Inner::E42P05 => "42P05", + Inner::E42P06 => "42P06", + Inner::E42P07 => "42P07", + Inner::E42712 => "42712", + Inner::E42710 => "42710", + Inner::E42702 => "42702", + Inner::E42725 => "42725", + Inner::E42P08 => "42P08", + Inner::E42P09 => "42P09", + Inner::E42P10 => "42P10", + Inner::E42611 => "42611", + Inner::E42P11 => "42P11", + Inner::E42P12 => "42P12", + Inner::E42P13 => "42P13", + Inner::E42P14 => "42P14", + Inner::E42P15 => "42P15", + Inner::E42P16 => "42P16", + Inner::E42P17 => "42P17", + Inner::E44000 => "44000", + Inner::E53000 => "53000", + Inner::E53100 => "53100", + Inner::E53200 => "53200", + Inner::E53300 => "53300", + Inner::E53400 => "53400", + Inner::E54000 => "54000", + Inner::E54001 => "54001", + Inner::E54011 => "54011", + Inner::E54023 => "54023", + Inner::E55000 => "55000", + Inner::E55006 => "55006", + Inner::E55P02 => "55P02", + Inner::E55P03 => "55P03", + Inner::E55P04 => "55P04", + Inner::E57000 => "57000", + Inner::E57014 => "57014", + Inner::E57P01 => "57P01", + Inner::E57P02 => "57P02", + Inner::E57P03 => "57P03", + Inner::E57P04 => "57P04", + Inner::E57P05 => "57P05", + Inner::E58000 => "58000", + Inner::E58030 => "58030", + Inner::E58P01 => "58P01", + Inner::E58P02 => "58P02", + Inner::E72000 => "72000", + Inner::EF0000 => "F0000", + Inner::EF0001 => "F0001", + Inner::EHV000 => "HV000", + Inner::EHV005 => "HV005", + Inner::EHV002 => "HV002", + Inner::EHV010 => "HV010", + Inner::EHV021 => "HV021", + Inner::EHV024 => "HV024", + Inner::EHV007 => "HV007", + Inner::EHV008 => "HV008", + Inner::EHV004 => "HV004", + Inner::EHV006 => "HV006", + Inner::EHV091 => "HV091", + Inner::EHV00B => "HV00B", + Inner::EHV00C => "HV00C", + Inner::EHV00D => "HV00D", + Inner::EHV090 => "HV090", + Inner::EHV00A => "HV00A", + Inner::EHV009 => "HV009", + Inner::EHV014 => "HV014", + Inner::EHV001 => "HV001", + Inner::EHV00P => "HV00P", + Inner::EHV00J => "HV00J", + Inner::EHV00K => "HV00K", + Inner::EHV00Q => "HV00Q", + Inner::EHV00R => "HV00R", + Inner::EHV00L => "HV00L", + Inner::EHV00M => "HV00M", + Inner::EHV00N => "HV00N", + Inner::EP0000 => "P0000", + Inner::EP0001 => "P0001", + Inner::EP0002 => "P0002", + Inner::EP0003 => "P0003", + Inner::EP0004 => "P0004", + Inner::EXX000 => "XX000", + Inner::EXX001 => "XX001", + Inner::EXX002 => "XX002", + Inner::Other(code) => code, + } + } + + /// 00000 + pub const SUCCESSFUL_COMPLETION: SqlState = SqlState(Inner::E00000); + + /// 01000 + pub const WARNING: SqlState = SqlState(Inner::E01000); + + /// 0100C + pub const WARNING_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E0100C); + + /// 01008 + pub const WARNING_IMPLICIT_ZERO_BIT_PADDING: SqlState = SqlState(Inner::E01008); + + /// 01003 + pub const WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: SqlState = SqlState(Inner::E01003); + + /// 01007 + pub const WARNING_PRIVILEGE_NOT_GRANTED: SqlState = SqlState(Inner::E01007); + + /// 01006 + pub const WARNING_PRIVILEGE_NOT_REVOKED: SqlState = SqlState(Inner::E01006); + + /// 01004 + pub const WARNING_STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E01004); + + /// 01P01 + pub const WARNING_DEPRECATED_FEATURE: SqlState = SqlState(Inner::E01P01); + + /// 02000 + pub const NO_DATA: SqlState = SqlState(Inner::E02000); + + /// 02001 + pub const NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E02001); + + /// 03000 + pub const SQL_STATEMENT_NOT_YET_COMPLETE: SqlState = SqlState(Inner::E03000); + + /// 08000 + pub const CONNECTION_EXCEPTION: SqlState = SqlState(Inner::E08000); + + /// 08003 + pub const CONNECTION_DOES_NOT_EXIST: SqlState = SqlState(Inner::E08003); + + /// 08006 + pub const CONNECTION_FAILURE: SqlState = SqlState(Inner::E08006); + + /// 08001 + pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: SqlState = SqlState(Inner::E08001); + + /// 08004 + pub const SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: SqlState = SqlState(Inner::E08004); + + /// 08007 + pub const TRANSACTION_RESOLUTION_UNKNOWN: SqlState = SqlState(Inner::E08007); + + /// 08P01 + pub const PROTOCOL_VIOLATION: SqlState = SqlState(Inner::E08P01); + + /// 09000 + pub const TRIGGERED_ACTION_EXCEPTION: SqlState = SqlState(Inner::E09000); + + /// 0A000 + pub const FEATURE_NOT_SUPPORTED: SqlState = SqlState(Inner::E0A000); + + /// 0B000 + pub const INVALID_TRANSACTION_INITIATION: SqlState = SqlState(Inner::E0B000); + + /// 0F000 + pub const LOCATOR_EXCEPTION: SqlState = SqlState(Inner::E0F000); + + /// 0F001 + pub const L_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E0F001); + + /// 0L000 + pub const INVALID_GRANTOR: SqlState = SqlState(Inner::E0L000); + + /// 0LP01 + pub const INVALID_GRANT_OPERATION: SqlState = SqlState(Inner::E0LP01); + + /// 0P000 + pub const INVALID_ROLE_SPECIFICATION: SqlState = SqlState(Inner::E0P000); + + /// 0Z000 + pub const DIAGNOSTICS_EXCEPTION: SqlState = SqlState(Inner::E0Z000); + + /// 0Z002 + pub const STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: SqlState = + SqlState(Inner::E0Z002); + + /// 20000 + pub const CASE_NOT_FOUND: SqlState = SqlState(Inner::E20000); + + /// 21000 + pub const CARDINALITY_VIOLATION: SqlState = SqlState(Inner::E21000); + + /// 22000 + pub const DATA_EXCEPTION: SqlState = SqlState(Inner::E22000); + + /// 2202E + pub const ARRAY_ELEMENT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 2202E + pub const ARRAY_SUBSCRIPT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 22021 + pub const CHARACTER_NOT_IN_REPERTOIRE: SqlState = SqlState(Inner::E22021); + + /// 22008 + pub const DATETIME_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22008); + + /// 22008 + pub const DATETIME_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22008); + + /// 22012 + pub const DIVISION_BY_ZERO: SqlState = SqlState(Inner::E22012); + + /// 22005 + pub const ERROR_IN_ASSIGNMENT: SqlState = SqlState(Inner::E22005); + + /// 2200B + pub const ESCAPE_CHARACTER_CONFLICT: SqlState = SqlState(Inner::E2200B); + + /// 22022 + pub const INDICATOR_OVERFLOW: SqlState = SqlState(Inner::E22022); + + /// 22015 + pub const INTERVAL_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22015); + + /// 2201E + pub const INVALID_ARGUMENT_FOR_LOG: SqlState = SqlState(Inner::E2201E); + + /// 22014 + pub const INVALID_ARGUMENT_FOR_NTILE: SqlState = SqlState(Inner::E22014); + + /// 22016 + pub const INVALID_ARGUMENT_FOR_NTH_VALUE: SqlState = SqlState(Inner::E22016); + + /// 2201F + pub const INVALID_ARGUMENT_FOR_POWER_FUNCTION: SqlState = SqlState(Inner::E2201F); + + /// 2201G + pub const INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: SqlState = SqlState(Inner::E2201G); + + /// 22018 + pub const INVALID_CHARACTER_VALUE_FOR_CAST: SqlState = SqlState(Inner::E22018); + + /// 22007 + pub const INVALID_DATETIME_FORMAT: SqlState = SqlState(Inner::E22007); + + /// 22019 + pub const INVALID_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22019); + + /// 2200D + pub const INVALID_ESCAPE_OCTET: SqlState = SqlState(Inner::E2200D); + + /// 22025 + pub const INVALID_ESCAPE_SEQUENCE: SqlState = SqlState(Inner::E22025); + + /// 22P06 + pub const NONSTANDARD_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22P06); + + /// 22010 + pub const INVALID_INDICATOR_PARAMETER_VALUE: SqlState = SqlState(Inner::E22010); + + /// 22023 + pub const INVALID_PARAMETER_VALUE: SqlState = SqlState(Inner::E22023); + + /// 22013 + pub const INVALID_PRECEDING_OR_FOLLOWING_SIZE: SqlState = SqlState(Inner::E22013); + + /// 2201B + pub const INVALID_REGULAR_EXPRESSION: SqlState = SqlState(Inner::E2201B); + + /// 2201W + pub const INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: SqlState = SqlState(Inner::E2201W); + + /// 2201X + pub const INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: SqlState = SqlState(Inner::E2201X); + + /// 2202H + pub const INVALID_TABLESAMPLE_ARGUMENT: SqlState = SqlState(Inner::E2202H); + + /// 2202G + pub const INVALID_TABLESAMPLE_REPEAT: SqlState = SqlState(Inner::E2202G); + + /// 22009 + pub const INVALID_TIME_ZONE_DISPLACEMENT_VALUE: SqlState = SqlState(Inner::E22009); + + /// 2200C + pub const INVALID_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E2200C); + + /// 2200G + pub const MOST_SPECIFIC_TYPE_MISMATCH: SqlState = SqlState(Inner::E2200G); + + /// 22004 + pub const NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E22004); + + /// 22002 + pub const NULL_VALUE_NO_INDICATOR_PARAMETER: SqlState = SqlState(Inner::E22002); + + /// 22003 + pub const NUMERIC_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22003); + + /// 2200H + pub const SEQUENCE_GENERATOR_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E2200H); + + /// 22026 + pub const STRING_DATA_LENGTH_MISMATCH: SqlState = SqlState(Inner::E22026); + + /// 22001 + pub const STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E22001); + + /// 22011 + pub const SUBSTRING_ERROR: SqlState = SqlState(Inner::E22011); + + /// 22027 + pub const TRIM_ERROR: SqlState = SqlState(Inner::E22027); + + /// 22024 + pub const UNTERMINATED_C_STRING: SqlState = SqlState(Inner::E22024); + + /// 2200F + pub const ZERO_LENGTH_CHARACTER_STRING: SqlState = SqlState(Inner::E2200F); + + /// 22P01 + pub const FLOATING_POINT_EXCEPTION: SqlState = SqlState(Inner::E22P01); + + /// 22P02 + pub const INVALID_TEXT_REPRESENTATION: SqlState = SqlState(Inner::E22P02); + + /// 22P03 + pub const INVALID_BINARY_REPRESENTATION: SqlState = SqlState(Inner::E22P03); + + /// 22P04 + pub const BAD_COPY_FILE_FORMAT: SqlState = SqlState(Inner::E22P04); + + /// 22P05 + pub const UNTRANSLATABLE_CHARACTER: SqlState = SqlState(Inner::E22P05); + + /// 2200L + pub const NOT_AN_XML_DOCUMENT: SqlState = SqlState(Inner::E2200L); + + /// 2200M + pub const INVALID_XML_DOCUMENT: SqlState = SqlState(Inner::E2200M); + + /// 2200N + pub const INVALID_XML_CONTENT: SqlState = SqlState(Inner::E2200N); + + /// 2200S + pub const INVALID_XML_COMMENT: SqlState = SqlState(Inner::E2200S); + + /// 2200T + pub const INVALID_XML_PROCESSING_INSTRUCTION: SqlState = SqlState(Inner::E2200T); + + /// 22030 + pub const DUPLICATE_JSON_OBJECT_KEY_VALUE: SqlState = SqlState(Inner::E22030); + + /// 22031 + pub const INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: SqlState = SqlState(Inner::E22031); + + /// 22032 + pub const INVALID_JSON_TEXT: SqlState = SqlState(Inner::E22032); + + /// 22033 + pub const INVALID_SQL_JSON_SUBSCRIPT: SqlState = SqlState(Inner::E22033); + + /// 22034 + pub const MORE_THAN_ONE_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22034); + + /// 22035 + pub const NO_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22035); + + /// 22036 + pub const NON_NUMERIC_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22036); + + /// 22037 + pub const NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: SqlState = SqlState(Inner::E22037); + + /// 22038 + pub const SINGLETON_SQL_JSON_ITEM_REQUIRED: SqlState = SqlState(Inner::E22038); + + /// 22039 + pub const SQL_JSON_ARRAY_NOT_FOUND: SqlState = SqlState(Inner::E22039); + + /// 2203A + pub const SQL_JSON_MEMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203A); + + /// 2203B + pub const SQL_JSON_NUMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203B); + + /// 2203C + pub const SQL_JSON_OBJECT_NOT_FOUND: SqlState = SqlState(Inner::E2203C); + + /// 2203D + pub const TOO_MANY_JSON_ARRAY_ELEMENTS: SqlState = SqlState(Inner::E2203D); + + /// 2203E + pub const TOO_MANY_JSON_OBJECT_MEMBERS: SqlState = SqlState(Inner::E2203E); + + /// 2203F + pub const SQL_JSON_SCALAR_REQUIRED: SqlState = SqlState(Inner::E2203F); + + /// 2203G + pub const SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE: SqlState = SqlState(Inner::E2203G); + + /// 23000 + pub const INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E23000); + + /// 23001 + pub const RESTRICT_VIOLATION: SqlState = SqlState(Inner::E23001); + + /// 23502 + pub const NOT_NULL_VIOLATION: SqlState = SqlState(Inner::E23502); + + /// 23503 + pub const FOREIGN_KEY_VIOLATION: SqlState = SqlState(Inner::E23503); + + /// 23505 + pub const UNIQUE_VIOLATION: SqlState = SqlState(Inner::E23505); + + /// 23514 + pub const CHECK_VIOLATION: SqlState = SqlState(Inner::E23514); + + /// 23P01 + pub const EXCLUSION_VIOLATION: SqlState = SqlState(Inner::E23P01); + + /// 24000 + pub const INVALID_CURSOR_STATE: SqlState = SqlState(Inner::E24000); + + /// 25000 + pub const INVALID_TRANSACTION_STATE: SqlState = SqlState(Inner::E25000); + + /// 25001 + pub const ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25001); + + /// 25002 + pub const BRANCH_TRANSACTION_ALREADY_ACTIVE: SqlState = SqlState(Inner::E25002); + + /// 25008 + pub const HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: SqlState = SqlState(Inner::E25008); + + /// 25003 + pub const INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25003); + + /// 25004 + pub const INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: SqlState = + SqlState(Inner::E25004); + + /// 25005 + pub const NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25005); + + /// 25006 + pub const READ_ONLY_SQL_TRANSACTION: SqlState = SqlState(Inner::E25006); + + /// 25007 + pub const SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: SqlState = SqlState(Inner::E25007); + + /// 25P01 + pub const NO_ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P01); + + /// 25P02 + pub const IN_FAILED_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P02); + + /// 25P03 + pub const IDLE_IN_TRANSACTION_SESSION_TIMEOUT: SqlState = SqlState(Inner::E25P03); + + /// 26000 + pub const INVALID_SQL_STATEMENT_NAME: SqlState = SqlState(Inner::E26000); + + /// 26000 + pub const UNDEFINED_PSTATEMENT: SqlState = SqlState(Inner::E26000); + + /// 27000 + pub const TRIGGERED_DATA_CHANGE_VIOLATION: SqlState = SqlState(Inner::E27000); + + /// 28000 + pub const INVALID_AUTHORIZATION_SPECIFICATION: SqlState = SqlState(Inner::E28000); + + /// 28P01 + pub const INVALID_PASSWORD: SqlState = SqlState(Inner::E28P01); + + /// 2B000 + pub const DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: SqlState = SqlState(Inner::E2B000); + + /// 2BP01 + pub const DEPENDENT_OBJECTS_STILL_EXIST: SqlState = SqlState(Inner::E2BP01); + + /// 2D000 + pub const INVALID_TRANSACTION_TERMINATION: SqlState = SqlState(Inner::E2D000); + + /// 2F000 + pub const SQL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E2F000); + + /// 2F005 + pub const S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT: SqlState = SqlState(Inner::E2F005); + + /// 2F002 + pub const S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F002); + + /// 2F003 + pub const S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E2F003); + + /// 2F004 + pub const S_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F004); + + /// 34000 + pub const INVALID_CURSOR_NAME: SqlState = SqlState(Inner::E34000); + + /// 34000 + pub const UNDEFINED_CURSOR: SqlState = SqlState(Inner::E34000); + + /// 38000 + pub const EXTERNAL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E38000); + + /// 38001 + pub const E_R_E_CONTAINING_SQL_NOT_PERMITTED: SqlState = SqlState(Inner::E38001); + + /// 38002 + pub const E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38002); + + /// 38003 + pub const E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E38003); + + /// 38004 + pub const E_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38004); + + /// 39000 + pub const EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: SqlState = SqlState(Inner::E39000); + + /// 39001 + pub const E_R_I_E_INVALID_SQLSTATE_RETURNED: SqlState = SqlState(Inner::E39001); + + /// 39004 + pub const E_R_I_E_NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E39004); + + /// 39P01 + pub const E_R_I_E_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P01); + + /// 39P02 + pub const E_R_I_E_SRF_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P02); + + /// 39P03 + pub const E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P03); + + /// 3B000 + pub const SAVEPOINT_EXCEPTION: SqlState = SqlState(Inner::E3B000); + + /// 3B001 + pub const S_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E3B001); + + /// 3D000 + pub const INVALID_CATALOG_NAME: SqlState = SqlState(Inner::E3D000); + + /// 3D000 + pub const UNDEFINED_DATABASE: SqlState = SqlState(Inner::E3D000); + + /// 3F000 + pub const INVALID_SCHEMA_NAME: SqlState = SqlState(Inner::E3F000); + + /// 3F000 + pub const UNDEFINED_SCHEMA: SqlState = SqlState(Inner::E3F000); + + /// 40000 + pub const TRANSACTION_ROLLBACK: SqlState = SqlState(Inner::E40000); + + /// 40002 + pub const T_R_INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E40002); + + /// 40001 + pub const T_R_SERIALIZATION_FAILURE: SqlState = SqlState(Inner::E40001); + + /// 40003 + pub const T_R_STATEMENT_COMPLETION_UNKNOWN: SqlState = SqlState(Inner::E40003); + + /// 40P01 + pub const T_R_DEADLOCK_DETECTED: SqlState = SqlState(Inner::E40P01); + + /// 42000 + pub const SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: SqlState = SqlState(Inner::E42000); + + /// 42601 + pub const SYNTAX_ERROR: SqlState = SqlState(Inner::E42601); + + /// 42501 + pub const INSUFFICIENT_PRIVILEGE: SqlState = SqlState(Inner::E42501); + + /// 42846 + pub const CANNOT_COERCE: SqlState = SqlState(Inner::E42846); + + /// 42803 + pub const GROUPING_ERROR: SqlState = SqlState(Inner::E42803); + + /// 42P20 + pub const WINDOWING_ERROR: SqlState = SqlState(Inner::E42P20); + + /// 42P19 + pub const INVALID_RECURSION: SqlState = SqlState(Inner::E42P19); + + /// 42830 + pub const INVALID_FOREIGN_KEY: SqlState = SqlState(Inner::E42830); + + /// 42602 + pub const INVALID_NAME: SqlState = SqlState(Inner::E42602); + + /// 42622 + pub const NAME_TOO_LONG: SqlState = SqlState(Inner::E42622); + + /// 42939 + pub const RESERVED_NAME: SqlState = SqlState(Inner::E42939); + + /// 42804 + pub const DATATYPE_MISMATCH: SqlState = SqlState(Inner::E42804); + + /// 42P18 + pub const INDETERMINATE_DATATYPE: SqlState = SqlState(Inner::E42P18); + + /// 42P21 + pub const COLLATION_MISMATCH: SqlState = SqlState(Inner::E42P21); + + /// 42P22 + pub const INDETERMINATE_COLLATION: SqlState = SqlState(Inner::E42P22); + + /// 42809 + pub const WRONG_OBJECT_TYPE: SqlState = SqlState(Inner::E42809); + + /// 428C9 + pub const GENERATED_ALWAYS: SqlState = SqlState(Inner::E428C9); + + /// 42703 + pub const UNDEFINED_COLUMN: SqlState = SqlState(Inner::E42703); + + /// 42883 + pub const UNDEFINED_FUNCTION: SqlState = SqlState(Inner::E42883); + + /// 42P01 + pub const UNDEFINED_TABLE: SqlState = SqlState(Inner::E42P01); + + /// 42P02 + pub const UNDEFINED_PARAMETER: SqlState = SqlState(Inner::E42P02); + + /// 42704 + pub const UNDEFINED_OBJECT: SqlState = SqlState(Inner::E42704); + + /// 42701 + pub const DUPLICATE_COLUMN: SqlState = SqlState(Inner::E42701); + + /// 42P03 + pub const DUPLICATE_CURSOR: SqlState = SqlState(Inner::E42P03); + + /// 42P04 + pub const DUPLICATE_DATABASE: SqlState = SqlState(Inner::E42P04); + + /// 42723 + pub const DUPLICATE_FUNCTION: SqlState = SqlState(Inner::E42723); + + /// 42P05 + pub const DUPLICATE_PSTATEMENT: SqlState = SqlState(Inner::E42P05); + + /// 42P06 + pub const DUPLICATE_SCHEMA: SqlState = SqlState(Inner::E42P06); + + /// 42P07 + pub const DUPLICATE_TABLE: SqlState = SqlState(Inner::E42P07); + + /// 42712 + pub const DUPLICATE_ALIAS: SqlState = SqlState(Inner::E42712); + + /// 42710 + pub const DUPLICATE_OBJECT: SqlState = SqlState(Inner::E42710); + + /// 42702 + pub const AMBIGUOUS_COLUMN: SqlState = SqlState(Inner::E42702); + + /// 42725 + pub const AMBIGUOUS_FUNCTION: SqlState = SqlState(Inner::E42725); + + /// 42P08 + pub const AMBIGUOUS_PARAMETER: SqlState = SqlState(Inner::E42P08); + + /// 42P09 + pub const AMBIGUOUS_ALIAS: SqlState = SqlState(Inner::E42P09); + + /// 42P10 + pub const INVALID_COLUMN_REFERENCE: SqlState = SqlState(Inner::E42P10); + + /// 42611 + pub const INVALID_COLUMN_DEFINITION: SqlState = SqlState(Inner::E42611); + + /// 42P11 + pub const INVALID_CURSOR_DEFINITION: SqlState = SqlState(Inner::E42P11); + + /// 42P12 + pub const INVALID_DATABASE_DEFINITION: SqlState = SqlState(Inner::E42P12); + + /// 42P13 + pub const INVALID_FUNCTION_DEFINITION: SqlState = SqlState(Inner::E42P13); + + /// 42P14 + pub const INVALID_PSTATEMENT_DEFINITION: SqlState = SqlState(Inner::E42P14); + + /// 42P15 + pub const INVALID_SCHEMA_DEFINITION: SqlState = SqlState(Inner::E42P15); + + /// 42P16 + pub const INVALID_TABLE_DEFINITION: SqlState = SqlState(Inner::E42P16); + + /// 42P17 + pub const INVALID_OBJECT_DEFINITION: SqlState = SqlState(Inner::E42P17); + + /// 44000 + pub const WITH_CHECK_OPTION_VIOLATION: SqlState = SqlState(Inner::E44000); + + /// 53000 + pub const INSUFFICIENT_RESOURCES: SqlState = SqlState(Inner::E53000); + + /// 53100 + pub const DISK_FULL: SqlState = SqlState(Inner::E53100); + + /// 53200 + pub const OUT_OF_MEMORY: SqlState = SqlState(Inner::E53200); + + /// 53300 + pub const TOO_MANY_CONNECTIONS: SqlState = SqlState(Inner::E53300); + + /// 53400 + pub const CONFIGURATION_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E53400); + + /// 54000 + pub const PROGRAM_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E54000); + + /// 54001 + pub const STATEMENT_TOO_COMPLEX: SqlState = SqlState(Inner::E54001); + + /// 54011 + pub const TOO_MANY_COLUMNS: SqlState = SqlState(Inner::E54011); + + /// 54023 + pub const TOO_MANY_ARGUMENTS: SqlState = SqlState(Inner::E54023); + + /// 55000 + pub const OBJECT_NOT_IN_PREREQUISITE_STATE: SqlState = SqlState(Inner::E55000); + + /// 55006 + pub const OBJECT_IN_USE: SqlState = SqlState(Inner::E55006); + + /// 55P02 + pub const CANT_CHANGE_RUNTIME_PARAM: SqlState = SqlState(Inner::E55P02); + + /// 55P03 + pub const LOCK_NOT_AVAILABLE: SqlState = SqlState(Inner::E55P03); + + /// 55P04 + pub const UNSAFE_NEW_ENUM_VALUE_USAGE: SqlState = SqlState(Inner::E55P04); + + /// 57000 + pub const OPERATOR_INTERVENTION: SqlState = SqlState(Inner::E57000); + + /// 57014 + pub const QUERY_CANCELED: SqlState = SqlState(Inner::E57014); + + /// 57P01 + pub const ADMIN_SHUTDOWN: SqlState = SqlState(Inner::E57P01); + + /// 57P02 + pub const CRASH_SHUTDOWN: SqlState = SqlState(Inner::E57P02); + + /// 57P03 + pub const CANNOT_CONNECT_NOW: SqlState = SqlState(Inner::E57P03); + + /// 57P04 + pub const DATABASE_DROPPED: SqlState = SqlState(Inner::E57P04); + + /// 57P05 + pub const IDLE_SESSION_TIMEOUT: SqlState = SqlState(Inner::E57P05); + + /// 58000 + pub const SYSTEM_ERROR: SqlState = SqlState(Inner::E58000); + + /// 58030 + pub const IO_ERROR: SqlState = SqlState(Inner::E58030); + + /// 58P01 + pub const UNDEFINED_FILE: SqlState = SqlState(Inner::E58P01); + + /// 58P02 + pub const DUPLICATE_FILE: SqlState = SqlState(Inner::E58P02); + + /// 72000 + pub const SNAPSHOT_TOO_OLD: SqlState = SqlState(Inner::E72000); + + /// F0000 + pub const CONFIG_FILE_ERROR: SqlState = SqlState(Inner::EF0000); + + /// F0001 + pub const LOCK_FILE_EXISTS: SqlState = SqlState(Inner::EF0001); + + /// HV000 + pub const FDW_ERROR: SqlState = SqlState(Inner::EHV000); + + /// HV005 + pub const FDW_COLUMN_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV005); + + /// HV002 + pub const FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: SqlState = SqlState(Inner::EHV002); + + /// HV010 + pub const FDW_FUNCTION_SEQUENCE_ERROR: SqlState = SqlState(Inner::EHV010); + + /// HV021 + pub const FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: SqlState = SqlState(Inner::EHV021); + + /// HV024 + pub const FDW_INVALID_ATTRIBUTE_VALUE: SqlState = SqlState(Inner::EHV024); + + /// HV007 + pub const FDW_INVALID_COLUMN_NAME: SqlState = SqlState(Inner::EHV007); + + /// HV008 + pub const FDW_INVALID_COLUMN_NUMBER: SqlState = SqlState(Inner::EHV008); + + /// HV004 + pub const FDW_INVALID_DATA_TYPE: SqlState = SqlState(Inner::EHV004); + + /// HV006 + pub const FDW_INVALID_DATA_TYPE_DESCRIPTORS: SqlState = SqlState(Inner::EHV006); + + /// HV091 + pub const FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: SqlState = SqlState(Inner::EHV091); + + /// HV00B + pub const FDW_INVALID_HANDLE: SqlState = SqlState(Inner::EHV00B); + + /// HV00C + pub const FDW_INVALID_OPTION_INDEX: SqlState = SqlState(Inner::EHV00C); + + /// HV00D + pub const FDW_INVALID_OPTION_NAME: SqlState = SqlState(Inner::EHV00D); + + /// HV090 + pub const FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: SqlState = SqlState(Inner::EHV090); + + /// HV00A + pub const FDW_INVALID_STRING_FORMAT: SqlState = SqlState(Inner::EHV00A); + + /// HV009 + pub const FDW_INVALID_USE_OF_NULL_POINTER: SqlState = SqlState(Inner::EHV009); + + /// HV014 + pub const FDW_TOO_MANY_HANDLES: SqlState = SqlState(Inner::EHV014); + + /// HV001 + pub const FDW_OUT_OF_MEMORY: SqlState = SqlState(Inner::EHV001); + + /// HV00P + pub const FDW_NO_SCHEMAS: SqlState = SqlState(Inner::EHV00P); + + /// HV00J + pub const FDW_OPTION_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV00J); + + /// HV00K + pub const FDW_REPLY_HANDLE: SqlState = SqlState(Inner::EHV00K); + + /// HV00Q + pub const FDW_SCHEMA_NOT_FOUND: SqlState = SqlState(Inner::EHV00Q); + + /// HV00R + pub const FDW_TABLE_NOT_FOUND: SqlState = SqlState(Inner::EHV00R); + + /// HV00L + pub const FDW_UNABLE_TO_CREATE_EXECUTION: SqlState = SqlState(Inner::EHV00L); + + /// HV00M + pub const FDW_UNABLE_TO_CREATE_REPLY: SqlState = SqlState(Inner::EHV00M); + + /// HV00N + pub const FDW_UNABLE_TO_ESTABLISH_CONNECTION: SqlState = SqlState(Inner::EHV00N); + + /// P0000 + pub const PLPGSQL_ERROR: SqlState = SqlState(Inner::EP0000); + + /// P0001 + pub const RAISE_EXCEPTION: SqlState = SqlState(Inner::EP0001); + + /// P0002 + pub const NO_DATA_FOUND: SqlState = SqlState(Inner::EP0002); + + /// P0003 + pub const TOO_MANY_ROWS: SqlState = SqlState(Inner::EP0003); + + /// P0004 + pub const ASSERT_FAILURE: SqlState = SqlState(Inner::EP0004); + + /// XX000 + pub const INTERNAL_ERROR: SqlState = SqlState(Inner::EXX000); + + /// XX001 + pub const DATA_CORRUPTED: SqlState = SqlState(Inner::EXX001); + + /// XX002 + pub const INDEX_CORRUPTED: SqlState = SqlState(Inner::EXX002); +} + +#[derive(PartialEq, Eq, Clone, Debug)] +#[allow(clippy::upper_case_acronyms)] +enum Inner { + E00000, + E01000, + E0100C, + E01008, + E01003, + E01007, + E01006, + E01004, + E01P01, + E02000, + E02001, + E03000, + E08000, + E08003, + E08006, + E08001, + E08004, + E08007, + E08P01, + E09000, + E0A000, + E0B000, + E0F000, + E0F001, + E0L000, + E0LP01, + E0P000, + E0Z000, + E0Z002, + E20000, + E21000, + E22000, + E2202E, + E22021, + E22008, + E22012, + E22005, + E2200B, + E22022, + E22015, + E2201E, + E22014, + E22016, + E2201F, + E2201G, + E22018, + E22007, + E22019, + E2200D, + E22025, + E22P06, + E22010, + E22023, + E22013, + E2201B, + E2201W, + E2201X, + E2202H, + E2202G, + E22009, + E2200C, + E2200G, + E22004, + E22002, + E22003, + E2200H, + E22026, + E22001, + E22011, + E22027, + E22024, + E2200F, + E22P01, + E22P02, + E22P03, + E22P04, + E22P05, + E2200L, + E2200M, + E2200N, + E2200S, + E2200T, + E22030, + E22031, + E22032, + E22033, + E22034, + E22035, + E22036, + E22037, + E22038, + E22039, + E2203A, + E2203B, + E2203C, + E2203D, + E2203E, + E2203F, + E2203G, + E23000, + E23001, + E23502, + E23503, + E23505, + E23514, + E23P01, + E24000, + E25000, + E25001, + E25002, + E25008, + E25003, + E25004, + E25005, + E25006, + E25007, + E25P01, + E25P02, + E25P03, + E26000, + E27000, + E28000, + E28P01, + E2B000, + E2BP01, + E2D000, + E2F000, + E2F005, + E2F002, + E2F003, + E2F004, + E34000, + E38000, + E38001, + E38002, + E38003, + E38004, + E39000, + E39001, + E39004, + E39P01, + E39P02, + E39P03, + E3B000, + E3B001, + E3D000, + E3F000, + E40000, + E40002, + E40001, + E40003, + E40P01, + E42000, + E42601, + E42501, + E42846, + E42803, + E42P20, + E42P19, + E42830, + E42602, + E42622, + E42939, + E42804, + E42P18, + E42P21, + E42P22, + E42809, + E428C9, + E42703, + E42883, + E42P01, + E42P02, + E42704, + E42701, + E42P03, + E42P04, + E42723, + E42P05, + E42P06, + E42P07, + E42712, + E42710, + E42702, + E42725, + E42P08, + E42P09, + E42P10, + E42611, + E42P11, + E42P12, + E42P13, + E42P14, + E42P15, + E42P16, + E42P17, + E44000, + E53000, + E53100, + E53200, + E53300, + E53400, + E54000, + E54001, + E54011, + E54023, + E55000, + E55006, + E55P02, + E55P03, + E55P04, + E57000, + E57014, + E57P01, + E57P02, + E57P03, + E57P04, + E57P05, + E58000, + E58030, + E58P01, + E58P02, + E72000, + EF0000, + EF0001, + EHV000, + EHV005, + EHV002, + EHV010, + EHV021, + EHV024, + EHV007, + EHV008, + EHV004, + EHV006, + EHV091, + EHV00B, + EHV00C, + EHV00D, + EHV090, + EHV00A, + EHV009, + EHV014, + EHV001, + EHV00P, + EHV00J, + EHV00K, + EHV00Q, + EHV00R, + EHV00L, + EHV00M, + EHV00N, + EP0000, + EP0001, + EP0002, + EP0003, + EP0004, + EXX000, + EXX001, + EXX002, + Other(Box), +} + +#[rustfmt::skip] +static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = +::phf::Map { + key: 12913932095322966823, + disps: &[ + (0, 24), + (0, 12), + (0, 74), + (0, 109), + (0, 11), + (0, 9), + (0, 0), + (4, 38), + (3, 155), + (0, 6), + (1, 242), + (0, 66), + (0, 53), + (5, 180), + (3, 221), + (7, 230), + (0, 125), + (1, 46), + (0, 11), + (1, 2), + (0, 5), + (0, 13), + (0, 171), + (0, 15), + (0, 4), + (0, 22), + (1, 85), + (0, 75), + (2, 0), + (1, 25), + (7, 47), + (0, 45), + (0, 35), + (0, 7), + (7, 124), + (0, 0), + (14, 104), + (1, 183), + (61, 50), + (3, 76), + (0, 12), + (0, 7), + (4, 189), + (0, 1), + (64, 102), + (0, 0), + (16, 192), + (24, 19), + (0, 5), + (0, 87), + (0, 89), + (0, 14), + ], + entries: &[ + ("2F000", SqlState::SQL_ROUTINE_EXCEPTION), + ("01008", SqlState::WARNING_IMPLICIT_ZERO_BIT_PADDING), + ("42501", SqlState::INSUFFICIENT_PRIVILEGE), + ("22000", SqlState::DATA_EXCEPTION), + ("0100C", SqlState::WARNING_DYNAMIC_RESULT_SETS_RETURNED), + ("2200N", SqlState::INVALID_XML_CONTENT), + ("40001", SqlState::T_R_SERIALIZATION_FAILURE), + ("28P01", SqlState::INVALID_PASSWORD), + ("38000", SqlState::EXTERNAL_ROUTINE_EXCEPTION), + ("25006", SqlState::READ_ONLY_SQL_TRANSACTION), + ("2203D", SqlState::TOO_MANY_JSON_ARRAY_ELEMENTS), + ("42P09", SqlState::AMBIGUOUS_ALIAS), + ("F0000", SqlState::CONFIG_FILE_ERROR), + ("42P18", SqlState::INDETERMINATE_DATATYPE), + ("40002", SqlState::T_R_INTEGRITY_CONSTRAINT_VIOLATION), + ("22009", SqlState::INVALID_TIME_ZONE_DISPLACEMENT_VALUE), + ("42P08", SqlState::AMBIGUOUS_PARAMETER), + ("08000", SqlState::CONNECTION_EXCEPTION), + ("25P01", SqlState::NO_ACTIVE_SQL_TRANSACTION), + ("22024", SqlState::UNTERMINATED_C_STRING), + ("55000", SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE), + ("25001", SqlState::ACTIVE_SQL_TRANSACTION), + ("03000", SqlState::SQL_STATEMENT_NOT_YET_COMPLETE), + ("42710", SqlState::DUPLICATE_OBJECT), + ("2D000", SqlState::INVALID_TRANSACTION_TERMINATION), + ("2200G", SqlState::MOST_SPECIFIC_TYPE_MISMATCH), + ("22022", SqlState::INDICATOR_OVERFLOW), + ("55006", SqlState::OBJECT_IN_USE), + ("53200", SqlState::OUT_OF_MEMORY), + ("22012", SqlState::DIVISION_BY_ZERO), + ("P0002", SqlState::NO_DATA_FOUND), + ("XX001", SqlState::DATA_CORRUPTED), + ("22P05", SqlState::UNTRANSLATABLE_CHARACTER), + ("40003", SqlState::T_R_STATEMENT_COMPLETION_UNKNOWN), + ("22021", SqlState::CHARACTER_NOT_IN_REPERTOIRE), + ("25000", SqlState::INVALID_TRANSACTION_STATE), + ("42P15", SqlState::INVALID_SCHEMA_DEFINITION), + ("0B000", SqlState::INVALID_TRANSACTION_INITIATION), + ("22004", SqlState::NULL_VALUE_NOT_ALLOWED), + ("42804", SqlState::DATATYPE_MISMATCH), + ("42803", SqlState::GROUPING_ERROR), + ("02001", SqlState::NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED), + ("25002", SqlState::BRANCH_TRANSACTION_ALREADY_ACTIVE), + ("28000", SqlState::INVALID_AUTHORIZATION_SPECIFICATION), + ("HV009", SqlState::FDW_INVALID_USE_OF_NULL_POINTER), + ("22P01", SqlState::FLOATING_POINT_EXCEPTION), + ("2B000", SqlState::DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST), + ("42723", SqlState::DUPLICATE_FUNCTION), + ("21000", SqlState::CARDINALITY_VIOLATION), + ("0Z002", SqlState::STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER), + ("23505", SqlState::UNIQUE_VIOLATION), + ("HV00J", SqlState::FDW_OPTION_NAME_NOT_FOUND), + ("23P01", SqlState::EXCLUSION_VIOLATION), + ("39P03", SqlState::E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED), + ("42P10", SqlState::INVALID_COLUMN_REFERENCE), + ("2202H", SqlState::INVALID_TABLESAMPLE_ARGUMENT), + ("55P04", SqlState::UNSAFE_NEW_ENUM_VALUE_USAGE), + ("P0000", SqlState::PLPGSQL_ERROR), + ("2F005", SqlState::S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT), + ("HV00M", SqlState::FDW_UNABLE_TO_CREATE_REPLY), + ("0A000", SqlState::FEATURE_NOT_SUPPORTED), + ("24000", SqlState::INVALID_CURSOR_STATE), + ("25008", SqlState::HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL), + ("01003", SqlState::WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION), + ("42712", SqlState::DUPLICATE_ALIAS), + ("HV014", SqlState::FDW_TOO_MANY_HANDLES), + ("58030", SqlState::IO_ERROR), + ("2201W", SqlState::INVALID_ROW_COUNT_IN_LIMIT_CLAUSE), + ("22033", SqlState::INVALID_SQL_JSON_SUBSCRIPT), + ("2BP01", SqlState::DEPENDENT_OBJECTS_STILL_EXIST), + ("HV005", SqlState::FDW_COLUMN_NAME_NOT_FOUND), + ("25004", SqlState::INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION), + ("54000", SqlState::PROGRAM_LIMIT_EXCEEDED), + ("20000", SqlState::CASE_NOT_FOUND), + ("2203G", SqlState::SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE), + ("22038", SqlState::SINGLETON_SQL_JSON_ITEM_REQUIRED), + ("22007", SqlState::INVALID_DATETIME_FORMAT), + ("08004", SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION), + ("2200H", SqlState::SEQUENCE_GENERATOR_LIMIT_EXCEEDED), + ("HV00D", SqlState::FDW_INVALID_OPTION_NAME), + ("P0004", SqlState::ASSERT_FAILURE), + ("22018", SqlState::INVALID_CHARACTER_VALUE_FOR_CAST), + ("0L000", SqlState::INVALID_GRANTOR), + ("22P04", SqlState::BAD_COPY_FILE_FORMAT), + ("22031", SqlState::INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION), + ("01P01", SqlState::WARNING_DEPRECATED_FEATURE), + ("0LP01", SqlState::INVALID_GRANT_OPERATION), + ("58P02", SqlState::DUPLICATE_FILE), + ("26000", SqlState::INVALID_SQL_STATEMENT_NAME), + ("54001", SqlState::STATEMENT_TOO_COMPLEX), + ("22010", SqlState::INVALID_INDICATOR_PARAMETER_VALUE), + ("HV00C", SqlState::FDW_INVALID_OPTION_INDEX), + ("22008", SqlState::DATETIME_FIELD_OVERFLOW), + ("42P06", SqlState::DUPLICATE_SCHEMA), + ("25007", SqlState::SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED), + ("42P20", SqlState::WINDOWING_ERROR), + ("HV091", SqlState::FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER), + ("HV021", SqlState::FDW_INCONSISTENT_DESCRIPTOR_INFORMATION), + ("42702", SqlState::AMBIGUOUS_COLUMN), + ("02000", SqlState::NO_DATA), + ("54011", SqlState::TOO_MANY_COLUMNS), + ("HV004", SqlState::FDW_INVALID_DATA_TYPE), + ("01006", SqlState::WARNING_PRIVILEGE_NOT_REVOKED), + ("42701", SqlState::DUPLICATE_COLUMN), + ("08P01", SqlState::PROTOCOL_VIOLATION), + ("42622", SqlState::NAME_TOO_LONG), + ("P0003", SqlState::TOO_MANY_ROWS), + ("22003", SqlState::NUMERIC_VALUE_OUT_OF_RANGE), + ("42P03", SqlState::DUPLICATE_CURSOR), + ("23001", SqlState::RESTRICT_VIOLATION), + ("57000", SqlState::OPERATOR_INTERVENTION), + ("22027", SqlState::TRIM_ERROR), + ("42P12", SqlState::INVALID_DATABASE_DEFINITION), + ("3B000", SqlState::SAVEPOINT_EXCEPTION), + ("2201B", SqlState::INVALID_REGULAR_EXPRESSION), + ("22030", SqlState::DUPLICATE_JSON_OBJECT_KEY_VALUE), + ("2F004", SqlState::S_R_E_READING_SQL_DATA_NOT_PERMITTED), + ("428C9", SqlState::GENERATED_ALWAYS), + ("2200S", SqlState::INVALID_XML_COMMENT), + ("22039", SqlState::SQL_JSON_ARRAY_NOT_FOUND), + ("42809", SqlState::WRONG_OBJECT_TYPE), + ("2201X", SqlState::INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE), + ("39001", SqlState::E_R_I_E_INVALID_SQLSTATE_RETURNED), + ("25P02", SqlState::IN_FAILED_SQL_TRANSACTION), + ("0P000", SqlState::INVALID_ROLE_SPECIFICATION), + ("HV00N", SqlState::FDW_UNABLE_TO_ESTABLISH_CONNECTION), + ("53100", SqlState::DISK_FULL), + ("42601", SqlState::SYNTAX_ERROR), + ("23000", SqlState::INTEGRITY_CONSTRAINT_VIOLATION), + ("HV006", SqlState::FDW_INVALID_DATA_TYPE_DESCRIPTORS), + ("HV00B", SqlState::FDW_INVALID_HANDLE), + ("HV00Q", SqlState::FDW_SCHEMA_NOT_FOUND), + ("01000", SqlState::WARNING), + ("42883", SqlState::UNDEFINED_FUNCTION), + ("57P01", SqlState::ADMIN_SHUTDOWN), + ("22037", SqlState::NON_UNIQUE_KEYS_IN_A_JSON_OBJECT), + ("00000", SqlState::SUCCESSFUL_COMPLETION), + ("55P03", SqlState::LOCK_NOT_AVAILABLE), + ("42P01", SqlState::UNDEFINED_TABLE), + ("42830", SqlState::INVALID_FOREIGN_KEY), + ("22005", SqlState::ERROR_IN_ASSIGNMENT), + ("22025", SqlState::INVALID_ESCAPE_SEQUENCE), + ("XX002", SqlState::INDEX_CORRUPTED), + ("42P16", SqlState::INVALID_TABLE_DEFINITION), + ("55P02", SqlState::CANT_CHANGE_RUNTIME_PARAM), + ("22019", SqlState::INVALID_ESCAPE_CHARACTER), + ("P0001", SqlState::RAISE_EXCEPTION), + ("72000", SqlState::SNAPSHOT_TOO_OLD), + ("42P11", SqlState::INVALID_CURSOR_DEFINITION), + ("40P01", SqlState::T_R_DEADLOCK_DETECTED), + ("57P02", SqlState::CRASH_SHUTDOWN), + ("HV00A", SqlState::FDW_INVALID_STRING_FORMAT), + ("2F002", SqlState::S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("23503", SqlState::FOREIGN_KEY_VIOLATION), + ("40000", SqlState::TRANSACTION_ROLLBACK), + ("22032", SqlState::INVALID_JSON_TEXT), + ("2202E", SqlState::ARRAY_ELEMENT_ERROR), + ("42P19", SqlState::INVALID_RECURSION), + ("42611", SqlState::INVALID_COLUMN_DEFINITION), + ("42P13", SqlState::INVALID_FUNCTION_DEFINITION), + ("25003", SqlState::INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION), + ("39P02", SqlState::E_R_I_E_SRF_PROTOCOL_VIOLATED), + ("XX000", SqlState::INTERNAL_ERROR), + ("08006", SqlState::CONNECTION_FAILURE), + ("57P04", SqlState::DATABASE_DROPPED), + ("42P07", SqlState::DUPLICATE_TABLE), + ("22P03", SqlState::INVALID_BINARY_REPRESENTATION), + ("22035", SqlState::NO_SQL_JSON_ITEM), + ("42P14", SqlState::INVALID_PSTATEMENT_DEFINITION), + ("01007", SqlState::WARNING_PRIVILEGE_NOT_GRANTED), + ("38004", SqlState::E_R_E_READING_SQL_DATA_NOT_PERMITTED), + ("42P21", SqlState::COLLATION_MISMATCH), + ("0Z000", SqlState::DIAGNOSTICS_EXCEPTION), + ("HV001", SqlState::FDW_OUT_OF_MEMORY), + ("0F000", SqlState::LOCATOR_EXCEPTION), + ("22013", SqlState::INVALID_PRECEDING_OR_FOLLOWING_SIZE), + ("2201E", SqlState::INVALID_ARGUMENT_FOR_LOG), + ("22011", SqlState::SUBSTRING_ERROR), + ("42602", SqlState::INVALID_NAME), + ("01004", SqlState::WARNING_STRING_DATA_RIGHT_TRUNCATION), + ("42P02", SqlState::UNDEFINED_PARAMETER), + ("2203C", SqlState::SQL_JSON_OBJECT_NOT_FOUND), + ("HV002", SqlState::FDW_DYNAMIC_PARAMETER_VALUE_NEEDED), + ("0F001", SqlState::L_E_INVALID_SPECIFICATION), + ("58P01", SqlState::UNDEFINED_FILE), + ("38001", SqlState::E_R_E_CONTAINING_SQL_NOT_PERMITTED), + ("42703", SqlState::UNDEFINED_COLUMN), + ("57P05", SqlState::IDLE_SESSION_TIMEOUT), + ("57P03", SqlState::CANNOT_CONNECT_NOW), + ("HV007", SqlState::FDW_INVALID_COLUMN_NAME), + ("22014", SqlState::INVALID_ARGUMENT_FOR_NTILE), + ("22P06", SqlState::NONSTANDARD_USE_OF_ESCAPE_CHARACTER), + ("2203F", SqlState::SQL_JSON_SCALAR_REQUIRED), + ("2200F", SqlState::ZERO_LENGTH_CHARACTER_STRING), + ("09000", SqlState::TRIGGERED_ACTION_EXCEPTION), + ("2201F", SqlState::INVALID_ARGUMENT_FOR_POWER_FUNCTION), + ("08003", SqlState::CONNECTION_DOES_NOT_EXIST), + ("38002", SqlState::E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("F0001", SqlState::LOCK_FILE_EXISTS), + ("42P22", SqlState::INDETERMINATE_COLLATION), + ("2200C", SqlState::INVALID_USE_OF_ESCAPE_CHARACTER), + ("2203E", SqlState::TOO_MANY_JSON_OBJECT_MEMBERS), + ("23514", SqlState::CHECK_VIOLATION), + ("22P02", SqlState::INVALID_TEXT_REPRESENTATION), + ("54023", SqlState::TOO_MANY_ARGUMENTS), + ("2200T", SqlState::INVALID_XML_PROCESSING_INSTRUCTION), + ("22016", SqlState::INVALID_ARGUMENT_FOR_NTH_VALUE), + ("25P03", SqlState::IDLE_IN_TRANSACTION_SESSION_TIMEOUT), + ("3B001", SqlState::S_E_INVALID_SPECIFICATION), + ("08001", SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), + ("22036", SqlState::NON_NUMERIC_SQL_JSON_ITEM), + ("3F000", SqlState::INVALID_SCHEMA_NAME), + ("39P01", SqlState::E_R_I_E_TRIGGER_PROTOCOL_VIOLATED), + ("22026", SqlState::STRING_DATA_LENGTH_MISMATCH), + ("42P17", SqlState::INVALID_OBJECT_DEFINITION), + ("22034", SqlState::MORE_THAN_ONE_SQL_JSON_ITEM), + ("HV000", SqlState::FDW_ERROR), + ("2200B", SqlState::ESCAPE_CHARACTER_CONFLICT), + ("HV008", SqlState::FDW_INVALID_COLUMN_NUMBER), + ("34000", SqlState::INVALID_CURSOR_NAME), + ("2201G", SqlState::INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION), + ("44000", SqlState::WITH_CHECK_OPTION_VIOLATION), + ("HV010", SqlState::FDW_FUNCTION_SEQUENCE_ERROR), + ("39004", SqlState::E_R_I_E_NULL_VALUE_NOT_ALLOWED), + ("22001", SqlState::STRING_DATA_RIGHT_TRUNCATION), + ("3D000", SqlState::INVALID_CATALOG_NAME), + ("25005", SqlState::NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION), + ("2200L", SqlState::NOT_AN_XML_DOCUMENT), + ("27000", SqlState::TRIGGERED_DATA_CHANGE_VIOLATION), + ("HV090", SqlState::FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH), + ("42939", SqlState::RESERVED_NAME), + ("58000", SqlState::SYSTEM_ERROR), + ("2200M", SqlState::INVALID_XML_DOCUMENT), + ("HV00L", SqlState::FDW_UNABLE_TO_CREATE_EXECUTION), + ("57014", SqlState::QUERY_CANCELED), + ("23502", SqlState::NOT_NULL_VIOLATION), + ("22002", SqlState::NULL_VALUE_NO_INDICATOR_PARAMETER), + ("HV00R", SqlState::FDW_TABLE_NOT_FOUND), + ("HV00P", SqlState::FDW_NO_SCHEMAS), + ("38003", SqlState::E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("39000", SqlState::EXTERNAL_ROUTINE_INVOCATION_EXCEPTION), + ("22015", SqlState::INTERVAL_FIELD_OVERFLOW), + ("HV00K", SqlState::FDW_REPLY_HANDLE), + ("HV024", SqlState::FDW_INVALID_ATTRIBUTE_VALUE), + ("2200D", SqlState::INVALID_ESCAPE_OCTET), + ("08007", SqlState::TRANSACTION_RESOLUTION_UNKNOWN), + ("2F003", SqlState::S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("42725", SqlState::AMBIGUOUS_FUNCTION), + ("2203A", SqlState::SQL_JSON_MEMBER_NOT_FOUND), + ("42846", SqlState::CANNOT_COERCE), + ("42P04", SqlState::DUPLICATE_DATABASE), + ("42000", SqlState::SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION), + ("2203B", SqlState::SQL_JSON_NUMBER_NOT_FOUND), + ("42P05", SqlState::DUPLICATE_PSTATEMENT), + ("53300", SqlState::TOO_MANY_CONNECTIONS), + ("53400", SqlState::CONFIGURATION_LIMIT_EXCEEDED), + ("42704", SqlState::UNDEFINED_OBJECT), + ("2202G", SqlState::INVALID_TABLESAMPLE_REPEAT), + ("22023", SqlState::INVALID_PARAMETER_VALUE), + ("53000", SqlState::INSUFFICIENT_RESOURCES), + ], +}; diff --git a/libs/proxy/tokio-postgres2/src/generic_client.rs b/libs/proxy/tokio-postgres2/src/generic_client.rs new file mode 100644 index 0000000000..768213f8ed --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/generic_client.rs @@ -0,0 +1,64 @@ +use crate::query::RowStream; +use crate::types::Type; +use crate::{Client, Error, Transaction}; +use async_trait::async_trait; +use postgres_protocol2::Oid; + +mod private { + pub trait Sealed {} +} + +/// A trait allowing abstraction over connections and transactions. +/// +/// This trait is "sealed", and cannot be implemented outside of this crate. +#[async_trait] +pub trait GenericClient: private::Sealed { + /// Like `Client::query_raw_txt`. + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send; + + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result; +} + +impl private::Sealed for Client {} + +#[async_trait] +impl GenericClient for Client { + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.get_type(oid).await + } +} + +impl private::Sealed for Transaction<'_> {} + +#[async_trait] +#[allow(clippy::needless_lifetimes)] +impl GenericClient for Transaction<'_> { + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.client().get_type(oid).await + } +} diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs new file mode 100644 index 0000000000..72ba8172b2 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -0,0 +1,148 @@ +//! An asynchronous, pipelined, PostgreSQL client. +#![warn(rust_2018_idioms, clippy::all, missing_docs)] + +pub use crate::cancel_token::CancelToken; +pub use crate::client::Client; +pub use crate::config::Config; +pub use crate::connection::Connection; +use crate::error::DbError; +pub use crate::error::Error; +pub use crate::generic_client::GenericClient; +pub use crate::query::RowStream; +pub use crate::row::{Row, SimpleQueryRow}; +pub use crate::simple_query::SimpleQueryStream; +pub use crate::statement::{Column, Statement}; +use crate::tls::MakeTlsConnect; +pub use crate::tls::NoTls; +pub use crate::to_statement::ToStatement; +pub use crate::transaction::Transaction; +pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; +use crate::types::ToSql; +use postgres_protocol2::message::backend::ReadyForQueryBody; +use tokio::net::TcpStream; + +/// After executing a query, the connection will be in one of these states +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum ReadyForQueryStatus { + /// Connection state is unknown + Unknown, + /// Connection is idle (no transactions) + Idle = b'I', + /// Connection is in a transaction block + Transaction = b'T', + /// Connection is in a failed transaction block + FailedTransaction = b'E', +} + +impl From for ReadyForQueryStatus { + fn from(value: ReadyForQueryBody) -> Self { + match value.status() { + b'I' => Self::Idle, + b'T' => Self::Transaction, + b'E' => Self::FailedTransaction, + _ => Self::Unknown, + } + } +} + +mod cancel_query; +mod cancel_query_raw; +mod cancel_token; +mod client; +mod codec; +pub mod config; +mod connect; +mod connect_raw; +mod connect_socket; +mod connect_tls; +mod connection; +pub mod error; +mod generic_client; +pub mod maybe_tls_stream; +mod prepare; +mod query; +pub mod row; +mod simple_query; +mod statement; +pub mod tls; +mod to_statement; +mod transaction; +mod transaction_builder; +pub mod types; + +/// A convenience function which parses a connection string and connects to the database. +/// +/// See the documentation for [`Config`] for details on the connection string format. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +/// +/// [`Config`]: config/struct.Config.html +pub async fn connect( + config: &str, + tls: T, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + let config = config.parse::()?; + config.connect(tls).await +} + +/// An asynchronous notification. +#[derive(Clone, Debug)] +pub struct Notification { + process_id: i32, + channel: String, + payload: String, +} + +impl Notification { + /// The process ID of the notifying backend process. + pub fn process_id(&self) -> i32 { + self.process_id + } + + /// The name of the channel that the notify has been raised on. + pub fn channel(&self) -> &str { + &self.channel + } + + /// The "payload" string passed from the notifying process. + pub fn payload(&self) -> &str { + &self.payload + } +} + +/// An asynchronous message from the server. +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum AsyncMessage { + /// A notice. + /// + /// Notices use the same format as errors, but aren't "errors" per-se. + Notice(DbError), + /// A notification. + /// + /// Connections can subscribe to notifications with the `LISTEN` command. + Notification(Notification), +} + +/// Message returned by the `SimpleQuery` stream. +#[derive(Debug)] +#[non_exhaustive] +pub enum SimpleQueryMessage { + /// A row of data. + Row(SimpleQueryRow), + /// A statement in the query has completed. + /// + /// The number of rows modified or selected is returned. + CommandComplete(u64), +} + +fn slice_iter<'a>( + s: &'a [&'a (dyn ToSql + Sync)], +) -> impl ExactSizeIterator + 'a { + s.iter().map(|s| *s as _) +} diff --git a/libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs b/libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs new file mode 100644 index 0000000000..9a7e248997 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs @@ -0,0 +1,77 @@ +//! MaybeTlsStream. +//! +//! Represents a stream that may or may not be encrypted with TLS. +use crate::tls::{ChannelBinding, TlsStream}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A stream that may or may not be encrypted with TLS. +pub enum MaybeTlsStream { + /// An unencrypted stream. + Raw(S), + /// An encrypted stream. + Tls(T), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + Unpin, + T: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncWrite + Unpin, + T: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx), + } + } +} + +impl TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match self { + MaybeTlsStream::Raw(_) => ChannelBinding::none(), + MaybeTlsStream::Tls(s) => s.channel_binding(), + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs new file mode 100644 index 0000000000..da0c755c5b --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -0,0 +1,262 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::error::SqlState; +use crate::types::{Field, Kind, Oid, Type}; +use crate::{query, slice_iter}; +use crate::{Column, Error, Statement}; +use bytes::Bytes; +use fallible_iterator::FallibleIterator; +use futures_util::{pin_mut, TryStreamExt}; +use log::debug; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +pub(crate) const TYPEINFO_QUERY: &str = "\ +SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid +FROM pg_catalog.pg_type t +LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid +INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +WHERE t.oid = $1 +"; + +// Range types weren't added until Postgres 9.2, so pg_range may not exist +const TYPEINFO_FALLBACK_QUERY: &str = "\ +SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid +FROM pg_catalog.pg_type t +INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +WHERE t.oid = $1 +"; + +const TYPEINFO_ENUM_QUERY: &str = "\ +SELECT enumlabel +FROM pg_catalog.pg_enum +WHERE enumtypid = $1 +ORDER BY enumsortorder +"; + +// Postgres 9.0 didn't have enumsortorder +const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\ +SELECT enumlabel +FROM pg_catalog.pg_enum +WHERE enumtypid = $1 +ORDER BY oid +"; + +pub(crate) const TYPEINFO_COMPOSITE_QUERY: &str = "\ +SELECT attname, atttypid +FROM pg_catalog.pg_attribute +WHERE attrelid = $1 +AND NOT attisdropped +AND attnum > 0 +ORDER BY attnum +"; + +static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + +pub async fn prepare( + client: &Arc, + query: &str, + types: &[Type], +) -> Result { + let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); + let buf = encode(client, &name, query, types)?; + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, oid).await?; + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + Ok(Statement::new(client, name, parameters, columns)) +} + +fn prepare_rec<'a>( + client: &'a Arc, + query: &'a str, + types: &'a [Type], +) -> Pin> + 'a + Send>> { + Box::pin(prepare(client, query, types)) +} + +fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { + if types.is_empty() { + debug!("preparing query {}: {}", name, query); + } else { + debug!("preparing query {} with types {:?}: {}", name, types, query); + } + + client.with_buf(|buf| { + frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?; + frontend::describe(b'S', name, buf).map_err(Error::encode)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + }) +} + +pub async fn get_type(client: &Arc, oid: Oid) -> Result { + if let Some(type_) = Type::from_oid(oid) { + return Ok(type_); + } + + if let Some(type_) = client.type_(oid) { + return Ok(type_); + } + + let stmt = typeinfo_statement(client).await?; + + let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; + pin_mut!(rows); + + let row = match rows.try_next().await? { + Some(row) => row, + None => return Err(Error::unexpected_message()), + }; + + let name: String = row.try_get(0)?; + let type_: i8 = row.try_get(1)?; + let elem_oid: Oid = row.try_get(2)?; + let rngsubtype: Option = row.try_get(3)?; + let basetype: Oid = row.try_get(4)?; + let schema: String = row.try_get(5)?; + let relid: Oid = row.try_get(6)?; + + let kind = if type_ == b'e' as i8 { + let variants = get_enum_variants(client, oid).await?; + Kind::Enum(variants) + } else if type_ == b'p' as i8 { + Kind::Pseudo + } else if basetype != 0 { + let type_ = get_type_rec(client, basetype).await?; + Kind::Domain(type_) + } else if elem_oid != 0 { + let type_ = get_type_rec(client, elem_oid).await?; + Kind::Array(type_) + } else if relid != 0 { + let fields = get_composite_fields(client, relid).await?; + Kind::Composite(fields) + } else if let Some(rngsubtype) = rngsubtype { + let type_ = get_type_rec(client, rngsubtype).await?; + Kind::Range(type_) + } else { + Kind::Simple + }; + + let type_ = Type::new(name, oid, kind, schema); + client.set_type(oid, &type_); + + Ok(type_) +} + +fn get_type_rec<'a>( + client: &'a Arc, + oid: Oid, +) -> Pin> + Send + 'a>> { + Box::pin(get_type(client, oid)) +} + +async fn typeinfo_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo() { + return Ok(stmt); + } + + let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await { + Ok(stmt) => stmt, + Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => { + prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await? + } + Err(e) => return Err(e), + }; + + client.set_typeinfo(&stmt); + Ok(stmt) +} + +async fn get_enum_variants(client: &Arc, oid: Oid) -> Result, Error> { + let stmt = typeinfo_enum_statement(client).await?; + + query::query(client, stmt, slice_iter(&[&oid])) + .await? + .and_then(|row| async move { row.try_get(0) }) + .try_collect() + .await +} + +async fn typeinfo_enum_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo_enum() { + return Ok(stmt); + } + + let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await { + Ok(stmt) => stmt, + Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => { + prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await? + } + Err(e) => return Err(e), + }; + + client.set_typeinfo_enum(&stmt); + Ok(stmt) +} + +async fn get_composite_fields(client: &Arc, oid: Oid) -> Result, Error> { + let stmt = typeinfo_composite_statement(client).await?; + + let rows = query::query(client, stmt, slice_iter(&[&oid])) + .await? + .try_collect::>() + .await?; + + let mut fields = vec![]; + for row in rows { + let name = row.try_get(0)?; + let oid = row.try_get(1)?; + let type_ = get_type_rec(client, oid).await?; + fields.push(Field::new(name, type_)); + } + + Ok(fields) +} + +async fn typeinfo_composite_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo_composite() { + return Ok(stmt); + } + + let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?; + + client.set_typeinfo_composite(&stmt); + Ok(stmt) +} diff --git a/libs/proxy/tokio-postgres2/src/query.rs b/libs/proxy/tokio-postgres2/src/query.rs new file mode 100644 index 0000000000..534195a707 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/query.rs @@ -0,0 +1,340 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::types::IsNull; +use crate::{Column, Error, ReadyForQueryStatus, Row, Statement}; +use bytes::{BufMut, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Stream}; +use log::{debug, log_enabled, Level}; +use pin_project_lite::pin_project; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use postgres_types2::{Format, ToSql, Type}; +use std::fmt; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]); + +impl fmt::Debug for BorrowToSqlParamsDebug<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.0.iter()).finish() + } +} + +pub async fn query<'a, I>( + client: &InnerClient, + statement: Statement, + params: I, +) -> Result +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let buf = if log_enabled!(Level::Debug) { + let params = params.into_iter().collect::>(); + debug!( + "executing statement {} with parameters: {:?}", + statement.name(), + BorrowToSqlParamsDebug(params.as_slice()), + ); + encode(client, &statement, params)? + } else { + encode(client, &statement, params)? + }; + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + command_tag: None, + status: ReadyForQueryStatus::Unknown, + output_format: Format::Binary, + _p: PhantomPinned, + }) +} + +pub async fn query_txt( + client: &Arc, + query: &str, + params: I, +) -> Result +where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + + let buf = client.with_buf(|buf| { + frontend::parse( + "", // unnamed prepared statement + query, // query to parse + std::iter::empty(), // give no type info + buf, + ) + .map_err(Error::encode)?; + frontend::describe(b'S', "", buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol2::IsNull::No) + } + None => Ok(postgres_protocol2::IsNull::Yes), + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + // now read the responses + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN); + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN); + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + Ok(RowStream { + statement: Statement::new_anonymous(parameters, columns), + responses, + command_tag: None, + status: ReadyForQueryStatus::Unknown, + output_format: Format::Text, + _p: PhantomPinned, + }) +} + +pub async fn execute<'a, I>( + client: &InnerClient, + statement: Statement, + params: I, +) -> Result +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let buf = if log_enabled!(Level::Debug) { + let params = params.into_iter().collect::>(); + debug!( + "executing statement {} with parameters: {:?}", + statement.name(), + BorrowToSqlParamsDebug(params.as_slice()), + ); + encode(client, &statement, params)? + } else { + encode(client, &statement, params)? + }; + let mut responses = start(client, buf).await?; + + let mut rows = 0; + loop { + match responses.next().await? { + Message::DataRow(_) => {} + Message::CommandComplete(body) => { + rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + } + Message::EmptyQueryResponse => rows = 0, + Message::ReadyForQuery(_) => return Ok(rows), + _ => return Err(Error::unexpected_message()), + } + } +} + +async fn start(client: &InnerClient, buf: Bytes) -> Result { + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(responses) +} + +pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + client.with_buf(|buf| { + encode_bind(statement, params, "", buf)?; + frontend::execute("", 0, buf).map_err(Error::encode)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + }) +} + +pub fn encode_bind<'a, I>( + statement: &Statement, + params: I, + portal: &str, + buf: &mut BytesMut, +) -> Result<(), Error> +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let param_types = statement.params(); + let params = params.into_iter(); + + assert!( + param_types.len() == params.len(), + "expected {} parameters but got {}", + param_types.len(), + params.len() + ); + + let (param_formats, params): (Vec<_>, Vec<_>) = params + .zip(param_types.iter()) + .map(|(p, ty)| (p.encode_format(ty) as i16, p)) + .unzip(); + + let params = params.into_iter(); + + let mut error_idx = 0; + let r = frontend::bind( + portal, + statement.name(), + param_formats, + params.zip(param_types).enumerate(), + |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { + Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No), + Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes), + Err(e) => { + error_idx = idx; + Err(e) + } + }, + Some(1), + buf, + ); + match r { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + } +} + +pin_project! { + /// A stream of table rows. + pub struct RowStream { + statement: Statement, + responses: Responses, + command_tag: Option, + output_format: Format, + status: ReadyForQueryStatus, + #[pin] + _p: PhantomPinned, + } +} + +impl Stream for RowStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::DataRow(body) => { + return Poll::Ready(Some(Ok(Row::new( + this.statement.clone(), + body, + *this.output_format, + )?))) + } + Message::EmptyQueryResponse | Message::PortalSuspended => {} + Message::CommandComplete(body) => { + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } + } + Message::ReadyForQuery(status) => { + *this.status = status.into(); + return Poll::Ready(None); + } + _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } + } +} + +impl RowStream { + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[Column] { + self.statement.columns() + } + + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() + } + + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} diff --git a/libs/proxy/tokio-postgres2/src/row.rs b/libs/proxy/tokio-postgres2/src/row.rs new file mode 100644 index 0000000000..10e130707d --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/row.rs @@ -0,0 +1,300 @@ +//! Rows. + +use crate::row::sealed::{AsName, Sealed}; +use crate::simple_query::SimpleColumn; +use crate::statement::Column; +use crate::types::{FromSql, Type, WrongType}; +use crate::{Error, Statement}; +use fallible_iterator::FallibleIterator; +use postgres_protocol2::message::backend::DataRowBody; +use postgres_types2::{Format, WrongFormat}; +use std::fmt; +use std::ops::Range; +use std::str; +use std::sync::Arc; + +mod sealed { + pub trait Sealed {} + + pub trait AsName { + fn as_name(&self) -> &str; + } +} + +impl AsName for Column { + fn as_name(&self) -> &str { + self.name() + } +} + +impl AsName for String { + fn as_name(&self) -> &str { + self + } +} + +/// A trait implemented by types that can index into columns of a row. +/// +/// This cannot be implemented outside of this crate. +pub trait RowIndex: Sealed { + #[doc(hidden)] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName; +} + +impl Sealed for usize {} + +impl RowIndex for usize { + #[inline] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName, + { + if *self >= columns.len() { + None + } else { + Some(*self) + } + } +} + +impl Sealed for str {} + +impl RowIndex for str { + #[inline] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName, + { + if let Some(idx) = columns.iter().position(|d| d.as_name() == self) { + return Some(idx); + }; + + // FIXME ASCII-only case insensitivity isn't really the right thing to + // do. Postgres itself uses a dubious wrapper around tolower and JDBC + // uses the US locale. + columns + .iter() + .position(|d| d.as_name().eq_ignore_ascii_case(self)) + } +} + +impl Sealed for &T where T: ?Sized + Sealed {} + +impl RowIndex for &T +where + T: ?Sized + RowIndex, +{ + #[inline] + fn __idx(&self, columns: &[U]) -> Option + where + U: AsName, + { + T::__idx(*self, columns) + } +} + +/// A row of data returned from the database by a query. +pub struct Row { + statement: Statement, + output_format: Format, + body: DataRowBody, + ranges: Vec>>, +} + +impl fmt::Debug for Row { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Row") + .field("columns", &self.columns()) + .finish() + } +} + +impl Row { + pub(crate) fn new( + statement: Statement, + body: DataRowBody, + output_format: Format, + ) -> Result { + let ranges = body.ranges().collect().map_err(Error::parse)?; + Ok(Row { + statement, + body, + ranges, + output_format, + }) + } + + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[Column] { + self.statement.columns() + } + + /// Determines if the row contains no values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the number of values in the row. + pub fn len(&self) -> usize { + self.columns().len() + } + + /// Deserializes a value from the row. + /// + /// The value can be specified either by its numeric index in the row, or by its column name. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + pub fn get<'a, I, T>(&'a self, idx: I) -> T + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + match self.get_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `Row::get`, but returns a `Result` rather than panicking. + pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + self.get_inner(&idx) + } + + fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + let idx = match idx.__idx(self.columns()) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let ty = self.columns()[idx].type_(); + if !T::accepts(ty) { + return Err(Error::from_sql( + Box::new(WrongType::new::(ty.clone())), + idx, + )); + } + + FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx)) + } + + /// Get the raw bytes for the column at the given index. + fn col_buffer(&self, idx: usize) -> Option<&[u8]> { + let range = self.ranges.get(idx)?.to_owned()?; + Some(&self.body.buffer()[range]) + } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.output_format == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } +} + +impl AsName for SimpleColumn { + fn as_name(&self) -> &str { + self.name() + } +} + +/// A row of data returned from the database by a simple query. +#[derive(Debug)] +pub struct SimpleQueryRow { + columns: Arc<[SimpleColumn]>, + body: DataRowBody, + ranges: Vec>>, +} + +impl SimpleQueryRow { + #[allow(clippy::new_ret_no_self)] + pub(crate) fn new( + columns: Arc<[SimpleColumn]>, + body: DataRowBody, + ) -> Result { + let ranges = body.ranges().collect().map_err(Error::parse)?; + Ok(SimpleQueryRow { + columns, + body, + ranges, + }) + } + + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[SimpleColumn] { + &self.columns + } + + /// Determines if the row contains no values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the number of values in the row. + pub fn len(&self) -> usize { + self.columns.len() + } + + /// Returns a value from the row. + /// + /// The value can be specified either by its numeric index in the row, or by its column name. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + pub fn get(&self, idx: I) -> Option<&str> + where + I: RowIndex + fmt::Display, + { + match self.get_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `SimpleQueryRow::get`, but returns a `Result` rather than panicking. + pub fn try_get(&self, idx: I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + self.get_inner(&idx) + } + + fn get_inner(&self, idx: &I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + let idx = match idx.__idx(&self.columns) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]); + FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx)) + } +} diff --git a/libs/proxy/tokio-postgres2/src/simple_query.rs b/libs/proxy/tokio-postgres2/src/simple_query.rs new file mode 100644 index 0000000000..fb2550377b --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/simple_query.rs @@ -0,0 +1,142 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; +use bytes::Bytes; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Stream}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +/// Information about a column of a single query row. +#[derive(Debug)] +pub struct SimpleColumn { + name: String, +} + +impl SimpleColumn { + pub(crate) fn new(name: String) -> SimpleColumn { + SimpleColumn { name } + } + + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } +} + +pub async fn simple_query(client: &InnerClient, query: &str) -> Result { + debug!("executing simple query: {}", query); + + let buf = encode(client, query)?; + let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + Ok(SimpleQueryStream { + responses, + columns: None, + status: ReadyForQueryStatus::Unknown, + _p: PhantomPinned, + }) +} + +pub async fn batch_execute( + client: &InnerClient, + query: &str, +) -> Result { + debug!("executing statement batch: {}", query); + + let buf = encode(client, query)?; + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + loop { + match responses.next().await? { + Message::ReadyForQuery(status) => return Ok(status.into()), + Message::CommandComplete(_) + | Message::EmptyQueryResponse + | Message::RowDescription(_) + | Message::DataRow(_) => {} + _ => return Err(Error::unexpected_message()), + } + } +} + +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { + client.with_buf(|buf| { + frontend::query(query, buf).map_err(Error::encode)?; + Ok(buf.split().freeze()) + }) +} + +pin_project! { + /// A stream of simple query results. + pub struct SimpleQueryStream { + responses: Responses, + columns: Option>, + status: ReadyForQueryStatus, + #[pin] + _p: PhantomPinned, + } +} + +impl SimpleQueryStream { + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} + +impl Stream for SimpleQueryStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::CommandComplete(body) => { + let rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows)))); + } + Message::EmptyQueryResponse => { + return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0)))); + } + Message::RowDescription(body) => { + let columns = body + .fields() + .map(|f| Ok(SimpleColumn::new(f.name().to_string()))) + .collect::>() + .map_err(Error::parse)? + .into(); + + *this.columns = Some(columns); + } + Message::DataRow(body) => { + let row = match &this.columns { + Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, + None => return Poll::Ready(Some(Err(Error::unexpected_message()))), + }; + return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); + } + Message::ReadyForQuery(s) => { + *this.status = s.into(); + return Poll::Ready(None); + } + _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/statement.rs b/libs/proxy/tokio-postgres2/src/statement.rs new file mode 100644 index 0000000000..22e160fc05 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/statement.rs @@ -0,0 +1,157 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::types::Type; +use postgres_protocol2::{ + message::{backend::Field, frontend}, + Oid, +}; +use std::{ + fmt, + sync::{Arc, Weak}, +}; + +struct StatementInner { + client: Weak, + name: String, + params: Vec, + columns: Vec, +} + +impl Drop for StatementInner { + fn drop(&mut self) { + if let Some(client) = self.client.upgrade() { + let buf = client.with_buf(|buf| { + frontend::close(b'S', &self.name, buf).unwrap(); + frontend::sync(buf); + buf.split().freeze() + }); + let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } +} + +/// A prepared statement. +/// +/// Prepared statements can only be used with the connection that created them. +#[derive(Clone)] +pub struct Statement(Arc); + +impl Statement { + pub(crate) fn new( + inner: &Arc, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement(Arc::new(StatementInner { + client: Arc::downgrade(inner), + name, + params, + columns, + })) + } + + pub(crate) fn new_anonymous(params: Vec, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { + client: Weak::new(), + name: String::new(), + params, + columns, + })) + } + + pub(crate) fn name(&self) -> &str { + &self.0.name + } + + /// Returns the expected types of the statement's parameters. + pub fn params(&self) -> &[Type] { + &self.0.params + } + + /// Returns information about the columns returned when the statement is queried. + pub fn columns(&self) -> &[Column] { + &self.0.columns + } +} + +/// Information about a column of a query. +pub struct Column { + name: String, + type_: Type, + + // raw fields from RowDescription + table_oid: Oid, + column_id: i16, + format: i16, + + // that better be stored in self.type_, but that is more radical refactoring + type_oid: Oid, + type_size: i16, + type_modifier: i32, +} + +impl Column { + pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column { + Column { + name, + type_, + table_oid: raw_field.table_oid(), + column_id: raw_field.column_id(), + format: raw_field.format(), + type_oid: raw_field.type_oid(), + type_size: raw_field.type_size(), + type_modifier: raw_field.type_modifier(), + } + } + + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the type of the column. + pub fn type_(&self) -> &Type { + &self.type_ + } + + /// Returns the table OID of the column. + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + /// Returns the column ID of the column. + pub fn column_id(&self) -> i16 { + self.column_id + } + + /// Returns the format of the column. + pub fn format(&self) -> i16 { + self.format + } + + /// Returns the type OID of the column. + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + /// Returns the type size of the column. + pub fn type_size(&self) -> i16 { + self.type_size + } + + /// Returns the type modifier of the column. + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } +} + +impl fmt::Debug for Column { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Column") + .field("name", &self.name) + .field("type", &self.type_) + .finish() + } +} diff --git a/libs/proxy/tokio-postgres2/src/tls.rs b/libs/proxy/tokio-postgres2/src/tls.rs new file mode 100644 index 0000000000..dc8140719f --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/tls.rs @@ -0,0 +1,162 @@ +//! TLS support. + +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, io}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub(crate) mod private { + pub struct ForcePrivateApi; +} + +/// Channel binding information returned from a TLS handshake. +pub struct ChannelBinding { + pub(crate) tls_server_end_point: Option>, +} + +impl ChannelBinding { + /// Creates a `ChannelBinding` containing no information. + pub fn none() -> ChannelBinding { + ChannelBinding { + tls_server_end_point: None, + } + } + + /// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information. + pub fn tls_server_end_point(tls_server_end_point: Vec) -> ChannelBinding { + ChannelBinding { + tls_server_end_point: Some(tls_server_end_point), + } + } +} + +/// A constructor of `TlsConnect`ors. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +pub trait MakeTlsConnect { + /// The stream type created by the `TlsConnect` implementation. + type Stream: TlsStream + Unpin; + /// The `TlsConnect` implementation created by this type. + type TlsConnect: TlsConnect; + /// The error type returned by the `TlsConnect` implementation. + type Error: Into>; + + /// Creates a new `TlsConnect`or. + /// + /// The domain name is provided for certificate verification and SNI. + fn make_tls_connect(&mut self, domain: &str) -> Result; +} + +/// An asynchronous function wrapping a stream in a TLS session. +pub trait TlsConnect { + /// The stream returned by the future. + type Stream: TlsStream + Unpin; + /// The error returned by the future. + type Error: Into>; + /// The future returned by the connector. + type Future: Future>; + + /// Returns a future performing a TLS handshake over the stream. + fn connect(self, stream: S) -> Self::Future; + + #[doc(hidden)] + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + true + } +} + +/// A TLS-wrapped connection to a PostgreSQL database. +pub trait TlsStream: AsyncRead + AsyncWrite { + /// Returns channel binding information for the session. + fn channel_binding(&self) -> ChannelBinding; +} + +/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error. +/// +/// This can be used when `sslmode` is `none` or `prefer`. +#[derive(Debug, Copy, Clone)] +pub struct NoTls; + +impl MakeTlsConnect for NoTls { + type Stream = NoTlsStream; + type TlsConnect = NoTls; + type Error = NoTlsError; + + fn make_tls_connect(&mut self, _: &str) -> Result { + Ok(NoTls) + } +} + +impl TlsConnect for NoTls { + type Stream = NoTlsStream; + type Error = NoTlsError; + type Future = NoTlsFuture; + + fn connect(self, _: S) -> NoTlsFuture { + NoTlsFuture(()) + } + + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + false + } +} + +/// The future returned by `NoTls`. +pub struct NoTlsFuture(()); + +impl Future for NoTlsFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + Poll::Ready(Err(NoTlsError(()))) + } +} + +/// The TLS "stream" type produced by the `NoTls` connector. +/// +/// Since `NoTls` doesn't support TLS, this type is uninhabited. +pub enum NoTlsStream {} + +impl AsyncRead for NoTlsStream { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll> { + match *self {} + } +} + +impl AsyncWrite for NoTlsStream { + fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll> { + match *self {} + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match *self {} + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match *self {} + } +} + +impl TlsStream for NoTlsStream { + fn channel_binding(&self) -> ChannelBinding { + match *self {} + } +} + +/// The error returned by `NoTls`. +#[derive(Debug)] +pub struct NoTlsError(()); + +impl fmt::Display for NoTlsError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("no TLS implementation configured") + } +} + +impl Error for NoTlsError {} diff --git a/libs/proxy/tokio-postgres2/src/to_statement.rs b/libs/proxy/tokio-postgres2/src/to_statement.rs new file mode 100644 index 0000000000..427f77dd79 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/to_statement.rs @@ -0,0 +1,57 @@ +use crate::to_statement::private::{Sealed, ToStatementType}; +use crate::Statement; + +mod private { + use crate::{Client, Error, Statement}; + + pub trait Sealed {} + + pub enum ToStatementType<'a> { + Statement(&'a Statement), + Query(&'a str), + } + + impl<'a> ToStatementType<'a> { + pub async fn into_statement(self, client: &Client) -> Result { + match self { + ToStatementType::Statement(s) => Ok(s.clone()), + ToStatementType::Query(s) => client.prepare(s).await, + } + } + } +} + +/// A trait abstracting over prepared and unprepared statements. +/// +/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which +/// was prepared previously. +/// +/// This trait is "sealed" and cannot be implemented by anything outside this crate. +pub trait ToStatement: Sealed { + #[doc(hidden)] + fn __convert(&self) -> ToStatementType<'_>; +} + +impl ToStatement for Statement { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Statement(self) + } +} + +impl Sealed for Statement {} + +impl ToStatement for str { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Query(self) + } +} + +impl Sealed for str {} + +impl ToStatement for String { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Query(self) + } +} + +impl Sealed for String {} diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs new file mode 100644 index 0000000000..03a57e4947 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -0,0 +1,74 @@ +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::query::RowStream; +use crate::{CancelToken, Client, Error, ReadyForQueryStatus}; +use postgres_protocol2::message::frontend; + +/// A representation of a PostgreSQL database transaction. +/// +/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the +/// transaction. Transactions can be nested, with inner transactions implemented via safepoints. +pub struct Transaction<'a> { + client: &'a mut Client, + done: bool, +} + +impl Drop for Transaction<'_> { + fn drop(&mut self) { + if self.done { + return; + } + + let buf = self.client.inner().with_buf(|buf| { + frontend::query("ROLLBACK", buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } +} + +impl<'a> Transaction<'a> { + pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> { + Transaction { + client, + done: false, + } + } + + /// Consumes the transaction, committing all changes made within it. + pub async fn commit(mut self) -> Result { + self.done = true; + self.client.batch_execute("COMMIT").await + } + + /// Rolls the transaction back, discarding all changes made within it. + /// + /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. + pub async fn rollback(mut self) -> Result { + self.done = true; + self.client.batch_execute("ROLLBACK").await + } + + /// Like `Client::query_raw_txt`. + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + self.client.query_raw_txt(statement, params).await + } + + /// Like `Client::cancel_token`. + pub fn cancel_token(&self) -> CancelToken { + self.client.cancel_token() + } + + /// Returns a reference to the underlying `Client`. + pub fn client(&self) -> &Client { + self.client + } +} diff --git a/libs/proxy/tokio-postgres2/src/transaction_builder.rs b/libs/proxy/tokio-postgres2/src/transaction_builder.rs new file mode 100644 index 0000000000..9718ac588c --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/transaction_builder.rs @@ -0,0 +1,113 @@ +use crate::{Client, Error, Transaction}; + +/// The isolation level of a database transaction. +#[derive(Debug, Copy, Clone)] +#[non_exhaustive] +pub enum IsolationLevel { + /// Equivalent to `ReadCommitted`. + ReadUncommitted, + + /// An individual statement in the transaction will see rows committed before it began. + ReadCommitted, + + /// All statements in the transaction will see the same view of rows committed before the first query in the + /// transaction. + RepeatableRead, + + /// The reads and writes in this transaction must be able to be committed as an atomic "unit" with respect to reads + /// and writes of all other concurrent serializable transactions without interleaving. + Serializable, +} + +/// A builder for database transactions. +pub struct TransactionBuilder<'a> { + client: &'a mut Client, + isolation_level: Option, + read_only: Option, + deferrable: Option, +} + +impl<'a> TransactionBuilder<'a> { + pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> { + TransactionBuilder { + client, + isolation_level: None, + read_only: None, + deferrable: None, + } + } + + /// Sets the isolation level of the transaction. + pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self { + self.isolation_level = Some(isolation_level); + self + } + + /// Sets the access mode of the transaction. + pub fn read_only(mut self, read_only: bool) -> Self { + self.read_only = Some(read_only); + self + } + + /// Sets the deferrability of the transaction. + /// + /// If the transaction is also serializable and read only, creation of the transaction may block, but when it + /// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to + /// serialization failure. + pub fn deferrable(mut self, deferrable: bool) -> Self { + self.deferrable = Some(deferrable); + self + } + + /// Begins the transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + pub async fn start(self) -> Result, Error> { + let mut query = "START TRANSACTION".to_string(); + let mut first = true; + + if let Some(level) = self.isolation_level { + first = false; + + query.push_str(" ISOLATION LEVEL "); + let level = match level { + IsolationLevel::ReadUncommitted => "READ UNCOMMITTED", + IsolationLevel::ReadCommitted => "READ COMMITTED", + IsolationLevel::RepeatableRead => "REPEATABLE READ", + IsolationLevel::Serializable => "SERIALIZABLE", + }; + query.push_str(level); + } + + if let Some(read_only) = self.read_only { + if !first { + query.push(','); + } + first = false; + + let s = if read_only { + " READ ONLY" + } else { + " READ WRITE" + }; + query.push_str(s); + } + + if let Some(deferrable) = self.deferrable { + if !first { + query.push(','); + } + + let s = if deferrable { + " DEFERRABLE" + } else { + " NOT DEFERRABLE" + }; + query.push_str(s); + } + + self.client.batch_execute(&query).await?; + + Ok(Transaction::new(self.client)) + } +} diff --git a/libs/proxy/tokio-postgres2/src/types.rs b/libs/proxy/tokio-postgres2/src/types.rs new file mode 100644 index 0000000000..e571d7ee00 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/types.rs @@ -0,0 +1,6 @@ +//! Types. +//! +//! This module is a reexport of the `postgres_types` crate. + +#[doc(inline)] +pub use postgres_types2::*; diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 1665d6361a..0d774d529d 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -55,6 +55,7 @@ parquet.workspace = true parquet_derive.workspace = true pin-project-lite.workspace = true postgres_backend.workspace = true +postgres-protocol = { package = "postgres-protocol2", path = "../libs/proxy/postgres-protocol2" } pq_proto.workspace = true prometheus.workspace = true rand.workspace = true @@ -80,8 +81,7 @@ subtle.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] } -tokio-postgres = { workspace = true, features = ["with-serde_json-1"] } -tokio-postgres-rustls.workspace = true +tokio-postgres = { package = "tokio-postgres2", path = "../libs/proxy/tokio-postgres2" } tokio-rustls.workspace = true tokio-util.workspace = true tokio = { workspace = true, features = ["signal"] } @@ -96,7 +96,6 @@ utils.workspace = true uuid.workspace = true rustls-native-certs.workspace = true x509-parser.workspace = true -postgres-protocol.workspace = true redis.workspace = true zerocopy.workspace = true @@ -117,6 +116,5 @@ tokio-tungstenite.workspace = true pbkdf2 = { workspace = true, features = ["simple", "std"] } rcgen.workspace = true rstest.workspace = true -tokio-postgres-rustls.workspace = true walkdir.workspace = true rand_distr = "0.4" diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 8408d4720b..2abe88ac88 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -13,7 +13,6 @@ use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::tls::MakeTlsConnect; -use tokio_postgres_rustls::MakeRustlsConnect; use tracing::{debug, error, info, warn}; use crate::auth::parse_endpoint_param; @@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; +use crate::postgres_rustls::MakeRustlsConnect; use crate::proxy::neon_option; use crate::types::Host; @@ -244,7 +244,6 @@ impl ConnCfg { let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432); let host = match host { Host::Tcp(host) => host.as_str(), - Host::Unix(_) => continue, // unix sockets are not welcome here }; match connect_once(host, *port).await { @@ -315,7 +314,7 @@ impl ConnCfg { }; let client_config = client_config.with_no_client_auth(); - let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config); + let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config); let tls = >::make_tls_connect( &mut mk_tls, host, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 5c19a23e36..4a063a5faa 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -414,6 +414,7 @@ impl RequestContextInner { outcome, }); } + if let Some(tx) = self.sender.take() { // If type changes, this error handling needs to be updated. let tx: mpsc::UnboundedSender = tx; diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index ad7e1d2771..ba69f9cf2d 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -88,6 +88,7 @@ pub mod jemalloc; pub mod logging; pub mod metrics; pub mod parse; +pub mod postgres_rustls; pub mod protocol2; pub mod proxy; pub mod rate_limiter; diff --git a/proxy/src/postgres_rustls/mod.rs b/proxy/src/postgres_rustls/mod.rs new file mode 100644 index 0000000000..31e7915e89 --- /dev/null +++ b/proxy/src/postgres_rustls/mod.rs @@ -0,0 +1,158 @@ +use std::convert::TryFrom; +use std::sync::Arc; + +use rustls::pki_types::ServerName; +use rustls::ClientConfig; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_postgres::tls::MakeTlsConnect; + +mod private { + use std::future::Future; + use std::io; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use rustls::pki_types::ServerName; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_postgres::tls::{ChannelBinding, TlsConnect}; + use tokio_rustls::client::TlsStream; + use tokio_rustls::TlsConnector; + + use crate::config::TlsServerEndPoint; + + pub struct TlsConnectFuture { + inner: tokio_rustls::Connect, + } + + impl Future for TlsConnectFuture + where + S: AsyncRead + AsyncWrite + Unpin, + { + type Output = io::Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream) + } + } + + pub struct RustlsConnect(pub RustlsConnectData); + + pub struct RustlsConnectData { + pub hostname: ServerName<'static>, + pub connector: TlsConnector, + } + + impl TlsConnect for RustlsConnect + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Stream = RustlsStream; + type Error = io::Error; + type Future = TlsConnectFuture; + + fn connect(self, stream: S) -> Self::Future { + TlsConnectFuture { + inner: self.0.connector.connect(self.0.hostname, stream), + } + } + } + + pub struct RustlsStream(TlsStream); + + impl tokio_postgres::tls::TlsStream for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn channel_binding(&self) -> ChannelBinding { + let (_, session) = self.0.get_ref(); + match session.peer_certificates() { + Some([cert, ..]) => TlsServerEndPoint::new(cert) + .ok() + .and_then(|cb| match cb { + TlsServerEndPoint::Sha256(hash) => Some(hash), + TlsServerEndPoint::Undefined => None, + }) + .map_or_else(ChannelBinding::none, |hash| { + ChannelBinding::tls_server_end_point(hash.to_vec()) + }), + _ => ChannelBinding::none(), + } + } + } + + impl AsyncRead for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + } + + impl AsyncWrite for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + } +} + +/// A `MakeTlsConnect` implementation using `rustls`. +/// +/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. +#[derive(Clone)] +pub struct MakeRustlsConnect { + config: Arc, +} + +impl MakeRustlsConnect { + /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. + #[must_use] + pub fn new(config: ClientConfig) -> Self { + Self { + config: Arc::new(config), + } + } +} + +impl MakeTlsConnect for MakeRustlsConnect +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Stream = private::RustlsStream; + type TlsConnect = private::RustlsConnect; + type Error = rustls::pki_types::InvalidDnsNameError; + + fn make_tls_connect(&mut self, hostname: &str) -> Result { + ServerName::try_from(hostname).map(|dns_name| { + private::RustlsConnect(private::RustlsConnectData { + hostname: dns_name.to_owned(), + connector: Arc::clone(&self.config).into(), + }) + }) + } +} diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 3de8ca8736..2c2c2964b6 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -14,7 +14,6 @@ use rustls::pki_types; use tokio::io::DuplexStream; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; -use tokio_postgres_rustls::MakeRustlsConnect; use super::connect_compute::ConnectMechanism; use super::retry::CouldRetry; @@ -29,6 +28,7 @@ use crate::control_plane::{ self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache, }; use crate::error::ErrorKind; +use crate::postgres_rustls::MakeRustlsConnect; use crate::types::{BranchId, EndpointId, ProjectId}; use crate::{sasl, scram}; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 3037e20888..75909f3358 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -333,7 +333,7 @@ impl PoolingBackend { debug!("setting up backend session state"); // initiates the auth session - if let Err(e) = client.query("select auth.init()", &[]).await { + if let Err(e) = client.execute("select auth.init()", &[]).await { discard.discard(); return Err(e.into()); } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index bd262f45ed..c302eac568 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -6,9 +6,10 @@ use std::task::{ready, Poll}; use futures::future::poll_fn; use futures::Future; use smallvec::SmallVec; +use tokio::net::TcpStream; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; -use tokio_postgres::{AsyncMessage, Socket}; +use tokio_postgres::AsyncMessage; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; #[cfg(test)] @@ -57,7 +58,7 @@ pub(crate) fn poll_client( ctx: &RequestContext, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: tokio_postgres::Connection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> Client { diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 9abe35db08..db9ac49dae 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -24,10 +24,11 @@ use p256::ecdsa::{Signature, SigningKey}; use parking_lot::RwLock; use serde_json::value::RawValue; use signature::Signer; +use tokio::net::TcpStream; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; use tokio_postgres::types::ToSql; -use tokio_postgres::{AsyncMessage, Socket}; +use tokio_postgres::AsyncMessage; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, info_span, warn, Instrument}; @@ -163,7 +164,7 @@ pub(crate) fn poll_client( ctx: &RequestContext, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: tokio_postgres::Connection, key: SigningKey, conn_id: uuid::Uuid, aux: MetricsAuxInfo, @@ -286,11 +287,11 @@ impl ClientInnerCommon { let token = resign_jwt(&local_data.key, payload, local_data.jti)?; // initiates the auth session - self.inner.simple_query("discard all").await?; + self.inner.batch_execute("discard all").await?; self.inner - .query( + .execute( "select auth.jwt_session_init($1)", - &[&token as &(dyn ToSql + Sync)], + &[&&*token as &(dyn ToSql + Sync)], ) .await?; diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index a73d9d6352..c0a3abc377 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -60,7 +60,6 @@ num-integer = { version = "0.1", features = ["i128"] } num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } parquet = { version = "53", default-features = false, features = ["zstd"] } -postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", default-features = false, features = ["with-serde_json-1"] } prost = { version = "0.13", features = ["prost-derive"] } rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } @@ -79,8 +78,7 @@ subtle = { version = "2" } sync_wrapper = { version = "0.1", default-features = false, features = ["futures"] } tikv-jemalloc-sys = { version = "0.6", features = ["stats"] } time = { version = "0.3", features = ["macros", "serde-well-known"] } -tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] } -tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", features = ["with-serde_json-1"] } +tokio = { version = "1", features = ["full", "test-util"] } tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] } tokio-stream = { version = "0.1", features = ["net"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }